From eed23a7aaa3e838585a5f2ab1e9a2b7e61613b5a Mon Sep 17 00:00:00 2001 From: mrT23 Date: Mon, 4 Mar 2024 08:16:05 +0200 Subject: [PATCH] Add truncation and summarization features to PR code suggestions --- .../settings/pr_code_suggestions_prompts.toml | 10 +++--- pr_agent/tools/pr_code_suggestions.py | 36 ++++++++++++++++--- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/pr_agent/settings/pr_code_suggestions_prompts.toml b/pr_agent/settings/pr_code_suggestions_prompts.toml index 8e29b5a9..aae955a3 100644 --- a/pr_agent/settings/pr_code_suggestions_prompts.toml +++ b/pr_agent/settings/pr_code_suggestions_prompts.toml @@ -40,12 +40,14 @@ Specific instructions: {%- if extra_instructions %} + Extra instructions from the user: ====== {{ extra_instructions }} ====== {%- endif %} + The output must be a YAML object equivalent to type $PRCodeSuggestions, according to the following Pydantic definitions: ===== class CodeSuggestion(BaseModel): @@ -80,18 +82,18 @@ code_suggestions: ... {%- if summarize_mode %} existing_code: | - def func1(): + ... improved_code: | ... one_sentence_summary: | ... relevant_lines_start: 12 - relevant_lines_end: 12 + relevant_lines_end: 13 {%- else %} existing_code: | - def func1(): + ... relevant_lines_start: 12 - relevant_lines_end: 12 + relevant_lines_end: 13 improved_code: | ... {%- endif %} diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index e96c31d3..b0e076b0 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -137,6 +137,7 @@ class PRCodeSuggestions: model, add_line_numbers_to_hunks=True, disable_extra_lines=True) + if self.patches_diff: get_logger().debug(f"PR diff", artifact=self.patches_diff) self.prediction = await self._get_prediction(model, self.patches_diff) @@ -150,12 +151,23 @@ class PRCodeSuggestions: environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) - response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt) return response + @staticmethod + def _truncate_if_needed(suggestion): + max_code_suggestion_length = get_settings().get("PR_CODE_SUGGESTIONS.MAX_CODE_SUGGESTION_LENGTH", 0) + suggestion_truncation_message = get_settings().get("PR_CODE_SUGGESTIONS.SUGGESTION_TRUNCATION_MESSAGE", "") + if max_code_suggestion_length > 0: + if len(suggestion['improved_code']) > max_code_suggestion_length: + suggestion['improved_code'] = suggestion['improved_code'][:max_code_suggestion_length] + suggestion['improved_code'] += f"\n{suggestion_truncation_message}" + get_logger().info(f"Truncated suggestion from {len(suggestion['improved_code'])} " + f"characters to {max_code_suggestion_length} characters") + return suggestion + def _prepare_pr_code_suggestions(self) -> Dict: review = self.prediction.strip() data = load_yaml(review, @@ -165,8 +177,22 @@ class PRCodeSuggestions: # remove invalid suggestions suggestion_list = [] + one_sentence_summary_list = [] for i, suggestion in enumerate(data['code_suggestions']): - if suggestion['existing_code'] != suggestion['improved_code']: + if get_settings().pr_code_suggestions.summarize: + if not suggestion or 'one_sentence_summary' not in suggestion or 'label' not in suggestion or 'relevant_file' not in suggestion: + get_logger().debug(f"Skipping suggestion {i + 1}, because it is invalid: {suggestion}") + continue + + if suggestion['one_sentence_summary'] in one_sentence_summary_list: + get_logger().debug(f"Skipping suggestion {i + 1}, because it is a duplicate: {suggestion}") + continue + + if ('existing_code' in suggestion) and ('improved_code' in suggestion) and ( + suggestion['existing_code'] != suggestion['improved_code']): + suggestion = self._truncate_if_needed(suggestion) + if get_settings().pr_code_suggestions.summarize: + one_sentence_summary_list.append(suggestion['one_sentence_summary']) suggestion_list.append(suggestion) else: get_logger().debug( @@ -244,13 +270,14 @@ class PRCodeSuggestions: async def _prepare_prediction_extended(self, model: str) -> dict: self.patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model, - max_calls=get_settings().pr_code_suggestions.max_number_of_calls) + max_calls=get_settings().pr_code_suggestions.max_number_of_calls) if self.patches_diff_list: get_logger().debug(f"PR diff", artifact=self.patches_diff_list) # parallelize calls to AI: if get_settings().pr_code_suggestions.parallel_calls: - prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in self.patches_diff_list]) + prediction_list = await asyncio.gather( + *[self._get_prediction(model, patches_diff) for patches_diff in self.patches_diff_list]) self.prediction_list = prediction_list else: prediction_list = [] @@ -304,7 +331,6 @@ class PRCodeSuggestions: system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render( variables) user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables) - response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt, user=user_prompt)