diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index feee84a0..3a6264eb 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse +from github.Issue import Issue from github import AppAuthentication, Auth, Github, GithubException from retry import retry from starlette_context import context @@ -42,6 +43,7 @@ class GithubProvider(GitProvider): self.repo = None self.pr_num = None self.pr = None + self.issue_main = None self.github_user_id = None self.diff_files = None self.git_files = None @@ -51,9 +53,29 @@ class GithubProvider(GitProvider): self.pr_commits = list(self.pr.get_commits()) self.last_commit_id = self.pr_commits[-1] self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object - else: + elif pr_url and 'issue' in pr_url: #url is an issue + self.issue_main = self._get_issue_handle(pr_url) + else: #Instantiated the provider without a PR / Issue self.pr_commits = None + def _get_issue_handle(self, issue_url) -> Optional[Issue]: + repo_name, issue_number = self._parse_issue_url(issue_url) + if not repo_name or not issue_number: + get_logger().error(f"Given url: {issue_url} is not a valid issue.") + return None + # else: Check if can get a valid Repo handle: + try: + repo_obj = self.github_client.get_repo(repo_name) + if not repo_obj: + get_logger().error(f"Given url: {issue_url}, belonging to owner/repo: {repo_name} does " + f"not have a valid repository: {self.get_git_repo_url(issue_url)}") + return None + # else: Valid repo handle: + return repo_obj.get_issue(issue_number) + except Exception as e: + get_logger().exception(f"Failed to get an issue object for issue: {issue_url}, belonging to owner/repo: {repo_name}") + return None + def get_incremental_commits(self, incremental=IncrementalPR(False)): self.incremental = incremental if self.incremental.is_incremental: @@ -344,10 +366,19 @@ class GithubProvider(GitProvider): self.publish_persistent_comment_full(pr_comment, initial_header, update_header, name, final_update_message) def publish_comment(self, pr_comment: str, is_temporary: bool = False): + if not self.pr and not self.issue_main: + get_logger().error("Cannot publish a comment if missing PR/Issue context") + return None + if is_temporary and not get_settings().config.publish_output_progress: get_logger().debug(f"Skipping publish_comment for temporary comment: {pr_comment}") return None pr_comment = self.limit_output_characters(pr_comment, self.max_comment_chars) + + # In case this is an issue, can publish the comment on the issue. + if self.issue_main: + return self.issue_main.create_comment(pr_comment) + response = self.pr.create_issue_comment(pr_comment) if hasattr(response, "user") and hasattr(response.user, "login"): self.github_user_id = response.user.login @@ -731,11 +762,11 @@ class GithubProvider(GitProvider): def _parse_issue_url(self, issue_url: str) -> Tuple[str, int]: parsed_url = urlparse(issue_url) - if 'github.com' not in parsed_url.netloc: - raise ValueError("The provided URL is not a valid GitHub URL") + if parsed_url.path.startswith('/api/v3'): #Check if came from github app + parsed_url = urlparse(issue_url.replace("/api/v3", "")) path_parts = parsed_url.path.strip('/').split('/') - if 'api.github.com' in parsed_url.netloc: + if 'api.github.com' in parsed_url.netloc or '/api/v3' in issue_url: #Check if came from github app if len(path_parts) < 5 or path_parts[3] != 'issues': raise ValueError("The provided URL does not appear to be a GitHub ISSUE URL") repo_name = '/'.join(path_parts[1:3]) diff --git a/pr_agent/tools/pr_help_docs.py b/pr_agent/tools/pr_help_docs.py index 5695065a..ddd42509 100644 --- a/pr_agent/tools/pr_help_docs.py +++ b/pr_agent/tools/pr_help_docs.py @@ -101,6 +101,7 @@ def aggregate_documentation_files_for_prompt_contents(base_path: str, doc_files: def format_markdown_q_and_a_response(question_str: str, response_str: str, relevant_sections: List[Dict[str, str]], supported_suffixes: List[str], base_url_prefix: str, base_url_suffix: str="") -> str: + base_url_prefix = base_url_prefix.strip('/') #Sanitize base_url_prefix answer_str = "" answer_str += f"### Question: \n{question_str}\n\n" answer_str += f"### Answer:\n{response_str.strip()}\n\n" @@ -114,9 +115,9 @@ def format_markdown_q_and_a_response(question_str: str, response_str: str, relev if str(section['relevant_section_header_string']).strip(): markdown_header = format_markdown_header(section['relevant_section_header_string']) if base_url_prefix: - answer_str += f"> - {base_url_prefix}{file}{base_url_suffix}#{markdown_header}\n" + answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}#{markdown_header}\n" else: - answer_str += f"> - {base_url_prefix}{file}{base_url_suffix}\n" + answer_str += f"> - {base_url_prefix}/{file}{base_url_suffix}\n" return answer_str def format_markdown_header(header: str) -> str: