mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
feat: add ticket compliance check
- Implement ticket compliance check logic in `utils.py` and `ticket_pr_compliance_check.py` - Add functions to extract and cache PR tickets, and check ticket relevancy
This commit is contained in:
@ -1,18 +1,22 @@
|
||||
from __future__ import annotations
|
||||
import html2text
|
||||
|
||||
import html
|
||||
import copy
|
||||
import difflib
|
||||
import hashlib
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
|
||||
import html2text
|
||||
import requests
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from starlette_context import context
|
||||
@ -110,6 +114,7 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
"Insights from user's answers": "📝",
|
||||
"Code feedback": "🤖",
|
||||
"Estimated effort to review [1-5]": "⏱️",
|
||||
"Ticket compliance check": "🎫",
|
||||
}
|
||||
markdown_text = ""
|
||||
if not incremental_review:
|
||||
@ -165,6 +170,8 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
markdown_text += f'### {emoji} No relevant tests\n\n'
|
||||
else:
|
||||
markdown_text += f"### PR contains tests\n\n"
|
||||
elif 'ticket compliance check' in key_nice.lower():
|
||||
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
|
||||
elif 'security concerns' in key_nice.lower():
|
||||
if gfm_supported:
|
||||
markdown_text += f"<tr><td>"
|
||||
@ -254,6 +261,52 @@ def convert_to_markdown_v2(output_data: dict,
|
||||
return markdown_text
|
||||
|
||||
|
||||
def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
|
||||
ticket_compliance_str = ""
|
||||
final_compliance_level = -1
|
||||
if isinstance(value, list):
|
||||
for v in value:
|
||||
ticket_url = v.get('ticket_url', '').strip()
|
||||
compliance_level = v.get('overall_compliance_level', '').strip()
|
||||
# add emojis, if 'Fully compliant' ✅, 'Partially compliant' 🔶, or 'Not compliant' ❌
|
||||
if compliance_level.lower() == 'fully compliant':
|
||||
# compliance_level = '✅ Fully compliant'
|
||||
final_compliance_level = 2 if final_compliance_level == -1 else 1
|
||||
elif compliance_level.lower() == 'partially compliant':
|
||||
# compliance_level = '🔶 Partially compliant'
|
||||
final_compliance_level = 1
|
||||
elif compliance_level.lower() == 'not compliant':
|
||||
# compliance_level = '❌ Not compliant'
|
||||
final_compliance_level = 0 if final_compliance_level < 1 else 1
|
||||
|
||||
# explanation = v.get('compliance_analysis', '').strip()
|
||||
explanation = ''
|
||||
fully_compliant_str = v.get('fully_compliant_requirements', '').strip()
|
||||
not_compliant_str = v.get('not_compliant_requirements', '').strip()
|
||||
if fully_compliant_str:
|
||||
explanation += f"Fully compliant requirements:\n{fully_compliant_str}\n\n"
|
||||
if not_compliant_str:
|
||||
explanation += f"Not compliant requirements:\n{not_compliant_str}\n\n"
|
||||
|
||||
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {compliance_level}**\n\n{explanation}\n\n"
|
||||
if final_compliance_level == 2:
|
||||
compliance_level = '✅'
|
||||
elif final_compliance_level == 1:
|
||||
compliance_level = '🔶'
|
||||
else:
|
||||
compliance_level = '❌'
|
||||
|
||||
if gfm_supported:
|
||||
markdown_text += f"<tr><td>\n\n"
|
||||
markdown_text += f"**{emoji} Ticket compliance analysis {compliance_level}**\n\n"
|
||||
markdown_text += ticket_compliance_str
|
||||
markdown_text += f"</td></tr>\n"
|
||||
else:
|
||||
markdown_text += f"### {emoji} Ticket compliance analysis {compliance_level}\n\n"
|
||||
markdown_text += ticket_compliance_str+"\n\n"
|
||||
return markdown_text
|
||||
|
||||
|
||||
def process_can_be_split(emoji, value):
|
||||
try:
|
||||
# key_nice = "Can this PR be split?"
|
||||
@ -554,7 +607,8 @@ def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", l
|
||||
get_logger().warning(f"Initial failure to parse AI prediction: {e}")
|
||||
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
|
||||
if not data:
|
||||
get_logger().error(f"Failed to parse AI prediction after fallbacks", artifact={'response_text': response_text})
|
||||
get_logger().error(f"Failed to parse AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text})
|
||||
else:
|
||||
get_logger().info(f"Successfully parsed AI prediction after fallbacks",
|
||||
artifact={'response_text': response_text})
|
||||
@ -841,56 +895,64 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None):
|
||||
if git_provider and not rate_limit_status:
|
||||
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||
def get_rate_limit_status(github_token) -> dict:
|
||||
GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
|
||||
# GITHUB_API_URL = "https://api.github.com"
|
||||
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
|
||||
HEADERS = {
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"Authorization": f"token {github_token}"
|
||||
}
|
||||
|
||||
if not rate_limit_status:
|
||||
rate_limit_status = get_rate_limit_status_func()
|
||||
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
|
||||
try:
|
||||
rate_limit_info = response.json()
|
||||
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
|
||||
return {'resources': {}}
|
||||
response.raise_for_status() # Check for HTTP errors
|
||||
except: # retry
|
||||
time.sleep(0.1)
|
||||
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
|
||||
return response.json()
|
||||
return rate_limit_info
|
||||
|
||||
|
||||
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
|
||||
try:
|
||||
rate_limit_status = get_rate_limit_status(github_token)
|
||||
if installation_id:
|
||||
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
|
||||
# validate that the rate limit is not exceeded
|
||||
is_rate_limit = False
|
||||
for key, value in rate_limit_status['resources'].items():
|
||||
if value['remaining'] == 0:
|
||||
print(f"key: {key}, value: {value}")
|
||||
is_rate_limit = True
|
||||
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||
sleep_time_hour = sleep_time_sec / 3600.0
|
||||
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
||||
if sleep_time_sec > 0:
|
||||
time.sleep(sleep_time_sec+1)
|
||||
|
||||
if git_provider:
|
||||
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
|
||||
else:
|
||||
rate_limit_status = get_rate_limit_status_func()
|
||||
|
||||
return is_rate_limit
|
||||
# validate that the rate limit is not exceeded
|
||||
for key, value in rate_limit_status['resources'].items():
|
||||
if value['remaining'] < value['limit'] * threshold:
|
||||
get_logger().error(f"key: {key}, value: {value}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error in rate limit {e}",
|
||||
artifact={"traceback": traceback.format_exc()})
|
||||
return True
|
||||
|
||||
|
||||
def get_largest_component(pr_url):
|
||||
from pr_agent.tools.pr_analyzer import PRAnalyzer
|
||||
publish_output = get_settings().config.publish_output
|
||||
get_settings().config.publish_output = False # disable publish output
|
||||
analyzer = PRAnalyzer(pr_url)
|
||||
methods_dict_files = analyzer.run_sync()
|
||||
get_settings().config.publish_output = publish_output
|
||||
max_lines_changed = 0
|
||||
file_b = ""
|
||||
component_name_b = ""
|
||||
for file in methods_dict_files:
|
||||
for method in methods_dict_files[file]:
|
||||
try:
|
||||
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
|
||||
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
|
||||
file_b = file
|
||||
component_name_b = method
|
||||
except:
|
||||
pass
|
||||
if component_name_b:
|
||||
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
|
||||
return component_name_b, file_b
|
||||
else:
|
||||
return None, None
|
||||
def validate_and_await_rate_limit(github_token):
|
||||
try:
|
||||
rate_limit_status = get_rate_limit_status(github_token)
|
||||
# validate that the rate limit is not exceeded
|
||||
for key, value in rate_limit_status['resources'].items():
|
||||
if value['remaining'] < value['limit'] // 80:
|
||||
get_logger().error(f"key: {key}, value: {value}")
|
||||
sleep_time_sec = value['reset'] - datetime.now().timestamp()
|
||||
sleep_time_hour = sleep_time_sec / 3600.0
|
||||
get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
|
||||
if sleep_time_sec > 0:
|
||||
time.sleep(sleep_time_sec + 1)
|
||||
rate_limit_status = get_rate_limit_status(github_token)
|
||||
return rate_limit_status
|
||||
except:
|
||||
get_logger().error("Error in rate limit")
|
||||
return None
|
||||
|
||||
|
||||
def github_action_output(output_data: dict, key_name: str):
|
||||
try:
|
||||
@ -906,7 +968,7 @@ def github_action_output(output_data: dict, key_name: str):
|
||||
|
||||
|
||||
def show_relevant_configurations(relevant_section: str) -> str:
|
||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys",
|
||||
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
|
||||
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
|
||||
extra_skip_keys = get_settings().config.get('config.skip_keys', [])
|
||||
if extra_skip_keys:
|
||||
@ -939,6 +1001,25 @@ def is_value_no(value):
|
||||
return False
|
||||
|
||||
|
||||
def set_pr_string(repo_name, pr_number):
|
||||
return f"{repo_name}#{pr_number}"
|
||||
|
||||
|
||||
def string_to_uniform_number(s: str) -> float:
|
||||
"""
|
||||
Convert a string to a uniform number in the range [0, 1].
|
||||
The uniform distribution is achieved by the nature of the SHA-256 hash function, which produces a uniformly distributed hash value over its output space.
|
||||
"""
|
||||
# Generate a hash of the string
|
||||
hash_object = hashlib.sha256(s.encode())
|
||||
# Convert the hash to an integer
|
||||
hash_int = int(hash_object.hexdigest(), 16)
|
||||
# Normalize the integer to the range [0, 1]
|
||||
max_hash_int = 2 ** 256 - 1
|
||||
uniform_number = float(hash_int) / max_hash_int
|
||||
return uniform_number
|
||||
|
||||
|
||||
def process_description(description_full: str) -> Tuple[str, List]:
|
||||
if not description_full:
|
||||
return "", []
|
||||
@ -998,7 +1079,10 @@ def process_description(description_full: str) -> Tuple[str, List]:
|
||||
'long_summary': long_summary
|
||||
})
|
||||
else:
|
||||
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
|
||||
if '<code>...</code>' in file_data:
|
||||
pass # PR with many files. some did not get analyzed
|
||||
else:
|
||||
get_logger().error(f"Failed to parse description", artifact={'description': file_data})
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to process description: {e}", artifact={'description': file_data})
|
||||
|
||||
|
Reference in New Issue
Block a user