diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index eea41a7f..7555c292 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -82,8 +82,9 @@ enable_help_text=true # params for '/improve --extended' mode auto_extended_mode=true num_code_suggestions_per_chunk=5 -rank_extended_suggestions = false max_number_of_calls = 3 +parallel_calls = true +rank_extended_suggestions = false final_clip_factor = 0.8 [pr_add_docs] # /add_docs # diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 5fc45619..de922730 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -1,3 +1,4 @@ +import asyncio import copy import textwrap from functools import partial @@ -111,18 +112,18 @@ class PRCodeSuggestions: async def _prepare_prediction(self, model: str): get_logger().info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, + patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, add_line_numbers_to_hunks=True, disable_extra_lines=True) get_logger().info('Getting AI prediction...') - self.prediction = await self._get_prediction(model) + self.prediction = await self._get_prediction(model, patches_diff) - async def _get_prediction(self, model: str): + async def _get_prediction(self, model: str, patches_diff: str): variables = copy.deepcopy(self.vars) - variables["diff"] = self.patches_diff # update diff + variables["diff"] = patches_diff # update diff environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables) @@ -229,14 +230,18 @@ class PRCodeSuggestions: patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model, max_calls=get_settings().pr_code_suggestions.max_number_of_calls) - get_logger().info('Getting multi AI predictions...') - prediction_list = [] - for i, patches_diff in enumerate(patches_diff_list): - get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}") - self.patches_diff = patches_diff - prediction = await self._get_prediction(model) # toDo: parallelize - prediction_list.append(prediction) - self.prediction_list = prediction_list + # parallelize calls to AI: + if get_settings().pr_code_suggestions.parallel_calls: + get_logger().info('Getting multi AI predictions in parallel...') + prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in patches_diff_list]) + self.prediction_list = prediction_list + else: + get_logger().info('Getting multi AI predictions...') + prediction_list = [] + for i, patches_diff in enumerate(patches_diff_list): + get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}") + prediction = await self._get_prediction(model, patches_diff) # toDo: parallelize + prediction_list.append(prediction) data = {} for prediction in prediction_list: