From 1373ca23fca4634c2719a7796cd93624afcc32a2 Mon Sep 17 00:00:00 2001 From: tmokmss Date: Tue, 28 Nov 2023 20:11:40 +0900 Subject: [PATCH 1/5] support Amazon Bedrock --- Usage.md | 17 +++++++++++++++++ pr_agent/algo/__init__.py | 3 +++ pr_agent/algo/ai_handler.py | 24 +++++++++++++++++------- pr_agent/algo/utils.py | 1 + pr_agent/settings/.secrets_template.toml | 3 +++ requirements.txt | 2 +- 6 files changed, 42 insertions(+), 8 deletions(-) diff --git a/Usage.md b/Usage.md index 95707773..37fd61e1 100644 --- a/Usage.md +++ b/Usage.md @@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file. +#### Amazon Bedrock + +To use Amazon Bedrock and its foundational models, add the below configuration: + +``` +[config] # in configuration.toml +model = "anthropic.claude-v2" +fallback_models="anthropic.claude-instant-v1" + +[aws] # in .secrets.toml +bedrock_region = "us-east-1" +``` + +Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details. + +AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. + ### Working with large PRs The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens. diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 5fe82ee5..63a628a5 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -18,4 +18,7 @@ MAX_TOKENS = { 'vertex_ai/codechat-bison-32k': 32000, 'codechat-bison': 6144, 'codechat-bison-32k': 32000, + 'anthropic.claude-v2': 100000, + 'anthropic.claude-instant-v1': 100000, + 'anthropic.claude-v1': 100000, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 9a48cdc3..6d873cd9 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -7,6 +7,7 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +import boto3 OPENAI_RETRIES = 5 @@ -24,6 +25,7 @@ class AiHandler: Raises a ValueError if the OpenAI key is missing. """ self.azure = False + self.aws_bedrock_client = None if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key @@ -60,6 +62,11 @@ class AiHandler: litellm.vertex_location = get_settings().get( "VERTEXAI.VERTEX_LOCATION", None ) + if get_settings().get("AWS.BEDROCK_REGION", None): + self.aws_bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=get_settings().aws.bedrock_region, + ) @property def deployment_id(self): @@ -100,13 +107,16 @@ class AiHandler: if self.azure: model = 'azure/' + model messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] - response = await acompletion( - model=model, - deployment_id=deployment_id, - messages=messages, - temperature=temperature, - force_timeout=get_settings().config.ai_timeout - ) + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "temperature": temperature, + "force_timeout": get_settings().config.ai_timeout, + } + if self.aws_bedrock_client: + kwargs["aws_bedrock_client"] = self.aws_bedrock_client + response = await acompletion(**kwargs) except (APIError, Timeout, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 7a6e666c..d3c5e828 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -290,6 +290,7 @@ def _fix_key_value(key: str, value: str): def load_yaml(response_text: str) -> dict: response_text = response_text.removeprefix('```yaml').rstrip('`') + response_text = response_text.strip().rstrip().removeprefix('{').removesuffix('}') try: data = yaml.safe_load(response_text) except Exception as e: diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index ba51382c..e7ca4057 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model vertex_project = "" # the google cloud platform project name for your vertexai deployment vertex_location = "" # the google cloud platform location for your vertexai deployment +[aws] +bedrock_region = "" # the AWS region to call Bedrock APIs + [github] # ---- Set the following only for deployment type == "user" user_token = "" # A GitHub personal access token with 'repo' scope. diff --git a/requirements.txt b/requirements.txt index eae08f4c..678cafd6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 litellm==0.12.5 -boto3==1.28.25 +boto3==1.33.1 google-cloud-storage==2.10.0 ujson==5.8.0 azure-devops==7.1.0b3 From 97d6fb999a515ea6459c7474db886a593d4f325a Mon Sep 17 00:00:00 2001 From: tmokmss Date: Tue, 28 Nov 2023 20:58:57 +0900 Subject: [PATCH 2/5] set max_tokens_to_sample --- pr_agent/algo/ai_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 6d873cd9..24273db6 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -63,6 +63,9 @@ class AiHandler: "VERTEXAI.VERTEX_LOCATION", None ) if get_settings().get("AWS.BEDROCK_REGION", None): + litellm.AmazonAnthropicConfig.max_tokens_to_sample = int(get_settings().get( + "AWS.CLAUDE_MAX_TOKENS_TO_SAMPLE", '2000' + )) self.aws_bedrock_client = boto3.client( service_name="bedrock-runtime", region_name=get_settings().aws.bedrock_region, From 917f4b6a012f19d40e8b847dab7ad228f633b41e Mon Sep 17 00:00:00 2001 From: tmokmss Date: Tue, 28 Nov 2023 20:59:21 +0900 Subject: [PATCH 3/5] hard code value --- pr_agent/algo/ai_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 24273db6..7f071ef3 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -63,9 +63,7 @@ class AiHandler: "VERTEXAI.VERTEX_LOCATION", None ) if get_settings().get("AWS.BEDROCK_REGION", None): - litellm.AmazonAnthropicConfig.max_tokens_to_sample = int(get_settings().get( - "AWS.CLAUDE_MAX_TOKENS_TO_SAMPLE", '2000' - )) + litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000 self.aws_bedrock_client = boto3.client( service_name="bedrock-runtime", region_name=get_settings().aws.bedrock_region, From f8f57419c40ad2d777200a936d541bf20c868ee1 Mon Sep 17 00:00:00 2001 From: tmokmss Date: Tue, 28 Nov 2023 23:07:46 +0900 Subject: [PATCH 4/5] Update ai_handler.py --- pr_agent/algo/ai_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 7f071ef3..5b6a05f4 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -1,5 +1,6 @@ import os +import boto3 import litellm import openai from litellm import acompletion @@ -7,7 +8,6 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain from retry import retry from pr_agent.config_loader import get_settings from pr_agent.log import get_logger -import boto3 OPENAI_RETRIES = 5 From 5e642c10fae9d5289a7d9dde8b82554d245cd7e9 Mon Sep 17 00:00:00 2001 From: tmokmss Date: Wed, 29 Nov 2023 17:57:54 +0900 Subject: [PATCH 5/5] fallback to try_fix_yaml --- pr_agent/algo/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index d3c5e828..1599f056 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -290,7 +290,6 @@ def _fix_key_value(key: str, value: str): def load_yaml(response_text: str) -> dict: response_text = response_text.removeprefix('```yaml').rstrip('`') - response_text = response_text.strip().rstrip().removeprefix('{').removesuffix('}') try: data = yaml.safe_load(response_text) except Exception as e: @@ -326,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict: break except: pass - return data + + # thrid fallback - try to remove leading and trailing curly brackets + response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}') + try: + data = yaml.safe_load(response_text_copy,) + get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") + return data + except: + pass def set_custom_labels(variables):