diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 20933d51..45ef40b2 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Tuple, Union, Callable, List +from github import RateLimitExceededException + from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages @@ -19,7 +21,6 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 - def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: """ @@ -40,7 +41,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s global PATCH_EXTRA_LINES PATCH_EXTRA_LINES = 0 - diff_files = list(git_provider.get_diff_files()) + try: + diff_files = list(git_provider.get_diff_files()) + except RateLimitExceededException as e: + logging.error(f"Rate limit exceeded for git provider API. original message {e}") + raise # get pr languages pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 3f7c1ef2..677c2eb1 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -136,3 +136,4 @@ class IncrementalPR: self.commits_range = None self.first_new_commit_sha = None self.last_seen_commit_sha = None + diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 10a50412..7f617937 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -3,13 +3,15 @@ from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github import AppAuthentication, Github, Auth +from github import AppAuthentication, Auth, Github, GithubException +from retry import retry from pr_agent.config_loader import settings -from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff +from .git_provider import FilePatchInfo, GitProvider, IncrementalPR +from ..servers.utils import RateLimitExceeded class GithubProvider(GitProvider): @@ -78,27 +80,34 @@ class GithubProvider(GitProvider): return self.file_set.values() return self.pr.get_files() + @retry(exceptions=RateLimitExceeded, + tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) def get_diff_files(self) -> list[FilePatchInfo]: - files = self.get_files() - diff_files = [] - for file in files: - if is_valid_file(file.filename): - new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) - patch = file.patch - if self.incremental.is_incremental and self.file_set: - original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) - patch = load_large_diff(file, - new_file_content_str, - original_file_content_str, - None) - self.file_set[file.filename] = patch - else: - original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) + try: + files = self.get_files() + diff_files = [] + for file in files: + if is_valid_file(file.filename): + new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) + patch = file.patch + if self.incremental.is_incremental and self.file_set: + original_file_content_str = self._get_pr_file_content(file, + self.incremental.last_seen_commit_sha) + patch = load_large_diff(file, + new_file_content_str, + original_file_content_str, + None) + self.file_set[file.filename] = patch + else: + original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) - diff_files.append( - FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) - self.diff_files = diff_files - return diff_files + diff_files.append( + FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) + self.diff_files = diff_files + return diff_files + except GithubException.RateLimitExceededException as e: + logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}") + raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e def publish_description(self, pr_title: str, pr_body: str): self.pr.edit(title=pr_title, body=pr_body) diff --git a/pr_agent/servers/utils.py b/pr_agent/servers/utils.py index 942ac449..c24b880c 100644 --- a/pr_agent/servers/utils.py +++ b/pr_agent/servers/utils.py @@ -21,3 +21,7 @@ def verify_signature(payload_body, secret_token, signature_header): if not hmac.compare_digest(expected_signature, signature_header): raise HTTPException(status_code=403, detail="Request signatures didn't match!") + +class RateLimitExceeded(Exception): + """Raised when the git provider API rate limit has been exceeded.""" + pass diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index f951c648..58f4ba32 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -30,6 +30,7 @@ push_changelog_changes=false [github] # The type of deployment to create. Valid values are 'app' or 'user'. deployment_type = "user" +ratelimit_retries = 5 [gitlab] # URL to the gitlab service