From 6b4b16dcf98c3f96ec8487a9f0950e75a8e19ccd Mon Sep 17 00:00:00 2001 From: Rhys Tyers Date: Tue, 7 Nov 2023 09:13:08 +0000 Subject: [PATCH] Support Google's Vertex AI --- pr_agent/algo/__init__.py | 6 +++- pr_agent/algo/ai_handler.py | 66 ++++++++++++++++++++----------------- requirements.txt | 3 +- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 5a253363..5fe82ee5 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -13,5 +13,9 @@ MAX_TOKENS = { 'claude-2': 100000, 'command-nightly': 4096, 'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096, - 'meta-llama/Llama-2-7b-chat-hf': 4096 + 'meta-llama/Llama-2-7b-chat-hf': 4096, + 'vertex_ai/codechat-bison': 6144, + 'vertex_ai/codechat-bison-32k': 32000, + 'codechat-bison': 6144, + 'codechat-bison-32k': 32000, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index c3989563..9a48cdc3 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -23,39 +23,43 @@ class AiHandler: Initializes the OpenAI API key and other settings from a configuration file. Raises a ValueError if the OpenAI key is missing. """ - try: + self.azure = False + + if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key - if get_settings().get("litellm.use_client"): - litellm_token = get_settings().get("litellm.LITELLM_TOKEN") - assert litellm_token, "LITELLM_TOKEN is required" - os.environ["LITELLM_TOKEN"] = litellm_token - litellm.use_client = True - self.azure = False - if get_settings().get("OPENAI.ORG", None): - litellm.organization = get_settings().openai.org - if get_settings().get("OPENAI.API_TYPE", None): - if get_settings().openai.api_type == "azure": - self.azure = True - litellm.azure_key = get_settings().openai.key - if get_settings().get("OPENAI.API_VERSION", None): - litellm.api_version = get_settings().openai.api_version - if get_settings().get("OPENAI.API_BASE", None): - litellm.api_base = get_settings().openai.api_base - if get_settings().get("ANTHROPIC.KEY", None): - litellm.anthropic_key = get_settings().anthropic.key - if get_settings().get("COHERE.KEY", None): - litellm.cohere_key = get_settings().cohere.key - if get_settings().get("REPLICATE.KEY", None): - litellm.replicate_key = get_settings().replicate.key - if get_settings().get("REPLICATE.KEY", None): - litellm.replicate_key = get_settings().replicate.key - if get_settings().get("HUGGINGFACE.KEY", None): - litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None): - litellm.api_base = get_settings().huggingface.api_base - except AttributeError as e: - raise ValueError("OpenAI key is required") from e + if get_settings().get("litellm.use_client"): + litellm_token = get_settings().get("litellm.LITELLM_TOKEN") + assert litellm_token, "LITELLM_TOKEN is required" + os.environ["LITELLM_TOKEN"] = litellm_token + litellm.use_client = True + if get_settings().get("OPENAI.ORG", None): + litellm.organization = get_settings().openai.org + if get_settings().get("OPENAI.API_TYPE", None): + if get_settings().openai.api_type == "azure": + self.azure = True + litellm.azure_key = get_settings().openai.key + if get_settings().get("OPENAI.API_VERSION", None): + litellm.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + litellm.api_base = get_settings().openai.api_base + if get_settings().get("ANTHROPIC.KEY", None): + litellm.anthropic_key = get_settings().anthropic.key + if get_settings().get("COHERE.KEY", None): + litellm.cohere_key = get_settings().cohere.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key + if get_settings().get("HUGGINGFACE.KEY", None): + litellm.huggingface_key = get_settings().huggingface.key + if get_settings().get("HUGGINGFACE.API_BASE", None): + litellm.api_base = get_settings().huggingface.api_base + if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): + litellm.vertex_project = get_settings().vertexai.vertex_project + litellm.vertex_location = get_settings().get( + "VERTEXAI.VERTEX_LOCATION", None + ) @property def deployment_id(self): diff --git a/requirements.txt b/requirements.txt index 8589b30b..eae08f4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ atlassian-python-api==3.39.0 GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 -litellm~=0.1.574 +litellm==0.12.5 boto3==1.28.25 google-cloud-storage==2.10.0 ujson==5.8.0 @@ -22,3 +22,4 @@ msrest==0.7.1 pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 +google-cloud-aiplatform==1.35.0