Merge pull request #1227 from Codium-ai/tr/dynamic

refactor logic
This commit is contained in:
Tal
2024-09-13 22:22:47 +03:00
committed by GitHub
4 changed files with 83 additions and 58 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import re
import traceback
from pr_agent.config_loader import get_settings
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
@ -12,27 +13,48 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str
if type(original_file_str) == bytes:
try:
original_file_str = original_file_str.decode('utf-8')
except UnicodeDecodeError:
return ""
# skip patches
patch_extension_skip_types = get_settings().config.patch_extension_skip_types #[".md",".txt"]
if patch_extension_skip_types and filename:
if any([filename.endswith(skip_type) for skip_type in patch_extension_skip_types]):
original_file_str = decode_if_bytes(original_file_str)
if not original_file_str:
return patch_str
# dynamic context settings
if should_skip_patch(filename):
return patch_str
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
return extended_patch_str
def decode_if_bytes(original_file_str):
if isinstance(original_file_str, bytes):
try:
return original_file_str.decode('utf-8')
except UnicodeDecodeError:
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
for encoding in encodings_to_try:
try:
return original_file_str.decode(encoding)
except UnicodeDecodeError:
continue
return ""
return original_file_str
def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename:
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return False
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
allow_dynamic_context = get_settings().config.allow_dynamic_context
max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context
patch_extra_lines_before_dynamic = patch_extra_lines_before
if allow_dynamic_context:
if max_extra_lines_before_dynamic_context > patch_extra_lines_before:
patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context
else:
get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'")
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
@ -46,23 +68,14 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
# finish last hunk
# finish processing previous hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines)
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
def _calc_context_limits(patch_lines_before):
@ -82,7 +95,7 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
found_header = False
for i,line, in enumerate(lines_before):
for i, line, in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
@ -99,8 +112,9 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
for line in delta_lines:
if section_header in line:
@ -120,11 +134,10 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
continue
extended_patch_lines.append(line)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to extend patch: {e}")
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
# finish last hunk
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
@ -135,6 +148,20 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
return extended_patch_str
def extract_hunk_headers(match):
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
return section_header, size1, size2, start1, start2
def omit_deletion_hunks(patch_lines) -> str:
"""
Omit deletion hunks from the patch and return the modified patch.
@ -253,8 +280,8 @@ __old hunk__
start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line = []
for line in patch_lines:
if 'no newline at end of file' in line.lower():
for line_i, line in enumerate(patch_lines):
if 'no newline at end of file' in line.lower().strip().strip('//'):
continue
if line.startswith('@@'):
@ -280,21 +307,18 @@ __old hunk__
if match:
prev_header_line = header_line
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
continue
elif line_i + 1 == len(patch_lines):
continue
new_content_lines.append(line)
old_content_lines.append(line)
@ -339,15 +363,7 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
match = RE_HUNK_HEADER.match(line)
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
# check if line range is in this hunk
if side.lower() == 'left':

View File

@ -347,11 +347,9 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
except:
get_logger().warning(
f"Failed to generate prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}: "
f"{traceback.format_exc()}"
)
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception
raise Exception(f"Failed to generate prediction with any model of {all_models}")
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:

View File

@ -164,7 +164,7 @@ class PRReviewer:
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True,)
disable_extra_lines=False,)
if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff)

View File

@ -60,11 +60,22 @@ class TestExtendPatch:
original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6'
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501
num_lines = 1
original_allow_dynamic_context = get_settings().config.allow_dynamic_context
get_settings().config.allow_dynamic_context = False
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = True
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = original_allow_dynamic_context
def test_dynamic_context(self):
get_settings().config.max_extra_lines_before_dynamic_context = 10
original_file_str = "def foo():"