From b85679e5e44e212a1e3ea3a171492f85477cdb9a Mon Sep 17 00:00:00 2001 From: mrT23 Date: Tue, 22 Aug 2023 09:42:59 +0300 Subject: [PATCH] improve --extend --- pr_agent/agent/pr_agent.py | 2 - pr_agent/algo/utils.py | 3 +- pr_agent/cli.py | 20 +- pr_agent/servers/help.py | 2 +- pr_agent/settings/configuration.toml | 10 +- pr_agent/tools/pr_code_suggestions.py | 107 ++++++++- .../tools/pr_extended_code_suggestions.py | 215 ------------------ 7 files changed, 122 insertions(+), 237 deletions(-) delete mode 100644 pr_agent/tools/pr_extended_code_suggestions.py diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index f722695c..70121f3c 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -11,7 +11,6 @@ from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer -from pr_agent.tools.pr_extended_code_suggestions import PRExtendedCodeSuggestions from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_config import PRConfig @@ -26,7 +25,6 @@ command2class = { "describe_pr": PRDescription, "improve": PRCodeSuggestions, "improve_code": PRCodeSuggestions, - "extended_improve": PRExtendedCodeSuggestions, "ask": PRQuestions, "ask_question": PRQuestions, "update_changelog": PRUpdateChangelog, diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 87e206cf..2139203f 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -247,7 +247,8 @@ def update_settings_from_args(args: List[str]) -> List[str]: arg = arg.strip('-').strip() vals = arg.split('=', 1) if len(vals) != 2: - logging.error(f'Invalid argument format: {arg}') + if len(vals) > 2: # --extended is a valid argument + logging.error(f'Invalid argument format: {arg}') other_args.append(arg) continue key, value = _fix_key_value(*vals) diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 0f871041..01c1a7ec 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -19,13 +19,21 @@ For example: - cli.py --pr_url=... reflect Supported commands: -review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement. -ask / ask_question [question] - Ask a question about the PR. -describe / describe_pr - Modify the PR title and description based on the PR's contents. -improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit. -reflect - Ask the PR author questions about the PR. -update_changelog - Update the changelog based on the PR's contents. +-review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement. +-ask / ask_question [question] - Ask a question about the PR. + +-describe / describe_pr - Modify the PR title and description based on the PR's contents. + +-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit. +Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback + +-reflect - Ask the PR author questions about the PR. + +-update_changelog - Update the changelog based on the PR's contents. + + +Configuration: To edit any configuration parameter from 'configuration.toml', just add -config_path=. For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' """) diff --git a/pr_agent/servers/help.py b/pr_agent/servers/help.py index 1ee7fc4d..0f3f3caa 100644 --- a/pr_agent/servers/help.py +++ b/pr_agent/servers/help.py @@ -1,7 +1,7 @@ commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \ "considers changes since the last review, include the '-i' option.\n" \ "> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \ - "> **/improve**: Suggest improvements to the code in the PR. \n" \ + "> **/improve [--extended]**: Suggest improvements to the code in the PR. Extended mode employs several calls, and provides a more thorough feedback. \n" \ "> **/ask \\**: Pose a question about the PR.\n" \ "> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \ ">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \ diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 00c9b453..b1d19f97 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -31,14 +31,12 @@ extra_instructions = "" [pr_code_suggestions] # /improve # num_code_suggestions=4 extra_instructions = "" - -[pr_extendeted_code_suggestions] # /extended_improve # +rank_suggestions = false +# params for '/improve --extended' mode num_code_suggestions_per_chunk=8 -extra_instructions = "" +rank_extended_suggestions = true max_number_of_calls = 5 -rank_suggestions = true -final_clip_factor = 0.5 - +final_clip_factor = 0.9 [pr_update_changelog] # /update_changelog # push_changelog_changes=false diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index ebb88b21..4dc2f400 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -2,11 +2,13 @@ import copy import json import logging import textwrap +from typing import List +import yaml from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_multi_diffs from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import try_fix_json from pr_agent.config_loader import get_settings @@ -22,6 +24,13 @@ class PRCodeSuggestions: self.git_provider.get_languages(), self.git_provider.get_files() ) + # extended mode + self.is_extended = any(["extended" in arg for arg in args]) + if self.is_extended: + num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk + else: + num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions + self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None @@ -32,7 +41,7 @@ class PRCodeSuggestions: "description": self.git_provider.get_pr_description(), "language": self.main_language, "diff": "", # empty diff for initial calculation - "num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions, + "num_code_suggestions": num_code_suggestions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), } @@ -42,14 +51,22 @@ class PRCodeSuggestions: get_settings().pr_code_suggestions_prompt.user) async def run(self): - # assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now" - logging.info('Generating code suggestions for PR...') if get_settings().config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) - await retry_with_fallback_models(self._prepare_prediction) + logging.info('Preparing PR review...') - data = self._prepare_pr_code_suggestions() + if not self.is_extended: + await retry_with_fallback_models(self._prepare_prediction) + data = self._prepare_pr_code_suggestions() + else: + data = await retry_with_fallback_models(self._prepare_prediction_extended) + + if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \ + (self.is_extended and get_settings().pr_code_suggestions.rank_extended_suggestions): + logging.info('Ranking Suggestions...') + data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions']) + if get_settings().config.publish_output: logging.info('Pushing PR review...') self.git_provider.remove_initial_comment() @@ -145,3 +162,81 @@ class PRCodeSuggestions: return new_code_snippet + async def _prepare_prediction_extended(self, model: str) -> dict: + logging.info('Getting PR diff...') + patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model, + max_calls=get_settings().pr_code_suggestions.max_number_of_calls) + + logging.info('Getting multi AI predictions...') + prediction_list = [] + for i, patches_diff in enumerate(patches_diff_list): + logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}") + self.patches_diff = patches_diff + prediction = await self._get_prediction(model) + prediction_list.append(prediction) + self.prediction_list = prediction_list + + data = {} + for prediction in prediction_list: + self.prediction = prediction + data_per_chunk = self._prepare_pr_code_suggestions() + if "Code suggestions" in data: + data["Code suggestions"].extend(data_per_chunk["Code suggestions"]) + else: + data.update(data_per_chunk) + self.data = data + return data + + async def rank_suggestions(self, data: List) -> List: + """ + Call a model to rank (sort) code suggestions based on their importance order. + + Args: + data (List): A list of code suggestions to be ranked. + + Returns: + List: The ranked list of code suggestions. + """ + + suggestion_list = [] + # remove invalid suggestions + for i, suggestion in enumerate(data): + if suggestion['existing code'] != suggestion['improved code']: + suggestion_list.append(suggestion) + + data_sorted = [[]] * len(suggestion_list) + + try: + suggestion_str = "" + for i, suggestion in enumerate(suggestion_list): + suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n' + + variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str} + model = get_settings().config.model + environment = Environment(undefined=StrictUndefined) + system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render( + variables) + user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables) + if get_settings().config.verbosity_level >= 2: + logging.info(f"\nSystem prompt:\n{system_prompt}") + logging.info(f"\nUser prompt:\n{user_prompt}") + response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, + user=user_prompt) + + sort_order = yaml.safe_load(response) + for s in sort_order['Sort Order']: + suggestion_number = s['suggestion number'] + importance_order = s['importance order'] + data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1] + + if get_settings().pr_extendeted_code_suggestions.final_clip_factor != 1: + new_len = int(0.5 + len(data_sorted) * get_settings().pr_extendeted_code_suggestions.final_clip_factor) + data_sorted = data_sorted[:new_len] + except Exception as e: + if get_settings().config.verbosity_level >= 1: + logging.info(f"Could not sort suggestions, error: {e}") + data_sorted = suggestion_list + + return data_sorted + + diff --git a/pr_agent/tools/pr_extended_code_suggestions.py b/pr_agent/tools/pr_extended_code_suggestions.py deleted file mode 100644 index 17f7b570..00000000 --- a/pr_agent/tools/pr_extended_code_suggestions.py +++ /dev/null @@ -1,215 +0,0 @@ -from typing import List -import copy -import json -import logging -import textwrap -from typing import Dict, Any - -import yaml -from jinja2 import Environment, StrictUndefined - -from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_multi_diffs, retry_with_fallback_models -from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import try_fix_json -from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider -from pr_agent.git_providers.git_provider import get_main_pr_language - - -class PRExtendedCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None): - - self.git_provider = get_git_provider()(pr_url) - self.main_language = get_main_pr_language( - self.git_provider.get_languages(), self.git_provider.get_files() - ) - - self.ai_handler = AiHandler() - self.patches_diff = None - self.prediction = None - self.cli_mode = cli_mode - self.vars = { - "title": self.git_provider.pr.title, - "branch": self.git_provider.get_pr_branch(), - "description": self.git_provider.get_pr_description(), - "language": self.main_language, - "diff": "", # empty diff for initial calculation - "num_code_suggestions": get_settings().pr_extendeted_code_suggestions.num_code_suggestions_per_chunk, - "extra_instructions": get_settings().pr_extendeted_code_suggestions.extra_instructions, - "commit_messages_str": self.git_provider.get_commit_messages(), - } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - get_settings().pr_code_suggestions_prompt.system, - get_settings().pr_code_suggestions_prompt.user) - - async def run(self): - logging.info('Generating code suggestions for PR...') - if get_settings().config.publish_output: - self.git_provider.publish_comment("Preparing review...", is_temporary=True) - data = await retry_with_fallback_models(self._prepare_prediction) - - if get_settings().pr_extendeted_code_suggestions.rank_suggestions: - logging.info('Ranking Suggestions...') - data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions']) - - logging.info('Preparing PR review...') - if get_settings().config.publish_output: - logging.info('Pushing PR review...') - self.git_provider.remove_initial_comment() - logging.info('Pushing inline code comments...') - self.push_inline_code_suggestions(data) - - async def _prepare_prediction(self, model: str) -> dict: - logging.info('Getting PR diff...') - patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model, - max_calls=get_settings().pr_extendeted_code_suggestions.max_number_of_calls) - - logging.info('Getting multi AI predictions...') - prediction_list = [] - for i, patches_diff in enumerate(patches_diff_list): - logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}") - self.patches_diff = patches_diff - prediction = await self._get_prediction(model) - prediction_list.append(prediction) - self.prediction_list = prediction_list - - data = {} - for prediction in prediction_list: - self.prediction = prediction - data_per_chunk = self._prepare_pr_code_suggestions() - if "Code suggestions" in data: - data["Code suggestions"].extend(data_per_chunk["Code suggestions"]) - else: - data.update(data_per_chunk) - self.data = data - return data - - async def _get_prediction(self, model: str): - variables = copy.deepcopy(self.vars) - variables["diff"] = self.patches_diff # update diff - environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) - if get_settings().config.verbosity_level >= 2: - logging.info(f"\nSystem prompt:\n{system_prompt}") - logging.info(f"\nUser prompt:\n{user_prompt}") - response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, - system=system_prompt, user=user_prompt) - - return response - - def _prepare_pr_code_suggestions(self) -> str: - review = self.prediction.strip() - try: - data = json.loads(review) - except json.decoder.JSONDecodeError: - if get_settings().config.verbosity_level >= 1: - logging.info(f"Could not parse json response: {review}") - data = try_fix_json(review, code_suggestions=True) - return data - - def push_inline_code_suggestions(self, data): - code_suggestions = [] - - if not data['Code suggestions']: - return self.git_provider.publish_comment('No suggestions found to improve this PR.') - - for d in data['Code suggestions']: - try: - if get_settings().config.verbosity_level >= 1: - logging.info(f"suggestion: {d}") - relevant_file = d['relevant file'].strip() - relevant_lines_str = d['relevant lines'].strip() - if ',' in relevant_lines_str: # handling 'relevant lines': '181, 190' or '178-184, 188-194' - relevant_lines_str = relevant_lines_str.split(',')[0] - relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position - relevant_lines_end = int(relevant_lines_str.split('-')[-1]) - content = d['suggestion content'] - new_code_snippet = d['improved code'] - - if new_code_snippet: - new_code_snippet = self.dedent_code(relevant_file, relevant_lines_start, new_code_snippet) - - body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```" - code_suggestions.append({'body': body, 'relevant_file': relevant_file, - 'relevant_lines_start': relevant_lines_start, - 'relevant_lines_end': relevant_lines_end}) - except Exception: - if get_settings().config.verbosity_level >= 1: - logging.info(f"Could not parse suggestion: {d}") - - self.git_provider.publish_code_suggestions(code_suggestions) - - def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet): - try: # dedent code snippet - self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files \ - else self.git_provider.get_diff_files() - original_initial_line = None - for file in self.diff_files: - if file.filename.strip() == relevant_file: - original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1] - break - if original_initial_line: - suggested_initial_line = new_code_snippet.splitlines()[0] - original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip()) - suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip()) - delta_spaces = original_initial_spaces - suggested_initial_spaces - if delta_spaces > 0: - new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n') - except Exception as e: - if get_settings().config.verbosity_level >= 1: - logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}") - - return new_code_snippet - - async def rank_suggestions(self, data: List) -> List: - """ - Call a model to rank (sort) code suggestions based on their importance order. - - Args: - data (List): A list of code suggestions to be ranked. - - Returns: - List: The ranked list of code suggestions. - """ - - suggestion_list = [] - # remove invalid suggestions - for i, suggestion in enumerate(data): - if suggestion['existing code'] != suggestion['improved code']: - suggestion_list.append(suggestion) - - data_sorted = [[]] * len(suggestion_list) - - try: - suggestion_str = "" - for i, suggestion in enumerate(suggestion_list): - suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n' - - variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str} - model = get_settings().config.model - environment = Environment(undefined=StrictUndefined) - system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(variables) - user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables) - if get_settings().config.verbosity_level >= 2: - logging.info(f"\nSystem prompt:\n{system_prompt}") - logging.info(f"\nUser prompt:\n{user_prompt}") - response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, user=user_prompt) - - sort_order = yaml.safe_load(response) - for s in sort_order['Sort Order']: - suggestion_number = s['suggestion number'] - importance_order = s['importance order'] - data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1] - - if get_settings().pr_extendeted_code_suggestions.final_clip_factor != 1: - new_len = int(0.5 + len(data_sorted) * get_settings().pr_extendeted_code_suggestions.final_clip_factor) - data_sorted = data_sorted[:new_len] - except Exception as e: - if get_settings().config.verbosity_level >= 1: - logging.info(f"Could not sort suggestions, error: {e}") - data_sorted = suggestion_list - - return data_sorted