extended improve

This commit is contained in:
mrT23
2023-08-21 09:07:21 +03:00
parent fda98643c2
commit fb9335f424
10 changed files with 406 additions and 41 deletions

View File

@ -11,6 +11,7 @@ from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_extended_code_suggestions import PRExtendedCodeSuggestions
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
from pr_agent.tools.pr_config import PRConfig
@ -25,6 +26,7 @@ command2class = {
"describe_pr": PRDescription,
"improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions,
"extended_improve": PRExtendedCodeSuggestions,
"ask": PRQuestions,
"ask_question": PRQuestions,
"update_changelog": PRUpdateChangelog,

View File

@ -55,7 +55,7 @@ class AiHandler:
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
"""
Performs a chat completion using the OpenAI ChatCompletion API.
Retries in case of API errors or timeouts.

View File

@ -176,7 +176,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
...
"""
patch_with_lines_str = f"## {file.filename}\n"
patch_with_lines_str = f"\n\n## {file.filename}\n"
import re
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(

View File

@ -57,7 +57,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler,
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(pr_languages, token_handler,
add_line_numbers_to_hunks)
# if we are under the limit, return the full diff
@ -78,9 +78,9 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
return final_diff
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]:
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> Tuple[list, int, list]:
"""
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
@ -90,13 +90,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
Returns:
- patches_extended: A list of extended patches for each file in the pull request.
- total_tokens: The total number of tokens used in the extended patches.
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
@ -108,15 +105,16 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
if add_line_numbers_to_hunks:
if add_line_numbers_to_hunks and PATCH_EXTRA_LINES > 0:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended_tokens.append(patch_tokens)
patches_extended.append(full_extended_patch)
return patches_extended, total_tokens
return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
@ -338,3 +336,83 @@ def clip_tokens(text: str, max_tokens: int) -> str:
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
Args:
git_provider (GitProvider): An object that provides access to Git provider APIs.
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
model (str): The name of the model.
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
Returns:
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
Raises:
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
"""
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# Sort files within each language group by tokens in descending order
sorted_files = []
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
patches = []
final_diff_list = []
total_tokens = token_handler.prompt_tokens
call_number = 1
for file in sorted_files:
if call_number > max_calls:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Reached max calls ({max_calls})")
break
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename)
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if get_settings().config.verbosity_level >= 2:
logging.info(f"Call number: {call_number}")
if patch:
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
# Add the last chunk
if patches:
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
return final_diff_list

View File

@ -19,6 +19,7 @@ global_settings = Dynaconf(
"settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml"

View File

@ -32,6 +32,14 @@ extra_instructions = ""
num_code_suggestions=4
extra_instructions = ""
[pr_extendeted_code_suggestions] # /extended_improve #
num_code_suggestions_per_chunk=8
extra_instructions = ""
max_number_of_calls = 5
rank_suggestions = true
final_clip_factor = 0.5
[pr_update_changelog] # /update_changelog #
push_changelog_changes=false
extra_instructions = ""

View File

@ -1,19 +1,47 @@
[pr_code_suggestions_prompt]
system="""You are a language model called CodiumAI-PR-Code-Reviewer.
Your task is to provide meaningfull non-trivial code suggestions to improve the new code in a PR (the '+' lines).
- Try to give important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to the 'new hunk' code, and focus on improving the new added code lines, with '+'.
system="""You are a language model called PR-Code-Reviewer.
Your task is to provide meaningful actionable code suggestions, to improve the new code presented in a PR.
Example PR Diff input:
'
## src/file1.py
--new hunk--
12 code line that already existed in the file...
13 code line that already existed in the file....
14 +new code line added in the PR
15 code line that already existed in the file...
16 code line that already existed in the file...
--old hunk--
code line that already existed in the file...
-code line that was removed in the PR
code line that already existed in the file...
--new hunk--
...
--old hunk--
...
## src/file2.py
...
'
Specific instructions:
- Focus on important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to code from the '--new hunk--' sections, and focus on new lines of code (lines starting with '+').
- Provide the exact line number range (inclusive) for each issue.
- Assume there is additional code in the relevant file that is not included in the diff.
- Assume there is additional relevant code, that is not included in the diff.
- Provide up to {{ num_code_suggestions }} code suggestions.
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
- Don't output line numbers in the 'improved code' snippets.
- Avoid making suggestions that have already been implemented in the PR code. For example, if you propose adding a docstring, type hint, or anything else, make sure it isn't already in the '--new hunk--' code.
{%- if extra_instructions %}
Extra instructions from the user:
{{ extra_instructions }}
{% endif %}
{%- endif %}
You must use the following JSON schema to format your answer:
```json
@ -30,39 +58,26 @@ You must use the following JSON schema to format your answer:
},
"suggestion content": {
"type": "string",
"description": "a concrete suggestion for meaningfully improving the new PR code."
"description": "a concrete suggestion for meaningfully improving the new PR code (lines from the '--new hunk--' sections, starting with '+')."
},
"existing code": {
"type": "string",
"description": "a code snippet showing authentic relevant code lines from a 'new hunk' section. It must be continuous, correctly formatted and indented, and without line numbers."
"description": "a code snippet showing the relevant code lines from a '--new hunk--' section. It must be continuous, correctly formatted and indented, and without line numbers."
},
"relevant lines": {
"type": "string",
"description": "the relevant lines in the 'new hunk' sections, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
"description": "the relevant lines from a '--new hunk--' section, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
},
"improved code": {
"type": "string",
"description": "a new code snippet that can be used to replace the relevant lines in 'new hunk' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
"description": "a new code snippet that can be used to replace the relevant lines in '--new hunk--' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
}
}
}
}
```
Example input:
'
## src/file1.py
---new_hunk---
```
[new hunk code, annotated with line numbers]
```
---old_hunk---
```
[old hunk code]
```
...
'
Don't output line numbers in the 'improved code' snippets.
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
"""

View File

@ -1,9 +1,9 @@
[pr_review_prompt]
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
Your task is to provide constructive and concise feedback for the PR, and also provide meaningfull code suggestions to improve the new PR code (the '+' lines).
Your task is to provide constructive and concise feedback for the PR, and also provide meaningful code suggestions to improve the new PR code (the '+' lines).
{%- if num_code_suggestions > 0 %}
- Provide up to {{ num_code_suggestions }} code suggestions.
- Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
- Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should focus on improving the new added code lines.
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
{%- endif %}

View File

@ -0,0 +1,46 @@
[pr_sort_code_suggestions_prompt]
system="""
"""
user="""You are given a list of code suggestions to improve a PR:
{{ suggestion_str|trim }}
Your task is to sort the code suggestions by their order of importance, and return a list with sorting order.
The sorting order is a list of pairs, where each pair contains the index of the suggestion in the original list.
Rank the suggestions based on their importance to improving the PR, with critical issues first and minor issues last.
You must use the following YAML schema to format your answer:
```yaml
Sort Order:
type: array
maxItems: {{ suggestion_list|length }}
uniqueItems: true
items:
suggestion number:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
importance order:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
```
Example output:
```yaml
Sort Order:
- suggestion number: 1
importance order: 2
- suggestion number: 2
importance order: 3
- suggestion number: 3
importance order: 1
```
Make sure to output a valid YAML. Use multi-line block scalar ('|') if needed.
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
Response (should be a valid YAML, and nothing else):
```yaml
"""

View File

@ -0,0 +1,215 @@
from typing import List
import copy
import json
import logging
import textwrap
from typing import Dict, Any
import yaml
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRExtendedCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": get_settings().pr_extendeted_code_suggestions.num_code_suggestions_per_chunk,
"extra_instructions": get_settings().pr_extendeted_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
async def run(self):
logging.info('Generating code suggestions for PR...')
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
data = await retry_with_fallback_models(self._prepare_prediction)
if get_settings().pr_extendeted_code_suggestions.rank_suggestions:
logging.info('Ranking Suggestions...')
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
logging.info('Preparing PR review...')
if get_settings().config.publish_output:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
self.push_inline_code_suggestions(data)
async def _prepare_prediction(self, model: str) -> dict:
logging.info('Getting PR diff...')
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_extendeted_code_suggestions.max_number_of_calls)
logging.info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model)
prediction_list.append(prediction)
self.prediction_list = prediction_list
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "Code suggestions" in data:
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
else:
data.update(data_per_chunk)
self.data = data
return data
async def _get_prediction(self, model: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_code_suggestions(self) -> str:
review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True)
return data
def push_inline_code_suggestions(self, data):
code_suggestions = []
if not data['Code suggestions']:
return self.git_provider.publish_comment('No suggestions found to improve this PR.')
for d in data['Code suggestions']:
try:
if get_settings().config.verbosity_level >= 1:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
if ',' in relevant_lines_str: # handling 'relevant lines': '181, 190' or '178-184, 188-194'
relevant_lines_str = relevant_lines_str.split(',')[0]
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
content = d['suggestion content']
new_code_snippet = d['improved code']
if new_code_snippet:
new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```"
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end})
except Exception:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions)
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet
async def rank_suggestions(self, data: List) -> List:
"""
Call a model to rank (sort) code suggestions based on their importance order.
Args:
data (List): A list of code suggestions to be ranked.
Returns:
List: The ranked list of code suggestions.
"""
suggestion_list = []
# remove invalid suggestions
for i, suggestion in enumerate(data):
if suggestion['existing code'] != suggestion['improved code']:
suggestion_list.append(suggestion)
data_sorted = [[]] * len(suggestion_list)
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, user=user_prompt)
sort_order = yaml.safe_load(response)
for s in sort_order['Sort Order']:
suggestion_number = s['suggestion number']
importance_order = s['importance order']
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
if get_settings().pr_extendeted_code_suggestions.final_clip_factor != 1:
new_len = int(0.5 + len(data_sorted) * get_settings().pr_extendeted_code_suggestions.final_clip_factor)
data_sorted = data_sorted[:new_len]
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not sort suggestions, error: {e}")
data_sorted = suggestion_list
return data_sorted