Compare commits

..

1 Commits

6 changed files with 25 additions and 46 deletions

View File

@ -3,8 +3,6 @@ from __future__ import annotations
import logging import logging
from typing import Tuple, Union, Callable, List from typing import Tuple, Union, Callable, List
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
@ -21,6 +19,7 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000
OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3 PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
""" """
@ -41,11 +40,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
global PATCH_EXTRA_LINES global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0 PATCH_EXTRA_LINES = 0
try: diff_files = list(git_provider.get_diff_files())
diff_files = list(git_provider.get_diff_files())
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)

View File

@ -136,4 +136,3 @@ class IncrementalPR:
self.commits_range = None self.commits_range = None
self.first_new_commit_sha = None self.first_new_commit_sha = None
self.last_seen_commit_sha = None self.last_seen_commit_sha = None

View File

@ -3,15 +3,13 @@ from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from github import AppAuthentication, Auth, Github, GithubException from github import AppAuthentication, Github, Auth
from retry import retry
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..algo.language_handler import is_valid_file from ..algo.language_handler import is_valid_file
from ..algo.utils import load_large_diff from ..algo.utils import load_large_diff
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR
from ..servers.utils import RateLimitExceeded
class GithubProvider(GitProvider): class GithubProvider(GitProvider):
@ -80,34 +78,27 @@ class GithubProvider(GitProvider):
return self.file_set.values() return self.file_set.values()
return self.pr.get_files() return self.pr.get_files()
@retry(exceptions=RateLimitExceeded,
tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3))
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
try: files = self.get_files()
files = self.get_files() diff_files = []
diff_files = [] for file in files:
for file in files: if is_valid_file(file.filename):
if is_valid_file(file.filename): new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) patch = file.patch
patch = file.patch if self.incremental.is_incremental and self.file_set:
if self.incremental.is_incremental and self.file_set: original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha)
original_file_content_str = self._get_pr_file_content(file, patch = load_large_diff(file,
self.incremental.last_seen_commit_sha) new_file_content_str,
patch = load_large_diff(file, original_file_content_str,
new_file_content_str, None)
original_file_content_str, self.file_set[file.filename] = patch
None) else:
self.file_set[file.filename] = patch original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
else:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
diff_files.append( diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename))
self.diff_files = diff_files self.diff_files = diff_files
return diff_files return diff_files
except GithubException.RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}")
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
self.pr.edit(title=pr_title, body=pr_body) self.pr.edit(title=pr_title, body=pr_body)

View File

@ -14,7 +14,6 @@ from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False):
gitlab_url = settings.get("GITLAB.URL", None) gitlab_url = settings.get("GITLAB.URL", None)
if not gitlab_url: if not gitlab_url:
@ -23,8 +22,8 @@ class GitLabProvider(GitProvider):
if not gitlab_access_token: if not gitlab_access_token:
raise ValueError("GitLab personal access token is not set in the config file") raise ValueError("GitLab personal access token is not set in the config file")
self.gl = gitlab.Gitlab( self.gl = gitlab.Gitlab(
gitlab_url, url=gitlab_url,
gitlab_access_token oauth_token=gitlab_access_token
) )
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None

View File

@ -21,7 +21,3 @@ def verify_signature(payload_body, secret_token, signature_header):
if not hmac.compare_digest(expected_signature, signature_header): if not hmac.compare_digest(expected_signature, signature_header):
raise HTTPException(status_code=403, detail="Request signatures didn't match!") raise HTTPException(status_code=403, detail="Request signatures didn't match!")
class RateLimitExceeded(Exception):
"""Raised when the git provider API rate limit has been exceeded."""
pass

View File

@ -27,7 +27,6 @@ num_code_suggestions=4
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'. # The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user" deployment_type = "user"
ratelimit_retries = 5
[gitlab] [gitlab]
# URL to the gitlab service # URL to the gitlab service