From db5138dc428575c0c2245f2f04b58921b2ecc825 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Sat, 17 May 2025 20:38:05 +0300 Subject: [PATCH] Improve YAML parsing with additional fallback strategies for AI predictions --- pr_agent/algo/utils.py | 69 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 4a341386..780c7953 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -731,8 +731,9 @@ def try_fix_yaml(response_text: str, response_text_original="") -> dict: response_text_lines = response_text.split('\n') - keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:'] + keys_yaml = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:', 'label:'] keys_yaml = keys_yaml + keys_fix_yaml + # first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...' response_text_lines_copy = response_text_lines.copy() for i in range(0, len(response_text_lines_copy)): @@ -747,8 +748,29 @@ def try_fix_yaml(response_text: str, except: pass - # second fallback - try to extract only range from first ```yaml to ```` - snippet_pattern = r'```(yaml)?[\s\S]*?```' + # 1.5 fallback - try to convert '|' to '|2'. Will solve cases of indent decreasing during the code + response_text_copy = copy.deepcopy(response_text) + response_text_copy = response_text_copy.replace('|\n', '|2\n') + try: + data = yaml.safe_load(response_text_copy) + get_logger().info(f"Successfully parsed AI prediction after replacing | with |2") + return data + except: + # if it fails, we can try to add spaces to the lines that are not indented properly, and contain '}'. + response_text_lines_copy = response_text_copy.split('\n') + for i in range(0, len(response_text_lines_copy)): + initial_space = len(response_text_lines_copy[i]) - len(response_text_lines_copy[i].lstrip()) + if initial_space == 2 and '|2' not in response_text_lines_copy[i] and '}' in response_text_lines_copy[i]: + response_text_lines_copy[i] = ' ' + response_text_lines_copy[i].lstrip() + try: + data = yaml.safe_load('\n'.join(response_text_lines_copy)) + get_logger().info(f"Successfully parsed AI prediction after replacing | with |2 and adding spaces") + return data + except: + pass + + # second fallback - try to extract only range from first ```yaml to the last ``` + snippet_pattern = r'```yaml([\s\S]*?)```(?=\s*$|")' snippet = re.search(snippet_pattern, '\n'.join(response_text_lines_copy)) if not snippet: snippet = re.search(snippet_pattern, response_text_original) # before we removed the "```" @@ -803,16 +825,47 @@ def try_fix_yaml(response_text: str, except: pass - # sixth fallback - try to remove last lines - for i in range(1, len(response_text_lines)): - response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) + # sixth fallback - replace tabs with spaces + if '\t' in response_text: + response_text_copy = copy.deepcopy(response_text) + response_text_copy = response_text_copy.replace('\t', ' ') try: - data = yaml.safe_load(response_text_lines_tmp) - get_logger().info(f"Successfully parsed AI prediction after removing {i} lines") + data = yaml.safe_load(response_text_copy) + get_logger().info(f"Successfully parsed AI prediction after replacing tabs with spaces") return data except: pass + # seventh fallback - add indent for sections of code blocks + response_text_copy = copy.deepcopy(response_text) + response_text_copy_lines = response_text_copy.split('\n') + start_line = -1 + for i, line in enumerate(response_text_copy_lines): + if 'existing_code:' in line or 'improved_code:' in line: + start_line = i + elif line.endswith(': |') or line.endswith(': |-') or line.endswith(': |2') or line.endswith(':'): + start_line = -1 + elif start_line != -1: + response_text_copy_lines[i] = ' ' + line + response_text_copy = '\n'.join(response_text_copy_lines) + try: + data = yaml.safe_load(response_text_copy) + get_logger().info(f"Successfully parsed AI prediction after adding indent for sections of code blocks") + return data + except: + pass + + # # sixth fallback - try to remove last lines + # for i in range(1, len(response_text_lines)): + # response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) + # try: + # data = yaml.safe_load(response_text_lines_tmp) + # get_logger().info(f"Successfully parsed AI prediction after removing {i} lines") + # return data + # except: + # pass + + def set_custom_labels(variables, git_provider=None): if not get_settings().config.enable_custom_labels: