diff --git a/docs/GENERATE_CUSTOM_LABELS.md b/docs/GENERATE_CUSTOM_LABELS.md new file mode 100644 index 00000000..5c1743f4 --- /dev/null +++ b/docs/GENERATE_CUSTOM_LABELS.md @@ -0,0 +1,35 @@ +# Generate Custom Labels +The `generte_labels` tool scans the PR code changes, and given a list of labels and their descriptions, it automatically suggests labels that match the PR code changes. + +It can be invoked manually by commenting on any PR: +``` +/generte_labels +``` +For example: + +If we wish to add detect changes to SQL queries in a given PR, we can add the following custom label along with its description: + + +When running the `generte_labels` tool on a PR that includes changes in SQL queries, it will automatically suggest the custom label: + + +### Configuration options +To enable custom labels, you need to add the following configuration to the [custom_labels file](./../pr_agent/settings/custom_labels.toml): + - Change `enable_custom_labels` to True: This will turn off the default labels and enable the custom labels provided in the custom_labels.toml file. + - Add the custom labels to the custom_labels.toml file. It should be formatted as follows: + ``` +[custom_labels."Custom Label Name"] +description = "Description of when AI should suggest this label" +``` + - You can add modify the list to include all the custom labels you wish to use in your repository. + +#### Github Action +To use the `generte_labels` tool with Github Action: + +- Add the following file to your repository under `env` section in `.github/workflows/pr_agent.yml` +- Comma separated list of custom labels and their descriptions +- The number of labels and descriptions should be the same and in the same order (empty descriptions are allowed): +``` +CUSTOM_LABELS: "label1, label2, ..." +CUSTOM_LABELS_DESCRIPTION: "label1 description, label2 description, ..." +``` \ No newline at end of file diff --git a/pics/custom_label_published.png b/pics/custom_label_published.png new file mode 100644 index 00000000..7dfffcf6 Binary files /dev/null and b/pics/custom_label_published.png differ diff --git a/pics/custom_labels_list.png b/pics/custom_labels_list.png new file mode 100644 index 00000000..4e00caad Binary files /dev/null and b/pics/custom_labels_list.png differ diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index cd2bf2cc..6e76c5e0 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -7,6 +7,7 @@ from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_config import PRConfig from pr_agent.tools.pr_description import PRDescription +from pr_agent.tools.pr_generate_labels import PRGenerateLabels from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer @@ -31,6 +32,7 @@ command2class = { "settings": PRConfig, "similar_issue": PRSimilarIssue, "add_docs": PRAddDocs, + "generate_labels": PRGenerateLabels, } commands = list(command2class.keys()) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 4e88b33e..7de31fd4 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -304,3 +304,19 @@ def try_fix_yaml(review_text: str) -> dict: except: pass return data + + +def set_custom_labels(variables): + labels = get_settings().custom_labels + if not labels: + # set default labels + labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + labels_list = "\n - ".join(labels) if labels else "" + labels_list = f" - {labels_list}" if labels_list else "" + variables["custom_labels"] = labels_list + return + final_labels = "" + for k, v in labels.items(): + final_labels += f" - {k} ({v['description']})\n" + variables["custom_labels"] = final_labels + variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 3b0b0360..a160bb1a 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -23,8 +23,10 @@ global_settings = Dynaconf( "settings/pr_sort_code_suggestions_prompts.toml", "settings/pr_information_from_user_prompts.toml", "settings/pr_update_changelog_prompts.toml", + "settings/pr_custom_labels.toml", "settings/pr_add_docs.toml", - "settings_prod/.secrets.toml" + "settings_prod/.secrets.toml", + "settings/custom_labels.toml" ]] ) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 714e7297..e1a5e56a 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -17,8 +17,11 @@ async def run_action(): OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY') OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG') GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN') - get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) + CUSTOM_LABELS = os.environ.get('CUSTOM_LABELS') + CUSTOM_LABELS_DESCRIPTIONS = os.environ.get('CUSTOM_LABELS_DESCRIPTIONS') + # CUSTOM_LABELS is a comma separated list of labels (string), convert to list and strip spaces + get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) # Check if required environment variables are set if not GITHUB_EVENT_NAME: @@ -33,6 +36,7 @@ async def run_action(): if not GITHUB_TOKEN: print("GITHUB_TOKEN not set") return + CUSTOM_LABELS_DICT = handle_custom_labels(CUSTOM_LABELS, CUSTOM_LABELS_DESCRIPTIONS) # Set the environment variables in the settings get_settings().set("OPENAI.KEY", OPENAI_KEY) @@ -40,6 +44,7 @@ async def run_action(): get_settings().set("OPENAI.ORG", OPENAI_ORG) get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN) get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user") + get_settings().set("CUSTOM_LABELS", CUSTOM_LABELS_DICT) # Load the event payload try: @@ -88,5 +93,31 @@ async def run_action(): await PRAgent().handle_request(url, body) +def handle_custom_labels(CUSTOM_LABELS, CUSTOM_LABELS_DESCRIPTIONS): + if CUSTOM_LABELS: + CUSTOM_LABELS = [x.strip() for x in CUSTOM_LABELS.split(',')] + else: + # Set default labels + CUSTOM_LABELS = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', + 'Other'] + print(f"Using default labels: {CUSTOM_LABELS}") + if CUSTOM_LABELS_DESCRIPTIONS: + CUSTOM_LABELS_DESCRIPTIONS = [x.strip() for x in CUSTOM_LABELS_DESCRIPTIONS.split(',')] + else: + # Set default labels + CUSTOM_LABELS_DESCRIPTIONS = ['Fixes a bug in the code', 'Adds or modifies tests', + 'Fixes a bug in the code and adds or modifies tests', + 'Refactors the code without changing its functionality', + 'Adds new features or functionality', + 'Adds or modifies documentation', + 'Other changes that do not fit in any of the above categories'] + print(f"Using default labels: {CUSTOM_LABELS_DESCRIPTIONS}") + # create a dictionary of labels and descriptions + CUSTOM_LABELS_DICT = dict() + for i in range(len(CUSTOM_LABELS)): + CUSTOM_LABELS_DICT[CUSTOM_LABELS[i]] = {'description': CUSTOM_LABELS_DESCRIPTIONS[i]} + return CUSTOM_LABELS_DICT + + if __name__ == '__main__': asyncio.run(run_action()) \ No newline at end of file diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 97d92261..9486c740 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -33,10 +33,13 @@ add_original_user_description=false keep_original_user_title=false use_bullet_points=true extra_instructions = "" + # markers use_description_markers=false include_generated_by_header=true +#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] + [pr_questions] # /ask # [pr_code_suggestions] # /improve # diff --git a/pr_agent/settings/custom_labels.toml b/pr_agent/settings/custom_labels.toml new file mode 100644 index 00000000..8b1340f2 --- /dev/null +++ b/pr_agent/settings/custom_labels.toml @@ -0,0 +1,16 @@ +enable_custom_labels=false + +[custom_labels."Bug fix"] +description = "Fixes a bug in the code" +[custom_labels."Tests"] +description = "Adds or modifies tests" +[custom_labels."Bug fix with tests"] +description = "Fixes a bug in the code and adds or modifies tests" +[custom_labels."Refactoring"] +description = "Code refactoring without changing functionality" +[custom_labels."Enhancement"] +description = "Adds new features or functionality" +[custom_labels."Documentation"] +description = "Adds or modifies documentation" +[custom_labels."Other"] +description = "Other changes that do not fit in any of the above categories" \ No newline at end of file diff --git a/pr_agent/settings/pr_custom_labels.toml b/pr_agent/settings/pr_custom_labels.toml new file mode 100644 index 00000000..96cee17b --- /dev/null +++ b/pr_agent/settings/pr_custom_labels.toml @@ -0,0 +1,72 @@ +[pr_custom_labels_prompt] +system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests. +Your task is to label the type of the PR content. +- Make sure not to focus the new PR code (the '+' lines). +- If needed, each YAML output should be in block scalar format ('|-') +{%- if extra_instructions %} + +Extra instructions from the user: +' +{{ extra_instructions }} +' +{% endif %} + +You must use the following YAML schema to format your answer: +```yaml +PR Type: + type: array +{%- if enable_custom_labels %} + description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} + items: + type: string + enum: +{%- if enable_custom_labels %} +{{ custom_labels }} +{%- else %} + - Bug fix + - Tests + - Refactoring + - Enhancement + - Documentation + - Other +{%- endif %} + +Example output: +```yaml +PR Type: +{%- if enable_custom_labels %} +{{ custom_labels_examples }} +{%- else %} + - Bug fix + - Tests +{%- endif %} +``` + +Make sure to output a valid YAML. Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields. +""" + +user="""PR Info: +Previous title: '{{title}}' +Previous description: '{{description}}' +Branch: '{{branch}}' +{%- if language %} + +Main language: {{language}} +{%- endif %} +{%- if commit_messages_str %} + +Commit messages: +{{commit_messages_str}} +{%- endif %} + + +The PR Git Diff: +``` +{{diff}} +``` +Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines. + +Response (should be a valid YAML, and nothing else): +```yaml +""" diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 7d1621b0..3d9c3f0b 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -19,16 +19,22 @@ PR Title: description: an informative title for the PR, describing its main theme PR Type: type: array +{%- if enable_custom_labels %} + description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} items: type: string enum: +{%- if enable_custom_labels %} +{{ custom_labels }} +{%- else %} - Bug fix - Tests - - Bug fix with tests - Refactoring - Enhancement - Documentation - Other +{%- endif %} PR Description: type: string description: an informative and concise description of the PR. @@ -52,7 +58,11 @@ Example output: PR Title: |- ... PR Type: +{%- if enable_custom_labels %} +{{ custom_labels_examples }} +{%- else %} - Bug fix +{%- endif %} PR Description: |- ... PR Main Files Walkthrough: diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index cb49b5d0..b717ec3d 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -51,13 +51,22 @@ PR Analysis: description: summary of the PR in 2-3 sentences. Type of PR: type: string +{%- if enable_custom_labels %} + description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} + items: + type: string enum: +{%- if enable_custom_labels %} +{{ custom_labels }} +{%- else %} - Bug fix - Tests - Refactoring - Enhancement - Documentation - Other +{%- endif %} {%- if require_score %} Score: type: int diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c1bd03fd..a88ff336 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml +from pr_agent.algo.utils import load_yaml, set_custom_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -42,7 +42,10 @@ class PRDescription: "diff": "", # empty diff for initial calculation "use_bullet_points": get_settings().pr_description.use_bullet_points, "extra_instructions": get_settings().pr_description.extra_instructions, - "commit_messages_str": self.git_provider.get_commit_messages() + "commit_messages_str": self.git_provider.get_commit_messages(), + "enable_custom_labels": get_settings().enable_custom_labels, + "custom_labels": "", + "custom_labels_examples": "", } self.user_description = self.git_provider.get_user_description() @@ -140,6 +143,7 @@ class PRDescription: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) + set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) @@ -156,7 +160,6 @@ class PRDescription: return response - def _prepare_data(self): # Load the AI prediction data into a dictionary self.data = load_yaml(self.prediction.strip()) diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py new file mode 100644 index 00000000..bf5b5f98 --- /dev/null +++ b/pr_agent/tools/pr_generate_labels.py @@ -0,0 +1,163 @@ +import copy +import re +from typing import List, Tuple + +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import load_yaml, set_custom_labels +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language +from pr_agent.log import get_logger + + +class PRGenerateLabels: + def __init__(self, pr_url: str, args: list = None): + """ + Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels + corresponding to the PR using an AI model. + Args: + pr_url (str): The URL of the pull request. + args (list, optional): List of arguments passed to the PRGenerateLabels class. Defaults to None. + """ + # Initialize the git provider and main PR language + self.git_provider = get_git_provider()(pr_url) + self.main_pr_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) + self.pr_id = self.git_provider.get_pr_id() + + # Initialize the AI handler + self.ai_handler = AiHandler() + + # Initialize the variables dictionary + self.vars = { + "title": self.git_provider.pr.title, + "branch": self.git_provider.get_pr_branch(), + "description": self.git_provider.get_pr_description(full=False), + "language": self.main_pr_language, + "diff": "", # empty diff for initial calculation + "use_bullet_points": get_settings().pr_description.use_bullet_points, + "extra_instructions": get_settings().pr_description.extra_instructions, + "commit_messages_str": self.git_provider.get_commit_messages(), + "custom_labels": "", + "custom_labels_examples": "", + "enable_custom_labels": get_settings().enable_custom_labels, + } + + # Initialize the token handler + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_custom_labels_prompt.system, + get_settings().pr_custom_labels_prompt.user, + ) + + # Initialize patches_diff and prediction attributes + self.patches_diff = None + self.prediction = None + + async def run(self): + """ + Generates a PR labels using an AI model and publishes it to the PR. + """ + + try: + get_logger().info(f"Generating a PR labels {self.pr_id}") + if get_settings().config.publish_output: + self.git_provider.publish_comment("Preparing PR labels...", is_temporary=True) + + await retry_with_fallback_models(self._prepare_prediction) + + get_logger().info(f"Preparing answer {self.pr_id}") + if self.prediction: + self._prepare_data() + else: + return None + + pr_labels = self._prepare_labels() + + if get_settings().config.publish_output: + get_logger().info(f"Pushing labels {self.pr_id}") + if self.git_provider.is_supported("get_labels"): + current_labels = self.git_provider.get_labels() + if current_labels is None: + current_labels = [] + self.git_provider.publish_labels(pr_labels + current_labels) + self.git_provider.remove_initial_comment() + except Exception as e: + get_logger().error(f"Error generating PR labels {self.pr_id}: {e}") + + return "" + + async def _prepare_prediction(self, model: str) -> None: + """ + Prepare the AI prediction for the PR labels based on the provided model. + + Args: + model (str): The name of the model to be used for generating the prediction. + + Returns: + None + + Raises: + Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions. + + """ + + get_logger().info(f"Getting PR diff {self.pr_id}") + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + get_logger().info(f"Getting AI prediction {self.pr_id}") + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str) -> str: + """ + Generate an AI prediction for the PR labels based on the provided model. + + Args: + model (str): The name of the model to be used for generating the prediction. + + Returns: + str: The generated AI prediction. + """ + variables = copy.deepcopy(self.vars) + variables["diff"] = self.patches_diff # update diff + + environment = Environment(undefined=StrictUndefined) + set_custom_labels(variables) + system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) + + if get_settings().config.verbosity_level >= 2: + get_logger().info(f"\nSystem prompt:\n{system_prompt}") + get_logger().info(f"\nUser prompt:\n{user_prompt}") + + response, finish_reason = await self.ai_handler.chat_completion( + model=model, + temperature=0.2, + system=system_prompt, + user=user_prompt + ) + + return response + + def _prepare_data(self): + # Load the AI prediction data into a dictionary + self.data = load_yaml(self.prediction.strip()) + + + + def _prepare_labels(self) -> List[str]: + pr_types = [] + + # If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types' + if 'PR Type' in self.data: + if type(self.data['PR Type']) == list: + pr_types = self.data['PR Type'] + elif type(self.data['PR Type']) == str: + pr_types = self.data['PR Type'].split(',') + + return pr_types diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index ed99ddf6..0eeb5578 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,7 +9,7 @@ from yaml import SafeLoader from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml +from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language @@ -63,6 +63,8 @@ class PRReviewer: 'answer_str': answer_str, "extra_instructions": get_settings().pr_reviewer.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), + "custom_labels": "", + "enable_custom_labels": get_settings().enable_custom_labels, } self.token_handler = TokenHandler( @@ -149,6 +151,7 @@ class PRReviewer: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) + set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)