mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
feat: Implement label case conversion and update label descriptions in settings files
This commit is contained in:
@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None):
|
|||||||
|
|
||||||
# Set custom labels
|
# Set custom labels
|
||||||
variables["custom_labels_class"] = "class Label(str, Enum):"
|
variables["custom_labels_class"] = "class Label(str, Enum):"
|
||||||
|
counter = 0
|
||||||
|
labels_minimal_to_labels_dict = {}
|
||||||
for k, v in labels.items():
|
for k, v in labels.items():
|
||||||
description = v['description'].strip('\n').replace('\n', '\\n')
|
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
|
||||||
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
|
||||||
|
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
|
||||||
|
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
|
||||||
|
counter += 1
|
||||||
|
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
|
||||||
|
|
||||||
def get_user_labels(current_labels: List[str] = None):
|
def get_user_labels(current_labels: List[str] = None):
|
||||||
"""
|
"""
|
||||||
|
@ -30,7 +30,7 @@ class Label(str, Enum):
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
|
||||||
class Labels(BaseModel):
|
class Labels(BaseModel):
|
||||||
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
|
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
|
||||||
======
|
======
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel):
|
|||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
|
||||||
{%- if enable_semantic_files_types %}
|
{%- if enable_semantic_files_types %}
|
||||||
|
|
||||||
Class FileDescription(BaseModel):
|
Class FileDescription(BaseModel):
|
||||||
filename: str = Field(description="the relevant file full path")
|
filename: str = Field(description="the relevant file full path")
|
||||||
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
|
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
|
||||||
@ -48,7 +49,7 @@ Class PRDescription(BaseModel):
|
|||||||
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
|
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
|
||||||
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
|
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
|
||||||
{%- if enable_custom_labels %}
|
{%- if enable_custom_labels %}
|
||||||
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
|
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- if enable_file_walkthrough %}
|
{%- if enable_file_walkthrough %}
|
||||||
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
|
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
|
||||||
@ -69,8 +70,10 @@ type:
|
|||||||
- ...
|
- ...
|
||||||
{%- if enable_custom_labels %}
|
{%- if enable_custom_labels %}
|
||||||
labels:
|
labels:
|
||||||
- ...
|
- |
|
||||||
- ...
|
...
|
||||||
|
- |
|
||||||
|
...
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
description: |-
|
description: |-
|
||||||
...
|
...
|
||||||
|
@ -162,6 +162,7 @@ class PRDescription:
|
|||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
set_custom_labels(variables, self.git_provider)
|
set_custom_labels(variables, self.git_provider)
|
||||||
|
self.variables = variables
|
||||||
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
||||||
|
|
||||||
@ -203,6 +204,16 @@ class PRDescription:
|
|||||||
pr_types = self.data['type']
|
pr_types = self.data['type']
|
||||||
elif type(self.data['type']) == str:
|
elif type(self.data['type']) == str:
|
||||||
pr_types = self.data['type'].split(',')
|
pr_types = self.data['type'].split(',')
|
||||||
|
|
||||||
|
# convert lowercase labels to original case
|
||||||
|
try:
|
||||||
|
if "labels_minimal_to_labels_dict" in self.variables:
|
||||||
|
d: dict = self.variables["labels_minimal_to_labels_dict"]
|
||||||
|
for i, label_i in enumerate(pr_types):
|
||||||
|
if label_i in d:
|
||||||
|
pr_types[i] = d[label_i]
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
||||||
return pr_types
|
return pr_types
|
||||||
|
|
||||||
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:
|
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:
|
||||||
|
@ -135,6 +135,7 @@ class PRGenerateLabels:
|
|||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
set_custom_labels(variables, self.git_provider)
|
set_custom_labels(variables, self.git_provider)
|
||||||
|
self.variables = variables
|
||||||
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
|
||||||
|
|
||||||
@ -170,4 +171,14 @@ class PRGenerateLabels:
|
|||||||
elif type(self.data['labels']) == str:
|
elif type(self.data['labels']) == str:
|
||||||
pr_types = self.data['labels'].split(',')
|
pr_types = self.data['labels'].split(',')
|
||||||
|
|
||||||
|
# convert lowercase labels to original case
|
||||||
|
try:
|
||||||
|
if "labels_minimal_to_labels_dict" in self.variables:
|
||||||
|
d: dict = self.variables["labels_minimal_to_labels_dict"]
|
||||||
|
for i, label_i in enumerate(pr_types):
|
||||||
|
if label_i in d:
|
||||||
|
pr_types[i] = d[label_i]
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
|
||||||
|
|
||||||
return pr_types
|
return pr_types
|
||||||
|
Reference in New Issue
Block a user