Merge branch 'main' into of/fix-improve-notes

This commit is contained in:
ofir-frd
2025-02-21 10:38:21 +02:00
3 changed files with 254 additions and 20 deletions

View File

@ -5,6 +5,7 @@ import itertools
import re import re
import time import time
import traceback import traceback
import json
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -887,6 +888,84 @@ class GithubProvider(GitProvider):
except: except:
return "" return ""
def fetch_sub_issues(self, issue_url):
"""
Fetch sub-issues linked to the given GitHub issue URL using GraphQL via PyGitHub.
"""
sub_issues = set()
# Extract owner, repo, and issue number from URL
parts = issue_url.rstrip("/").split("/")
owner, repo, issue_number = parts[-4], parts[-3], parts[-1]
try:
# Gets Issue ID from Issue Number
query = f"""
query {{
repository(owner: "{owner}", name: "{repo}") {{
issue(number: {issue_number}) {{
id
}}
}}
}}
"""
response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql",
input={"query": query})
# Extract the JSON response from the tuple and parses it
if isinstance(response_tuple, tuple) and len(response_tuple) == 3:
response_json = json.loads(response_tuple[2])
else:
get_logger().error(f"Unexpected response format: {response_tuple}")
return sub_issues
issue_id = response_json.get("data", {}).get("repository", {}).get("issue", {}).get("id")
if not issue_id:
get_logger().warning(f"Issue ID not found for {issue_url}")
return sub_issues
# Fetch Sub-Issues
sub_issues_query = f"""
query {{
node(id: "{issue_id}") {{
... on Issue {{
subIssues(first: 10) {{
nodes {{
url
}}
}}
}}
}}
}}
"""
sub_issues_response_tuple = self.github_client._Github__requester.requestJson("POST", "/graphql", input={
"query": sub_issues_query})
# Extract the JSON response from the tuple and parses it
if isinstance(sub_issues_response_tuple, tuple) and len(sub_issues_response_tuple) == 3:
sub_issues_response_json = json.loads(sub_issues_response_tuple[2])
else:
get_logger().error("Unexpected sub-issues response format", artifact={"response": sub_issues_response_tuple})
return sub_issues
if not sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues"):
get_logger().error("Invalid sub-issues response structure")
return sub_issues
nodes = sub_issues_response_json.get("data", {}).get("node", {}).get("subIssues", {}).get("nodes", [])
get_logger().info(f"Github Sub-issues fetched: {len(nodes)}", artifact={"nodes": nodes})
for sub_issue in nodes:
if "url" in sub_issue:
sub_issues.add(sub_issue["url"])
except Exception as e:
get_logger().exception(f"Failed to fetch sub-issues. Error: {e}")
return sub_issues
def auto_approve(self) -> bool: def auto_approve(self) -> bool:
try: try:
res = self.pr.create_review(event="APPROVE") res = self.pr.create_review(event="APPROVE")

View File

