2023-09-09 17:35:45 +03:00
import os
2024-04-14 14:09:58 +03:00
import requests
2023-08-07 13:26:28 +03:00
import litellm
2023-07-06 00:21:08 +03:00
import openai
2023-08-07 13:26:28 +03:00
from litellm import acompletion
2024-03-06 12:13:54 +02:00
from tenacity import retry , retry_if_exception_type , stop_after_attempt
2024-08-19 15:45:47 -04:00
2023-12-14 07:44:13 +08:00
from pr_agent . algo . ai_handlers . base_ai_handler import BaseAiHandler
2023-08-01 14:43:26 +03:00
from pr_agent . config_loader import get_settings
2023-10-16 14:56:00 +03:00
from pr_agent . log import get_logger
2023-08-07 13:26:28 +03:00
OPENAI_RETRIES = 5
2023-07-06 00:21:08 +03:00
2023-12-14 07:44:13 +08:00
class LiteLLMAIHandler ( BaseAiHandler ) :
2023-07-20 10:51:21 +03:00
"""
This class handles interactions with the OpenAI API for chat completions .
It initializes the API key and other settings from a configuration file ,
and provides a method for performing chat completions using the OpenAI ChatCompletion API .
"""
2023-07-06 00:21:08 +03:00
def __init__ ( self ) :
2023-07-20 10:51:21 +03:00
"""
Initializes the OpenAI API key and other settings from a configuration file .
Raises a ValueError if the OpenAI key is missing .
"""
2023-11-07 09:13:08 +00:00
self . azure = False
2024-03-06 12:13:54 +02:00
self . api_base = None
self . repetition_penalty = None
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " OPENAI.KEY " , None ) :
2023-08-01 14:43:26 +03:00
openai . api_key = get_settings ( ) . openai . key
2023-08-03 16:05:46 -07:00
litellm . openai_key = get_settings ( ) . openai . key
2024-07-04 12:23:36 +03:00
elif ' OPENAI_API_KEY ' not in os . environ :
litellm . api_key = " dummy_key "
if get_settings ( ) . get ( " aws.AWS_ACCESS_KEY_ID " ) :
2024-07-04 12:26:23 +03:00
assert get_settings ( ) . aws . AWS_SECRET_ACCESS_KEY and get_settings ( ) . aws . AWS_REGION_NAME , " AWS credentials are incomplete "
2024-07-04 12:23:36 +03:00
os . environ [ " AWS_ACCESS_KEY_ID " ] = get_settings ( ) . aws . AWS_ACCESS_KEY_ID
os . environ [ " AWS_SECRET_ACCESS_KEY " ] = get_settings ( ) . aws . AWS_SECRET_ACCESS_KEY
os . environ [ " AWS_REGION_NAME " ] = get_settings ( ) . aws . AWS_REGION_NAME
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " litellm.use_client " ) :
litellm_token = get_settings ( ) . get ( " litellm.LITELLM_TOKEN " )
assert litellm_token , " LITELLM_TOKEN is required "
os . environ [ " LITELLM_TOKEN " ] = litellm_token
litellm . use_client = True
2024-03-13 11:20:02 +09:00
if get_settings ( ) . get ( " LITELLM.DROP_PARAMS " , None ) :
litellm . drop_params = get_settings ( ) . litellm . drop_params
2024-08-19 15:45:47 -04:00
if get_settings ( ) . get ( " LITELLM.SUCCESS_CALLBACK " , None ) :
litellm . success_callback = get_settings ( ) . litellm . success_callback
if get_settings ( ) . get ( " LITELLM.FAILURE_CALLBACK " , None ) :
litellm . failure_callback = get_settings ( ) . litellm . failure_callback
if get_settings ( ) . get ( " LITELLM.SERVICE_CALLBACK " , None ) :
litellm . service_callback = get_settings ( ) . litellm . service_callback
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " OPENAI.ORG " , None ) :
litellm . organization = get_settings ( ) . openai . org
if get_settings ( ) . get ( " OPENAI.API_TYPE " , None ) :
if get_settings ( ) . openai . api_type == " azure " :
self . azure = True
litellm . azure_key = get_settings ( ) . openai . key
if get_settings ( ) . get ( " OPENAI.API_VERSION " , None ) :
litellm . api_version = get_settings ( ) . openai . api_version
if get_settings ( ) . get ( " OPENAI.API_BASE " , None ) :
litellm . api_base = get_settings ( ) . openai . api_base
if get_settings ( ) . get ( " ANTHROPIC.KEY " , None ) :
litellm . anthropic_key = get_settings ( ) . anthropic . key
if get_settings ( ) . get ( " COHERE.KEY " , None ) :
litellm . cohere_key = get_settings ( ) . cohere . key
2024-04-21 15:21:45 +09:00
if get_settings ( ) . get ( " GROQ.KEY " , None ) :
litellm . api_key = get_settings ( ) . groq . key
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " REPLICATE.KEY " , None ) :
litellm . replicate_key = get_settings ( ) . replicate . key
if get_settings ( ) . get ( " HUGGINGFACE.KEY " , None ) :
litellm . huggingface_key = get_settings ( ) . huggingface . key
2024-03-06 12:13:54 +02:00
if get_settings ( ) . get ( " HUGGINGFACE.API_BASE " , None ) and ' huggingface ' in get_settings ( ) . config . model :
litellm . api_base = get_settings ( ) . huggingface . api_base
self . api_base = get_settings ( ) . huggingface . api_base
2024-06-03 23:58:31 +08:00
if get_settings ( ) . get ( " OLLAMA.API_BASE " , None ) :
2024-04-02 11:01:45 +02:00
litellm . api_base = get_settings ( ) . ollama . api_base
self . api_base = get_settings ( ) . ollama . api_base
2024-06-16 17:28:30 +01:00
if get_settings ( ) . get ( " HUGGINGFACE.REPETITION_PENALTY " , None ) :
2024-03-06 12:13:54 +02:00
self . repetition_penalty = float ( get_settings ( ) . huggingface . repetition_penalty )
2023-11-07 09:13:08 +00:00
if get_settings ( ) . get ( " VERTEXAI.VERTEX_PROJECT " , None ) :
litellm . vertex_project = get_settings ( ) . vertexai . vertex_project
litellm . vertex_location = get_settings ( ) . get (
" VERTEXAI.VERTEX_LOCATION " , None
)
2024-10-29 08:00:16 +09:00
# Google AI Studio
# SEE https://docs.litellm.ai/docs/providers/gemini
if get_settings ( ) . get ( " GOOGLE_AI_STUDIO.GEMINI_API_KEY " , None ) :
os . environ [ " GEMINI_API_KEY " ] = get_settings ( ) . google_ai_studio . gemini_api_key
2024-03-16 13:52:02 +02:00
def prepare_logs ( self , response , system , user , resp , finish_reason ) :
response_log = response . dict ( ) . copy ( )
response_log [ ' system ' ] = system
response_log [ ' user ' ] = user
response_log [ ' output ' ] = resp
response_log [ ' finish_reason ' ] = finish_reason
if hasattr ( self , ' main_pr_language ' ) :
response_log [ ' main_pr_language ' ] = self . main_pr_language
else :
response_log [ ' main_pr_language ' ] = ' unknown '
return response_log
2024-08-17 09:20:30 +03:00
def add_litellm_callbacks ( selfs , kwargs ) - > dict :
2024-08-19 15:45:47 -04:00
captured_extra = [ ]
2024-08-17 09:15:05 +03:00
def capture_logs ( message ) :
# Parsing the log message and context
record = message . record
log_entry = { }
2024-08-19 15:45:47 -04:00
if record . get ( ' extra ' , None ) . get ( ' command ' , None ) is not None :
2024-08-17 09:15:05 +03:00
log_entry . update ( { " command " : record [ ' extra ' ] [ " command " ] } )
if record . get ( ' extra ' , { } ) . get ( ' pr_url ' , None ) is not None :
log_entry . update ( { " pr_url " : record [ ' extra ' ] [ " pr_url " ] } )
# Append the log entry to the captured_logs list
2024-08-19 15:45:47 -04:00
captured_extra . append ( log_entry )
2024-08-17 09:15:05 +03:00
# Adding the custom sink to Loguru
handler_id = get_logger ( ) . add ( capture_logs )
get_logger ( ) . debug ( " Capturing logs for litellm callbacks " )
get_logger ( ) . remove ( handler_id )
2024-08-19 15:45:47 -04:00
context = captured_extra [ 0 ] if len ( captured_extra ) > 0 else None
command = context . get ( " command " , " unknown " )
pr_url = context . get ( " pr_url " , " unknown " )
git_provider = get_settings ( ) . config . git_provider
metadata = dict ( )
callbacks = litellm . success_callback + litellm . failure_callback + litellm . service_callback
if " langfuse " in callbacks :
metadata . update ( {
" trace_name " : command ,
" tags " : [ git_provider , command ] ,
" trace_metadata " : {
" command " : command ,
" pr_url " : pr_url ,
} ,
} )
if " langsmith " in callbacks :
metadata . update ( {
" run_name " : command ,
" tags " : [ git_provider , command ] ,
" extra " : {
" metadata " : {
" command " : command ,
" pr_url " : pr_url ,
}
} ,
} )
2024-08-17 09:15:05 +03:00
# Adding the captured logs to the kwargs
2024-08-19 15:45:47 -04:00
kwargs [ " metadata " ] = metadata
2024-08-17 09:15:05 +03:00
return kwargs
2023-08-07 16:17:06 +03:00
@property
def deployment_id ( self ) :
"""
Returns the deployment ID for the OpenAI API .
"""
return get_settings ( ) . get ( " OPENAI.DEPLOYMENT_ID " , None )
2024-03-06 12:13:54 +02:00
@retry (
2024-06-29 11:30:15 +03:00
retry = retry_if_exception_type ( ( openai . APIError , openai . APIConnectionError , openai . APITimeoutError ) ) , # No retry on RateLimitError
2024-03-06 12:13:54 +02:00
stop = stop_after_attempt ( OPENAI_RETRIES )
)
2024-04-14 12:00:19 +03:00
async def chat_completion ( self , model : str , system : str , user : str , temperature : float = 0.2 , img_path : str = None ) :
2023-07-06 00:21:08 +03:00
try :
2024-02-25 09:58:58 +02:00
resp , finish_reason = None , None
2023-08-07 22:42:53 +03:00
deployment_id = self . deployment_id
2023-10-06 08:12:11 +03:00
if self . azure :
2023-10-06 08:31:31 +03:00
model = ' azure/ ' + model
2024-08-13 16:26:32 +03:00
if ' claude ' in model and not system :
2024-09-15 08:07:59 +03:00
system = " No system prompt provided "
2024-08-13 16:26:32 +03:00
get_logger ( ) . warning (
" Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error. " )
2023-10-16 14:56:00 +03:00
messages = [ { " role " : " system " , " content " : system } , { " role " : " user " , " content " : user } ]
2024-10-09 08:53:34 +03:00
2024-04-14 12:00:19 +03:00
if img_path :
2024-04-14 14:09:58 +03:00
try :
# check if the image link is alive
r = requests . head ( img_path , allow_redirects = True )
if r . status_code == 404 :
error_msg = f " The image link is not [alive](img_path). \n Please repost the original image as a comment, and send the question again with ' quote reply ' (see [instructions](https://pr-agent-docs.codium.ai/tools/ask/#ask-on-images-using-the-pr-code-as-context)). "
get_logger ( ) . error ( error_msg )
return f " { error_msg } " , " error "
except Exception as e :
get_logger ( ) . error ( f " Error fetching image: { img_path } " , e )
return f " Error fetching image: { img_path } " , " error "
2024-04-14 12:00:19 +03:00
messages [ 1 ] [ " content " ] = [ { " type " : " text " , " text " : messages [ 1 ] [ " content " ] } ,
{ " type " : " image_url " , " image_url " : { " url " : img_path } } ]
2024-10-09 08:53:34 +03:00
# Currently O1 does not support separate system and user prompts
2024-10-09 08:56:31 +03:00
O1_MODEL_PREFIX = ' o1- '
2024-10-19 11:34:57 +03:00
model_type = model . split ( ' / ' ) [ - 1 ] if ' / ' in model else model
2024-10-19 11:32:45 +03:00
if model_type . startswith ( O1_MODEL_PREFIX ) :
2024-10-09 08:53:34 +03:00
user = f " { system } \n \n \n { user } "
system = " "
get_logger ( ) . info ( f " Using O1 model, combining system and user prompts " )
messages = [ { " role " : " user " , " content " : user } ]
kwargs = {
" model " : model ,
" deployment_id " : deployment_id ,
" messages " : messages ,
" timeout " : get_settings ( ) . config . ai_timeout ,
" api_base " : self . api_base ,
}
else :
kwargs = {
" model " : model ,
" deployment_id " : deployment_id ,
" messages " : messages ,
" temperature " : temperature ,
" timeout " : get_settings ( ) . config . ai_timeout ,
" api_base " : self . api_base ,
}
2024-08-17 09:15:05 +03:00
if get_settings ( ) . litellm . get ( " enable_callbacks " , False ) :
2024-08-17 09:20:30 +03:00
kwargs = self . add_litellm_callbacks ( kwargs )
2024-08-17 09:15:05 +03:00
2024-07-27 17:23:42 +03:00
seed = get_settings ( ) . config . get ( " seed " , - 1 )
if temperature > 0 and seed > = 0 :
2024-07-27 18:03:35 +03:00
raise ValueError ( f " Seed ( { seed } ) is not supported with temperature ( { temperature } ) > 0 " )
2024-07-27 17:50:59 +03:00
elif seed > = 0 :
get_logger ( ) . info ( f " Using fixed seed of { seed } " )
2024-07-27 18:02:57 +03:00
kwargs [ " seed " ] = seed
2024-07-27 17:19:32 +03:00
2024-03-06 12:13:54 +02:00
if self . repetition_penalty :
kwargs [ " repetition_penalty " ] = self . repetition_penalty
2024-02-24 16:47:23 +02:00
2024-02-25 10:45:15 +02:00
get_logger ( ) . debug ( " Prompts " , artifact = { " system " : system , " user " : user } )
2024-03-06 12:13:54 +02:00
if get_settings ( ) . config . verbosity_level > = 2 :
get_logger ( ) . info ( f " \n System prompt: \n { system } " )
get_logger ( ) . info ( f " \n User prompt: \n { user } " )
2023-11-28 20:11:40 +09:00
response = await acompletion ( * * kwargs )
2024-06-29 11:30:15 +03:00
except ( openai . APIError , openai . APITimeoutError ) as e :
2024-09-15 08:50:24 +03:00
get_logger ( ) . warning ( f " Error during LLM inference: { e } " )
2023-07-06 00:21:08 +03:00
raise
2024-03-06 12:13:54 +02:00
except ( openai . RateLimitError ) as e :
2024-09-15 08:50:24 +03:00
get_logger ( ) . error ( f " Rate limit error during LLM inference: { e } " )
2023-07-20 15:01:12 +03:00
raise
2023-07-20 15:02:34 +03:00
except ( Exception ) as e :
2024-09-15 08:50:24 +03:00
get_logger ( ) . warning ( f " Unknown error during LLM inference: { e } " )
2024-03-06 12:13:54 +02:00
raise openai . APIError from e
2023-08-03 16:05:46 -07:00
if response is None or len ( response [ " choices " ] ) == 0 :
2024-03-06 12:13:54 +02:00
raise openai . APIError
2024-02-25 09:58:58 +02:00
else :
resp = response [ " choices " ] [ 0 ] [ ' message ' ] [ ' content ' ]
finish_reason = response [ " choices " ] [ 0 ] [ " finish_reason " ]
get_logger ( ) . debug ( f " \n AI response: \n { resp } " )
2024-02-24 16:47:23 +02:00
2024-03-16 13:52:02 +02:00
# log the full response for debugging
response_log = self . prepare_logs ( response , system , user , resp , finish_reason )
2024-03-16 13:47:44 +02:00
get_logger ( ) . debug ( " Full_response " , artifact = response_log )
# for CLI debugging
2024-03-06 12:13:54 +02:00
if get_settings ( ) . config . verbosity_level > = 2 :
get_logger ( ) . info ( f " \n AI response: \n { resp } " )
2024-04-02 11:01:45 +02:00
return resp , finish_reason