support Azure AD authentication for OpenAI services for litellm implemetation

This commit is contained in:
arpit-at
2025-04-16 11:19:04 +05:30
parent 73b3e2520c
commit 0da667d179
2 changed files with 41 additions and 1 deletions

View File

@ -1,5 +1,5 @@
import os
from azure.identity import ClientSecretCredential
import litellm
import openai
import requests
@ -31,6 +31,7 @@ class LiteLLMAIHandler(BaseAiHandler):
self.azure = False
self.api_base = None
self.repetition_penalty = None
if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
@ -97,6 +98,19 @@ class LiteLLMAIHandler(BaseAiHandler):
if get_settings().get("DEEPINFRA.KEY", None):
os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY")
# Check for Azure AD configuration
if get_settings().get("AZURE_AD.CLIENT_ID", None):
self.azure = True
# Generate access token using Azure AD credentials from settings
access_token = self._get_azure_ad_token()
litellm.api_key = access_token
openai.api_key = access_token
# Set API base from settings
self.api_base = get_settings().azure_ad.api_base
litellm.api_base = self.api_base
openai.api_base = self.api_base
# Models that only use user meessage
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
@ -109,6 +123,25 @@ class LiteLLMAIHandler(BaseAiHandler):
# Models that support extended thinking
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
def _get_azure_ad_token(self):
"""
Generates an access token using Azure AD credentials from settings.
Returns:
str: The access token
"""
try:
credential = ClientSecretCredential(
tenant_id=get_settings().azure_ad.tenant_id,
client_id=get_settings().azure_ad.client_id,
client_secret=get_settings().azure_ad.client_secret
)
# Get token for Azure OpenAI service
token = credential.get_token("https://cognitiveservices.azure.com/.default")
return token.token
except Exception as e:
get_logger().error(f"Failed to get Azure AD token: {e}")
raise
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
response_log['system'] = system

View File

@ -101,3 +101,10 @@ key = ""
[deepinfra]
key = ""
[azure_ad]
# Azure AD authentication for OpenAI services
client_id = "" # Your Azure AD application client ID
client_secret = "" # Your Azure AD application client secret
tenant_id = "" # Your Azure AD tenant ID
api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/)