diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 8d008e92..1bf37518 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -316,19 +316,21 @@ def _fix_key_value(key: str, value: str): return key, value -def load_yaml(response_text: str) -> dict: +def load_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict: response_text = response_text.removeprefix('```yaml').rstrip('`') try: data = yaml.safe_load(response_text) except Exception as e: get_logger().error(f"Failed to parse AI prediction: {e}") - data = try_fix_yaml(response_text) + data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml) return data -def try_fix_yaml(response_text: str) -> dict: + +def try_fix_yaml(response_text: str, keys_fix_yaml: List[str]) -> dict: response_text_lines = response_text.split('\n') keys = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:'] + keys = keys + keys_fix_yaml # first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...' response_text_lines_copy = response_text_lines.copy() for i in range(0, len(response_text_lines_copy)): @@ -355,7 +357,16 @@ def try_fix_yaml(response_text: str) -> dict: except: pass - # thrid fallback - try to remove last lines + # third fallback - try to remove leading and trailing curly brackets + response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}') + try: + data = yaml.safe_load(response_text_copy,) + get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") + return data + except: + pass + + # fourth fallback - try to remove last lines data = {} for i in range(1, len(response_text_lines)): response_text_lines_tmp = '\n'.join(response_text_lines[:-i]) @@ -365,15 +376,6 @@ def try_fix_yaml(response_text: str) -> dict: return data except: pass - - # fourth fallback - try to remove leading and trailing curly brackets - response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}') - try: - data = yaml.safe_load(response_text_copy,) - get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") - return data - except: - pass def set_custom_labels(variables, git_provider=None):