mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 04:40:38 +08:00
@ -26,8 +26,12 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str:
|
model: str,
|
||||||
|
add_line_numbers_to_hunks: bool = False,
|
||||||
|
disable_extra_lines: bool = False,
|
||||||
|
large_pr_handling=False,
|
||||||
|
return_remaining_files=False):
|
||||||
if disable_extra_lines:
|
if disable_extra_lines:
|
||||||
PATCH_EXTRA_LINES = 0
|
PATCH_EXTRA_LINES = 0
|
||||||
else:
|
else:
|
||||||
@ -72,7 +76,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
|||||||
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
|
||||||
f"pruning diff.")
|
f"pruning diff.")
|
||||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
|
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling)
|
||||||
|
|
||||||
if large_pr_handling and len(patches_compressed_list) > 1:
|
if large_pr_handling and len(patches_compressed_list) > 1:
|
||||||
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
|
||||||
@ -129,7 +133,10 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
|||||||
|
|
||||||
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
|
||||||
f"deleted_list_str: {deleted_list_str}")
|
f"deleted_list_str: {deleted_list_str}")
|
||||||
return final_diff
|
if not return_remaining_files:
|
||||||
|
return final_diff
|
||||||
|
else:
|
||||||
|
return final_diff, remaining_files_list
|
||||||
|
|
||||||
|
|
||||||
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||||
@ -159,7 +166,7 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
|
||||||
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
|
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks, large_pr_handling=True)
|
||||||
|
|
||||||
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||||
|
|
||||||
@ -208,7 +215,8 @@ def pr_generate_extended_diff(pr_languages: list,
|
|||||||
|
|
||||||
|
|
||||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]:
|
convert_hunks_to_line_numbers: bool,
|
||||||
|
large_pr_handling: bool) -> Tuple[list, list, list, list, dict, list]:
|
||||||
deleted_files_list = []
|
deleted_files_list = []
|
||||||
|
|
||||||
# sort each one of the languages in top_langs by the number of tokens in the diff
|
# sort each one of the languages in top_langs by the number of tokens in the diff
|
||||||
@ -253,18 +261,20 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
files_in_patches_list.append(files_in_patch_list)
|
files_in_patches_list.append(files_in_patch_list)
|
||||||
|
|
||||||
# additional iterations (if needed)
|
# additional iterations (if needed)
|
||||||
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
if large_pr_handling:
|
||||||
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
|
||||||
if remaining_files_list:
|
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
|
||||||
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
if remaining_files_list:
|
||||||
file_dict,
|
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
|
||||||
max_tokens_model,
|
file_dict,
|
||||||
remaining_files_list, token_handler)
|
max_tokens_model,
|
||||||
patches_list.append(patches)
|
remaining_files_list, token_handler)
|
||||||
total_tokens_list.append(total_tokens)
|
if patches:
|
||||||
files_in_patches_list.append(files_in_patch_list)
|
patches_list.append(patches)
|
||||||
else:
|
total_tokens_list.append(total_tokens)
|
||||||
break
|
files_in_patches_list.append(files_in_patch_list)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
|
||||||
|
|
||||||
|
@ -76,7 +76,8 @@ use_description_markers=false
|
|||||||
include_generated_by_header=true
|
include_generated_by_header=true
|
||||||
# large pr mode 💎
|
# large pr mode 💎
|
||||||
enable_large_pr_handling=true
|
enable_large_pr_handling=true
|
||||||
max_ai_calls=3
|
max_ai_calls=4
|
||||||
|
async_ai_calls=true
|
||||||
mention_extra_files=true
|
mention_extra_files=true
|
||||||
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
|
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
|
||||||
|
|
||||||
|
@ -71,10 +71,6 @@ class PRAddDocs:
|
|||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str):
|
||||||
get_logger().info('Getting PR diff...')
|
get_logger().info('Getting PR diff...')
|
||||||
|
|
||||||
# Disable adding docs to scripts and other non-relevant text files
|
|
||||||
from pr_agent.algo.language_handler import bad_extensions
|
|
||||||
bad_extensions += get_settings().docs_blacklist_extensions.docs_blacklist
|
|
||||||
|
|
||||||
self.patches_diff = get_pr_diff(self.git_provider,
|
self.patches_diff = get_pr_diff(self.git_provider,
|
||||||
self.token_handler,
|
self.token_handler,
|
||||||
model,
|
model,
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import yaml
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
@ -168,13 +170,21 @@ class PRDescription:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
|
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
|
||||||
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model,
|
output = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling, return_remaining_files=True)
|
||||||
large_pr_handling=large_pr_handling)
|
if isinstance(output, tuple):
|
||||||
|
patches_diff, remaining_files_list = output
|
||||||
|
else:
|
||||||
|
patches_diff = output
|
||||||
|
remaining_files_list = []
|
||||||
if not large_pr_handling or patches_diff:
|
if not large_pr_handling or patches_diff:
|
||||||
self.patches_diff = patches_diff
|
self.patches_diff = patches_diff
|
||||||
if patches_diff:
|
if patches_diff:
|
||||||
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
get_logger().debug(f"PR diff", artifact=self.patches_diff)
|
||||||
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
|
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
|
||||||
|
if (remaining_files_list and 'pr_files' in self.prediction and 'label:' in self.prediction and
|
||||||
|
get_settings().pr_description.mention_extra_files):
|
||||||
|
get_logger().debug(f"Extending additional files, {len(remaining_files_list)} files")
|
||||||
|
self.prediction = await self.extend_additional_files(remaining_files_list)
|
||||||
else:
|
else:
|
||||||
get_logger().error(f"Error getting PR diff {self.pr_id}")
|
get_logger().error(f"Error getting PR diff {self.pr_id}")
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
@ -192,13 +202,27 @@ class PRDescription:
|
|||||||
self.git_provider, token_handler_only_files_prompt, model)
|
self.git_provider, token_handler_only_files_prompt, model)
|
||||||
|
|
||||||
# get the files prediction for each patch
|
# get the files prediction for each patch
|
||||||
|
if not get_settings().pr_description.async_ai_calls:
|
||||||
|
results = []
|
||||||
|
for i, patches in enumerate(patches_compressed_list): # sync calls
|
||||||
|
patches_diff = "\n".join(patches)
|
||||||
|
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||||
|
prediction_files = await self._get_prediction(model, patches_diff,
|
||||||
|
prompt="pr_description_only_files_prompts")
|
||||||
|
results.append(prediction_files)
|
||||||
|
else: # async calls
|
||||||
|
tasks = []
|
||||||
|
for i, patches in enumerate(patches_compressed_list):
|
||||||
|
patches_diff = "\n".join(patches)
|
||||||
|
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._get_prediction(model, patches_diff, prompt="pr_description_only_files_prompts"))
|
||||||
|
tasks.append(task)
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
file_description_str_list = []
|
file_description_str_list = []
|
||||||
for i, patches in enumerate(patches_compressed_list):
|
for i, result in enumerate(results):
|
||||||
patches_diff = "\n".join(patches)
|
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
|
||||||
get_logger().debug(f"PR diff number {i + 1} for describe files")
|
|
||||||
prediction_files = await self._get_prediction(model, patches_diff,
|
|
||||||
prompt="pr_description_only_files_prompts")
|
|
||||||
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
|
|
||||||
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
|
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
|
||||||
prediction_files = prediction_files.removeprefix('pr_files:').strip()
|
prediction_files = prediction_files.removeprefix('pr_files:').strip()
|
||||||
file_description_str_list.append(prediction_files)
|
file_description_str_list.append(prediction_files)
|
||||||
@ -248,7 +272,7 @@ class PRDescription:
|
|||||||
changes_title: |
|
changes_title: |
|
||||||
...
|
...
|
||||||
label: |
|
label: |
|
||||||
not processed (token-limit)
|
additional files (token-limit)
|
||||||
"""
|
"""
|
||||||
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
|
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
|
||||||
# final processing
|
# final processing
|
||||||
@ -259,6 +283,34 @@ class PRDescription:
|
|||||||
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
get_logger().debug(f"Using only headers for describe {self.pr_id}")
|
||||||
self.prediction = prediction_headers
|
self.prediction = prediction_headers
|
||||||
|
|
||||||
|
async def extend_additional_files(self, remaining_files_list) -> str:
|
||||||
|
prediction = self.prediction
|
||||||
|
try:
|
||||||
|
original_prediction_dict = load_yaml(self.prediction)
|
||||||
|
prediction_extra = "pr_files:"
|
||||||
|
for file in remaining_files_list:
|
||||||
|
extra_file_yaml = f"""\
|
||||||
|
- filename: |
|
||||||
|
{file}
|
||||||
|
changes_summary: |
|
||||||
|
...
|
||||||
|
changes_title: |
|
||||||
|
...
|
||||||
|
label: |
|
||||||
|
additional files (token-limit)
|
||||||
|
"""
|
||||||
|
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
|
||||||
|
prediction_extra_dict = load_yaml(prediction_extra)
|
||||||
|
# merge the two dictionaries
|
||||||
|
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
|
||||||
|
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
|
||||||
|
new_yaml = yaml.dump(original_prediction_dict)
|
||||||
|
if load_yaml(new_yaml):
|
||||||
|
prediction = new_yaml
|
||||||
|
return prediction
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Error extending additional files {self.pr_id}: {e}")
|
||||||
|
return self.prediction
|
||||||
|
|
||||||
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
|
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
@ -481,9 +533,9 @@ class PRDescription:
|
|||||||
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
|
filename_publish = f"<strong>{filename_publish}</strong><dd>{file_changes_title_code_br}</dd>"
|
||||||
diff_plus_minus = ""
|
diff_plus_minus = ""
|
||||||
delta_nbsp = ""
|
delta_nbsp = ""
|
||||||
diff_files = self.git_provider.diff_files
|
diff_files = self.git_provider.get_diff_files()
|
||||||
for f in diff_files:
|
for f in diff_files:
|
||||||
if f.filename.lower() == filename.lower():
|
if f.filename.lower().strip('/') == filename.lower().strip('/'):
|
||||||
num_plus_lines = f.num_plus_lines
|
num_plus_lines = f.num_plus_lines
|
||||||
num_minus_lines = f.num_minus_lines
|
num_minus_lines = f.num_minus_lines
|
||||||
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"
|
diff_plus_minus += f"+{num_plus_lines}/-{num_minus_lines}"
|
||||||
|
Reference in New Issue
Block a user