feat: Implement label case conversion and update label descriptions in settings files

This commit is contained in:
mrT23
2023-12-18 12:29:06 +02:00
parent 54891ad1d2
commit 1c4e64333c
5 changed files with 37 additions and 6 deletions

View File

@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None):
# Set custom labels # Set custom labels
variables["custom_labels_class"] = "class Label(str, Enum):" variables["custom_labels_class"] = "class Label(str, Enum):"
counter = 0
labels_minimal_to_labels_dict = {}
for k, v in labels.items(): for k, v in labels.items():
description = v['description'].strip('\n').replace('\n', '\\n') description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}" # variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
counter += 1
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
def get_user_labels(current_labels: List[str] = None): def get_user_labels(current_labels: List[str] = None):
""" """

View File

@ -30,7 +30,7 @@ class Label(str, Enum):
{%- endif %} {%- endif %}
class Labels(BaseModel): class Labels(BaseModel):
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.") labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
====== ======

View File

@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel):
{%- endif %} {%- endif %}
{%- if enable_semantic_files_types %} {%- if enable_semantic_files_types %}
Class FileDescription(BaseModel): Class FileDescription(BaseModel):
filename: str = Field(description="the relevant file full path") filename: str = Field(description="the relevant file full path")
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file") changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
@ -48,7 +49,7 @@ Class PRDescription(BaseModel):
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.") type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}") description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
{%- if enable_custom_labels %} {%- if enable_custom_labels %}
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.") labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
{%- endif %} {%- endif %}
{%- if enable_file_walkthrough %} {%- if enable_file_walkthrough %}
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10) main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
@ -69,8 +70,10 @@ type:
- ... - ...
{%- if enable_custom_labels %} {%- if enable_custom_labels %}
labels: labels:
- ... - |
- ... ...
- |
...
{%- endif %} {%- endif %}
description: |- description: |-
... ...

View File

@ -162,6 +162,7 @@ class PRDescription:
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider) set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
@ -203,6 +204,16 @@ class PRDescription:
pr_types = self.data['type'] pr_types = self.data['type']
elif type(self.data['type']) == str: elif type(self.data['type']) == str:
pr_types = self.data['type'].split(',') pr_types = self.data['type'].split(',')
# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
return pr_types return pr_types
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]: def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:

View File

@ -135,6 +135,7 @@ class PRGenerateLabels:
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider) set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
@ -170,4 +171,14 @@ class PRGenerateLabels:
elif type(self.data['labels']) == str: elif type(self.data['labels']) == str:
pr_types = self.data['labels'].split(',') pr_types = self.data['labels'].split(',')
# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
return pr_types return pr_types