Merge pull request #483 from tmokmss/add-bedrock-support

Add Amazon Bedrock support
This commit is contained in:
mrT23
2023-11-29 03:08:01 -08:00
committed by GitHub
6 changed files with 51 additions and 9 deletions

View File

@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat
If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file. If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file.
#### Amazon Bedrock
To use Amazon Bedrock and its foundational models, add the below configuration:
```
[config] # in configuration.toml
model = "anthropic.claude-v2"
fallback_models="anthropic.claude-instant-v1"
[aws] # in .secrets.toml
bedrock_region = "us-east-1"
```
Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details.
AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.
### Working with large PRs ### Working with large PRs
The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens. The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens.

View File

@ -18,4 +18,7 @@ MAX_TOKENS = {
'vertex_ai/codechat-bison-32k': 32000, 'vertex_ai/codechat-bison-32k': 32000,
'codechat-bison': 6144, 'codechat-bison': 6144,
'codechat-bison-32k': 32000, 'codechat-bison-32k': 32000,
'anthropic.claude-v2': 100000,
'anthropic.claude-instant-v1': 100000,
'anthropic.claude-v1': 100000,
} }

View File

@ -1,5 +1,6 @@
import os import os
import boto3
import litellm import litellm
import openai import openai
from litellm import acompletion from litellm import acompletion
@ -24,6 +25,7 @@ class AiHandler:
Raises a ValueError if the OpenAI key is missing. Raises a ValueError if the OpenAI key is missing.
""" """
self.azure = False self.azure = False
self.aws_bedrock_client = None
if get_settings().get("OPENAI.KEY", None): if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key openai.api_key = get_settings().openai.key
@ -60,6 +62,12 @@ class AiHandler:
litellm.vertex_location = get_settings().get( litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None "VERTEXAI.VERTEX_LOCATION", None
) )
if get_settings().get("AWS.BEDROCK_REGION", None):
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name=get_settings().aws.bedrock_region,
)
@property @property
def deployment_id(self): def deployment_id(self):
@ -100,13 +108,16 @@ class AiHandler:
if self.azure: if self.azure:
model = 'azure/' + model model = 'azure/' + model
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
response = await acompletion( kwargs = {
model=model, "model": model,
deployment_id=deployment_id, "deployment_id": deployment_id,
messages=messages, "messages": messages,
temperature=temperature, "temperature": temperature,
force_timeout=get_settings().config.ai_timeout "force_timeout": get_settings().config.ai_timeout,
) }
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
response = await acompletion(**kwargs)
except (APIError, Timeout, TryAgain) as e: except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e) get_logger().error("Error during OpenAI inference: ", e)
raise raise

View File

@ -325,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict:
break break
except: except:
pass pass
# thrid fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
try:
data = yaml.safe_load(response_text_copy,)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
return data return data
except:
pass
def set_custom_labels(variables): def set_custom_labels(variables):

View File

@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model
vertex_project = "" # the google cloud platform project name for your vertexai deployment vertex_project = "" # the google cloud platform project name for your vertexai deployment
vertex_location = "" # the google cloud platform location for your vertexai deployment vertex_location = "" # the google cloud platform location for your vertexai deployment
[aws]
bedrock_region = "" # the AWS region to call Bedrock APIs
[github] [github]
# ---- Set the following only for deployment type == "user" # ---- Set the following only for deployment type == "user"
user_token = "" # A GitHub personal access token with 'repo' scope. user_token = "" # A GitHub personal access token with 'repo' scope.

View File

@ -14,7 +14,7 @@ GitPython==3.1.32
PyYAML==6.0 PyYAML==6.0
starlette-context==0.3.6 starlette-context==0.3.6
litellm==0.12.5 litellm==0.12.5
boto3==1.28.25 boto3==1.33.1
google-cloud-storage==2.10.0 google-cloud-storage==2.10.0
ujson==5.8.0 ujson==5.8.0
azure-devops==7.1.0b3 azure-devops==7.1.0b3