from __future__ import annotations import difflib import json import logging import re import textwrap from dataclasses import dataclass from datetime import datetime from enum import Enum from typing import Any, List, Tuple, Optional import yaml from starlette_context import context from pr_agent.algo.token_handler import get_token_encoder from pr_agent.config_loader import get_settings, global_settings class EDIT_TYPE(Enum): ADDED = 1 DELETED = 2 MODIFIED = 3 RENAMED = 4 def get_setting(key: str) -> Any: try: key = key.upper() return context.get("settings", global_settings).get(key, global_settings.get(key, None)) except Exception: return global_settings.get(key, None) def convert_to_markdown(output_data: dict) -> str: """ Convert a dictionary of data into markdown format. Args: output_data (dict): A dictionary containing data to be converted to markdown format. Returns: str: The markdown formatted text generated from the input dictionary. """ markdown_text = "" emojis = { "Main theme": "๐ŸŽฏ", "PR summary": "๐Ÿ“", "Type of PR": "๐Ÿ“Œ", "Score": "๐Ÿ…", "Relevant tests added": "๐Ÿงช", "Unrelated changes": "โš ๏ธ", "Focused PR": "โœจ", "Security concerns": "๐Ÿ”’", "General suggestions": "๐Ÿ’ก", "Insights from user's answers": "๐Ÿ“", "Code feedback": "๐Ÿค–", } for key, value in output_data.items(): if value is None or value == '' or value == {}: continue if isinstance(value, dict): markdown_text += f"## {key}\n\n" markdown_text += convert_to_markdown(value) elif isinstance(value, list): emoji = emojis.get(key, "") if key.lower() == 'code feedback': markdown_text += f"\n\n- **
{ emoji } Code feedback:**\n\n" else: markdown_text += f"- {emoji} **{key}:**\n\n" for item in value: if isinstance(item, dict) and key.lower() == 'code feedback': markdown_text += parse_code_suggestion(item) elif item: markdown_text += f" - {item}\n" if key.lower() == 'code feedback': markdown_text += "
\n\n" elif value != 'n/a': emoji = emojis.get(key, "") markdown_text += f"- {emoji} **{key}:** {value}\n" return markdown_text def parse_code_suggestion(code_suggestions: dict) -> str: """ Convert a dictionary of data into markdown format. Args: code_suggestions (dict): A dictionary containing data to be converted to markdown format. Returns: str: A string containing the markdown formatted text generated from the input dictionary. """ markdown_text = "" for sub_key, sub_value in code_suggestions.items(): if isinstance(sub_value, dict): # "code example" markdown_text += f" - **{sub_key}:**\n" for code_key, code_value in sub_value.items(): # 'before' and 'after' code code_str = f"```\n{code_value}\n```" code_str_indented = textwrap.indent(code_str, ' ') markdown_text += f" - **{code_key}:**\n{code_str_indented}\n" else: if "relevant file" in sub_key.lower(): markdown_text += f"\n - **{sub_key}:** {sub_value}\n" else: markdown_text += f" **{sub_key}:** {sub_value}\n" markdown_text += "\n" return markdown_text def try_fix_json(review, max_iter=10, code_suggestions=False): """ Fix broken or incomplete JSON messages and return the parsed JSON data. Args: - review: A string containing the JSON message to be fixed. - max_iter: An integer representing the maximum number of iterations to try and fix the JSON message. - code_suggestions: A boolean indicating whether to try and fix JSON messages with code feedback. Returns: - data: A dictionary containing the parsed JSON data. The function attempts to fix broken or incomplete JSON messages by parsing until the last valid code suggestion. If the JSON message ends with a closing bracket, the function calls the fix_json_escape_char function to fix the message. If code_suggestions is True and the JSON message contains code feedback, the function tries to fix the JSON message by parsing until the last valid code suggestion. The function uses regular expressions to find the last occurrence of "}," with any number of whitespaces or newlines. It tries to parse the JSON message with the closing bracket and checks if it is valid. If the JSON message is valid, the parsed JSON data is returned. If the JSON message is not valid, the last code suggestion is removed and the process is repeated until a valid JSON message is obtained or the maximum number of iterations is reached. If a valid JSON message is not obtained, an error is logged and an empty dictionary is returned. """ if review.endswith("}"): return fix_json_escape_char(review) data = {} if code_suggestions: closing_bracket = "]}" else: closing_bracket = "]}}" if (review.rfind("'Code feedback': [") > 0 or review.rfind('"Code feedback": [') > 0) or \ (review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0) : last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 valid_json = False iter_count = 0 while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter: try: data = json.loads(review[:last_code_suggestion_ind] + closing_bracket) valid_json = True review = review[:last_code_suggestion_ind].strip() + closing_bracket except json.decoder.JSONDecodeError: review = review[:last_code_suggestion_ind] last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 iter_count += 1 if not valid_json: logging.error("Unable to decode JSON response from AI") data = {} return data def fix_json_escape_char(json_message=None): """ Fix broken or incomplete JSON messages and return the parsed JSON data. Args: json_message (str): A string containing the JSON message to be fixed. Returns: dict: A dictionary containing the parsed JSON data. Raises: None """ try: result = json.loads(json_message) except Exception as e: # Find the offending character index: idx_to_replace = int(str(e).split(' ')[-1].replace(')', '')) # Remove the offending character: json_message = list(json_message) json_message[idx_to_replace] = ' ' new_message = ''.join(json_message) return fix_json_escape_char(json_message=new_message) return result def convert_str_to_datetime(date_str): """ Convert a string representation of a date and time into a datetime object. Args: date_str (str): A string representation of a date and time in the format '%a, %d %b %Y %H:%M:%S %Z' Returns: datetime: A datetime object representing the input date and time. Example: >>> convert_str_to_datetime('Mon, 01 Jan 2022 12:00:00 UTC') datetime.datetime(2022, 1, 1, 12, 0, 0) """ datetime_format = '%a, %d %b %Y %H:%M:%S %Z' return datetime.strptime(date_str, datetime_format) def load_large_diff(filename, new_file_content_str: str, original_file_content_str: str) -> str: """ Generate a patch for a modified file by comparing the original content of the file with the new content provided as input. Args: new_file_content_str: The new content of the file as a string. original_file_content_str: The original content of the file as a string. Returns: The generated or provided patch string. Raises: None. """ patch = "" try: diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), new_file_content_str.splitlines(keepends=True)) if get_settings().config.verbosity_level >= 2: logging.warning(f"File was modified, but no patch was found. Manually creating patch: {filename}.") patch = ''.join(diff) except Exception: pass return patch def update_settings_from_args(args: List[str]) -> List[str]: """ Update the settings of the Dynaconf object based on the arguments passed to the function. Args: args: A list of arguments passed to the function. Example args: ['--pr_code_suggestions.extra_instructions="be funny', '--pr_code_suggestions.num_code_suggestions=3'] Returns: None Raises: ValueError: If the argument is not in the correct format. """ other_args = [] if args: for arg in args: arg = arg.strip() if arg.startswith('--'): arg = arg.strip('-').strip() vals = arg.split('=', 1) if len(vals) != 2: if len(vals) > 2: # --extended is a valid argument logging.error(f'Invalid argument format: {arg}') other_args.append(arg) continue key, value = _fix_key_value(*vals) get_settings().set(key, value) logging.info(f'Updated setting {key} to: "{value}"') else: other_args.append(arg) return other_args def _fix_key_value(key: str, value: str): key = key.strip().upper() value = value.strip() try: value = yaml.safe_load(value) except Exception as e: logging.error(f"Failed to parse YAML for config override {key}={value}", exc_info=e) return key, value def load_yaml(review_text: str) -> dict: review_text = review_text.removeprefix('```yaml').rstrip('`') try: data = yaml.safe_load(review_text) except Exception as e: logging.error(f"Failed to parse AI prediction: {e}") data = try_fix_yaml(review_text) return data def try_fix_yaml(review_text: str) -> dict: review_text_lines = review_text.split('\n') data = {} for i in range(1, len(review_text_lines)): review_text_lines_tmp = '\n'.join(review_text_lines[:-i]) try: data = yaml.load(review_text_lines_tmp, Loader=yaml.SafeLoader) logging.info(f"Successfully parsed AI prediction after removing {i} lines") break except: pass return data def clip_tokens(text: str, max_tokens: int) -> str: """ Clip the number of tokens in a string to a maximum number of tokens. Args: text (str): The string to clip. max_tokens (int): The maximum number of tokens allowed in the string. Returns: str: The clipped string. """ if not text: return text try: encoder = get_token_encoder() num_input_tokens = len(encoder.encode(text)) if num_input_tokens <= max_tokens: return text num_chars = len(text) chars_per_token = num_chars / num_input_tokens num_output_chars = int(chars_per_token * max_tokens) clipped_text = text[:num_output_chars] return clipped_text except Exception as e: logging.warning(f"Failed to clip tokens: {e}") return text @dataclass class FilePatchInfo: base_file: str head_file: str patch: str filename: str tokens: int = -1 edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED old_filename: str = None language: Optional[str] = None def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo], relevant_file: str, relevant_line_in_file: str) -> Tuple[int, int]: """ Find the line number and absolute position of a relevant line in a file. Args: diff_files (List[FilePatchInfo]): A list of FilePatchInfo objects representing the patches of files. relevant_file (str): The name of the file where the relevant line is located. relevant_line_in_file (str): The content of the relevant line. Returns: Tuple[int, int]: A tuple containing the line number and absolute position of the relevant line in the file. """ position = -1 absolute_position = -1 re_hunk_header = re.compile( r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") for file in diff_files: if file.filename.strip() == relevant_file: patch = file.patch patch_lines = patch.splitlines() # try to find the line in the patch using difflib, with some margin of error matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file, patch_lines, n=3, cutoff=0.93) if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'): relevant_line_in_file = matches_difflib[0] delta = 0 start1, size1, start2, size2 = 0, 0, 0, 0 for i, line in enumerate(patch_lines): if line.startswith('@@'): delta = 0 match = re_hunk_header.match(line) start1, size1, start2, size2 = map(int, match.groups()[:4]) elif not line.startswith('-'): delta += 1 if relevant_line_in_file in line and line[0] != '-': position = i absolute_position = start2 + delta - 1 break if position == -1 and relevant_line_in_file[0] == '+': no_plus_line = relevant_line_in_file[1:].lstrip() for i, line in enumerate(patch_lines): if line.startswith('@@'): delta = 0 match = re_hunk_header.match(line) start1, size1, start2, size2 = map(int, match.groups()[:4]) elif not line.startswith('-'): delta += 1 if no_plus_line in line and line[0] != '-': # The model might add a '+' to the beginning of the relevant_line_in_file even if originally # it's a context line position = i absolute_position = start2 + delta - 1 break return position, absolute_position