diff --git a/docs/docs/usage-guide/additional_configurations.md b/docs/docs/usage-guide/additional_configurations.md index 4ae01414..8d8e75dc 100644 --- a/docs/docs/usage-guide/additional_configurations.md +++ b/docs/docs/usage-guide/additional_configurations.md @@ -66,7 +66,8 @@ By default, around any change in your PR, git patch provides three lines of cont For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter: ``` [config] -patch_extra_lines=3 +patch_extra_lines_before=6 +patch_extra_lines_after=2 ``` Increasing this number provides more context to the model, but will also increase the token budget. diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 15343c97..5cb18b3a 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -7,19 +7,8 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo from pr_agent.log import get_logger -def extend_patch(original_file_str, patch_str, num_lines) -> str: - """ - Extends the given patch to include a specified number of surrounding lines. - - Args: - original_file_str (str): The original file to which the patch will be applied. - patch_str (str): The patch to be applied to the original file. - num_lines (int): The number of surrounding lines to include in the extended patch. - - Returns: - str: The extended patch string. - """ - if not patch_str or num_lines == 0: +def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch_extra_lines_after=0) -> str: + if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0): return patch_str if type(original_file_str) == bytes: @@ -43,7 +32,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: # finish previous hunk if start1 != -1: extended_patch_lines.extend( - original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) + original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]) res = list(match.groups()) for i in range(len(res)): @@ -55,10 +44,10 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: start1, size1, size2 = map(int, res[:3]) start2 = 0 section_header = res[4] - extended_start1 = max(1, start1 - num_lines) - extended_size1 = size1 + (start1 - extended_start1) + num_lines - extended_start2 = max(1, start2 - num_lines) - extended_size2 = size2 + (start2 - extended_start2) + num_lines + extended_start1 = max(1, start1 - patch_extra_lines_before) + extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after + extended_start2 = max(1, start2 - patch_extra_lines_before) + extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after extended_patch_lines.append( f'@@ -{extended_start1},{extended_size1} ' f'+{extended_start2},{extended_size2} @@ {section_header}') @@ -74,7 +63,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: # finish previous hunk if start1 != -1: extended_patch_lines.extend( - original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) + original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after]) extended_patch_str = '\n'.join(extended_patch_lines) return extended_patch_str diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index d635ec35..80a8ded7 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -33,9 +33,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, large_pr_handling=False, return_remaining_files=False): if disable_extra_lines: - PATCH_EXTRA_LINES = 0 + PATCH_EXTRA_LINES_BEFORE = 0 + PATCH_EXTRA_LINES_AFTER = 0 else: - PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines + PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before + PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after try: diff_files_original = git_provider.get_diff_files() @@ -64,7 +66,8 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, # generate a standard diff string, with patch extension patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES) + pr_languages, token_handler, add_line_numbers_to_hunks, + patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): @@ -174,17 +177,8 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, add_line_numbers_to_hunks: bool, - patch_extra_lines: int = 0) -> Tuple[list, int, list]: - """ - Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff - minimization techniques if needed. - - Args: - - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding - files. - - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. - """ + patch_extra_lines_before: int = 0, + patch_extra_lines_after: int = 0) -> Tuple[list, int, list]: total_tokens = token_handler.prompt_tokens # initial tokens patches_extended = [] patches_extended_tokens = [] @@ -196,7 +190,8 @@ def pr_generate_extended_diff(pr_languages: list, continue # extend each patch with extra lines of context - extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines) + extended_patch = extend_patch(original_file_content_str, patch, + patch_extra_lines_before, patch_extra_lines_after) if not extended_patch: get_logger().warning(f"Failed to extend patch for file: {file.filename}") continue diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 5336a48a..7598462d 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -20,7 +20,8 @@ max_commits_tokens = 500 max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. custom_model_max_tokens=-1 # for models not in the default list # -patch_extra_lines = 1 +patch_extra_lines_before = 6 +patch_extra_lines_after = 2 secret_provider="" cli_mode=false ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs