diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 97b02b56..981a0068 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -317,12 +317,6 @@ def _fix_key_value(key: str, value: str): def load_yaml(response_text: str) -> dict: - # remove everything before the first ```yaml - snipet_pattern = r'```(yaml)?[\s\S]*?```' - snipet = re.search(snipet_pattern, response_text) - if snipet: - response_text = snipet.group() - response_text = response_text.removeprefix('```yaml').rstrip('`') try: data = yaml.safe_load(response_text) @@ -349,7 +343,19 @@ def try_fix_yaml(response_text: str) -> dict: except: get_logger().info(f"Failed to parse AI prediction after adding |-\n") - # second fallback - try to remove last lines + # second fallback - try to extract only range from first ```yaml to ```` + snippet_pattern = r'```(yaml)?[\s\S]*?```' + snippet = re.search(snippet_pattern, '\n'.join(response_text_lines_copy)) + if snippet: + snippet_text = snippet.group() + try: + data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`')) + get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet") + return data + except: + pass + + # thrid fallback - try to remove last lines data = {} for i in range(1, len(response_text_lines)): response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) @@ -360,7 +366,7 @@ def try_fix_yaml(response_text: str) -> dict: except: pass - # thrid fallback - try to remove leading and trailing curly brackets + # fourth fallback - try to remove leading and trailing curly brackets response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}') try: data = yaml.safe_load(response_text_copy,)