extend additional files

This commit is contained in:
mrT23
2024-06-30 18:38:06 +03:00
parent 3e6263e1cc
commit f2cb70ea67
3 changed files with 66 additions and 22 deletions

View File

@ -26,8 +26,12 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str: model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
large_pr_handling=False,
return_remaining_files=False):
if disable_extra_lines: if disable_extra_lines:
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES = 0
else: else:
@ -72,7 +76,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.") f"pruning diff.")
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
if large_pr_handling and len(patches_compressed_list) > 1: if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.") get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
@ -129,7 +133,10 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}") f"deleted_list_str: {deleted_list_str}")
return final_diff if not return_remaining_files:
return final_diff
else:
return final_diff, remaining_files_list
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
@ -208,7 +215,8 @@ def pr_generate_extended_diff(pr_languages: list,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]: convert_hunks_to_line_numbers: bool,
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = [] deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff # sort each one of the languages in top_langs by the number of tokens in the diff
@ -253,18 +261,19 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
files_in_patches_list.append(files_in_patch_list) files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed) # additional iterations (if needed)
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize if large_pr_handling:
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1): NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
if remaining_files_list: for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, if remaining_files_list:
file_dict, total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
max_tokens_model, file_dict,
remaining_files_list, token_handler) max_tokens_model,
patches_list.append(patches) remaining_files_list, token_handler)
total_tokens_list.append(total_tokens) patches_list.append(patches)
files_in_patches_list.append(files_in_patch_list) total_tokens_list.append(total_tokens)
else: files_in_patches_list.append(files_in_patch_list)
break else:
break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list

View File

@ -71,10 +71,6 @@ class PRAddDocs:
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...') get_logger().info('Getting PR diff...')
# Disable adding docs to scripts and other non-relevant text files
from pr_agent.algo.language_handler import bad_extensions
bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist
self.patches_diff = get_pr_diff(self.git_provider, self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler, self.token_handler,
model, model,

View File

@ -4,6 +4,7 @@ import re
from functools import partial from functools import partial
from typing import List, Tuple from typing import List, Tuple
import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -169,12 +170,21 @@ class PRDescription:
return None return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling) output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
patches_diff = output
remaining_files_list = []
if not large_pr_handling or patches_diff: if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff self.patches_diff = patches_diff
if patches_diff: if patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff) get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt") self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and
get_settings().pr_description.mention_extra_files):
get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files")
self.prediction = await self.extend_additional_files(remaining_files_list)
else: else:
get_logger().error(f"Error getting PR diff {self.pr_id}") get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None self.prediction = None
@ -273,6 +283,35 @@ class PRDescription:
get_logger().debug(f"Using only headers for describe {self.pr_id}") get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers self.prediction = prediction_headers
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
original_prediction_dict = load_yaml(self.prediction)
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
- filename: |
{file}
changes_summary: |
...
changes_title: |
...
label: |
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra)
# merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml):
prediction = new_yaml
return prediction
except Exception as e:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff variables["diff"] = patches_diff # update diff