Merge pull request #1805 from group-3-sPRinter/improve/token_handler

Refactor count_tokens method structure in token_handler.py for better extensibility
This commit is contained in:
Tal
2025-05-25 12:11:41 +03:00
committed by GitHub

View File

@ -1,4 +1,6 @@
from threading import Lock from threading import Lock
from math import ceil
import re
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model, get_encoding from tiktoken import encoding_for_model, get_encoding
@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger from pr_agent.log import get_logger
class ModelTypeValidator:
@staticmethod
def is_openai_model(model_name: str) -> bool:
return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name)
@staticmethod
def is_anthropic_model(model_name: str) -> bool:
return 'claude' in model_name
class TokenEncoder: class TokenEncoder:
_encoder_instance = None _encoder_instance = None
_model = None _model = None
@ -40,6 +52,10 @@ class TokenHandler:
method. method.
""" """
# Constants
CLAUDE_MODEL = "claude-3-7-sonnet-20250219"
CLAUDE_MAX_CONTENT_SIZE = 9_000_000 # Maximum allowed content size (9MB) for Claude API
def __init__(self, pr=None, vars: dict = {}, system="", user=""): def __init__(self, pr=None, vars: dict = {}, system="", user=""):
""" """
Initializes the TokenHandler object. Initializes the TokenHandler object.
@ -51,6 +67,7 @@ class TokenHandler:
- user: The user string. - user: The user string.
""" """
self.encoder = TokenEncoder.get_token_encoder() self.encoder = TokenEncoder.get_token_encoder()
if pr is not None: if pr is not None:
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
@ -79,22 +96,22 @@ class TokenHandler:
get_logger().error(f"Error in _get_system_user_tokens: {e}") get_logger().error(f"Error in _get_system_user_tokens: {e}")
return 0 return 0
def calc_claude_tokens(self, patch): def _calc_claude_tokens(self, patch: str) -> int:
try: try:
import anthropic import anthropic
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
MaxTokens = MAX_TOKENS[get_settings().config.model]
# Check if the content size is too large (9MB limit) client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
if len(patch.encode('utf-8')) > 9_000_000: max_tokens = MAX_TOKENS[get_settings().config.model]
if len(patch.encode('utf-8')) > self.CLAUDE_MAX_CONTENT_SIZE:
get_logger().warning( get_logger().warning(
"Content too large for Anthropic token counting API, falling back to local tokenizer" "Content too large for Anthropic token counting API, falling back to local tokenizer"
) )
return MaxTokens return max_tokens
response = client.messages.count_tokens( response = client.messages.count_tokens(
model="claude-3-7-sonnet-20250219", model=self.CLAUDE_MODEL,
system="system", system="system",
messages=[{ messages=[{
"role": "user", "role": "user",
@ -105,27 +122,42 @@ class TokenHandler:
except Exception as e: except Exception as e:
get_logger().error(f"Error in Anthropic token counting: {e}") get_logger().error(f"Error in Anthropic token counting: {e}")
return MaxTokens return max_tokens
def estimate_token_count_for_non_anth_claude_models(self, model, default_encoder_estimate): def _apply_estimation_factor(self, model_name: str, default_estimate: int) -> int:
from math import ceil factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
import re get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")
model_is_from_o_series = re.match(r"^o[1-9](-mini|-preview)?$", model) return ceil(factor * default_estimate)
if ('gpt' in get_settings().config.model.lower() or model_is_from_o_series) and get_settings(use_context=False).get('openai.key'):
return default_encoder_estimate
#else: Model is not an OpenAI one - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
elbow_factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0) def _get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
get_logger().warning(f"{model}'s expected token count cannot be accurately estimated. Using {elbow_factor} of encoder output as best effort estimate") """
return ceil(elbow_factor * default_encoder_estimate) Get token count based on model type.
def count_tokens(self, patch: str, force_accurate=False) -> int: Args:
patch: The text to count tokens for.
default_estimate: The default token count estimate.
Returns:
int: The calculated token count.
"""
model_name = get_settings().config.model.lower()
if ModelTypeValidator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'):
return default_estimate
if ModelTypeValidator.is_anthropic_model(model_name) and get_settings(use_context=False).get('anthropic.key'):
return self._calc_claude_tokens(patch)
return self._apply_estimation_factor(model_name, default_estimate)
def count_tokens(self, patch: str, force_accurate: bool = False) -> int:
""" """
Counts the number of tokens in a given patch string. Counts the number of tokens in a given patch string.
Args: Args:
- patch: The patch string. - patch: The patch string.
- force_accurate: If True, uses a more precise calculation method.
Returns: Returns:
The number of tokens in the patch string. The number of tokens in the patch string.
@ -136,10 +168,4 @@ class TokenHandler:
if not force_accurate: if not force_accurate:
return encoder_estimate return encoder_estimate
#else, force_accurate==True: User requested providing an accurate estimation: return self._get_token_count_by_model_type(patch, encoder_estimate)
model = get_settings().config.model.lower()
if 'claude' in model and get_settings(use_context=False).get('anthropic.key'):
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
#else: Non Anthropic provided model:
return self.estimate_token_count_for_non_anth_claude_models(model, encoder_estimate)