2024-10-30 09:56:03 +09:00
from threading import Lock
2023-07-06 00:21:08 +03:00
from jinja2 import Environment , StrictUndefined
2025-03-24 15:47:35 +02:00
from math import ceil
2023-08-03 16:05:46 -07:00
from tiktoken import encoding_for_model , get_encoding
2024-04-03 08:42:50 +03:00
2024-10-30 09:56:03 +09:00
from pr_agent . config_loader import get_settings
2024-08-09 21:44:00 +03:00
from pr_agent . log import get_logger
2024-04-03 08:42:50 +03:00
class TokenEncoder :
_encoder_instance = None
_model = None
_lock = Lock ( ) # Create a lock object
2023-07-06 00:21:08 +03:00
2024-04-03 08:42:50 +03:00
@classmethod
def get_token_encoder ( cls ) :
model = get_settings ( ) . config . model
if cls . _encoder_instance is None or model != cls . _model : # Check without acquiring the lock for performance
with cls . _lock : # Lock acquisition to ensure thread safety
if cls . _encoder_instance is None or model != cls . _model :
cls . _model = model
cls . _encoder_instance = encoding_for_model ( cls . _model ) if " gpt " in cls . _model else get_encoding (
" cl100k_base " )
return cls . _encoder_instance
2023-07-06 00:21:08 +03:00
2023-08-09 12:17:54 +03:00
2023-07-06 00:21:08 +03:00
class TokenHandler :
2023-07-20 10:51:21 +03:00
"""
A class for handling tokens in the context of a pull request .
Attributes :
2023-08-01 14:43:26 +03:00
- encoder : An object of the encoding_for_model class from the tiktoken module . Used to encode strings and count the
number of tokens in them .
- limit : The maximum number of tokens allowed for the given model , as defined in the MAX_TOKENS dictionary in the
pr_agent . algo module .
- prompt_tokens : The number of tokens in the system and user strings , as calculated by the _get_system_user_tokens
method .
2023-07-20 10:51:21 +03:00
"""
2023-09-05 08:40:05 +03:00
def __init__ ( self , pr = None , vars : dict = { } , system = " " , user = " " ) :
2023-07-20 10:51:21 +03:00
"""
Initializes the TokenHandler object .
Args :
- pr : The pull request object .
- vars : A dictionary of variables .
- system : The system string .
- user : The user string .
"""
2024-04-03 08:42:50 +03:00
self . encoder = TokenEncoder . get_token_encoder ( )
2023-09-05 08:40:05 +03:00
if pr is not None :
self . prompt_tokens = self . _get_system_user_tokens ( pr , self . encoder , vars , system , user )
2023-07-06 00:21:08 +03:00
def _get_system_user_tokens ( self , pr , encoder , vars : dict , system , user ) :
2023-07-20 10:51:21 +03:00
"""
Calculates the number of tokens in the system and user strings .
Args :
- pr : The pull request object .
- encoder : An object of the encoding_for_model class from the tiktoken module .
- vars : A dictionary of variables .
- system : The system string .
- user : The user string .
Returns :
The sum of the number of tokens in the system and user strings .
"""
2024-08-09 21:44:00 +03:00
try :
environment = Environment ( undefined = StrictUndefined )
system_prompt = environment . from_string ( system ) . render ( vars )
user_prompt = environment . from_string ( user ) . render ( vars )
system_prompt_tokens = len ( encoder . encode ( system_prompt ) )
user_prompt_tokens = len ( encoder . encode ( user_prompt ) )
return system_prompt_tokens + user_prompt_tokens
except Exception as e :
get_logger ( ) . error ( f " Error in _get_system_user_tokens: { e } " )
return 0
2023-07-06 00:21:08 +03:00
2025-03-23 09:55:58 +02:00
def calc_claude_tokens ( self , patch ) :
try :
import anthropic
from pr_agent . algo import MAX_TOKENS
client = anthropic . Anthropic ( api_key = get_settings ( use_context = False ) . get ( ' anthropic.key ' ) )
MaxTokens = MAX_TOKENS [ get_settings ( ) . config . model ]
# Check if the content size is too large (9MB limit)
if len ( patch . encode ( ' utf-8 ' ) ) > 9_000_000 :
get_logger ( ) . warning (
" Content too large for Anthropic token counting API, falling back to local tokenizer "
)
return MaxTokens
response = client . messages . count_tokens (
model = " claude-3-7-sonnet-20250219 " ,
system = " system " ,
messages = [ {
" role " : " user " ,
" content " : patch
} ] ,
)
return response . input_tokens
except Exception as e :
get_logger ( ) . error ( f " Error in Anthropic token counting: { e } " )
return MaxTokens
def count_tokens ( self , patch : str , force_accurate = False ) - > int :
2023-07-20 10:51:21 +03:00
"""
Counts the number of tokens in a given patch string .
Args :
- patch : The patch string .
Returns :
The number of tokens in the patch string .
"""
2025-03-24 15:47:35 +02:00
encoder_estimate = len ( self . encoder . encode ( patch , disallowed_special = ( ) ) )
if not force_accurate :
return encoder_estimate
#else, need to provide an accurate estimation:
model = get_settings ( ) . config . model . lower ( )
if force_accurate and ' claude ' in model and get_settings ( use_context = False ) . get ( ' anthropic.key ' ) :
2025-03-23 09:55:58 +02:00
return self . calc_claude_tokens ( patch ) # API call to Anthropic for accurate token counting for Claude models
2025-03-24 15:47:35 +02:00
#else: Non Anthropic provided model
import re
model_is_from_o_series = re . match ( r " ^o[1-9](-mini|-preview)?$ " , model )
if ( ' gpt ' in get_settings ( ) . config . model . lower ( ) or model_is_from_o_series ) and get_settings ( use_context = False ) . get ( ' openai.key ' ) :
return encoder_estimate
#else: Model is neither an OpenAI, nor an Anthropic model - therefore, cannot provide an accurate token count and instead, return a higher number as best effort.
elbow_factor = 1 + get_settings ( ) . get ( ' config.model_token_count_estimate_factor ' , 0 )
get_logger ( ) . warning ( f " { model } ' s expected token count cannot be accurately estimated. Using { elbow_factor } of encoder output as best effort estimate " )
return ceil ( elbow_factor * encoder_estimate )