From ec3f2fb485ad318d6f3cea953da2369fcd2a7413 Mon Sep 17 00:00:00 2001 From: "Hussam.lawen" Date: Mon, 6 Nov 2023 15:08:29 +0200 Subject: [PATCH] Revert "generate labels keep user labels only" This reverts commit 94a2a5e527d8d0ec20e962cff73cc465894f3217. --- pr_agent/algo/utils.py | 17 ----------------- pr_agent/settings/pr_description_prompts.toml | 3 +-- pr_agent/tools/pr_description.py | 8 ++++---- pr_agent/tools/pr_generate_labels.py | 7 ++++--- 4 files changed, 9 insertions(+), 26 deletions(-) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 440b6615..304c2200 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -324,20 +324,3 @@ def set_custom_labels(variables): final_labels += f" - {k} ({v['description']})\n" variables["custom_labels"] = final_labels variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" - - -def get_user_labels(current_labels): - ## Only keep labels that has been added by the user - if current_labels is None: - current_labels = [] - user_labels = [] - for label in current_labels: - if label in ['Bug fix', 'Tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']: - continue - if get_settings().config.enable_custom_labels: - if label in get_settings().custom_labels: - continue - user_labels.append(label) - if user_labels: - get_logger().info(f"Keeping user labels: {user_labels}") - return user_labels diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index cfb42948..ae07e71f 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -63,8 +63,7 @@ PR Type: ... {%- if enable_custom_labels %} PR Labels: -- ... -- ... +{{ custom_labels_examples }} {%- endif %} PR Description: |- ... diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 611523ea..c2c6ba98 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels +from pr_agent.algo.utils import load_yaml, set_custom_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -98,9 +98,9 @@ class PRDescription: self.git_provider.publish_description(pr_title, pr_body) if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"): current_labels = self.git_provider.get_labels() - user_labels = get_user_labels(current_labels) - - self.git_provider.publish_labels(pr_labels + user_labels) + if current_labels is None: + current_labels = [] + self.git_provider.publish_labels(pr_labels + current_labels) self.git_provider.remove_initial_comment() except Exception as e: get_logger().error(f"Error generating PR description {self.pr_id}: {e}") diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index e413e96f..94dc2815 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels +from pr_agent.algo.utils import load_yaml, set_custom_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -84,8 +84,9 @@ class PRGenerateLabels: get_logger().info(f"Pushing labels {self.pr_id}") current_labels = self.git_provider.get_labels() - user_labels = get_user_labels(current_labels) - pr_labels = pr_labels + user_labels + if current_labels is None: + current_labels = [] + pr_labels = pr_labels + current_labels if self.git_provider.is_supported("get_labels"): self.git_provider.publish_labels(pr_labels)