diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 17ea2356..1e96d2f1 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -10,8 +10,7 @@ from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_large_diff from pr_agent.config_loader import settings from pr_agent.git_providers.git_provider import GitProvider -from github import GithubException -from retry import retry + DELETED_FILES_ = "Deleted files:\n" @@ -21,10 +20,6 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 -GITHUB_RETRIES=1 - -@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError, GithubException.RateLimitExceededException), - tries=GITHUB_RETRIES, delay=2, backoff=2, jitter=(1, 3)) def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: """ diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 10a50412..64e4ab46 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -3,14 +3,14 @@ from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github import AppAuthentication, Github, Auth +from github import AppAuthentication, Github, Auth, GithubException from pr_agent.config_loader import settings from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff - +from retry import retry class GithubProvider(GitProvider): def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): @@ -78,6 +78,8 @@ class GithubProvider(GitProvider): return self.file_set.values() return self.pr.get_files() + @retry(exceptions=(GithubException.RateLimitExceededException), + tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) def get_diff_files(self) -> list[FilePatchInfo]: files = self.get_files() diff_files = []