async calls

This commit is contained in:
mrT23
2024-06-30 17:33:48 +03:00
parent 6a5f43f8ce
commit 3e6263e1cc
2 changed files with 27 additions and 13 deletions

View File

@ -76,7 +76,8 @@ use_description_markers=false
include_generated_by_header=true include_generated_by_header=true
# large pr mode 💎 # large pr mode 💎
enable_large_pr_handling=true enable_large_pr_handling=true
max_ai_calls=3 max_ai_calls=4
async_ai_calls=true
mention_extra_files=true mention_extra_files=true
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other'] #custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']

View File

@ -1,3 +1,4 @@
import asyncio
import copy import copy
import re import re
from functools import partial from functools import partial
@ -168,8 +169,7 @@ class PRDescription:
return None return None
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings() large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling)
large_pr_handling=large_pr_handling)
if not large_pr_handling or patches_diff: if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff self.patches_diff = patches_diff
if patches_diff: if patches_diff:
@ -192,13 +192,27 @@ class PRDescription:
self.git_provider, token_handler_only_files_prompt, model) self.git_provider, token_handler_only_files_prompt, model)
# get the files prediction for each patch # get the files prediction for each patch
if not get_settings().pr_description.async_ai_calls:
results = []
for i, patches in enumerate(patches_compressed_list): # sync calls
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
results.append(prediction_files)
else: # async calls
tasks = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
task = asyncio.create_task(
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
file_description_str_list = [] file_description_str_list = []
for i, patches in enumerate(patches_compressed_list): for i, result in enumerate(results):
patches_diff = "\n".join(patches) prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'): if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip() prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files) file_description_str_list.append(prediction_files)
@ -248,7 +262,7 @@ class PRDescription:
changes_title: | changes_title: |
... ...
label: | label: |
not processed (token-limit) additional files (token-limit)
""" """
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip() files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
# final processing # final processing
@ -259,7 +273,6 @@ class PRDescription:
get_logger().debug(f"Using only headers for describe {self.pr_id}") get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers self.prediction = prediction_headers
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str: async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars) variables = copy.deepcopy(self.vars)
variables["diff"] = patches_diff # update diff variables["diff"] = patches_diff # update diff
@ -481,9 +494,9 @@ class PRDescription:
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>" filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
diff_plus_minus = "" diff_plus_minus = ""
delta_nbsp = "" delta_nbsp = ""
diff_files = self.git_provider.diff_files diff_files = self.git_provider.get_diff_files()
for f in diff_files: for f in diff_files:
if f.filename.lower() == filename.lower(): if f.filename.lower().strip('/') == filename.lower().strip('/'):
num_plus_lines = f.num_plus_lines num_plus_lines = f.num_plus_lines
num_minus_lines = f.num_minus_lines num_minus_lines = f.num_minus_lines
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}" diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"