patch_extra_lines_before and patch_extra_lines_after

This commit is contained in:
mrT23
2024-08-10 21:55:51 +03:00
parent 1a8b143f58
commit 61bdfd3b99
4 changed files with 22 additions and 36 deletions

View File

@ -66,7 +66,8 @@ By default, around any change in your PR, git patch provides three lines of cont
For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter: For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter:
``` ```
[config] [config]
patch_extra_lines=3 patch_extra_lines_before=6
patch_extra_lines_after=2
``` ```
Increasing this number provides more context to the model, but will also increase the token budget. Increasing this number provides more context to the model, but will also increase the token budget.

View File

@ -7,19 +7,8 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from pr_agent.log import get_logger from pr_agent.log import get_logger
def extend_patch(original_file_str, patch_str, num_lines) -> str: def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch_extra_lines_after=0) -> str:
""" if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0):
Extends the given patch to include a specified number of surrounding lines.
Args:
original_file_str (str): The original file to which the patch will be applied.
patch_str (str): The patch to be applied to the original file.
num_lines (int): The number of surrounding lines to include in the extended patch.
Returns:
str: The extended patch string.
"""
if not patch_str or num_lines == 0:
return patch_str return patch_str
if type(original_file_str) == bytes: if type(original_file_str) == bytes:
@ -43,7 +32,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
# finish previous hunk # finish previous hunk
if start1 != -1: if start1 != -1:
extended_patch_lines.extend( extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after])
res = list(match.groups()) res = list(match.groups())
for i in range(len(res)): for i in range(len(res)):
@ -55,10 +44,10 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
start1, size1, size2 = map(int, res[:3]) start1, size1, size2 = map(int, res[:3])
start2 = 0 start2 = 0
section_header = res[4] section_header = res[4]
extended_start1 = max(1, start1 - num_lines) extended_start1 = max(1, start1 - patch_extra_lines_before)
extended_size1 = size1 + (start1 - extended_start1) + num_lines extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after
extended_start2 = max(1, start2 - num_lines) extended_start2 = max(1, start2 - patch_extra_lines_before)
extended_size2 = size2 + (start2 - extended_start2) + num_lines extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after
extended_patch_lines.append( extended_patch_lines.append(
f'@@ -{extended_start1},{extended_size1} ' f'@@ -{extended_start1},{extended_size1} '
f'+{extended_start2},{extended_size2} @@ {section_header}') f'+{extended_start2},{extended_size2} @@ {section_header}')
@ -74,7 +63,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
# finish previous hunk # finish previous hunk
if start1 != -1: if start1 != -1:
extended_patch_lines.extend( extended_patch_lines.extend(
original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after])
extended_patch_str = '\n'.join(extended_patch_lines) extended_patch_str = '\n'.join(extended_patch_lines)
return extended_patch_str return extended_patch_str

View File

@ -33,9 +33,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
large_pr_handling=False, large_pr_handling=False,
return_remaining_files=False): return_remaining_files=False):
if disable_extra_lines: if disable_extra_lines:
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES_BEFORE = 0
PATCH_EXTRA_LINES_AFTER = 0
else: else:
PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before
PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after
try: try:
diff_files_original = git_provider.get_diff_files() diff_files_original = git_provider.get_diff_files()
@ -64,7 +66,8 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
# generate a standard diff string, with patch extension # generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES) pr_languages, token_handler, add_line_numbers_to_hunks,
patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER)
# if we are under the limit, return the full diff # if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
@ -174,17 +177,8 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
def pr_generate_extended_diff(pr_languages: list, def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool, add_line_numbers_to_hunks: bool,
patch_extra_lines: int = 0) -> Tuple[list, int, list]: patch_extra_lines_before: int = 0,
""" patch_extra_lines_after: int = 0) -> Tuple[list, int, list]:
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
"""
total_tokens = token_handler.prompt_tokens # initial tokens total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = [] patches_extended = []
patches_extended_tokens = [] patches_extended_tokens = []
@ -196,7 +190,8 @@ def pr_generate_extended_diff(pr_languages: list,
continue continue
# extend each patch with extra lines of context # extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines) extended_patch = extend_patch(original_file_content_str, patch,
patch_extra_lines_before, patch_extra_lines_after)
if not extended_patch: if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}") get_logger().warning(f"Failed to extend patch for file: {file.filename}")
continue continue

View File

@ -20,7 +20,8 @@ max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
custom_model_max_tokens=-1 # for models not in the default list custom_model_max_tokens=-1 # for models not in the default list
# #
patch_extra_lines = 1 patch_extra_lines_before = 6
patch_extra_lines_after = 2
secret_provider="" secret_provider=""
cli_mode=false cli_mode=false
ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs