diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 32f6e315..67735ee0 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -57,6 +57,10 @@ class GitProvider(ABC): relevant_lines_start: int, relevant_lines_end: int): pass + @abstractmethod + def publish_labels(self, labels): + pass + @abstractmethod def remove_initial_comment(self): pass diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 1f2ffec7..2f19487b 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -13,7 +13,7 @@ from ..algo.utils import load_large_diff class GithubProvider(GitProvider): - def __init__(self, pr_url: Optional[str] = None, incremental: Optional[IncrementalPR] = False): + def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): self.repo_obj = None self.installation_id = settings.get("GITHUB.INSTALLATION_ID") self.github_client = self._get_github_client() @@ -306,3 +306,21 @@ class GithubProvider(GitProvider): except Exception: file_content_str = "" return file_content_str + + def publish_labels(self, pr_types): + try: + if type(pr_types) is not list: + pr_types = [pr_types] + colors = ["1d76db", "e99695", "c5def5", "bfdadc", "bfd4f2", "d4c5f9", "d1bcf9"] + labels = ["Bug fix", "Tests", "Bug fix with tests", "Refactoring", "Enhancement", "Documentation", "Other"] + post_parameters = [] + for p in pr_types: + ind = 0 + if p in labels: + ind = labels.index(p) + post_parameters.append({"name": p, "color": colors[ind]}) + headers, data = self.pr._requester.requestJsonAndCheck( + "PUT", f"{self.pr.issue_url}/labels", input=post_parameters + ) + except: + logging.exception("Failed to publish labels") diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 95246647..089d3fca 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -254,3 +254,6 @@ class GitLabProvider(GitProvider): def get_user_id(self): return None + + def publish_labels(self, labels): + pass \ No newline at end of file diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 2ec31b1c..8c8df966 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -10,9 +10,9 @@ You must use the following JSON schema to format your answer: "type": "string", "description": "an informative title for the PR, describing its main theme" }, - "Type of PR": { + "PR Type": { "type": "string", - "enum": ["Bug fix", "Tests", "Bug fix with tests", "Refactoring", "Enhancement", "Documentation", "Other"] + "description": possible values are: ["Bug fix", "Tests", "Bug fix with tests", "Refactoring", "Enhancement", "Documentation", "Other"] }, "PR Description": { "type": "string", diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 53e57a8c..a8647a83 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -42,13 +42,14 @@ class PRDescription: logging.info('Getting AI prediction...') self.prediction = await self._get_prediction() logging.info('Preparing answer...') - pr_title, pr_body, markdown_text = self._prepare_pr_answer() + pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() if settings.config.publish_output: logging.info('Pushing answer...') if settings.pr_description.publish_description_as_comment: self.git_provider.publish_comment(markdown_text) else: self.git_provider.publish_description(pr_title, pr_body) + self.git_provider.publish_labels(pr_types) self.git_provider.remove_initial_comment() return "" @@ -73,6 +74,9 @@ class PRDescription: markdown_text += f"## {key}\n\n" markdown_text += f"{value}\n\n" pr_body = "" + pr_types = [] + if 'PR Type' in data: + pr_types = data['PR Type'].split(',') title = data['PR Title'] del data['PR Title'] for key, value in data.items(): @@ -83,4 +87,4 @@ class PRDescription: pr_body += f"**{value}**\n\n___\n" if settings.config.verbosity_level >= 2: logging.info(f"title:\n{title}\n{pr_body}") - return title, pr_body, markdown_text + return title, pr_body, pr_types, markdown_text