Merge pull request #1227 from Codium-ai/tr/dynamic

refactor logic
This commit is contained in:
Tal
2024-09-13 22:22:47 +03:00
committed by GitHub
4 changed files with 83 additions and 58 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
import traceback
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
@ -12,27 +13,48 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str: if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str return patch_str
if type(original_file_str) == bytes: original_file_str = decode_if_bytes(original_file_str)
if not original_file_str:
return patch_str
if should_skip_patch(filename):
return patch_str
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
return extended_patch_str
def decode_if_bytes(original_file_str):
if isinstance(original_file_str, bytes):
try: try:
original_file_str = original_file_str.decode('utf-8') return original_file_str.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
for encoding in encodings_to_try:
try:
return original_file_str.decode(encoding)
except UnicodeDecodeError:
continue
return "" return ""
return original_file_str
# skip patches
patch_extension_skip_types = get_settings().config.patch_extension_skip_types #[".md",".txt"] def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename: if patch_extension_skip_types and filename:
if any([filename.endswith(skip_type) for skip_type in patch_extension_skip_types]): return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return patch_str return False
# dynamic context settings
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
allow_dynamic_context = get_settings().config.allow_dynamic_context allow_dynamic_context = get_settings().config.allow_dynamic_context
max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
patch_extra_lines_before_dynamic = patch_extra_lines_before
if allow_dynamic_context:
if max_extra_lines_before_dynamic_context > patch_extra_lines_before:
patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context
else:
get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'")
original_lines = original_file_str.splitlines() original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines) len_original_lines = len(original_lines)
@ -46,23 +68,14 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
for line in patch_lines: for line in patch_lines:
if line.startswith('@@'): if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match: if match:
# finish last hunk # finish processing previous hunk
if start1 != -1 and patch_extra_lines_after > 0: if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines) extended_patch_lines.extend(delta_lines)
res = list(match.groups()) section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0: if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
def _calc_context_limits(patch_lines_before): def _calc_context_limits(patch_lines_before):
@ -82,7 +95,7 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
_calc_context_limits(patch_extra_lines_before_dynamic) _calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1] lines_before = original_lines[extended_start1 - 1:start1 - 1]
found_header = False found_header = False
for i,line, in enumerate(lines_before): for i, line, in enumerate(lines_before):
if section_header in line: if section_header in line:
found_header = True found_header = True
# Update start and size in one line each # Update start and size in one line each
@ -99,12 +112,13 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
extended_start1, extended_size1, extended_start2, extended_size2 = \ extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before) _calc_context_limits(patch_extra_lines_before)
delta_lines = original_lines[extended_start1 - 1:start1 - 1] delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
delta_lines = [f' {line}' for line in delta_lines]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context: if section_header and not allow_dynamic_context:
for line in delta_lines: for line in delta_lines:
if section_header in line: if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines section_header = '' # remove section header if it is in the extra delta lines
break break
else: else:
extended_start1 = start1 extended_start1 = start1
@ -120,11 +134,10 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
continue continue
extended_patch_lines.append(line) extended_patch_lines.append(line)
except Exception as e: except Exception as e:
if get_settings().config.verbosity_level >= 2: get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
get_logger().error(f"Failed to extend patch: {e}")
return patch_str return patch_str
# finish last hunk # finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0: if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line # add space at the beginning of each extra line
@ -135,6 +148,20 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
return extended_patch_str return extended_patch_str
def extract_hunk_headers(match):
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
return section_header, size1, size2, start1, start2
def omit_deletion_hunks(patch_lines) -> str: def omit_deletion_hunks(patch_lines) -> str:
""" """
Omit deletion hunks from the patch and return the modified patch. Omit deletion hunks from the patch and return the modified patch.
@ -253,8 +280,8 @@ __old hunk__
start1, size1, start2, size2 = -1, -1, -1, -1 start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = [] prev_header_line = []
header_line = [] header_line = []
for line in patch_lines: for line_i, line in enumerate(patch_lines):
if 'no newline at end of file' in line.lower(): if 'no newline at end of file' in line.lower().strip().strip('//'):
continue continue
if line.startswith('@@'): if line.startswith('@@'):
@ -280,21 +307,18 @@ __old hunk__
if match: if match:
prev_header_line = header_line prev_header_line = header_line
res = list(match.groups()) section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
elif line.startswith('+'): elif line.startswith('+'):
new_content_lines.append(line) new_content_lines.append(line)
elif line.startswith('-'): elif line.startswith('-'):
old_content_lines.append(line) old_content_lines.append(line)
else: else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
continue
elif line_i + 1 == len(patch_lines):
continue
new_content_lines.append(line) new_content_lines.append(line)
old_content_lines.append(line) old_content_lines.append(line)
@ -339,15 +363,7 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
res = list(match.groups()) section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
# check if line range is in this hunk # check if line range is in this hunk
if side.lower() == 'left': if side.lower() == 'left':

View File

@ -347,11 +347,9 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT
except: except:
get_logger().warning( get_logger().warning(
f"Failed to generate prediction with {model}" f"Failed to generate prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}: "
f"{traceback.format_exc()}"
) )
if i == len(all_models) - 1: # If it's the last iteration if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception raise Exception(f"Failed to generate prediction with any model of {all_models}")
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:

View File

@ -164,7 +164,7 @@ class PRReviewer:
self.token_handler, self.token_handler,
model, model,
add_line_numbers_to_hunks=True, add_line_numbers_to_hunks=True,
disable_extra_lines=True,) disable_extra_lines=False,)
if self.patches_diff: if self.patches_diff:
get_logger().debug(f"PR diff", diff=self.patches_diff) get_logger().debug(f"PR diff", diff=self.patches_diff)

View File

@ -60,11 +60,22 @@ class TestExtendPatch:
original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6' original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6'
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501 patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501
num_lines = 1 num_lines = 1
original_allow_dynamic_context = get_settings().config.allow_dynamic_context
get_settings().config.allow_dynamic_context = False
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501 expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str, actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output assert actual_output == expected_output
get_settings().config.allow_dynamic_context = True
expected_output = '\n@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = original_allow_dynamic_context
def test_dynamic_context(self): def test_dynamic_context(self):
get_settings().config.max_extra_lines_before_dynamic_context = 10 get_settings().config.max_extra_lines_before_dynamic_context = 10
original_file_str = "def foo():" original_file_str = "def foo():"