mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 20:30:41 +08:00
Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler
This commit is contained in:
@ -1,7 +1,5 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
@ -9,10 +7,11 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml
|
||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRDescription:
|
||||
@ -31,6 +30,11 @@ class PRDescription:
|
||||
)
|
||||
self.pr_id = self.git_provider.get_pr_id()
|
||||
|
||||
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
|
||||
"gfm_markdown"):
|
||||
get_logger().debug(f"Disabling semantic files types for {self.pr_id}")
|
||||
get_settings().pr_description.enable_semantic_files_types = False
|
||||
|
||||
# Initialize the AI handler
|
||||
self.ai_handler = ai_handler
|
||||
|
||||
@ -41,8 +45,13 @@ class PRDescription:
|
||||
"description": self.git_provider.get_pr_description(full=False),
|
||||
"language": self.main_pr_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"use_bullet_points": get_settings().pr_description.use_bullet_points,
|
||||
"extra_instructions": get_settings().pr_description.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages()
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
"enable_custom_labels": get_settings().config.enable_custom_labels,
|
||||
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
|
||||
"enable_file_walkthrough": get_settings().pr_description.enable_file_walkthrough,
|
||||
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
|
||||
}
|
||||
|
||||
self.user_description = self.git_provider.get_user_description()
|
||||
@ -65,18 +74,21 @@ class PRDescription:
|
||||
"""
|
||||
|
||||
try:
|
||||
logging.info(f"Generating a PR description {self.pr_id}")
|
||||
get_logger().info(f"Generating a PR description {self.pr_id}")
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing PR description...", is_temporary=True)
|
||||
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
|
||||
logging.info(f"Preparing answer {self.pr_id}")
|
||||
get_logger().info(f"Preparing answer {self.pr_id}")
|
||||
if self.prediction:
|
||||
self._prepare_data()
|
||||
else:
|
||||
return None
|
||||
|
||||
if get_settings().pr_description.enable_semantic_files_types:
|
||||
self._prepare_file_labels()
|
||||
|
||||
pr_labels = []
|
||||
if get_settings().pr_description.publish_labels:
|
||||
pr_labels = self._prepare_labels()
|
||||
@ -88,19 +100,25 @@ class PRDescription:
|
||||
full_markdown_description = f"## Title\n\n{pr_title}\n\n___\n{pr_body}"
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
logging.info(f"Pushing answer {self.pr_id}")
|
||||
get_logger().info(f"Pushing answer {self.pr_id}")
|
||||
if get_settings().pr_description.publish_description_as_comment:
|
||||
self.git_provider.publish_comment(full_markdown_description)
|
||||
else:
|
||||
self.git_provider.publish_description(pr_title, pr_body)
|
||||
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
|
||||
current_labels = self.git_provider.get_labels()
|
||||
if current_labels is None:
|
||||
current_labels = []
|
||||
self.git_provider.publish_labels(pr_labels + current_labels)
|
||||
user_labels = get_user_labels(current_labels)
|
||||
self.git_provider.publish_labels(pr_labels + user_labels)
|
||||
|
||||
if (get_settings().pr_description.final_update_message and
|
||||
hasattr(self.git_provider, 'pr_url') and self.git_provider.pr_url):
|
||||
latest_commit_url = self.git_provider.get_latest_commit_url()
|
||||
if latest_commit_url:
|
||||
self.git_provider.publish_comment(
|
||||
f"**[PR Description]({self.git_provider.pr_url})** updated to latest commit ({latest_commit_url})")
|
||||
self.git_provider.remove_initial_comment()
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating PR description {self.pr_id}: {e}")
|
||||
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@ -121,9 +139,9 @@ class PRDescription:
|
||||
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
|
||||
return None
|
||||
|
||||
logging.info(f"Getting PR diff {self.pr_id}")
|
||||
get_logger().info(f"Getting PR diff {self.pr_id}")
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
logging.info(f"Getting AI prediction {self.pr_id}")
|
||||
get_logger().info(f"Getting AI prediction {self.pr_id}")
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str) -> str:
|
||||
@ -140,12 +158,13 @@ class PRDescription:
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
set_custom_labels(variables)
|
||||
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
||||
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
get_logger().info(f"\nSystem prompt:\n{system_prompt}")
|
||||
get_logger().info(f"\nUser prompt:\n{user_prompt}")
|
||||
|
||||
response, finish_reason = await self.ai_handler.chat_completion(
|
||||
model=model,
|
||||
@ -154,8 +173,10 @@ class PRDescription:
|
||||
user=user_prompt
|
||||
)
|
||||
|
||||
return response
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
get_logger().info(f"\nAI response:\n{response}")
|
||||
|
||||
return response
|
||||
|
||||
def _prepare_data(self):
|
||||
# Load the AI prediction data into a dictionary
|
||||
@ -169,16 +190,20 @@ class PRDescription:
|
||||
pr_types = []
|
||||
|
||||
# If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
|
||||
if 'PR Type' in self.data:
|
||||
if type(self.data['PR Type']) == list:
|
||||
pr_types = self.data['PR Type']
|
||||
elif type(self.data['PR Type']) == str:
|
||||
pr_types = self.data['PR Type'].split(',')
|
||||
|
||||
if 'labels' in self.data:
|
||||
if type(self.data['labels']) == list:
|
||||
pr_types = self.data['labels']
|
||||
elif type(self.data['labels']) == str:
|
||||
pr_types = self.data['labels'].split(',')
|
||||
elif 'type' in self.data:
|
||||
if type(self.data['type']) == list:
|
||||
pr_types = self.data['type']
|
||||
elif type(self.data['type']) == str:
|
||||
pr_types = self.data['type'].split(',')
|
||||
return pr_types
|
||||
|
||||
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:
|
||||
logging.info(f"Using description marker replacements {self.pr_id}")
|
||||
get_logger().info(f"Using description marker replacements {self.pr_id}")
|
||||
title = self.vars["title"]
|
||||
body = self.user_description
|
||||
if get_settings().pr_description.include_generated_by_header:
|
||||
@ -186,7 +211,12 @@ class PRDescription:
|
||||
else:
|
||||
ai_header = ""
|
||||
|
||||
ai_summary = self.data.get('PR Description')
|
||||
ai_type = self.data.get('type')
|
||||
if ai_type and not re.search(r'<!--\s*pr_agent:type\s*-->', body):
|
||||
pr_type = f"{ai_header}{ai_type}"
|
||||
body = body.replace('pr_agent:type', pr_type)
|
||||
|
||||
ai_summary = self.data.get('description')
|
||||
if ai_summary and not re.search(r'<!--\s*pr_agent:summary\s*-->', body):
|
||||
summary = f"{ai_header}{ai_summary}"
|
||||
body = body.replace('pr_agent:summary', summary)
|
||||
@ -215,12 +245,17 @@ class PRDescription:
|
||||
|
||||
# Iterate over the dictionary items and append the key and value to 'markdown_text' in a markdown format
|
||||
markdown_text = ""
|
||||
# Don't display 'PR Labels'
|
||||
if 'labels' in self.data and self.git_provider.is_supported("get_labels"):
|
||||
self.data.pop('labels')
|
||||
if not get_settings().pr_description.enable_pr_type:
|
||||
self.data.pop('type')
|
||||
for key, value in self.data.items():
|
||||
markdown_text += f"## {key}\n\n"
|
||||
markdown_text += f"{value}\n\n"
|
||||
|
||||
# Remove the 'PR Title' key from the dictionary
|
||||
ai_title = self.data.pop('PR Title', self.vars["title"])
|
||||
ai_title = self.data.pop('title', self.vars["title"])
|
||||
if get_settings().pr_description.keep_original_user_title:
|
||||
# Assign the original PR title to the 'title' variable
|
||||
title = self.vars["title"]
|
||||
@ -232,26 +267,131 @@ class PRDescription:
|
||||
# except for the items containing the word 'walkthrough'
|
||||
pr_body = ""
|
||||
for idx, (key, value) in enumerate(self.data.items()):
|
||||
pr_body += f"## {key}:\n"
|
||||
if key == 'pr_files':
|
||||
value = self.file_label_dict
|
||||
key_publish = "PR changes walkthrough"
|
||||
else:
|
||||
key_publish = key.rstrip(':').replace("_", " ").capitalize()
|
||||
pr_body += f"## {key_publish}\n"
|
||||
if 'walkthrough' in key.lower():
|
||||
# for filename, description in value.items():
|
||||
if self.git_provider.is_supported("gfm_markdown"):
|
||||
pr_body += "<details> <summary>files:</summary>\n\n"
|
||||
for file in value:
|
||||
filename = file['filename'].replace("'", "`")
|
||||
description = file['changes in file']
|
||||
pr_body += f'`{filename}`: {description}\n'
|
||||
description = file['changes_in_file']
|
||||
pr_body += f'- `{filename}`: {description}\n'
|
||||
if self.git_provider.is_supported("gfm_markdown"):
|
||||
pr_body +="</details>\n"
|
||||
pr_body += "</details>\n"
|
||||
elif 'pr_files' in key.lower():
|
||||
pr_body = self.process_pr_files_prediction(pr_body, value)
|
||||
else:
|
||||
# if the value is a list, join its items by comma
|
||||
if type(value) == list:
|
||||
if isinstance(value, list):
|
||||
value = ', '.join(v for v in value)
|
||||
pr_body += f"{value}\n"
|
||||
if idx < len(self.data) - 1:
|
||||
pr_body += "\n___\n"
|
||||
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
logging.info(f"title:\n{title}\n{pr_body}")
|
||||
get_logger().info(f"title:\n{title}\n{pr_body}")
|
||||
|
||||
return title, pr_body
|
||||
return title, pr_body
|
||||
|
||||
def _prepare_file_labels(self):
|
||||
self.file_label_dict = {}
|
||||
for file in self.data['pr_files']:
|
||||
try:
|
||||
filename = file['filename'].replace("'", "`").replace('"', '`')
|
||||
changes_summary = file['changes_summary']
|
||||
label = file['label']
|
||||
if label not in self.file_label_dict:
|
||||
self.file_label_dict[label] = []
|
||||
self.file_label_dict[label].append((filename, changes_summary))
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error preparing file label dict {self.pr_id}: {e}")
|
||||
pass
|
||||
|
||||
def process_pr_files_prediction(self, pr_body, value):
|
||||
if not self.git_provider.is_supported("gfm_markdown"):
|
||||
get_logger().info(f"Disabling semantic files types for {self.pr_id} since gfm_markdown is not supported")
|
||||
return pr_body
|
||||
|
||||
try:
|
||||
pr_body += "<table>"
|
||||
header = f"Relevant files"
|
||||
delta = 65
|
||||
header += " " * delta
|
||||
pr_body += f"""<thead><tr><th></th><th>{header}</th></tr></thead>"""
|
||||
pr_body += """<tbody>"""
|
||||
for semantic_label in value.keys():
|
||||
s_label = semantic_label.strip("'").strip('"')
|
||||
pr_body += f"""<tr><td><strong>{s_label.capitalize()}</strong></td>"""
|
||||
list_tuples = value[semantic_label]
|
||||
pr_body += f"""<td><details><summary>{len(list_tuples)} files</summary><table>"""
|
||||
for filename, file_change_description in list_tuples:
|
||||
filename = filename.replace("'", "`")
|
||||
filename_publish = filename.split("/")[-1]
|
||||
filename_publish = f"{filename_publish}"
|
||||
if len(filename_publish) < (delta - 5):
|
||||
filename_publish += " " * ((delta - 5) - len(filename_publish))
|
||||
diff_plus_minus = ""
|
||||
diff_files = self.git_provider.diff_files
|
||||
for f in diff_files:
|
||||
if f.filename.lower() == filename.lower():
|
||||
num_plus_lines = f.num_plus_lines
|
||||
num_minus_lines = f.num_minus_lines
|
||||
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"
|
||||
break
|
||||
|
||||
# try to add line numbers link to code suggestions
|
||||
link = ""
|
||||
if hasattr(self.git_provider, 'get_line_link'):
|
||||
filename = filename.strip()
|
||||
link = self.git_provider.get_line_link(filename, relevant_line_start=-1)
|
||||
|
||||
file_change_description = self._insert_br_after_x_chars(file_change_description, x=(delta - 5))
|
||||
pr_body += f"""
|
||||
<tr>
|
||||
<td>
|
||||
<details>
|
||||
<summary><strong>{filename_publish}</strong></summary>
|
||||
<ul>
|
||||
{filename}<br><br>
|
||||
<strong>{file_change_description}</strong>
|
||||
</ul>
|
||||
</details>
|
||||
</td>
|
||||
<td><a href="{link}"> {diff_plus_minus}</a></td>
|
||||
|
||||
</tr>
|
||||
"""
|
||||
pr_body += """</table></details></td></tr>"""
|
||||
pr_body += """</tr></tbody></table>"""
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing pr files to markdown {self.pr_id}: {e}")
|
||||
pass
|
||||
return pr_body
|
||||
|
||||
def _insert_br_after_x_chars(self, text, x=70):
|
||||
"""
|
||||
Insert <br> into a string after a word that increases its length above x characters.
|
||||
"""
|
||||
if len(text) < x:
|
||||
return text
|
||||
|
||||
words = text.split(' ')
|
||||
new_text = ""
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
# Check if adding this word exceeds x characters
|
||||
if current_length + len(word) > x:
|
||||
new_text += "<br>" # Insert line break
|
||||
current_length = 0 # Reset counter
|
||||
|
||||
# Add the word to the new text
|
||||
new_text += word + " "
|
||||
current_length += len(word) + 1 # Add 1 for the space
|
||||
|
||||
return new_text.strip() # Remove trailing space
|
||||
|
Reference in New Issue
Block a user