From 6c4a5bae52f969ee1fedb4bd7e855d7c333b8ee0 Mon Sep 17 00:00:00 2001 From: zmeir Date: Mon, 7 Aug 2023 16:17:06 +0300 Subject: [PATCH] Support fallback deployments to accompany fallback models This is useful for example in Azure OpenAI deployments where you have a different deployment per model, so the current fallback implementation doesn't work (still uses the same deployment for each fallback attempt) --- pr_agent/algo/ai_handler.py | 8 +++++++- pr_agent/algo/pr_processing.py | 17 +++++++++++++++-- pr_agent/settings/.secrets_template.toml | 1 + 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 57221518..cfca63f6 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -27,7 +27,6 @@ class AiHandler: self.azure = False if get_settings().get("OPENAI.ORG", None): litellm.organization = get_settings().openai.org - self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None) if get_settings().get("OPENAI.API_TYPE", None): if get_settings().openai.api_type == "azure": self.azure = True @@ -45,6 +44,13 @@ class AiHandler: except AttributeError as e: raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, temperature: float, system: str, user: str): diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 8b319446..fae2535a 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -208,13 +208,26 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo async def retry_with_fallback_models(f: Callable): + # getting all models model = get_settings().config.model fallback_models = get_settings().config.fallback_models if not isinstance(fallback_models, list): - fallback_models = [fallback_models] + fallback_models = [m.strip() for m in fallback_models.split(",")] all_models = [model] + fallback_models - for i, model in enumerate(all_models): + + # getting all deployments + deployment_id = get_settings().get("openai.deployment_id", None) + fallback_deployments = get_settings().get("openai.fallback_deployments", []) + if not isinstance(fallback_deployments, list) and fallback_deployments: + fallback_deployments = [d.strip() for d in fallback_deployments.split(",")] + if fallback_deployments: + all_deployments = [deployment_id] + fallback_deployments + else: + all_deployments = [deployment_id] * len(all_models) + # try each (model, deployment_id) pair until one is successful, otherwise raise exception + for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)): try: + get_settings().set("openai.deployment_id", deployment_id) return await f(model) except Exception as e: logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}") diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index 36b529a6..25a6562f 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -14,6 +14,7 @@ key = "" # Acquire through https://platform.openai.com #api_version = '2023-05-15' # Check Azure documentation for the current API version #api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://.openai.azure.com" #deployment_id = "" # The deployment name you chose when you deployed the engine +#fallback_deployments = [] # Match your fallback models from configuration.toml with the appropriate deployment_id [anthropic] key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/