diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 59b9da26..bc17aedd 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -377,9 +377,25 @@ def get_pr_multi_diffs(git_provider: GitProvider, patch = convert_to_hunks_with_lines_numbers(patch, file) new_patch_tokens = token_handler.count_tokens(patch) - if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: - get_logger().warning(f"Patch too large, skipping: {file.filename}") - continue + if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( + model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if get_settings().config.get('large_patch_policy', 'skip') == 'skip': + get_logger().warning(f"Patch too large, skipping: {file.filename}") + continue + elif get_settings().config.get('large_patch_policy') == 'clip': + delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens)) + patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True) + new_patch_tokens = token_handler.count_tokens(patch_clipped) + if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( + model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + get_logger().warning(f"Patch too large, skipping: {file.filename}") + continue + else: + get_logger().info(f"Clipped large patch for file: {file.filename}") + patch = patch_clipped + else: + get_logger().warning(f"Patch too large, skipping: {file.filename}") + continue if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD): final_diff = "\n".join(patches) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 0888a15a..d24d4244 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -552,7 +552,7 @@ def get_max_tokens(model): return max_tokens_model -def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: +def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str: """ Clip the number of tokens in a string to a maximum number of tokens. @@ -575,6 +575,8 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: chars_per_token = num_chars / num_input_tokens num_output_chars = int(chars_per_token * max_tokens) clipped_text = text[:num_output_chars] + if delete_last_line: + clipped_text = clipped_text.rsplit('\n', 1)[0] if add_three_dots: clipped_text += "\n...(truncated)" return clipped_text diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index c50ab0f7..98d14414 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -20,6 +20,7 @@ cli_mode=false ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs ai_disclaimer="" # Pro feature, full text for the AI disclaimer output_relevant_configurations=false +large_patch_policy = "clip" # "clip", "skip" [pr_reviewer] # /review # # enable/disable features