From 3e6263e1cc80d5e25d6b0dc1f24a0aa2d79fdf41 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 30 Jun 2024 17:33:48 +0300 Subject: [PATCH 1/4] async calls --- pr_agent/settings/configuration.toml | 3 ++- pr_agent/tools/pr_description.py | 37 +++++++++++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 76521939..fe5e325e 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -76,7 +76,8 @@ use_description_markers=false include_generated_by_header=true # large pr mode 💎 enable_large_pr_handling=true -max_ai_calls=3 +max_ai_calls=4 +async_ai_calls=true mention_extra_files=true #custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other'] diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 866cb55f..9fd45417 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -1,3 +1,4 @@ +import asyncio import copy import re from functools import partial @@ -168,8 +169,7 @@ class PRDescription: return None large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() - patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, - large_pr_handling=large_pr_handling) + patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling) if not large_pr_handling or patches_diff: self.patches_diff = patches_diff if patches_diff: @@ -192,13 +192,27 @@ class PRDescription: self.git_provider, token_handler_only_files_prompt, model) # get the files prediction for each patch + if not get_settings().pr_description.async_ai_calls: + results = [] + for i, patches in enumerate(patches_compressed_list): # sync calls + patches_diff = "\n".join(patches) + get_logger().debug(f"PR diff number {i + 1} for describe files") + prediction_files = await self._get_prediction(model, patches_diff, + prompt="pr_description_only_files_prompts") + results.append(prediction_files) + else: # async calls + tasks = [] + for i, patches in enumerate(patches_compressed_list): + patches_diff = "\n".join(patches) + get_logger().debug(f"PR diff number {i + 1} for describe files") + task = asyncio.create_task( + self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts")) + tasks.append(task) + # Wait for all tasks to complete + results = await asyncio.gather(*tasks) file_description_str_list = [] - for i, patches in enumerate(patches_compressed_list): - patches_diff = "\n".join(patches) - get_logger().debug(f"PR diff number {i + 1} for describe files") - prediction_files = await self._get_prediction(model, patches_diff, - prompt="pr_description_only_files_prompts") - prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip() + for i, result in enumerate(results): + prediction_files = result.strip().removeprefix('```yaml').strip('`').strip() if load_yaml(prediction_files) and prediction_files.startswith('pr_files'): prediction_files = prediction_files.removeprefix('pr_files:').strip() file_description_str_list.append(prediction_files) @@ -248,7 +262,7 @@ class PRDescription: changes_title: | ... label: | - not processed (token-limit) + additional files (token-limit) """ files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip() # final processing @@ -259,7 +273,6 @@ class PRDescription: get_logger().debug(f"Using only headers for describe {self.pr_id}") self.prediction = prediction_headers - async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: variables = copy.deepcopy(self.vars) variables["diff"] = patches_diff # update diff @@ -481,9 +494,9 @@ class PRDescription: filename_publish = f"{filename_publish}
{file_changes_title_code_br}
" diff_plus_minus = "" delta_nbsp = "" - diff_files = self.git_provider.diff_files + diff_files = self.git_provider.get_diff_files() for f in diff_files: - if f.filename.lower() == filename.lower(): + if f.filename.lower().strip('/') == filename.lower().strip('/'): num_plus_lines = f.num_plus_lines num_minus_lines = f.num_minus_lines diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}" From f2cb70ea67ddb32c97f087a45b2cb8bd191cfdb9 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 30 Jun 2024 18:38:06 +0300 Subject: [PATCH 2/4] extend additional files --- pr_agent/algo/pr_processing.py | 43 +++++++++++++++++++------------- pr_agent/tools/pr_add_docs.py | 4 --- pr_agent/tools/pr_description.py | 41 +++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 818e98ff..4e770be2 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -26,8 +26,12 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 -def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, - add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str: +def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, + model: str, + add_line_numbers_to_hunks: bool = False, + disable_extra_lines: bool = False, + large_pr_handling=False, + return_remaining_files=False): if disable_extra_lines: PATCH_EXTRA_LINES = 0 else: @@ -72,7 +76,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " f"pruning diff.") patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ - pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) + pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling) if large_pr_handling and len(patches_compressed_list) > 1: get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.") @@ -129,7 +133,10 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " f"deleted_list_str: {deleted_list_str}") - return final_diff + if not return_remaining_files: + return final_diff + else: + return final_diff, remaining_files_list def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str, @@ -208,7 +215,8 @@ def pr_generate_extended_diff(pr_languages: list, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, - convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]: + convert_hunks_to_line_numbers: bool, + large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]: deleted_files_list = [] # sort each one of the languages in top_langs by the number of tokens in the diff @@ -253,18 +261,19 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo files_in_patches_list.append(files_in_patch_list) # additional iterations (if needed) - NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize - for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1): - if remaining_files_list: - total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, - file_dict, - max_tokens_model, - remaining_files_list, token_handler) - patches_list.append(patches) - total_tokens_list.append(total_tokens) - files_in_patches_list.append(files_in_patch_list) - else: - break + if large_pr_handling: + NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize + for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1): + if remaining_files_list: + total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, + file_dict, + max_tokens_model, + remaining_files_list, token_handler) + patches_list.append(patches) + total_tokens_list.append(total_tokens) + files_in_patches_list.append(files_in_patch_list) + else: + break return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index a671dd3b..52d60ef9 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -71,10 +71,6 @@ class PRAddDocs: async def _prepare_prediction(self, model: str): get_logger().info('Getting PR diff...') - # Disable adding docs to scripts and other non-relevant text files - from pr_agent.algo.language_handler import bad_extensions - bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist - self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 9fd45417..b16a3a32 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -4,6 +4,7 @@ import re from functools import partial from typing import List, Tuple +import yaml from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler @@ -169,12 +170,21 @@ class PRDescription: return None large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() - patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling) + output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True) + if isinstance(output, tuple): + patches_diff, remaining_files_list = output + else: + patches_diff = output + remaining_files_list = [] if not large_pr_handling or patches_diff: self.patches_diff = patches_diff if patches_diff: get_logger().debug(f"PR diff", artifact=self.patches_diff) self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt") + if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and + get_settings().pr_description.mention_extra_files): + get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files") + self.prediction = await self.extend_additional_files(remaining_files_list) else: get_logger().error(f"Error getting PR diff {self.pr_id}") self.prediction = None @@ -273,6 +283,35 @@ class PRDescription: get_logger().debug(f"Using only headers for describe {self.pr_id}") self.prediction = prediction_headers + async def extend_additional_files(self, remaining_files_list) -> str: + prediction = self.prediction + try: + original_prediction_dict = load_yaml(self.prediction) + prediction_extra = "pr_files:" + for file in remaining_files_list: + extra_file_yaml = f"""\ +- filename: | + {file} + changes_summary: | + ... + changes_title: | + ... + label: | + additional files (token-limit) +""" + prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() + prediction_extra_dict = load_yaml(prediction_extra) + # merge the two dictionaries + if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): + original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) + new_yaml = yaml.dump(original_prediction_dict) + if load_yaml(new_yaml): + prediction = new_yaml + return prediction + except Exception as e: + get_logger().error(f"Error extending additional files {self.pr_id}: {e}") + return self.prediction + async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: variables = copy.deepcopy(self.vars) variables["diff"] = patches_diff # update diff From f058c09a68c1f3b40ffae980e8ff36bc6b0ad1c2 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 30 Jun 2024 20:20:50 +0300 Subject: [PATCH 3/4] extend additional files --- pr_agent/algo/pr_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 4e770be2..fb59b0f3 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -166,7 +166,7 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH pass patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ - pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) + pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True) return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list From 8d87b41cf248ee630895fbbd95a855a9a04b5925 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 30 Jun 2024 20:28:32 +0300 Subject: [PATCH 4/4] extend additional files --- pr_agent/algo/pr_processing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index fb59b0f3..96230d66 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -269,9 +269,10 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo file_dict, max_tokens_model, remaining_files_list, token_handler) - patches_list.append(patches) - total_tokens_list.append(total_tokens) - files_in_patches_list.append(files_in_patch_list) + if patches: + patches_list.append(patches) + total_tokens_list.append(total_tokens) + files_in_patches_list.append(files_in_patch_list) else: break