feat: enhance GitHubProvider with improved error handling and URL parsing

- Add traceback logging for exceptions in diff file retrieval
- Improve URL parsing to handle '/api/v3' paths and validate GitHub URLs
- Modify `publish_comment` to return None for temporary comments
- Update constructor to accept an optional GitHub client parameter
This commit is contained in:
mrT23
2024-10-14 09:18:06 +03:00
parent 49f8d86c77
commit f6ba49819a

View File

@ -1,36 +1,35 @@
import itertools
import time
import hashlib
import traceback
from datetime import datetime
from typing import Optional, Tuple
from urllib.parse import urlparse
from github import AppAuthentication, Auth, Github, GithubException
from github import AppAuthentication, Auth, Github
from retry import retry
from starlette_context import context
from ..algo.file_filter import filter_ignored
from ..algo.language_handler import is_valid_file
from ..algo.utils import PRReviewHeader, load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file, Range
from ..algo.types import EDIT_TYPE
from ..algo.utils import PRReviewHeader, load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file
from ..config_loader import get_settings
from ..log import get_logger
from ..servers.utils import RateLimitExceeded
from .git_provider import GitProvider, IncrementalPR, MAX_FILES_ALLOWED_FULL
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR, MAX_FILES_ALLOWED_FULL
class GithubProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None):
def __init__(self, pr_url: Optional[str] = None, github_client = None):
self.auto = None
self.repo_obj = None
try:
self.installation_id = context.get("installation_id", None)
except Exception:
self.installation_id = None
self.max_comment_chars = 65000
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/")
self.base_url = get_settings().get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
self.base_url_html = self.base_url.split("api/")[0].rstrip("/") if "api/" in self.base_url else "https://github.com"
self.base_domain = self.base_url.replace("https://", "").replace("http://", "")
self.base_domain_html = self.base_url_html.replace("https://", "").replace("http://", "")
self.github_client = self._get_github_client()
self.repo = None
self.pr_num = None
@ -233,8 +232,9 @@ class GithubProvider(GitProvider):
return diff_files
except GithubException.RateLimitExceededException as e:
get_logger().error(f"Rate limit exceeded for GitHub API. Original message: {e}")
except Exception as e:
get_logger().error(f"Failing to get diff files: {e}",
artifact={"traceback": traceback.format_exc()})
raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e
def publish_description(self, pr_title: str, pr_body: str):
@ -256,7 +256,7 @@ class GithubProvider(GitProvider):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
if is_temporary and not get_settings().config.publish_output_progress:
get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}")
return
return None
pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars)
response = self.pr.create_issue_comment(pr_comment)
if hasattr(response, "user") and hasattr(response.user, "login"):
@ -611,8 +611,11 @@ class GithubProvider(GitProvider):
def _parse_pr_url(self, pr_url: str) -> Tuple[str, int]:
parsed_url = urlparse(pr_url)
if parsed_url.path.startswith('/api/v3'):
parsed_url = urlparse(pr_url.replace("/api/v3", ""))
path_parts = parsed_url.path.strip('/').split('/')
if self.base_domain in parsed_url.netloc:
if 'api.github.com' in parsed_url.netloc or '/api/v3' in pr_url:
if len(path_parts) < 5 or path_parts[3] != 'pulls':
raise ValueError("The provided URL does not appear to be a GitHub PR URL")
repo_name = '/'.join(path_parts[1:3])
@ -635,8 +638,12 @@ class GithubProvider(GitProvider):
def _parse_issue_url(self, issue_url: str) -> Tuple[str, int]:
parsed_url = urlparse(issue_url)
if 'github.com' not in parsed_url.netloc:
raise ValueError("The provided URL is not a valid GitHub URL")
path_parts = parsed_url.path.strip('/').split('/')
if self.base_domain in parsed_url.netloc:
if 'api.github.com' in parsed_url.netloc:
if len(path_parts) < 5 or path_parts[3] != 'issues':
raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL")
repo_name = '/'.join(path_parts[1:3])