feat: add ticket compliance check

- Implement ticket compliance check logic in `utils.py` and `ticket_pr_compliance_check.py`
- Add functions to extract and cache PR tickets, and check ticket relevancy
This commit is contained in:
mrT23
2024-10-10 08:48:37 +03:00
parent 014ea884d2
commit 76d95bb6d7
12 changed files with 365 additions and 86 deletions

View File

@ -1,18 +1,22 @@
from __future__ import annotations
import html2text
import html
import copy
import difflib
import hashlib
import html
import json
import os
import re
import textwrap
import time
import traceback
from datetime import datetime
from enum import Enum
from typing import Any, List, Tuple
import html2text
import requests
import yaml
from pydantic import BaseModel
from starlette_context import context
@ -110,6 +114,7 @@ def convert_to_markdown_v2(output_data: dict,
"Insights from user's answers": "📝",
"Code feedback": "🤖",
"Estimated effort to review [1-5]": "⏱️",
"Ticket compliance check": "🎫",
}
markdown_text = ""
if not incremental_review:
@ -165,6 +170,8 @@ def convert_to_markdown_v2(output_data: dict,
markdown_text += f'### {emoji} No relevant tests\n\n'
else:
markdown_text += f"### PR contains tests\n\n"
elif 'ticket compliance check' in key_nice.lower():
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
elif 'security concerns' in key_nice.lower():
if gfm_supported:
markdown_text += f"<tr><td>"
@ -254,6 +261,52 @@ def convert_to_markdown_v2(output_data: dict,
return markdown_text
def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
ticket_compliance_str = ""
final_compliance_level = -1
if isinstance(value, list):
for v in value:
ticket_url = v.get('ticket_url', '').strip()
compliance_level = v.get('overall_compliance_level', '').strip()
# add emojis, if 'Fully compliant' ✅, 'Partially compliant' 🔶, or 'Not compliant' ❌
if compliance_level.lower() == 'fully compliant':
# compliance_level = '✅ Fully compliant'
final_compliance_level = 2 if final_compliance_level == -1 else 1
elif compliance_level.lower() == 'partially compliant':
# compliance_level = '🔶 Partially compliant'
final_compliance_level = 1
elif compliance_level.lower() == 'not compliant':
# compliance_level = '❌ Not compliant'
final_compliance_level = 0 if final_compliance_level < 1 else 1
# explanation = v.get('compliance_analysis', '').strip()
explanation = ''
fully_compliant_str = v.get('fully_compliant_requirements', '').strip()
not_compliant_str = v.get('not_compliant_requirements', '').strip()
if fully_compliant_str:
explanation += f"Fully compliant requirements:\n{fully_compliant_str}\n\n"
if not_compliant_str:
explanation += f"Not compliant requirements:\n{not_compliant_str}\n\n"
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {compliance_level}**\n\n{explanation}\n\n"
if final_compliance_level == 2:
compliance_level = ''
elif final_compliance_level == 1:
compliance_level = '🔶'
else:
compliance_level = ''
if gfm_supported:
markdown_text += f"<tr><td>\n\n"
markdown_text += f"**{emoji} Ticket compliance analysis {compliance_level}**\n\n"
markdown_text += ticket_compliance_str
markdown_text += f"</td></tr>\n"
else:
markdown_text += f"### {emoji} Ticket compliance analysis {compliance_level}\n\n"
markdown_text += ticket_compliance_str+"\n\n"
return markdown_text
def process_can_be_split(emoji, value):
try:
# key_nice = "Can this PR be split?"
@ -554,7 +607,8 @@ def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", l
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
if not data:
get_logger().error(f"Failed to parse AI prediction after fallbacks", artifact={'response_text': response_text})
get_logger().error(f"Failed to parse AI prediction after fallbacks",
artifact={'response_text': response_text})
else:
get_logger().info(f"Successfully parsed AI prediction after fallbacks",
artifact={'response_text': response_text})
@ -841,56 +895,64 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break
return position, absolute_position
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
if git_provider and not rate_limit_status:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
def get_rate_limit_status(github_token) -> dict:
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
# GITHUB_API_URL = "https://api.github.com"
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
HEADERS = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}"
}
if not rate_limit_status:
rate_limit_status = get_rate_limit_status_func()
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
try:
rate_limit_info = response.json()
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
return {'resources': {}}
response.raise_for_status() # Check for HTTP errors
except: # retry
time.sleep(0.1)
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
return response.json()
return rate_limit_info
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
try:
rate_limit_status = get_rate_limit_status(github_token)
if installation_id:
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
# validate that the rate limit is not exceeded
is_rate_limit = False
for key, value in rate_limit_status['resources'].items():
if value['remaining'] == 0:
print(f"key: {key}, value: {value}")
is_rate_limit = True
sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
if sleep_time_sec > 0:
time.sleep(sleep_time_sec+1)
if git_provider:
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
else:
rate_limit_status = get_rate_limit_status_func()
return is_rate_limit
# validate that the rate limit is not exceeded
for key, value in rate_limit_status['resources'].items():
if value['remaining'] < value['limit'] * threshold:
get_logger().error(f"key: {key}, value: {value}")
return False
return True
except Exception as e:
get_logger().error(f"Error in rate limit {e}",
artifact={"traceback": traceback.format_exc()})
return True
def get_largest_component(pr_url):
from pr_agent.tools.pr_analyzer import PRAnalyzer
publish_output = get_settings().config.publish_output
get_settings().config.publish_output = False # disable publish output
analyzer = PRAnalyzer(pr_url)
methods_dict_files = analyzer.run_sync()
get_settings().config.publish_output = publish_output
max_lines_changed = 0
file_b = ""
component_name_b = ""
for file in methods_dict_files:
for method in methods_dict_files[file]:
try:
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
file_b = file
component_name_b = method
except:
pass
if component_name_b:
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
return component_name_b, file_b
else:
return None, None
def validate_and_await_rate_limit(github_token):
try:
rate_limit_status = get_rate_limit_status(github_token)
# validate that the rate limit is not exceeded
for key, value in rate_limit_status['resources'].items():
if value['remaining'] < value['limit'] // 80:
get_logger().error(f"key: {key}, value: {value}")
sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
if sleep_time_sec > 0:
time.sleep(sleep_time_sec + 1)
rate_limit_status = get_rate_limit_status(github_token)
return rate_limit_status
except:
get_logger().error("Error in rate limit")
return None
def github_action_output(output_data: dict, key_name: str):
try:
@ -906,7 +968,7 @@ def github_action_output(output_data: dict, key_name: str):
def show_relevant_configurations(relevant_section: str) -> str:
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys",
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys:
@ -939,6 +1001,25 @@ def is_value_no(value):
return False
def set_pr_string(repo_name, pr_number):
return f"{repo_name}#{pr_number}"
def string_to_uniform_number(s: str) -> float:
"""
Convert a string to a uniform number in the range [0, 1].
The uniform distribution is achieved by the nature of the SHA-256 hash function, which produces a uniformly distributed hash value over its output space.
"""
# Generate a hash of the string
hash_object = hashlib.sha256(s.encode())
# Convert the hash to an integer
hash_int = int(hash_object.hexdigest(), 16)
# Normalize the integer to the range [0, 1]
max_hash_int = 2 ** 256 - 1
uniform_number = float(hash_int) / max_hash_int
return uniform_number
def process_description(description_full: str) -> Tuple[str, List]:
if not description_full:
return "", []
@ -998,7 +1079,10 @@ def process_description(description_full: str) -> Tuple[str, List]:
'long_summary': long_summary
})
else:
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
if '<code>...</code>' in file_data:
pass # PR with many files. some did not get analyzed
else:
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
except Exception as e:
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data})