Files
pr-agent/pr_agent/algo/pr_processing.py

420 lines
20 KiB
Python

from __future__ import annotations
import traceback
from typing import Callable, List, Tuple
from github import RateLimitExceededException
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens, clip_tokens, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from pr_agent.log import get_logger
DELETED_FILES_ = "Deleted files:\n"
MORE_MODIFIED_FILES_ = "Additional modified files (insufficient token budget to process):\n"
ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n"
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
"""
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
extra lines of context. Defaults to False.
Returns:
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
"""
if disable_extra_lines:
PATCH_EXTRA_LINES = 0
else:
PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e:
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e:
pass
# generate a standard diff string, with patch extension
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
get_logger().info(f"Tokens: {total_tokens}, total tokens under limit: {get_max_tokens(model)}, "
f"returning full diff.")
return "\n".join(patches_extended)
# if we are over the limit, start pruning
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed, modified_file_names, deleted_file_names, added_file_names, total_tokens_new = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
# Insert additional information about added, modified, and deleted files if there is enough space
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
final_diff = "\n".join(patches_compressed)
delta_tokens = 10
if added_file_names and (max_tokens - curr_token) > delta_tokens:
added_list_str = ADDED_FILES_ + "\n".join(added_file_names)
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
if added_list_str:
final_diff = final_diff + "\n\n" + added_list_str
curr_token += token_handler.count_tokens(added_list_str) + 2
if modified_file_names and (max_tokens - curr_token) > delta_tokens:
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
if modified_list_str:
final_diff = final_diff + "\n\n" + modified_list_str
curr_token += token_handler.count_tokens(modified_list_str) + 2
if deleted_file_names and (max_tokens - curr_token) > delta_tokens:
deleted_list_str = DELETED_FILES_ + "\n".join(deleted_file_names)
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
try:
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
except Exception as e:
pass
return final_diff
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
patch_extra_lines: int = 0) -> Tuple[list, int, list]:
"""
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Args:
- pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
patch = file.patch
if not patch:
continue
# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines)
if not extended_patch:
get_logger().warning(f"Failed to extend patch for file: {file.filename}")
continue
full_extended_patch = f"\n\n## {file.filename}\n\n{extended_patch}\n"
if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended_tokens.append(patch_tokens)
patches_extended.append(full_extended_patch)
return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, int]:
"""
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns:
Tuple[list, list, list]: A tuple containing the following lists:
- patches: A list of compressed diff patches for each file in the pull request.
- modified_files_list: A list of file names that were skipped due to large patch size.
- deleted_files_list: A list of file names that were deleted in the pull request.
Minimization techniques to reduce the number of tokens:
0. Start from the largest diff patch to smaller ones
1. Don't use extend context lines around diff
2. Minimize deleted files
3. Minimize deleted hunks
4. Minimize all remaining files when you reach token limit
"""
patches = []
added_files_list = []
modified_files_list = []
deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff
sorted_files = []
for lang in top_langs:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
total_tokens = token_handler.prompt_tokens
for file in sorted_files:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# removing delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename, file.edit_type)
if patch is None:
# if not deleted_files_list:
# total_tokens += token_handler.count_tokens(DELETED_FILES_)
if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
continue
if convert_hunks_to_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
# Hard Stop, no more tokens
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if get_settings().config.verbosity_level >= 2:
get_logger().warning(f"Patch too large, minimizing it, {file.filename}")
if file.edit_type == EDIT_TYPE.ADDED:
# if not added_files_list:
# total_tokens += token_handler.count_tokens(ADDED_FILES_)
if file.filename not in added_files_list:
added_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
else:
# if not modified_files_list:
# total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
if file.filename not in modified_files_list:
modified_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
continue
if patch:
if not convert_hunks_to_line_numbers:
patch_final = f"\n\n## file: '{file.filename.strip()}\n\n{patch.strip()}\n'"
else:
patch_final = "\n\n" + patch.strip()
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches, modified_files_list, deleted_files_list, added_files_list, total_tokens
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
try:
get_logger().debug(
f"Generating prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}"
)
get_settings().set("openai.deployment_id", deployment_id)
return await f(model)
except:
get_logger().warning(
f"Failed to generate prediction with {model}"
f"{(' from deployment ' + deployment_id) if deployment_id else ''}: "
f"{traceback.format_exc()}"
)
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
return all_models
def _get_all_deployments(all_models: List[str]) -> List[str]:
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
if len(all_deployments) < len(all_models):
raise ValueError(f"The number of deployments ({len(all_deployments)}) "
f"is less than the number of models ({len(all_models)})")
else:
all_deployments = [deployment_id] * len(all_models)
return all_deployments
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
Args:
git_provider (GitProvider): An object that provides access to Git provider APIs.
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
model (str): The name of the model.
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
Returns:
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
Raises:
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
"""
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files)
# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# Sort files within each language group by tokens in descending order
sorted_files = []
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True)
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return ["\n".join(patches_extended)]
patches = []
final_diff_list = []
total_tokens = token_handler.prompt_tokens
call_number = 1
for file in sorted_files:
if call_number > max_calls:
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Reached max calls ({max_calls})")
break
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type)
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if get_settings().config.get('large_patch_policy', 'skip') == 'skip':
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
else:
get_logger().info(f"Clipped large patch for file: {file.filename}")
patch = patch_clipped
else:
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Call number: {call_number}")
if patch:
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
# Add the last chunk
if patches:
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
return final_diff_list