mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-04 12:50:38 +08:00
Refactor labels
This commit is contained in:
@ -364,7 +364,7 @@ def try_fix_yaml(response_text: str) -> dict:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_custom_labels(variables):
|
def set_custom_labels(variables, git_provider=None):
|
||||||
if not get_settings().config.enable_custom_labels:
|
if not get_settings().config.enable_custom_labels:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -376,11 +376,8 @@ def set_custom_labels(variables):
|
|||||||
labels_list = f" - {labels_list}" if labels_list else ""
|
labels_list = f" - {labels_list}" if labels_list else ""
|
||||||
variables["custom_labels"] = labels_list
|
variables["custom_labels"] = labels_list
|
||||||
return
|
return
|
||||||
#final_labels = ""
|
|
||||||
#for k, v in labels.items():
|
# Set custom labels
|
||||||
# final_labels += f" - {k} ({v['description']})\n"
|
|
||||||
#variables["custom_labels"] = final_labels
|
|
||||||
#variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
|
|
||||||
variables["custom_labels_class"] = "class Label(str, Enum):"
|
variables["custom_labels_class"] = "class Label(str, Enum):"
|
||||||
for k, v in labels.items():
|
for k, v in labels.items():
|
||||||
description = v['description'].strip('\n').replace('\n', '\\n')
|
description = v['description'].strip('\n').replace('\n', '\\n')
|
||||||
|
@ -354,5 +354,5 @@ class BitbucketProvider(GitProvider):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# bitbucket does not support labels
|
# bitbucket does not support labels
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
pass
|
pass
|
||||||
|
@ -344,7 +344,7 @@ class BitbucketServerProvider(GitProvider):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# bitbucket does not support labels
|
# bitbucket does not support labels
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_pr_comments_url(self):
|
def _get_pr_comments_url(self):
|
||||||
|
@ -216,7 +216,7 @@ class CodeCommitProvider(GitProvider):
|
|||||||
def publish_labels(self, labels):
|
def publish_labels(self, labels):
|
||||||
return [""] # not implemented yet
|
return [""] # not implemented yet
|
||||||
|
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
return [""] # not implemented yet
|
return [""] # not implemented yet
|
||||||
|
|
||||||
def remove_initial_comment(self):
|
def remove_initial_comment(self):
|
||||||
|
@ -207,7 +207,7 @@ class GerritProvider(GitProvider):
|
|||||||
Comment = namedtuple('Comment', ['body'])
|
Comment = namedtuple('Comment', ['body'])
|
||||||
return Comments([Comment(c['message']) for c in reversed(comments)])
|
return Comments([Comment(c['message']) for c in reversed(comments)])
|
||||||
|
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Getting labels is not implemented for the gerrit provider')
|
'Getting labels is not implemented for the gerrit provider')
|
||||||
|
|
||||||
|
@ -135,7 +135,10 @@ class GitProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_repo_labels(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -461,13 +461,17 @@ class GithubProvider(GitProvider):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to publish labels, error: {e}")
|
get_logger().exception(f"Failed to publish labels, error: {e}")
|
||||||
|
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
try:
|
try:
|
||||||
return [label.name for label in self.pr.labels]
|
return [label.name for label in self.pr.labels]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
get_logger().exception(f"Failed to get labels, error: {e}")
|
get_logger().exception(f"Failed to get labels, error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_repo_labels(self):
|
||||||
|
labels = self.repo_obj.get_labels()
|
||||||
|
return [label for label in labels]
|
||||||
|
|
||||||
def get_commit_messages(self):
|
def get_commit_messages(self):
|
||||||
"""
|
"""
|
||||||
Retrieves the commit messages of a pull request.
|
Retrieves the commit messages of a pull request.
|
||||||
|
@ -408,7 +408,7 @@ class GitLabProvider(GitProvider):
|
|||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
return self.mr.labels
|
return self.mr.labels
|
||||||
|
|
||||||
def get_commit_messages(self):
|
def get_commit_messages(self):
|
||||||
|
@ -178,5 +178,5 @@ class LocalGitProvider(GitProvider):
|
|||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
|
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')
|
||||||
|
|
||||||
def get_labels(self):
|
def get_pr_labels(self):
|
||||||
raise NotImplementedError('Getting labels is not implemented for the local git provider')
|
raise NotImplementedError('Getting labels is not implemented for the local git provider')
|
||||||
|
@ -106,7 +106,7 @@ class PRDescription:
|
|||||||
else:
|
else:
|
||||||
self.git_provider.publish_description(pr_title, pr_body)
|
self.git_provider.publish_description(pr_title, pr_body)
|
||||||
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
|
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
|
||||||
current_labels = self.git_provider.get_labels()
|
current_labels = self.git_provider.get_pr_labels()
|
||||||
user_labels = get_user_labels(current_labels)
|
user_labels = get_user_labels(current_labels)
|
||||||
self.git_provider.publish_labels(pr_labels + user_labels)
|
self.git_provider.publish_labels(pr_labels + user_labels)
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ class PRDescription:
|
|||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
set_custom_labels(variables)
|
set_custom_labels(variables, self.git_provider)
|
||||||
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class PRGenerateLabels:
|
|||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
get_logger().info(f"Pushing labels {self.pr_id}")
|
get_logger().info(f"Pushing labels {self.pr_id}")
|
||||||
|
|
||||||
current_labels = self.git_provider.get_labels()
|
current_labels = self.git_provider.get_pr_labels()
|
||||||
user_labels = get_user_labels(current_labels)
|
user_labels = get_user_labels(current_labels)
|
||||||
pr_labels = pr_labels + user_labels
|
pr_labels = pr_labels + user_labels
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class PRGenerateLabels:
|
|||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
|
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
set_custom_labels(variables)
|
set_custom_labels(variables, self.git_provider)
|
||||||
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
|
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
|
||||||
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
|
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
|
||||||
|
|
||||||
|
@ -392,7 +392,7 @@ class PRReviewer:
|
|||||||
if security_concerns_bool:
|
if security_concerns_bool:
|
||||||
review_labels.append('Possible security concern')
|
review_labels.append('Possible security concern')
|
||||||
|
|
||||||
current_labels = self.git_provider.get_labels()
|
current_labels = self.git_provider.get_pr_labels()
|
||||||
current_labels_filtered = [label for label in current_labels if
|
current_labels_filtered = [label for label in current_labels if
|
||||||
not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith(
|
not label.lower().startswith('review effort [1-5]:') and not label.lower().startswith(
|
||||||
'possible security concern')]
|
'possible security concern')]
|
||||||
|
Reference in New Issue
Block a user