diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index d7eff9d7..9cc3b41f 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -1,12 +1,25 @@ from jinja2 import Environment, StrictUndefined from tiktoken import encoding_for_model, get_encoding - from pr_agent.config_loader import get_settings +from threading import Lock -def get_token_encoder(): - return encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding( - "cl100k_base") +class TokenEncoder: + _encoder_instance = None + _model = None + _lock = Lock() # Create a lock object + + @classmethod + def get_token_encoder(cls): + model = get_settings().config.model + if cls._encoder_instance is None or model != cls._model: # Check without acquiring the lock for performance + with cls._lock: # Lock acquisition to ensure thread safety + if cls._encoder_instance is None or model != cls._model: + cls._model = model + cls._encoder_instance = encoding_for_model(cls._model) if "gpt" in cls._model else get_encoding( + "cl100k_base") + return cls._encoder_instance + class TokenHandler: """ @@ -31,7 +44,7 @@ class TokenHandler: - system: The system string. - user: The user string. """ - self.encoder = get_token_encoder() + self.encoder = TokenEncoder.get_token_encoder() if pr is not None: self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index c2b6323c..b017f0aa 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -12,7 +12,7 @@ import yaml from starlette_context import context from pr_agent.algo import MAX_TOKENS -from pr_agent.algo.token_handler import get_token_encoder +from pr_agent.algo.token_handler import TokenEncoder from pr_agent.config_loader import get_settings, global_settings from pr_agent.algo.types import FilePatchInfo from pr_agent.log import get_logger @@ -566,7 +566,7 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str: return text try: - encoder = get_token_encoder() + encoder = TokenEncoder.get_token_encoder() num_input_tokens = len(encoder.encode(text)) if num_input_tokens <= max_tokens: return text