diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 60cf2c84..cb313f02 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,4 +1,6 @@ from threading import Lock +from math import ceil +import re from jinja2 import Environment, StrictUndefined from tiktoken import encoding_for_model, get_encoding @@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +class ModelTypeValidator: + @staticmethod + def is_openai_model(model_name: str) -> bool: + return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name) + + @staticmethod + def is_anthropic_model(model_name: str) -> bool: + return 'claude' in model_name + + class TokenEncoder: _encoder_instance = None _model = None @@ -40,6 +52,10 @@ class TokenHandler: method. """ + # Constants + CLAUDE_MODEL = "claude-3-7-sonnet-20250219" + CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API + def __init__(self, pr=None, vars: dict = {}, system="", user=""): """ Initializes the TokenHandler object. @@ -51,6 +67,7 @@ class TokenHandler: - user: The user string. """ self.encoder = TokenEncoder.get_token_encoder() + if pr is not None: self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) @@ -79,22 +96,22 @@ class TokenHandler: get_logger().error(f"Error in _get_system_user_tokens: {e}") return 0 - def calc_claude_tokens(self, patch): + def _calc_claude_tokens(self, patch: str) -> int: try: import anthropic from pr_agent.algo import MAX_TOKENS + client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key')) - MaxTokens = MAX_TOKENS[get_settings().config.model] + max_tokens = MAX_TOKENS[get_settings().config.model] - # Check if the content size is too large (9MB limit) - if len(patch.encode('utf-8')) > 9_000_000: + if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE: get_logger().warning( "Content too large for Anthropic token counting API, falling back to local tokenizer" ) - return MaxTokens + return max_tokens response = client.messages.count_tokens( - model="claude-3-7-sonnet-20250219", + model=self.CLAUDE_MODEL, system="system", messages=[{ "role": "user", @@ -104,42 +121,51 @@ class TokenHandler: return response.input_tokens except Exception as e: - get_logger().error( f"Error in Anthropic token counting: {e}") - return MaxTokens + get_logger().error(f"Error in Anthropic token counting: {e}") + return max_tokens - def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate): - from math import ceil - import re + def _apply_estimation_factor(self, model_name: str, default_estimate: int) -> int: + factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) + get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}") + + return ceil(factor * default_estimate) - model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) - if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'): - return default_encoder_estimate - #else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort. + def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int: + """ + Get token count based on model type. - elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) - get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate") - return ceil(elbow_factor * default_encoder_estimate) + Args: + patch: The text to count tokens for. + default_estimate: The default token count estimate. - def count_tokens(self, patch: str, force_accurate=False) -> int: + Returns: + int: The calculated token count. + """ + model_name = get_settings().config.model.lower() + + if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'): + return default_estimate + + if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get('anthropic.key'): + return self._calc_claude_tokens(patch) + + return self._apply_estimation_factor(model_name, default_estimate) + + def count_tokens(self, patch: str, force_accurate: bool = False) -> int: """ Counts the number of tokens in a given patch string. Args: - patch: The patch string. + - force_accurate: If True, uses a more precise calculation method. Returns: The number of tokens in the patch string. """ encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) - #If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. + # If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. if not force_accurate: return encoder_estimate - #else, force_accurate==True: User requested providing an accurate estimation: - model = get_settings().config.model.lower() - if 'claude' in model and get_settings(use_context=False).get('anthropic.key'): - return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models - - #else: Non Anthropic provided model: - return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate) + return self._get_token_count_by_model_type(patch, encoder_estimate)