feat: Enhance YAML parsing with additional fallbacks and key customization in load_yaml and try_fix_yaml functions

This commit is contained in:
mrT23
2023-12-21 08:21:34 +02:00
parent 37259e550f
commit 553dad0bee

View File

@ -316,19 +316,21 @@ def _fix_key_value(key: str, value: str):
return key, value
def load_yaml(response_text: str) -> dict:
def load_yaml(response_text: str, keys_fix_yaml: List[str] = []) -> dict:
response_text = response_text.removeprefix('```yaml').rstrip('`')
try:
data = yaml.safe_load(response_text)
except Exception as e:
get_logger().error(f"Failed to parse AI prediction: {e}")
data = try_fix_yaml(response_text)
data = try_fix_yaml(response_text, keys_fix_yaml=keys_fix_yaml)
return data
def try_fix_yaml(response_text: str) -> dict:
def try_fix_yaml(response_text: str, keys_fix_yaml: List[str]) -> dict:
response_text_lines = response_text.split('\n')
keys = ['relevant line:', 'suggestion content:', 'relevant file:', 'existing code:', 'improved code:']
keys = keys + keys_fix_yaml
# first fallback - try to convert 'relevant line: ...' to relevant line: |-\n ...'
response_text_lines_copy = response_text_lines.copy()
for i in range(0, len(response_text_lines_copy)):
@ -355,7 +357,16 @@ def try_fix_yaml(response_text: str) -> dict:
except:
pass
# thrid fallback - try to remove last lines
# third fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
try:
data = yaml.safe_load(response_text_copy,)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
return data
except:
pass
# fourth fallback - try to remove last lines
data = {}
for i in range(1, len(response_text_lines)):
response_text_lines_tmp = '\n'.join(response_text_lines[:-i])
@ -365,15 +376,6 @@ def try_fix_yaml(response_text: str) -> dict:
return data
except:
pass
# fourth fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
try:
data = yaml.safe_load(response_text_copy,)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
return data
except:
pass
def set_custom_labels(variables, git_provider=None):