Support fallback deployments to accompany fallback models

This is useful for example in Azure OpenAI deployments where you have a different deployment per model, so the current fallback implementation doesn't work (still uses the same deployment for each fallback attempt)
This commit is contained in:
zmeir
2023-08-07 16:17:06 +03:00
parent 43297b851f
commit 6c4a5bae52
3 changed files with 23 additions and 3 deletions

View File

@ -208,13 +208,26 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
async def retry_with_fallback_models(f: Callable):
# getting all models
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [fallback_models]
fallback_models = [m.strip() for m in fallback_models.split(",")]
all_models = [model] + fallback_models
for i, model in enumerate(all_models):
# getting all deployments
deployment_id = get_settings().get("openai.deployment_id", None)
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
if not isinstance(fallback_deployments, list) and fallback_deployments:
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
if fallback_deployments:
all_deployments = [deployment_id] + fallback_deployments
else:
all_deployments = [deployment_id] * len(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
try:
get_settings().set("openai.deployment_id", deployment_id)
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")