2023-09-09 17:35:45 +03:00
import os
2024-04-14 14:09:58 +03:00
import requests
2023-11-28 23:07:46 +09:00
import boto3
2023-08-07 13:26:28 +03:00
import litellm
2023-07-06 00:21:08 +03:00
import openai
2023-08-07 13:26:28 +03:00
from litellm import acompletion
2024-03-06 12:13:54 +02:00
from tenacity import retry , retry_if_exception_type , stop_after_attempt
2023-12-14 07:44:13 +08:00
from pr_agent . algo . ai_handlers . base_ai_handler import BaseAiHandler
2023-08-01 14:43:26 +03:00
from pr_agent . config_loader import get_settings
2023-10-16 14:56:00 +03:00
from pr_agent . log import get_logger
2023-08-07 13:26:28 +03:00
OPENAI_RETRIES = 5
2023-07-06 00:21:08 +03:00
2023-12-14 07:44:13 +08:00
class LiteLLMAIHandler ( BaseAiHandler ) :
2023-07-20 10:51:21 +03:00
"""
This class handles interactions with the OpenAI API for chat completions .
It initializes the API key and other settings from a configuration file ,
and provides a method for performing chat completions using the OpenAI ChatCompletion API .
"""
2023-07-06 00:21:08 +03:00
def __init__ ( self ) :
2023-07-20 10:51:21 +03:00
"""
Initializes the OpenAI API key and other settings from a configuration file .
Raises a ValueError if the OpenAI key is missing .
"""
2023-11-07 09:13:08 +00:00
self . azure = False
2024-03-06 12:13:54 +02:00
self . api_base = None
self . repetition_penalty = None
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " OPENAI.KEY " , None ) :
2023-08-01 14:43:26 +03:00
openai . api_key = get_settings ( ) . openai . key
2023-08-03 16:05:46 -07:00
litellm . openai_key = get_settings ( ) . openai . key
2024-07-04 12:23:36 +03:00
elif ' OPENAI_API_KEY ' not in os . environ :
litellm . api_key = " dummy_key "
if get_settings ( ) . get ( " aws.AWS_ACCESS_KEY_ID " ) :
2024-07-04 12:26:23 +03:00
assert get_settings ( ) . aws . AWS_SECRET_ACCESS_KEY and get_settings ( ) . aws . AWS_REGION_NAME , " AWS credentials are incomplete "
2024-07-04 12:23:36 +03:00
os . environ [ " AWS_ACCESS_KEY_ID " ] = get_settings ( ) . aws . AWS_ACCESS_KEY_ID
os . environ [ " AWS_SECRET_ACCESS_KEY " ] = get_settings ( ) . aws . AWS_SECRET_ACCESS_KEY
os . environ [ " AWS_REGION_NAME " ] = get_settings ( ) . aws . AWS_REGION_NAME
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " litellm.use_client " ) :
litellm_token = get_settings ( ) . get ( " litellm.LITELLM_TOKEN " )
assert litellm_token , " LITELLM_TOKEN is required "
os . environ [ " LITELLM_TOKEN " ] = litellm_token
litellm . use_client = True
2024-03-13 11:20:02 +09:00
if get_settings ( ) . get ( " LITELLM.DROP_PARAMS " , None ) :
litellm . drop_params = get_settings ( ) . litellm . drop_params
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " OPENAI.ORG " , None ) :
litellm . organization = get_settings ( ) . openai . org
if get_settings ( ) . get ( " OPENAI.API_TYPE " , None ) :
if get_settings ( ) . openai . api_type == " azure " :
self . azure = True
litellm . azure_key = get_settings ( ) . openai . key
if get_settings ( ) . get ( " OPENAI.API_VERSION " , None ) :
litellm . api_version = get_settings ( ) . openai . api_version
if get_settings ( ) . get ( " OPENAI.API_BASE " , None ) :
litellm . api_base = get_settings ( ) . openai . api_base
if get_settings ( ) . get ( " ANTHROPIC.KEY " , None ) :
litellm . anthropic_key = get_settings ( ) . anthropic . key
if get_settings ( ) . get ( " COHERE.KEY " , None ) :
litellm . cohere_key = get_settings ( ) . cohere . key
2024-04-21 15:21:45 +09:00
if get_settings ( ) . get ( " GROQ.KEY " , None ) :
litellm . api_key = get_settings ( ) . groq . key
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " REPLICATE.KEY " , None ) :
litellm . replicate_key = get_settings ( ) . replicate . key
if get_settings ( ) . get ( " HUGGINGFACE.KEY " , None ) :
litellm . huggingface_key = get_settings ( ) . huggingface . key
2024-03-06 12:13:54 +02:00
if get_settings ( ) . get ( " HUGGINGFACE.API_BASE " , None ) and ' huggingface ' in get_settings ( ) . config . model :
litellm . api_base = get_settings ( ) . huggingface . api_base
self . api_base = get_settings ( ) . huggingface . api_base
2024-06-03 23:58:31 +08:00
if get_settings ( ) . get ( " OLLAMA.API_BASE " , None ) :
2024-04-02 11:01:45 +02:00
litellm . api_base = get_settings ( ) . ollama . api_base
self . api_base = get_settings ( ) . ollama . api_base
2024-06-16 17:28:30 +01:00
if get_settings ( ) . get ( " HUGGINGFACE.REPETITION_PENALTY " , None ) :
2024-03-06 12:13:54 +02:00
self . repetition_penalty = float ( get_settings ( ) . huggingface . repetition_penalty )
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " VERTEXAI.VERTEX_PROJECT " , None ) :
litellm . vertex_project = get_settings ( ) . vertexai . vertex_project
litellm . vertex_location = get_settings ( ) . get (
" VERTEXAI.VERTEX_LOCATION " , None
)
2024-03-16 13:52:02 +02:00
def prepare_logs ( self , response , system , user , resp , finish_reason ) :
response_log = response . dict ( ) . copy ( )
response_log [ ' system ' ] = system
response_log [ ' user ' ] = user
response_log [ ' output ' ] = resp
response_log [ ' finish_reason ' ] = finish_reason
if hasattr ( self , ' main_pr_language ' ) :
response_log [ ' main_pr_language ' ] = self . main_pr_language
else :
response_log [ ' main_pr_language ' ] = ' unknown '
return response_log
2024-08-17 09:20:30 +03:00
def add_litellm_callbacks ( selfs , kwargs ) - > dict :
2024-08-17 09:15:05 +03:00
pr_metadata = [ ]
def capture_logs ( message ) :
# Parsing the log message and context
record = message . record
log_entry = { }
if record . get ( ' extra ' , { } ) . get ( ' command ' , None ) is not None :
log_entry . update ( { " command " : record [ ' extra ' ] [ " command " ] } )
if record . get ( ' extra ' , { } ) . get ( ' pr_url ' , None ) is not None :
log_entry . update ( { " pr_url " : record [ ' extra ' ] [ " pr_url " ] } )
# Append the log entry to the captured_logs list
pr_metadata . append ( log_entry )
# Adding the custom sink to Loguru
handler_id = get_logger ( ) . add ( capture_logs )
get_logger ( ) . debug ( " Capturing logs for litellm callbacks " )
get_logger ( ) . remove ( handler_id )
# Adding the captured logs to the kwargs
kwargs [ " metadata " ] = pr_metadata
return kwargs
2023-08-07 16:17:06 +03:00
@property
def deployment_id ( self ) :
"""
Returns the deployment ID for the OpenAI API .
"""
return get_settings ( ) . get ( " OPENAI.DEPLOYMENT_ID " , None )
2024-03-06 12:13:54 +02:00
@retry (
2024-06-29 11:30:15 +03:00
retry = retry_if_exception_type ( ( openai . APIError , openai . APIConnectionError , openai . APITimeoutError ) ) , # No retry on RateLimitError
2024-03-06 12:13:54 +02:00
stop = stop_after_attempt ( OPENAI_RETRIES )
)
2024-04-14 12:00:19 +03:00
async def chat_completion ( self , model : str , system : str , user : str , temperature : float = 0.2 , img_path : str = None ) :
2023-07-06 00:21:08 +03:00
try :
2024-02-25 09:58:58 +02:00
resp , finish_reason = None , None
2023-08-07 22:42:53 +03:00
deployment_id = self . deployment_id
2023-10-06 08:12:11 +03:00
if self . azure :
2023-10-06 08:31:31 +03:00
model = ' azure/ ' + model
2024-08-13 16:26:32 +03:00
if ' claude ' in model and not system :
system = " \n "
get_logger ( ) . warning (
" Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error. " )
2023-10-16 14:56:00 +03:00
messages = [ { " role " : " system " , " content " : system } , { " role " : " user " , " content " : user } ]
2024-04-14 12:00:19 +03:00
if img_path :
2024-04-14 14:09:58 +03:00
try :
# check if the image link is alive
r = requests . head ( img_path , allow_redirects = True )
if r . status_code == 404 :
error_msg = f " The image link is not [alive](img_path). \n Please repost the original image as a comment, and send the question again with ' quote reply ' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context)). "
get_logger ( ) . error ( error_msg )
return f " { error_msg } " , " error "
except Exception as e :
get_logger ( ) . error ( f " Error fetching image: { img_path } " , e )
return f " Error fetching image: { img_path } " , " error "
2024-04-14 12:00:19 +03:00
messages [ 1 ] [ " content " ] = [ { " type " : " text " , " text " : messages [ 1 ] [ " content " ] } ,
{ " type " : " image_url " , " image_url " : { " url " : img_path } } ]
2023-11-28 20:11:40 +09:00
kwargs = {
" model " : model ,
" deployment_id " : deployment_id ,
" messages " : messages ,
" temperature " : temperature ,
" force_timeout " : get_settings ( ) . config . ai_timeout ,
2024-06-03 23:58:31 +08:00
" api_base " : self . api_base ,
2023-11-28 20:11:40 +09:00
}
2024-08-17 09:15:05 +03:00
if get_settings ( ) . litellm . get ( " enable_callbacks " , False ) :
2024-08-17 09:20:30 +03:00
kwargs = self . add_litellm_callbacks ( kwargs )
2024-08-17 09:15:05 +03:00
2024-07-27 17:23:42 +03:00
seed = get_settings ( ) . config . get ( " seed " , - 1 )
if temperature > 0 and seed > = 0 :
2024-07-27 18:03:35 +03:00
raise ValueError ( f " Seed ( { seed } ) is not supported with temperature ( { temperature } ) > 0 " )
2024-07-27 17:50:59 +03:00
elif seed > = 0 :
get_logger ( ) . info ( f " Using fixed seed of { seed } " )
2024-07-27 18:02:57 +03:00
kwargs [ " seed " ] = seed
2024-07-27 17:19:32 +03:00
2024-03-06 12:13:54 +02:00
if self . repetition_penalty :
kwargs [ " repetition_penalty " ] = self . repetition_penalty
2024-02-24 16:47:23 +02:00
2024-02-25 10:45:15 +02:00
get_logger ( ) . debug ( " Prompts " , artifact = { " system " : system , " user " : user } )
2024-03-06 12:13:54 +02:00
if get_settings ( ) . config . verbosity_level > = 2 :
get_logger ( ) . info ( f " \n System prompt: \n { system } " )
get_logger ( ) . info ( f " \n User prompt: \n { user } " )
2023-11-28 20:11:40 +09:00
response = await acompletion ( * * kwargs )
2024-06-29 11:30:15 +03:00
except ( openai . APIError , openai . APITimeoutError ) as e :
2024-08-12 12:27:48 +03:00
get_logger ( ) . warning ( " Error during OpenAI inference: " , e )
2023-07-06 00:21:08 +03:00
raise
2024-03-06 12:13:54 +02:00
except ( openai . RateLimitError ) as e :
2023-10-16 14:56:00 +03:00
get_logger ( ) . error ( " Rate limit error during OpenAI inference: " , e )
2023-07-20 15:01:12 +03:00
raise
2023-07-20 15:02:34 +03:00
except ( Exception ) as e :
2024-08-12 12:27:48 +03:00
get_logger ( ) . warning ( " Unknown error during OpenAI inference: " , e )
2024-03-06 12:13:54 +02:00
raise openai . APIError from e
2023-08-03 16:05:46 -07:00
if response is None or len ( response [ " choices " ] ) == 0 :
2024-03-06 12:13:54 +02:00
raise openai . APIError
2024-02-25 09:58:58 +02:00
else :
resp = response [ " choices " ] [ 0 ] [ ' message ' ] [ ' content ' ]
finish_reason = response [ " choices " ] [ 0 ] [ " finish_reason " ]
get_logger ( ) . debug ( f " \n AI response: \n { resp } " )
2024-02-24 16:47:23 +02:00
2024-03-16 13:52:02 +02:00
# log the full response for debugging
response_log = self . prepare_logs ( response , system , user , resp , finish_reason )
2024-03-16 13:47:44 +02:00
get_logger ( ) . debug ( " Full_response " , artifact = response_log )
# for CLI debugging
2024-03-06 12:13:54 +02:00
if get_settings ( ) . config . verbosity_level > = 2 :
get_logger ( ) . info ( f " \n AI response: \n { resp } " )
2024-04-02 11:01:45 +02:00
return resp , finish_reason