@ -70,41 +70,65 @@ async def extract_tickets(git_provider):
user_description = git_provider.get_user_description() user_description = git_provider.get_user_description()
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html) tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo, git_provider.base_url_html)
tickets_content = [] tickets_content = []
if tickets: if tickets:
for ticket in tickets: for ticket in tickets:
# extract ticket number and repo name
repo_name, original_issue_number = git_provider._parse_issue_url(ticket) repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
# get the ticket object
try: try:
issue_main = git_provider.repo_obj.get_issue(original_issue_number) issue_main = git_provider.repo_obj.get_issue(original_issue_number)
except Exception as e: except Exception as e:
get_logger().error(f"Error getting issue_main error= {e}", get_logger().error(f"Error getting main issue: {e}",
artifact={"traceback": traceback.format_exc()}) artifact={"traceback": traceback.format_exc()})
continue continue
# clip issue_main.body max length issue_body_str = issue_main.body or ""
issue_body_str = issue_main.body
if not issue_body_str:
issue_body_str = ""
if len(issue_body_str) > MAX_TICKET_CHARACTERS: if len(issue_body_str) > MAX_TICKET_CHARACTERS:
issue_body_str = issue_body_str[:MAX_TICKET_CHARACTERS] + "..." issue_body_str = issue_body_str[:MAX_TICKET_CHARACTERS] + "..."
# extract labels # Extract sub-issues
sub_issues_content = []
try:
sub_issues = git_provider.fetch_sub_issues(ticket)
for sub_issue_url in sub_issues:
try:
sub_repo, sub_issue_number = git_provider._parse_issue_url(sub_issue_url)
sub_issue = git_provider.repo_obj.get_issue(sub_issue_number)
sub_body = sub_issue.body or ""
if len(sub_body) > MAX_TICKET_CHARACTERS:
sub_body = sub_body[:MAX_TICKET_CHARACTERS] + "..."
sub_issues_content.append({
'ticket_url': sub_issue_url,
'title': sub_issue.title,
'body': sub_body
})
except Exception as e:
get_logger().warning(f"Failed to fetch sub-issue content for {sub_issue_url}: {e}")
except Exception as e:
get_logger().warning(f"Failed to fetch sub-issues for {ticket}: {e}")
# Extract labels
labels = [] labels = []
try: try:
for label in issue_main.labels: for label in issue_main.labels:
if isinstance(label, str): labels.append(label.name if hasattr(label, 'name') else label)
labels.append(label)
else:
labels.append(label.name)
except Exception as e: except Exception as e:
get_logger().error(f"Error extracting labels error= {e}", get_logger().error(f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()}) artifact={"traceback": traceback.format_exc()})
tickets_content.append(
{'ticket_id': issue_main.number, tickets_content.append({
'ticket_url': ticket, 'title': issue_main.title, 'body': issue_body_str, 'ticket_id': issue_main.number,
'labels': ", ".join(labels)}) 'ticket_url': ticket,
'title': issue_main.title,
'body': issue_body_str,
'labels': ", ".join(labels),
'sub_issues': sub_issues_content # Store sub-issues content
})
return tickets_content return tickets_content
except Exception as e: except Exception as e:
@ -115,14 +139,27 @@ async def extract_tickets(git_provider):
async def extract_and_cache_pr_tickets(git_provider, vars): async def extract_and_cache_pr_tickets(git_provider, vars):
if not get_settings().get('pr_reviewer.require_ticket_analysis_review', False): if not get_settings().get('pr_reviewer.require_ticket_analysis_review', False):
return return
related_tickets = get_settings().get('related_tickets', []) related_tickets = get_settings().get('related_tickets', [])
if not related_tickets: if not related_tickets:
tickets_content = await extract_tickets(git_provider) tickets_content = await extract_tickets(git_provider)
if tickets_content: if tickets_content:
get_logger().info("Extracted tickets from PR description", artifact={"tickets": tickets_content}) # Store sub-issues along with main issues
vars['related_tickets'] = tickets_content for ticket in tickets_content:
get_settings().set('related_tickets', tickets_content) if "sub_issues" in ticket and ticket["sub_issues"]:
else: # if tickets are already cached for sub_issue in ticket["sub_issues"]:
related_tickets.append(sub_issue) # Add sub-issues content
related_tickets.append(ticket)
get_logger().info("Extracted tickets and sub-issues from PR description",
artifact={"tickets": related_tickets})
vars['related_tickets'] = related_tickets
get_settings().set('related_tickets', related_tickets)
else:
get_logger().info("Using cached tickets", artifact={"tickets": related_tickets}) get_logger().info("Using cached tickets", artifact={"tickets": related_tickets})
vars['related_tickets'] = related_tickets vars['related_tickets'] = related_tickets

View File

