This commit is contained in:
mrT23
2024-06-26 20:11:20 +03:00
parent 55a82382ef
commit 0f920bcc5b
3 changed files with 333 additions and 123 deletions

View File

@ -7,11 +7,14 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_diff_multiple_patchs, \
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels, ModelType, show_relevant_configurations
from pr_agent.algo.utils import set_custom_labels
from pr_agent.algo.utils import load_yaml, get_user_labels, ModelType, show_relevant_configurations, get_max_tokens, \
clip_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
from pr_agent.git_providers import get_git_provider, GithubProvider, get_git_provider_with_context
from pr_agent.git_providers.git_provider import get_main_pr_language
from pr_agent.log import get_logger
from pr_agent.servers.help import HelpMessage
@ -56,6 +59,7 @@ class PRDescription:
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
}
self.user_description = self.git_provider.get_user_description()
# Initialize the token handler
@ -163,32 +167,105 @@ class PRDescription:
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
return None
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model)
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling)
if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff
if patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
else:
get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None
else:
get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None
# get the diff in multiple patches, with the token handler only for the files prompt
get_logger().debug('large_pr_handling for describe')
token_handler_only_files_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user,
)
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
files_in_patches_list) = get_pr_diff_multiple_patchs(
self.git_provider, token_handler_only_files_prompt, model)
async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
# get the files prediction for each patch
file_description_str_list = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files)
else:
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
Args:
model (str): The name of the model to be used for generating the prediction.
# generate files_walkthrough string, with proper token handling
token_handler_only_description_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_description_prompts.system,
get_settings().pr_description_only_description_prompts.user)
files_walkthrough = "\n".join(file_description_str_list)
if remaining_files_list:
files_walkthrough += "\n\nNo more token budget. Additional unprocessed files:"
for file in remaining_files_list:
files_walkthrough += f"\n- {file}"
if deleted_files_list:
files_walkthrough += "\n\nAdditional deleted files:"
for file in deleted_files_list:
files_walkthrough += f"\n- {file}"
tokens_files_walkthrough = len(token_handler_only_description_prompt.encoder.encode(files_walkthrough))
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit
files_walkthrough = clip_tokens(files_walkthrough,
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
num_input_tokens=tokens_files_walkthrough)
Returns:
str: The generated AI prediction.
"""
# PR header inference
# toDo - add deleted and unprocessed files to the prompt ('files_walkthrough'), as extra data
get_logger().debug(f"PR diff only description", artifact=files_walkthrough)
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough,
prompt="pr_description_only_description_prompts")
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
if get_settings().pr_description.mention_extra_files:
for file in remaining_files_list:
extra_file_yaml = f"""\
- filename: |
{file}
changes_summary: |
...
changes_title: |
...
label: |
not processed (token-limit)
"""
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
# final processing
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough
if not load_yaml(self.prediction):
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
if load_yaml(prediction_headers):
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(variables)
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
@ -351,7 +428,7 @@ class PRDescription:
filename = file['filename'].replace("'", "`").replace('"', '`')
changes_summary = file['changes_summary']
changes_title = file['changes_title'].strip()
label = file.get('label')
label = file.get('label').strip().lower()
if label not in file_label_dict:
file_label_dict[label] = []
file_label_dict[label].append((filename, changes_title, changes_summary))
@ -392,6 +469,7 @@ class PRDescription:
for filename, file_changes_title, file_change_description in list_tuples:
filename = filename.replace("'", "`").rstrip()
filename_publish = filename.split("/")[-1]
file_changes_title_code = f"<code>{file_changes_title}</code>"
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
if len(file_changes_title_code_br) < (delta - 5):
@ -423,14 +501,16 @@ class PRDescription:
<hr>
{filename}
{file_change_description_br}
</details>
</td>
<td><a href="{link}">{diff_plus_minus}</a>{delta_nbsp}</td>
</tr>
"""
if use_collapsible_file_list: