diff --git a/Usage.md b/Usage.md index f11b28df..19904fc4 100644 --- a/Usage.md +++ b/Usage.md @@ -303,6 +303,23 @@ key = ... Also review the [AiHandler](pr_agent/algo/ai_handler.py) file for instruction how to set keys for other models. +#### Vertex AI + +To use Google's Vertex AI platform and its associated models (chat-bison/codechat-bison) set: + +``` +[config] # in configuration.toml +model = "vertex_ai/codechat-bison" + +[vertexai] # in .secrets.toml +vertex_project = "my-google-cloud-project" +vertex_location = "" +``` + +Your [application default credentials](https://cloud.google.com/docs/authentication/application-default-credentials) will be used for authentication so there is no need to set explicit credentials in most environments. + +If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file. + ### Working with large PRs The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens. diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 5a253363..5fe82ee5 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -13,5 +13,9 @@ MAX_TOKENS = { 'claude-2': 100000, 'command-nightly': 4096, 'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096, - 'meta-llama/Llama-2-7b-chat-hf': 4096 + 'meta-llama/Llama-2-7b-chat-hf': 4096, + 'vertex_ai/codechat-bison': 6144, + 'vertex_ai/codechat-bison-32k': 32000, + 'codechat-bison': 6144, + 'codechat-bison-32k': 32000, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index c3989563..9a48cdc3 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -23,39 +23,43 @@ class AiHandler: Initializes the OpenAI API key and other settings from a configuration file. Raises a ValueError if the OpenAI key is missing. """ - try: + self.azure = False + + if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key - if get_settings().get("litellm.use_client"): - litellm_token = get_settings().get("litellm.LITELLM_TOKEN") - assert litellm_token, "LITELLM_TOKEN is required" - os.environ["LITELLM_TOKEN"] = litellm_token - litellm.use_client = True - self.azure = False - if get_settings().get("OPENAI.ORG", None): - litellm.organization = get_settings().openai.org - if get_settings().get("OPENAI.API_TYPE", None): - if get_settings().openai.api_type == "azure": - self.azure = True - litellm.azure_key = get_settings().openai.key - if get_settings().get("OPENAI.API_VERSION", None): - litellm.api_version = get_settings().openai.api_version - if get_settings().get("OPENAI.API_BASE", None): - litellm.api_base = get_settings().openai.api_base - if get_settings().get("ANTHROPIC.KEY", None): - litellm.anthropic_key = get_settings().anthropic.key - if get_settings().get("COHERE.KEY", None): - litellm.cohere_key = get_settings().cohere.key - if get_settings().get("REPLICATE.KEY", None): - litellm.replicate_key = get_settings().replicate.key - if get_settings().get("REPLICATE.KEY", None): - litellm.replicate_key = get_settings().replicate.key - if get_settings().get("HUGGINGFACE.KEY", None): - litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None): - litellm.api_base = get_settings().huggingface.api_base - except AttributeError as e: - raise ValueError("OpenAI key is required") from e + if get_settings().get("litellm.use_client"): + litellm_token = get_settings().get("litellm.LITELLM_TOKEN") + assert litellm_token, "LITELLM_TOKEN is required" + os.environ["LITELLM_TOKEN"] = litellm_token + litellm.use_client = True + if get_settings().get("OPENAI.ORG", None): + litellm.organization = get_settings().openai.org + if get_settings().get("OPENAI.API_TYPE", None): + if get_settings().openai.api_type == "azure": + self.azure = True + litellm.azure_key = get_settings().openai.key + if get_settings().get("OPENAI.API_VERSION", None): + litellm.api_version = get_settings().openai.api_version + if get_settings().get("OPENAI.API_BASE", None): + litellm.api_base = get_settings().openai.api_base + if get_settings().get("ANTHROPIC.KEY", None): + litellm.anthropic_key = get_settings().anthropic.key + if get_settings().get("COHERE.KEY", None): + litellm.cohere_key = get_settings().cohere.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key + if get_settings().get("REPLICATE.KEY", None): + litellm.replicate_key = get_settings().replicate.key + if get_settings().get("HUGGINGFACE.KEY", None): + litellm.huggingface_key = get_settings().huggingface.key + if get_settings().get("HUGGINGFACE.API_BASE", None): + litellm.api_base = get_settings().huggingface.api_base + if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): + litellm.vertex_project = get_settings().vertexai.vertex_project + litellm.vertex_location = get_settings().get( + "VERTEXAI.VERTEX_LOCATION", None + ) @property def deployment_id(self): diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index b6b11cd4..ba51382c 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -36,6 +36,10 @@ api_base = "" # the base url for your huggingface inference endpoint [ollama] api_base = "" # the base url for your local Llama 2, Code Llama, and other models inference endpoint. Acquire through https://ollama.ai/ +[vertexai] +vertex_project = "" # the google cloud platform project name for your vertexai deployment +vertex_location = "" # the google cloud platform location for your vertexai deployment + [github] # ---- Set the following only for deployment type == "user" user_token = "" # A GitHub personal access token with 'repo' scope. diff --git a/requirements.txt b/requirements.txt index 8589b30b..eae08f4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ atlassian-python-api==3.39.0 GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 -litellm~=0.1.574 +litellm==0.12.5 boto3==1.28.25 google-cloud-storage==2.10.0 ujson==5.8.0 @@ -22,3 +22,4 @@ msrest==0.7.1 pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 +google-cloud-aiplatform==1.35.0