mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 04:10:49 +08:00
Add large_patch_policy configuration and implement patch clipping logic
This commit is contained in:
@ -383,8 +383,8 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
elif get_settings().config.get('large_patch_policy') == 'clip':
|
||||
delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens))
|
||||
patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True)
|
||||
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
|
||||
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
|
||||
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
||||
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
|
@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, List, Tuple
|
||||
@ -76,6 +77,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
|
||||
"Score": "🏅",
|
||||
"Relevant tests": "🧪",
|
||||
"Focused PR": "✨",
|
||||
"Relevant ticket": "🎫",
|
||||
"Security concerns": "🔒",
|
||||
"Insights from user's answers": "📝",
|
||||
"Code feedback": "🤖",
|
||||
@ -85,7 +87,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
|
||||
if not incremental_review:
|
||||
markdown_text += f"## PR Review 🔍\n\n"
|
||||
else:
|
||||
markdown_text += f"## Incremental PR Review 🔍 \n\n"
|
||||
markdown_text += f"## Incremental PR Review 🔍\n\n"
|
||||
markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n"
|
||||
if gfm_supported:
|
||||
markdown_text += "<table>\n<tr>\n"
|
||||
@ -470,7 +472,8 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
|
||||
except:
|
||||
pass
|
||||
|
||||
# third fallback - try to remove leading and trailing curly brackets
|
||||
|
||||
# third fallback - try to remove leading and trailing curly brackets
|
||||
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
||||
try:
|
||||
data = yaml.safe_load(response_text_copy)
|
||||
@ -552,7 +555,7 @@ def get_max_tokens(model):
|
||||
return max_tokens_model
|
||||
|
||||
|
||||
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str:
|
||||
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
|
||||
"""
|
||||
Clip the number of tokens in a string to a maximum number of tokens.
|
||||
|
||||
@ -567,18 +570,30 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_lin
|
||||
return text
|
||||
|
||||
try:
|
||||
encoder = TokenEncoder.get_token_encoder()
|
||||
num_input_tokens = len(encoder.encode(text))
|
||||
if num_input_tokens is None:
|
||||
encoder = TokenEncoder.get_token_encoder()
|
||||
num_input_tokens = len(encoder.encode(text))
|
||||
if num_input_tokens <= max_tokens:
|
||||
return text
|
||||
if max_tokens < 0:
|
||||
return ""
|
||||
|
||||
# calculate the number of characters to keep
|
||||
num_chars = len(text)
|
||||
chars_per_token = num_chars / num_input_tokens
|
||||
num_output_chars = int(chars_per_token * max_tokens)
|
||||
clipped_text = text[:num_output_chars]
|
||||
if delete_last_line:
|
||||
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
||||
if add_three_dots:
|
||||
clipped_text += "\n...(truncated)"
|
||||
factor = 0.9 # reduce by 10% to be safe
|
||||
num_output_chars = int(factor * chars_per_token * max_tokens)
|
||||
|
||||
# clip the text
|
||||
if num_output_chars > 0:
|
||||
clipped_text = text[:num_output_chars]
|
||||
if delete_last_line:
|
||||
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
||||
if add_three_dots:
|
||||
clipped_text += "\n...(truncated)"
|
||||
else: # if the text is empty
|
||||
clipped_text = ""
|
||||
|
||||
return clipped_text
|
||||
except Exception as e:
|
||||
get_logger().warning(f"Failed to clip tokens: {e}")
|
||||
@ -665,11 +680,62 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
|
||||
if git_provider and not rate_limit_status:
|
||||
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||
|
||||
if not rate_limit_status:
|
||||
rate_limit_status = get_rate_limit_status_func()
|
||||
# validate that the rate limit is not exceeded
|
||||
is_rate_limit = False
|
||||
for key, value in rate_limit_status['resources'].items():
|
||||
if value['remaining'] == 0:
|
||||
print(f"key: {key}, value: {value}")
|
||||
is_rate_limit = True
|
||||
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||
sleep_time_hour = sleep_time_sec / 3600.0
|
||||
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
||||
if sleep_time_sec > 0:
|
||||
time.sleep(sleep_time_sec+1)
|
||||
|
||||
if git_provider:
|
||||
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||
else:
|
||||
rate_limit_status = get_rate_limit_status_func()
|
||||
|
||||
return is_rate_limit
|
||||
|
||||
|
||||
def get_largest_component(pr_url):
|
||||
from pr_agent.tools.pr_analyzer import PRAnalyzer
|
||||
publish_output = get_settings().config.publish_output
|
||||
get_settings().config.publish_output = False # disable publish output
|
||||
analyzer = PRAnalyzer(pr_url)
|
||||
methods_dict_files = analyzer.run_sync()
|
||||
get_settings().config.publish_output = publish_output
|
||||
max_lines_changed = 0
|
||||
file_b = ""
|
||||
component_name_b = ""
|
||||
for file in methods_dict_files:
|
||||
for method in methods_dict_files[file]:
|
||||
try:
|
||||
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
|
||||
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
|
||||
file_b = file
|
||||
component_name_b = method
|
||||
except:
|
||||
pass
|
||||
if component_name_b:
|
||||
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
|
||||
return component_name_b, file_b
|
||||
else:
|
||||
return None, None
|
||||
|
||||
def github_action_output(output_data: dict, key_name: str):
|
||||
try:
|
||||
if not get_settings().get('github_action_config.enable_output', False):
|
||||
return
|
||||
|
||||
|
||||
key_data = output_data.get(key_name, {})
|
||||
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
||||
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
|
||||
|
Reference in New Issue
Block a user