mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
support Azure AD authentication for OpenAI services for litellm implemetation
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from azure.identity import ClientSecretCredential
|
||||||
import litellm
|
import litellm
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
@ -31,6 +31,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
self.azure = False
|
self.azure = False
|
||||||
self.api_base = None
|
self.api_base = None
|
||||||
self.repetition_penalty = None
|
self.repetition_penalty = None
|
||||||
|
|
||||||
if get_settings().get("OPENAI.KEY", None):
|
if get_settings().get("OPENAI.KEY", None):
|
||||||
openai.api_key = get_settings().openai.key
|
openai.api_key = get_settings().openai.key
|
||||||
litellm.openai_key = get_settings().openai.key
|
litellm.openai_key = get_settings().openai.key
|
||||||
@ -97,6 +98,19 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
if get_settings().get("DEEPINFRA.KEY", None):
|
if get_settings().get("DEEPINFRA.KEY", None):
|
||||||
os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY")
|
os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY")
|
||||||
|
|
||||||
|
# Check for Azure AD configuration
|
||||||
|
if get_settings().get("AZURE_AD.CLIENT_ID", None):
|
||||||
|
self.azure = True
|
||||||
|
# Generate access token using Azure AD credentials from settings
|
||||||
|
access_token = self._get_azure_ad_token()
|
||||||
|
litellm.api_key = access_token
|
||||||
|
openai.api_key = access_token
|
||||||
|
|
||||||
|
# Set API base from settings
|
||||||
|
self.api_base = get_settings().azure_ad.api_base
|
||||||
|
litellm.api_base = self.api_base
|
||||||
|
openai.api_base = self.api_base
|
||||||
|
|
||||||
# Models that only use user meessage
|
# Models that only use user meessage
|
||||||
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
|
self.user_message_only_models = USER_MESSAGE_ONLY_MODELS
|
||||||
|
|
||||||
@ -109,6 +123,25 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
# Models that support extended thinking
|
# Models that support extended thinking
|
||||||
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
|
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
|
||||||
|
|
||||||
|
def _get_azure_ad_token(self):
|
||||||
|
"""
|
||||||
|
Generates an access token using Azure AD credentials from settings.
|
||||||
|
Returns:
|
||||||
|
str: The access token
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
credential = ClientSecretCredential(
|
||||||
|
tenant_id=get_settings().azure_ad.tenant_id,
|
||||||
|
client_id=get_settings().azure_ad.client_id,
|
||||||
|
client_secret=get_settings().azure_ad.client_secret
|
||||||
|
)
|
||||||
|
# Get token for Azure OpenAI service
|
||||||
|
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||||
|
return token.token
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Failed to get Azure AD token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
def prepare_logs(self, response, system, user, resp, finish_reason):
|
def prepare_logs(self, response, system, user, resp, finish_reason):
|
||||||
response_log = response.dict().copy()
|
response_log = response.dict().copy()
|
||||||
response_log['system'] = system
|
response_log['system'] = system
|
||||||
|
@ -101,3 +101,10 @@ key = ""
|
|||||||
|
|
||||||
[deepinfra]
|
[deepinfra]
|
||||||
key = ""
|
key = ""
|
||||||
|
|
||||||
|
[azure_ad]
|
||||||
|
# Azure AD authentication for OpenAI services
|
||||||
|
client_id = "" # Your Azure AD application client ID
|
||||||
|
client_secret = "" # Your Azure AD application client secret
|
||||||
|
tenant_id = "" # Your Azure AD tenant ID
|
||||||
|
api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/)
|
Reference in New Issue
Block a user