improve --extend

This commit is contained in:
mrT23
2023-08-22 09:42:59 +03:00
parent fb9335f424
commit b85679e5e4
7 changed files with 122 additions and 237 deletions

View File

@ -11,7 +11,6 @@ from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_extended_code_suggestions import PRExtendedCodeSuggestions
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
from pr_agent.tools.pr_config import PRConfig from pr_agent.tools.pr_config import PRConfig
@ -26,7 +25,6 @@ command2class = {
"describe_pr": PRDescription, "describe_pr": PRDescription,
"improve": PRCodeSuggestions, "improve": PRCodeSuggestions,
"improve_code": PRCodeSuggestions, "improve_code": PRCodeSuggestions,
"extended_improve": PRExtendedCodeSuggestions,
"ask": PRQuestions, "ask": PRQuestions,
"ask_question": PRQuestions, "ask_question": PRQuestions,
"update_changelog": PRUpdateChangelog, "update_changelog": PRUpdateChangelog,

View File

@ -247,7 +247,8 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip('-').strip() arg = arg.strip('-').strip()
vals = arg.split('=', 1) vals = arg.split('=', 1)
if len(vals) != 2: if len(vals) != 2:
logging.error(f'Invalid argument format: {arg}') if len(vals) > 2: # --extended is a valid argument
logging.error(f'Invalid argument format: {arg}')
other_args.append(arg) other_args.append(arg)
continue continue
key, value = _fix_key_value(*vals) key, value = _fix_key_value(*vals)

View File

@ -19,13 +19,21 @@ For example:
- cli.py --pr_url=... reflect - cli.py --pr_url=... reflect
Supported commands: Supported commands:
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement. -review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
reflect - Ask the PR author questions about the PR.
update_changelog - Update the changelog based on the PR's contents.
-ask / ask_question [question] - Ask a question about the PR.
-describe / describe_pr - Modify the PR title and description based on the PR's contents.
-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
-reflect - Ask the PR author questions about the PR.
-update_changelog - Update the changelog based on the PR's contents.
Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>. To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""") """)

View File

@ -1,7 +1,7 @@
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \ commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
"considers changes since the last review, include the '-i' option.\n" \ "considers changes since the last review, include the '-i' option.\n" \
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \ "> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
"> **/improve**: Suggest improvements to the code in the PR. \n" \ "> **/improve [--extended]**: Suggest improvements to the code in the PR. Extended mode employs several calls, and provides a more thorough feedback. \n" \
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \ "> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \
"> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \ "> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \
">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \ ">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \

View File

@ -31,14 +31,12 @@ extra_instructions = ""
[pr_code_suggestions] # /improve # [pr_code_suggestions] # /improve #
num_code_suggestions=4 num_code_suggestions=4
extra_instructions = "" extra_instructions = ""
rank_suggestions = false
[pr_extendeted_code_suggestions] # /extended_improve # # params for '/improve --extended' mode
num_code_suggestions_per_chunk=8 num_code_suggestions_per_chunk=8
extra_instructions = "" rank_extended_suggestions = true
max_number_of_calls = 5 max_number_of_calls = 5
rank_suggestions = true final_clip_factor = 0.9
final_clip_factor = 0.5
[pr_update_changelog] # /update_changelog # [pr_update_changelog] # /update_changelog #
push_changelog_changes=false push_changelog_changes=false

View File

@ -2,11 +2,13 @@ import copy
import json import json
import logging import logging
import textwrap import textwrap
from typing import List
import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_multi_diffs
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -22,6 +24,13 @@ class PRCodeSuggestions:
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
# extended mode
self.is_extended = any(["extended" in arg for arg in args])
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -32,7 +41,7 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions, "num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
@ -42,14 +51,22 @@ class PRCodeSuggestions:
get_settings().pr_code_suggestions_prompt.user) get_settings().pr_code_suggestions_prompt.user)
async def run(self): async def run(self):
# assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...') logging.info('Generating code suggestions for PR...')
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions() if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \
(self.is_extended and get_settings().pr_code_suggestions.rank_extended_suggestions):
logging.info('Ranking Suggestions...')
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
if get_settings().config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
@ -145,3 +162,81 @@ class PRCodeSuggestions:
return new_code_snippet return new_code_snippet
async def _prepare_prediction_extended(self, model: str) -> dict:
logging.info('Getting PR diff...')
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
logging.info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model)
prediction_list.append(prediction)
self.prediction_list = prediction_list
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "Code suggestions" in data:
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
else:
data.update(data_per_chunk)
self.data = data
return data
async def rank_suggestions(self, data: List) -> List:
"""
Call a model to rank (sort) code suggestions based on their importance order.
Args:
data (List): A list of code suggestions to be ranked.
Returns:
List: The ranked list of code suggestions.
"""
suggestion_list = []
# remove invalid suggestions
for i, suggestion in enumerate(data):
if suggestion['existing code'] != suggestion['improved code']:
suggestion_list.append(suggestion)
data_sorted = [[]] * len(suggestion_list)
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(
variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt,
user=user_prompt)
sort_order = yaml.safe_load(response)
for s in sort_order['Sort Order']:
suggestion_number = s['suggestion number']
importance_order = s['importance order']
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
if get_settings().pr_extendeted_code_suggestions.final_clip_factor != 1:
new_len = int(0.5 + len(data_sorted) * get_settings().pr_extendeted_code_suggestions.final_clip_factor)
data_sorted = data_sorted[:new_len]
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not sort suggestions, error: {e}")
data_sorted = suggestion_list
return data_sorted

View File

@ -1,215 +0,0 @@
from typing import List
import copy
import json
import logging
import textwrap
from typing import Dict, Any
import yaml
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRExtendedCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
"num_code_suggestions": get_settings().pr_extendeted_code_suggestions.num_code_suggestions_per_chunk,
"extra_instructions": get_settings().pr_extendeted_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
get_settings().pr_code_suggestions_prompt.system,
get_settings().pr_code_suggestions_prompt.user)
async def run(self):
logging.info('Generating code suggestions for PR...')
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
data = await retry_with_fallback_models(self._prepare_prediction)
if get_settings().pr_extendeted_code_suggestions.rank_suggestions:
logging.info('Ranking Suggestions...')
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
logging.info('Preparing PR review...')
if get_settings().config.publish_output:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
self.push_inline_code_suggestions(data)
async def _prepare_prediction(self, model: str) -> dict:
logging.info('Getting PR diff...')
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_extendeted_code_suggestions.max_number_of_calls)
logging.info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model)
prediction_list.append(prediction)
self.prediction_list = prediction_list
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "Code suggestions" in data:
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
else:
data.update(data_per_chunk)
self.data = data
return data
async def _get_prediction(self, model: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_code_suggestions(self) -> str:
review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review, code_suggestions=True)
return data
def push_inline_code_suggestions(self, data):
code_suggestions = []
if not data['Code suggestions']:
return self.git_provider.publish_comment('No suggestions found to improve this PR.')
for d in data['Code suggestions']:
try:
if get_settings().config.verbosity_level >= 1:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
if ',' in relevant_lines_str: # handling 'relevant lines': '181, 190' or '178-184, 188-194'
relevant_lines_str = relevant_lines_str.split(',')[0]
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
content = d['suggestion content']
new_code_snippet = d['improved code']
if new_code_snippet:
new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet)
body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```"
code_suggestions.append({'body': body, 'relevant_file': relevant_file,
'relevant_lines_start': relevant_lines_start,
'relevant_lines_end': relevant_lines_end})
except Exception:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions)
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \
else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
return new_code_snippet
async def rank_suggestions(self, data: List) -> List:
"""
Call a model to rank (sort) code suggestions based on their importance order.
Args:
data (List): A list of code suggestions to be ranked.
Returns:
List: The ranked list of code suggestions.
"""
suggestion_list = []
# remove invalid suggestions
for i, suggestion in enumerate(data):
if suggestion['existing code'] != suggestion['improved code']:
suggestion_list.append(suggestion)
data_sorted = [[]] * len(suggestion_list)
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, user=user_prompt)
sort_order = yaml.safe_load(response)
for s in sort_order['Sort Order']:
suggestion_number = s['suggestion number']
importance_order = s['importance order']
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
if get_settings().pr_extendeted_code_suggestions.final_clip_factor != 1:
new_len = int(0.5 + len(data_sorted) * get_settings().pr_extendeted_code_suggestions.final_clip_factor)
data_sorted = data_sorted[:new_len]
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not sort suggestions, error: {e}")
data_sorted = suggestion_list
return data_sorted