mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Initial commit - PR-Agent OSS release
This commit is contained in:
24
pr_agent/algo/token_handler.py
Normal file
24
pr_agent/algo/token_handler.py
Normal file
@ -0,0 +1,24 @@
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.config_loader import settings
|
||||
|
||||
|
||||
class TokenHandler:
|
||||
def __init__(self, pr, vars: dict, system, user):
|
||||
self.encoder = encoding_for_model(settings.config.model)
|
||||
self.limit = MAX_TOKENS[settings.config.model]
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(system).render(vars)
|
||||
user_prompt = environment.from_string(user).render(vars)
|
||||
|
||||
system_prompt_tokens = len(encoder.encode(system_prompt))
|
||||
user_prompt_tokens = len(encoder.encode(user_prompt))
|
||||
return system_prompt_tokens + user_prompt_tokens
|
||||
|
||||
def count_tokens(self, patch: str) -> int:
|
||||
return len(self.encoder.encode(patch))
|
Reference in New Issue
Block a user