diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 304c2200..440b6615 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -324,3 +324,20 @@ def set_custom_labels(variables): final_labels += f" - {k} ({v['description']})\n" variables["custom_labels"] = final_labels variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" + + +def get_user_labels(current_labels): + ## Only keep labels that has been added by the user + if current_labels is None: + current_labels = [] + user_labels = [] + for label in current_labels: + if label in ['Bug fix', 'Tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']: + continue + if get_settings().config.enable_custom_labels: + if label in get_settings().custom_labels: + continue + user_labels.append(label) + if user_labels: + get_logger().info(f"Keeping user labels: {user_labels}") + return user_labels diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index ae07e71f..cfb42948 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -63,7 +63,8 @@ PR Type: ... {%- if enable_custom_labels %} PR Labels: -{{ custom_labels_examples }} +- ... +- ... {%- endif %} PR Description: |- ... diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c2c6ba98..611523ea 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, set_custom_labels +from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -98,9 +98,9 @@ class PRDescription: self.git_provider.publish_description(pr_title, pr_body) if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"): current_labels = self.git_provider.get_labels() - if current_labels is None: - current_labels = [] - self.git_provider.publish_labels(pr_labels + current_labels) + user_labels = get_user_labels(current_labels) + + self.git_provider.publish_labels(pr_labels + user_labels) self.git_provider.remove_initial_comment() except Exception as e: get_logger().error(f"Error generating PR description {self.pr_id}: {e}") diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index 94dc2815..e413e96f 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, set_custom_labels +from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -84,9 +84,8 @@ class PRGenerateLabels: get_logger().info(f"Pushing labels {self.pr_id}") current_labels = self.git_provider.get_labels() - if current_labels is None: - current_labels = [] - pr_labels = pr_labels + current_labels + user_labels = get_user_labels(current_labels) + pr_labels = pr_labels + user_labels if self.git_provider.is_supported("get_labels"): self.git_provider.publish_labels(pr_labels)