@ -0,0 +1,118 @@
import unittest
import asyncio
from unittest.mock import AsyncMock, patch
from pr_agent.tools.ticket_pr_compliance_check import extract_tickets, extract_and_cache_pr_tickets
from pr_agent.git_providers.github_provider import GithubProvider
class TestTicketCompliance(unittest.TestCase):
@patch.object(GithubProvider, 'get_user_description', return_value="Fixes #1 and relates to #2")
@patch.object(GithubProvider, '_parse_issue_url', side_effect=lambda url: ("WonOfAKind/KimchiBot", int(url.split('#')[-1])))
@patch.object(GithubProvider, 'repo_obj')
async def test_extract_tickets(self, mock_repo, mock_parse_issue_url, mock_user_desc):
"""
Test extract_tickets() to ensure it extracts tickets correctly
and fetches their content.
"""
github_provider = GithubProvider()
github_provider.repo = "WonOfAKind/KimchiBot"
github_provider.base_url_html = "https://github.com"
# Mock issue retrieval
mock_issue = AsyncMock()
mock_issue.number = 1
mock_issue.title = "Sample Issue"
mock_issue.body = "This is a test issue body."
mock_issue.labels = ["bug", "high priority"]
# Mock repo object
mock_repo.get_issue.return_value = mock_issue
tickets = await extract_tickets(github_provider)
# Verify tickets were extracted correctly
self.assertIsInstance(tickets, list)
self.assertGreater(len(tickets), 0, "Expected at least one ticket!")
# Verify ticket structure
first_ticket = tickets[0]
self.assertIn("ticket_id", first_ticket)
self.assertIn("ticket_url", first_ticket)
self.assertIn("title", first_ticket)
self.assertIn("body", first_ticket)
self.assertIn("labels", first_ticket)
print("\n Test Passed: extract_tickets() successfully retrieved ticket info!")
@patch.object(GithubProvider, 'get_user_description', return_value="Fixes #1 and relates to #2")
@patch.object(GithubProvider, '_parse_issue_url', side_effect=lambda url: ("WonOfAKind/KimchiBot", int(url.split('#')[-1])))
@patch.object(GithubProvider, 'repo_obj')
async def test_extract_and_cache_pr_tickets(self, mock_repo, mock_parse_issue_url, mock_user_desc):
"""
Test extract_and_cache_pr_tickets() to ensure tickets are extracted and cached correctly.
"""
github_provider = GithubProvider()
github_provider.repo = "WonOfAKind/KimchiBot"
github_provider.base_url_html = "https://github.com"
vars = {} # Simulate the dictionary to store results
# Mock issue retrieval
mock_issue = AsyncMock()
mock_issue.number = 1
mock_issue.title = "Sample Issue"
mock_issue.body = "This is a test issue body."
mock_issue.labels = ["bug", "high priority"]
# Mock repo object
mock_repo.get_issue.return_value = mock_issue
# Run function
await extract_and_cache_pr_tickets(github_provider, vars)
# Ensure tickets are cached
self.assertIn("related_tickets", vars)
self.assertIsInstance(vars["related_tickets"], list)
self.assertGreater(len(vars["related_tickets"]), 0, "Expected at least one cached ticket!")
print("\n Test Passed: extract_and_cache_pr_tickets() successfully cached ticket data!")
def test_fetch_sub_issues(self):
"""
Test fetch_sub_issues() to ensure sub-issues are correctly retrieved.
"""
github_provider = GithubProvider()
issue_url = "https://github.com/WonOfAKind/KimchiBot/issues/1" # Known issue with sub-issues
result = github_provider.fetch_sub_issues(issue_url)
print("Fetched sub-issues:", result)
self.assertIsInstance(result, set) # Ensure result is a set
self.assertGreater(len(result), 0, "Expected at least one sub-issue but found none!")
print("\n Test Passed: fetch_sub_issues() retrieved sub-issues correctly!")
def test_fetch_sub_issues_with_no_results(self):
"""
Test fetch_sub_issues() to ensure an empty set is returned for an issue with no sub-issues.
"""
github_provider = GithubProvider()
issue_url = "https://github.com/qodo-ai/pr-agent/issues/1499" # Likely non-existent issue
result = github_provider.fetch_sub_issues(issue_url)
print("Fetched sub-issues for non-existent issue:", result)
self.assertIsInstance(result, set) # Ensure result is a set
self.assertEqual(len(result), 0, "Expected no sub-issues but some were found!")
print("\n Test Passed: fetch_sub_issues_with_no_results() correctly returned an empty set!")
if __name__ == "__main__":
asyncio.run(unittest.main())