mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 21:00:40 +08:00
Add model name validation
This commit is contained in:
@ -1,4 +1,6 @@
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
from math import ceil
|
||||||
|
import re
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
from tiktoken import encoding_for_model, get_encoding
|
from tiktoken import encoding_for_model, get_encoding
|
||||||
@ -7,6 +9,16 @@ from pr_agent.config_loader import get_settings
|
|||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTypeValidator:
|
||||||
|
@staticmethod
|
||||||
|
def is_openai_model(model_name: str) -> bool:
|
||||||
|
return 'gpt' in model_name or re.match(r"^o[1-9](-mini|-preview)?$", model_name)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_claude_model(model_name: str) -> bool:
|
||||||
|
return 'claude' in model_name
|
||||||
|
|
||||||
|
|
||||||
class TokenEncoder:
|
class TokenEncoder:
|
||||||
_encoder_instance = None
|
_encoder_instance = None
|
||||||
_model = None
|
_model = None
|
||||||
@ -51,6 +63,9 @@ class TokenHandler:
|
|||||||
- user: The user string.
|
- user: The user string.
|
||||||
"""
|
"""
|
||||||
self.encoder = TokenEncoder.get_token_encoder()
|
self.encoder = TokenEncoder.get_token_encoder()
|
||||||
|
self.settings = get_settings()
|
||||||
|
self.model_validator = ModelTypeValidator()
|
||||||
|
|
||||||
if pr is not None:
|
if pr is not None:
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||||
|
|
||||||
@ -79,19 +94,20 @@ class TokenHandler:
|
|||||||
get_logger().error(f"Error in _get_system_user_tokens: {e}")
|
get_logger().error(f"Error in _get_system_user_tokens: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def calc_claude_tokens(self, patch):
|
def calc_claude_tokens(self, patch: str) -> int:
|
||||||
try:
|
try:
|
||||||
import anthropic
|
import anthropic
|
||||||
from pr_agent.algo import MAX_TOKENS
|
from pr_agent.algo import MAX_TOKENS
|
||||||
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
|
|
||||||
MaxTokens = MAX_TOKENS[get_settings().config.model]
|
client = anthropic.Anthropic(api_key=self.settings.get('anthropic.key'))
|
||||||
|
max_tokens = MAX_TOKENS[self.settings.config.model]
|
||||||
|
|
||||||
# Check if the content size is too large (9MB limit)
|
# Check if the content size is too large (9MB limit)
|
||||||
if len(patch.encode('utf-8')) > 9_000_000:
|
if len(patch.encode('utf-8')) > 9_000_000:
|
||||||
get_logger().warning(
|
get_logger().warning(
|
||||||
"Content too large for Anthropic token counting API, falling back to local tokenizer"
|
"Content too large for Anthropic token counting API, falling back to local tokenizer"
|
||||||
)
|
)
|
||||||
return MaxTokens
|
return max_tokens
|
||||||
|
|
||||||
response = client.messages.count_tokens(
|
response = client.messages.count_tokens(
|
||||||
model="claude-3-7-sonnet-20250219",
|
model="claude-3-7-sonnet-20250219",
|
||||||
@ -104,29 +120,21 @@ class TokenHandler:
|
|||||||
return response.input_tokens
|
return response.input_tokens
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().error( f"Error in Anthropic token counting: {e}")
|
get_logger().error(f"Error in Anthropic token counting: {e}")
|
||||||
return MaxTokens
|
return max_tokens
|
||||||
|
|
||||||
def is_openai_model(self, model_name):
|
def apply_estimation_factor(self, model_name: str, default_estimate: int) -> int:
|
||||||
from re import match
|
factor = 1 + self.settings.get('config.model_token_count_estimate_factor', 0)
|
||||||
|
|
||||||
return 'gpt' in model_name or match(r"^o[1-9](-mini|-preview)?$", model_name)
|
|
||||||
|
|
||||||
def apply_estimation_factor(self, model_name, default_estimate):
|
|
||||||
from math import ceil
|
|
||||||
|
|
||||||
factor = 1 + get_settings().get('config.model_token_count_estimate_factor', 0)
|
|
||||||
get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")
|
get_logger().warning(f"{model_name}'s token count cannot be accurately estimated. Using factor of {factor}")
|
||||||
|
|
||||||
return ceil(factor * default_estimate)
|
return ceil(factor * default_estimate)
|
||||||
|
|
||||||
def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
|
def get_token_count_by_model_type(self, patch: str, default_estimate: int) -> int:
|
||||||
model_name = get_settings().config.model.lower()
|
model_name = get_settings().config.model.lower()
|
||||||
|
|
||||||
if 'claude' in model_name and get_settings(use_context=False).get('anthropic.key'):
|
if self.model_validator.is_claude_model(model_name) and get_settings(use_context=False).get('anthropic.key'):
|
||||||
return self.calc_claude_tokens(patch)
|
return self.calc_claude_tokens(patch)
|
||||||
|
|
||||||
if self.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'):
|
if self.model_validator.is_openai_model(model_name) and get_settings(use_context=False).get('openai.key'):
|
||||||
return default_estimate
|
return default_estimate
|
||||||
|
|
||||||
return self.apply_estimation_factor(model_name, default_estimate)
|
return self.apply_estimation_factor(model_name, default_estimate)
|
||||||
|
Reference in New Issue
Block a user