diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index fe629993..1f6af30b 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -203,8 +203,10 @@ async def retry_with_fallback_models(f: Callable): if not isinstance(fallback_models, list): fallback_models = [fallback_models] all_models = [model] + fallback_models - for model in all_models: + for i, model in enumerate(all_models): try: return await f(model) except Exception as e: logging.warning(f"Failed to generate prediction with {model}: {e}") + if i == len(all_models) - 1: # If it's the last iteration + raise # Re-raise the last exception