feat: improve patch extension with new file content comparison

This commit is contained in:
mrT23
2025-02-24 11:46:12 +02:00
parent 2a647709c4
commit 56250f5ea8
4 changed files with 103 additions and 55 deletions

View File

@ -9,11 +9,12 @@ from pr_agent.log import get_logger
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
patch_extra_lines_after=0, filename: str = "") -> str:
patch_extra_lines_after=0, filename: str = "", new_file_str="") -> str:
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str
original_file_str = decode_if_bytes(original_file_str)
new_file_str = decode_if_bytes(new_file_str)
if not original_file_str:
return patch_str
@ -22,7 +23,7 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
patch_extra_lines_before, patch_extra_lines_after, new_file_str)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
@ -52,12 +53,13 @@ def should_skip_patch(filename):
return False
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after, new_file_str=""):
allow_dynamic_context = get_settings().config.allow_dynamic_context
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
file_original_lines = original_file_str.splitlines()
file_new_lines = new_file_str.splitlines() if new_file_str else []
len_original_lines = len(file_original_lines)
patch_lines = patch_str.splitlines()
extended_patch_lines = []
@ -73,12 +75,12 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
if match:
# finish processing previous hunk
if is_valid_hunk and (start1 != -1 and patch_extra_lines_after > 0):
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines)
delta_lines_original = [f' {line}' for line in file_original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines_original)
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1)
is_valid_hunk = check_if_hunk_lines_matches_to_file(i, file_original_lines, patch_lines, start1)
if is_valid_hunk and (patch_extra_lines_before > 0 or patch_extra_lines_after > 0):
def _calc_context_limits(patch_lines_before):
@ -93,20 +95,28 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
extended_size2 = max(extended_size2 - delta_cap, size2)
return extended_start1, extended_size1, extended_start2, extended_size2
if allow_dynamic_context:
if allow_dynamic_context and file_new_lines:
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
lines_before_original = file_original_lines[extended_start1 - 1:start1 - 1]
lines_before_new = file_new_lines[extended_start2 - 1:start2 - 1]
found_header = False
for i, line, in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
# get_logger().debug(f"Found section header in line {i} before the hunk")
section_header = ''
break
if lines_before_original == lines_before_new: # Making sure no changes from a previous hunk
for i, line, in enumerate(lines_before_original):
if section_header in line:
found_header = True
# Update start and size in one line each
extended_start1, extended_start2 = extended_start1 + i, extended_start2 + i
extended_size1, extended_size2 = extended_size1 - i, extended_size2 - i
# get_logger().debug(f"Found section header in line {i} before the hunk")
section_header = ''
break
else:
get_logger().debug(f"Extra lines before hunk are different in original and new file - dynamic context",
artifact={"lines_before_original": lines_before_original,
"lines_before_new": lines_before_new})
if not found_header:
# get_logger().debug(f"Section header not found in the extra lines before the hunk")
extended_start1, extended_size1, extended_start2, extended_size2 = \
@ -115,11 +125,23 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
# check if extra lines before hunk are different in original and new file
delta_lines_original = [f' {line}' for line in file_original_lines[extended_start1 - 1:start1 - 1]]
if file_new_lines:
delta_lines_new = [f' {line}' for line in file_new_lines[extended_start2 - 1:start2 - 1]]
if delta_lines_original != delta_lines_new:
get_logger().debug(f"Extra lines before hunk are different in original and new file",
artifact={"delta_lines_original": delta_lines_original,
"delta_lines_new": delta_lines_new})
extended_start1 = start1
extended_size1 = size1
extended_start2 = start2
extended_size2 = size2
delta_lines_original = []
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
for line in delta_lines:
for line in delta_lines_original:
if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines
break
@ -128,12 +150,12 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
extended_size1 = size1
extended_start2 = start2
extended_size2 = size2
delta_lines = []
delta_lines_original = []
extended_patch_lines.append('')
extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}')
extended_patch_lines.extend(delta_lines) # one to zero based
extended_patch_lines.extend(delta_lines_original) # one to zero based
continue
extended_patch_lines.append(line)
except Exception as e:
@ -142,15 +164,14 @@ def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before,
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0 and is_valid_hunk:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines_original = file_original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
delta_lines = [f' {line}' for line in delta_lines]
extended_patch_lines.extend(delta_lines)
delta_lines_original = [f' {line}' for line in delta_lines_original]
extended_patch_lines.extend(delta_lines_original)
extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str
def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
"""
Check if the hunk lines match the original file content. We saw cases where the hunk header line doesn't match the original file content, and then
@ -160,8 +181,18 @@ def check_if_hunk_lines_matches_to_file(i, original_lines, patch_lines, start1):
try:
if i + 1 < len(patch_lines) and patch_lines[i + 1][0] == ' ': # an existing line in the file
if patch_lines[i + 1].strip() != original_lines[start1 - 1].strip():
# check if different encoding is needed
original_line = original_lines[start1 - 1].strip()
for encoding in ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']:
try:
if original_line.encode(encoding).decode().strip() == patch_lines[i + 1].strip():
get_logger().info(f"Detected different encoding in hunk header line {start1}, needed encoding: {encoding}")
return False # we still want to avoid extending the hunk. But we don't want to log an error
except:
pass
is_valid_hunk = False
get_logger().error(
get_logger().info(
f"Invalid hunk in PR, line {start1} in hunk header doesn't match the original file content")
except:
pass
@ -288,7 +319,7 @@ __old hunk__
"""
# if the file was deleted, return a message indicating that the file was deleted
if hasattr(file, 'edit_type') and file.edit_type == EDIT_TYPE.DELETED:
return f"\n\n## file '{file.filename.strip()}' was deleted\n"
return f"\n\n## File '{file.filename.strip()}' was deleted\n"
patch_with_lines_str = f"\n\n## File: '{file.filename.strip()}'\n"
patch_lines = patch.splitlines()
@ -363,7 +394,7 @@ __old hunk__
return patch_with_lines_str.rstrip()
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side) -> tuple[str, str]:
def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, side, remove_trailing_chars: bool = True) -> tuple[str, str]:
try:
patch_with_lines_str = f"\n\n## File: '{file_name.strip()}'\n\n"
selected_lines = ""
@ -411,4 +442,8 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
get_logger().error(f"Failed to extract hunk lines from patch: {e}", artifact={"traceback": traceback.format_exc()})
return "", ""
return patch_with_lines_str.rstrip(), selected_lines.rstrip()
if remove_trailing_chars:
patch_with_lines_str = patch_with_lines_str.rstrip()
selected_lines = selected_lines.rstrip()
return patch_with_lines_str, selected_lines

