mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 12:20:38 +08:00
self_reflect
This commit is contained in:
@ -22,53 +22,54 @@ class PRCodeSuggestions:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
if pr_url:
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
# limit context specifically for the improve command, which has hard input to parse:
|
||||
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
# limit context specifically for the improve command, which has hard input to parse:
|
||||
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
|
||||
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
except:
|
||||
self.is_extended = False
|
||||
if self.is_extended:
|
||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
|
||||
else:
|
||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
except:
|
||||
self.is_extended = False
|
||||
if self.is_extended:
|
||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
|
||||
else:
|
||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
||||
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
"description": self.git_provider.get_pr_description(),
|
||||
"language": self.main_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"num_code_suggestions": num_code_suggestions,
|
||||
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
|
||||
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_code_suggestions_prompt.system,
|
||||
get_settings().pr_code_suggestions_prompt.user)
|
||||
self.ai_handler = ai_handler()
|
||||
self.ai_handler.main_pr_language = self.main_language
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
"description": self.git_provider.get_pr_description(),
|
||||
"language": self.main_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"num_code_suggestions": num_code_suggestions,
|
||||
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
|
||||
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||
}
|
||||
self.token_handler = TokenHandler(self.git_provider.pr,
|
||||
self.vars,
|
||||
get_settings().pr_code_suggestions_prompt.system,
|
||||
get_settings().pr_code_suggestions_prompt.user)
|
||||
|
||||
self.progress = f"## Generating PR code suggestions\n\n"
|
||||
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
|
||||
self.progress_response = None
|
||||
self.progress = f"## Generating PR code suggestions\n\n"
|
||||
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
|
||||
self.progress_response = None
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
@ -83,12 +84,11 @@ class PRCodeSuggestions:
|
||||
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
|
||||
|
||||
if not self.is_extended:
|
||||
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
|
||||
data = self._prepare_pr_code_suggestions()
|
||||
data = await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
|
||||
else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
|
||||
|
||||
if (not data) or (not 'code_suggestions' in data) or (not data['code_suggestions']):
|
||||
if not data or not data.get('code_suggestions'):
|
||||
get_logger().error('No code suggestions found for PR.')
|
||||
pr_body = "## PR Code Suggestions ✨\n\nNo code suggestions found for PR."
|
||||
get_logger().debug(f"PR output", artifact=pr_body)
|
||||
@ -148,7 +148,7 @@ class PRCodeSuggestions:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
async def _prepare_prediction(self, model: str) -> dict:
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
@ -162,7 +162,10 @@ class PRCodeSuggestions:
|
||||
get_logger().error(f"Error getting PR diff")
|
||||
self.prediction = None
|
||||
|
||||
async def _get_prediction(self, model: str, patches_diff: str):
|
||||
data = self._prepare_pr_code_suggestions(self.prediction)
|
||||
return data
|
||||
|
||||
async def _get_prediction(self, model: str, patches_diff: str) -> dict:
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -171,7 +174,21 @@ class PRCodeSuggestions:
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
|
||||
return response
|
||||
# load suggestions from the AI response
|
||||
data = self._prepare_pr_code_suggestions(response)
|
||||
|
||||
# self-reflect on suggestions
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff)
|
||||
if response_reflect:
|
||||
response_reflect_yaml = load_yaml(response_reflect)
|
||||
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
|
||||
if len(code_suggestions_feedback) == len(data["code_suggestions"]):
|
||||
for i, suggestion in enumerate(data["code_suggestions"]):
|
||||
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
|
||||
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _truncate_if_needed(suggestion):
|
||||
@ -185,14 +202,13 @@ class PRCodeSuggestions:
|
||||
f"characters to {max_code_suggestion_length} characters")
|
||||
return suggestion
|
||||
|
||||
def _prepare_pr_code_suggestions(self) -> Dict:
|
||||
review = self.prediction.strip()
|
||||
data = load_yaml(review,
|
||||
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
|
||||
data = load_yaml(predictions.strip(),
|
||||
keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"])
|
||||
if isinstance(data, list):
|
||||
data = {'code_suggestions': data}
|
||||
|
||||
# remove invalid suggestions
|
||||
# remove or edit invalid suggestions
|
||||
suggestion_list = []
|
||||
one_sentence_summary_list = []
|
||||
for i, suggestion in enumerate(data['code_suggestions']):
|
||||
@ -210,15 +226,17 @@ class PRCodeSuggestions:
|
||||
get_logger().debug(f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
|
||||
continue
|
||||
|
||||
if ('existing_code' in suggestion) and ('improved_code' in suggestion) and (
|
||||
suggestion['existing_code'] != suggestion['improved_code']):
|
||||
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
|
||||
if suggestion['existing_code'] == suggestion['improved_code']:
|
||||
get_logger().debug(f"skipping improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
|
||||
suggestion['existing_code'] = ""
|
||||
suggestion = self._truncate_if_needed(suggestion)
|
||||
if not get_settings().pr_code_suggestions.commitable_code_suggestions:
|
||||
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
|
||||
suggestion_list.append(suggestion)
|
||||
else:
|
||||
get_logger().debug(
|
||||
f"Skipping suggestion {i + 1}, because existing code is equal to improved code {suggestion['existing_code']}")
|
||||
get_logger().info(
|
||||
f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
|
||||
data['code_suggestions'] = suggestion_list
|
||||
@ -309,14 +327,24 @@ class PRCodeSuggestions:
|
||||
prediction = await self._get_prediction(model, patches_diff)
|
||||
prediction_list.append(prediction)
|
||||
|
||||
data = {}
|
||||
for prediction in prediction_list:
|
||||
self.prediction = prediction
|
||||
data_per_chunk = self._prepare_pr_code_suggestions()
|
||||
if "code_suggestions" in data:
|
||||
data["code_suggestions"].extend(data_per_chunk["code_suggestions"])
|
||||
else:
|
||||
data.update(data_per_chunk)
|
||||
|
||||
data = {"code_suggestions": []}
|
||||
for i, prediction in enumerate(prediction_list):
|
||||
try:
|
||||
if "code_suggestions" in prediction:
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
score = int(prediction["code_suggestions"][0]["score"])
|
||||
if score > 0:
|
||||
data["code_suggestions"].extend(prediction["code_suggestions"])
|
||||
else:
|
||||
get_logger().info(f"Skipping suggestions from call {i + 1}, because score is {score}")
|
||||
else:
|
||||
data["code_suggestions"].extend(prediction["code_suggestions"])
|
||||
else:
|
||||
get_logger().error(f"Error getting PR diff, no code suggestions found in call {i + 1}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error getting PR diff, error: {e}")
|
||||
data = None
|
||||
self.data = data
|
||||
else:
|
||||
get_logger().error(f"Error getting PR diff")
|
||||
@ -397,10 +425,13 @@ class PRCodeSuggestions:
|
||||
pr_body = "## PR Code Suggestions ✨\n\n"
|
||||
|
||||
pr_body += "<table>"
|
||||
header = f"Suggestions"
|
||||
delta = 76
|
||||
header = f"Suggestion"
|
||||
delta = 68
|
||||
header += " " * delta
|
||||
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr></thead>"""
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td><td align=center>Score</td></tr>"""
|
||||
else:
|
||||
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr>"""
|
||||
pr_body += """<tbody>"""
|
||||
suggestions_labels = dict()
|
||||
# add all suggestions related to each label
|
||||
@ -410,6 +441,11 @@ class PRCodeSuggestions:
|
||||
suggestions_labels[label] = []
|
||||
suggestions_labels[label].append(suggestion)
|
||||
|
||||
# sort suggestions_labels by the suggestion with the highest score
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
suggestions_labels = dict(sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
|
||||
|
||||
|
||||
for label, suggestions in suggestions_labels.items():
|
||||
num_suggestions=len(suggestions)
|
||||
pr_body += f"""<tr><td rowspan={num_suggestions}><strong>{label.capitalize()}</strong></td>\n"""
|
||||
@ -423,8 +459,12 @@ class PRCodeSuggestions:
|
||||
range_str = f"[{relevant_lines_start}]"
|
||||
else:
|
||||
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
|
||||
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
|
||||
relevant_lines_end)
|
||||
|
||||
try:
|
||||
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
|
||||
relevant_lines_end)
|
||||
except:
|
||||
code_snippet_link = ""
|
||||
# add html table for each suggestion
|
||||
|
||||
suggestion_content = suggestion['suggestion_content'].rstrip().rstrip()
|
||||
@ -445,12 +485,11 @@ class PRCodeSuggestions:
|
||||
pr_body += f"""<td>\n\n"""
|
||||
else:
|
||||
pr_body += f"""<tr><td>\n\n"""
|
||||
suggestion_summary = suggestion['one_sentence_summary'].strip()
|
||||
suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
|
||||
if '`' in suggestion_summary:
|
||||
suggestion_summary = replace_code_tags(suggestion_summary)
|
||||
# suggestion_summary = suggestion_summary + max((77-len(suggestion_summary)), 0)*" "
|
||||
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
|
||||
|
||||
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
|
||||
pr_body += f"""
|
||||
**{suggestion_content}**
|
||||
|
||||
@ -458,6 +497,15 @@ class PRCodeSuggestions:
|
||||
|
||||
{example_code}
|
||||
"""
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
pr_body +=f"\n\n<details><summary><b>Suggestion importance[1-10]: {suggestion['score']}</b></summary>\n\n"
|
||||
pr_body += f"Why: {suggestion['score_why']}\n\n"
|
||||
pr_body += f"</details>"
|
||||
|
||||
# # add another column for 'score'
|
||||
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
|
||||
pr_body += f"</td><td align=center>{suggestion['score']}\n\n"
|
||||
|
||||
pr_body += f"</details>"
|
||||
pr_body += f"</td></tr>"
|
||||
|
||||
@ -469,3 +517,30 @@ class PRCodeSuggestions:
|
||||
except Exception as e:
|
||||
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
|
||||
return ""
|
||||
|
||||
async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str) -> str:
|
||||
if not suggestion_list:
|
||||
return ""
|
||||
|
||||
try:
|
||||
suggestion_str = ""
|
||||
for i, suggestion in enumerate(suggestion_list):
|
||||
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
|
||||
|
||||
variables = {'suggestion_list': suggestion_list,
|
||||
'suggestion_str': suggestion_str,
|
||||
"diff": patches_diff,
|
||||
'num_code_suggestions': len(suggestion_list),
|
||||
'commitable_code_suggestions_mode': get_settings().pr_code_suggestions.commitable_code_suggestions,}
|
||||
model = get_settings().config.model
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.system).render(
|
||||
variables)
|
||||
user_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
|
||||
response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
|
||||
system=system_prompt_reflect,
|
||||
user=user_prompt_reflect)
|
||||
except Exception as e:
|
||||
get_logger().info(f"Could not reflect on suggestions, error: {e}")
|
||||
return ""
|
||||
return response_reflect
|
Reference in New Issue
Block a user