refactor logic

This commit is contained in:
mrT23
2024-09-13 22:17:24 +03:00
parent 0fb158fd47
commit cc0e432247
4 changed files with 83 additions and 58 deletions

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import re
import traceback
from pr_agent.config_loader import get_settings
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
@ -12,27 +13,48 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0) or not original_file_str:
return patch_str
if type(original_file_str) == bytes:
original_file_str = decode_if_bytes(original_file_str)
if not original_file_str:
return patch_str
if should_skip_patch(filename):
return patch_str
try:
extended_patch_str = process_patch_lines(patch_str, original_file_str,
patch_extra_lines_before, patch_extra_lines_after)
except Exception as e:
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
return extended_patch_str
def decode_if_bytes(original_file_str):
if isinstance(original_file_str, bytes):
try:
original_file_str = original_file_str.decode('utf-8')
return original_file_str.decode('utf-8')
except UnicodeDecodeError:
encodings_to_try = ['iso-8859-1', 'latin-1', 'ascii', 'utf-16']
for encoding in encodings_to_try:
try:
return original_file_str.decode(encoding)
except UnicodeDecodeError:
continue
return ""
return original_file_str
# skip patches
patch_extension_skip_types = get_settings().config.patch_extension_skip_types #[".md",".txt"]
def should_skip_patch(filename):
patch_extension_skip_types = get_settings().config.patch_extension_skip_types
if patch_extension_skip_types and filename:
if any([filename.endswith(skip_type) for skip_type in patch_extension_skip_types]):
return patch_str
return any(filename.endswith(skip_type) for skip_type in patch_extension_skip_types)
return False
# dynamic context settings
def process_patch_lines(patch_str, original_file_str, patch_extra_lines_before, patch_extra_lines_after):
allow_dynamic_context = get_settings().config.allow_dynamic_context
max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context
patch_extra_lines_before_dynamic = patch_extra_lines_before
if allow_dynamic_context:
if max_extra_lines_before_dynamic_context > patch_extra_lines_before:
patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context
else:
get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'")
patch_extra_lines_before_dynamic = get_settings().config.max_extra_lines_before_dynamic_context
original_lines = original_file_str.splitlines()
len_original_lines = len(original_lines)
@ -46,23 +68,14 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
for line in patch_lines:
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
# identify hunk header
if match:
# finish last hunk
# finish processing previous hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]]
extended_patch_lines.extend(delta_lines)
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
def _calc_context_limits(patch_lines_before):
@ -82,7 +95,7 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
_calc_context_limits(patch_extra_lines_before_dynamic)
lines_before = original_lines[extended_start1 - 1:start1 - 1]
found_header = False
for i,line, in enumerate(lines_before):
for i, line, in enumerate(lines_before):
if section_header in line:
found_header = True
# Update start and size in one line each
@ -99,12 +112,13 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
extended_start1, extended_size1, extended_start2, extended_size2 = \
_calc_context_limits(patch_extra_lines_before)
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
delta_lines = [f' {line}' for line in delta_lines]
delta_lines = [f' {line}' for line in original_lines[extended_start1 - 1:start1 - 1]]
# logic to remove section header if its in the extra delta lines (in dynamic context, this is also done)
if section_header and not allow_dynamic_context:
for line in delta_lines:
if section_header in line:
section_header = '' # remove section header if it is in the extra delta lines
section_header = '' # remove section header if it is in the extra delta lines
break
else:
extended_start1 = start1
@ -120,11 +134,10 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
continue
extended_patch_lines.append(line)
except Exception as e:
if get_settings().config.verbosity_level >= 2:
get_logger().error(f"Failed to extend patch: {e}")
get_logger().warning(f"Failed to extend patch: {e}", artifact={"traceback": traceback.format_exc()})
return patch_str
# finish last hunk
# finish processing last hunk
if start1 != -1 and patch_extra_lines_after > 0:
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
# add space at the beginning of each extra line
@ -135,6 +148,20 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0,
return extended_patch_str
def extract_hunk_headers(match):
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header = res[4]
return section_header, size1, size2, start1, start2
def omit_deletion_hunks(patch_lines) -> str:
"""
Omit deletion hunks from the patch and return the modified patch.
@ -253,8 +280,8 @@ __old hunk__
start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line = []
for line in patch_lines:
if 'no newline at end of file' in line.lower():
for line_i, line in enumerate(patch_lines):
if 'no newline at end of file' in line.lower().strip().strip('//'):
continue
if line.startswith('@@'):
@ -280,21 +307,18 @@ __old hunk__
if match:
prev_header_line = header_line
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
if not line and line_i: # if this line is empty and the next line is a hunk header, skip it
if line_i + 1 < len(patch_lines) and patch_lines[line_i + 1].startswith('@@'):
continue
elif line_i + 1 == len(patch_lines):
continue
new_content_lines.append(line)
old_content_lines.append(line)
@ -339,15 +363,7 @@ def extract_hunk_lines_from_patch(patch: str, file_name, line_start, line_end, s
match = RE_HUNK_HEADER.match(line)
res = list(match.groups())
for i in range(len(res)):
if res[i] is None:
res[i] = 0
try:
start1, size1, start2, size2 = map(int, res[:4])
except: # '@@ -0,0 +1 @@' case
start1, size1, size2 = map(int, res[:3])
start2 = 0
section_header, size1, size2, start1, start2 = extract_hunk_headers(match)
# check if line range is in this hunk
if side.lower() == 'left':