From fa24413201d82b8f37dc4eadd7e1622e857da7b5 Mon Sep 17 00:00:00 2001 From: "Hussam.lawen" Date: Mon, 23 Oct 2023 16:29:33 +0300 Subject: [PATCH] Custom Labels --- pr_agent/servers/github_action_runner.py | 10 ++++++++++ pr_agent/settings/configuration.toml | 3 +++ pr_agent/settings/pr_description_prompts.toml | 8 +------- pr_agent/settings/pr_reviewer_prompts.toml | 7 +------ pr_agent/tools/pr_description.py | 13 ++++++++++++- pr_agent/tools/pr_reviewer.py | 10 ++++++++++ 6 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 714e7297..7339cf3b 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -17,6 +17,9 @@ async def run_action(): OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY') OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG') GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') + CUSTOM_LABELS = os.environ.get('CUSTOM_LABELS') + # CUSTOM_LABELS is a comma separated list of labels (string), convert to list and strip spaces + get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) @@ -33,6 +36,12 @@ async def run_action(): if not GITHUB_TOKEN: print("GITHUB_TOKEN not set") return + if CUSTOM_LABELS: + CUSTOM_LABELS = [x.strip() for x in CUSTOM_LABELS.split(',')] + else: + # Set default labels + CUSTOM_LABELS = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + print(f"Using default labels: {CUSTOM_LABELS}") # Set the environment variables in the settings get_settings().set("OPENAI.KEY", OPENAI_KEY) @@ -40,6 +49,7 @@ async def run_action(): get_settings().set("OPENAI.ORG", OPENAI_ORG) get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN) get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user") + get_settings().set("PR_DESCIPTION.CUSTOM_LABELS", CUSTOM_LABELS) # Load the event payload try: diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 6f44bc53..49b9317e 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -33,10 +33,13 @@ add_original_user_description=false keep_original_user_title=false use_bullet_points=true extra_instructions = "" + # markers use_description_markers=false include_generated_by_header=true +custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + [pr_questions] # /ask # [pr_code_suggestions] # /improve # diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 7d1621b0..1c018959 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -22,13 +22,7 @@ PR Type: items: type: string enum: - - Bug fix - - Tests - - Bug fix with tests - - Refactoring - - Enhancement - - Documentation - - Other +{{ custom_labels }} PR Description: type: string description: an informative and concise description of the PR. diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index cb49b5d0..4841947b 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -52,12 +52,7 @@ PR Analysis: Type of PR: type: string enum: - - Bug fix - - Tests - - Refactoring - - Enhancement - - Documentation - - Other +{{ custom_labels }} {%- if require_score %} Score: type: int diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c1bd03fd..2c710aa0 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -42,7 +42,9 @@ class PRDescription: "diff": "", # empty diff for initial calculation "use_bullet_points": get_settings().pr_description.use_bullet_points, "extra_instructions": get_settings().pr_description.extra_instructions, - "commit_messages_str": self.git_provider.get_commit_messages() + "commit_messages_str": self.git_provider.get_commit_messages(), + "custom_labels": "" + } self.user_description = self.git_provider.get_user_description() @@ -140,6 +142,7 @@ class PRDescription: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) + await self.set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) @@ -156,6 +159,14 @@ class PRDescription: return response + async def set_custom_labels(self, variables): + labels = get_settings().pr_description.custom_labels + if not labels: + # set default labels + labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + labels_list = "\n - ".join(labels) if labels else "" + labels_list = f" - {labels_list}" if labels_list else "" + variables["custom_labels"] = labels_list def _prepare_data(self): # Load the AI prediction data into a dictionary diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index ed99ddf6..dc563f23 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -149,6 +149,7 @@ class PRReviewer: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) + await self.set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables) @@ -311,3 +312,12 @@ class PRReviewer: break return question_str, answer_str + + async def set_custom_labels(self, variables): + labels = get_settings().pr_description.custom_labels + if not labels: + # set default labels + labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + labels_list = "\n - ".join(labels) if labels else "" + labels_list = f" - {labels_list}" if labels_list else "" + variables["custom_labels"] = labels_list \ No newline at end of file