diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index db311dac..adab9506 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -208,24 +208,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo async def retry_with_fallback_models(f: Callable): - # getting all models - model = get_settings().config.model - fallback_models = get_settings().config.fallback_models - if not isinstance(fallback_models, list): - fallback_models = [m.strip() for m in fallback_models.split(",")] - all_models = [model] + fallback_models - # getting all deployments - deployment_id = get_settings().get("openai.deployment_id", None) - fallback_deployments = get_settings().get("openai.fallback_deployments", []) - if not isinstance(fallback_deployments, list) and fallback_deployments: - fallback_deployments = [d.strip() for d in fallback_deployments.split(",")] - if fallback_deployments: - all_deployments = [deployment_id] + fallback_deployments - if len(fallback_deployments) < len(fallback_models): - raise ValueError(f"The number of fallback deployments ({len(all_deployments)}) " - f"is less than the number of fallback models ({len(all_models)})") - else: - all_deployments = [deployment_id] * len(all_models) + all_models = _get_all_models() + all_deployments = _get_all_deployments(all_models) # try each (model, deployment_id) pair until one is successful, otherwise raise exception for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)): try: @@ -241,6 +225,30 @@ async def retry_with_fallback_models(f: Callable): raise # Re-raise the last exception +def _get_all_models() -> List[str]: + model = get_settings().config.model + fallback_models = get_settings().config.fallback_models + if not isinstance(fallback_models, list): + fallback_models = [m.strip() for m in fallback_models.split(",")] + all_models = [model] + fallback_models + return all_models + + +def _get_all_deployments(all_models: List[str]) -> List[str]: + deployment_id = get_settings().get("openai.deployment_id", None) + fallback_deployments = get_settings().get("openai.fallback_deployments", []) + if not isinstance(fallback_deployments, list) and fallback_deployments: + fallback_deployments = [d.strip() for d in fallback_deployments.split(",")] + if fallback_deployments: + all_deployments = [deployment_id] + fallback_deployments + if len(all_deployments) < len(all_models): + raise ValueError(f"The number of deployments ({len(all_deployments)}) " + f"is less than the number of models ({len(all_models)})") + else: + all_deployments = [deployment_id] * len(all_models) + return all_deployments + + def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], relevant_file: str, relevant_line_in_file: str) -> Tuple[int, int]: