feat: add ticket compliance check

- Implement ticket compliance check logic in `utils.py` and `ticket_pr_compliance_check.py`
- Add functions to extract and cache PR tickets, and check ticket relevancy
This commit is contained in:
mrT23
2024-10-10 08:48:37 +03:00
parent 014ea884d2
commit 76d95bb6d7
12 changed files with 365 additions and 86 deletions

View File

@ -20,6 +20,8 @@ from pr_agent.git_providers import get_git_provider, GithubProvider, get_git_pro
from pr_agent.git_providers.git_provider import get_main_pr_language
from pr_agent.log import get_logger
from pr_agent.servers.help import HelpMessage
from pr_agent.tools.ticket_pr_compliance_check import extract_ticket_links_from_pr_description, extract_tickets, \
extract_and_cache_pr_tickets
class PRDescription:
@ -38,6 +40,7 @@ class PRDescription:
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.pr_id = self.git_provider.get_pr_id()
self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"]
if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported(
"gfm_markdown"):
@ -60,6 +63,7 @@ class PRDescription:
"enable_custom_labels": get_settings().config.enable_custom_labels,
"custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function
"enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types,
"related_tickets": "",
}
self.user_description = self.git_provider.get_user_description()
@ -87,6 +91,9 @@ class PRDescription:
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
self.git_provider.publish_comment("Preparing PR description...", is_temporary=True)
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
if self.prediction:
@ -226,7 +233,7 @@ class PRDescription:
file_description_str_list = []
for i, result in enumerate(results):
prediction_files = result.strip().removeprefix('```yaml').strip('`').strip()
if load_yaml(prediction_files) and prediction_files.startswith('pr_files'):
if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'):
prediction_files = prediction_files.removeprefix('pr_files:').strip()
file_description_str_list.append(prediction_files)
else:
@ -304,16 +311,16 @@ extra_file_yaml =
# final processing
self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough
if not load_yaml(self.prediction):
if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix):
get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}")
if load_yaml(prediction_headers):
if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix):
get_logger().debug(f"Using only headers for describe {self.pr_id}")
self.prediction = prediction_headers
async def extend_additional_files(self, remaining_files_list) -> str:
prediction = self.prediction
try:
original_prediction_dict = load_yaml(self.prediction)
original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix)
prediction_extra = "pr_files:"
for file in remaining_files_list:
extra_file_yaml = f"""\
@ -327,12 +334,12 @@ extra_file_yaml =
additional files (token-limit)
"""
prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip()
prediction_extra_dict = load_yaml(prediction_extra)
prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix)
# merge the two dictionaries
if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict):
original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"])
new_yaml = yaml.dump(original_prediction_dict)
if load_yaml(new_yaml):
if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix):
prediction = new_yaml
return prediction
except Exception as e:
@ -361,7 +368,7 @@ extra_file_yaml =
def _prepare_data(self):
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip())
self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix)
if get_settings().pr_description.add_original_user_description and self.user_description:
self.data["User Description"] = self.user_description

View File

@ -1,5 +1,6 @@
import copy
import datetime
import traceback
from collections import OrderedDict
from functools import partial
from typing import List, Tuple
@ -15,6 +16,7 @@ from pr_agent.git_providers import get_git_provider, get_git_provider_with_conte
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
from pr_agent.log import get_logger
from pr_agent.servers.help import HelpMessage
from pr_agent.tools.ticket_pr_compliance_check import extract_tickets, extract_and_cache_pr_tickets
class PRReviewer:
@ -84,6 +86,7 @@ class PRReviewer:
"custom_labels": "",
"enable_custom_labels": get_settings().config.enable_custom_labels,
"is_ai_metadata": get_settings().get("config.enable_ai_metadata", False),
"related_tickets": get_settings().get('related_tickets', []),
}
self.token_handler = TokenHandler(
@ -121,6 +124,9 @@ class PRReviewer:
'config': dict(get_settings().config)}
get_logger().debug("Relevant configs", artifacts=relevant_configs)
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)
if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set:
get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files")
previous_review_url = ""
@ -207,7 +213,7 @@ class PRReviewer:
first_key = 'review'
last_key = 'security_concerns'
data = load_yaml(self.prediction.strip(),
keys_fix_yaml=["estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
"relevant_file:", "relevant_line:", "suggestion:"],
first_key=first_key, last_key=last_key)
github_action_output(data, 'review')
@ -282,7 +288,7 @@ class PRReviewer:
first_key = 'review'
last_key = 'security_concerns'
data = load_yaml(self.prediction.strip(),
keys_fix_yaml=["estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:",
"relevant_file:", "relevant_line:", "suggestion:"],
first_key=first_key, last_key=last_key)
comments: List[str] = []
@ -401,7 +407,16 @@ class PRReviewer:
review_labels = []
if get_settings().pr_reviewer.enable_review_labels_effort:
estimated_effort = data['review']['estimated_effort_to_review_[1-5]']
estimated_effort_number = int(estimated_effort.split(',')[0])
estimated_effort_number = 0
if isinstance(estimated_effort, str):
try:
estimated_effort_number = int(estimated_effort.split(',')[0])
except ValueError:
get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}")
elif isinstance(estimated_effort, int):
estimated_effort_number = estimated_effort
else:
get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}")
if 1 <= estimated_effort_number <= 5: # 1, because ...
review_labels.append(f'Review effort [1-5]: {estimated_effort_number}')
if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review:

