self_reflect

This commit is contained in:
mrT23
2024-05-10 19:44:26 +03:00
parent 38058ea714
commit 1ebc20b761
6 changed files with 267 additions and 84 deletions

8
format_importance.py Normal file
View File

@ -0,0 +1,8 @@
import numpy as np
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
data = np.load('/Users/talrid/Git/pr-agent/data.npy', allow_pickle=True).tolist()
cls=PRCodeSuggestions(pr_url=None)
res = cls.generate_summarized_suggestions(data)
print(res)

View File

@ -21,6 +21,7 @@ global_settings = Dynaconf(
"settings/pr_line_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_code_suggestions_reflect_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",

View File

@ -85,9 +85,10 @@ extra_instructions = ""
rank_suggestions = false
enable_help_text=false
persistent_comment=false
self_reflect_on_suggestions=true
# params for '/improve --extended' mode
auto_extended_mode=true
num_code_suggestions_per_chunk=5
num_code_suggestions_per_chunk=4
max_number_of_calls = 3
parallel_calls = true
rank_extended_suggestions = false

View File

@ -1,8 +1,9 @@
[pr_code_suggestions_prompt]
system="""You are PR-Reviewer, a language model that specializes in suggesting code improvements for a Pull Request (PR).
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR diff (lines starting with '+').
system="""You are PR-Reviewer, a language model that specializes in suggesting ways to improve for a Pull Request (PR) code.
Your task is to provide meaningful and actionable code suggestions, to improve the new code presented in a PR diff.
Example for the PR Diff format:
The format we will use to present the PR code diff:
======
## file: 'src/file1.py'
@ -26,22 +27,26 @@ __old hunk__
## file: 'src/file2.py'
...
======
- In this format, we separated each hunk of code to '__new hunk__' and '__old hunk__' sections. The '__new hunk__' section contains the new code of the chunk, and the '__old hunk__' section contains the old code that was removed.
- Code lines are prefixed symbols ('+', '-', ' '). The '+' symbol indicates new code added in the PR, the '-' symbol indicates code removed in the PR, and the ' ' symbol indicates unchanged code.
- We also added line numbers for the '__new hunk__' sections, to help you refer to the code lines in your suggestions. These line numbers are not part of the actual code, and are only used for reference.
Specific instructions:
Specific instructions for generating code suggestions:
- Provide up to {{ num_code_suggestions }} code suggestions. The suggestions should be diverse and insightful.
- The suggestions should refer only to code from the '__new hunk__' sections, and focus on new lines of code (lines starting with '+').
- Prioritize suggestions that address major problems, issues and bugs in the PR code. As a second priority, suggestions should focus on enhancement, best practice, performance, maintainability, and other aspects.
- The suggestions should focus on ways to improve the new code in the PR, meaning focusing on lines from '__new hunk__' sections, starting with '+'. Use the '__old hunk__' sections to understand the context of the code changes.
- Prioritize suggestions that address major problems, issues and possible bugs in the PR code.
- Don't suggest to add docstring, type hints, or comments, or to remove unused imports.
- Suggestions should not repeat code already present in the '__new hunk__' sections.
- Provide the exact line numbers range (inclusive) for each suggestion.
- Provide the exact line numbers range (inclusive) for each suggestion. Use the line numbers from the '__new hunk__' sections.
- When quoting variables or names from the code, use backticks (`) instead of single quote (').
- Be aware that you are reviewing a PR code diff, and that the entire codebase is not available for you as context. Hence, don't suggest changes that require knowledge of the entire codebase.
{%- if extra_instructions %}
Extra instructions from the user:
Extra instructions from the user, that should be taken into account with high priority:
======
{{ extra_instructions }}
======
@ -59,12 +64,12 @@ class CodeSuggestion(BaseModel):
improved_code: str = Field(description="a short code snippet to illustrate the improved code, after applying the suggestion.")
one_sentence_summary:str = Field(description="a short summary of the suggestion action, in a single sentence. Focus on the 'what'. Be general, and avoid method or variable names.")
{%- else %}
existing_code: str = Field(description="a code snippet, demonstrating the relevant code lines from a '__new hunk__' section. It must be contiguous, correctly formatted and indented, and without line numbers")
improved_code: str = Field(description="a new code snippet, that can be used to replace the relevant lines in '__new hunk__' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers")
existing_code: str = Field(description="a code snippet, demonstrating the relevant code lines from a '__new hunk__' section. It must be contiguous, correctly formatted and indented, and without line numbers. Use abbreviations if needed")
improved_code: str = Field(description="If relevant, a new code snippet, that can be used to replace the relevant lines in '__new hunk__' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers". Retrun '...' if not applicable")
{%- endif %}
relevant_lines_start: int = Field(description="The relevant line number, from a '__new hunk__' section, where the suggestion starts (inclusive). Should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above")
relevant_lines_end: int = Field(description="The relevant line number, from a '__new hunk__' section, where the suggestion ends (inclusive). Should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above")
label: str = Field(description="a single label for the suggestion, to help the user understand the suggestion type. For example: 'security', 'bug', 'performance', 'enhancement', 'possible issue', 'best practice', 'maintainability', etc. Other labels are also allowed")
label: str = Field(description="a single label for the suggestion, to help the user understand the suggestion type. For example: 'security', 'possible bug', 'possible issue', 'performance', 'enhancement', 'best practice', 'maintainability', etc. Other labels are also allowed")
class PRCodeSuggestions(BaseModel):
code_suggestions: List[CodeSuggestion]

View File

@ -0,0 +1,93 @@
[pr_code_suggestions_reflect_prompt]
system="""You are a language model that specializes in reviewing and evaluating suggestions for a Pull Request (PR) code.
Your input is a PR code, and a list of code suggestions that were generated for the PR.
Your goal is to inspect, review and score the suggestsions.
Be aware - the suggestions may not always be correct or accurate, and you should evaluate them in relation to the actual PR code diff presented. Sometimes the suggestion may ignore parts of the actual code diff, and in that case, you should give a score of 0.
Specific instructions:
- Carefully review both the suggestion content, and the related PR code diff. Mistakes in the suggestions can occur. Make sure the suggestions are correct, and properly derived from the PR code diff.
- In addition to the exact code lines mentioned in each suggestion, review the code around them, to ensure that the suggestions are contextually accurate.
- Also check that the 'existing_code' and 'improved_code' fields correctly reflect the suggested changes.
- High scores (8 to 10) should be given to correct suggestions that address major bugs and issues, or security concerns. Lower scores (3 to 7) should be for correct suggestions addressing minor issues, code style, code readability, maintainability, etc. Don't give high scores to suggestions that are not crucial, and bring only small improvement or optimization.
- Order the feedback the same way the suggestions are ordered in the input.
The format that is used to present the PR code diff is as follows:
======
## file: 'src/file1.py'
@@ ... @@ def func1():
__new hunk__
12 code line1 that remained unchanged in the PR
13 +new hunk code line2 added in the PR
14 code line3 that remained unchanged in the PR
__old hunk__
code line1 that remained unchanged in the PR
-old hunk code line2 that was removed in the PR
code line3 that remained unchanged in the PR
@@ ... @@ def func2():
__new hunk__
...
__old hunk__
...
## file: 'src/file2.py'
...
======
- In this format, we separated each hunk of code to '__new hunk__' and '__old hunk__' sections. The '__new hunk__' section contains the new code of the chunk, and the '__old hunk__' section contains the old code that was removed.
- Code lines are prefixed symbols ('+', '-', ' '). The '+' symbol indicates new code added in the PR, the '-' symbol indicates code removed in the PR, and the ' ' symbol indicates unchanged code.
- We also added line numbers for the '__new hunk__' sections, to help you refer to the code lines in your suggestions. These line numbers are not part of the actual code, and are only used for reference.
The output must be a YAML object equivalent to type $PRCodeSuggestionsFeedback, according to the following Pydantic definitions:
=====
class CodeSuggestionFeedback(BaseModel):
{%- if not commitable_code_suggestions_mode %}
suggestion_summary: str = Field(description="repeated from the input")
{%- endif %}
relevant_file: str = Field(description="repeated from the input")
suggestion_score: int = Field(description="The actual output - the score of the suggestion, from 0 to 10. Give 0 if the suggestion is plain wrong. Otherwise, give a score from 1 to 10 (inclusive), where 1 is the lowest and 10 is the highest.")
why: str = Field(description="Short and concise explanation of why the suggestion received the score (one to two sentences).")
class PRCodeSuggestionsFeedback(BaseModel):
code_suggestions: List[CodeSuggestionFeedback]
=====
Example output:
```yaml
code_suggestions:
{%- if not commitable_code_suggestions_mode %}
- suggestion_content: |
Use a more descriptive variable name here
relevant_file: "src/file1.py"
{%- else %}
- relevant_file: "src/file1.py"
{%- endif %}
suggestion_score: 6
why: |
The variable name 't' is not descriptive enough
```
Each YAML output MUST be after a newline, indented, with block scalar indicator ('|').
"""
user="""You are given a Pull Request (PR) code diff:
======
{{ diff|trim }}
======
And here is a list of corresponding {{ num_code_suggestions }} code suggestions to improve this Pull Request code:
======
{{ suggestion_str|trim }}
======
Response (should be a valid YAML, and nothing else):
```yaml
"""

View File

@ -22,53 +22,54 @@ class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
if pr_url:
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
except:
self.is_extended = False
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_language
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": num_code_suggestions,
"commitable_code_suggestions_mode": get_settings().pr_code_suggestions.commitable_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
self.progress = f"## Generating PR code suggestions\n\n"
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
self.progress_response = None
self.progress = f"## Generating PR code suggestions\n\n"
self.progress += f"""\nWork in progress ...<br>\n<img src="https://codium.ai/images/pr_agent/dual_ball_loading-crop.gif" width=48>"""
self.progress_response = None
async def run(self):
try:
@ -83,12 +84,11 @@ class PRCodeSuggestions:
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions()
data = await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
if (not data) or (not 'code_suggestions' in data) or (not data['code_suggestions']):
if not data or not data.get('code_suggestions'):
get_logger().error('No code suggestions found for PR.')
pr_body = "## PR Code Suggestions ✨\n\nNo code suggestions found for PR."
get_logger().debug(f"PR output", artifact=pr_body)
@ -148,7 +148,7 @@ class PRCodeSuggestions:
except Exception as e:
pass
async def _prepare_prediction(self, model: str):
async def _prepare_prediction(self, model: str) -> dict:
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
@ -162,7 +162,10 @@ class PRCodeSuggestions:
get_logger().error(f"Error getting PR diff")
self.prediction = None
async def _get_prediction(self, model: str, patches_diff: str):
data = self._prepare_pr_code_suggestions(self.prediction)
return data
async def _get_prediction(self, model: str, patches_diff: str) -> dict:
variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
@ -171,7 +174,21 @@ class PRCodeSuggestions:
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
# load suggestions from the AI response
data = self._prepare_pr_code_suggestions(response)
# self-reflect on suggestions
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
response_reflect = await self.self_reflect_on_suggestions(data["code_suggestions"], patches_diff)
if response_reflect:
response_reflect_yaml = load_yaml(response_reflect)
code_suggestions_feedback = response_reflect_yaml["code_suggestions"]
if len(code_suggestions_feedback) == len(data["code_suggestions"]):
for i, suggestion in enumerate(data["code_suggestions"]):
suggestion["score"] = code_suggestions_feedback[i]["suggestion_score"]
suggestion["score_why"] = code_suggestions_feedback[i]["why"]
return data
@staticmethod
def _truncate_if_needed(suggestion):
@ -185,14 +202,13 @@ class PRCodeSuggestions:
f"characters to {max_code_suggestion_length} characters")
return suggestion
def _prepare_pr_code_suggestions(self) -> Dict:
review = self.prediction.strip()
data = load_yaml(review,
def _prepare_pr_code_suggestions(self, predictions: str) -> Dict:
data = load_yaml(predictions.strip(),
keys_fix_yaml=["relevant_file", "suggestion_content", "existing_code", "improved_code"])
if isinstance(data, list):
data = {'code_suggestions': data}
# remove invalid suggestions
# remove or edit invalid suggestions
suggestion_list = []
one_sentence_summary_list = []
for i, suggestion in enumerate(data['code_suggestions']):
@ -210,15 +226,17 @@ class PRCodeSuggestions:
get_logger().debug(f"Skipping suggestion {i + 1}, because it uses 'const instead let': {suggestion}")
continue
if ('existing_code' in suggestion) and ('improved_code' in suggestion) and (
suggestion['existing_code'] != suggestion['improved_code']):
if ('existing_code' in suggestion) and ('improved_code' in suggestion):
if suggestion['existing_code'] == suggestion['improved_code']:
get_logger().debug(f"skipping improved suggestion {i + 1}, because equal to existing code: {suggestion['existing_code']}")
suggestion['existing_code'] = ""
suggestion = self._truncate_if_needed(suggestion)
if not get_settings().pr_code_suggestions.commitable_code_suggestions:
one_sentence_summary_list.append(suggestion['one_sentence_summary'])
suggestion_list.append(suggestion)
else:
get_logger().debug(
f"Skipping suggestion {i + 1}, because existing code is equal to improved code {suggestion['existing_code']}")
get_logger().info(
f"Skipping suggestion {i + 1}, because it does not contain 'existing_code' or 'improved_code': {suggestion}")
except Exception as e:
get_logger().error(f"Error processing suggestion {i + 1}: {suggestion}, error: {e}")
data['code_suggestions'] = suggestion_list
@ -309,14 +327,24 @@ class PRCodeSuggestions:
prediction = await self._get_prediction(model, patches_diff)
prediction_list.append(prediction)
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "code_suggestions" in data:
data["code_suggestions"].extend(data_per_chunk["code_suggestions"])
else:
data.update(data_per_chunk)
data = {"code_suggestions": []}
for i, prediction in enumerate(prediction_list):
try:
if "code_suggestions" in prediction:
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
score = int(prediction["code_suggestions"][0]["score"])
if score > 0:
data["code_suggestions"].extend(prediction["code_suggestions"])
else:
get_logger().info(f"Skipping suggestions from call {i + 1}, because score is {score}")
else:
data["code_suggestions"].extend(prediction["code_suggestions"])
else:
get_logger().error(f"Error getting PR diff, no code suggestions found in call {i + 1}")
except Exception as e:
get_logger().error(f"Error getting PR diff, error: {e}")
data = None
self.data = data
else:
get_logger().error(f"Error getting PR diff")
@ -397,10 +425,13 @@ class PRCodeSuggestions:
pr_body = "## PR Code Suggestions ✨\n\n"
pr_body += "<table>"
header = f"Suggestions"
delta = 76
header = f"Suggestion"
delta = 68
header += "&nbsp; " * delta
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr></thead>"""
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td><td align=center>Score</td></tr>"""
else:
pr_body += f"""<thead><tr><td>Category</td><td align=left>{header}</td></tr>"""
pr_body += """<tbody>"""
suggestions_labels = dict()
# add all suggestions related to each label
@ -410,6 +441,11 @@ class PRCodeSuggestions:
suggestions_labels[label] = []
suggestions_labels[label].append(suggestion)
# sort suggestions_labels by the suggestion with the highest score
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
suggestions_labels = dict(sorted(suggestions_labels.items(), key=lambda x: max([s['score'] for s in x[1]]), reverse=True))
for label, suggestions in suggestions_labels.items():
num_suggestions=len(suggestions)
pr_body += f"""<tr><td rowspan={num_suggestions}><strong>{label.capitalize()}</strong></td>\n"""
@ -423,8 +459,12 @@ class PRCodeSuggestions:
range_str = f"[{relevant_lines_start}]"
else:
range_str = f"[{relevant_lines_start}-{relevant_lines_end}]"
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
relevant_lines_end)
try:
code_snippet_link = self.git_provider.get_line_link(relevant_file, relevant_lines_start,
relevant_lines_end)
except:
code_snippet_link = ""
# add html table for each suggestion
suggestion_content = suggestion['suggestion_content'].rstrip().rstrip()
@ -445,12 +485,11 @@ class PRCodeSuggestions:
pr_body += f"""<td>\n\n"""
else:
pr_body += f"""<tr><td>\n\n"""
suggestion_summary = suggestion['one_sentence_summary'].strip()
suggestion_summary = suggestion['one_sentence_summary'].strip().rstrip('.')
if '`' in suggestion_summary:
suggestion_summary = replace_code_tags(suggestion_summary)
# suggestion_summary = suggestion_summary + max((77-len(suggestion_summary)), 0)*"&nbsp;"
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
pr_body += f"""\n\n<details><summary>{suggestion_summary}</summary>\n\n___\n\n"""
pr_body += f"""
**{suggestion_content}**
@ -458,6 +497,15 @@ class PRCodeSuggestions:
{example_code}
"""
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body +=f"\n\n<details><summary><b>Suggestion importance[1-10]: {suggestion['score']}</b></summary>\n\n"
pr_body += f"Why: {suggestion['score_why']}\n\n"
pr_body += f"</details>"
# # add another column for 'score'
if get_settings().pr_code_suggestions.self_reflect_on_suggestions:
pr_body += f"</td><td align=center>{suggestion['score']}\n\n"
pr_body += f"</details>"
pr_body += f"</td></tr>"
@ -469,3 +517,30 @@ class PRCodeSuggestions:
except Exception as e:
get_logger().info(f"Failed to publish summarized code suggestions, error: {e}")
return ""
async def self_reflect_on_suggestions(self, suggestion_list: List, patches_diff: str) -> str:
if not suggestion_list:
return ""
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list,
'suggestion_str': suggestion_str,
"diff": patches_diff,
'num_code_suggestions': len(suggestion_list),
'commitable_code_suggestions_mode': get_settings().pr_code_suggestions.commitable_code_suggestions,}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.system).render(
variables)
user_prompt_reflect = environment.from_string(get_settings().pr_code_suggestions_reflect_prompt.user).render(variables)
response_reflect, finish_reason_reflect = await self.ai_handler.chat_completion(model=model,
system=system_prompt_reflect,
user=user_prompt_reflect)
except Exception as e:
get_logger().info(f"Could not reflect on suggestions, error: {e}")
return ""
return response_reflect