diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 0ec1397b..97b02b56 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -317,6 +317,12 @@ def _fix_key_value(key: str, value: str): def load_yaml(response_text: str) -> dict: + # remove everything before the first ```yaml + snipet_pattern = r'```(yaml)?[\s\S]*?```' + snipet = re.search(snipet_pattern, response_text) + if snipet: + response_text = snipet.group() + response_text = response_text.removeprefix('```yaml').rstrip('`') try: data = yaml.safe_load(response_text) diff --git a/tests/unittest/test_load_yaml.py b/tests/unittest/test_load_yaml.py index a77c847b..34beee35 100644 --- a/tests/unittest/test_load_yaml.py +++ b/tests/unittest/test_load_yaml.py @@ -15,6 +15,18 @@ class TestLoadYaml: expected_output = {'name': 'John Smith', 'age': 35} assert load_yaml(yaml_str) == expected_output + def test_load_valid_yaml_with_description(self): + yaml_str = '''\ +Here is the answer in YAML format: + +```yaml +name: John Smith +age: 35 +``` +''' + expected_output = {'name': 'John Smith', 'age': 35} + assert load_yaml(yaml_str) == expected_output + def test_load_invalid_yaml1(self): yaml_str = \ '''\