diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index c4cc0224..fd491012 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -30,6 +30,14 @@ model="" # the OpenAI model you've deployed on Azure (e.g. gpt-4o) fallback_models=["..."] ``` +Passing custom headers to the underlying LLM Model API can be done by setting extra_headers parameter to litellm. +``` +[litellm] +extra_headers='{"projectId": "", ...}') #The value of this setting should be a JSON string representing the desired headers, a ValueError is thrown otherwise. +``` +This enables users to pass authorization tokens or API keys, when routing requests through an API management gateway. + + ### Ollama You can run models locally through either [VLLM](https://docs.litellm.ai/docs/providers/vllm) or [Ollama](https://docs.litellm.ai/docs/providers/ollama) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index e3d9a94e..ba557979 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -11,6 +11,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.utils import ReasoningEffort, get_version from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +import json OPENAI_RETRIES = 5 @@ -254,12 +255,22 @@ class LiteLLMAIHandler(BaseAiHandler): if self.repetition_penalty: kwargs["repetition_penalty"] = self.repetition_penalty + #Added support for extra_headers while using litellm to call underlying model, via a api management gateway, would allow for passing custom headers for security and authorization + if get_settings().get("LITELLM.EXTRA_HEADERS", None): + try: + litellm_extra_headers = json.loads(get_settings().litellm.extra_headers) + if not isinstance(litellm_extra_headers, dict): + raise ValueError("LITELLM.EXTRA_HEADERS must be a JSON object") + except json.JSONDecodeError as e: + raise ValueError(f"LITELLM.EXTRA_HEADERS contains invalid JSON: {str(e)}") + kwargs["extra_headers"] = litellm_extra_headers + get_logger().debug("Prompts", artifact={"system": system, "user": user}) - + if get_settings().config.verbosity_level >= 2: get_logger().info(f"\nSystem prompt:\n{system}") get_logger().info(f"\nUser prompt:\n{user}") - + response = await acompletion(**kwargs) except (openai.APIError, openai.APITimeoutError) as e: get_logger().warning(f"Error during LLM inference: {e}")