mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-07 06:10:39 +08:00
Merge pull request #1114 from Codium-ai/tr/patch_extra_lines_before_and_after
Tr/patch extra lines before and after
This commit is contained in:
@ -66,7 +66,8 @@ By default, around any change in your PR, git patch provides three lines of cont
|
||||
For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter:
|
||||
```
|
||||
[config]
|
||||
patch_extra_lines=3
|
||||
patch_extra_lines_before=4
|
||||
patch_extra_lines_after=2
|
||||
```
|
||||
|
||||
Increasing this number provides more context to the model, but will also increase the token budget.
|
||||
|
@ -46,6 +46,7 @@ MAX_TOKENS = {
|
||||
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
|
||||
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
|
||||
'claude-3-5-sonnet': 100000,
|
||||
'groq/llama3-8b-8192': 8192,
|
||||
'groq/llama3-70b-8192': 8192,
|
||||
'groq/mixtral-8x7b-32768': 32768,
|
||||
|
@ -7,19 +7,8 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
||||
def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
"""
|
||||
Extends the given patch to include a specified number of surrounding lines.
|
||||
|
||||
Args:
|
||||
original_file_str (str): The original file to which the patch will be applied.
|
||||
patch_str (str): The patch to be applied to the original file.
|
||||
num_lines (int): The number of surrounding lines to include in the extended patch.
|
||||
|
||||
Returns:
|
||||
str: The extended patch string.
|
||||
"""
|
||||
if not patch_str or num_lines == 0:
|
||||
def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch_extra_lines_after=0) -> str:
|
||||
if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0):
|
||||
return patch_str
|
||||
|
||||
if type(original_file_str) == bytes:
|
||||
@ -29,6 +18,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
return ""
|
||||
|
||||
original_lines = original_file_str.splitlines()
|
||||
len_original_lines = len(original_lines)
|
||||
patch_lines = patch_str.splitlines()
|
||||
extended_patch_lines = []
|
||||
|
||||
@ -40,10 +30,11 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
if line.startswith('@@'):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
if match:
|
||||
# finish previous hunk
|
||||
if start1 != -1:
|
||||
extended_patch_lines.extend(
|
||||
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
|
||||
# finish last hunk
|
||||
if start1 != -1 and patch_extra_lines_after > 0:
|
||||
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
|
||||
delta_lines = [f' {line}' for line in delta_lines]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
|
||||
res = list(match.groups())
|
||||
for i in range(len(res)):
|
||||
@ -55,15 +46,33 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
start1, size1, size2 = map(int, res[:3])
|
||||
start2 = 0
|
||||
section_header = res[4]
|
||||
extended_start1 = max(1, start1 - num_lines)
|
||||
extended_size1 = size1 + (start1 - extended_start1) + num_lines
|
||||
extended_start2 = max(1, start2 - num_lines)
|
||||
extended_size2 = size2 + (start2 - extended_start2) + num_lines
|
||||
|
||||
if patch_extra_lines_before > 0 or patch_extra_lines_after > 0:
|
||||
extended_start1 = max(1, start1 - patch_extra_lines_before)
|
||||
extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
|
||||
if extended_start1 - 1 + extended_size1 > len(original_lines):
|
||||
extended_size1 = len_original_lines - extended_start1 + 1
|
||||
extended_start2 = max(1, start2 - patch_extra_lines_before)
|
||||
extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
|
||||
if extended_start2 - 1 + extended_size2 > len_original_lines:
|
||||
extended_size2 = len_original_lines - extended_start2 + 1
|
||||
delta_lines = original_lines[extended_start1 - 1:start1 - 1]
|
||||
delta_lines = [f' {line}' for line in delta_lines]
|
||||
if section_header:
|
||||
for line in delta_lines:
|
||||
if section_header in line:
|
||||
section_header = '' # remove section header if it is in the extra delta lines
|
||||
break
|
||||
else:
|
||||
extended_start1 = start1
|
||||
extended_size1 = size1
|
||||
extended_start2 = start2
|
||||
extended_size2 = size2
|
||||
delta_lines = []
|
||||
extended_patch_lines.append(
|
||||
f'@@ -{extended_start1},{extended_size1} '
|
||||
f'+{extended_start2},{extended_size2} @@ {section_header}')
|
||||
extended_patch_lines.extend(
|
||||
original_lines[extended_start1 - 1:start1 - 1]) # one to zero based
|
||||
extended_patch_lines.extend(delta_lines) # one to zero based
|
||||
continue
|
||||
extended_patch_lines.append(line)
|
||||
except Exception as e:
|
||||
@ -71,10 +80,12 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
get_logger().error(f"Failed to extend patch: {e}")
|
||||
return patch_str
|
||||
|
||||
# finish previous hunk
|
||||
if start1 != -1:
|
||||
extended_patch_lines.extend(
|
||||
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines])
|
||||
# finish last hunk
|
||||
if start1 != -1 and patch_extra_lines_after > 0:
|
||||
delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]
|
||||
# add space at the beginning of each extra line
|
||||
delta_lines = [f' {line}' for line in delta_lines]
|
||||
extended_patch_lines.extend(delta_lines)
|
||||
|
||||
extended_patch_str = '\n'.join(extended_patch_lines)
|
||||
return extended_patch_str
|
||||
|
@ -33,9 +33,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
large_pr_handling=False,
|
||||
return_remaining_files=False):
|
||||
if disable_extra_lines:
|
||||
PATCH_EXTRA_LINES = 0
|
||||
PATCH_EXTRA_LINES_BEFORE = 0
|
||||
PATCH_EXTRA_LINES_AFTER = 0
|
||||
else:
|
||||
PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines
|
||||
PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
|
||||
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
|
||||
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
@ -64,7 +66,8 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
|
||||
# generate a standard diff string, with patch extension
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks,
|
||||
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
@ -72,7 +75,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
f"returning full diff.")
|
||||
return "\n".join(patches_extended)
|
||||
|
||||
# if we are over the limit, start pruning
|
||||
# if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines)
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||
f"pruning diff.")
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
@ -174,17 +177,8 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
|
||||
def pr_generate_extended_diff(pr_languages: list,
|
||||
token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool,
|
||||
patch_extra_lines: int = 0) -> Tuple[list, int, list]:
|
||||
"""
|
||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
|
||||
minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
|
||||
files.
|
||||
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
||||
"""
|
||||
patch_extra_lines_before: int = 0,
|
||||
patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches_extended = []
|
||||
patches_extended_tokens = []
|
||||
@ -196,7 +190,8 @@ def pr_generate_extended_diff(pr_languages: list,
|
||||
continue
|
||||
|
||||
# extend each patch with extra lines of context
|
||||
extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines)
|
||||
extended_patch = extend_patch(original_file_content_str, patch,
|
||||
patch_extra_lines_before, patch_extra_lines_after)
|
||||
if not extended_patch:
|
||||
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
|
||||
continue
|
||||
@ -405,10 +400,13 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
for lang in pr_languages:
|
||||
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||
|
||||
|
||||
# try first a single run with standard diff string, with patch extension, and no deletions
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=True)
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=True,
|
||||
patch_extra_lines_before=get_settings().config.patch_extra_lines_before,
|
||||
patch_extra_lines_after=get_settings().config.patch_extra_lines_after)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
return ["\n".join(patches_extended)] if patches_extended else []
|
||||
|
||||
|
@ -47,3 +47,17 @@ def apply_repo_settings(pr_url):
|
||||
os.remove(repo_settings_file)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e)
|
||||
|
||||
# enable switching models with a short definition
|
||||
if get_settings().config.model.lower()=='claude-3-5-sonnet':
|
||||
set_claude_model()
|
||||
|
||||
|
||||
def set_claude_model():
|
||||
"""
|
||||
set the claude-sonnet-3.5 model easily (even by users), just by stating: --config.model='claude-3-5-sonnet'
|
||||
"""
|
||||
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
get_settings().set('config.model', model_claude)
|
||||
get_settings().set('config.model_turbo', model_claude)
|
||||
get_settings().set('config.fallback_models', [model_claude])
|
||||
|
@ -20,7 +20,8 @@ max_commits_tokens = 500
|
||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||
custom_model_max_tokens=-1 # for models not in the default list
|
||||
#
|
||||
patch_extra_lines = 1
|
||||
patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch
|
||||
patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch
|
||||
secret_provider=""
|
||||
cli_mode=false
|
||||
ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs
|
||||
@ -96,7 +97,7 @@ enable_help_text=false
|
||||
|
||||
|
||||
[pr_code_suggestions] # /improve #
|
||||
max_context_tokens=10000
|
||||
max_context_tokens=14000
|
||||
num_code_suggestions=4
|
||||
commitable_code_suggestions = false
|
||||
extra_instructions = ""
|
||||
|
@ -286,7 +286,7 @@ class PRCodeSuggestions:
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=True)
|
||||
disable_extra_lines=False)
|
||||
|
||||
if self.patches_diff:
|
||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||
|
@ -1,44 +1,7 @@
|
||||
|
||||
# Generated by CodiumAI
|
||||
|
||||
|
||||
import pytest
|
||||
from pr_agent.algo.git_patch_processing import extend_patch
|
||||
|
||||
"""
|
||||
Code Analysis
|
||||
|
||||
Objective:
|
||||
The objective of the 'extend_patch' function is to extend a given patch to include a specified number of surrounding
|
||||
lines. This function takes in an original file string, a patch string, and the number of lines to extend the patch by,
|
||||
and returns the extended patch string.
|
||||
|
||||
Inputs:
|
||||
- original_file_str: a string representing the original file
|
||||
- patch_str: a string representing the patch to be extended
|
||||
- num_lines: an integer representing the number of lines to extend the patch by
|
||||
|
||||
Flow:
|
||||
1. Split the original file string and patch string into separate lines
|
||||
2. Initialize variables to keep track of the current hunk's start and size for both the original file and the patch
|
||||
3. Iterate through each line in the patch string
|
||||
4. If the line starts with '@@', extract the start and size values for both the original file and the patch, and
|
||||
calculate the extended start and size values
|
||||
5. Append the extended hunk header to the extended patch lines list
|
||||
6. Append the specified number of lines before the hunk to the extended patch lines list
|
||||
7. Append the current line to the extended patch lines list
|
||||
8. If the line is not a hunk header, append it to the extended patch lines list
|
||||
9. Return the extended patch string
|
||||
|
||||
Outputs:
|
||||
- extended_patch_str: a string representing the extended patch
|
||||
|
||||
Additional aspects:
|
||||
- The function uses regular expressions to extract the start and size values from the hunk header
|
||||
- The function handles cases where the start value of a hunk is less than the number of lines to extend by by setting
|
||||
the extended start value to 1
|
||||
- The function handles cases where the hunk extends beyond the end of the original file by only including lines up to
|
||||
the end of the original file in the extended patch
|
||||
"""
|
||||
from pr_agent.algo.pr_processing import pr_generate_extended_diff
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
|
||||
|
||||
class TestExtendPatch:
|
||||
@ -48,7 +11,8 @@ class TestExtendPatch:
|
||||
patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\n line3'
|
||||
num_lines = 1
|
||||
expected_output = '@@ -1,4 +1,4 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4'
|
||||
actual_output = extend_patch(original_file_str, patch_str, num_lines)
|
||||
actual_output = extend_patch(original_file_str, patch_str,
|
||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||
assert actual_output == expected_output
|
||||
|
||||
# Tests that the function returns an empty string when patch_str is empty
|
||||
@ -57,14 +21,16 @@ class TestExtendPatch:
|
||||
patch_str = ''
|
||||
num_lines = 1
|
||||
expected_output = ''
|
||||
assert extend_patch(original_file_str, patch_str, num_lines) == expected_output
|
||||
assert extend_patch(original_file_str, patch_str,
|
||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == expected_output
|
||||
|
||||
# Tests that the function returns the original patch when num_lines is 0
|
||||
def test_zero_num_lines(self):
|
||||
original_file_str = 'line1\nline2\nline3\nline4\nline5'
|
||||
patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3'
|
||||
num_lines = 0
|
||||
assert extend_patch(original_file_str, patch_str, num_lines) == patch_str
|
||||
assert extend_patch(original_file_str, patch_str,
|
||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == patch_str
|
||||
|
||||
# Tests that the function returns the original patch when patch_str contains no hunks
|
||||
def test_no_hunks(self):
|
||||
@ -78,9 +44,11 @@ class TestExtendPatch:
|
||||
def test_single_hunk(self):
|
||||
original_file_str = 'line1\nline2\nline3\nline4\nline5'
|
||||
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4'
|
||||
num_lines = 1
|
||||
|
||||
for num_lines in [1, 2, 3]: # check that even if we are over the number of lines in the file, the function still works
|
||||
expected_output = '@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5'
|
||||
actual_output = extend_patch(original_file_str, patch_str, num_lines)
|
||||
actual_output = extend_patch(original_file_str, patch_str,
|
||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||
assert actual_output == expected_output
|
||||
|
||||
# Tests the functionality of extending a patch with multiple hunks.
|
||||
@ -89,5 +57,59 @@ class TestExtendPatch:
|
||||
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501
|
||||
num_lines = 1
|
||||
expected_output = '@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501
|
||||
actual_output = extend_patch(original_file_str, patch_str, num_lines)
|
||||
actual_output = extend_patch(original_file_str, patch_str,
|
||||
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
|
||||
assert actual_output == expected_output
|
||||
|
||||
|
||||
class TestExtendedPatchMoreLines:
|
||||
class File:
|
||||
def __init__(self, base_file, patch, filename):
|
||||
self.base_file = base_file
|
||||
self.patch = patch
|
||||
self.filename = filename
|
||||
|
||||
@pytest.fixture
|
||||
def token_handler(self):
|
||||
# Create a TokenHandler instance with dummy data
|
||||
th = TokenHandler(system="System prompt", user="User prompt")
|
||||
th.prompt_tokens = 100
|
||||
return th
|
||||
|
||||
@pytest.fixture
|
||||
def pr_languages(self):
|
||||
# Create a list of languages with files containing base_file and patch data
|
||||
return [
|
||||
{
|
||||
'files': [
|
||||
self.File(base_file="line000\nline00\nline0\nline1\noriginal content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
|
||||
patch="@@ -5,5 +5,5 @@\n-original content\n+modified content\n line2\n line3\n line4\n line5",
|
||||
filename="file1"),
|
||||
self.File(base_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
|
||||
patch="@@ -6,5 +6,5 @@\nline6\nline7\nline8\n-line9\n+modified line9\nline10",
|
||||
filename="file2")
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def test_extend_patches_with_extra_lines(self, token_handler, pr_languages):
|
||||
patches_extended_no_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=False,
|
||||
patch_extra_lines_before=0,
|
||||
patch_extra_lines_after=0
|
||||
)
|
||||
|
||||
# Check that with no extra lines, the patches are the same as the original patches
|
||||
p0 = patches_extended_no_extra_lines[0].strip()
|
||||
p1 = patches_extended_no_extra_lines[1].strip()
|
||||
assert p0 == '## file1\n\n' + pr_languages[0]['files'][0].patch.strip()
|
||||
assert p1 == '## file2\n\n' + pr_languages[0]['files'][1].patch.strip()
|
||||
|
||||
patches_extended_with_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=False,
|
||||
patch_extra_lines_before=2,
|
||||
patch_extra_lines_after=1
|
||||
)
|
||||
|
||||
p0_extended = patches_extended_with_extra_lines[0].strip()
|
||||
assert p0_extended == '## file1\n\n@@ -3,8 +3,8 @@ \n line0\n line1\n-original content\n+modified content\n line2\n line3\n line4\n line5\n line6'
|
||||
|
Reference in New Issue
Block a user