mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-01 19:30:40 +08:00
feat: enhance GitHubProvider with improved error handling and URL parsing
- Add traceback logging for exceptions in diff file retrieval - Improve URL parsing to handle '/api/v3' paths and validate GitHub URLs - Modify `publish_comment` to return None for temporary comments - Update constructor to accept an optional GitHub client parameter
This commit is contained in:
@ -1,36 +1,35 @@
|
||||
import itertools
|
||||
import time
|
||||
import hashlib
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from github import AppAuthentication, Auth, Github, GithubException
|
||||
from github import AppAuthentication, Auth, Github
|
||||
from retry import retry
|
||||
from starlette_context import context
|
||||
|
||||
from ..algo.file_filter import filter_ignored
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import PRReviewHeader, load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, Range
|
||||
from ..algo.types import EDIT_TYPE
|
||||
from ..algo.utils import PRReviewHeader, load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from ..servers.utils import RateLimitExceeded
|
||||
from .git_provider import GitProvider, IncrementalPR, MAX_FILES_ALLOWED_FULL
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR, MAX_FILES_ALLOWED_FULL
|
||||
|
||||
|
||||
class GithubProvider(GitProvider):
|
||||
def __init__(self, pr_url: Optional[str] = None):
|
||||
def __init__(self, pr_url: Optional[str] = None, github_client = None):
|
||||
self.auto = None
|
||||
self.repo_obj = None
|
||||
try:
|
||||
self.installation_id = context.get("installation_id", None)
|
||||
except Exception:
|
||||
self.installation_id = None
|
||||
self.max_comment_chars = 65000
|
||||
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
|
||||
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
||||
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com"
|
||||
self.base_domain = self.base_url.replace("https://", "").replace("http://", "")
|
||||
self.base_domain_html = self.base_url_html.replace("https://", "").replace("http://", "")
|
||||
self.github_client = self._get_github_client()
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
@ -233,8 +232,9 @@ class GithubProvider(GitProvider):
|
||||
|
||||
return diff_files
|
||||
|
||||
except GithubException.RateLimitExceededException as e:
|
||||
get_logger().error(f"Rate limit exceeded for GitHub API. Original message: {e}")
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failing to get diff files: {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
@ -256,7 +256,7 @@ class GithubProvider(GitProvider):
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if is_temporary and not get_settings().config.publish_output_progress:
|
||||
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
|
||||
return
|
||||
return None
|
||||
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
|
||||
response = self.pr.create_issue_comment(pr_comment)
|
||||
if hasattr(response, "user") and hasattr(response.user, "login"):
|
||||
@ -611,8 +611,11 @@ class GithubProvider(GitProvider):
|
||||
def _parse_pr_url(self, pr_url: str) -> Tuple[str, int]:
|
||||
parsed_url = urlparse(pr_url)
|
||||
|
||||
if parsed_url.path.startswith('/api/v3'):
|
||||
parsed_url = urlparse(pr_url.replace("/api/v3", ""))
|
||||
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if self.base_domain in parsed_url.netloc:
|
||||
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
|
||||
if len(path_parts) < 5 or path_parts[3] != 'pulls':
|
||||
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
|
||||
repo_name = '/'.join(path_parts[1:3])
|
||||
@ -635,8 +638,12 @@ class GithubProvider(GitProvider):
|
||||
|
||||
def _parse_issue_url(self, issue_url: str) -> Tuple[str, int]:
|
||||
parsed_url = urlparse(issue_url)
|
||||
|
||||
if 'github.com' not in parsed_url.netloc:
|
||||
raise ValueError("The provided URL is not a valid GitHub URL")
|
||||
|
||||
path_parts = parsed_url.path.strip('/').split('/')
|
||||
if self.base_domain in parsed_url.netloc:
|
||||
if 'api.github.com' in parsed_url.netloc:
|
||||
if len(path_parts) < 5 or path_parts[3] != 'issues':
|
||||
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL")
|
||||
repo_name = '/'.join(path_parts[1:3])
|
||||
|
Reference in New Issue
Block a user