This commit is contained in:
mrT23
2024-06-26 20:11:20 +03:00
parent 55a82382ef
commit 0f920bcc5b
3 changed files with 333 additions and 123 deletions

View File

@ -24,26 +24,10 @@ ADDED_FILES_ = "Additional added files (insufficient token budget to process):\n
OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
"""
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull
request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the
diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with
extra lines of context. Defaults to False.
Returns:
str: A string with the diff of the pull request, applying diff minimization techniques if needed.
"""
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False, large_pr_handling=False) -> str:
if disable_extra_lines:
PATCH_EXTRA_LINES = 0
else:
@ -87,39 +71,99 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
# if we are over the limit, start pruning
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed, modified_file_names, deleted_file_names, added_file_names, total_tokens_new = \
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
if large_pr_handling and len(patches_compressed_list) > 1:
get_logger().info(f"Large PR handling mode, and found {len(patches_compressed_list)} patches with original diff.")
return "" # return empty string, as we generate multiple patches with a different prompt
# return the first patch
patches_compressed = patches_compressed_list[0]
total_tokens_new = total_tokens_list[0]
files_in_patch = files_in_patches_list[0]
# Insert additional information about added, modified, and deleted files if there is enough space
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
final_diff = "\n".join(patches_compressed)
delta_tokens = 10
if added_file_names and (max_tokens - curr_token) > delta_tokens:
added_list_str = ADDED_FILES_ + "\n".join(added_file_names)
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
if added_list_str:
final_diff = final_diff + "\n\n" + added_list_str
curr_token += token_handler.count_tokens(added_list_str) + 2
if modified_file_names and (max_tokens - curr_token) > delta_tokens:
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
if modified_list_str:
final_diff = final_diff + "\n\n" + modified_list_str
curr_token += token_handler.count_tokens(modified_list_str) + 2
if deleted_file_names and (max_tokens - curr_token) > delta_tokens:
deleted_list_str = DELETED_FILES_ + "\n".join(deleted_file_names)
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
try:
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
except Exception as e:
pass
added_list_str = modified_list_str = deleted_list_str = ""
unprocessed_files = []
# generate the added, modified, and deleted files lists
if (max_tokens - curr_token) > delta_tokens:
for filename, file_values in file_dict.items():
if filename in files_in_patch:
continue
if file_values['edit_type'] == EDIT_TYPE.ADDED:
unprocessed_files.append(filename)
if not added_list_str:
added_list_str = ADDED_FILES_ + f"\n{filename}"
else:
added_list_str = added_list_str + f"\n{filename}"
elif file_values['edit_type'] == EDIT_TYPE.MODIFIED or EDIT_TYPE.RENAMED:
unprocessed_files.append(filename)
if not modified_list_str:
modified_list_str = MORE_MODIFIED_FILES_ + f"\n{filename}"
else:
modified_list_str = modified_list_str + f"\n{filename}"
elif file_values['edit_type'] == EDIT_TYPE.DELETED:
# unprocessed_files.append(filename) # not needed here, because the file was deleted, so no need to process it
if not deleted_list_str:
deleted_list_str = DELETED_FILES_ + f"\n{filename}"
else:
deleted_list_str = deleted_list_str + f"\n{filename}"
# prune the added, modified, and deleted files lists, and add them to the final diff
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
if added_list_str:
final_diff = final_diff + "\n\n" + added_list_str
curr_token += token_handler.count_tokens(added_list_str) + 2
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
if modified_list_str:
final_diff = final_diff + "\n\n" + modified_list_str
curr_token += token_handler.count_tokens(modified_list_str) + 2
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
return final_diff
def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False):
try:
diff_files_original = git_provider.get_diff_files()
except RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for git provider API. original message {e}")
raise
diff_files = filter_ignored(diff_files_original)
if diff_files != diff_files_original:
try:
get_logger().info(f"Filtered out {len(diff_files_original) - len(diff_files)} files")
new_names = set([a.filename for a in diff_files])
orig_names = set([a.filename for a in diff_files_original])
get_logger().info(f"Filtered out files: {orig_names - new_names}")
except Exception as e:
pass
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
if pr_languages:
try:
get_logger().info(f"PR main language: {pr_languages[0]['language']}")
except Exception as e:
pass
patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)
return patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool,
@ -164,41 +208,16 @@ def pr_generate_extended_diff(pr_languages: list,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, int]:
"""
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their
corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the
pull request.
model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns:
Tuple[list, list, list]: A tuple containing the following lists:
- patches: A list of compressed diff patches for each file in the pull request.
- modified_files_list: A list of file names that were skipped due to large patch size.
- deleted_files_list: A list of file names that were deleted in the pull request.
Minimization techniques to reduce the number of tokens:
0. Start from the largest diff patch to smaller ones
1. Don't use extend context lines around diff
2. Minimize deleted files
3. Minimize deleted hunks
4. Minimize all remaining files when you reach token limit
"""
patches = []
added_files_list = []
modified_files_list = []
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, dict, list]:
deleted_files_list = []
# sort each one of the languages in top_langs by the number of tokens in the diff
sorted_files = []
for lang in top_langs:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
total_tokens = token_handler.prompt_tokens
# generate patches for each file, and count tokens
file_dict = {}
for file in sorted_files:
original_file_content_str = file.base_file
new_file_content_str = file.head_file
@ -210,55 +229,85 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename, file.edit_type)
if patch is None:
# if not deleted_files_list:
# total_tokens += token_handler.count_tokens(DELETED_FILES_)
if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
continue
if convert_hunks_to_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
file_dict[file.filename] = {'patch': patch, 'tokens': new_patch_tokens, 'edit_type': file.edit_type}
max_tokens_model = get_max_tokens(model)
# first iteration
files_in_patches_list = []
remaining_files_list = [file.filename for file in sorted_files]
patches_list =[]
total_tokens_list = []
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers, file_dict,
max_tokens_model, remaining_files_list, token_handler)
patches_list.append(patches)
total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list)
# additional iterations (if needed)
NUMBER_OF_ALLOWED_ITERATIONS = get_settings().pr_description.max_ai_calls - 1 # one more call is to summarize
for i in range(NUMBER_OF_ALLOWED_ITERATIONS-1):
if remaining_files_list:
total_tokens, patches, remaining_files_list, files_in_patch_list = generate_full_patch(convert_hunks_to_line_numbers,
file_dict,
max_tokens_model,
remaining_files_list, token_handler)
patches_list.append(patches)
total_tokens_list.append(total_tokens)
files_in_patches_list.append(files_in_patch_list)
else:
break
return patches_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list
def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_model,remaining_files_list_prev, token_handler):
total_tokens = token_handler.prompt_tokens # initial tokens
patches = []
remaining_files_list_new = []
files_in_patch_list = []
for filename, data in file_dict.items():
if filename not in remaining_files_list_prev:
continue
patch = data['patch']
new_patch_tokens = data['tokens']
edit_type = data['edit_type']
# Hard Stop, no more tokens
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {filename}.")
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
if get_settings().config.verbosity_level >= 2:
get_logger().warning(f"Patch too large, minimizing it, {file.filename}")
if file.edit_type == EDIT_TYPE.ADDED:
# if not added_files_list:
# total_tokens += token_handler.count_tokens(ADDED_FILES_)
if file.filename not in added_files_list:
added_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
else:
# if not modified_files_list:
# total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
if file.filename not in modified_files_list:
modified_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
get_logger().warning(f"Patch too large, skipping it, {filename}")
remaining_files_list_new.append(filename)
continue
if patch:
if not convert_hunks_to_line_numbers:
patch_final = f"\n\n## file: '{file.filename.strip()}\n\n{patch.strip()}\n'"
patch_final = f"\n\n## file: '{filename.strip()}\n\n{patch.strip()}\n'"
else:
patch_final = "\n\n" + patch.strip()
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
files_in_patch_list.append(filename)
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")
return patches, modified_files_list, deleted_files_list, added_files_list, total_tokens
get_logger().info(f"Tokens: {total_tokens}, last filename: {filename}")
return total_tokens, patches, remaining_files_list_new, files_in_patch_list
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
@ -418,3 +467,46 @@ def get_pr_multi_diffs(git_provider: GitProvider,
final_diff_list.append(final_diff)
return final_diff_list
def prune_context(token_handler, curr_component_str, component_context_str, minium_output_tokens, max_tokens=None) -> Tuple[str, str]:
try:
# Get the max tokens possible
if not max_tokens:
get_logger().error(f"Max tokens not provided, using default value")
max_tokens = get_max_tokens(get_settings().config.model_turbo)
# Check if the component + context are too long
component_tokens = token_handler.count_tokens(curr_component_str)
context_tokens = token_handler.count_tokens(component_context_str)
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
get_logger().info(
f"Total tokens: {total_tokens}, context_tokens: {context_tokens}, component_tokens: {component_tokens}, prompt_tokens: {token_handler.prompt_tokens}, max_tokens: {max_tokens}")
# clip the context to fit the max tokens
if total_tokens > max_tokens - minium_output_tokens:
# clip the context to fit the max tokens
max_context_tokens = max_tokens - (minium_output_tokens) - component_tokens - token_handler.prompt_tokens
component_context_str = clip_tokens(component_context_str,
max_context_tokens, num_input_tokens=context_tokens)
context_tokens_old = context_tokens
context_tokens = token_handler.count_tokens(component_context_str)
total_tokens = component_tokens + context_tokens + token_handler.prompt_tokens
get_logger().info(f"Clipped context from {context_tokens_old} to {context_tokens} tokens, total tokens: {total_tokens}")
# clip the class itself to fit the max tokens, if needed
delta = 50 # extra tokens to prevent clipping the component if not necessary
if total_tokens > (max_tokens - minium_output_tokens-delta):
max_context_tokens = max_tokens - minium_output_tokens - context_tokens - token_handler.prompt_tokens # notice 'context_tokens'
curr_component_str= clip_tokens(curr_component_str,
max_context_tokens, num_input_tokens=component_tokens)
component_tokens_new = token_handler.count_tokens(curr_component_str)
total_tokens = component_tokens_new + context_tokens + token_handler.prompt_tokens
get_logger().info(f"Clipped component to fit the max tokens, from {component_tokens} to {component_tokens_new} tokens, total tokens: {total_tokens}")
except Exception as e:
component_context_str = ''
curr_component_str = ''
return curr_component_str, component_context_str

View File

@ -74,7 +74,10 @@ inline_file_summary=false # false, true, 'table'
# markers
use_description_markers=false
include_generated_by_header=true
# large pr mode
enable_large_pr_handling=true
max_ai_calls=3
mention_extra_files=true
#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
[pr_questions] # /ask #
@ -105,16 +108,34 @@ final_clip_factor = 0.8
demand_code_suggestions_self_review=false # add a checkbox for the author to self-review the code suggestions
code_suggestions_self_review_text= "**Author self-review**: I have reviewed the PR code suggestions, and addressed the relevant ones."
approve_pr_on_self_review=false # Pro feature. if true, the PR will be auto-approved after the author clicks on the self-review checkbox
# Suggestion impact
publish_post_process_suggestion_impact=true
[pr_custom_prompt] # /custom_prompt #
prompt = """\
The code suggestions should focus only on the following:
- ...
- ...
...
"""
suggestions_score_threshold=0
num_code_suggestions_per_chunk=4
self_reflect_on_custom_suggestions=true
enable_help_text=false
[pr_add_docs] # /add_docs #
extra_instructions = ""
docs_style = "Sphinx Style" # "Google Style with Args, Returns, Attributes...etc", "Numpy Style", "Sphinx Style", "PEP257", "reStructuredText"
docs_style = "Sphinx" # "Google Style with Args, Returns, Attributes...etc", "Numpy Style", "Sphinx Style", "PEP257", "reStructuredText"
file = "" # in case there are several components with the same name, you can specify the relevant file
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
[pr_update_changelog] # /update_changelog #
push_changelog_changes=false
extra_instructions = ""
[pr_analyze] # /analyze #
enable_help_text=true
[pr_test] # /test #
extra_instructions = ""
@ -129,13 +150,14 @@ enable_help_text=false
num_code_suggestions=4
extra_instructions = ""
file = "" # in case there are several components with the same name, you can specify the relevant file
class_name = ""
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
[checks] # /checks (pro feature) #
enable_auto_checks_feedback=true
excluded_checks_list=["lint"] # list of checks to exclude, for example: ["check1", "check2"]
persistent_comment=true
enable_help_text=true
final_update_message = false
[pr_help] # /help #
@ -148,15 +170,16 @@ ratelimit_retries = 5
base_url = "https://api.github.com"
publish_inline_comments_fallback_with_verification = true
try_fix_invalid_inline_comments = true
app_name = "pr-agent"
[github_action_config]
# auto_review = true # set as env var in .github/workflows/pr-agent.yaml
# auto_describe = true # set as env var in .github/workflows/pr-agent.yaml
# auto_improve = true # set as env var in .github/workflows/pr-agent.yaml
# enable_output = true # set as env var in .github/workflows/pr-agent.yaml
[github_app]
# these toggles allows running the github app from custom deployments
bot_user = "github-actions[bot]"
override_deployment_type = true
# settings for "pull_request" event
handle_pr_actions = ['opened', 'reopened', 'ready_for_review']
@ -180,7 +203,14 @@ ignore_pr_title = []
ignore_bot_pr = true
[gitlab]
url = "https://gitlab.com" # URL to the gitlab service
# URL to the gitlab service
url = "https://gitlab.com"
# Polling (either project id or namespace/project_name) syntax can be used
projects_to_monitor = ['org_name/repo_name']
# Polling trigger
magic_word = "AutoReview"
# Polling interval
polling_interval_seconds = 30
pr_commands = [
"/describe",
"/review --pr_reviewer.num_code_suggestions=0",
@ -229,6 +259,14 @@ force_update_dataset = false
max_issues_to_scan = 500
vectordb = "pinecone"
[pr_find_similar_component]
class_name = ""
file = ""
search_from_org = false
allow_fallback_less_words = true
number_of_keywords = 5
number_of_results = 5
[pinecone]
# fill and place in .secrets.toml
#api_key = ...

View File

@ -7,11 +7,14 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_diff_multiple_patchs, \
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels, ModelType, show_relevant_configurations
from pr_agent.algo.utils import set_custom_labels
from pr_agent.algo.utils import load_yaml, get_user_labels, ModelType, show_relevant_configurations, get_max_tokens, \
clip_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
from pr_agent.git_providers import get_git_provider, GithubProvider, get_git_provider_with_context
from pr_agent.git_providers.git_provider import get_main_pr_language
from pr_agent.log import get_logger
from pr_agent.servers.help import HelpMessage
@ -56,6 +59,7 @@ class PRDescription:
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
}
self.user_description = self.git_provider.get_user_description()
# Initialize the token handler
@ -163,32 +167,105 @@ class PRDescription:
if get_settings().pr_description.use_description_markers and 'pr_agent:' not in self.user_description:
return None
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model)
large_pr_handling = get_settings().pr_description.enable_large_pr_handling and "pr_description_only_files_prompts" in get_settings()
patches_diff = get_pr_diff(self.git_provider, self.token_handler, model, large_pr_handling=large_pr_handling)
if not large_pr_handling or patches_diff:
self.patches_diff = patches_diff
if patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
self.prediction = await self._get_prediction(model, patches_diff, prompt="pr_description_prompt")
else:
get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None
else:
get_logger().error(f"Error getting PR diff {self.pr_id}")
self.prediction = None
# get the diff in multiple patches, with the token handler only for the files prompt
get_logger().debug('large_pr_handling for describe')
token_handler_only_files_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_files_prompts.system,
get_settings().pr_description_only_files_prompts.user,
)
(patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict,
files_in_patches_list) = get_pr_diff_multiple_patchs(
self.git_provider, token_handler_only_files_prompt, model)
async def _get_prediction(self, model: str) -> str:
"""
Generate an AI prediction for the PR description based on the provided model.
# get the files prediction for each patch
file_description_str_list = []
for i, patches in enumerate(patches_compressed_list):
patches_diff = "\n".join(patches)
get_logger().debug(f"PR diff number {i + 1} for describe files")
prediction_files = await self._get_prediction(model, patches_diff,
prompt="pr_description_only_files_prompts")
prediction_files = prediction_files.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files)
else:
get_logger().debug(f"failed to generate predictions in iteration {i + 1} for describe files")
Args:
model (str): The name of the model to be used for generating the prediction.
# generate files_walkthrough string, with proper token handling
token_handler_only_description_prompt = TokenHandler(
self.git_provider.pr,
self.vars,
get_settings().pr_description_only_description_prompts.system,
get_settings().pr_description_only_description_prompts.user)
files_walkthrough = "\n".join(file_description_str_list)
if remaining_files_list:
files_walkthrough += "\n\nNo more token budget. Additional unprocessed files:"
for file in remaining_files_list:
files_walkthrough += f"\n- {file}"
if deleted_files_list:
files_walkthrough += "\n\nAdditional deleted files:"
for file in deleted_files_list:
files_walkthrough += f"\n- {file}"
tokens_files_walkthrough = len(token_handler_only_description_prompt.encoder.encode(files_walkthrough))
total_tokens = token_handler_only_description_prompt.prompt_tokens + tokens_files_walkthrough
max_tokens_model = get_max_tokens(model)
if total_tokens > max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
# clip files_walkthrough to git the tokens within the limit
files_walkthrough = clip_tokens(files_walkthrough,
max_tokens_model - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD - token_handler_only_description_prompt.prompt_tokens,
num_input_tokens=tokens_files_walkthrough)
Returns:
str: The generated AI prediction.
"""
# PR header inference
# toDo - add deleted and unprocessed files to the prompt ('files_walkthrough'), as extra data
get_logger().debug(f"PR diff only description", artifact=files_walkthrough)
prediction_headers = await self._get_prediction(model, patches_diff=files_walkthrough,
prompt="pr_description_only_description_prompts")
prediction_headers = prediction_headers.strip().removeprefix('```yaml').strip('`').strip()
if get_settings().pr_description.mention_extra_files:
for file in remaining_files_list:
extra_file_yaml = f"""\
- filename: |
{file}
changes_summary: |
...
changes_title: |
...
label: |
not processed (token-limit)
"""
files_walkthrough = files_walkthrough.strip() + "\n" + extra_file_yaml.strip()
# final processing
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough
if not load_yaml(self.prediction):
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
if load_yaml(prediction_headers):
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
async def _get_prediction(self, model: str, patches_diff: str, prompt="pr_description_prompt") -> str:
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
system_prompt = environment.from_string(get_settings().get(prompt, {}).get("system", "")).render(variables)
user_prompt = environment.from_string(get_settings().get(prompt, {}).get("user", "")).render(variables)
response, finish_reason = await self.ai_handler.chat_completion(
model=model,
@ -351,7 +428,7 @@ class PRDescription:
filename = file['filename'].replace("'", "`").replace('"', '`')
changes_summary = file['changes_summary']
changes_title = file['changes_title'].strip()
label = file.get('label')
label = file.get('label').strip().lower()
if label not in file_label_dict:
file_label_dict[label] = []
file_label_dict[label].append((filename, changes_title, changes_summary))
@ -392,6 +469,7 @@ class PRDescription:
for filename, file_changes_title, file_change_description in list_tuples:
filename = filename.replace("'", "`").rstrip()
filename_publish = filename.split("/")[-1]
file_changes_title_code = f"<code>{file_changes_title}</code>"
file_changes_title_code_br = insert_br_after_x_chars(file_changes_title_code, x=(delta - 5)).strip()
if len(file_changes_title_code_br) < (delta - 5):
@ -423,6 +501,7 @@ class PRDescription:
<hr>
{filename}
{file_change_description_br}
@ -431,6 +510,7 @@ class PRDescription:
</td>
<td><a href="{link}">{diff_plus_minus}</a>{delta_nbsp}</td>
</tr>
"""
if use_collapsible_file_list: