Merge remote-tracking branch 'origin/main' into tr/review_redesign

This commit is contained in:
mrT23
2024-07-03 07:54:26 +03:00
4 changed files with 93 additions and 34 deletions

View File

@ -26,8 +26,12 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str: model: str,
add_line_numbers_to_hunks: bool = False,
disable_extra_lines: bool = False,
large_pr_handling=False,
return_remaining_files=False):
if disable_extra_lines: if disable_extra_lines:
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES = 0
else: else:
@ -72,7 +76,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.") f"pruning diff.")
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
if large_pr_handling and len(patches_compressed_list) > 1: if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.") get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
@ -129,7 +133,10 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}") f"deleted_list_str: {deleted_list_str}")
return final_diff if not return_remaining_files:
return final_diff
else:
return final_diff, remaining_files_list
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
@ -159,7 +166,7 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
pass pass
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
@ -208,7 +215,8 @@ def pr_generate_extended_diff(pr_languages: list,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]: convert_hunks_to_line_numbers: bool,
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = [] deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff # sort each one of the languages in top_langs by the number of tokens in the diff
@ -253,18 +261,20 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
files_in_patches_list.append(files_in_patch_list) files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed) # additional iterations (if needed)
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize if large_pr_handling:
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1): NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
if remaining_files_list: for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, if remaining_files_list:
file_dict, total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
max_tokens_model, file_dict,
remaining_files_list, token_handler) max_tokens_model,
patches_list.append(patches) remaining_files_list, token_handler)
total_tokens_list.append(total_tokens) if patches:
files_in_patches_list.append(files_in_patch_list) patches_list.append(patches)
else: total_tokens_list.append(total_tokens)
break files_in_patches_list.append(files_in_patch_list)
else:
break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list

View File

@ -76,7 +76,8 @@ use_description_markers=false
include_generated_by_header=true include_generated_by_header=true
# large pr mode 💎 # large pr mode 💎
enable_large_pr_handling=true enable_large_pr_handling=true
max_ai_calls=3 max_ai_calls=4
async_ai_calls=true
mention_extra_files=true mention_extra_files=true
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other'] #custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']

View File

@ -71,10 +71,6 @@ class PRAddDocs:
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...') get_logger().info('Getting PR diff...')
# Disable adding docs to scripts and other non-relevant text files
from pr_agent.algo.language_handler import bad_extensions
bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist
self.patches_diff = get_pr_diff(self.git_provider, self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler, self.token_handler,
model, model,

View File

@ -1,8 +1,10 @@
import asyncio
import copy import copy
import re import re
from functools import partial from functools import partial
from typing import List, Tuple from typing import List, Tuple
import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -168,13 +170,21 @@ class PRDescription:
return None return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
large_pr_handling=large_pr_handling) if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
patches_diff = output
remaining_files_list = []
if not large_pr_handling or patches_diff: if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff self.patches_diff = patches_diff
if patches_diff: if patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff) get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt") self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and
get_settings().pr_description.mention_extra_files):
get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files")
self.prediction = await self.extend_additional_files(remaining_files_list)
else: else:
get_logger().error(f"Error getting PR diff {self.pr_id}") get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None self.prediction = None
@ -192,13 +202,27 @@ class PRDescription:
self.git_provider, token_handler_only_files_prompt, model) self.git_provider, token_handler_only_files_prompt, model)
# get the files prediction for each patch # get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
results = []
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
results.append(prediction_files)
else: # async calls
tasks = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = [] file_description_str_list = []
for i, patches in enumerate(patches_compressed_list): for i, result in enumerate(results):
patches_diff = "\n".join(patches) prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'): if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip() prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files) file_description_str_list.append(prediction_files)
@ -248,7 +272,7 @@ class PRDescription:
changes_title: | changes_title: |
... ...
label: | label: |
not processed (token-limit) additional files (token-limit)
""" """
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip() files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
# final processing # final processing
@ -259,6 +283,34 @@ class PRDescription:
get_logger().debug(f"Using only headers for describe {self.pr_id}") get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers self.prediction = prediction_headers
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
original_prediction_dict = load_yaml(self.prediction)
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
- filename: |
{file}
changes_summary: |
...
changes_title: |
...
label: |
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra)
# merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml):
prediction = new_yaml
return prediction
except Exception as e:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
@ -481,9 +533,9 @@ class PRDescription:
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>" filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
diff_plus_minus = "" diff_plus_minus = ""
delta_nbsp = "" delta_nbsp = ""
diff_files = self.git_provider.diff_files diff_files = self.git_provider.get_diff_files()
for f in diff_files: for f in diff_files:
if f.filename.lower() == filename.lower(): if f.filename.lower().strip('/') == filename.lower().strip('/'):
num_plus_lines = f.num_plus_lines num_plus_lines = f.num_plus_lines
num_minus_lines = f.num_minus_lines num_minus_lines = f.num_minus_lines
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}" diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"