From 4eef3e91905030ea843a64d9b54ba09980eb1f58 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Thu, 10 Oct 2024 08:48:37 +0300 Subject: [PATCH 1/2] feat: add ticket compliance check and rate limit validation - Implement ticket compliance check logic in `utils.py` and `ticket_pr_compliance_check.py` - Add functions to extract and cache PR tickets, and check ticket relevancy - Introduce rate limit validation for GitHub API requests - Update `pr_reviewer_prompts.toml` and `pr_description_prompts.toml` to include ticket compliance fields - Modify configuration to require ticket analysis review --- pr_agent/algo/utils.py | 186 +++++++++++++----- pr_agent/settings/configuration.toml | 4 +- pr_agent/settings/pr_description_prompts.toml | 25 ++- pr_agent/settings/pr_reviewer_prompts.toml | 54 ++++- pr_agent/tools/ticket_pr_compliance_check.py | 113 +++++++++++ 5 files changed, 324 insertions(+), 58 deletions(-) create mode 100644 pr_agent/tools/ticket_pr_compliance_check.py diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index bd0781ae..98c69d00 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -1,18 +1,22 @@ from __future__ import annotations -import html2text -import html import copy import difflib +import hashlib +import html import json import os import re import textwrap import time +import traceback from datetime import datetime from enum import Enum from typing import Any, List, Tuple + +import html2text +import requests import yaml from pydantic import BaseModel from starlette_context import context @@ -110,6 +114,7 @@ def convert_to_markdown_v2(output_data: dict, "Insights from user's answers": "📝", "Code feedback": "🤖", "Estimated effort to review [1-5]": "⏱️", + "Ticket compliance check": "🎫", } markdown_text = "" if not incremental_review: @@ -165,6 +170,8 @@ def convert_to_markdown_v2(output_data: dict, markdown_text += f'### {emoji} No relevant tests\n\n' else: markdown_text += f"### PR contains tests\n\n" + elif 'ticket compliance check' in key_nice.lower(): + markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) elif 'security concerns' in key_nice.lower(): if gfm_supported: markdown_text += f"" @@ -254,6 +261,52 @@ def convert_to_markdown_v2(output_data: dict, return markdown_text +def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str: + ticket_compliance_str = "" + final_compliance_level = -1 + if isinstance(value, list): + for v in value: + ticket_url = v.get('ticket_url', '').strip() + compliance_level = v.get('overall_compliance_level', '').strip() + # add emojis, if 'Fully compliant' ✅, 'Partially compliant' 🔶, or 'Not compliant' ❌ + if compliance_level.lower() == 'fully compliant': + # compliance_level = '✅ Fully compliant' + final_compliance_level = 2 if final_compliance_level == -1 else 1 + elif compliance_level.lower() == 'partially compliant': + # compliance_level = '🔶 Partially compliant' + final_compliance_level = 1 + elif compliance_level.lower() == 'not compliant': + # compliance_level = '❌ Not compliant' + final_compliance_level = 0 if final_compliance_level < 1 else 1 + + # explanation = v.get('compliance_analysis', '').strip() + explanation = '' + fully_compliant_str = v.get('fully_compliant_requirements', '').strip() + not_compliant_str = v.get('not_compliant_requirements', '').strip() + if fully_compliant_str: + explanation += f"Fully compliant requirements:\n{fully_compliant_str}\n\n" + if not_compliant_str: + explanation += f"Not compliant requirements:\n{not_compliant_str}\n\n" + + ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {compliance_level}**\n\n{explanation}\n\n" + if final_compliance_level == 2: + compliance_level = '✅' + elif final_compliance_level == 1: + compliance_level = '🔶' + else: + compliance_level = '❌' + + if gfm_supported: + markdown_text += f"\n\n" + markdown_text += f"**{emoji} Ticket compliance analysis {compliance_level}**\n\n" + markdown_text += ticket_compliance_str + markdown_text += f"\n" + else: + markdown_text += f"### {emoji} Ticket compliance analysis {compliance_level}\n\n" + markdown_text += ticket_compliance_str+"\n\n" + return markdown_text + + def process_can_be_split(emoji, value): try: # key_nice = "Can this PR be split?" @@ -554,7 +607,8 @@ def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", l get_logger().warning(f"Initial failure to parse AI prediction: {e}") data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key) if not data: - get_logger().error(f"Failed to parse AI prediction after fallbacks", artifact={'response_text': response_text}) + get_logger().error(f"Failed to parse AI prediction after fallbacks", + artifact={'response_text': response_text}) else: get_logger().info(f"Successfully parsed AI prediction after fallbacks", artifact={'response_text': response_text}) @@ -841,56 +895,64 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], break return position, absolute_position -def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None): - if git_provider and not rate_limit_status: - rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data} +def get_rate_limit_status(github_token) -> dict: + GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com" + # GITHUB_API_URL = "https://api.github.com" + RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit" + HEADERS = { + "Accept": "application/vnd.github.v3+json", + "Authorization": f"token {github_token}" + } - if not rate_limit_status: - rate_limit_status = get_rate_limit_status_func() + response = requests.get(RATE_LIMIT_URL, headers=HEADERS) + try: + rate_limit_info = response.json() + if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise + return {'resources': {}} + response.raise_for_status() # Check for HTTP errors + except: # retry + time.sleep(0.1) + response = requests.get(RATE_LIMIT_URL, headers=HEADERS) + return response.json() + return rate_limit_info + + +def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool: + try: + rate_limit_status = get_rate_limit_status(github_token) + if installation_id: + get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}") # validate that the rate limit is not exceeded - is_rate_limit = False - for key, value in rate_limit_status['resources'].items(): - if value['remaining'] == 0: - print(f"key: {key}, value: {value}") - is_rate_limit = True - sleep_time_sec = value['reset'] - datetime.now().timestamp() - sleep_time_hour = sleep_time_sec / 3600.0 - print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") - if sleep_time_sec > 0: - time.sleep(sleep_time_sec+1) - - if git_provider: - rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data} - else: - rate_limit_status = get_rate_limit_status_func() - - return is_rate_limit + # validate that the rate limit is not exceeded + for key, value in rate_limit_status['resources'].items(): + if value['remaining'] < value['limit'] * threshold: + get_logger().error(f"key: {key}, value: {value}") + return False + return True + except Exception as e: + get_logger().error(f"Error in rate limit {e}", + artifact={"traceback": traceback.format_exc()}) + return True -def get_largest_component(pr_url): - from pr_agent.tools.pr_analyzer import PRAnalyzer - publish_output = get_settings().config.publish_output - get_settings().config.publish_output = False # disable publish output - analyzer = PRAnalyzer(pr_url) - methods_dict_files = analyzer.run_sync() - get_settings().config.publish_output = publish_output - max_lines_changed = 0 - file_b = "" - component_name_b = "" - for file in methods_dict_files: - for method in methods_dict_files[file]: - try: - if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed: - max_lines_changed = methods_dict_files[file][method]['num_plus_lines'] - file_b = file - component_name_b = method - except: - pass - if component_name_b: - get_logger().info(f"Using the largest changed component: '{component_name_b}'") - return component_name_b, file_b - else: - return None, None +def validate_and_await_rate_limit(github_token): + try: + rate_limit_status = get_rate_limit_status(github_token) + # validate that the rate limit is not exceeded + for key, value in rate_limit_status['resources'].items(): + if value['remaining'] < value['limit'] // 80: + get_logger().error(f"key: {key}, value: {value}") + sleep_time_sec = value['reset'] - datetime.now().timestamp() + sleep_time_hour = sleep_time_sec / 3600.0 + get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") + if sleep_time_sec > 0: + time.sleep(sleep_time_sec + 1) + rate_limit_status = get_rate_limit_status(github_token) + return rate_limit_status + except: + get_logger().error("Error in rate limit") + return None + def github_action_output(output_data: dict, key_name: str): try: @@ -906,7 +968,7 @@ def github_action_output(output_data: dict, key_name: str): def show_relevant_configurations(relevant_section: str) -> str: - skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", + skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect", 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME'] extra_skip_keys = get_settings().config.get('config.skip_keys', []) if extra_skip_keys: @@ -939,6 +1001,25 @@ def is_value_no(value): return False +def set_pr_string(repo_name, pr_number): + return f"{repo_name}#{pr_number}" + + +def string_to_uniform_number(s: str) -> float: + """ + Convert a string to a uniform number in the range [0, 1]. + The uniform distribution is achieved by the nature of the SHA-256 hash function, which produces a uniformly distributed hash value over its output space. + """ + # Generate a hash of the string + hash_object = hashlib.sha256(s.encode()) + # Convert the hash to an integer + hash_int = int(hash_object.hexdigest(), 16) + # Normalize the integer to the range [0, 1] + max_hash_int = 2 ** 256 - 1 + uniform_number = float(hash_int) / max_hash_int + return uniform_number + + def process_description(description_full: str) -> Tuple[str, List]: if not description_full: return "", [] @@ -998,7 +1079,10 @@ def process_description(description_full: str) -> Tuple[str, List]: 'long_summary': long_summary }) else: - get_logger().error(f"Failed to parse description", artifact={'description': file_data}) + if '...' in file_data: + pass # PR with many files. some did not get analyzed + else: + get_logger().error(f"Failed to parse description", artifact={'description': file_data}) except Exception as e: get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data}) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index add06bb6..d4ecf379 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -51,9 +51,7 @@ require_tests_review=true require_estimate_effort_to_review=true require_can_be_split_review=false require_security_review=true -# soc2 -require_soc2_ticket=false -soc2_ticket_prompt="Does the PR description include a link to ticket in a project management system (e.g., Jira, Asana, Trello, etc.) ?" +require_ticket_analysis_review=true # general options num_code_suggestions=0 inline_code_comments = false diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index de7c3d54..364dd9af 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -78,9 +78,9 @@ pr_files: ... ... {%- endif %} -description: |- +description: | ... -title: |- +title: | ... {%- if enable_custom_labels %} labels: @@ -94,7 +94,26 @@ labels: Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|') """ -user="""PR Info: +user=""" +{%- if related_tickets %} +Related Ticket Info: +{% for ticket in related_tickets %} +===== +Ticket Title: '{{ ticket.title }}' +{%- if ticket.labels %} +Ticket Labels: {{ ticket.labels }} +{%- endif %} +{%- if ticket.body %} +Ticket Description: +##### +{{ ticket.body }} +##### +{%- endif %} +===== +{% endfor %} +{%- endif %} + +PR Info: Previous title: '{{title}}' diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index e3b4bfe4..0b61e8ea 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -85,7 +85,20 @@ class KeyIssuesComponentLink(BaseModel): start_line: int = Field(description="The start line that corresponds to this issue in the relevant file") end_line: int = Field(description="The end line that corresponds to this issue in the relevant file") +{%- if related_tickets %} + +class TicketCompliance(BaseModel): + ticket_url: str = Field(description="Ticket URL or ID") + ticket_requirements: str = Field(description="Repeat, in your own words, all ticket requirements, in bullet points") + fully_compliant_requirements: str = Field(description="A list, in bullet points, of which requirements are met by the PR code. Don't explain how the requirements are met, just list them shortly. Can be empty") + not_compliant_requirements: str = Field(description="A list, in bullet points, of which requirements are not met by the PR code. Don't explain how the requirements are not met, just list them shortly. Can be empty") + overall_compliance_level: str = Field(description="Overall give this PR one of these three values in relation to the ticket: 'Fully compliant', 'Partially compliant', or 'Not compliant'") +{%- endif %} + class Review(BaseModel): +{%- if related_tickets %} + ticket_compliance_check: List[TicketCompliance] = Field(description="A list of compliance checks for the related tickets") +{%- endif %} {%- if require_estimate_effort_to_review %} estimated_effort_to_review_[1-5]: int = Field(description="Estimate, on a scale of 1-5 (inclusive), the time and effort required to review this PR by an experienced and knowledgeable developer. 1 means short and easy review , 5 means long and hard review. Take into account the size, complexity, quality, and the needed changes of the PR code diff.") {%- endif %} @@ -130,6 +143,19 @@ class PRReview(BaseModel): Example output: ```yaml review: +{%- if related_tickets %} + ticket_compliance_check: + - ticket_url: | + ... + ticket_requirements: | + ... + fully_compliant_requirements: | + ... + not_compliant_requirements: | + ... + overall_compliance_level: | + ... +{%- endif %} {%- if require_estimate_effort_to_review %} estimated_effort_to_review_[1-5]: | 3 @@ -176,7 +202,33 @@ code_feedback: Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|') """ -user="""--PR Info-- +user=""" +{%- if related_tickets %} +--PR Ticket Info-- +{%- for ticket in related_tickets %} +===== +Ticket URL: '{{ ticket.ticket_url }}' + +Ticket Title: '{{ ticket.title }}' + +{%- if ticket.labels %} + +Ticket Labels: {{ ticket.labels }} + +{%- endif %} +{%- if ticket.body %} + +Ticket Description: +##### +{{ ticket.body }} +##### +{%- endif %} +===== +{% endfor %} +{%- endif %} + + +--PR Info-- Title: '{{title}}' diff --git a/pr_agent/tools/ticket_pr_compliance_check.py b/pr_agent/tools/ticket_pr_compliance_check.py new file mode 100644 index 00000000..03fdc88b --- /dev/null +++ b/pr_agent/tools/ticket_pr_compliance_check.py @@ -0,0 +1,113 @@ +import re +import traceback + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import GithubProvider +from pr_agent.log import get_logger + + +def find_jira_tickets(text): + # Regular expression patterns for JIRA tickets + patterns = [ + r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123) + r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket + ] + + tickets = set() + for pattern in patterns: + matches = re.findall(pattern, text) + for match in matches: + if isinstance(match, tuple): + # If it's a tuple (from the URL pattern), take the last non-empty group + ticket = next((m for m in reversed(match) if m), None) + else: + ticket = match + if ticket: + tickets.add(ticket) + + return list(tickets) + + +def extract_ticket_links_from_pr_description(pr_description, repo_path): + """ + Extract all ticket links from PR description + """ + + # example link to search for: https://github.com/Codium-ai/pr-agent-pro/issues/525 + pattern = r'https://github[^/]+/[^/]+/[^/]+/issues/\d+' # should support also github server (for example 'https://github.company.ai/Codium-ai/pr-agent-pro/issues/525') + + # Find all matches in the text + github_tickets = re.findall(pattern, pr_description) + + # Find all issues referenced like #123 and add them as https://github.com/{repo_path}/issues/{issue_number} + # (unneeded, since when you pull the actual comment, it appears as a full link) + # issue_number_pattern = r'#\d+' + # issue_numbers = re.findall(issue_number_pattern, pr_description) + # for issue_number in issue_numbers: + # issue_number = issue_number[1:] # remove # + # # check if issue_number is a valid number and len(issue_number) < 5 + # if issue_number.isdigit() and len(issue_number) < 5: + # github_tickets.append(f'https://github.com/{repo_path}/issues/{issue_number}') + + return github_tickets + + +async def extract_tickets(git_provider): + MAX_TICKET_CHARACTERS = 10000 + try: + if isinstance(git_provider, GithubProvider): + user_description = git_provider.get_user_description() + tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo) + tickets_content = [] + if tickets: + for ticket in tickets: + # extract ticket number and repo name + repo_name, original_issue_number = git_provider._parse_issue_url(ticket) + + # get the ticket object + issue_main = git_provider.repo_obj.get_issue(original_issue_number) + + # clip issue_main.body max length + issue_body = issue_main.body + if len(issue_main.body) > MAX_TICKET_CHARACTERS: + issue_body = issue_main.body[:MAX_TICKET_CHARACTERS] + "..." + + # extract labels + labels = [] + try: + for label in issue_main.labels: + if isinstance(label, str): + labels.append(label) + else: + labels.append(label.name) + except Exception as e: + get_logger().error(f"Error extracting labels error= {e}", + artifact={"traceback": traceback.format_exc()}) + tickets_content.append( + {'ticket_id': issue_main.number, + 'ticket_url': ticket, 'title': issue_main.title, 'body': issue_body, + 'labels': ", ".join(labels)}) + return tickets_content + + except Exception as e: + get_logger().error(f"Error extracting tickets error= {e}", + artifact={"traceback": traceback.format_exc()}) + + +async def extract_and_cache_pr_tickets(git_provider, vars): + if get_settings().get('config.require_ticket_analysis_review', False): + return + related_tickets = get_settings().get('related_tickets', []) + if not related_tickets: + tickets_content = await extract_tickets(git_provider) + if tickets_content: + get_logger().info("Extracted tickets from PR description", artifact={"tickets": tickets_content}) + vars['related_tickets'] = tickets_content + get_settings().set('related_tickets', tickets_content) + else: # if tickets are already cached + get_logger().info("Using cached tickets", artifact={"tickets": related_tickets}) + vars['related_tickets'] = related_tickets + + +def check_tickets_relevancy(): + return True From 7db9a038055e37f10aadfaa4cd53a276a9b4b4a2 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Thu, 10 Oct 2024 08:53:07 +0300 Subject: [PATCH 2/2] feat: integrate ticket extraction and enhance YAML handling in PR tools - Add ticket extraction and caching functionality in `pr_description.py` and `pr_reviewer.py`. - Introduce `keys_fix` parameter to improve YAML loading robustness. - Enhance error handling for estimated effort parsing in `pr_reviewer.py`. --- pr_agent/tools/pr_description.py | 21 ++++++++++++++------- pr_agent/tools/pr_reviewer.py | 21 ++++++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 9f7d79d3..c965d84e 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -20,6 +20,8 @@ from pr_agent.git_providers import get_git_provider, GithubProvider, get_git_pro from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage +from pr_agent.tools.ticket_pr_compliance_check import extract_ticket_links_from_pr_description, extract_tickets, \ + extract_and_cache_pr_tickets class PRDescription: @@ -38,6 +40,7 @@ class PRDescription: self.git_provider.get_languages(), self.git_provider.get_files() ) self.pr_id = self.git_provider.get_pr_id() + self.keys_fix = ["filename:", "language:", "changes_summary:", "changes_title:", "description:", "title:"] if get_settings().pr_description.enable_semantic_files_types and not self.git_provider.is_supported( "gfm_markdown"): @@ -60,6 +63,7 @@ class PRDescription: "enable_custom_labels": get_settings().config.enable_custom_labels, "custom_labels_class": "", # will be filled if necessary in 'set_custom_labels' function "enable_semantic_files_types": get_settings().pr_description.enable_semantic_files_types, + "related_tickets": "", } self.user_description = self.git_provider.get_user_description() @@ -87,6 +91,9 @@ class PRDescription: if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False): self.git_provider.publish_comment("Preparing PR description...", is_temporary=True) + # ticket extraction if exists + await extract_and_cache_pr_tickets(self.git_provider, self.vars) + await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) if self.prediction: @@ -226,7 +233,7 @@ class PRDescription: file_description_str_list = [] for i, result in enumerate(results): prediction_files = result.strip().removeprefix('```yaml').strip('`').strip() - if load_yaml(prediction_files) and prediction_files.startswith('pr_files'): + if load_yaml(prediction_files, keys_fix_yaml=self.keys_fix) and prediction_files.startswith('pr_files'): prediction_files = prediction_files.removeprefix('pr_files:').strip() file_description_str_list.append(prediction_files) else: @@ -304,16 +311,16 @@ extra_file_yaml = # final processing self.prediction = prediction_headers + "\n" + "pr_files:\n" + files_walkthrough - if not load_yaml(self.prediction): + if not load_yaml(self.prediction, keys_fix_yaml=self.keys_fix): get_logger().error(f"Error getting valid YAML in large PR handling for describe {self.pr_id}") - if load_yaml(prediction_headers): + if load_yaml(prediction_headers, keys_fix_yaml=self.keys_fix): get_logger().debug(f"Using only headers for describe {self.pr_id}") self.prediction = prediction_headers async def extend_additional_files(self, remaining_files_list) -> str: prediction = self.prediction try: - original_prediction_dict = load_yaml(self.prediction) + original_prediction_dict = load_yaml(self.prediction, keys_fix_yaml=self.keys_fix) prediction_extra = "pr_files:" for file in remaining_files_list: extra_file_yaml = f"""\ @@ -327,12 +334,12 @@ extra_file_yaml = additional files (token-limit) """ prediction_extra = prediction_extra + "\n" + extra_file_yaml.strip() - prediction_extra_dict = load_yaml(prediction_extra) + prediction_extra_dict = load_yaml(prediction_extra, keys_fix_yaml=self.keys_fix) # merge the two dictionaries if isinstance(original_prediction_dict, dict) and isinstance(prediction_extra_dict, dict): original_prediction_dict["pr_files"].extend(prediction_extra_dict["pr_files"]) new_yaml = yaml.dump(original_prediction_dict) - if load_yaml(new_yaml): + if load_yaml(new_yaml, keys_fix_yaml=self.keys_fix): prediction = new_yaml return prediction except Exception as e: @@ -361,7 +368,7 @@ extra_file_yaml = def _prepare_data(self): # Load the AI prediction data into a dictionary - self.data = load_yaml(self.prediction.strip()) + self.data = load_yaml(self.prediction.strip(), keys_fix_yaml=self.keys_fix) if get_settings().pr_description.add_original_user_description and self.user_description: self.data["User Description"] = self.user_description diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 88799d98..f5f82e06 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -1,5 +1,6 @@ import copy import datetime +import traceback from collections import OrderedDict from functools import partial from typing import List, Tuple @@ -15,6 +16,7 @@ from pr_agent.git_providers import get_git_provider, get_git_provider_with_conte from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage +from pr_agent.tools.ticket_pr_compliance_check import extract_tickets, extract_and_cache_pr_tickets class PRReviewer: @@ -84,6 +86,7 @@ class PRReviewer: "custom_labels": "", "enable_custom_labels": get_settings().config.enable_custom_labels, "is_ai_metadata": get_settings().get("config.enable_ai_metadata", False), + "related_tickets": get_settings().get('related_tickets', []), } self.token_handler = TokenHandler( @@ -121,6 +124,9 @@ class PRReviewer: 'config': dict(get_settings().config)} get_logger().debug("Relevant configs", artifacts=relevant_configs) + # ticket extraction if exists + await extract_and_cache_pr_tickets(self.git_provider, self.vars) + if self.incremental.is_incremental and hasattr(self.git_provider, "unreviewed_files_set") and not self.git_provider.unreviewed_files_set: get_logger().info(f"Incremental review is enabled for {self.pr_url} but there are no new files") previous_review_url = "" @@ -207,7 +213,7 @@ class PRReviewer: first_key = 'review' last_key = 'security_concerns' data = load_yaml(self.prediction.strip(), - keys_fix_yaml=["estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", + keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", "relevant_file:", "relevant_line:", "suggestion:"], first_key=first_key, last_key=last_key) github_action_output(data, 'review') @@ -282,7 +288,7 @@ class PRReviewer: first_key = 'review' last_key = 'security_concerns' data = load_yaml(self.prediction.strip(), - keys_fix_yaml=["estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", + keys_fix_yaml=["ticket_compliance_check", "estimated_effort_to_review_[1-5]:", "security_concerns:", "key_issues_to_review:", "relevant_file:", "relevant_line:", "suggestion:"], first_key=first_key, last_key=last_key) comments: List[str] = [] @@ -401,7 +407,16 @@ class PRReviewer: review_labels = [] if get_settings().pr_reviewer.enable_review_labels_effort: estimated_effort = data['review']['estimated_effort_to_review_[1-5]'] - estimated_effort_number = int(estimated_effort.split(',')[0]) + estimated_effort_number = 0 + if isinstance(estimated_effort, str): + try: + estimated_effort_number = int(estimated_effort.split(',')[0]) + except ValueError: + get_logger().warning(f"Invalid estimated_effort value: {estimated_effort}") + elif isinstance(estimated_effort, int): + estimated_effort_number = estimated_effort + else: + get_logger().warning(f"Unexpected type for estimated_effort: {type(estimated_effort)}") if 1 <= estimated_effort_number <= 5: # 1, because ... review_labels.append(f'Review effort [1-5]: {estimated_effort_number}') if get_settings().pr_reviewer.enable_review_labels_security and get_settings().pr_reviewer.require_security_review: