mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 13:20:39 +08:00
support Amazon Bedrock
This commit is contained in:
@ -7,6 +7,7 @@ from openai.error import APIError, RateLimitError, Timeout, TryAgain
|
||||
from retry import retry
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.log import get_logger
|
||||
import boto3
|
||||
|
||||
OPENAI_RETRIES = 5
|
||||
|
||||
@ -24,6 +25,7 @@ class AiHandler:
|
||||
Raises a ValueError if the OpenAI key is missing.
|
||||
"""
|
||||
self.azure = False
|
||||
self.aws_bedrock_client = None
|
||||
|
||||
if get_settings().get("OPENAI.KEY", None):
|
||||
openai.api_key = get_settings().openai.key
|
||||
@ -60,6 +62,11 @@ class AiHandler:
|
||||
litellm.vertex_location = get_settings().get(
|
||||
"VERTEXAI.VERTEX_LOCATION", None
|
||||
)
|
||||
if get_settings().get("AWS.BEDROCK_REGION", None):
|
||||
self.aws_bedrock_client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=get_settings().aws.bedrock_region,
|
||||
)
|
||||
|
||||
@property
|
||||
def deployment_id(self):
|
||||
@ -100,13 +107,16 @@ class AiHandler:
|
||||
if self.azure:
|
||||
model = 'azure/' + model
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
response = await acompletion(
|
||||
model=model,
|
||||
deployment_id=deployment_id,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
force_timeout=get_settings().config.ai_timeout
|
||||
)
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"deployment_id": deployment_id,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"force_timeout": get_settings().config.ai_timeout,
|
||||
}
|
||||
if self.aws_bedrock_client:
|
||||
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
|
||||
response = await acompletion(**kwargs)
|
||||
except (APIError, Timeout, TryAgain) as e:
|
||||
get_logger().error("Error during OpenAI inference: ", e)
|
||||
raise
|
||||
|
Reference in New Issue
Block a user