enable ai_metadata

This commit is contained in:
mrT23
2024-09-07 17:25:05 +03:00
parent 24f7e8622f
commit 8706f643ef
32 changed files with 338 additions and 117 deletions

View File

@ -7,7 +7,8 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models, \
add_ai_metadata_to_diff_files
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType, show_relevant_configurations
from pr_agent.config_loader import get_settings
@ -54,16 +55,27 @@ class PRCodeSuggestions:
self.prediction = None
self.pr_url = pr_url
self.cli_mode = cli_mode
self.pr_description, self.pr_description_files = (
self.git_provider.get_pr_description(split_changes_walkthrough=True))
if (self.pr_description_files and get_settings().get("config.is_auto_command", False) and
get_settings().get("config.enable_ai_metadata", False)):
add_ai_metadata_to_diff_files(self.git_provider, self.pr_description_files)
get_logger().debug(f"AI metadata added to the this command")
else:
get_settings().set("config.enable_ai_metadata", False)
get_logger().debug(f"AI metadata is disabled for this command")
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"description": self.pr_description,
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
"relevant_best_practices": "",
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
}
if 'claude' in get_settings().config.model:
# prompt for Claude, with minor adjustments
@ -505,7 +517,8 @@ class PRCodeSuggestions:
async def _prepare_prediction_extended(self, model: str) -> dict:
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
max_calls=get_settings().pr_code_suggestions.max_number_of_calls,
pr_description_files =self.pr_description_files)
if self.patches_diff_list:
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
get_logger().debug(f"PR diff:", artifact=self.patches_diff_list)