re-implemented YAML extraction as a fallback

This commit is contained in:
koid
2023-12-21 10:48:33 +09:00
parent 16b61eb4e8
commit e2797ad09a

View File

@ -317,12 +317,6 @@ def _fix_key_value(key: str, value: str):
def load_yaml(response_text: str) -> dict:
# remove everything before the first ```yaml
snipet_pattern = r'```(yaml)?[\s\S]*?```'
snipet = re.search(snipet_pattern, response_text)
if snipet:
response_text = snipet.group()
response_text = response_text.removeprefix('```yaml').rstrip('`')
try:
data = yaml.safe_load(response_text)
@ -349,7 +343,19 @@ def try_fix_yaml(response_text: str) -> dict:
except:
get_logger().info(f"Failed to parse AI prediction after adding |-\n")
# second fallback - try to remove last lines
# second fallback - try to extract only range from first ```yaml to ````
snippet_pattern = r'```(yaml)?[\s\S]*?```'
snippet = re.search(snippet_pattern, '\n'.join(response_text_lines_copy))
if snippet:
snippet_text = snippet.group()
try:
data = yaml.safe_load(snippet_text.removeprefix('```yaml').rstrip('`'))
get_logger().info(f"Successfully parsed AI prediction after extracting yaml snippet")
return data
except:
pass
# thrid fallback - try to remove last lines
data = {}
for i in range(1, len(response_text_lines)):
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
@ -360,7 +366,7 @@ def try_fix_yaml(response_text: str) -> dict:
except:
pass
# thrid fallback - try to remove leading and trailing curly brackets
# fourth fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
try:
data = yaml.safe_load(response_text_copy,)