mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 05:10:38 +08:00
Support fallback deployments to accompany fallback models
This is useful for example in Azure OpenAI deployments where you have a different deployment per model, so the current fallback implementation doesn't work (still uses the same deployment for each fallback attempt)
This commit is contained in:
@ -27,7 +27,6 @@ class AiHandler:
|
|||||||
self.azure = False
|
self.azure = False
|
||||||
if get_settings().get("OPENAI.ORG", None):
|
if get_settings().get("OPENAI.ORG", None):
|
||||||
litellm.organization = get_settings().openai.org
|
litellm.organization = get_settings().openai.org
|
||||||
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
|
||||||
if get_settings().get("OPENAI.API_TYPE", None):
|
if get_settings().get("OPENAI.API_TYPE", None):
|
||||||
if get_settings().openai.api_type == "azure":
|
if get_settings().openai.api_type == "azure":
|
||||||
self.azure = True
|
self.azure = True
|
||||||
@ -45,6 +44,13 @@ class AiHandler:
|
|||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("OpenAI key is required") from e
|
raise ValueError("OpenAI key is required") from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deployment_id(self):
|
||||||
|
"""
|
||||||
|
Returns the deployment ID for the OpenAI API.
|
||||||
|
"""
|
||||||
|
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
|
||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
|
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
|
||||||
|
@ -208,13 +208,26 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
|
|
||||||
|
|
||||||
async def retry_with_fallback_models(f: Callable):
|
async def retry_with_fallback_models(f: Callable):
|
||||||
|
# getting all models
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
fallback_models = get_settings().config.fallback_models
|
fallback_models = get_settings().config.fallback_models
|
||||||
if not isinstance(fallback_models, list):
|
if not isinstance(fallback_models, list):
|
||||||
fallback_models = [fallback_models]
|
fallback_models = [m.strip() for m in fallback_models.split(",")]
|
||||||
all_models = [model] + fallback_models
|
all_models = [model] + fallback_models
|
||||||
for i, model in enumerate(all_models):
|
|
||||||
|
# getting all deployments
|
||||||
|
deployment_id = get_settings().get("openai.deployment_id", None)
|
||||||
|
fallback_deployments = get_settings().get("openai.fallback_deployments", [])
|
||||||
|
if not isinstance(fallback_deployments, list) and fallback_deployments:
|
||||||
|
fallback_deployments = [d.strip() for d in fallback_deployments.split(",")]
|
||||||
|
if fallback_deployments:
|
||||||
|
all_deployments = [deployment_id] + fallback_deployments
|
||||||
|
else:
|
||||||
|
all_deployments = [deployment_id] * len(all_models)
|
||||||
|
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||||
|
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
||||||
try:
|
try:
|
||||||
|
get_settings().set("openai.deployment_id", deployment_id)
|
||||||
return await f(model)
|
return await f(model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
|
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
|
||||||
|
@ -14,6 +14,7 @@ key = "" # Acquire through https://platform.openai.com
|
|||||||
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
#api_version = '2023-05-15' # Check Azure documentation for the current API version
|
||||||
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
|
||||||
#deployment_id = "" # The deployment name you chose when you deployed the engine
|
#deployment_id = "" # The deployment name you chose when you deployed the engine
|
||||||
|
#fallback_deployments = [] # Match your fallback models from configuration.toml with the appropriate deployment_id
|
||||||
|
|
||||||
[anthropic]
|
[anthropic]
|
||||||
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
|
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/
|
||||||
|
Reference in New Issue
Block a user