Add large_patch_policy configuration and implement patch clipping logic

This commit is contained in:
mrT23
2024-05-29 13:52:44 +03:00
parent 17f46bb53b
commit 911c1268fc
2 changed files with 80 additions and 14 deletions

View File

@ -383,8 +383,8 @@ def get_pr_multi_diffs(git_provider: GitProvider,
get_logger().warning(f"Patch too large, skipping: {file.filename}") get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue continue
elif get_settings().config.get('large_patch_policy') == 'clip': elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens)) delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True) patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
new_patch_tokens = token_handler.count_tokens(patch_clipped) new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:

View File

@ -5,6 +5,7 @@ import json
import os import os
import re import re
import textwrap import textwrap
import time
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, List, Tuple from typing import Any, List, Tuple
@ -76,6 +77,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
"Score": "🏅", "Score": "🏅",
"Relevant tests": "🧪", "Relevant tests": "🧪",
"Focused PR": "", "Focused PR": "",
"Relevant ticket": "🎫",
"Security concerns": "🔒", "Security concerns": "🔒",
"Insights from user's answers": "📝", "Insights from user's answers": "📝",
"Code feedback": "🤖", "Code feedback": "🤖",
@ -85,7 +87,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
if not incremental_review: if not incremental_review:
markdown_text += f"## PR Review 🔍\n\n" markdown_text += f"## PR Review 🔍\n\n"
else: else:
markdown_text += f"## Incremental PR Review 🔍 \n\n" markdown_text += f"## Incremental PR Review 🔍\n\n"
markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n" markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n"
if gfm_supported: if gfm_supported:
markdown_text += "<table>\n<tr>\n" markdown_text += "<table>\n<tr>\n"
@ -470,7 +472,8 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
except: except:
pass pass
# third fallback - try to remove leading and trailing curly brackets
# third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
try: try:
data = yaml.safe_load(response_text_copy) data = yaml.safe_load(response_text_copy)
@ -552,7 +555,7 @@ def get_max_tokens(model):
return max_tokens_model return max_tokens_model
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str: def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
""" """
Clip the number of tokens in a string to a maximum number of tokens. Clip the number of tokens in a string to a maximum number of tokens.
@ -567,18 +570,30 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_lin
return text return text
try: try:
encoder = TokenEncoder.get_token_encoder() if num_input_tokens is None:
num_input_tokens = len(encoder.encode(text)) encoder = TokenEncoder.get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens: if num_input_tokens <= max_tokens:
return text return text
if max_tokens < 0:
return ""
# calculate the number of characters to keep
num_chars = len(text) num_chars = len(text)
chars_per_token = num_chars / num_input_tokens chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens) factor = 0.9 # reduce by 10% to be safe
clipped_text = text[:num_output_chars] num_output_chars = int(factor * chars_per_token * max_tokens)
if delete_last_line:
clipped_text = clipped_text.rsplit('\n', 1)[0] # clip the text
if add_three_dots: if num_output_chars > 0:
clipped_text += "\n...(truncated)" clipped_text = text[:num_output_chars]
if delete_last_line:
clipped_text = clipped_text.rsplit('\n', 1)[0]
if add_three_dots:
clipped_text += "\n...(truncated)"
else: # if the text is empty
clipped_text = ""
return clipped_text return clipped_text
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}") get_logger().warning(f"Failed to clip tokens: {e}")
@ -665,11 +680,62 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break break
return position, absolute_position return position, absolute_position
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
if git_provider and not rate_limit_status:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
if not rate_limit_status:
rate_limit_status = get_rate_limit_status_func()
# validate that the rate limit is not exceeded
is_rate_limit = False
for key, value in rate_limit_status['resources'].items():
if value['remaining'] == 0:
print(f"key: {key}, value: {value}")
is_rate_limit = True
sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
if sleep_time_sec > 0:
time.sleep(sleep_time_sec+1)
if git_provider:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
else:
rate_limit_status = get_rate_limit_status_func()
return is_rate_limit
def get_largest_component(pr_url):
from pr_agent.tools.pr_analyzer import PRAnalyzer
publish_output = get_settings().config.publish_output
get_settings().config.publish_output = False # disable publish output
analyzer = PRAnalyzer(pr_url)
methods_dict_files = analyzer.run_sync()
get_settings().config.publish_output = publish_output
max_lines_changed = 0
file_b = ""
component_name_b = ""
for file in methods_dict_files:
for method in methods_dict_files[file]:
try:
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
file_b = file
component_name_b = method
except:
pass
if component_name_b:
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
return component_name_b, file_b
else:
return None, None
def github_action_output(output_data: dict, key_name: str): def github_action_output(output_data: dict, key_name: str):
try: try:
if not get_settings().get('github_action_config.enable_output', False): if not get_settings().get('github_action_config.enable_output', False):
return return
key_data = output_data.get(key_name, {}) key_data = output_data.get(key_name, {})
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh) print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)