Add large_patch_policy configuration and implement patch clipping logic

This commit is contained in:
mrT23
2024-05-29 13:52:44 +03:00
parent 17f46bb53b
commit 911c1268fc
2 changed files with 80 additions and 14 deletions

View File

@ -383,8 +383,8 @@ def get_pr_multi_diffs(git_provider: GitProvider,
get_logger().warning(f"Patch too large, skipping: {file.filename}")
continue
elif get_settings().config.get('large_patch_policy') == 'clip':
delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens))
patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True)
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
new_patch_tokens = token_handler.count_tokens(patch_clipped)
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:

View File

@ -5,6 +5,7 @@ import json
import os
import re
import textwrap
import time
from datetime import datetime
from enum import Enum
from typing import Any, List, Tuple
@ -76,6 +77,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
"Score": "🏅",
"Relevant tests": "🧪",
"Focused PR": "",
"Relevant ticket": "🎫",
"Security concerns": "🔒",
"Insights from user's answers": "📝",
"Code feedback": "🤖",
@ -85,7 +87,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
if not incremental_review:
markdown_text += f"## PR Review 🔍\n\n"
else:
markdown_text += f"## Incremental PR Review 🔍 \n\n"
markdown_text += f"## Incremental PR Review 🔍\n\n"
markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n"
if gfm_supported:
markdown_text += "<table>\n<tr>\n"
@ -470,7 +472,8 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
except:
pass
# third fallback - try to remove leading and trailing curly brackets
# third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
try:
data = yaml.safe_load(response_text_copy)
@ -552,7 +555,7 @@ def get_max_tokens(model):
return max_tokens_model
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str:
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
"""
Clip the number of tokens in a string to a maximum number of tokens.
@ -567,18 +570,30 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_lin
return text
try:
encoder = TokenEncoder.get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens is None:
encoder = TokenEncoder.get_token_encoder()
num_input_tokens = len(encoder.encode(text))
if num_input_tokens <= max_tokens:
return text
if max_tokens < 0:
return ""
# calculate the number of characters to keep
num_chars = len(text)
chars_per_token = num_chars / num_input_tokens
num_output_chars = int(chars_per_token * max_tokens)
clipped_text = text[:num_output_chars]
if delete_last_line:
clipped_text = clipped_text.rsplit('\n', 1)[0]
if add_three_dots:
clipped_text += "\n...(truncated)"
factor = 0.9 # reduce by 10% to be safe
num_output_chars = int(factor * chars_per_token * max_tokens)
# clip the text
if num_output_chars > 0:
clipped_text = text[:num_output_chars]
if delete_last_line:
clipped_text = clipped_text.rsplit('\n', 1)[0]
if add_three_dots:
clipped_text += "\n...(truncated)"
else: # if the text is empty
clipped_text = ""
return clipped_text
except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}")
@ -665,11 +680,62 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break
return position, absolute_position
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
if git_provider and not rate_limit_status:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
if not rate_limit_status:
rate_limit_status = get_rate_limit_status_func()
# validate that the rate limit is not exceeded
is_rate_limit = False
for key, value in rate_limit_status['resources'].items():
if value['remaining'] == 0:
print(f"key: {key}, value: {value}")
is_rate_limit = True
sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
if sleep_time_sec > 0:
time.sleep(sleep_time_sec+1)
if git_provider:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
else:
rate_limit_status = get_rate_limit_status_func()
return is_rate_limit
def get_largest_component(pr_url):
from pr_agent.tools.pr_analyzer import PRAnalyzer
publish_output = get_settings().config.publish_output
get_settings().config.publish_output = False # disable publish output
analyzer = PRAnalyzer(pr_url)
methods_dict_files = analyzer.run_sync()
get_settings().config.publish_output = publish_output
max_lines_changed = 0
file_b = ""
component_name_b = ""
for file in methods_dict_files:
for method in methods_dict_files[file]:
try:
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
file_b = file
component_name_b = method
except:
pass
if component_name_b:
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
return component_name_b, file_b
else:
return None, None
def github_action_output(output_data: dict, key_name: str):
try:
if not get_settings().get('github_action_config.enable_output', False):
return
key_data = output_data.get(key_name, {})
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)