mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 04:40:38 +08:00
Merge pull request #483 from tmokmss/add-bedrock-support
Add Amazon Bedrock support
This commit is contained in:
17
Usage.md
17
Usage.md
@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat
|
|||||||
|
|
||||||
If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file.
|
If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file.
|
||||||
|
|
||||||
|
#### Amazon Bedrock
|
||||||
|
|
||||||
|
To use Amazon Bedrock and its foundational models, add the below configuration:
|
||||||
|
|
||||||
|
```
|
||||||
|
[config] # in configuration.toml
|
||||||
|
model = "anthropic.claude-v2"
|
||||||
|
fallback_models="anthropic.claude-instant-v1"
|
||||||
|
|
||||||
|
[aws] # in .secrets.toml
|
||||||
|
bedrock_region = "us-east-1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details.
|
||||||
|
|
||||||
|
AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.
|
||||||
|
|
||||||
### Working with large PRs
|
### Working with large PRs
|
||||||
|
|
||||||
The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens.
|
The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens.
|
||||||
|
@ -18,4 +18,7 @@ MAX_TOKENS = {
|
|||||||
'vertex_ai/codechat-bison-32k': 32000,
|
'vertex_ai/codechat-bison-32k': 32000,
|
||||||
'codechat-bison': 6144,
|
'codechat-bison': 6144,
|
||||||
'codechat-bison-32k': 32000,
|
'codechat-bison-32k': 32000,
|
||||||
|
'anthropic.claude-v2': 100000,
|
||||||
|
'anthropic.claude-instant-v1': 100000,
|
||||||
|
'anthropic.claude-v1': 100000,
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import boto3
|
||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
@ -24,6 +25,7 @@ class AiHandler:
|
|||||||
Raises a ValueError if the OpenAI key is missing.
|
Raises a ValueError if the OpenAI key is missing.
|
||||||
"""
|
"""
|
||||||
self.azure = False
|
self.azure = False
|
||||||
|
self.aws_bedrock_client = None
|
||||||
|
|
||||||
if get_settings().get("OPENAI.KEY", None):
|
if get_settings().get("OPENAI.KEY", None):
|
||||||
openai.api_key = get_settings().openai.key
|
openai.api_key = get_settings().openai.key
|
||||||
@ -60,6 +62,12 @@ class AiHandler:
|
|||||||
litellm.vertex_location = get_settings().get(
|
litellm.vertex_location = get_settings().get(
|
||||||
"VERTEXAI.VERTEX_LOCATION", None
|
"VERTEXAI.VERTEX_LOCATION", None
|
||||||
)
|
)
|
||||||
|
if get_settings().get("AWS.BEDROCK_REGION", None):
|
||||||
|
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
|
||||||
|
self.aws_bedrock_client = boto3.client(
|
||||||
|
service_name="bedrock-runtime",
|
||||||
|
region_name=get_settings().aws.bedrock_region,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def deployment_id(self):
|
def deployment_id(self):
|
||||||
@ -100,13 +108,16 @@ class AiHandler:
|
|||||||
if self.azure:
|
if self.azure:
|
||||||
model = 'azure/' + model
|
model = 'azure/' + model
|
||||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||||
response = await acompletion(
|
kwargs = {
|
||||||
model=model,
|
"model": model,
|
||||||
deployment_id=deployment_id,
|
"deployment_id": deployment_id,
|
||||||
messages=messages,
|
"messages": messages,
|
||||||
temperature=temperature,
|
"temperature": temperature,
|
||||||
force_timeout=get_settings().config.ai_timeout
|
"force_timeout": get_settings().config.ai_timeout,
|
||||||
)
|
}
|
||||||
|
if self.aws_bedrock_client:
|
||||||
|
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
|
||||||
|
response = await acompletion(**kwargs)
|
||||||
except (APIError, Timeout, TryAgain) as e:
|
except (APIError, Timeout, TryAgain) as e:
|
||||||
get_logger().error("Error during OpenAI inference: ", e)
|
get_logger().error("Error during OpenAI inference: ", e)
|
||||||
raise
|
raise
|
||||||
|
@ -325,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict:
|
|||||||
break
|
break
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return data
|
|
||||||
|
# thrid fallback - try to remove leading and trailing curly brackets
|
||||||
|
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(response_text_copy,)
|
||||||
|
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
|
||||||
|
return data
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_custom_labels(variables):
|
def set_custom_labels(variables):
|
||||||
|
@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model
|
|||||||
vertex_project = "" # the google cloud platform project name for your vertexai deployment
|
vertex_project = "" # the google cloud platform project name for your vertexai deployment
|
||||||
vertex_location = "" # the google cloud platform location for your vertexai deployment
|
vertex_location = "" # the google cloud platform location for your vertexai deployment
|
||||||
|
|
||||||
|
[aws]
|
||||||
|
bedrock_region = "" # the AWS region to call Bedrock APIs
|
||||||
|
|
||||||
[github]
|
[github]
|
||||||
# ---- Set the following only for deployment type == "user"
|
# ---- Set the following only for deployment type == "user"
|
||||||
user_token = "" # A GitHub personal access token with 'repo' scope.
|
user_token = "" # A GitHub personal access token with 'repo' scope.
|
||||||
|
@ -14,7 +14,7 @@ GitPython==3.1.32
|
|||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
starlette-context==0.3.6
|
starlette-context==0.3.6
|
||||||
litellm==0.12.5
|
litellm==0.12.5
|
||||||
boto3==1.28.25
|
boto3==1.33.1
|
||||||
google-cloud-storage==2.10.0
|
google-cloud-storage==2.10.0
|
||||||
ujson==5.8.0
|
ujson==5.8.0
|
||||||
azure-devops==7.1.0b3
|
azure-devops==7.1.0b3
|
||||||
|
Reference in New Issue
Block a user