Files
pr-agent/pr_agent/algo/ai_handlers/litellm_ai_handler.py

166 lines
7.9 KiB
Python
Raw Normal View History

2023-09-09 17:35:45 +03:00
import os
2024-04-14 14:09:58 +03:00
import requests
2023-11-28 23:07:46 +09:00
import boto3
import litellm
2023-07-06 00:21:08 +03:00
import openai
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
OPENAI_RETRIES = 5
2023-07-06 00:21:08 +03:00
class LiteLLMAIHandler(BaseAiHandler):
2023-07-20 10:51:21 +03:00
"""
This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file,
and provides a method for performing chat completions using the OpenAI ChatCompletion API.
"""
2023-07-06 00:21:08 +03:00
def __init__(self):
2023-07-20 10:51:21 +03:00
"""
Initializes the OpenAI API key and other settings from a configuration file.
Raises a ValueError if the OpenAI key is missing.
"""
2023-11-07 09:13:08 +00:00
self.azure = False
self.api_base = None
self.repetition_penalty = None
2023-11-07 09:13:08 +00:00
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
2023-08-03 16:05:46 -07:00
litellm.openai_key = get_settings().openai.key
2024-07-04 12:23:36 +03:00
elif 'OPENAI_API_KEY' not in os.environ:
litellm.api_key = "dummy_key"
if get_settings().get("aws.AWS_ACCESS_KEY_ID"):
os.environ["AWS_ACCESS_KEY_ID"] = get_settings().aws.AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = get_settings().aws.AWS_SECRET_ACCESS_KEY
os.environ["AWS_REGION_NAME"] = get_settings().aws.AWS_REGION_NAME
2023-11-07 09:13:08 +00:00
if get_settings().get("litellm.use_client"):
litellm_token = get_settings().get("litellm.LITELLM_TOKEN")
assert litellm_token, "LITELLM_TOKEN is required"
os.environ["LITELLM_TOKEN"] = litellm_token
litellm.use_client = True
2024-03-13 11:20:02 +09:00
if get_settings().get("LITELLM.DROP_PARAMS", None):
litellm.drop_params = get_settings().litellm.drop_params
2023-11-07 09:13:08 +00:00
if get_settings().get("OPENAI.ORG", None):
litellm.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("GROQ.KEY", None):
litellm.api_key = get_settings().groq.key
2023-11-07 09:13:08 +00:00
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
if get_settings().get("HUGGINGFACE.KEY", None):
litellm.huggingface_key = get_settings().huggingface.key
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPETITION_PENALTY", None):
self.repetition_penalty = float(get_settings().huggingface.repetition_penalty)
2023-11-07 09:13:08 +00:00
if get_settings().get("VERTEXAI.VERTEX_PROJECT", None):
litellm.vertex_project = get_settings().vertexai.vertex_project
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
response_log['system'] = system
response_log['user'] = user
response_log['output'] = resp
response_log['finish_reason'] = finish_reason
if hasattr(self, 'main_pr_language'):
response_log['main_pr_language'] = self.main_pr_language
else:
response_log['main_pr_language'] = 'unknown'
return response_log
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(
2024-06-29 11:30:15 +03:00
retry=retry_if_exception_type((openai.APIError, openai.APIConnectionError, openai.APITimeoutError)), # No retry on RateLimitError
stop=stop_after_attempt(OPENAI_RETRIES)
)
2024-04-14 12:00:19 +03:00
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
2023-07-06 00:21:08 +03:00
try:
resp, finish_reason = None, None
2023-08-07 22:42:53 +03:00
deployment_id = self.deployment_id
2023-10-06 08:12:11 +03:00
if self.azure:
2023-10-06 08:31:31 +03:00
model = 'azure/' + model
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
2024-04-14 12:00:19 +03:00
if img_path:
2024-04-14 14:09:58 +03:00
try:
# check if the image link is alive
r = requests.head(img_path, allow_redirects=True)
if r.status_code == 404:
error_msg = f"The image link is not [alive](img_path).\nPlease repost the original image as a comment, and send the question again with 'quote reply' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context))."
get_logger().error(error_msg)
return f"{error_msg}", "error"
except Exception as e:
get_logger().error(f"Error fetching image: {img_path}", e)
return f"Error fetching image: {img_path}", "error"
2024-04-14 12:00:19 +03:00
messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]},
{"type": "image_url", "image_url": {"url": img_path}}]
2023-11-28 20:11:40 +09:00
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"temperature": temperature,
"force_timeout": get_settings().config.ai_timeout,
"api_base": self.api_base,
2023-11-28 20:11:40 +09:00
}
if self.repetition_penalty:
kwargs["repetition_penalty"] = self.repetition_penalty
2024-02-25 10:45:15 +02:00
get_logger().debug("Prompts", artifact={"system": system, "user": user})
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}")
2023-11-28 20:11:40 +09:00
response = await acompletion(**kwargs)
2024-06-29 11:30:15 +03:00
except (openai.APIError, openai.APITimeoutError) as e:
get_logger().error("Error during OpenAI inference: ", e)
2023-07-06 00:21:08 +03:00
raise
except (openai.RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise openai.APIError from e
2023-08-03 16:05:46 -07:00
if response is None or len(response["choices"]) == 0:
raise openai.APIError
else:
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
get_logger().debug(f"\nAI response:\n{resp}")
# log the full response for debugging
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nAI response:\n{resp}")
return resp, finish_reason