bug fixes and updates

This commit is contained in:
Krrish Dholakia
2023-08-03 16:05:46 -07:00
parent 102edcdcf1
commit ed8554699b
7 changed files with 42 additions and 23 deletions

View File

@ -7,4 +7,8 @@ MAX_TOKENS = {
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
}

View File

@ -3,6 +3,7 @@ import logging
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
import litellm
from litellm import acompletion
from pr_agent.config_loader import get_settings
@ -22,6 +23,7 @@ class AiHandler:
"""
try:
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@ -31,6 +33,9 @@ class AiHandler:
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
litellm.api_base = get_settings().openai.api_base
if get_settings().get("LITE.KEY", None):
self.llm_api_key = get_settings().lite.key
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@ -65,6 +70,7 @@ class AiHandler:
{"role": "user", "content": user}
],
temperature=temperature,
api_key=self.llm_api_key
)
except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e)
@ -75,8 +81,9 @@ class AiHandler:
except (Exception) as e:
logging.error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
if response is None or len(response.choices) == 0:
if response is None or len(response["choices"]) == 0:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
print(resp, finish_reason)
return resp, finish_reason

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import traceback
import logging
from typing import Callable, Tuple
@ -221,6 +221,6 @@ async def retry_with_fallback_models(f: Callable):
try:
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}")
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception

View File

@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from tiktoken import encoding_for_model, get_encoding
from pr_agent.config_loader import get_settings
@ -27,7 +27,7 @@ class TokenHandler:
- system: The system string.
- user: The user string.
"""
self.encoder = encoding_for_model(get_settings().config.model)
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
@ -47,7 +47,6 @@ class TokenHandler:
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)
system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens