From 0da667d1793d0f57f65b551cf683d38be6735509 Mon Sep 17 00:00:00 2001 From: arpit-at Date: Wed, 16 Apr 2025 11:19:04 +0530 Subject: [PATCH] support Azure AD authentication for OpenAI services for litellm implemetation --- .../algo/ai_handlers/litellm_ai_handler.py | 35 ++++++++++++++++++- pr_agent/settings/.secrets_template.toml | 7 ++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 2ca04ea3..d2286d3b 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -1,5 +1,5 @@ import os - +from azure.identity import ClientSecretCredential import litellm import openai import requests @@ -31,6 +31,7 @@ class LiteLLMAIHandler(BaseAiHandler): self.azure = False self.api_base = None self.repetition_penalty = None + if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key @@ -97,6 +98,19 @@ class LiteLLMAIHandler(BaseAiHandler): if get_settings().get("DEEPINFRA.KEY", None): os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY") + # Check for Azure AD configuration + if get_settings().get("AZURE_AD.CLIENT_ID", None): + self.azure = True + # Generate access token using Azure AD credentials from settings + access_token = self._get_azure_ad_token() + litellm.api_key = access_token + openai.api_key = access_token + + # Set API base from settings + self.api_base = get_settings().azure_ad.api_base + litellm.api_base = self.api_base + openai.api_base = self.api_base + # Models that only use user meessage self.user_message_only_models = USER_MESSAGE_ONLY_MODELS @@ -109,6 +123,25 @@ class LiteLLMAIHandler(BaseAiHandler): # Models that support extended thinking self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS + def _get_azure_ad_token(self): + """ + Generates an access token using Azure AD credentials from settings. + Returns: + str: The access token + """ + try: + credential = ClientSecretCredential( + tenant_id=get_settings().azure_ad.tenant_id, + client_id=get_settings().azure_ad.client_id, + client_secret=get_settings().azure_ad.client_secret + ) + # Get token for Azure OpenAI service + token = credential.get_token("https://cognitiveservices.azure.com/.default") + return token.token + except Exception as e: + get_logger().error(f"Failed to get Azure AD token: {e}") + raise + def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() response_log['system'] = system diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index f1bb30d4..05f7bc0e 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -101,3 +101,10 @@ key = "" [deepinfra] key = "" + +[azure_ad] +# Azure AD authentication for OpenAI services +client_id = "" # Your Azure AD application client ID +client_secret = "" # Your Azure AD application client secret +tenant_id = "" # Your Azure AD tenant ID +api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/) \ No newline at end of file