Merge remote-tracking branch 'origin/main' into tr/review_redesign

This commit is contained in:
mrT23
2024-07-03 07:54:26 +03:00
4 changed files with 93 additions and 34 deletions

View File

@ -71,10 +71,6 @@ class PRAddDocs:
async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
# Disable adding docs to scripts and other non-relevant text files
from pr_agent.algo.language_handler import bad_extensions
bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,

View File

@ -1,8 +1,10 @@
import asyncio
import copy
import re
from functools import partial
from typing import List, Tuple
import yaml
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
@ -168,13 +170,21 @@ class PRDescription:
return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model,
large_pr_handling=large_pr_handling)
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
if isinstance(output, tuple):
patches_diff, remaining_files_list = output
else:
patches_diff = output
remaining_files_list = []
if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff
if patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and
get_settings().pr_description.mention_extra_files):
get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files")
self.prediction = await self.extend_additional_files(remaining_files_list)
else:
get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None
@ -192,13 +202,27 @@ class PRDescription:
self.git_provider, token_handler_only_files_prompt, model)
# get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
results = []
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
results.append(prediction_files)
else: # async calls
tasks = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
for i, result in enumerate(results):
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files)
@ -248,7 +272,7 @@ class PRDescription:
changes_title: |
...
label: |
not processed (token-limit)
additional files (token-limit)
"""
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
# final processing
@ -259,6 +283,34 @@ class PRDescription:
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
original_prediction_dict = load_yaml(self.prediction)
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
- filename: |
{file}
changes_summary: |
...
changes_title: |
...
label: |
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra)
# merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml):
prediction = new_yaml
return prediction
except Exception as e:
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
return self.prediction
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars)
@ -481,9 +533,9 @@ class PRDescription:
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
diff_plus_minus = ""
delta_nbsp = ""
diff_files = self.git_provider.diff_files
diff_files = self.git_provider.get_diff_files()
for f in diff_files:
if f.filename.lower() == filename.lower():
if f.filename.lower().strip('/') == filename.lower().strip('/'):
num_plus_lines = f.num_plus_lines
num_minus_lines = f.num_minus_lines
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"