mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 21:30:40 +08:00
Support cloning repo
Support forcing accurate token calculation (claude) Help docs: Add desired branch in case of user supplied git repo, with default set to "main" Better documentation for getting canonical url parts
This commit is contained in:
@ -76,7 +76,35 @@ class TokenHandler:
|
||||
get_logger().error(f"Error in _get_system_user_tokens: {e}")
|
||||
return 0
|
||||
|
||||
def count_tokens(self, patch: str) -> int:
|
||||
def calc_claude_tokens(self, patch):
|
||||
try:
|
||||
import anthropic
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
client = anthropic.Anthropic(api_key=get_settings(use_context=False).get('anthropic.key'))
|
||||
MaxTokens = MAX_TOKENS[get_settings().config.model]
|
||||
|
||||
# Check if the content size is too large (9MB limit)
|
||||
if len(patch.encode('utf-8')) > 9_000_000:
|
||||
get_logger().warning(
|
||||
"Content too large for Anthropic token counting API, falling back to local tokenizer"
|
||||
)
|
||||
return MaxTokens
|
||||
|
||||
response = client.messages.count_tokens(
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
system="system",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": patch
|
||||
}],
|
||||
)
|
||||
return response.input_tokens
|
||||
|
||||
except Exception as e:
|
||||
get_logger().error( f"Error in Anthropic token counting: {e}")
|
||||
return MaxTokens
|
||||
|
||||
def count_tokens(self, patch: str, force_accurate=False) -> int:
|
||||
"""
|
||||
Counts the number of tokens in a given patch string.
|
||||
|
||||
@ -86,4 +114,6 @@ class TokenHandler:
|
||||
Returns:
|
||||
The number of tokens in the patch string.
|
||||
"""
|
||||
return len(self.encoder.encode(patch, disallowed_special=()))
|
||||
if force_accurate and 'claude' in get_settings().config.model.lower() and get_settings(use_context=False).get('anthropic.key'):
|
||||
return self.calc_claude_tokens(patch) # API call to Anthropic for accurate token counting for Claude models
|
||||
return len(self.encoder.encode(patch, disallowed_special=()))
|
Reference in New Issue
Block a user