diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 1fd474e0..58b72f17 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -5,6 +5,7 @@ import itertools import re import time import traceback +import json from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse @@ -887,6 +888,84 @@ class GithubProvider(GitProvider): except: return "" + def fetch_sub_issues(self, issue_url): + """ + Fetch sub-issues linked to the given GitHub issue URL using GraphQL via PyGitHub. + """ + sub_issues = set() + + # Extract owner, repo, and issue number from URL + parts = issue_url.rstrip("/").split("/") + owner, repo, issue_number = parts[-4], parts[-3], parts[-1] + + try: + # Gets Issue ID from Issue Number + query = f""" + query {{ + repository(owner: "{owner}", name: "{repo}") {{ + issue(number: {issue_number}) {{ + id + }} + }} + }} + """ + response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", + input={"query": query}) + + # Extract the JSON response from the tuple and parses it + if isinstance(response_tuple, tuple) and len(response_tuple) == 3: + response_json = json.loads(response_tuple[2]) + else: + get_logger().error(f"Unexpected response format: {response_tuple}") + return sub_issues + + + issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id") + + if not issue_id: + get_logger().warning(f"Issue ID not found for {issue_url}") + return sub_issues + + # Fetch Sub-Issues + sub_issues_query = f""" + query {{ + node(id: "{issue_id}") {{ + ... on Issue {{ + subIssues(first: 10) {{ + nodes {{ + url + }} + }} + }} + }} + }} + """ + sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={ + "query": sub_issues_query}) + + # Extract the JSON response from the tuple and parses it + if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3: + sub_issues_response_json = json.loads(sub_issues_response_tuple[2]) + else: + get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple}) + return sub_issues + + if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"): + get_logger().error("Invalid sub-issues response structure") + return sub_issues + + nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", []) + get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes}) + + for sub_issue in nodes: + if "url" in sub_issue: + sub_issues.add(sub_issue["url"]) + + except Exception as e: + get_logger().exception(f"Failed to fetch sub-issues. Error: {e}") + + return sub_issues + def auto_approve(self) -> bool: try: res = self.pr.create_review(event="APPROVE") diff --git a/pr_agent/tools/ticket_pr_compliance_check.py b/pr_agent/tools/ticket_pr_compliance_check.py index 54c72eb9..45baa0d2 100644 --- a/pr_agent/tools/ticket_pr_compliance_check.py +++ b/pr_agent/tools/ticket_pr_compliance_check.py @@ -70,41 +70,65 @@ async def extract_tickets(git_provider): user_description = git_provider.get_user_description() tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html) tickets_content = [] + if tickets: + for ticket in tickets: - # extract ticket number and repo name repo_name, original_issue_number = git_provider._parse_issue_url(ticket) - # get the ticket object try: issue_main = git_provider.repo_obj.get_issue(original_issue_number) except Exception as e: - get_logger().error(f"Error getting issue_main error= {e}", + get_logger().error(f"Error getting main issue: {e}", artifact={"traceback": traceback.format_exc()}) continue - # clip issue_main.body max length - issue_body_str = issue_main.body - if not issue_body_str: - issue_body_str = "" + issue_body_str = issue_main.body or "" if len(issue_body_str) > MAX_TICKET_CHARACTERS: issue_body_str = issue_body_str[:MAX_TICKET_CHARACTERS] + "..." - # extract labels + # Extract sub-issues + sub_issues_content = [] + try: + sub_issues = git_provider.fetch_sub_issues(ticket) + for sub_issue_url in sub_issues: + try: + sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url) + sub_issue = git_provider.repo_obj.get_issue(sub_issue_number) + + sub_body = sub_issue.body or "" + if len(sub_body) > MAX_TICKET_CHARACTERS: + sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..." + + sub_issues_content.append({ + 'ticket_url': sub_issue_url, + 'title': sub_issue.title, + 'body': sub_body + }) + except Exception as e: + get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}") + + except Exception as e: + get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}") + + # Extract labels labels = [] try: for label in issue_main.labels: - if isinstance(label, str): - labels.append(label) - else: - labels.append(label.name) + labels.append(label.name if hasattr(label, 'name') else label) except Exception as e: get_logger().error(f"Error extracting labels error= {e}", artifact={"traceback": traceback.format_exc()}) - tickets_content.append( - {'ticket_id': issue_main.number, - 'ticket_url': ticket, 'title': issue_main.title, 'body': issue_body_str, - 'labels': ", ".join(labels)}) + + tickets_content.append({ + 'ticket_id': issue_main.number, + 'ticket_url': ticket, + 'title': issue_main.title, + 'body': issue_body_str, + 'labels': ", ".join(labels), + 'sub_issues': sub_issues_content # Store sub-issues content + }) + return tickets_content except Exception as e: @@ -115,14 +139,27 @@ async def extract_tickets(git_provider): async def extract_and_cache_pr_tickets(git_provider, vars): if not get_settings().get('pr_reviewer.require_ticket_analysis_review', False): return + related_tickets = get_settings().get('related_tickets', []) + if not related_tickets: tickets_content = await extract_tickets(git_provider) + if tickets_content: - get_logger().info("Extracted tickets from PR description", artifact={"tickets": tickets_content}) - vars['related_tickets'] = tickets_content - get_settings().set('related_tickets', tickets_content) - else: # if tickets are already cached + # Store sub-issues along with main issues + for ticket in tickets_content: + if "sub_issues" in ticket and ticket["sub_issues"]: + for sub_issue in ticket["sub_issues"]: + related_tickets.append(sub_issue) # Add sub-issues content + + related_tickets.append(ticket) + + get_logger().info("Extracted tickets and sub-issues from PR description", + artifact={"tickets": related_tickets}) + + vars['related_tickets'] = related_tickets + get_settings().set('related_tickets', related_tickets) + else: get_logger().info("Using cached tickets", artifact={"tickets": related_tickets}) vars['related_tickets'] = related_tickets diff --git a/tests/unittest/test_fetching_sub_issues.py b/tests/unittest/test_fetching_sub_issues.py new file mode 100644 index 00000000..48795866 --- /dev/null +++ b/tests/unittest/test_fetching_sub_issues.py @@ -0,0 +1,118 @@ +import unittest +import asyncio +from unittest.mock import AsyncMock, patch +from pr_agent.tools.ticket_pr_compliance_check import extract_tickets, extract_and_cache_pr_tickets +from pr_agent.git_providers.github_provider import GithubProvider + + +class TestTicketCompliance(unittest.TestCase): + + @patch.object(GithubProvider, 'get_user_description', return_value="Fixes #1 and relates to #2") + @patch.object(GithubProvider, '_parse_issue_url', side_effect=lambda url: ("WonOfAKind/KimchiBot", int(url.split('#')[-1]))) + @patch.object(GithubProvider, 'repo_obj') + async def test_extract_tickets(self, mock_repo, mock_parse_issue_url, mock_user_desc): + """ + Test extract_tickets() to ensure it extracts tickets correctly + and fetches their content. + """ + github_provider = GithubProvider() + github_provider.repo = "WonOfAKind/KimchiBot" + github_provider.base_url_html = "https://github.com" + + # Mock issue retrieval + mock_issue = AsyncMock() + mock_issue.number = 1 + mock_issue.title = "Sample Issue" + mock_issue.body = "This is a test issue body." + mock_issue.labels = ["bug", "high priority"] + + # Mock repo object + mock_repo.get_issue.return_value = mock_issue + + tickets = await extract_tickets(github_provider) + + # Verify tickets were extracted correctly + self.assertIsInstance(tickets, list) + self.assertGreater(len(tickets), 0, "Expected at least one ticket!") + + # Verify ticket structure + first_ticket = tickets[0] + self.assertIn("ticket_id", first_ticket) + self.assertIn("ticket_url", first_ticket) + self.assertIn("title", first_ticket) + self.assertIn("body", first_ticket) + self.assertIn("labels", first_ticket) + + print("\n Test Passed: extract_tickets() successfully retrieved ticket info!") + + @patch.object(GithubProvider, 'get_user_description', return_value="Fixes #1 and relates to #2") + @patch.object(GithubProvider, '_parse_issue_url', side_effect=lambda url: ("WonOfAKind/KimchiBot", int(url.split('#')[-1]))) + @patch.object(GithubProvider, 'repo_obj') + async def test_extract_and_cache_pr_tickets(self, mock_repo, mock_parse_issue_url, mock_user_desc): + """ + Test extract_and_cache_pr_tickets() to ensure tickets are extracted and cached correctly. + """ + github_provider = GithubProvider() + github_provider.repo = "WonOfAKind/KimchiBot" + github_provider.base_url_html = "https://github.com" + + vars = {} # Simulate the dictionary to store results + + # Mock issue retrieval + mock_issue = AsyncMock() + mock_issue.number = 1 + mock_issue.title = "Sample Issue" + mock_issue.body = "This is a test issue body." + mock_issue.labels = ["bug", "high priority"] + + # Mock repo object + mock_repo.get_issue.return_value = mock_issue + + # Run function + await extract_and_cache_pr_tickets(github_provider, vars) + + # Ensure tickets are cached + self.assertIn("related_tickets", vars) + self.assertIsInstance(vars["related_tickets"], list) + self.assertGreater(len(vars["related_tickets"]), 0, "Expected at least one cached ticket!") + + print("\n Test Passed: extract_and_cache_pr_tickets() successfully cached ticket data!") + + def test_fetch_sub_issues(self): + """ + Test fetch_sub_issues() to ensure sub-issues are correctly retrieved. + """ + github_provider = GithubProvider() + issue_url = "https://github.com/WonOfAKind/KimchiBot/issues/1" # Known issue with sub-issues + result = github_provider.fetch_sub_issues(issue_url) + + print("Fetched sub-issues:", result) + + self.assertIsInstance(result, set) # Ensure result is a set + self.assertGreater(len(result), 0, "Expected at least one sub-issue but found none!") + + print("\n Test Passed: fetch_sub_issues() retrieved sub-issues correctly!") + + def test_fetch_sub_issues_with_no_results(self): + """ + Test fetch_sub_issues() to ensure an empty set is returned for an issue with no sub-issues. + """ + github_provider = GithubProvider() + issue_url = "https://github.com/qodo-ai/pr-agent/issues/1499" # Likely non-existent issue + result = github_provider.fetch_sub_issues(issue_url) + + print("Fetched sub-issues for non-existent issue:", result) + + self.assertIsInstance(result, set) # Ensure result is a set + self.assertEqual(len(result), 0, "Expected no sub-issues but some were found!") + + print("\n Test Passed: fetch_sub_issues_with_no_results() correctly returned an empty set!") + + +if __name__ == "__main__": + asyncio.run(unittest.main()) + + + + +