From 3e6263e1cc80d5e25d6b0dc1f24a0aa2d79fdf41 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sun, 30 Jun 2024 17:33:48 +0300 Subject: [PATCH] async calls --- pr_agent/settings/configuration.toml | 3 ++- pr_agent/tools/pr_description.py | 37 +++++++++++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 76521939..fe5e325e 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -76,7 +76,8 @@ use_description_markers=false include_generated_by_header=true # large pr mode 💎 enable_large_pr_handling=true -max_ai_calls=3 +max_ai_calls=4 +async_ai_calls=true mention_extra_files=true #custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other'] diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 866cb55f..9fd45417 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -1,3 +1,4 @@ +import asyncio import copy import re from functools import partial @@ -168,8 +169,7 @@ class PRDescription: return None large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() - patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, - large_pr_handling=large_pr_handling) + patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling) if not large_pr_handling or patches_diff: self.patches_diff = patches_diff if patches_diff: @@ -192,13 +192,27 @@ class PRDescription: self.git_provider, token_handler_only_files_prompt, model) # get the files prediction for each patch + if not get_settings().pr_description.async_ai_calls: + results = [] + for i, patches in enumerate(patches_compressed_list): # sync calls + patches_diff = "\n".join(patches) + get_logger().debug(f"PR diff number {i + 1} for describe files") + prediction_files = await self._get_prediction(model, patches_diff, + prompt="pr_description_only_files_prompts") + results.append(prediction_files) + else: # async calls + tasks = [] + for i, patches in enumerate(patches_compressed_list): + patches_diff = "\n".join(patches) + get_logger().debug(f"PR diff number {i + 1} for describe files") + task = asyncio.create_task( + self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts")) + tasks.append(task) + # Wait for all tasks to complete + results = await asyncio.gather(*tasks) file_description_str_list = [] - for i, patches in enumerate(patches_compressed_list): - patches_diff = "\n".join(patches) - get_logger().debug(f"PR diff number {i + 1} for describe files") - prediction_files = await self._get_prediction(model, patches_diff, - prompt="pr_description_only_files_prompts") - prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip() + for i, result in enumerate(results): + prediction_files = result.strip().removeprefix('```yaml').strip('`').strip() if load_yaml(prediction_files) and prediction_files.startswith('pr_files'): prediction_files = prediction_files.removeprefix('pr_files:').strip() file_description_str_list.append(prediction_files) @@ -248,7 +262,7 @@ class PRDescription: changes_title: | ... label: | - not processed (token-limit) + additional files (token-limit) """ files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip() # final processing @@ -259,7 +273,6 @@ class PRDescription: get_logger().debug(f"Using only headers for describe {self.pr_id}") self.prediction = prediction_headers - async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: variables = copy.deepcopy(self.vars) variables["diff"] = patches_diff # update diff @@ -481,9 +494,9 @@ class PRDescription: filename_publish = f"{filename_publish}
{file_changes_title_code_br}
" diff_plus_minus = "" delta_nbsp = "" - diff_files = self.git_provider.diff_files + diff_files = self.git_provider.get_diff_files() for f in diff_files: - if f.filename.lower() == filename.lower(): + if f.filename.lower().strip('/') == filename.lower().strip('/'): num_plus_lines = f.num_plus_lines num_minus_lines = f.num_minus_lines diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"