mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 12:50:38 +08:00
Add large_patch_policy configuration and implement patch clipping logic
This commit is contained in:
@ -383,8 +383,8 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
|||||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||||
continue
|
continue
|
||||||
elif get_settings().config.get('large_patch_policy') == 'clip':
|
elif get_settings().config.get('large_patch_policy') == 'clip':
|
||||||
delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens))
|
delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens
|
||||||
patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True)
|
patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens)
|
||||||
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
new_patch_tokens = token_handler.count_tokens(patch_clipped)
|
||||||
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(
|
||||||
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||||
|
@ -5,6 +5,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
@ -76,6 +77,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
|
|||||||
"Score": "🏅",
|
"Score": "🏅",
|
||||||
"Relevant tests": "🧪",
|
"Relevant tests": "🧪",
|
||||||
"Focused PR": "✨",
|
"Focused PR": "✨",
|
||||||
|
"Relevant ticket": "🎫",
|
||||||
"Security concerns": "🔒",
|
"Security concerns": "🔒",
|
||||||
"Insights from user's answers": "📝",
|
"Insights from user's answers": "📝",
|
||||||
"Code feedback": "🤖",
|
"Code feedback": "🤖",
|
||||||
@ -85,7 +87,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment
|
|||||||
if not incremental_review:
|
if not incremental_review:
|
||||||
markdown_text += f"## PR Review 🔍\n\n"
|
markdown_text += f"## PR Review 🔍\n\n"
|
||||||
else:
|
else:
|
||||||
markdown_text += f"## Incremental PR Review 🔍 \n\n"
|
markdown_text += f"## Incremental PR Review 🔍\n\n"
|
||||||
markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n"
|
markdown_text += f"⏮️ Review for commits since previous PR-Agent review {incremental_review}.\n\n"
|
||||||
if gfm_supported:
|
if gfm_supported:
|
||||||
markdown_text += "<table>\n<tr>\n"
|
markdown_text += "<table>\n<tr>\n"
|
||||||
@ -470,7 +472,8 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# third fallback - try to remove leading and trailing curly brackets
|
|
||||||
|
# third fallback - try to remove leading and trailing curly brackets
|
||||||
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
||||||
try:
|
try:
|
||||||
data = yaml.safe_load(response_text_copy)
|
data = yaml.safe_load(response_text_copy)
|
||||||
@ -552,7 +555,7 @@ def get_max_tokens(model):
|
|||||||
return max_tokens_model
|
return max_tokens_model
|
||||||
|
|
||||||
|
|
||||||
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str:
|
def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str:
|
||||||
"""
|
"""
|
||||||
Clip the number of tokens in a string to a maximum number of tokens.
|
Clip the number of tokens in a string to a maximum number of tokens.
|
||||||
|
|
||||||
@ -567,18 +570,30 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_lin
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoder = TokenEncoder.get_token_encoder()
|
if num_input_tokens is None:
|
||||||
num_input_tokens = len(encoder.encode(text))
|
encoder = TokenEncoder.get_token_encoder()
|
||||||
|
num_input_tokens = len(encoder.encode(text))
|
||||||
if num_input_tokens <= max_tokens:
|
if num_input_tokens <= max_tokens:
|
||||||
return text
|
return text
|
||||||
|
if max_tokens < 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# calculate the number of characters to keep
|
||||||
num_chars = len(text)
|
num_chars = len(text)
|
||||||
chars_per_token = num_chars / num_input_tokens
|
chars_per_token = num_chars / num_input_tokens
|
||||||
num_output_chars = int(chars_per_token * max_tokens)
|
factor = 0.9 # reduce by 10% to be safe
|
||||||
clipped_text = text[:num_output_chars]
|
num_output_chars = int(factor * chars_per_token * max_tokens)
|
||||||
if delete_last_line:
|
|
||||||
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
# clip the text
|
||||||
if add_three_dots:
|
if num_output_chars > 0:
|
||||||
clipped_text += "\n...(truncated)"
|
clipped_text = text[:num_output_chars]
|
||||||
|
if delete_last_line:
|
||||||
|
clipped_text = clipped_text.rsplit('\n', 1)[0]
|
||||||
|
if add_three_dots:
|
||||||
|
clipped_text += "\n...(truncated)"
|
||||||
|
else: # if the text is empty
|
||||||
|
clipped_text = ""
|
||||||
|
|
||||||
return clipped_text
|
return clipped_text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().warning(f"Failed to clip tokens: {e}")
|
get_logger().warning(f"Failed to clip tokens: {e}")
|
||||||
@ -665,11 +680,62 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|||||||
break
|
break
|
||||||
return position, absolute_position
|
return position, absolute_position
|
||||||
|
|
||||||
|
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
|
||||||
|
if git_provider and not rate_limit_status:
|
||||||
|
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||||
|
|
||||||
|
if not rate_limit_status:
|
||||||
|
rate_limit_status = get_rate_limit_status_func()
|
||||||
|
# validate that the rate limit is not exceeded
|
||||||
|
is_rate_limit = False
|
||||||
|
for key, value in rate_limit_status['resources'].items():
|
||||||
|
if value['remaining'] == 0:
|
||||||
|
print(f"key: {key}, value: {value}")
|
||||||
|
is_rate_limit = True
|
||||||
|
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||||
|
sleep_time_hour = sleep_time_sec / 3600.0
|
||||||
|
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
||||||
|
if sleep_time_sec > 0:
|
||||||
|
time.sleep(sleep_time_sec+1)
|
||||||
|
|
||||||
|
if git_provider:
|
||||||
|
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||||
|
else:
|
||||||
|
rate_limit_status = get_rate_limit_status_func()
|
||||||
|
|
||||||
|
return is_rate_limit
|
||||||
|
|
||||||
|
|
||||||
|
def get_largest_component(pr_url):
|
||||||
|
from pr_agent.tools.pr_analyzer import PRAnalyzer
|
||||||
|
publish_output = get_settings().config.publish_output
|
||||||
|
get_settings().config.publish_output = False # disable publish output
|
||||||
|
analyzer = PRAnalyzer(pr_url)
|
||||||
|
methods_dict_files = analyzer.run_sync()
|
||||||
|
get_settings().config.publish_output = publish_output
|
||||||
|
max_lines_changed = 0
|
||||||
|
file_b = ""
|
||||||
|
component_name_b = ""
|
||||||
|
for file in methods_dict_files:
|
||||||
|
for method in methods_dict_files[file]:
|
||||||
|
try:
|
||||||
|
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
|
||||||
|
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
|
||||||
|
file_b = file
|
||||||
|
component_name_b = method
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if component_name_b:
|
||||||
|
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
|
||||||
|
return component_name_b, file_b
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
|
||||||
def github_action_output(output_data: dict, key_name: str):
|
def github_action_output(output_data: dict, key_name: str):
|
||||||
try:
|
try:
|
||||||
if not get_settings().get('github_action_config.enable_output', False):
|
if not get_settings().get('github_action_config.enable_output', False):
|
||||||
return
|
return
|
||||||
|
|
||||||
key_data = output_data.get(key_name, {})
|
key_data = output_data.get(key_name, {})
|
||||||
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
with open(os.environ['GITHUB_OUTPUT'], 'a') as fh:
|
||||||
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
|
print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)
|
||||||
|
Reference in New Issue
Block a user