bug fixes and updates

This commit is contained in:
Krrish Dholakia
2023-08-03 16:05:46 -07:00
parent 102edcdcf1
commit ed8554699b
7 changed files with 42 additions and 23 deletions

View File

@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings
@ -27,7 +27,7 @@ class TokenHandler:
- system: The system string.
- user: The user string.
"""
self.encoder = encoding_for_model(get_settings().config.model)
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
@ -47,7 +47,6 @@ class TokenHandler:
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens