feat: add ticket compliance check and rate limit validation

- Implement ticket compliance check logic in `utils.py` and `ticket_pr_compliance_check.py`
- Add functions to extract and cache PR tickets, and check ticket relevancy
- Introduce rate limit validation for GitHub API requests
- Update `pr_reviewer_prompts.toml` and `pr_description_prompts.toml` to include ticket compliance fields
- Modify configuration to require ticket analysis review
This commit is contained in:
mrT23
2024-10-10 08:48:37 +03:00
parent 014ea884d2
commit 4eef3e9190
5 changed files with 324 additions and 58 deletions

View File

@ -1,18 +1,22 @@
from __future__ import annotations from __future__ import annotations
import html2text
import html
import copy import copy
import difflib import difflib
import hashlib
import html
import json import json
import os import os
import re import re
import textwrap import textwrap
import time import time
import traceback
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, List, Tuple from typing import Any, List, Tuple
import html2text
import requests
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from starlette_context import context from starlette_context import context
@ -110,6 +114,7 @@ def convert_to_markdown_v2(output_data: dict,
"Insights from user's answers": "📝", "Insights from user's answers": "📝",
"Code feedback": "🤖", "Code feedback": "🤖",
"Estimated effort to review [1-5]": "⏱️", "Estimated effort to review [1-5]": "⏱️",
"Ticket compliance check": "🎫",
} }
markdown_text = "" markdown_text = ""
if not incremental_review: if not incremental_review:
@ -165,6 +170,8 @@ def convert_to_markdown_v2(output_data: dict,
markdown_text += f'### {emoji} No relevant tests\n\n' markdown_text += f'### {emoji} No relevant tests\n\n'
else: else:
markdown_text += f"### PR contains tests\n\n" markdown_text += f"### PR contains tests\n\n"
elif 'ticket compliance check' in key_nice.lower():
markdown_text = ticket_markdown_logic(emoji, markdown_text, value, gfm_supported)
elif 'security concerns' in key_nice.lower(): elif 'security concerns' in key_nice.lower():
if gfm_supported: if gfm_supported:
markdown_text += f"<tr><td>" markdown_text += f"<tr><td>"
@ -254,6 +261,52 @@ def convert_to_markdown_v2(output_data: dict,
return markdown_text return markdown_text
def ticket_markdown_logic(emoji, markdown_text, value, gfm_supported) -> str:
ticket_compliance_str = ""
final_compliance_level = -1
if isinstance(value, list):
for v in value:
ticket_url = v.get('ticket_url', '').strip()
compliance_level = v.get('overall_compliance_level', '').strip()
# add emojis, if 'Fully compliant' ✅, 'Partially compliant' 🔶, or 'Not compliant' ❌
if compliance_level.lower() == 'fully compliant':
# compliance_level = '✅ Fully compliant'
final_compliance_level = 2 if final_compliance_level == -1 else 1
elif compliance_level.lower() == 'partially compliant':
# compliance_level = '🔶 Partially compliant'
final_compliance_level = 1
elif compliance_level.lower() == 'not compliant':
# compliance_level = '❌ Not compliant'
final_compliance_level = 0 if final_compliance_level < 1 else 1
# explanation = v.get('compliance_analysis', '').strip()
explanation = ''
fully_compliant_str = v.get('fully_compliant_requirements', '').strip()
not_compliant_str = v.get('not_compliant_requirements', '').strip()
if fully_compliant_str:
explanation += f"Fully compliant requirements:\n{fully_compliant_str}\n\n"
if not_compliant_str:
explanation += f"Not compliant requirements:\n{not_compliant_str}\n\n"
ticket_compliance_str += f"\n\n**[{ticket_url.split('/')[-1]}]({ticket_url}) - {compliance_level}**\n\n{explanation}\n\n"
if final_compliance_level == 2:
compliance_level = ''
elif final_compliance_level == 1:
compliance_level = '🔶'
else:
compliance_level = ''
if gfm_supported:
markdown_text += f"<tr><td>\n\n"
markdown_text += f"**{emoji} Ticket compliance analysis {compliance_level}**\n\n"
markdown_text += ticket_compliance_str
markdown_text += f"</td></tr>\n"
else:
markdown_text += f"### {emoji} Ticket compliance analysis {compliance_level}\n\n"
markdown_text += ticket_compliance_str+"\n\n"
return markdown_text
def process_can_be_split(emoji, value): def process_can_be_split(emoji, value):
try: try:
# key_nice = "Can this PR be split?" # key_nice = "Can this PR be split?"
@ -554,7 +607,8 @@ def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", l
get_logger().warning(f"Initial failure to parse AI prediction: {e}") get_logger().warning(f"Initial failure to parse AI prediction: {e}")
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key) data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml, first_key=first_key, last_key=last_key)
if not data: if not data:
get_logger().error(f"Failed to parse AI prediction after fallbacks", artifact={'response_text': response_text}) get_logger().error(f"Failed to parse AI prediction after fallbacks",
artifact={'response_text': response_text})
else: else:
get_logger().info(f"Successfully parsed AI prediction after fallbacks", get_logger().info(f"Successfully parsed AI prediction after fallbacks",
artifact={'response_text': response_text}) artifact={'response_text': response_text})
@ -841,56 +895,64 @@ def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
break break
return position, absolute_position return position, absolute_position
def validate_and_await_rate_limit(rate_limit_status=None, git_provider=None, get_rate_limit_status_func=None): def get_rate_limit_status(github_token) -> dict:
if git_provider and not rate_limit_status: GITHUB_API_URL = get_settings(use_context=False).get("GITHUB.BASE_URL", "https://api.github.com").rstrip("/") # "https://api.github.com"
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data} # GITHUB_API_URL = "https://api.github.com"
RATE_LIMIT_URL = f"{GITHUB_API_URL}/rate_limit"
HEADERS = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {github_token}"
}
if not rate_limit_status: response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
rate_limit_status = get_rate_limit_status_func() try:
rate_limit_info = response.json()
if rate_limit_info.get('message') == 'Rate limiting is not enabled.': # for github enterprise
return {'resources': {}}
response.raise_for_status() # Check for HTTP errors
except: # retry
time.sleep(0.1)
response = requests.get(RATE_LIMIT_URL, headers=HEADERS)
return response.json()
return rate_limit_info
def validate_rate_limit_github(github_token, installation_id=None, threshold=0.1) -> bool:
try:
rate_limit_status = get_rate_limit_status(github_token)
if installation_id:
get_logger().debug(f"installation_id: {installation_id}, Rate limit status: {rate_limit_status['rate']}")
# validate that the rate limit is not exceeded
# validate that the rate limit is not exceeded # validate that the rate limit is not exceeded
is_rate_limit = False
for key, value in rate_limit_status['resources'].items(): for key, value in rate_limit_status['resources'].items():
if value['remaining'] == 0: if value['remaining'] < value['limit'] * threshold:
print(f"key: {key}, value: {value}") get_logger().error(f"key: {key}, value: {value}")
is_rate_limit = True return False
return True
except Exception as e:
get_logger().error(f"Error in rate limit {e}",
artifact={"traceback": traceback.format_exc()})
return True
def validate_and_await_rate_limit(github_token):
try:
rate_limit_status = get_rate_limit_status(github_token)
# validate that the rate limit is not exceeded
for key, value in rate_limit_status['resources'].items():
if value['remaining'] < value['limit'] // 80:
get_logger().error(f"key: {key}, value: {value}")
sleep_time_sec = value['reset'] - datetime.now().timestamp() sleep_time_sec = value['reset'] - datetime.now().timestamp()
sleep_time_hour = sleep_time_sec / 3600.0 sleep_time_hour = sleep_time_sec / 3600.0
print(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours") get_logger().error(f"Rate limit exceeded. Sleeping for {sleep_time_hour} hours")
if sleep_time_sec > 0: if sleep_time_sec > 0:
time.sleep(sleep_time_sec + 1) time.sleep(sleep_time_sec + 1)
rate_limit_status = get_rate_limit_status(github_token)
if git_provider: return rate_limit_status
rate_limit_status = {'resources': git_provider.github_client.get_rate_limit().raw_data}
else:
rate_limit_status = get_rate_limit_status_func()
return is_rate_limit
def get_largest_component(pr_url):
from pr_agent.tools.pr_analyzer import PRAnalyzer
publish_output = get_settings().config.publish_output
get_settings().config.publish_output = False # disable publish output
analyzer = PRAnalyzer(pr_url)
methods_dict_files = analyzer.run_sync()
get_settings().config.publish_output = publish_output
max_lines_changed = 0
file_b = ""
component_name_b = ""
for file in methods_dict_files:
for method in methods_dict_files[file]:
try:
if methods_dict_files[file][method]['num_plus_lines'] > max_lines_changed:
max_lines_changed = methods_dict_files[file][method]['num_plus_lines']
file_b = file
component_name_b = method
except: except:
pass get_logger().error("Error in rate limit")
if component_name_b: return None
get_logger().info(f"Using the largest changed component: '{component_name_b}'")
return component_name_b, file_b
else:
return None, None
def github_action_output(output_data: dict, key_name: str): def github_action_output(output_data: dict, key_name: str):
try: try:
@ -906,7 +968,7 @@ def github_action_output(output_data: dict, key_name: str):
def show_relevant_configurations(relevant_section: str) -> str: def show_relevant_configurations(relevant_section: str) -> str:
skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", skip_keys = ['ai_disclaimer', 'ai_disclaimer_title', 'ANALYTICS_FOLDER', 'secret_provider', "skip_keys", "app_id", "redirect",
'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME'] 'trial_prefix_message', 'no_eligible_message', 'identity_provider', 'ALLOWED_REPOS','APP_NAME']
extra_skip_keys = get_settings().config.get('config.skip_keys', []) extra_skip_keys = get_settings().config.get('config.skip_keys', [])
if extra_skip_keys: if extra_skip_keys:
@ -939,6 +1001,25 @@ def is_value_no(value):
return False return False
def set_pr_string(repo_name, pr_number):
return f"{repo_name}#{pr_number}"
def string_to_uniform_number(s: str) -> float:
"""
Convert a string to a uniform number in the range [0, 1].
The uniform distribution is achieved by the nature of the SHA-256 hash function, which produces a uniformly distributed hash value over its output space.
"""
# Generate a hash of the string
hash_object = hashlib.sha256(s.encode())
# Convert the hash to an integer
hash_int = int(hash_object.hexdigest(), 16)
# Normalize the integer to the range [0, 1]
max_hash_int = 2 ** 256 - 1
uniform_number = float(hash_int) / max_hash_int
return uniform_number
def process_description(description_full: str) -> Tuple[str, List]: def process_description(description_full: str) -> Tuple[str, List]:
if not description_full: if not description_full:
return "", [] return "", []
@ -997,6 +1078,9 @@ def process_description(description_full: str) -> Tuple[str, List]:
'short_summary': short_summary, 'short_summary': short_summary,
'long_summary': long_summary 'long_summary': long_summary
}) })
else:
if '<code>...</code>' in file_data:
pass # PR with many files. some did not get analyzed
else: else:
get_logger().error(f"Failed to parse description", artifact={'description': file_data}) get_logger().error(f"Failed to parse description", artifact={'description': file_data})
except Exception as e: except Exception as e:

View File

@ -51,9 +51,7 @@ require_tests_review=true
require_estimate_effort_to_review=true require_estimate_effort_to_review=true
require_can_be_split_review=false require_can_be_split_review=false
require_security_review=true require_security_review=true
# soc2 require_ticket_analysis_review=true
require_soc2_ticket=false
soc2_ticket_prompt="Does the PR description include a link to ticket in a project management system (e.g., Jira, Asana, Trello, etc.) ?"
# general options # general options
num_code_suggestions=0 num_code_suggestions=0
inline_code_comments = false inline_code_comments = false

View File

@ -78,9 +78,9 @@ pr_files:
... ...
... ...
{%- endif %} {%- endif %}
description: |- description: |
... ...
title: |- title: |
... ...
{%- if enable_custom_labels %} {%- if enable_custom_labels %}
labels: labels:
@ -94,7 +94,26 @@ labels:
Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|') Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|')
""" """
user="""PR Info: user="""
{%- if related_tickets %}
Related Ticket Info:
{% for ticket in related_tickets %}
=====
Ticket Title: '{{ ticket.title }}'
{%- if ticket.labels %}
Ticket Labels: {{ ticket.labels }}
{%- endif %}
{%- if ticket.body %}
Ticket Description:
#####
{{ ticket.body }}
#####
{%- endif %}
=====
{% endfor %}
{%- endif %}
PR Info:
Previous title: '{{title}}' Previous title: '{{title}}'

View File

@ -85,7 +85,20 @@ class KeyIssuesComponentLink(BaseModel):
start_line: int = Field(description="The start line that corresponds to this issue in the relevant file") start_line: int = Field(description="The start line that corresponds to this issue in the relevant file")
end_line: int = Field(description="The end line that corresponds to this issue in the relevant file") end_line: int = Field(description="The end line that corresponds to this issue in the relevant file")
{%- if related_tickets %}
class TicketCompliance(BaseModel):
ticket_url: str = Field(description="Ticket URL or ID")
ticket_requirements: str = Field(description="Repeat, in your own words, all ticket requirements, in bullet points")
fully_compliant_requirements: str = Field(description="A list, in bullet points, of which requirements are met by the PR code. Don't explain how the requirements are met, just list them shortly. Can be empty")
not_compliant_requirements: str = Field(description="A list, in bullet points, of which requirements are not met by the PR code. Don't explain how the requirements are not met, just list them shortly. Can be empty")
overall_compliance_level: str = Field(description="Overall give this PR one of these three values in relation to the ticket: 'Fully compliant', 'Partially compliant', or 'Not compliant'")
{%- endif %}
class Review(BaseModel): class Review(BaseModel):
{%- if related_tickets %}
ticket_compliance_check: List[TicketCompliance] = Field(description="A list of compliance checks for the related tickets")
{%- endif %}
{%- if require_estimate_effort_to_review %} {%- if require_estimate_effort_to_review %}
estimated_effort_to_review_[1-5]: int = Field(description="Estimate, on a scale of 1-5 (inclusive), the time and effort required to review this PR by an experienced and knowledgeable developer. 1 means short and easy review , 5 means long and hard review. Take into account the size, complexity, quality, and the needed changes of the PR code diff.") estimated_effort_to_review_[1-5]: int = Field(description="Estimate, on a scale of 1-5 (inclusive), the time and effort required to review this PR by an experienced and knowledgeable developer. 1 means short and easy review , 5 means long and hard review. Take into account the size, complexity, quality, and the needed changes of the PR code diff.")
{%- endif %} {%- endif %}
@ -130,6 +143,19 @@ class PRReview(BaseModel):
Example output: Example output:
```yaml ```yaml
review: review:
{%- if related_tickets %}
ticket_compliance_check:
- ticket_url: |
...
ticket_requirements: |
...
fully_compliant_requirements: |
...
not_compliant_requirements: |
...
overall_compliance_level: |
...
{%- endif %}
{%- if require_estimate_effort_to_review %} {%- if require_estimate_effort_to_review %}
estimated_effort_to_review_[1-5]: | estimated_effort_to_review_[1-5]: |
3 3
@ -176,7 +202,33 @@ code_feedback:
Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|') Answer should be a valid YAML, and nothing else. Each YAML output MUST be after a newline, with proper indent, and block scalar indicator ('|')
""" """
user="""--PR Info-- user="""
{%- if related_tickets %}
--PR Ticket Info--
{%- for ticket in related_tickets %}
=====
Ticket URL: '{{ ticket.ticket_url }}'
Ticket Title: '{{ ticket.title }}'
{%- if ticket.labels %}
Ticket Labels: {{ ticket.labels }}
{%- endif %}
{%- if ticket.body %}
Ticket Description:
#####
{{ ticket.body }}
#####
{%- endif %}
=====
{% endfor %}
{%- endif %}
--PR Info--
Title: '{{title}}' Title: '{{title}}'

View File

@ -0,0 +1,113 @@
import re
import traceback
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import GithubProvider
from pr_agent.log import get_logger
def find_jira_tickets(text):
# Regular expression patterns for JIRA tickets
patterns = [
r'\b[A-Z]{2,10}-\d{1,7}\b', # Standard JIRA ticket format (e.g., PROJ-123)
r'(?:https?://[^\s/]+/browse/)?([A-Z]{2,10}-\d{1,7})\b' # JIRA URL or just the ticket
]
tickets = set()
for pattern in patterns:
matches = re.findall(pattern, text)
for match in matches:
if isinstance(match, tuple):
# If it's a tuple (from the URL pattern), take the last non-empty group
ticket = next((m for m in reversed(match) if m), None)
else:
ticket = match
if ticket:
tickets.add(ticket)
return list(tickets)
def extract_ticket_links_from_pr_description(pr_description, repo_path):
"""
Extract all ticket links from PR description
"""
# example link to search for: https://github.com/Codium-ai/pr-agent-pro/issues/525
pattern = r'https://github[^/]+/[^/]+/[^/]+/issues/\d+' # should support also github server (for example 'https://github.company.ai/Codium-ai/pr-agent-pro/issues/525')
# Find all matches in the text
github_tickets = re.findall(pattern, pr_description)
# Find all issues referenced like #123 and add them as https://github.com/{repo_path}/issues/{issue_number}
# (unneeded, since when you pull the actual comment, it appears as a full link)
# issue_number_pattern = r'#\d+'
# issue_numbers = re.findall(issue_number_pattern, pr_description)
# for issue_number in issue_numbers:
# issue_number = issue_number[1:] # remove #
# # check if issue_number is a valid number and len(issue_number) < 5
# if issue_number.isdigit() and len(issue_number) < 5:
# github_tickets.append(f'https://github.com/{repo_path}/issues/{issue_number}')
return github_tickets
async def extract_tickets(git_provider):
MAX_TICKET_CHARACTERS = 10000
try:
if isinstance(git_provider, GithubProvider):
user_description = git_provider.get_user_description()
tickets = extract_ticket_links_from_pr_description(user_description, git_provider.repo)
tickets_content = []
if tickets:
for ticket in tickets:
# extract ticket number and repo name
repo_name, original_issue_number = git_provider._parse_issue_url(ticket)
# get the ticket object
issue_main = git_provider.repo_obj.get_issue(original_issue_number)
# clip issue_main.body max length
issue_body = issue_main.body
if len(issue_main.body) > MAX_TICKET_CHARACTERS:
issue_body = issue_main.body[:MAX_TICKET_CHARACTERS] + "..."
# extract labels
labels = []
try:
for label in issue_main.labels:
if isinstance(label, str):
labels.append(label)
else:
labels.append(label.name)
except Exception as e:
get_logger().error(f"Error extracting labels error= {e}",
artifact={"traceback": traceback.format_exc()})
tickets_content.append(
{'ticket_id': issue_main.number,
'ticket_url': ticket, 'title': issue_main.title, 'body': issue_body,
'labels': ", ".join(labels)})
return tickets_content
except Exception as e:
get_logger().error(f"Error extracting tickets error= {e}",
artifact={"traceback": traceback.format_exc()})
async def extract_and_cache_pr_tickets(git_provider, vars):
if get_settings().get('config.require_ticket_analysis_review', False):
return
related_tickets = get_settings().get('related_tickets', [])
if not related_tickets:
tickets_content = await extract_tickets(git_provider)
if tickets_content:
get_logger().info("Extracted tickets from PR description", artifact={"tickets": tickets_content})
vars['related_tickets'] = tickets_content
get_settings().set('related_tickets', tickets_content)
else: # if tickets are already cached
get_logger().info("Using cached tickets", artifact={"tickets": related_tickets})
vars['related_tickets'] = related_tickets
def check_tickets_relevancy():
return True