mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 21:00:40 +08:00
Merge pull request #1156 from Codium-ai/tr/dynamic_context
Tr/dynamic context
This commit is contained in:
@ -17,6 +17,15 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch
|
|||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
allow_dynamic_context = get_settings().config.allow_dynamic_context
|
||||||
|
max_extra_lines_before_dynamic_context = get_settings().config.max_extra_lines_before_dynamic_context
|
||||||
|
patch_extra_lines_before_dynamic = patch_extra_lines_before
|
||||||
|
if allow_dynamic_context:
|
||||||
|
if max_extra_lines_before_dynamic_context > patch_extra_lines_before:
|
||||||
|
patch_extra_lines_before_dynamic = max_extra_lines_before_dynamic_context
|
||||||
|
else:
|
||||||
|
get_logger().warning(f"'max_extra_lines_before_dynamic_context' should be greater than 'patch_extra_lines_before'")
|
||||||
|
|
||||||
original_lines = original_file_str.splitlines()
|
original_lines = original_file_str.splitlines()
|
||||||
len_original_lines = len(original_lines)
|
len_original_lines = len(original_lines)
|
||||||
patch_lines = patch_str.splitlines()
|
patch_lines = patch_str.splitlines()
|
||||||
@ -48,18 +57,41 @@ def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch
|
|||||||
section_header = res[4]
|
section_header = res[4]
|
||||||
|
|
||||||
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
|
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
|
||||||
extended_start1 = max(1, start1 - patch_extra_lines_before)
|
def _calc_context_limits(patch_lines_before):
|
||||||
|
extended_start1 = max(1, start1 - patch_lines_before)
|
||||||
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
||||||
extended_start2 = max(1, start2 - patch_extra_lines_before)
|
extended_start2 = max(1, start2 - patch_lines_before)
|
||||||
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
||||||
if extended_start1 - 1 + extended_size1 > len_original_lines:
|
if extended_start1 - 1 + extended_size1 > len_original_lines:
|
||||||
# we cannot extend beyond the original file
|
# we cannot extend beyond the original file
|
||||||
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
|
delta_cap = extended_start1 - 1 + extended_size1 - len_original_lines
|
||||||
extended_size1 = max(extended_size1 - delta_cap, size1)
|
extended_size1 = max(extended_size1 - delta_cap, size1)
|
||||||
extended_size2 = max(extended_size2 - delta_cap, size2)
|
extended_size2 = max(extended_size2 - delta_cap, size2)
|
||||||
|
return extended_start1, extended_size1, extended_start2, extended_size2
|
||||||
|
|
||||||
|
if allow_dynamic_context:
|
||||||
|
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||||
|
_calc_context_limits(patch_extra_lines_before_dynamic)
|
||||||
|
lines_before = original_lines[extended_start1 - 1:start1 - 1]
|
||||||
|
found_header = False
|
||||||
|
for i,line, in enumerate(lines_before):
|
||||||
|
if section_header in line:
|
||||||
|
found_header = True
|
||||||
|
extended_start1 = extended_start1 + i
|
||||||
|
get_logger().debug(f"Found section header in line {i} before the hunk")
|
||||||
|
section_header = ''
|
||||||
|
break
|
||||||
|
if not found_header:
|
||||||
|
get_logger().debug(f"Section header not found in the extra lines before the hunk")
|
||||||
|
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||||
|
_calc_context_limits(patch_extra_lines_before)
|
||||||
|
else:
|
||||||
|
extended_start1, extended_size1, extended_start2, extended_size2 = \
|
||||||
|
_calc_context_limits(patch_extra_lines_before)
|
||||||
|
|
||||||
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
|
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
|
||||||
delta_lines = [f' {line}' for line in delta_lines]
|
delta_lines = [f' {line}' for line in delta_lines]
|
||||||
if section_header:
|
if section_header and not allow_dynamic_context:
|
||||||
for line in delta_lines:
|
for line in delta_lines:
|
||||||
if section_header in line:
|
if section_header in line:
|
||||||
section_header = '' # remove section header if it is in the extra delta lines
|
section_header = '' # remove section header if it is in the extra delta lines
|
||||||
|
@ -20,6 +20,8 @@ max_commits_tokens = 500
|
|||||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||||
custom_model_max_tokens=-1 # for models not in the default list
|
custom_model_max_tokens=-1 # for models not in the default list
|
||||||
#
|
#
|
||||||
|
allow_dynamic_context=false
|
||||||
|
max_extra_lines_before_dynamic_context = 10 # will try to include up to 10 extra lines before the hunk in the patch, until we reach an enclosing function or class
|
||||||
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
|
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
|
||||||
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
|
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
|
||||||
secret_provider=""
|
secret_provider=""
|
||||||
|
@ -2,9 +2,13 @@ import pytest
|
|||||||
from pr_agent.algo.git_patch_processing import extend_patch
|
from pr_agent.algo.git_patch_processing import extend_patch
|
||||||
from pr_agent.algo.pr_processing import pr_generate_extended_diff
|
from pr_agent.algo.pr_processing import pr_generate_extended_diff
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
|
||||||
|
|
||||||
class TestExtendPatch:
|
class TestExtendPatch:
|
||||||
|
def setUp(self):
|
||||||
|
get_settings().config.allow_dynamic_context = False
|
||||||
|
|
||||||
# Tests that the function works correctly with valid input
|
# Tests that the function works correctly with valid input
|
||||||
def test_happy_path(self):
|
def test_happy_path(self):
|
||||||
original_file_str = 'line1\nline2\nline3\nline4\nline5'
|
original_file_str = 'line1\nline2\nline3\nline4\nline5'
|
||||||
@ -61,8 +65,34 @@ class TestExtendPatch:
|
|||||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||||
assert actual_output == expected_output
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_dynamic_context(self):
|
||||||
|
get_settings().config.max_extra_lines_before_dynamic_context = 10
|
||||||
|
original_file_str = "def foo():"
|
||||||
|
for i in range(9):
|
||||||
|
original_file_str += f"\n line({i})"
|
||||||
|
patch_str ="@@ -11,1 +11,1 @@ def foo():\n- line(9)\n+ new_line(9)"
|
||||||
|
num_lines=1
|
||||||
|
|
||||||
|
get_settings().config.allow_dynamic_context = True
|
||||||
|
actual_output = extend_patch(original_file_str, patch_str,
|
||||||
|
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||||
|
expected_output='\n@@ -1,10 +1,10 @@ \n def foo():\n line(0)\n line(1)\n line(2)\n line(3)\n line(4)\n line(5)\n line(6)\n line(7)\n line(8)\n- line(9)\n+ new_line(9)'
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
get_settings().config.allow_dynamic_context = False
|
||||||
|
actual_output2 = extend_patch(original_file_str, patch_str,
|
||||||
|
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||||
|
expected_output_no_dynamic_context = '\n@@ -10,1 +10,1 @@ def foo():\n line(8)\n- line(9)\n+ new_line(9)'
|
||||||
|
assert actual_output2 == expected_output_no_dynamic_context
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtendedPatchMoreLines:
|
class TestExtendedPatchMoreLines:
|
||||||
|
def setUp(self):
|
||||||
|
get_settings().config.allow_dynamic_context = False
|
||||||
|
|
||||||
class File:
|
class File:
|
||||||
def __init__(self, base_file, patch, filename):
|
def __init__(self, base_file, patch, filename):
|
||||||
self.base_file = base_file
|
self.base_file = base_file
|
||||||
|
Reference in New Issue
Block a user