mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 04:10:49 +08:00
s
This commit is contained in:
@ -24,26 +24,10 @@ ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n
|
||||
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
|
||||
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
||||
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
||||
"""
|
||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
|
||||
request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||
pull request.
|
||||
model (str): The name of the model used for tokenization.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
|
||||
diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
|
||||
extra lines of context. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
"""
|
||||
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str:
|
||||
if disable_extra_lines:
|
||||
PATCH_EXTRA_LINES = 0
|
||||
else:
|
||||
@ -87,39 +71,99 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
||||
# if we are over the limit, start pruning
|
||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||
f"pruning diff.")
|
||||
patches_compressed, modified_file_names, deleted_file_names, added_file_names, total_tokens_new = \
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
|
||||
|
||||
if large_pr_handling and len(patches_compressed_list) > 1:
|
||||
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
||||
return "" # return empty string, as we generate multiple patches with a different prompt
|
||||
|
||||
# return the first patch
|
||||
patches_compressed = patches_compressed_list[0]
|
||||
total_tokens_new = total_tokens_list[0]
|
||||
files_in_patch = files_in_patches_list[0]
|
||||
|
||||
# Insert additional information about added, modified, and deleted files if there is enough space
|
||||
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
|
||||
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
|
||||
final_diff = "\n".join(patches_compressed)
|
||||
delta_tokens = 10
|
||||
if added_file_names and (max_tokens - curr_token) > delta_tokens:
|
||||
added_list_str = ADDED_FILES_ + "\n".join(added_file_names)
|
||||
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
|
||||
if added_list_str:
|
||||
final_diff = final_diff + "\n\n" + added_list_str
|
||||
curr_token += token_handler.count_tokens(added_list_str) + 2
|
||||
if modified_file_names and (max_tokens - curr_token) > delta_tokens:
|
||||
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
|
||||
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
|
||||
if modified_list_str:
|
||||
final_diff = final_diff + "\n\n" + modified_list_str
|
||||
curr_token += token_handler.count_tokens(modified_list_str) + 2
|
||||
if deleted_file_names and (max_tokens - curr_token) > delta_tokens:
|
||||
deleted_list_str = DELETED_FILES_ + "\n".join(deleted_file_names)
|
||||
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
|
||||
if deleted_list_str:
|
||||
final_diff = final_diff + "\n\n" + deleted_list_str
|
||||
try:
|
||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||
f"deleted_list_str: {deleted_list_str}")
|
||||
except Exception as e:
|
||||
pass
|
||||
added_list_str = modified_list_str = deleted_list_str = ""
|
||||
unprocessed_files = []
|
||||
# generate the added, modified, and deleted files lists
|
||||
if (max_tokens - curr_token) > delta_tokens:
|
||||
for filename, file_values in file_dict.items():
|
||||
if filename in files_in_patch:
|
||||
continue
|
||||
if file_values['edit_type'] == EDIT_TYPE.ADDED:
|
||||
unprocessed_files.append(filename)
|
||||
if not added_list_str:
|
||||
added_list_str = ADDED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
added_list_str = added_list_str + f"\n{filename}"
|
||||
elif file_values['edit_type'] == EDIT_TYPE.MODIFIED or EDIT_TYPE.RENAMED:
|
||||
unprocessed_files.append(filename)
|
||||
if not modified_list_str:
|
||||
modified_list_str = MORE_MODIFIED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
modified_list_str = modified_list_str + f"\n{filename}"
|
||||
elif file_values['edit_type'] == EDIT_TYPE.DELETED:
|
||||
# unprocessed_files.append(filename) # not needed here, because the file was deleted, so no need to process it
|
||||
if not deleted_list_str:
|
||||
deleted_list_str = DELETED_FILES_ + f"\n{filename}"
|
||||
else:
|
||||
deleted_list_str = deleted_list_str + f"\n{filename}"
|
||||
|
||||
# prune the added, modified, and deleted files lists, and add them to the final diff
|
||||
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
|
||||
if added_list_str:
|
||||
final_diff = final_diff + "\n\n" + added_list_str
|
||||
curr_token += token_handler.count_tokens(added_list_str) + 2
|
||||
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
|
||||
if modified_list_str:
|
||||
final_diff = final_diff + "\n\n" + modified_list_str
|
||||
curr_token += token_handler.count_tokens(modified_list_str) + 2
|
||||
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
|
||||
if deleted_list_str:
|
||||
final_diff = final_diff + "\n\n" + deleted_list_str
|
||||
|
||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||
f"deleted_list_str: {deleted_list_str}")
|
||||
return final_diff
|
||||
|
||||
|
||||
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
|
||||
try:
|
||||
diff_files_original = git_provider.get_diff_files()
|
||||
except RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||
raise
|
||||
|
||||
diff_files = filter_ignored(diff_files_original)
|
||||
if diff_files != diff_files_original:
|
||||
try:
|
||||
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
|
||||
new_names = set([a.filename for a in diff_files])
|
||||
orig_names = set([a.filename for a in diff_files_original])
|
||||
get_logger().info(f"Filtered out files: {orig_names - new_names}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# get pr languages
|
||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||
if pr_languages:
|
||||
try:
|
||||
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
|
||||
|
||||
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
|
||||
|
||||
def pr_generate_extended_diff(pr_languages: list,
|
||||
token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool,
|
||||
@ -164,41 +208,16 @@ def pr_generate_extended_diff(pr_languages: list,
|
||||
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, int]:
|
||||
"""
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
|
||||
tokens used.
|
||||
Args:
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
|
||||
corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
|
||||
pull request.
|
||||
model (str): The model used for tokenization.
|
||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||
Returns:
|
||||
Tuple[list, list, list]: A tuple containing the following lists:
|
||||
- patches: A list of compressed diff patches for each file in the pull request.
|
||||
- modified_files_list: A list of file names that were skipped due to large patch size.
|
||||
- deleted_files_list: A list of file names that were deleted in the pull request.
|
||||
|
||||
Minimization techniques to reduce the number of tokens:
|
||||
0. Start from the largest diff patch to smaller ones
|
||||
1. Don't use extend context lines around diff
|
||||
2. Minimize deleted files
|
||||
3. Minimize deleted hunks
|
||||
4. Minimize all remaining files when you reach token limit
|
||||
"""
|
||||
|
||||
patches = []
|
||||
added_files_list = []
|
||||
modified_files_list = []
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]:
|
||||
deleted_files_list = []
|
||||
|
||||
# sort each one of the languages in top_langs by the number of tokens in the diff
|
||||
sorted_files = []
|
||||
for lang in top_langs:
|
||||
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||
|
||||
total_tokens = token_handler.prompt_tokens
|
||||
# generate patches for each file, and count tokens
|
||||
file_dict = {}
|
||||
for file in sorted_files:
|
||||
original_file_content_str = file.base_file
|
||||
new_file_content_str = file.head_file
|
||||
@ -210,55 +229,85 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
patch = handle_patch_deletions(patch, original_file_content_str,
|
||||
new_file_content_str, file.filename, file.edit_type)
|
||||
if patch is None:
|
||||
# if not deleted_files_list:
|
||||
# total_tokens += token_handler.count_tokens(DELETED_FILES_)
|
||||
if file.filename not in deleted_files_list:
|
||||
deleted_files_list.append(file.filename)
|
||||
# total_tokens += token_handler.count_tokens(file.filename) + 1
|
||||
continue
|
||||
|
||||
if convert_hunks_to_line_numbers:
|
||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
|
||||
|
||||
max_tokens_model = get_max_tokens(model)
|
||||
|
||||
# first iteration
|
||||
files_in_patches_list = []
|
||||
remaining_files_list = [file.filename for file in sorted_files]
|
||||
patches_list =[]
|
||||
total_tokens_list = []
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
|
||||
max_tokens_model, remaining_files_list, token_handler)
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
files_in_patches_list.append(files_in_patch_list)
|
||||
|
||||
# additional iterations (if needed)
|
||||
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
||||
if remaining_files_list:
|
||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
||||
file_dict,
|
||||
max_tokens_model,
|
||||
remaining_files_list, token_handler)
|
||||
patches_list.append(patches)
|
||||
total_tokens_list.append(total_tokens)
|
||||
files_in_patches_list.append(files_in_patch_list)
|
||||
else:
|
||||
break
|
||||
|
||||
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||
|
||||
|
||||
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
|
||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||
patches = []
|
||||
remaining_files_list_new = []
|
||||
files_in_patch_list = []
|
||||
for filename, data in file_dict.items():
|
||||
if filename not in remaining_files_list_prev:
|
||||
continue
|
||||
|
||||
patch = data['patch']
|
||||
new_patch_tokens = data['tokens']
|
||||
edit_type = data['edit_type']
|
||||
|
||||
# Hard Stop, no more tokens
|
||||
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
||||
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
|
||||
continue
|
||||
|
||||
# If the patch is too large, just show the file name
|
||||
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().warning(f"Patch too large, minimizing it, {file.filename}")
|
||||
if file.edit_type == EDIT_TYPE.ADDED:
|
||||
# if not added_files_list:
|
||||
# total_tokens += token_handler.count_tokens(ADDED_FILES_)
|
||||
if file.filename not in added_files_list:
|
||||
added_files_list.append(file.filename)
|
||||
# total_tokens += token_handler.count_tokens(file.filename) + 1
|
||||
else:
|
||||
# if not modified_files_list:
|
||||
# total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
|
||||
if file.filename not in modified_files_list:
|
||||
modified_files_list.append(file.filename)
|
||||
# total_tokens += token_handler.count_tokens(file.filename) + 1
|
||||
get_logger().warning(f"Patch too large, skipping it, {filename}")
|
||||
remaining_files_list_new.append(filename)
|
||||
continue
|
||||
|
||||
if patch:
|
||||
if not convert_hunks_to_line_numbers:
|
||||
patch_final = f"\n\n## file: '{file.filename.strip()}\n\n{patch.strip()}\n'"
|
||||
patch_final = f"\n\n## file: '{filename.strip()}\n\n{patch.strip()}\n'"
|
||||
else:
|
||||
patch_final = "\n\n" + patch.strip()
|
||||
patches.append(patch_final)
|
||||
total_tokens += token_handler.count_tokens(patch_final)
|
||||
files_in_patch_list.append(filename)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||
|
||||
return patches, modified_files_list, deleted_files_list, added_files_list, total_tokens
|
||||
get_logger().info(f"Tokens: {total_tokens}, last filename: {filename}")
|
||||
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||
@ -417,4 +466,47 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
final_diff = "\n".join(patches)
|
||||
final_diff_list.append(final_diff)
|
||||
|
||||
return final_diff_list
|
||||
return final_diff_list
|
||||
|
||||
|
||||
def prune_context(token_handler, curr_component_str, component_context_str, minium_output_tokens, max_tokens=None) -> Tuple[str, str]:
|
||||
try:
|
||||
# Get the max tokens possible
|
||||
if not max_tokens:
|
||||
get_logger().error(f"Max tokens not provided, using default value")
|
||||
max_tokens = get_max_tokens(get_settings().config.model_turbo)
|
||||
|
||||
# Check if the component + context are too long
|
||||
component_tokens = token_handler.count_tokens(curr_component_str)
|
||||
context_tokens = token_handler.count_tokens(component_context_str)
|
||||
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
|
||||
get_logger().info(
|
||||
f"Total tokens: {total_tokens}, context_tokens: {context_tokens}, component_tokens: {component_tokens}, prompt_tokens: {token_handler.prompt_tokens}, max_tokens: {max_tokens}")
|
||||
|
||||
# clip the context to fit the max tokens
|
||||
if total_tokens > max_tokens - minium_output_tokens:
|
||||
# clip the context to fit the max tokens
|
||||
max_context_tokens = max_tokens - (minium_output_tokens) - component_tokens - token_handler.prompt_tokens
|
||||
component_context_str = clip_tokens(component_context_str,
|
||||
max_context_tokens, num_input_tokens=context_tokens)
|
||||
context_tokens_old = context_tokens
|
||||
context_tokens = token_handler.count_tokens(component_context_str)
|
||||
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
|
||||
get_logger().info(f"Clipped context from {context_tokens_old} to {context_tokens} tokens, total tokens: {total_tokens}")
|
||||
|
||||
# clip the class itself to fit the max tokens, if needed
|
||||
delta = 50 # extra tokens to prevent clipping the component if not necessary
|
||||
if total_tokens > (max_tokens - minium_output_tokens-delta):
|
||||
max_context_tokens = max_tokens - minium_output_tokens - context_tokens - token_handler.prompt_tokens # notice 'context_tokens'
|
||||
curr_component_str= clip_tokens(curr_component_str,
|
||||
max_context_tokens, num_input_tokens=component_tokens)
|
||||
component_tokens_new = token_handler.count_tokens(curr_component_str)
|
||||
total_tokens = component_tokens_new + context_tokens + token_handler.prompt_tokens
|
||||
get_logger().info(f"Clipped component to fit the max tokens, from {component_tokens} to {component_tokens_new} tokens, total tokens: {total_tokens}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
component_context_str = ''
|
||||
curr_component_str = ''
|
||||
|
||||
return curr_component_str, component_context_str
|
Reference in New Issue
Block a user