From 911c1268fcaf098b06f69ca4aaea6d1fa0eb48ea Mon Sep 17 00:00:00 2001 From: mrT23 Date: Wed, 29 May 2024 13:52:44 +0300 Subject: [PATCH] Add large_patch_policy configuration and implement patch clipping logic --- pr_agent/algo/pr_processing.py | 4 +- pr_agent/algo/utils.py | 90 +++++++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index bc17aedd..731af5a2 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -383,8 +383,8 @@ def get_pr_multi_diffs(git_provider: GitProvider, get_logger().warning(f"Patch too large, skipping: {file.filename}") continue elif get_settings().config.get('large_patch_policy') == 'clip': - delta_tokens = int(0.9*(get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens)) - patch_clipped = clip_tokens(patch,delta_tokens, delete_last_line=True) + delta_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD - token_handler.prompt_tokens + patch_clipped = clip_tokens(patch, delta_tokens, delete_last_line=True, num_input_tokens=new_patch_tokens) new_patch_tokens = token_handler.count_tokens(patch_clipped) if patch_clipped and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens( model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index d24d4244..f9798277 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -5,6 +5,7 @@ import json import os import re import textwrap +import time from datetime import datetime from enum import Enum from typing import Any, List, Tuple @@ -76,6 +77,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment "Score": "๐Ÿ…", "Relevant tests": "๐Ÿงช", "Focused PR": "โœจ", + "Relevant ticket": "๐ŸŽซ", "Security concerns": "๐Ÿ”’", "Insights from user's answers": "๐Ÿ“", "Code feedback": "๐Ÿค–", @@ -85,7 +87,7 @@ def convert_to_markdown(output_data: dict, gfm_supported: bool = True, increment if not incremental_review: markdown_text += f"## PR Review ๐Ÿ”\n\n" else: - markdown_text += f"## Incremental PR Review ๐Ÿ” \n\n" + markdown_text += f"## Incremental PR Review ๐Ÿ”\n\n" markdown_text += f"โฎ๏ธ Review for commits since previous PR-Agent review {incremental_review}.\n\n" if gfm_supported: markdown_text += "\n\n" @@ -470,7 +472,8 @@ def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict: except: pass - # third fallback - try to remove leading and trailing curly brackets + + # third fallback - try to remove leading and trailing curly brackets response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n') try: data = yaml.safe_load(response_text_copy) @@ -552,7 +555,7 @@ def get_max_tokens(model): return max_tokens_model -def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_line=False) -> str: +def clip_tokens(text: str, max_tokens: int, add_three_dots=True, num_input_tokens=None, delete_last_line=False) -> str: """ Clip the number of tokens in a string to a maximum number of tokens. @@ -567,18 +570,30 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True, delete_last_lin return text try: - encoder = TokenEncoder.get_token_encoder() - num_input_tokens = len(encoder.encode(text)) + if num_input_tokens is None: + encoder = TokenEncoder.get_token_encoder() + num_input_tokens = len(encoder.encode(text)) if num_input_tokens <= max_tokens: return text + if max_tokens < 0: + return "" + + # calculate the number of characters to keep num_chars = len(text) chars_per_token = num_chars / num_input_tokens - num_output_chars = int(chars_per_token * max_tokens) - clipped_text = text[:num_output_chars] - if delete_last_line: - clipped_text = clipped_text.rsplit('\n', 1)[0] - if add_three_dots: - clipped_text += "\n...(truncated)" + factor = 0.9 # reduce by 10% to be safe + num_output_chars = int(factor * chars_per_token * max_tokens) + + # clip the text + if num_output_chars > 0: + clipped_text = text[:num_output_chars] + if delete_last_line: + clipped_text = clipped_text.rsplit('\n', 1)[0] + if add_three_dots: + clipped_text += "\n...(truncated)" + else: # if the text is empty + clipped_text = "" + return clipped_text except Exception as e: get_logger().warning(f"Failed to clip tokens: {e}") @@ -665,11 +680,62 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], break return position, absolute_position +def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None): + if git_provider and not rate_limit_status: + rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data} + + if not rate_limit_status: + rate_limit_status = get_rate_limit_status_func() + # validate that the rate limit is not exceeded + is_rate_limit = False + for key, value in rate_limit_status['resources'].items(): + if value['remaining'] == 0: + print(f"key: {key}, value: {value}") + is_rate_limit = True + sleep_time_sec = value['reset'] - datetime.now().timestamp() + sleep_time_hour = sleep_time_sec / 3600.0 + print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") + if sleep_time_sec > 0: + time.sleep(sleep_time_sec+1) + + if git_provider: + rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data} + else: + rate_limit_status = get_rate_limit_status_func() + + return is_rate_limit + + +def get_largest_component(pr_url): + from pr_agent.tools.pr_analyzer import PRAnalyzer + publish_output = get_settings().config.publish_output + get_settings().config.publish_output = False # disable publish output + analyzer = PRAnalyzer(pr_url) + methods_dict_files = analyzer.run_sync() + get_settings().config.publish_output = publish_output + max_lines_changed = 0 + file_b = "" + component_name_b = "" + for file in methods_dict_files: + for method in methods_dict_files[file]: + try: + if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed: + max_lines_changed = methods_dict_files[file][method]['num_plus_lines'] + file_b = file + component_name_b = method + except: + pass + if component_name_b: + get_logger().info(f"Using the largest changed component: '{component_name_b}'") + return component_name_b, file_b + else: + return None, None + def github_action_output(output_data: dict, key_name: str): try: if not get_settings().get('github_action_config.enable_output', False): return - + key_data = output_data.get(key_name, {}) with open(os.environ['GITHUB_OUTPUT'], 'a') as fh: print(f"{key_name}={json.dumps(key_data, indent=None, ensure_ascii=False)}", file=fh)