Refactor labels

This commit is contained in:
Hussam.lawen
2023-12-11 16:47:38 +02:00
parent f7a6348401
commit fdc776887d
12 changed files with 23 additions and 19 deletions

View File

@ -364,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict:
pass pass
def set_custom_labels(variables): def set_custom_labels(variables, git_provider=None):
if not get_settings().config.enable_custom_labels: if not get_settings().config.enable_custom_labels:
return return
@ -376,11 +376,8 @@ def set_custom_labels(variables):
labels_list = f" - {labels_list}" if labels_list else "" labels_list = f" - {labels_list}" if labels_list else ""
variables["custom_labels"] = labels_list variables["custom_labels"] = labels_list
return return
#final_labels = ""
#for k, v in labels.items(): # Set custom labels
# final_labels += f" - {k} ({v['description']})\n"
#variables["custom_labels"] = final_labels
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
variables["custom_labels_class"] = "class Label(str, Enum):" variables["custom_labels_class"] = "class Label(str, Enum):"
for k, v in labels.items(): for k, v in labels.items():
description = v['description'].strip('\n').replace('\n', '\\n') description = v['description'].strip('\n').replace('\n', '\\n')

View File

@ -354,5 +354,5 @@ class BitbucketProvider(GitProvider):
pass pass
# bitbucket does not support labels # bitbucket does not support labels
def get_labels(self): def get_pr_labels(self):
pass pass

View File

@ -344,7 +344,7 @@ class BitbucketServerProvider(GitProvider):
pass pass
# bitbucket does not support labels # bitbucket does not support labels
def get_labels(self): def get_pr_labels(self):
pass pass
def _get_pr_comments_url(self): def _get_pr_comments_url(self):

View File

@ -216,7 +216,7 @@ class CodeCommitProvider(GitProvider):
def publish_labels(self, labels): def publish_labels(self, labels):
return [""] # not implemented yet return [""] # not implemented yet
def get_labels(self): def get_pr_labels(self):
return [""] # not implemented yet return [""] # not implemented yet
def remove_initial_comment(self): def remove_initial_comment(self):

View File

@ -207,7 +207,7 @@ class GerritProvider(GitProvider):
Comment = namedtuple('Comment', ['body']) Comment = namedtuple('Comment', ['body'])
return Comments([Comment(c['message']) for c in reversed(comments)]) return Comments([Comment(c['message']) for c in reversed(comments)])
def get_labels(self): def get_pr_labels(self):
raise NotImplementedError( raise NotImplementedError(
'Getting labels is not implemented for the gerrit provider') 'Getting labels is not implemented for the gerrit provider')

View File

@ -135,7 +135,10 @@ class GitProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def get_labels(self): def get_pr_labels(self):
pass
def get_repo_labels(self):
pass pass
@abstractmethod @abstractmethod

View File

@ -461,13 +461,17 @@ class GithubProvider(GitProvider):
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to publish labels, error: {e}") get_logger().exception(f"Failed to publish labels, error: {e}")
def get_labels(self): def get_pr_labels(self):
try: try:
return [label.name for label in self.pr.labels] return [label.name for label in self.pr.labels]
except Exception as e: except Exception as e:
get_logger().exception(f"Failed to get labels, error: {e}") get_logger().exception(f"Failed to get labels, error: {e}")
return [] return []
def get_repo_labels(self):
labels = self.repo_obj.get_labels()
return [label for label in labels]
def get_commit_messages(self): def get_commit_messages(self):
""" """
Retrieves the commit messages of a pull request. Retrieves the commit messages of a pull request.

View File

@ -408,7 +408,7 @@ class GitLabProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
pass pass
def get_labels(self): def get_pr_labels(self):
return self.mr.labels return self.mr.labels
def get_commit_messages(self): def get_commit_messages(self):

View File

@ -178,5 +178,5 @@ class LocalGitProvider(GitProvider):
def get_issue_comments(self): def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider') raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
def get_labels(self): def get_pr_labels(self):
raise NotImplementedError('Getting labels is not implemented for the local git provider') raise NotImplementedError('Getting labels is not implemented for the local git provider')

View File

@ -106,7 +106,7 @@ class PRDescription:
else: else:
self.git_provider.publish_description(pr_title, pr_body) self.git_provider.publish_description(pr_title, pr_body)
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"): if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
current_labels = self.git_provider.get_labels() current_labels = self.git_provider.get_pr_labels()
user_labels = get_user_labels(current_labels) user_labels = get_user_labels(current_labels)
self.git_provider.publish_labels(pr_labels + user_labels) self.git_provider.publish_labels(pr_labels + user_labels)
@ -158,7 +158,7 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables) set_custom_labels(variables, self.git_provider)
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)

View File

@ -82,7 +82,7 @@ class PRGenerateLabels:
if get_settings().config.publish_output: if get_settings().config.publish_output:
get_logger().info(f"Pushing labels {self.pr_id}") get_logger().info(f"Pushing labels {self.pr_id}")
current_labels = self.git_provider.get_labels() current_labels = self.git_provider.get_pr_labels()
user_labels = get_user_labels(current_labels) user_labels = get_user_labels(current_labels)
pr_labels = pr_labels + user_labels pr_labels = pr_labels + user_labels
@ -132,7 +132,7 @@ class PRGenerateLabels:
variables["diff"] = self.patches_diff # update diff variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined) environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables) set_custom_labels(variables, self.git_provider)
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)

View File

@ -392,7 +392,7 @@ class PRReviewer:
if security_concerns_bool: if security_concerns_bool:
review_labels.append('Possible security concern') review_labels.append('Possible security concern')
current_labels = self.git_provider.get_labels() current_labels = self.git_provider.get_pr_labels()
current_labels_filtered = [label for label in current_labels if current_labels_filtered = [label for label in current_labels if
not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith( not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith(
'possible security concern')] 'possible security concern')]