diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index 9eec3b64..36480d21 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -37,6 +37,17 @@ model="" # the OpenAI model you've deployed on Azure (e.g. gpt-4o) fallback_models=["..."] ``` +To use Azure AD (Entra id) based authentication set in your `.secrets.toml` (working from CLI), or in the GitHub `Settings > Secrets and variables` (working from GitHub App or GitHub Action): + +```toml +[azure_ad] +client_id = "" # Your Azure AD application client ID +client_secret = "" # Your Azure AD application client secret +tenant_id = "" # Your Azure AD tenant ID +api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/) +``` + + Passing custom headers to the underlying LLM Model API can be done by setting extra_headers parameter to litellm. ```toml diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 2ca04ea3..d717c087 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -1,5 +1,4 @@ import os - import litellm import openai import requests @@ -31,6 +30,7 @@ class LiteLLMAIHandler(BaseAiHandler): self.azure = False self.api_base = None self.repetition_penalty = None + if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key @@ -97,6 +97,19 @@ class LiteLLMAIHandler(BaseAiHandler): if get_settings().get("DEEPINFRA.KEY", None): os.environ['DEEPINFRA_API_KEY'] = get_settings().get("DEEPINFRA.KEY") + # Check for Azure AD configuration + if get_settings().get("AZURE_AD.CLIENT_ID", None): + self.azure = True + # Generate access token using Azure AD credentials from settings + access_token = self._get_azure_ad_token() + litellm.api_key = access_token + openai.api_key = access_token + + # Set API base from settings + self.api_base = get_settings().azure_ad.api_base + litellm.api_base = self.api_base + openai.api_base = self.api_base + # Models that only use user meessage self.user_message_only_models = USER_MESSAGE_ONLY_MODELS @@ -109,6 +122,26 @@ class LiteLLMAIHandler(BaseAiHandler): # Models that support extended thinking self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS + def _get_azure_ad_token(self): + """ + Generates an access token using Azure AD credentials from settings. + Returns: + str: The access token + """ + from azure.identity import ClientSecretCredential + try: + credential = ClientSecretCredential( + tenant_id=get_settings().azure_ad.tenant_id, + client_id=get_settings().azure_ad.client_id, + client_secret=get_settings().azure_ad.client_secret + ) + # Get token for Azure OpenAI service + token = credential.get_token("https://cognitiveservices.azure.com/.default") + return token.token + except Exception as e: + get_logger().error(f"Failed to get Azure AD token: {e}") + raise + def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() response_log['system'] = system diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index f1bb30d4..05f7bc0e 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -101,3 +101,10 @@ key = "" [deepinfra] key = "" + +[azure_ad] +# Azure AD authentication for OpenAI services +client_id = "" # Your Azure AD application client ID +client_secret = "" # Your Azure AD application client secret +tenant_id = "" # Your Azure AD tenant ID +api_base = "" # Your Azure OpenAI service base URL (e.g., https://openai.xyz.com/) \ No newline at end of file