diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 99dbc635..8bd3f115 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -52,6 +52,10 @@ class TokenHandler: method. """ + # Constants + CLAUDE_MODEL = "claude-3-7-sonnet-20250219" + CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API + def __init__(self, pr=None, vars: dict = {}, system="", user=""): """ Initializes the TokenHandler object. @@ -102,15 +106,14 @@ class TokenHandler: client = anthropic.Anthropic(api_key=self.settings.get('anthropic.key')) max_tokens = MAX_TOKENS[self.settings.config.model] - # Check if the content size is too large (9MB limit) - if len(patch.encode('utf-8')) > 9_000_000: + if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE: get_logger().warning( "Content too large for Anthropic token counting API, falling back to local tokenizer" ) return max_tokens response = client.messages.count_tokens( - model="claude-3-7-sonnet-20250219", + model=self.CLAUDE_MODEL, system="system", messages=[{ "role": "user", @@ -126,9 +129,20 @@ class TokenHandler: def apply_estimation_factor(self, model_name: str, default_estimate: int) -> int: factor = 1 + self.settings.get('config.model_token_count_estimate_factor', 0) get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}") + return ceil(factor * default_estimate) def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int: + """ + Get token count based on model type. + + Args: + patch: The text to count tokens for. + default_estimate: The default token count estimate. + + Returns: + int: The calculated token count. + """ model_name = self.settings.config.model.lower() if self.model_validator.is_claude_model(model_name) and self.settings.get('anthropic.key'): @@ -152,8 +166,8 @@ class TokenHandler: """ encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) - #If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. - if not force_accurate: - return encoder_estimate - else: + if force_accurate: return self.get_token_count_by_model_type(patch, encoder_estimate=encoder_estimate) + + # If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. + return encoder_estimate