From fdc776887d79360d17b3d9748fcf1d2c06ef033c Mon Sep 17 00:00:00 2001 From: "Hussam.lawen" Date: Mon, 11 Dec 2023 16:47:38 +0200 Subject: [PATCH] Refactor labels --- pr_agent/algo/utils.py | 9 +++------ pr_agent/git_providers/bitbucket_provider.py | 2 +- pr_agent/git_providers/bitbucket_server_provider.py | 2 +- pr_agent/git_providers/codecommit_provider.py | 2 +- pr_agent/git_providers/gerrit_provider.py | 2 +- pr_agent/git_providers/git_provider.py | 5 ++++- pr_agent/git_providers/github_provider.py | 6 +++++- pr_agent/git_providers/gitlab_provider.py | 2 +- pr_agent/git_providers/local_git_provider.py | 2 +- pr_agent/tools/pr_description.py | 4 ++-- pr_agent/tools/pr_generate_labels.py | 4 ++-- pr_agent/tools/pr_reviewer.py | 2 +- 12 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 8b2f21d7..6d0d3731 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -364,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict: pass -def set_custom_labels(variables): +def set_custom_labels(variables, git_provider=None): if not get_settings().config.enable_custom_labels: return @@ -376,11 +376,8 @@ def set_custom_labels(variables): labels_list = f" - {labels_list}" if labels_list else "" variables["custom_labels"] = labels_list return - #final_labels = "" - #for k, v in labels.items(): - # final_labels += f" - {k} ({v['description']})\n" - #variables["custom_labels"] = final_labels - #variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" + + # Set custom labels variables["custom_labels_class"] = "class Label(str, Enum):" for k, v in labels.items(): description = v['description'].strip('\n').replace('\n', '\\n') diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index ee8ad48f..23173f8e 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -354,5 +354,5 @@ class BitbucketProvider(GitProvider): pass # bitbucket does not support labels - def get_labels(self): + def get_pr_labels(self): pass diff --git a/pr_agent/git_providers/bitbucket_server_provider.py b/pr_agent/git_providers/bitbucket_server_provider.py index 44347850..902beb16 100644 --- a/pr_agent/git_providers/bitbucket_server_provider.py +++ b/pr_agent/git_providers/bitbucket_server_provider.py @@ -344,7 +344,7 @@ class BitbucketServerProvider(GitProvider): pass # bitbucket does not support labels - def get_labels(self): + def get_pr_labels(self): pass def _get_pr_comments_url(self): diff --git a/pr_agent/git_providers/codecommit_provider.py b/pr_agent/git_providers/codecommit_provider.py index 64cfc70a..286444c5 100644 --- a/pr_agent/git_providers/codecommit_provider.py +++ b/pr_agent/git_providers/codecommit_provider.py @@ -216,7 +216,7 @@ class CodeCommitProvider(GitProvider): def publish_labels(self, labels): return [""] # not implemented yet - def get_labels(self): + def get_pr_labels(self): return [""] # not implemented yet def remove_initial_comment(self): diff --git a/pr_agent/git_providers/gerrit_provider.py b/pr_agent/git_providers/gerrit_provider.py index d286b1bf..dbdbe82f 100644 --- a/pr_agent/git_providers/gerrit_provider.py +++ b/pr_agent/git_providers/gerrit_provider.py @@ -207,7 +207,7 @@ class GerritProvider(GitProvider): Comment = namedtuple('Comment', ['body']) return Comments([Comment(c['message']) for c in reversed(comments)]) - def get_labels(self): + def get_pr_labels(self): raise NotImplementedError( 'Getting labels is not implemented for the gerrit provider') diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index deb5df3d..4c4684c3 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -135,7 +135,10 @@ class GitProvider(ABC): pass @abstractmethod - def get_labels(self): + def get_pr_labels(self): + pass + + def get_repo_labels(self): pass @abstractmethod diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 3ae97742..f365db84 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -461,13 +461,17 @@ class GithubProvider(GitProvider): except Exception as e: get_logger().exception(f"Failed to publish labels, error: {e}") - def get_labels(self): + def get_pr_labels(self): try: return [label.name for label in self.pr.labels] except Exception as e: get_logger().exception(f"Failed to get labels, error: {e}") return [] + def get_repo_labels(self): + labels = self.repo_obj.get_labels() + return [label for label in labels] + def get_commit_messages(self): """ Retrieves the commit messages of a pull request. diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index c5e77d07..618cebc0 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -408,7 +408,7 @@ class GitLabProvider(GitProvider): def publish_inline_comments(self, comments: list[dict]): pass - def get_labels(self): + def get_pr_labels(self): return self.mr.labels def get_commit_messages(self): diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 0ef11413..b3fad772 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -178,5 +178,5 @@ class LocalGitProvider(GitProvider): def get_issue_comments(self): raise NotImplementedError('Getting issue comments is not implemented for the local git provider') - def get_labels(self): + def get_pr_labels(self): raise NotImplementedError('Getting labels is not implemented for the local git provider') diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 64acaab3..05fb63f8 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -106,7 +106,7 @@ class PRDescription: else: self.git_provider.publish_description(pr_title, pr_body) if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"): - current_labels = self.git_provider.get_labels() + current_labels = self.git_provider.get_pr_labels() user_labels = get_user_labels(current_labels) self.git_provider.publish_labels(pr_labels + user_labels) @@ -158,7 +158,7 @@ class PRDescription: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - set_custom_labels(variables) + set_custom_labels(variables, self.git_provider) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index 6ea322a4..fc90ed44 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -82,7 +82,7 @@ class PRGenerateLabels: if get_settings().config.publish_output: get_logger().info(f"Pushing labels {self.pr_id}") - current_labels = self.git_provider.get_labels() + current_labels = self.git_provider.get_pr_labels() user_labels = get_user_labels(current_labels) pr_labels = pr_labels + user_labels @@ -132,7 +132,7 @@ class PRGenerateLabels: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - set_custom_labels(variables) + set_custom_labels(variables, self.git_provider) system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 13fc1717..5a6f720a 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -392,7 +392,7 @@ class PRReviewer: if security_concerns_bool: review_labels.append('Possible security concern') - current_labels = self.git_provider.get_labels() + current_labels = self.git_provider.get_pr_labels() current_labels_filtered = [label for label in current_labels if not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith( 'possible security concern')]