diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index f51c4415..f7aa6b60 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -46,6 +46,7 @@ MAX_TOKENS = { 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000, 'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000, 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000, + 'claude-3-5-sonnet': 100000, 'groq/llama3-8b-8192': 8192, 'groq/llama3-70b-8192': 8192, 'groq/mixtral-8x7b-32768': 32768, diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 80a8ded7..1cad30b6 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -400,10 +400,13 @@ def get_pr_multi_diffs(git_provider: GitProvider, for lang in pr_languages: sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True)) - # try first a single run with standard diff string, with patch extension, and no deletions patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks=True) + pr_languages, token_handler, add_line_numbers_to_hunks=True, + patch_extra_lines_before=get_settings().config.patch_extra_lines_before, + patch_extra_lines_after=get_settings().config.patch_extra_lines_after) + + # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): return ["\n".join(patches_extended)] if patches_extended else [] diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 7598462d..6264631a 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -20,7 +20,7 @@ max_commits_tokens = 500 max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. custom_model_max_tokens=-1 # for models not in the default list # -patch_extra_lines_before = 6 +patch_extra_lines_before = 4 patch_extra_lines_after = 2 secret_provider="" cli_mode=false @@ -97,7 +97,7 @@ enable_help_text=false [pr_code_suggestions] # /improve # -max_context_tokens=10000 +max_context_tokens=16000 num_code_suggestions=4 commitable_code_suggestions = false extra_instructions = "" diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index f98590ce..1a965192 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -286,7 +286,7 @@ class PRCodeSuggestions: self.token_handler, model, add_line_numbers_to_hunks=True, - disable_extra_lines=True) + disable_extra_lines=False) if self.patches_diff: get_logger().debug(f"PR diff", artifact=self.patches_diff) diff --git a/tests/unittest/test_extend_patch.py b/tests/unittest/test_extend_patch.py index ba0af881..cb2b3c9c 100644 --- a/tests/unittest/test_extend_patch.py +++ b/tests/unittest/test_extend_patch.py @@ -1,44 +1,6 @@ - -# Generated by CodiumAI - - +import pytest from pr_agent.algo.git_patch_processing import extend_patch - -""" -Code Analysis - -Objective: -The objective of the 'extend_patch' function is to extend a given patch to include a specified number of surrounding -lines. This function takes in an original file string, a patch string, and the number of lines to extend the patch by, -and returns the extended patch string. - -Inputs: -- original_file_str: a string representing the original file -- patch_str: a string representing the patch to be extended -- num_lines: an integer representing the number of lines to extend the patch by - -Flow: -1. Split the original file string and patch string into separate lines -2. Initialize variables to keep track of the current hunk's start and size for both the original file and the patch -3. Iterate through each line in the patch string -4. If the line starts with '@@', extract the start and size values for both the original file and the patch, and -calculate the extended start and size values -5. Append the extended hunk header to the extended patch lines list -6. Append the specified number of lines before the hunk to the extended patch lines list -7. Append the current line to the extended patch lines list -8. If the line is not a hunk header, append it to the extended patch lines list -9. Return the extended patch string - -Outputs: -- extended_patch_str: a string representing the extended patch - -Additional aspects: -- The function uses regular expressions to extract the start and size values from the hunk header -- The function handles cases where the start value of a hunk is less than the number of lines to extend by by setting -the extended start value to 1 -- The function handles cases where the hunk extends beyond the end of the original file by only including lines up to -the end of the original file in the extended patch -""" +from pr_agent.algo.token_handler import TokenHandler class TestExtendPatch: @@ -48,7 +10,8 @@ class TestExtendPatch: patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3' num_lines = 1 expected_output = '@@ -1,4 +1,4 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4' - actual_output = extend_patch(original_file_str, patch_str, num_lines) + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output # Tests that the function returns an empty string when patch_str is empty @@ -57,14 +20,16 @@ class TestExtendPatch: patch_str = '' num_lines = 1 expected_output = '' - assert extend_patch(original_file_str, patch_str, num_lines) == expected_output + assert extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == expected_output # Tests that the function returns the original patch when num_lines is 0 def test_zero_num_lines(self): original_file_str = 'line1\nline2\nline3\nline4\nline5' patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3' num_lines = 0 - assert extend_patch(original_file_str, patch_str, num_lines) == patch_str + assert extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == patch_str # Tests that the function returns the original patch when patch_str contains no hunks def test_no_hunks(self): @@ -80,7 +45,8 @@ class TestExtendPatch: patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4' num_lines = 1 expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5' - actual_output = extend_patch(original_file_str, patch_str, num_lines) + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output # Tests the functionality of extending a patch with multiple hunks. @@ -89,5 +55,59 @@ class TestExtendPatch: patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501 num_lines = 1 expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5\n@@ -3,3 +3,3 @@ init2()\nline3\n-line4\n+new_line4\nline5' # noqa: E501 - actual_output = extend_patch(original_file_str, patch_str, num_lines) + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output + + +class PRProcessingTest: + class File: + def __init__(self, base_file, patch, filename): + self.base_file = base_file + self.patch = patch + self.filename = filename + + @pytest.fixture + def token_handler(self): + # Create a TokenHandler instance with dummy data + th = TokenHandler(system="System prompt", user="User prompt") + th.prompt_tokens = 100 + return th + + @pytest.fixture + def pr_languages(self): + # Create a list of languages with files containing base_file and patch data + return [ + { + 'files': [ + self.File(base_file="line000\nline00\nline0\nline1\noriginal content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10", + patch="@@ -5,5 +5,5 @@\n-original content\n+modified content\nline2\nline3\nline4\nline5", + filename="file1"), + self.File(base_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10", + patch="@@ -6,5 +6,5 @@\nline6\nline7\nline8\n-line9\n+modified line9\nline10", + filename="file2") + ] + } + ] + + def test_extend_patches_with_extra_lines(self, token_handler, pr_languages): + patches_extended_no_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff( + pr_languages, token_handler, add_line_numbers_to_hunks=False, + patch_extra_lines_before=0, + patch_extra_lines_after=0 + ) + + # Check that with no extra lines, the patches are the same as the original patches + p0 = patches_extended_no_extra_lines[0].strip() + p1 = patches_extended_no_extra_lines[1].strip() + assert p0 == '## file1\n\n' + pr_languages[0]['files'][0].patch.strip() + assert p1 == '## file2\n\n' + pr_languages[0]['files'][1].patch.strip() + + patches_extended_with_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff( + pr_languages, token_handler, add_line_numbers_to_hunks=False, + patch_extra_lines_before=2, + patch_extra_lines_after=1 + ) + + p0_extended = patches_extended_with_extra_lines[0].strip() + assert p0_extended == '## file1\n\n@@ -3,8 +3,8 @@ \nline0\nline1\n-original content\n+modified content\nline2\nline3\nline4\nline5\nline6'