From 0da667d1793d0f57f65b551cf683d38be6735509 Mon Sep 17 00:00:00 2001 From: arpit-at Date: Wed, 16 Apr 2025 11:19:04 +0530 Subject: [PATCH 1/3] support Azure AD authentication for OpenAI services for litellm implemetation --- .../algo/ai_handlers/litellm_ai_handler.py | 35 ++++++++++++++++++- pr_agent/settings/.secrets_template.toml | 7 ++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 2ca04ea3..d2286d3b 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -1,5 +1,5 @@ import os - +from azure.identity import ClientSecretCredential import litellm import openai import requests @@ -31,6 +31,7 @@ class LiteLLMAIHandler(BaseAiHandler): self.azure = False self.api_base = None self.repetition_penalty = None + if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key @@ -97,6 +98,19 @@ class LiteLLMAIHandler(BaseAiHandler): if get_settings().get("DEEPINFRA.KEY", None): os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY") + # Check for Azure AD configuration + if get_settings().get("AZURE_AD.CLIENT_ID", None): + self.azure = True + # Generate access token using Azure AD credentials from settings + access_token = self._get_azure_ad_token() + litellm.api_key = access_token + openai.api_key = access_token + + # Set API base from settings + self.api_base = get_settings().azure_ad.api_base + litellm.api_base = self.api_base + openai.api_base = self.api_base + # Models that only use user meessage self.user_message_only_models = USER_MESSAGE_ONLY_MODELS @@ -109,6 +123,25 @@ class LiteLLMAIHandler(BaseAiHandler): # Models that support extended thinking self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS + def _get_azure_ad_token(self): + """ + Generates an access token using Azure AD credentials from settings. + Returns: + str: The access token + """ + try: + credential = ClientSecretCredential( + tenant_id=get_settings().azure_ad.tenant_id, + client_id=get_settings().azure_ad.client_id, + client_secret=get_settings().azure_ad.client_secret + ) + # Get token for Azure OpenAI service + token = credential.get_token("https://cognitiveservices.azure.com/.default") + return token.token + except Exception as e: + get_logger().error(f"Failed to get Azure AD token: {e}") + raise + def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() response_log['system'] = system diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index f1bb30d4..05f7bc0e 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -101,3 +101,10 @@ key = "" [deepinfra] key = "" + +[azure_ad] +# Azure AD authentication for OpenAI services +client_id = "" # Your Azure AD application client ID +client_secret = "" # Your Azure AD application client secret +tenant_id = "" # Your Azure AD tenant ID +api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/) \ No newline at end of file From dc46acb7626d976b46715032dc6ed72b3c1aa348 Mon Sep 17 00:00:00 2001 From: arpit-at Date: Wed, 16 Apr 2025 13:27:52 +0530 Subject: [PATCH 2/3] doc update and minor fix --- docs/docs/usage-guide/changing_a_model.md | 11 +++++++++++ pr_agent/algo/ai_handlers/litellm_ai_handler.py | 1 + 2 files changed, 12 insertions(+) diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index 9eec3b64..36480d21 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -37,6 +37,17 @@ model="" # the OpenAI model you've deployed on Azure (e.g. gpt-4o) fallback_models=["..."] ``` +To use Azure AD (Entra id) based authentication set in your `.secrets.toml` (working from CLI), or in the GitHub `Settings > Secrets and variables` (working from GitHub App or GitHub Action): + +```toml +[azure_ad] +client_id = "" # Your Azure AD application client ID +client_secret = "" # Your Azure AD application client secret +tenant_id = "" # Your Azure AD tenant ID +api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/) +``` + + Passing custom headers to the underlying LLM Model API can be done by setting extra_headers parameter to litellm. ```toml diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index d2286d3b..b34b4a0a 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -100,6 +100,7 @@ class LiteLLMAIHandler(BaseAiHandler): # Check for Azure AD configuration if get_settings().get("AZURE_AD.CLIENT_ID", None): + from azure.identity import ClientSecretCredential self.azure = True # Generate access token using Azure AD credentials from settings access_token = self._get_azure_ad_token() From 27a7c1a94f36dad384a7b9e5f650b283b94f7342 Mon Sep 17 00:00:00 2001 From: arpit-at Date: Wed, 16 Apr 2025 13:32:53 +0530 Subject: [PATCH 3/3] doc update and minor fix --- pr_agent/algo/ai_handlers/litellm_ai_handler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index b34b4a0a..d717c087 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -1,5 +1,4 @@ import os -from azure.identity import ClientSecretCredential import litellm import openai import requests @@ -100,7 +99,6 @@ class LiteLLMAIHandler(BaseAiHandler): # Check for Azure AD configuration if get_settings().get("AZURE_AD.CLIENT_ID", None): - from azure.identity import ClientSecretCredential self.azure = True # Generate access token using Azure AD credentials from settings access_token = self._get_azure_ad_token() @@ -130,6 +128,7 @@ class LiteLLMAIHandler(BaseAiHandler): Returns: str: The access token """ + from azure.identity import ClientSecretCredential try: credential = ClientSecretCredential( tenant_id=get_settings().azure_ad.tenant_id,