View File

@ -195,13 +195,15 @@ def pr_generate_extended_diff(pr_languages: list,
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch,
patch_extra_lines_before, patch_extra_lines_after, file.filename)
patch_extra_lines_before, patch_extra_lines_after, file.filename,
new_file_str=new_file_content_str)
if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
continue
@ -212,7 +214,7 @@ def pr_generate_extended_diff(pr_languages: list,
full_extended_patch = f"\n\n## File: '{file.filename.strip()}'\n{extended_patch.rstrip()}\n"
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
full_extended_patch = add_ai_summary_top_patch(file, full_extended_patch)
patch_tokens = token_handler.count_tokens(full_extended_patch)
@ -384,7 +386,8 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
max_calls: int = 5,
add_line_numbers: bool = True) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
@ -425,7 +428,8 @@ def get_pr_multi_diffs(git_provider: GitProvider,
# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True,
pr_languages, token_handler,
add_line_numbers_to_hunks=add_line_numbers,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE,
patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
@ -454,7 +458,12 @@ def get_pr_multi_diffs(git_provider: GitProvider,
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
# Add line numbers and metadata to the patch
if add_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
else:
patch = f"\n\n## File: '{file.filename.strip()}'\n\n{patch.strip()}\n"
# add AI-summary metadata to the patch
if file.ai_file_summary and get_settings().get("config.enable_ai_metadata", False):
patch = add_ai_summary_top_patch(file, patch)

View File

@ -34,7 +34,7 @@ global_settings = Dynaconf(
)
def get_settings():
def get_settings(use_context=False):
"""
Retrieves the current settings.

View File

@ -5,12 +5,11 @@ from pr_agent.algo.pr_processing import pr_generate_extended_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import get_settings
get_settings().set("CONFIG.CLI_MODE", True)
get_settings().config.allow_dynamic_context = False
class TestExtendPatch:
def setUp(self):
get_settings().config.allow_dynamic_context = False
# Tests that the function works correctly with valid input
def test_happy_path(self):
original_file_str = 'line1\nline2\nline3\nline4\nline5'
@ -75,41 +74,46 @@ class TestExtendPatch:
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = original_allow_dynamic_context
get_settings(use_context=False).config.allow_dynamic_context = original_allow_dynamic_context
def test_dynamic_context(self):
get_settings().config.max_extra_lines_before_dynamic_context = 10
get_settings(use_context=False).config.max_extra_lines_before_dynamic_context = 10
original_file_str = "def foo():"
for i in range(9):
original_file_str += f"\n line({i})"
patch_str ="@@ -11,1 +11,1 @@ def foo():\n- line(9)\n+ new_line(9)"
patch_str ="@@ -10,1 +10,1 @@ def foo():\n- line(8)\n+ new_line(8)"
new_file_str = "\n".join(original_file_str.splitlines()[:-1] + [" new_line(8)"])
num_lines=1
get_settings().config.allow_dynamic_context = True
get_settings(use_context=False).config.allow_dynamic_context = True
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
expected_output='\n@@ -1,10 +1,10 @@ \n def foo():\n line(0)\n line(1)\n line(2)\n line(3)\n line(4)\n line(5)\n line(6)\n line(7)\n line(8)\n- line(9)\n+ new_line(9)'
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines, new_file_str=new_file_str)
expected_output='\n@@ -1,10 +1,10 @@ \n def foo():\n line(0)\n line(1)\n line(2)\n line(3)\n line(4)\n line(5)\n line(6)\n line(7)\n- line(8)\n+ new_line(8)'
assert actual_output == expected_output
get_settings().config.allow_dynamic_context = False
get_settings(use_context=False).config.allow_dynamic_context = False
actual_output2 = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
expected_output_no_dynamic_context = '\n@@ -10,1 +10,1 @@ def foo():\n line(8)\n- line(9)\n+ new_line(9)'
patch_extra_lines_before=1, patch_extra_lines_after=1)
expected_output_no_dynamic_context = '\n@@ -9,2 +9,2 @@ def foo():\n line(7)\n- line(8)\n+ new_line(8)'
assert actual_output2 == expected_output_no_dynamic_context
get_settings(use_context=False).config.allow_dynamic_context = False
actual_output3 = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=3, patch_extra_lines_after=3)
expected_output_no_dynamic_context = '\n@@ -7,4 +7,4 @@ def foo():\n line(5)\n line(6)\n line(7)\n- line(8)\n+ new_line(8)'
assert actual_output3 == expected_output_no_dynamic_context
class TestExtendedPatchMoreLines:
def setUp(self):
get_settings().config.allow_dynamic_context = False
class File:
def __init__(self, base_file, patch, filename, ai_file_summary=None):
def __init__(self, base_file, patch, head_file, filename, ai_file_summary=None):
self.base_file = base_file
self.patch = patch
self.head_file = head_file
self.filename = filename
self.ai_file_summary = ai_file_summary
@ -128,9 +132,11 @@ class TestExtendedPatchMoreLines:
'files': [
self.File(base_file="line000\nline00\nline0\nline1\noriginal content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
patch="@@ -5,5 +5,5 @@\n-original content\n+modified content\n line2\n line3\n line4\n line5",
head_file="line000\nline00\nline0\nline1\nmodified content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
filename="file1"),
self.File(base_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
patch="@@ -6,5 +6,5 @@\nline6\nline7\nline8\n-line9\n+modified line9\nline10",
head_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nmodified line9\nline10",
filename="file2")
]
}
@ -155,11 +161,9 @@ class TestExtendedPatchMoreLines:
patch_extra_lines_after=1
)
p0_extended = patches_extended_with_extra_lines[0].strip()
assert p0_extended == "## File: 'file1'\n\n@@ -3,8 +3,8 @@ \n line0\n line1\n-original content\n+modified content\n line2\n line3\n line4\n line5\n line6"
class TestLoadLargeDiff:
def test_no_newline(self):
patch = load_large_diff("test.py",