View File

@ -0,0 +1,113 @@
import re
import traceback
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import GithubProvider
from pr_agent.log import get_logger
def find_jira_tickets(text):
# Regular expression patterns for JIRA tickets
patterns = [
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
]
tickets = set()
for pattern in patterns:
matches = re.findall(pattern, text)
for match in matches:
if isinstance(match, tuple):
# If it's a tuple (from the URL pattern), take the last non-empty group
ticket = next((m for m in reversed(match) if m), None)
else:
ticket = match
if ticket:
tickets.add(ticket)
return list(tickets)
def extract_ticket_links_from_pr_description(pr_description, repo_path):
"""
Extract all ticket links from PR description
"""
# example link to search for: https://github.com/Codium-ai/pr-agent-pro/issues/525
pattern = r'https://github[^/]+/[^/]+/[^/]+/issues/\d+' # should support also github server (for example 'https://github.company.ai/Codium-ai/pr-agent-pro/issues/525')
# Find all matches in the text
github_tickets = re.findall(pattern, pr_description)
# Find all issues referenced like #123 and add them as https://github.com/{repo_path}/issues/{issue_number}
# (unneeded, since when you pull the actual comment, it appears as a full link)
# issue_number_pattern = r'#\d+'
# issue_numbers = re.findall(issue_number_pattern, pr_description)
# for issue_number in issue_numbers:
# issue_number = issue_number[1:] # remove #
# # check if issue_number is a valid number and len(issue_number) < 5
# if issue_number.isdigit() and len(issue_number) < 5:
# github_tickets.append(f'https://github.com/{repo_path}/issues/{issue_number}')
return github_tickets
async def extract_tickets(git_provider):
MAX_TICKET_CHARACTERS = 10000
try:
if isinstance(git_provider, GithubProvider):
user_description = git_provider.get_user_description()
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo)
tickets_content = []
if tickets:
for ticket in tickets:
# extract ticket number and repo name
repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
# get the ticket object
issue_main = git_provider.repo_obj.get_issue(original_issue_number)
# clip issue_main.body max length
issue_body = issue_main.body
if len(issue_main.body) > MAX_TICKET_CHARACTERS:
issue_body = issue_main.body[:MAX_TICKET_CHARACTERS] + "..."
# extract labels
labels = []
try:
for label in issue_main.labels:
if isinstance(label, str):
labels.append(label)
else:
labels.append(label.name)
except Exception as e:
get_logger().error(f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()})
tickets_content.append(
{'ticket_id': issue_main.number,
'ticket_url': ticket, 'title': issue_main.title, 'body': issue_body,
'labels': ", ".join(labels)})
return tickets_content
except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()})
async def extract_and_cache_pr_tickets(git_provider, vars):
if get_settings().get('config.require_ticket_analysis_review', False):
return
related_tickets = get_settings().get('related_tickets', [])
if not related_tickets:
tickets_content = await extract_tickets(git_provider)
if tickets_content:
get_logger().info("Extracted tickets from PR description", artifact={"tickets": tickets_content})
vars['related_tickets'] = tickets_content
get_settings().set('related_tickets', tickets_content)
else: # if tickets are already cached
get_logger().info("Using cached tickets", artifact={"tickets": related_tickets})
vars['related_tickets'] = related_tickets
def check_tickets_relevancy():
return True