self_reflect

This commit is contained in:
mrT23
2024-05-10 19:44:26 +03:00
parent 38058ea714
commit 1ebc20b761
6 changed files with 267 additions and 84 deletions

View File

@ -22,53 +22,54 @@ class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
if pr_url:
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
self.progress = f"## Generating PR code suggestions\n\n"
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
self.progress_response = None
self.progress = f"## Generating PR code suggestions\n\n"
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
self.progress_response = None
async def run(self):
try:
@ -83,12 +84,11 @@ class PRCodeSuggestions:
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions()
data = await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
if (not data) or (not 'code_suggestions' in data) or (not data['code_suggestions']):
if not data or not data.get('code_suggestions'):
get_logger().error('No code suggestions found for PR.')
pr_body = "## PR Code Suggestions ✨\n\nNo code suggestions found for PR."
get_logger().debug(f"PR output", artifact=pr_body)
@ -148,7 +148,7 @@ class PRCodeSuggestions:
except Exception as e:
pass
async def _prepare_prediction(self, model: str):
async def _prepare_prediction(self, model: str) -> dict:
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
@ -162,7 +162,10 @@ class PRCodeSuggestions:
get_logger().error(f"Error getting PR diff")
self.prediction = None
async def _get_prediction(self, model: str, patches_diff: str):
data = self._prepare_pr_code_suggestions(self.prediction)
return data
async def _get_prediction(self, model: str, patches_diff: str) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
@ -171,7 +174,21 @@ class PRCodeSuggestions:
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
# load suggestions from the AI response
data = self._prepare_pr_code_suggestions(response)
# self-reflect on suggestions
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff)
if response_reflect:
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
if len(code_suggestions_feedback) == len(data["code_suggestions"]):
for i, suggestion in enumerate(data["code_suggestions"]):
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
return data
@staticmethod
def _truncate_if_needed(suggestion):
@ -185,14 +202,13 @@ class PRCodeSuggestions:
f"characters to {max_code_suggestion_length} characters")
return suggestion
def _prepare_pr_code_suggestions(self) -> Dict:
review = self.prediction.strip()
data = load_yaml(review,
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
data = load_yaml(predictions.strip(),
keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"])
if isinstance(data, list):
data = {'code_suggestions': data}
# remove invalid suggestions
# remove or edit invalid suggestions
suggestion_list = []
one_sentence_summary_list = []
for i, suggestion in enumerate(data['code_suggestions']):
@ -210,15 +226,17 @@ class PRCodeSuggestions:
get_logger().debug(f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
continue
if ('existing_code' in suggestion) and ('improved_code' in suggestion) and (
suggestion['existing_code'] != suggestion['improved_code']):
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(f"skipping improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
suggestion['existing_code'] = ""
suggestion = self._truncate_if_needed(suggestion)
if not get_settings().pr_code_suggestions.commitable_code_suggestions:
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
suggestion_list.append(suggestion)
else:
get_logger().debug(
f"Skipping suggestion {i + 1}, because existing code is equal to improved code {suggestion['existing_code']}")
get_logger().info(
f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
data['code_suggestions'] = suggestion_list
@ -309,14 +327,24 @@ class PRCodeSuggestions:
prediction = await self._get_prediction(model, patches_diff)
prediction_list.append(prediction)
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "code_suggestions" in data:
data["code_suggestions"].extend(data_per_chunk["code_suggestions"])
else:
data.update(data_per_chunk)
data = {"code_suggestions": []}
for i, prediction in enumerate(prediction_list):
try:
if "code_suggestions" in prediction:
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
score = int(prediction["code_suggestions"][0]["score"])
if score > 0:
data["code_suggestions"].extend(prediction["code_suggestions"])
else:
get_logger().info(f"Skipping suggestions from call {i + 1}, because score is {score}")
else:
data["code_suggestions"].extend(prediction["code_suggestions"])
else:
get_logger().error(f"Error getting PR diff, no code suggestions found in call {i + 1}")
except Exception as e:
get_logger().error(f"Error getting PR diff, error: {e}")
data = None
self.data = data
else:
get_logger().error(f"Error getting PR diff")
@ -397,10 +425,13 @@ class PRCodeSuggestions:
pr_body = "## PR Code Suggestions ✨\n\n"
pr_body += "<table>"
header = f"Suggestions"
delta = 76
header = f"Suggestion"
delta = 68
header += "&nbsp; " * delta
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr></thead>"""
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td><td align=center>Score</td></tr>"""
else:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr>"""
pr_body += """<tbody>"""
suggestions_labels = dict()
# add all suggestions related to each label
@ -410,6 +441,11 @@ class PRCodeSuggestions:
suggestions_labels[label] = []
suggestions_labels[label].append(suggestion)
# sort suggestions_labels by the suggestion with the highest score
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
suggestions_labels = dict(sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
for label, suggestions in suggestions_labels.items():
num_suggestions=len(suggestions)
pr_body += f"""<tr><td rowspan={num_suggestions}><strong>{label.capitalize()}</strong></td>\n"""
@ -423,8 +459,12 @@ class PRCodeSuggestions:
range_str = f"[{relevant_lines_start}]"
else:
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
relevant_lines_end)
try:
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
relevant_lines_end)
except:
code_snippet_link = ""
# add html table for each suggestion
suggestion_content = suggestion['suggestion_content'].rstrip().rstrip()
@ -445,12 +485,11 @@ class PRCodeSuggestions:
pr_body += f"""<td>\n\n"""
else:
pr_body += f"""<tr><td>\n\n"""
suggestion_summary = suggestion['one_sentence_summary'].strip()
suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
if '`' in suggestion_summary:
suggestion_summary = replace_code_tags(suggestion_summary)
# suggestion_summary = suggestion_summary + max((77-len(suggestion_summary)), 0)*"&nbsp;"
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
pr_body += f"""
**{suggestion_content}**
@ -458,6 +497,15 @@ class PRCodeSuggestions:
{example_code}
"""
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body +=f"\n\n<details><summary><b>Suggestion importance[1-10]: {suggestion['score']}</b></summary>\n\n"
pr_body += f"Why: {suggestion['score_why']}\n\n"
pr_body += f"</details>"
# # add another column for 'score'
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"</td><td align=center>{suggestion['score']}\n\n"
pr_body += f"</details>"
pr_body += f"</td></tr>"
@ -469,3 +517,30 @@ class PRCodeSuggestions:
except Exception as e:
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
return ""
async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str) -> str:
if not suggestion_list:
return ""
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list,
'suggestion_str': suggestion_str,
"diff": patches_diff,
'num_code_suggestions': len(suggestion_list),
'commitable_code_suggestions_mode': get_settings().pr_code_suggestions.commitable_code_suggestions,}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.system).render(
variables)
user_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
system=system_prompt_reflect,
user=user_prompt_reflect)
except Exception as e:
get_logger().info(f"Could not reflect on suggestions, error: {e}")
return ""
return response_reflect