parallel_calls

This commit is contained in:
mrT23
2024-02-07 08:00:01 +02:00
parent ef1b0ce3e3
commit a7ce2b11b4
2 changed files with 19 additions and 13 deletions

View File

@ -82,8 +82,9 @@ enable_help_text=true
# params for '/improve --extended' mode
auto_extended_mode=true
num_code_suggestions_per_chunk=5
rank_extended_suggestions = false
max_number_of_calls = 3
parallel_calls = true
rank_extended_suggestions = false
final_clip_factor = 0.8
[pr_add_docs] # /add_docs #

View File

@ -1,3 +1,4 @@
import asyncio
import copy
import textwrap
from functools import partial
@ -111,18 +112,18 @@ class PRCodeSuggestions:
async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider,
patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)
get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
self.prediction = await self._get_prediction(model, patches_diff)
async def _get_prediction(self, model: str):
async def _get_prediction(self, model: str, patches_diff: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
@ -229,14 +230,18 @@ class PRCodeSuggestions:
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
get_logger().info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model) # toDo: parallelize
prediction_list.append(prediction)
self.prediction_list = prediction_list
# parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls:
get_logger().info('Getting multi AI predictions in parallel...')
prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in patches_diff_list])
self.prediction_list = prediction_list
else:
get_logger().info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
prediction = await self._get_prediction(model, patches_diff) # toDo: parallelize
prediction_list.append(prediction)
data = {}
for prediction in prediction_list: