Add constants and improve token calculation logic

This commit is contained in:
kkan9ma
2025-05-20 14:12:24 +09:00
parent e72bb28c4e
commit f198e6fa09

View File

@ -52,6 +52,10 @@ class TokenHandler:
method. method.
""" """
# Constants
CLAUDE_MODEL = "claude-3-7-sonnet-20250219"
CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API
def __init__(self, pr=None, vars: dict = {}, system="", user=""): def __init__(self, pr=None, vars: dict = {}, system="", user=""):
""" """
Initializes the TokenHandler object. Initializes the TokenHandler object.
@ -102,15 +106,14 @@ class TokenHandler:
client = anthropic.Anthropic(api_key=self.settings.get('anthropic.key')) client = anthropic.Anthropic(api_key=self.settings.get('anthropic.key'))
max_tokens = MAX_TOKENS[self.settings.config.model] max_tokens = MAX_TOKENS[self.settings.config.model]
# Check if the content size is too large (9MB limit) if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE:
if len(patch.encode('utf-8')) > 9_000_000:
get_logger().warning( get_logger().warning(
"Content too large for Anthropic token counting API, falling back to local tokenizer" "Content too large for Anthropic token counting API, falling back to local tokenizer"
) )
return max_tokens return max_tokens
response = client.messages.count_tokens( response = client.messages.count_tokens(
model="claude-3-7-sonnet-20250219", model=self.CLAUDE_MODEL,
system="system", system="system",
messages=[{ messages=[{
"role": "user", "role": "user",
@ -126,9 +129,20 @@ class TokenHandler:
def apply_estimation_factor(self, model_name: str, default_estimate: int) -> int: def apply_estimation_factor(self, model_name: str, default_estimate: int) -> int:
factor = 1 + self.settings.get('config.model_token_count_estimate_factor', 0) factor = 1 + self.settings.get('config.model_token_count_estimate_factor', 0)
get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}") get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")
return ceil(factor * default_estimate) return ceil(factor * default_estimate)
def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int: def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
"""
Get token count based on model type.
Args:
patch: The text to count tokens for.
default_estimate: The default token count estimate.
Returns:
int: The calculated token count.
"""
model_name = self.settings.config.model.lower() model_name = self.settings.config.model.lower()
if self.model_validator.is_claude_model(model_name) and self.settings.get('anthropic.key'): if self.model_validator.is_claude_model(model_name) and self.settings.get('anthropic.key'):
@ -152,8 +166,8 @@ class TokenHandler:
""" """
encoder_estimate = len(self.encoder.encode(patch, disallowed_special=())) encoder_estimate = len(self.encoder.encode(patch, disallowed_special=()))
#If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it. if force_accurate:
if not force_accurate:
return encoder_estimate
else:
return self.get_token_count_by_model_type(patch, encoder_estimate=encoder_estimate) return self.get_token_count_by_model_type(patch, encoder_estimate=encoder_estimate)
# If an estimate is enough (for example, in cases where the maximal allowed tokens is way below the known limits), return it.
return encoder_estimate