mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
parallel_calls
This commit is contained in:
@ -82,8 +82,9 @@ enable_help_text=true
|
||||
# params for '/improve --extended' mode
|
||||
auto_extended_mode=true
|
||||
num_code_suggestions_per_chunk=5
|
||||
rank_extended_suggestions = false
|
||||
max_number_of_calls = 3
|
||||
parallel_calls = true
|
||||
rank_extended_suggestions = false
|
||||
final_clip_factor = 0.8
|
||||
|
||||
[pr_add_docs] # /add_docs #
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import textwrap
|
||||
from functools import partial
|
||||
@ -111,18 +112,18 @@ class PRCodeSuggestions:
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
get_logger().info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=True)
|
||||
|
||||
get_logger().info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
self.prediction = await self._get_prediction(model, patches_diff)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
async def _get_prediction(self, model: str, patches_diff: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
variables["diff"] = patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
|
||||
@ -229,14 +230,18 @@ class PRCodeSuggestions:
|
||||
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
|
||||
|
||||
get_logger().info('Getting multi AI predictions...')
|
||||
prediction_list = []
|
||||
for i, patches_diff in enumerate(patches_diff_list):
|
||||
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
|
||||
self.patches_diff = patches_diff
|
||||
prediction = await self._get_prediction(model) # toDo: parallelize
|
||||
prediction_list.append(prediction)
|
||||
self.prediction_list = prediction_list
|
||||
# parallelize calls to AI:
|
||||
if get_settings().pr_code_suggestions.parallel_calls:
|
||||
get_logger().info('Getting multi AI predictions in parallel...')
|
||||
prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in patches_diff_list])
|
||||
self.prediction_list = prediction_list
|
||||
else:
|
||||
get_logger().info('Getting multi AI predictions...')
|
||||
prediction_list = []
|
||||
for i, patches_diff in enumerate(patches_diff_list):
|
||||
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
|
||||
prediction = await self._get_prediction(model, patches_diff) # toDo: parallelize
|
||||
prediction_list.append(prediction)
|
||||
|
||||
data = {}
|
||||
for prediction in prediction_list:
|
||||
|
Reference in New Issue
Block a user