2023-07-06 00:21:08 +03:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-07-18 23:14:47 +03:00
|
|
|
import difflib
|
2023-07-11 22:11:42 +03:00
|
|
|
import json
|
|
|
|
import re
|
2023-07-06 00:21:08 +03:00
|
|
|
import textwrap
|
2023-08-01 14:43:26 +03:00
|
|
|
from datetime import datetime
|
2024-02-01 09:46:04 +02:00
|
|
|
from enum import Enum
|
2024-02-05 09:20:36 +02:00
|
|
|
from typing import Any, List, Tuple
|
2023-07-06 00:21:08 +03:00
|
|
|
|
2023-08-09 08:50:15 +03:00
|
|
|
import yaml
|
2023-08-01 14:43:26 +03:00
|
|
|
from starlette_context import context
|
2023-11-07 14:28:41 +02:00
|
|
|
|
|
|
|
from pr_agent.algo import MAX_TOKENS
|
2023-11-26 08:29:47 +02:00
|
|
|
from pr_agent.algo.token_handler import get_token_encoder
|
2023-08-01 14:43:26 +03:00
|
|
|
from pr_agent.config_loader import get_settings, global_settings
|
2024-02-05 09:20:36 +02:00
|
|
|
from pr_agent.algo.types import FilePatchInfo
|
2023-10-16 14:56:00 +03:00
|
|
|
from pr_agent.log import get_logger
|
2023-08-01 14:43:26 +03:00
|
|
|
|
2024-02-01 09:46:04 +02:00
|
|
|
class ModelType(str, Enum):
|
|
|
|
REGULAR = "regular"
|
|
|
|
TURBO = "turbo"
|
2023-08-01 14:43:26 +03:00
|
|
|
|
|
|
|
def get_setting(key: str) -> Any:
|
|
|
|
try:
|
|
|
|
key = key.upper()
|
|
|
|
return context.get("settings", global_settings).get(key, global_settings.get(key, None))
|
|
|
|
except Exception:
|
|
|
|
return global_settings.get(key, None)
|
2023-07-06 00:21:08 +03:00
|
|
|
|
2023-09-12 07:43:15 +03:00
|
|
|
def convert_to_markdown(output_data: dict, gfm_supported: bool=True) -> str:
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
|
|
|
Convert a dictionary of data into markdown format.
|
|
|
|
Args:
|
|
|
|
output_data (dict): A dictionary containing data to be converted to markdown format.
|
|
|
|
Returns:
|
|
|
|
str: The markdown formatted text generated from the input dictionary.
|
|
|
|
"""
|
2023-07-06 00:21:08 +03:00
|
|
|
|
|
|
|
emojis = {
|
2024-02-08 17:08:42 +02:00
|
|
|
"Possible issues": "🔍",
|
2023-07-18 16:27:42 +03:00
|
|
|
"Score": "🏅",
|
2024-02-09 12:50:51 +02:00
|
|
|
"Relevant tests": "🧪",
|
2023-07-11 08:50:28 +03:00
|
|
|
"Focused PR": "✨",
|
2023-07-06 00:21:08 +03:00
|
|
|
"Security concerns": "🔒",
|
2023-07-18 16:32:51 +03:00
|
|
|
"Insights from user's answers": "📝",
|
2023-08-05 10:34:09 +03:00
|
|
|
"Code feedback": "🤖",
|
2023-09-17 17:08:02 +03:00
|
|
|
"Estimated effort to review [1-5]": "⏱️",
|
2023-07-06 00:21:08 +03:00
|
|
|
}
|
2024-02-08 17:08:42 +02:00
|
|
|
markdown_text = ""
|
|
|
|
markdown_text += f"## PR Review\n\n"
|
2024-02-13 18:33:22 +02:00
|
|
|
if gfm_supported:
|
|
|
|
markdown_text += "<table>\n<tr>\n"
|
|
|
|
markdown_text += """<td> <strong>PR feedback</strong> </td> <td></td></tr>"""
|
2024-02-08 20:14:25 +02:00
|
|
|
|
|
|
|
if not output_data or not output_data.get('review', {}):
|
|
|
|
return ""
|
|
|
|
|
2024-02-08 17:08:42 +02:00
|
|
|
for key, value in output_data['review'].items():
|
2024-01-06 10:15:04 +02:00
|
|
|
if value is None or value == '' or value == {} or value == []:
|
2023-07-06 00:21:08 +03:00
|
|
|
continue
|
2024-02-08 17:08:42 +02:00
|
|
|
key_nice = key.replace('_', ' ').capitalize()
|
|
|
|
emoji = emojis.get(key_nice, "")
|
2024-02-13 18:33:22 +02:00
|
|
|
if gfm_supported:
|
|
|
|
markdown_text += f"<tr><td> {emoji} {key_nice}</td><td>\n\n{value}\n\n</td></tr>\n"
|
|
|
|
else:
|
2024-02-13 18:37:48 +02:00
|
|
|
if len(value.split()) > 1:
|
|
|
|
markdown_text += f"{emoji} **{key_nice}:**\n\n {value}\n\n"
|
|
|
|
else:
|
|
|
|
markdown_text += f"{emoji} **{key_nice}:** {value}\n\n"
|
2024-02-13 18:33:22 +02:00
|
|
|
if gfm_supported:
|
|
|
|
markdown_text += "</table>\n"
|
2024-02-08 17:08:42 +02:00
|
|
|
|
|
|
|
if 'code_feedback' in output_data:
|
|
|
|
if gfm_supported:
|
|
|
|
markdown_text += f"\n\n"
|
|
|
|
markdown_text += f"<details><summary> <strong>Code feedback:</strong></summary>\n\n"
|
|
|
|
else:
|
|
|
|
markdown_text += f"\n\n** Code feedback:**\n\n"
|
|
|
|
markdown_text += "<hr>"
|
|
|
|
for i, value in enumerate(output_data['code_feedback']):
|
|
|
|
if value is None or value == '' or value == {} or value == []:
|
|
|
|
continue
|
|
|
|
markdown_text += parse_code_suggestion(value, i, gfm_supported)+"\n\n"
|
|
|
|
if markdown_text.endswith('<hr>'):
|
|
|
|
markdown_text = markdown_text[:-4]
|
|
|
|
if gfm_supported:
|
|
|
|
markdown_text += f"</details>"
|
|
|
|
#print(markdown_text)
|
|
|
|
|
2024-02-08 14:26:14 +02:00
|
|
|
|
2023-07-06 00:21:08 +03:00
|
|
|
return markdown_text
|
|
|
|
|
|
|
|
|
2024-02-08 17:08:42 +02:00
|
|
|
def parse_code_suggestion(code_suggestion: dict, i: int = 0, gfm_supported: bool = True) -> str:
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
|
|
|
Convert a dictionary of data into markdown format.
|
|
|
|
|
|
|
|
Args:
|
2024-02-08 17:08:42 +02:00
|
|
|
code_suggestion (dict): A dictionary containing data to be converted to markdown format.
|
2023-07-20 10:51:21 +03:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: A string containing the markdown formatted text generated from the input dictionary.
|
|
|
|
"""
|
2023-07-06 00:21:08 +03:00
|
|
|
markdown_text = ""
|
2024-02-08 17:08:42 +02:00
|
|
|
if gfm_supported and 'relevant_line' in code_suggestion:
|
2023-12-14 07:44:13 +08:00
|
|
|
markdown_text += '<table>'
|
2024-02-08 17:08:42 +02:00
|
|
|
for sub_key, sub_value in code_suggestion.items():
|
2023-12-14 07:44:13 +08:00
|
|
|
try:
|
2024-02-08 17:08:42 +02:00
|
|
|
if sub_key.lower() == 'relevant_file':
|
2023-12-14 07:44:13 +08:00
|
|
|
relevant_file = sub_value.strip('`').strip('"').strip("'")
|
2024-02-08 17:08:42 +02:00
|
|
|
markdown_text += f"<tr><td>relevant file</td><td>{relevant_file}</td></tr>"
|
2023-12-14 07:44:13 +08:00
|
|
|
# continue
|
|
|
|
elif sub_key.lower() == 'suggestion':
|
2023-12-26 17:06:29 +02:00
|
|
|
markdown_text += (f"<tr><td>{sub_key} </td>"
|
2024-02-08 17:08:42 +02:00
|
|
|
f"<td>\n\n<strong>\n\n{sub_value.strip()}\n\n</strong>\n</td></tr>")
|
|
|
|
elif sub_key.lower() == 'relevant_line':
|
2023-12-14 07:44:13 +08:00
|
|
|
markdown_text += f"<tr><td>relevant line</td>"
|
|
|
|
sub_value_list = sub_value.split('](')
|
|
|
|
relevant_line = sub_value_list[0].lstrip('`').lstrip('[')
|
|
|
|
if len(sub_value_list) > 1:
|
|
|
|
link = sub_value_list[1].rstrip(')').strip('`')
|
2024-01-09 14:56:18 +08:00
|
|
|
markdown_text += f"<td><a href='{link}'>{relevant_line}</a></td>"
|
2023-12-14 07:44:13 +08:00
|
|
|
else:
|
|
|
|
markdown_text += f"<td>{relevant_line}</td>"
|
|
|
|
markdown_text += "</tr>"
|
|
|
|
except Exception as e:
|
|
|
|
get_logger().exception(f"Failed to parse code suggestion: {e}")
|
|
|
|
pass
|
|
|
|
markdown_text += '</table>'
|
|
|
|
markdown_text += "<hr>"
|
|
|
|
else:
|
2024-02-08 17:08:42 +02:00
|
|
|
for sub_key, sub_value in code_suggestion.items():
|
2023-12-14 07:44:13 +08:00
|
|
|
if isinstance(sub_value, dict): # "code example"
|
|
|
|
markdown_text += f" - **{sub_key}:**\n"
|
|
|
|
for code_key, code_value in sub_value.items(): # 'before' and 'after' code
|
|
|
|
code_str = f"```\n{code_value}\n```"
|
|
|
|
code_str_indented = textwrap.indent(code_str, ' ')
|
|
|
|
markdown_text += f" - **{code_key}:**\n{code_str_indented}\n"
|
2023-07-06 00:21:08 +03:00
|
|
|
else:
|
2024-02-08 17:08:42 +02:00
|
|
|
if "relevant_file" in sub_key.lower():
|
2023-12-14 07:44:13 +08:00
|
|
|
markdown_text += f"\n - **{sub_key}:** {sub_value} \n"
|
|
|
|
else:
|
|
|
|
markdown_text += f" **{sub_key}:** {sub_value} \n"
|
|
|
|
if not gfm_supported:
|
2024-02-08 17:08:42 +02:00
|
|
|
if "relevant_line" not in sub_key.lower(): # nicer presentation
|
2023-10-29 17:59:46 +02:00
|
|
|
# markdown_text = markdown_text.rstrip('\n') + "\\\n" # works for gitlab
|
|
|
|
markdown_text = markdown_text.rstrip('\n') + " \n" # works for gitlab and bitbucker
|
2023-07-06 12:49:10 +03:00
|
|
|
|
2023-12-14 07:44:13 +08:00
|
|
|
markdown_text += "\n"
|
2023-07-06 00:21:08 +03:00
|
|
|
return markdown_text
|
|
|
|
|
2023-07-11 22:11:42 +03:00
|
|
|
|
2023-07-17 01:44:40 +03:00
|
|
|
def try_fix_json(review, max_iter=10, code_suggestions=False):
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
|
|
|
Fix broken or incomplete JSON messages and return the parsed JSON data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
- review: A string containing the JSON message to be fixed.
|
|
|
|
- max_iter: An integer representing the maximum number of iterations to try and fix the JSON message.
|
2023-08-05 10:34:09 +03:00
|
|
|
- code_suggestions: A boolean indicating whether to try and fix JSON messages with code feedback.
|
2023-07-20 10:51:21 +03:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
- data: A dictionary containing the parsed JSON data.
|
|
|
|
|
|
|
|
The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion.
|
2023-08-01 14:43:26 +03:00
|
|
|
If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the
|
|
|
|
message.
|
2023-08-05 10:34:09 +03:00
|
|
|
If code_suggestions is True and the JSON message contains code feedback, the function tries to fix the JSON
|
2023-08-01 14:43:26 +03:00
|
|
|
message by parsing until the last valid code suggestion.
|
|
|
|
The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or
|
|
|
|
newlines.
|
2023-07-20 10:51:21 +03:00
|
|
|
It tries to parse the JSON message with the closing bracket and checks if it is valid.
|
|
|
|
If the JSON message is valid, the parsed JSON data is returned.
|
2023-08-01 14:43:26 +03:00
|
|
|
If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON
|
|
|
|
message is obtained or the maximum number of iterations is reached.
|
2023-07-20 10:51:21 +03:00
|
|
|
If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned.
|
|
|
|
"""
|
|
|
|
|
2023-07-17 01:44:40 +03:00
|
|
|
if review.endswith("}"):
|
|
|
|
return fix_json_escape_char(review)
|
2023-07-20 10:51:21 +03:00
|
|
|
|
2023-07-11 22:11:42 +03:00
|
|
|
data = {}
|
2023-07-17 01:44:40 +03:00
|
|
|
if code_suggestions:
|
|
|
|
closing_bracket = "]}"
|
|
|
|
else:
|
|
|
|
closing_bracket = "]}}"
|
2023-07-20 10:51:21 +03:00
|
|
|
|
2023-08-05 10:34:09 +03:00
|
|
|
if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \
|
|
|
|
(review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) :
|
2023-07-11 22:11:42 +03:00
|
|
|
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
|
|
|
valid_json = False
|
2023-07-11 22:22:08 +03:00
|
|
|
iter_count = 0
|
2023-07-20 10:51:21 +03:00
|
|
|
|
2023-07-11 22:22:08 +03:00
|
|
|
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
|
2023-07-11 22:11:42 +03:00
|
|
|
try:
|
2023-07-17 01:44:40 +03:00
|
|
|
data = json.loads(review[:last_code_suggestion_ind] + closing_bracket)
|
2023-07-11 22:11:42 +03:00
|
|
|
valid_json = True
|
2023-07-17 01:44:40 +03:00
|
|
|
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
2023-07-11 22:11:42 +03:00
|
|
|
except json.decoder.JSONDecodeError:
|
|
|
|
review = review[:last_code_suggestion_ind]
|
|
|
|
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
2023-07-11 22:22:08 +03:00
|
|
|
iter_count += 1
|
2023-07-20 10:51:21 +03:00
|
|
|
|
2023-07-11 22:11:42 +03:00
|
|
|
if not valid_json:
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().error("Unable to decode JSON response from AI")
|
2023-07-11 22:11:42 +03:00
|
|
|
data = {}
|
2023-07-20 10:51:21 +03:00
|
|
|
|
2023-07-11 22:11:42 +03:00
|
|
|
return data
|
2023-07-17 01:44:40 +03:00
|
|
|
|
2023-07-18 11:34:57 +03:00
|
|
|
|
2023-07-17 01:44:40 +03:00
|
|
|
def fix_json_escape_char(json_message=None):
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
|
|
|
Fix broken or incomplete JSON messages and return the parsed JSON data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
json_message (str): A string containing the JSON message to be fixed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing the parsed JSON data.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
None
|
|
|
|
|
2023-09-05 08:40:05 +03:00
|
|
|
"""
|
2023-07-17 01:44:40 +03:00
|
|
|
try:
|
|
|
|
result = json.loads(json_message)
|
|
|
|
except Exception as e:
|
|
|
|
# Find the offending character index:
|
|
|
|
idx_to_replace = int(str(e).split(' ')[-1].replace(')', ''))
|
|
|
|
# Remove the offending character:
|
|
|
|
json_message = list(json_message)
|
|
|
|
json_message[idx_to_replace] = ' '
|
|
|
|
new_message = ''.join(json_message)
|
2023-07-18 11:34:57 +03:00
|
|
|
return fix_json_escape_char(json_message=new_message)
|
|
|
|
return result
|
2023-07-18 23:14:47 +03:00
|
|
|
|
|
|
|
|
|
|
|
def convert_str_to_datetime(date_str):
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
|
|
|
Convert a string representation of a date and time into a datetime object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z'
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
datetime: A datetime object representing the input date and time.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC')
|
|
|
|
datetime.datetime(2022, 1, 1, 12, 0, 0)
|
2023-09-05 08:40:05 +03:00
|
|
|
"""
|
2023-07-18 23:14:47 +03:00
|
|
|
datetime_format = '%a, %d %b %Y %H:%M:%S %Z'
|
|
|
|
return datetime.strptime(date_str, datetime_format)
|
|
|
|
|
|
|
|
|
2023-08-03 22:14:05 +03:00
|
|
|
def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str) -> str:
|
2023-07-20 10:51:21 +03:00
|
|
|
"""
|
2023-08-01 14:43:26 +03:00
|
|
|
Generate a patch for a modified file by comparing the original content of the file with the new content provided as
|
|
|
|
input.
|
2023-07-20 10:51:21 +03:00
|
|
|
|
|
|
|
Args:
|
|
|
|
new_file_content_str: The new content of the file as a string.
|
|
|
|
original_file_content_str: The original content of the file as a string.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The generated or provided patch string.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
None.
|
|
|
|
"""
|
2023-08-03 22:14:05 +03:00
|
|
|
patch = ""
|
|
|
|
try:
|
|
|
|
diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True),
|
|
|
|
new_file_content_str.splitlines(keepends=True))
|
|
|
|
if get_settings().config.verbosity_level >= 2:
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().warning(f"File was modified, but no patch was found. Manually creating patch: {filename}.")
|
2023-08-03 22:14:05 +03:00
|
|
|
patch = ''.join(diff)
|
|
|
|
except Exception:
|
|
|
|
pass
|
2023-07-18 23:14:47 +03:00
|
|
|
return patch
|
2023-07-30 11:43:44 +03:00
|
|
|
|
|
|
|
|
2023-08-01 14:43:26 +03:00
|
|
|
def update_settings_from_args(args: List[str]) -> List[str]:
|
2023-07-30 12:14:26 +03:00
|
|
|
"""
|
|
|
|
Update the settings of the Dynaconf object based on the arguments passed to the function.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
args: A list of arguments passed to the function.
|
2023-07-30 12:16:43 +03:00
|
|
|
Example args: ['--pr_code_suggestions.extra_instructions="be funny',
|
|
|
|
'--pr_code_suggestions.num_code_suggestions=3']
|
2023-07-30 12:14:26 +03:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If the argument is not in the correct format.
|
|
|
|
|
|
|
|
"""
|
2023-08-01 14:43:26 +03:00
|
|
|
other_args = []
|
2023-07-30 12:14:26 +03:00
|
|
|
if args:
|
2023-07-30 11:43:44 +03:00
|
|
|
for arg in args:
|
2023-08-01 14:43:26 +03:00
|
|
|
arg = arg.strip()
|
|
|
|
if arg.startswith('--'):
|
2023-07-30 11:43:44 +03:00
|
|
|
arg = arg.strip('-').strip()
|
2023-08-20 10:03:57 +03:00
|
|
|
vals = arg.split('=', 1)
|
2023-07-30 12:14:26 +03:00
|
|
|
if len(vals) != 2:
|
2023-08-22 09:42:59 +03:00
|
|
|
if len(vals) > 2: # --extended is a valid argument
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().error(f'Invalid argument format: {arg}')
|
2023-08-01 14:43:26 +03:00
|
|
|
other_args.append(arg)
|
|
|
|
continue
|
2023-08-20 10:03:57 +03:00
|
|
|
key, value = _fix_key_value(*vals)
|
2023-08-01 14:43:26 +03:00
|
|
|
get_settings().set(key, value)
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().info(f'Updated setting {key} to: "{value}"')
|
2023-08-01 14:43:26 +03:00
|
|
|
else:
|
|
|
|
other_args.append(arg)
|
|
|
|
return other_args
|
2023-08-09 08:50:15 +03:00
|
|
|
|
|
|
|
|
2023-08-20 10:03:57 +03:00
|
|
|
def _fix_key_value(key: str, value: str):
|
|
|
|
key = key.strip().upper()
|
|
|
|
value = value.strip()
|
|
|
|
try:
|
|
|
|
value = yaml.safe_load(value)
|
|
|
|
except Exception as e:
|
2023-11-23 09:16:50 +02:00
|
|
|
get_logger().debug(f"Failed to parse YAML for config override {key}={value}", exc_info=e)
|
2023-08-20 10:03:57 +03:00
|
|
|
return key, value
|
|
|
|
|
|
|
|
|
2023-12-21 08:21:34 +02:00
|
|
|
def load_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
|
2023-11-19 17:30:57 +02:00
|
|
|
response_text = response_text.removeprefix('```yaml').rstrip('`')
|
2023-08-09 08:50:15 +03:00
|
|
|
try:
|
2023-11-19 17:30:57 +02:00
|
|
|
data = yaml.safe_load(response_text)
|
2023-08-09 08:50:15 +03:00
|
|
|
except Exception as e:
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().error(f"Failed to parse AI prediction: {e}")
|
2023-12-21 08:21:34 +02:00
|
|
|
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml)
|
2023-08-09 08:50:15 +03:00
|
|
|
return data
|
|
|
|
|
2023-12-21 08:21:34 +02:00
|
|
|
|
2023-12-21 08:24:07 +02:00
|
|
|
def try_fix_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
|
2023-11-19 17:30:57 +02:00
|
|
|
response_text_lines = response_text.split('\n')
|
2023-11-10 18:44:19 +02:00
|
|
|
|
2024-02-08 23:53:29 +02:00
|
|
|
keys = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:']
|
2023-12-21 08:21:34 +02:00
|
|
|
keys = keys + keys_fix_yaml
|
2023-11-10 18:44:19 +02:00
|
|
|
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
|
2023-11-19 17:30:57 +02:00
|
|
|
response_text_lines_copy = response_text_lines.copy()
|
|
|
|
for i in range(0, len(response_text_lines_copy)):
|
|
|
|
for key in keys:
|
|
|
|
if key in response_text_lines_copy[i] and not '|-' in response_text_lines_copy[i]:
|
2023-11-20 10:30:59 +02:00
|
|
|
response_text_lines_copy[i] = response_text_lines_copy[i].replace(f'{key}',
|
|
|
|
f'{key} |-\n ')
|
2023-11-10 18:44:19 +02:00
|
|
|
try:
|
2023-11-19 17:30:57 +02:00
|
|
|
data = yaml.safe_load('\n'.join(response_text_lines_copy))
|
|
|
|
get_logger().info(f"Successfully parsed AI prediction after adding |-\n")
|
2023-11-10 18:44:19 +02:00
|
|
|
return data
|
|
|
|
except:
|
2023-11-20 10:30:59 +02:00
|
|
|
get_logger().info(f"Failed to parse AI prediction after adding |-\n")
|
2023-11-10 18:44:19 +02:00
|
|
|
|
2023-12-21 10:48:33 +09:00
|
|
|
# second fallback - try to extract only range from first ```yaml to ````
|
|
|
|
snippet_pattern = r'```(yaml)?[\s\S]*?```'
|
|
|
|
snippet = re.search(snippet_pattern, '\n'.join(response_text_lines_copy))
|
|
|
|
if snippet:
|
|
|
|
snippet_text = snippet.group()
|
|
|
|
try:
|
|
|
|
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
|
|
|
|
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
|
|
|
|
return data
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
2023-12-21 08:21:34 +02:00
|
|
|
# third fallback - try to remove leading and trailing curly brackets
|
2024-02-18 07:56:14 +02:00
|
|
|
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}').rstrip(':\n')
|
2023-12-21 08:21:34 +02:00
|
|
|
try:
|
2024-02-18 07:56:14 +02:00
|
|
|
data = yaml.safe_load(response_text_copy)
|
2023-12-21 08:21:34 +02:00
|
|
|
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
|
|
|
|
return data
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# fourth fallback - try to remove last lines
|
2023-08-09 08:50:15 +03:00
|
|
|
data = {}
|
2023-11-19 17:30:57 +02:00
|
|
|
for i in range(1, len(response_text_lines)):
|
|
|
|
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
|
2023-08-09 08:50:15 +03:00
|
|
|
try:
|
2024-02-18 07:56:14 +02:00
|
|
|
data = yaml.safe_load(response_text_lines_tmp)
|
2023-10-16 14:56:00 +03:00
|
|
|
get_logger().info(f"Successfully parsed AI prediction after removing {i} lines")
|
2023-12-21 11:11:46 +09:00
|
|
|
return data
|
2023-08-09 08:50:15 +03:00
|
|
|
except:
|
|
|
|
pass
|
2023-10-24 22:28:57 +03:00
|
|
|
|
|
|
|
|
2023-12-14 07:44:13 +08:00
|
|
|
def set_custom_labels(variables, git_provider=None):
|
2023-10-29 14:58:36 +02:00
|
|
|
if not get_settings().config.enable_custom_labels:
|
|
|
|
return
|
|
|
|
|
2023-10-24 22:28:57 +03:00
|
|
|
labels = get_settings().custom_labels
|
|
|
|
if not labels:
|
|
|
|
# set default labels
|
2023-12-05 07:48:21 +02:00
|
|
|
labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Enhancement', 'Documentation', 'Other']
|
2023-10-24 22:28:57 +03:00
|
|
|
labels_list = "\n - ".join(labels) if labels else ""
|
|
|
|
labels_list = f" - {labels_list}" if labels_list else ""
|
|
|
|
variables["custom_labels"] = labels_list
|
|
|
|
return
|
2023-12-14 07:44:13 +08:00
|
|
|
|
|
|
|
# Set custom labels
|
2023-11-13 15:55:35 +02:00
|
|
|
variables["custom_labels_class"] = "class Label(str, Enum):"
|
2023-12-18 12:29:06 +02:00
|
|
|
counter = 0
|
|
|
|
labels_minimal_to_labels_dict = {}
|
2023-10-24 22:28:57 +03:00
|
|
|
for k, v in labels.items():
|
2023-12-18 12:29:06 +02:00
|
|
|
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
|
|
|
|
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
|
|
|
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
|
|
|
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
|
|
|
|
counter += 1
|
|
|
|
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
|
2023-11-06 15:14:08 +02:00
|
|
|
|
2023-11-08 14:46:11 +02:00
|
|
|
def get_user_labels(current_labels: List[str] = None):
|
|
|
|
"""
|
|
|
|
Only keep labels that has been added by the user
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
if current_labels is None:
|
|
|
|
current_labels = []
|
|
|
|
user_labels = []
|
|
|
|
for label in current_labels:
|
2023-12-05 07:48:21 +02:00
|
|
|
if label.lower() in ['bug fix', 'tests', 'enhancement', 'documentation', 'other']:
|
2023-11-06 15:14:08 +02:00
|
|
|
continue
|
2023-11-08 14:46:11 +02:00
|
|
|
if get_settings().config.enable_custom_labels:
|
|
|
|
if label in get_settings().custom_labels:
|
|
|
|
continue
|
|
|
|
user_labels.append(label)
|
|
|
|
if user_labels:
|
|
|
|
get_logger().info(f"Keeping user labels: {user_labels}")
|
|
|
|
except Exception as e:
|
|
|
|
get_logger().exception(f"Failed to get user labels: {e}")
|
|
|
|
return current_labels
|
2023-11-06 15:14:08 +02:00
|
|
|
return user_labels
|
2023-11-07 14:28:41 +02:00
|
|
|
|
2023-11-07 14:38:37 +02:00
|
|
|
|
2023-11-07 14:28:41 +02:00
|
|
|
def get_max_tokens(model):
|
2023-11-07 14:38:37 +02:00
|
|
|
settings = get_settings()
|
2023-12-03 21:06:55 -05:00
|
|
|
if model in MAX_TOKENS:
|
|
|
|
max_tokens_model = MAX_TOKENS[model]
|
|
|
|
else:
|
|
|
|
raise Exception(f"MAX_TOKENS must be set for model {model} in ./pr_agent/algo/__init__.py")
|
|
|
|
|
2023-11-07 14:41:15 +02:00
|
|
|
if settings.config.max_model_tokens:
|
|
|
|
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
|
2023-11-07 14:38:37 +02:00
|
|
|
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
|
2023-11-07 14:28:41 +02:00
|
|
|
return max_tokens_model
|
2023-11-26 08:29:47 +02:00
|
|
|
|
|
|
|
|
|
|
|
def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str:
|
|
|
|
"""
|
|
|
|
Clip the number of tokens in a string to a maximum number of tokens.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text (str): The string to clip.
|
|
|
|
max_tokens (int): The maximum number of tokens allowed in the string.
|
|
|
|
add_three_dots (bool, optional): A boolean indicating whether to add three dots at the end of the clipped
|
|
|
|
Returns:
|
|
|
|
str: The clipped string.
|
|
|
|
"""
|
|
|
|
if not text:
|
|
|
|
return text
|
|
|
|
|
|
|
|
try:
|
|
|
|
encoder = get_token_encoder()
|
|
|
|
num_input_tokens = len(encoder.encode(text))
|
|
|
|
if num_input_tokens <= max_tokens:
|
|
|
|
return text
|
|
|
|
num_chars = len(text)
|
|
|
|
chars_per_token = num_chars / num_input_tokens
|
|
|
|
num_output_chars = int(chars_per_token * max_tokens)
|
|
|
|
clipped_text = text[:num_output_chars]
|
|
|
|
if add_three_dots:
|
|
|
|
clipped_text += "...(truncated)"
|
|
|
|
return clipped_text
|
|
|
|
except Exception as e:
|
|
|
|
get_logger().warning(f"Failed to clip tokens: {e}")
|
2024-01-15 19:07:41 +02:00
|
|
|
return text
|
|
|
|
|
|
|
|
def replace_code_tags(text):
|
|
|
|
"""
|
|
|
|
Replace odd instances of ` with <code> and even instances of ` with </code>
|
|
|
|
"""
|
|
|
|
parts = text.split('`')
|
|
|
|
for i in range(1, len(parts), 2):
|
|
|
|
parts[i] = '<code>' + parts[i] + '</code>'
|
2024-02-05 09:20:36 +02:00
|
|
|
return ''.join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
|
|
|
relevant_file: str,
|
|
|
|
relevant_line_in_file: str,
|
|
|
|
absolute_position: int = None) -> Tuple[int, int]:
|
|
|
|
position = -1
|
|
|
|
if absolute_position is None:
|
|
|
|
absolute_position = -1
|
|
|
|
re_hunk_header = re.compile(
|
|
|
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
|
|
|
|
|
|
for file in diff_files:
|
|
|
|
if file.filename and (file.filename.strip() == relevant_file):
|
|
|
|
patch = file.patch
|
|
|
|
patch_lines = patch.splitlines()
|
|
|
|
delta = 0
|
|
|
|
start1, size1, start2, size2 = 0, 0, 0, 0
|
|
|
|
if absolute_position != -1: # matching absolute to relative
|
|
|
|
for i, line in enumerate(patch_lines):
|
|
|
|
# new hunk
|
|
|
|
if line.startswith('@@'):
|
|
|
|
delta = 0
|
|
|
|
match = re_hunk_header.match(line)
|
|
|
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
|
|
|
elif not line.startswith('-'):
|
|
|
|
delta += 1
|
|
|
|
|
|
|
|
#
|
|
|
|
absolute_position_curr = start2 + delta - 1
|
|
|
|
|
|
|
|
if absolute_position_curr == absolute_position:
|
|
|
|
position = i
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
# try to find the line in the patch using difflib, with some margin of error
|
|
|
|
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
|
|
|
patch_lines, n=3, cutoff=0.93)
|
|
|
|
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
|
|
|
relevant_line_in_file = matches_difflib[0]
|
|
|
|
|
|
|
|
|
|
|
|
for i, line in enumerate(patch_lines):
|
|
|
|
if line.startswith('@@'):
|
|
|
|
delta = 0
|
|
|
|
match = re_hunk_header.match(line)
|
|
|
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
|
|
|
elif not line.startswith('-'):
|
|
|
|
delta += 1
|
|
|
|
|
|
|
|
if relevant_line_in_file in line and line[0] != '-':
|
|
|
|
position = i
|
|
|
|
absolute_position = start2 + delta - 1
|
|
|
|
break
|
|
|
|
|
|
|
|
if position == -1 and relevant_line_in_file[0] == '+':
|
|
|
|
no_plus_line = relevant_line_in_file[1:].lstrip()
|
|
|
|
for i, line in enumerate(patch_lines):
|
|
|
|
if line.startswith('@@'):
|
|
|
|
delta = 0
|
|
|
|
match = re_hunk_header.match(line)
|
|
|
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
|
|
|
elif not line.startswith('-'):
|
|
|
|
delta += 1
|
|
|
|
|
|
|
|
if no_plus_line in line and line[0] != '-':
|
|
|
|
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
|
|
|
# it's a context line
|
|
|
|
position = i
|
|
|
|
absolute_position = start2 + delta - 1
|
|
|
|
break
|
|
|
|
return position, absolute_position
|