Merge branch 'main' into nocode_suggestions_config

This commit is contained in:
Tal
2024-11-04 07:50:22 +02:00
committed by GitHub
9 changed files with 165 additions and 112 deletions

View File

@ -1,6 +1,7 @@
import asyncio
import copy
import textwrap
import traceback
from functools import partial
from typing import Dict, List
from jinja2 import Environment, StrictUndefined
@ -44,7 +45,7 @@ class PRCodeSuggestions:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
num_code_suggestions = int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk)
self.ai_handler = ai_handler()
@ -69,6 +70,7 @@ class PRCodeSuggestions:
"description": self.pr_description,
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"diff_no_line_numbers": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
@ -110,18 +112,17 @@ class PRCodeSuggestions:
if not data:
data = {"code_suggestions": []}
if (data is None or 'code_suggestions' not in data or not data['code_suggestions']
and get_settings().config.publish_output):
if (data is None or 'code_suggestions' not in data or not data['code_suggestions']):
pr_body = "## PR Code Suggestions ✨\n\nNo code suggestions found for the PR."
get_logger().warning('No code suggestions found for the PR.')
if (get_settings().config.publish_output_no_suggestions):
pr_body = "## PR Code Suggestions ✨\n\nNo code suggestions found for the PR."
if get_settings().config.publish_output and get_settings().config.publish_output_no_suggestions:
get_logger().debug(f"PR output", artifact=pr_body)
if self.progress_response:
self.git_provider.edit_comment(self.progress_response, body=pr_body)
else:
self.git_provider.publish_comment(pr_body)
else:
get_settings().data = {"artifact": ""}
return
if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \
@ -198,8 +199,11 @@ class PRCodeSuggestions:
self.git_provider.remove_comment(self.progress_response)
else:
get_logger().info('Code suggestions generated for PR, but not published since publish_output is False.')
get_settings().data = {"artifact": data}
return
except Exception as e:
get_logger().error(f"Failed to generate code suggestions for PR, error: {e}")
get_logger().error(f"Failed to generate code suggestions for PR, error: {e}",
artifact={"traceback": traceback.format_exc()})
if get_settings().config.publish_output:
if self.progress_response:
self.progress_response.delete()
@ -331,7 +335,7 @@ class PRCodeSuggestions:
if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, self.patches_diff)
self.prediction = await self._get_prediction(model, self.patches_diff, self.patches_diff_no_line_number)
else:
get_logger().warning(f"Empty PR diff")
self.prediction = None
@ -339,54 +343,76 @@ class PRCodeSuggestions:
data = self.prediction
return data
async def _get_prediction(self, model: str, patches_diff: str) -> dict:
async def _get_prediction(self, model: str, patches_diff: str, patches_diff_no_line_number: str) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
variables["diff_no_line_numbers"] = patches_diff_no_line_number # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(self.pr_code_suggestions_prompt_system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model, temperature=get_settings().config.temperature, system=system_prompt, user=user_prompt)
if not get_settings().config.publish_output:
get_settings().system_prompt = system_prompt
get_settings().user_prompt = user_prompt
# load suggestions from the AI response
data = self._prepare_pr_code_suggestions(response)
# self-reflect on suggestions
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
model_turbo = get_settings().config.model_turbo # use turbo model for self-reflection, since it is an easier task
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
patches_diff, model=model_turbo)
if response_reflect:
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
if len(code_suggestions_feedback) == len(data["code_suggestions"]):
for i, suggestion in enumerate(data["code_suggestions"]):
try:
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
except Exception as e: #
get_logger().error(f"Error processing suggestion score {i}",
artifact={"suggestion": suggestion,
"code_suggestions_feedback": code_suggestions_feedback[i]})
suggestion["score"] = 7
suggestion["score_why"] = ""
# if the before and after code is the same, clear one of them
try:
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(
f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
if get_settings().pr_code_suggestions.commitable_code_suggestions:
suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR
else:
suggestion['existing_code'] = ""
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}, error: {e}")
else:
# get_logger().error(f"Could not self-reflect on suggestions. using default score 7")
# self-reflect on suggestions (mandatory, since line numbers are generated now here)
model_reflection = get_settings().config.model
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"],
patches_diff, model=model_reflection)
if response_reflect:
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
if len(code_suggestions_feedback) == len(data["code_suggestions"]):
for i, suggestion in enumerate(data["code_suggestions"]):
suggestion["score"] = 7
suggestion["score_why"] = ""
try:
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
if 'relevant_lines_start' not in suggestion:
relevant_lines_start = code_suggestions_feedback[i].get('relevant_lines_start', -1)
relevant_lines_end = code_suggestions_feedback[i].get('relevant_lines_end', -1)
suggestion['relevant_lines_start'] = relevant_lines_start
suggestion['relevant_lines_end'] = relevant_lines_end
if relevant_lines_start < 0 or relevant_lines_end < 0:
suggestion["score"] = 0
try:
if get_settings().config.publish_output:
suggestion_statistics_dict = {'score': int(suggestion["score"]),
'label': suggestion["label"].lower().strip()}
get_logger().info(f"PR-Agent suggestions statistics",
statistics=suggestion_statistics_dict, analytics=True)
except Exception as e:
get_logger().error(f"Failed to log suggestion statistics, error: {e}")
pass
except Exception as e: #
get_logger().error(f"Error processing suggestion score {i}",
artifact={"suggestion": suggestion,
"code_suggestions_feedback": code_suggestions_feedback[i]})
suggestion["score"] = 7
suggestion["score_why"] = ""
# if the before and after code is the same, clear one of them
try:
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(
f"edited improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
if get_settings().pr_code_suggestions.commitable_code_suggestions:
suggestion['improved_code'] = "" # we need 'existing_code' to locate the code in the PR
else:
suggestion['existing_code'] = ""
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}, error: {e}")
else:
# get_logger().error(f"Could not self-reflect on suggestions. using default score 7")
for i, suggestion in enumerate(data["code_suggestions"]):
suggestion["score"] = 7
suggestion["score_why"] = ""
return data
@ -396,10 +422,10 @@ class PRCodeSuggestions:
suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "")
if max_code_suggestion_length > 0:
if len(suggestion['improved_code']) > max_code_suggestion_length:
suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} "
f"characters to {max_code_suggestion_length} characters")
suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length]
suggestion['improved_code'] += f"\n{suggestion_truncation_message}"
return suggestion
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
@ -414,8 +440,7 @@ class PRCodeSuggestions:
one_sentence_summary_list = []
for i, suggestion in enumerate(data['code_suggestions']):
try:
needed_keys = ['one_sentence_summary', 'label', 'relevant_file', 'relevant_lines_start',
'relevant_lines_end']
needed_keys = ['one_sentence_summary', 'label', 'relevant_file']
is_valid_keys = True
for key in needed_keys:
if key not in suggestion:
@ -539,9 +564,33 @@ class PRCodeSuggestions:
return True
return False
def remove_line_numbers(self, patches_diff_list: List[str]) -> List[str]:
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
try:
self.patches_diff_list_no_line_numbers = []
for patches_diff in self.patches_diff_list:
patches_diff_lines = patches_diff.splitlines()
for i, line in enumerate(patches_diff_lines):
if line.strip():
if line[0].isdigit():
# find the first letter in the line that starts with a valid letter
for j, char in enumerate(line):
if not char.isdigit():
patches_diff_lines[i] = line[j + 1:]
break
self.patches_diff_list_no_line_numbers.append('\n'.join(patches_diff_lines))
return self.patches_diff_list_no_line_numbers
except Exception as e:
get_logger().error(f"Error removing line numbers from patches_diff_list, error: {e}")
return patches_diff_list
async def _prepare_prediction_extended(self, model: str) -> dict:
self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
# create a copy of the patches_diff_list, without line numbers for '__new hunk__' sections
self.patches_diff_list_no_line_numbers = self.remove_line_numbers(self.patches_diff_list)
if self.patches_diff_list:
get_logger().info(f"Number of PR chunk calls: {len(self.patches_diff_list)}")
get_logger().debug(f"PR diff:", artifact=self.patches_diff_list)
@ -549,12 +598,14 @@ class PRCodeSuggestions:
# parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls:
prediction_list = await asyncio.gather(
*[self._get_prediction(model, patches_diff) for patches_diff in self.patches_diff_list])
*[self._get_prediction(model, patches_diff, patches_diff_no_line_numbers) for
patches_diff, patches_diff_no_line_numbers in
zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers)])
self.prediction_list = prediction_list
else:
prediction_list = []
for i, patches_diff in enumerate(self.patches_diff_list):
prediction = await self._get_prediction(model, patches_diff)
for patches_diff, patches_diff_no_line_numbers in zip(self.patches_diff_list, self.patches_diff_list_no_line_numbers):
prediction = await self._get_prediction(model, patches_diff, patches_diff_no_line_numbers)
prediction_list.append(prediction)
data = {"code_suggestions": []}
@ -563,18 +614,16 @@ class PRCodeSuggestions:
score_threshold = max(1, int(get_settings().pr_code_suggestions.suggestions_score_threshold))
for i, prediction in enumerate(predictions["code_suggestions"]):
try:
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
score = int(prediction.get("score", 1))
if score >= score_threshold:
data["code_suggestions"].append(prediction)
else:
get_logger().info(
f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}",
artifact=prediction)
else:
score = int(prediction.get("score", 1))
if score >= score_threshold:
data["code_suggestions"].append(prediction)
else:
get_logger().info(
f"Removing suggestions {i} from call {j}, because score is {score}, and score_threshold is {score_threshold}",
artifact=prediction)
except Exception as e:
get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}")
get_logger().error(f"Error getting PR diff for suggestion {i} in call {j}, error: {e}",
artifact={"prediction": prediction})
self.data = data
else:
get_logger().warning(f"Empty PR diff list")
@ -625,7 +674,7 @@ class PRCodeSuggestions:
if get_settings().pr_code_suggestions.final_clip_factor != 1:
max_len = max(
len(data_sorted),
get_settings().pr_code_suggestions.num_code_suggestions_per_chunk,
int(get_settings().pr_code_suggestions.num_code_suggestions_per_chunk),
)
new_len = int(0.5 + max_len * get_settings().pr_code_suggestions.final_clip_factor)
if new_len < len(data_sorted):
@ -658,10 +707,7 @@ class PRCodeSuggestions:
header = f"Suggestion"
delta = 66
header += "&nbsp; " * delta
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td><td align=center>Score</td></tr>"""
else:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr>"""
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td><td align=center>Score</td></tr>"""
pr_body += """<tbody>"""
suggestions_labels = dict()
# add all suggestions related to each label
@ -672,12 +718,11 @@ class PRCodeSuggestions:
suggestions_labels[label].append(suggestion)
# sort suggestions_labels by the suggestion with the highest score
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
suggestions_labels = dict(
sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
# sort the suggestions inside each label group by score
for label, suggestions in suggestions_labels.items():
suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True)
suggestions_labels = dict(
sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
# sort the suggestions inside each label group by score
for label, suggestions in suggestions_labels.items():
suggestions_labels[label] = sorted(suggestions, key=lambda x: x['score'], reverse=True)
counter_suggestions = 0
for label, suggestions in suggestions_labels.items():
@ -736,16 +781,14 @@ class PRCodeSuggestions:
{example_code.rstrip()}
"""
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"<details><summary>Suggestion importance[1-10]: {suggestion['score']}</summary>\n\n"
pr_body += f"Why: {suggestion['score_why']}\n\n"
pr_body += f"</details>"
pr_body += f"<details><summary>Suggestion importance[1-10]: {suggestion['score']}</summary>\n\n"
pr_body += f"Why: {suggestion['score_why']}\n\n"
pr_body += f"</details>"
pr_body += f"</details>"
# # add another column for 'score'
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"</td><td align=center>{suggestion['score']}\n\n"
pr_body += f"</td><td align=center>{suggestion['score']}\n\n"
pr_body += f"</td></tr>"
counter_suggestions += 1