diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 3c8e129f..57221518 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -6,7 +6,7 @@ from retry import retry import litellm from litellm import acompletion from pr_agent.config_loader import get_settings - +import traceback OPENAI_RETRIES=5 class AiHandler: @@ -24,18 +24,24 @@ class AiHandler: try: openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key + self.azure = False if get_settings().get("OPENAI.ORG", None): - openai.organization = get_settings().openai.org + litellm.organization = get_settings().openai.org self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) if get_settings().get("OPENAI.API_TYPE", None): - openai.api_type = get_settings().openai.api_type + if get_settings().openai.api_type == "azure": + self.azure = True + litellm.azure_key = get_settings().openai.key if get_settings().get("OPENAI.API_VERSION", None): - openai.api_version = get_settings().openai.api_version + litellm.api_version = get_settings().openai.api_version if get_settings().get("OPENAI.API_BASE", None): - openai.api_base = get_settings().openai.api_base litellm.api_base = get_settings().openai.api_base - if get_settings().get("LITE.KEY", None): - self.llm_api_key = get_settings().lite.key + if get_settings().get("ANTHROPIC.KEY", None): + litellm.anthropic_key = get_settings().anthropic.key + if get_settings().get("COHERE.KEY", None): + litellm.cohere_key = get_settings().cohere.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key except AttributeError as e: raise ValueError("OpenAI key is required") from e @@ -70,7 +76,7 @@ class AiHandler: {"role": "user", "content": user} ], temperature=temperature, - api_key=self.llm_api_key + azure=self.azure ) except (APIError, Timeout, TryAgain) as e: logging.error("Error during OpenAI inference: ", e) diff --git a/pyproject.toml b/pyproject.toml index 5dd3d103..4ca0c0b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "atlassian-python-api==3.39.0", "GitPython~=3.1.32", "starlette-context==0.3.6", - "litellm==0.1.2291" + "litellm~=0.1.351" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 51dc6fee..07a33514 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ python-gitlab==3.15.0 pytest~=7.4.0 aiohttp~=3.8.4 atlassian-python-api==3.39.0 -GitPython~=3.1.32 \ No newline at end of file +GitPython~=3.1.32 +litellm~=0.1.351 \ No newline at end of file