From 1373ca23fca4634c2719a7796cd93624afcc32a2 Mon Sep 17 00:00:00 2001 From: tmokmss Date: Tue, 28 Nov 2023 20:11:40 +0900 Subject: [PATCH] support Amazon Bedrock --- Usage.md | 17 +++++++++++++++++ pr_agent/algo/__init__.py | 3 +++ pr_agent/algo/ai_handler.py | 24 +++++++++++++++++------- pr_agent/algo/utils.py | 1 + pr_agent/settings/.secrets_template.toml | 3 +++ requirements.txt | 2 +- 6 files changed, 42 insertions(+), 8 deletions(-) diff --git a/Usage.md b/Usage.md index 95707773..37fd61e1 100644 --- a/Usage.md +++ b/Usage.md @@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file. +#### Amazon Bedrock + +To use Amazon Bedrock and its foundational models, add the below configuration: + +``` +[config] # in configuration.toml +model = "anthropic.claude-v2" +fallback_models="anthropic.claude-instant-v1" + +[aws] # in .secrets.toml +bedrock_region = "us-east-1" +``` + +Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details. + +AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. + ### Working with large PRs The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens. diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 5fe82ee5..63a628a5 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -18,4 +18,7 @@ MAX_TOKENS = { 'vertex_ai/codechat-bison-32k': 32000, 'codechat-bison': 6144, 'codechat-bison-32k': 32000, + 'anthropic.claude-v2': 100000, + 'anthropic.claude-instant-v1': 100000, + 'anthropic.claude-v1': 100000, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 9a48cdc3..6d873cd9 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -7,6 +7,7 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +import boto3 OPENAI_RETRIES = 5 @@ -24,6 +25,7 @@ class AiHandler: Raises a ValueError if the OpenAI key is missing. """ self.azure = False + self.aws_bedrock_client = None if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key @@ -60,6 +62,11 @@ class AiHandler: litellm.vertex_location = get_settings().get( "VERTEXAI.VERTEX_LOCATION", None ) + if get_settings().get("AWS.BEDROCK_REGION", None): + self.aws_bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=get_settings().aws.bedrock_region, + ) @property def deployment_id(self): @@ -100,13 +107,16 @@ class AiHandler: if self.azure: model = 'azure/' + model messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] - response = await acompletion( - model=model, - deployment_id=deployment_id, - messages=messages, - temperature=temperature, - force_timeout=get_settings().config.ai_timeout - ) + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "temperature": temperature, + "force_timeout": get_settings().config.ai_timeout, + } + if self.aws_bedrock_client: + kwargs["aws_bedrock_client"] = self.aws_bedrock_client + response = await acompletion(**kwargs) except (APIError, Timeout, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 7a6e666c..d3c5e828 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -290,6 +290,7 @@ def _fix_key_value(key: str, value: str): def load_yaml(response_text: str) -> dict: response_text = response_text.removeprefix('```yaml').rstrip('`') + response_text = response_text.strip().rstrip().removeprefix('{').removesuffix('}') try: data = yaml.safe_load(response_text) except Exception as e: diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index ba51382c..e7ca4057 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model vertex_project = "" # the google cloud platform project name for your vertexai deployment vertex_location = "" # the google cloud platform location for your vertexai deployment +[aws] +bedrock_region = "" # the AWS region to call Bedrock APIs + [github] # ---- Set the following only for deployment type == "user" user_token = "" # A GitHub personal access token with 'repo' scope. diff --git a/requirements.txt b/requirements.txt index eae08f4c..678cafd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 litellm==0.12.5 -boto3==1.28.25 +boto3==1.33.1 google-cloud-storage==2.10.0 ujson==5.8.0 azure-devops==7.1.0b3