bug fixes

This commit is contained in:
Krrish Dholakia
2023-08-05 22:50:41 -07:00
parent ed8554699b
commit 0f975ccf4a
3 changed files with 17 additions and 10 deletions

View File

@ -6,7 +6,7 @@ from retry import retry
import litellm import litellm
from litellm import acompletion from litellm import acompletion
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
import traceback
OPENAI_RETRIES=5 OPENAI_RETRIES=5
class AiHandler: class AiHandler:
@ -24,18 +24,24 @@ class AiHandler:
try: try:
openai.api_key = get_settings().openai.key openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key
self.azure = False
if get_settings().get("OPENAI.ORG", None): if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org litellm.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().get("OPENAI.API_TYPE", None): if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = get_settings().openai.api_type if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None): if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None): if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
litellm.api_base = get_settings().openai.api_base litellm.api_base = get_settings().openai.api_base
if get_settings().get("LITE.KEY", None): if get_settings().get("ANTHROPIC.KEY", None):
self.llm_api_key = get_settings().lite.key litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
except AttributeError as e: except AttributeError as e:
raise ValueError("OpenAI key is required") from e raise ValueError("OpenAI key is required") from e
@ -70,7 +76,7 @@ class AiHandler:
{"role": "user", "content": user} {"role": "user", "content": user}
], ],
temperature=temperature, temperature=temperature,
api_key=self.llm_api_key azure=self.azure
) )
except (APIError, Timeout, TryAgain) as e: except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e) logging.error("Error during OpenAI inference: ", e)

View File

@ -42,7 +42,7 @@ dependencies = [
"atlassian-python-api==3.39.0", "atlassian-python-api==3.39.0",
"GitPython~=3.1.32", "GitPython~=3.1.32",
"starlette-context==0.3.6", "starlette-context==0.3.6",
"litellm==0.1.2291" "litellm~=0.1.351"
] ]
[project.urls] [project.urls]

View File

@ -11,3 +11,4 @@ pytest~=7.4.0
aiohttp~=3.8.4 aiohttp~=3.8.4
atlassian-python-api==3.39.0 atlassian-python-api==3.39.0
GitPython~=3.1.32 GitPython~=3.1.32
litellm~=0.1.351