mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
support Azure AD authentication for OpenAI services for litellm implemetation
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
import os
|
||||
|
||||
from azure.identity import ClientSecretCredential
|
||||
import litellm
|
||||
import openai
|
||||
import requests
|
||||
@ -31,6 +31,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
self.azure = False
|
||||
self.api_base = None
|
||||
self.repetition_penalty = None
|
||||
|
||||
if get_settings().get("OPENAI.KEY", None):
|
||||
openai.api_key = get_settings().openai.key
|
||||
litellm.openai_key = get_settings().openai.key
|
||||
@ -97,6 +98,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
if get_settings().get("DEEPINFRA.KEY", None):
|
||||
os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY")
|
||||
|
||||
# Check for Azure AD configuration
|
||||
if get_settings().get("AZURE_AD.CLIENT_ID", None):
|
||||
self.azure = True
|
||||
# Generate access token using Azure AD credentials from settings
|
||||
access_token = self._get_azure_ad_token()
|
||||
litellm.api_key = access_token
|
||||
openai.api_key = access_token
|
||||
|
||||
# Set API base from settings
|
||||
self.api_base = get_settings().azure_ad.api_base
|
||||
litellm.api_base = self.api_base
|
||||
openai.api_base = self.api_base
|
||||
|
||||
# Models that only use user meessage
|
||||
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
|
||||
|
||||
@ -109,6 +123,25 @@ class LiteLLMAIHandler(BaseAiHandler):
|
||||
# Models that support extended thinking
|
||||
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
|
||||
|
||||
def _get_azure_ad_token(self):
|
||||
"""
|
||||
Generates an access token using Azure AD credentials from settings.
|
||||
Returns:
|
||||
str: The access token
|
||||
"""
|
||||
try:
|
||||
credential = ClientSecretCredential(
|
||||
tenant_id=get_settings().azure_ad.tenant_id,
|
||||
client_id=get_settings().azure_ad.client_id,
|
||||
client_secret=get_settings().azure_ad.client_secret
|
||||
)
|
||||
# Get token for Azure OpenAI service
|
||||
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||
return token.token
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to get Azure AD token: {e}")
|
||||
raise
|
||||
|
||||
def prepare_logs(self, response, system, user, resp, finish_reason):
|
||||
response_log = response.dict().copy()
|
||||
response_log['system'] = system
|
||||
|
@ -101,3 +101,10 @@ key = ""
|
||||
|
||||
[deepinfra]
|
||||
key = ""
|
||||
|
||||
[azure_ad]
|
||||
# Azure AD authentication for OpenAI services
|
||||
client_id = "" # Your Azure AD application client ID
|
||||
client_secret = "" # Your Azure AD application client secret
|
||||
tenant_id = "" # Your Azure AD tenant ID
|
||||
api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/)
|
Reference in New Issue
Block a user