mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-13 09:10:38 +08:00
extend additional files
This commit is contained in:
@ -71,10 +71,6 @@ class PRAddDocs:
|
||||
async def _prepare_prediction(self, model: str):
|
||||
get_logger().info('Getting PR diff...')
|
||||
|
||||
# Disable adding docs to scripts and other non-relevant text files
|
||||
from pr_agent.algo.language_handler import bad_extensions
|
||||
bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist
|
||||
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
|
@ -4,6 +4,7 @@ import re
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import yaml
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
@ -169,12 +170,21 @@ class PRDescription:
|
||||
return None
|
||||
|
||||
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
|
||||
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling)
|
||||
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
|
||||
if isinstance(output, tuple):
|
||||
patches_diff, remaining_files_list = output
|
||||
else:
|
||||
patches_diff = output
|
||||
remaining_files_list = []
|
||||
if not large_pr_handling or patches_diff:
|
||||
self.patches_diff = patches_diff
|
||||
if patches_diff:
|
||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
|
||||
if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and
|
||||
get_settings().pr_description.mention_extra_files):
|
||||
get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files")
|
||||
self.prediction = await self.extend_additional_files(remaining_files_list)
|
||||
else:
|
||||
get_logger().error(f"Error getting PR diff {self.pr_id}")
|
||||
self.prediction = None
|
||||
@ -273,6 +283,35 @@ class PRDescription:
|
||||
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
||||
self.prediction = prediction_headers
|
||||
|
||||
async def extend_additional_files(self, remaining_files_list) -> str:
|
||||
prediction = self.prediction
|
||||
try:
|
||||
original_prediction_dict = load_yaml(self.prediction)
|
||||
prediction_extra = "pr_files:"
|
||||
for file in remaining_files_list:
|
||||
extra_file_yaml = f"""\
|
||||
- filename: |
|
||||
{file}
|
||||
changes_summary: |
|
||||
...
|
||||
changes_title: |
|
||||
...
|
||||
label: |
|
||||
additional files (token-limit)
|
||||
"""
|
||||
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||
prediction_extra_dict = load_yaml(prediction_extra)
|
||||
# merge the two dictionaries
|
||||
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
||||
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
||||
new_yaml = yaml.dump(original_prediction_dict)
|
||||
if load_yaml(new_yaml):
|
||||
prediction = new_yaml
|
||||
return prediction
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
|
||||
return self.prediction
|
||||
|
||||
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = patches_diff # update diff
|
||||
|
Reference in New Issue
Block a user