diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 0c8851e8..2781fc5c 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,4 +1,6 @@ from threading import Lock +from math import ceil +import re from jinja2 import Environment, StrictUndefined from tiktoken import encoding_for_model, get_encoding @@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +class ModelTypeValidator: + @staticmethod + def is_openai_model(model_name: str) -> bool: + return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name) + + @staticmethod + def is_claude_model(model_name: str) -> bool: + return 'claude' in model_name + + class TokenEncoder: _encoder_instance = None _model = None @@ -51,6 +63,9 @@ class TokenHandler: - user: The user string. """ self.encoder = TokenEncoder.get_token_encoder() + self.settings = get_settings() + self.model_validator = ModelTypeValidator() + if pr is not None: self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) @@ -79,19 +94,20 @@ class TokenHandler: get_logger().error(f"Error in _get_system_user_tokens: {e}") return 0 - def calc_claude_tokens(self, patch): + def calc_claude_tokens(self, patch: str) -> int: try: import anthropic from pr_agent.algo import MAX_TOKENS - client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key')) - MaxTokens = MAX_TOKENS[get_settings().config.model] + + client = anthropic.Anthropic(api_key=self.settings.get('anthropic.key')) + max_tokens = MAX_TOKENS[self.settings.config.model] # Check if the content size is too large (9MB limit) if len(patch.encode('utf-8')) > 9_000_000: get_logger().warning( "Content too large for Anthropic token counting API, falling back to local tokenizer" ) - return MaxTokens + return max_tokens response = client.messages.count_tokens( model="claude-3-7-sonnet-20250219", @@ -104,29 +120,21 @@ class TokenHandler: return response.input_tokens except Exception as e: - get_logger().error( f"Error in Anthropic token counting: {e}") - return MaxTokens + get_logger().error(f"Error in Anthropic token counting: {e}") + return max_tokens - def is_openai_model(self, model_name): - from re import match - - return 'gpt' in model_name or match(r"^o[1-9](-mini|-preview)?$", model_name) - - def apply_estimation_factor(self, model_name, default_estimate): - from math import ceil - - factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) + def apply_estimation_factor(self, model_name: str, default_estimate: int) -> int: + factor = 1 + self.settings.get('config.model_token_count_estimate_factor', 0) get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}") - return ceil(factor * default_estimate) def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int: model_name = get_settings().config.model.lower() - if 'claude' in model_name and get_settings(use_context=False).get('anthropic.key'): + if self.model_validator.is_claude_model(model_name) and get_settings(use_context=False).get('anthropic.key'): return self.calc_claude_tokens(patch) - if self.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'): + if self.model_validator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'): return default_estimate return self.apply_estimation_factor(model_name, default_estimate)