diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index cd2bf2cc..6e76c5e0 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -7,6 +7,7 @@ from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_config import PRConfig from pr_agent.tools.pr_description import PRDescription +from pr_agent.tools.pr_generate_labels import PRGenerateLabels from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer @@ -31,6 +32,7 @@ command2class = { "settings": PRConfig, "similar_issue": PRSimilarIssue, "add_docs": PRAddDocs, + "generate_labels": PRGenerateLabels, } commands = list(command2class.keys()) diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 76981f33..7de31fd4 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -306,7 +306,7 @@ def try_fix_yaml(review_text: str) -> dict: return data -async def set_custom_labels(variables): +def set_custom_labels(variables): labels = get_settings().custom_labels if not labels: # set default labels @@ -319,3 +319,4 @@ async def set_custom_labels(variables): for k, v in labels.items(): final_labels += f" - {k} ({v['description']})\n" variables["custom_labels"] = final_labels + variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}" diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 3b0b0360..a160bb1a 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -23,8 +23,10 @@ global_settings = Dynaconf( "settings/pr_sort_code_suggestions_prompts.toml", "settings/pr_information_from_user_prompts.toml", "settings/pr_update_changelog_prompts.toml", + "settings/pr_custom_labels.toml", "settings/pr_add_docs.toml", - "settings_prod/.secrets.toml" + "settings_prod/.secrets.toml", + "settings/custom_labels.toml" ]] ) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 0a45e55a..dcf90c23 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -37,23 +37,7 @@ extra_instructions = "" # markers use_description_markers=false include_generated_by_header=true - -[custom_labels."Bug fix"] -description = "Fixes a bug in the code" -[custom_labels."Tests"] -description = "Adds or modifies tests" -[custom_labels."Bug fix with tests"] -description = "Fixes a bug in the code and adds or modifies tests" -[custom_labels."Refactoring"] -description = "Refactors the code without changing its functionality" -[custom_labels."Enhancement"] -description = "Adds new features or functionality" -[custom_labels."Documentation"] -description = "Adds or modifies documentation" -[custom_labels."SQL modifications"] -description = "Adds or modifies SQL queries" -[custom_labels."Other"] -description = "Other changes that do not fit in any of the above categories" +enable_custom_labels=true #custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other'] diff --git a/pr_agent/settings/custom_labels.toml b/pr_agent/settings/custom_labels.toml new file mode 100644 index 00000000..5891daa7 --- /dev/null +++ b/pr_agent/settings/custom_labels.toml @@ -0,0 +1,14 @@ +[custom_labels."Bug fix"] +description = "Fixes a bug in the code" +[custom_labels."Tests"] +description = "Adds or modifies tests" +[custom_labels."Bug fix with tests"] +description = "Fixes a bug in the code and adds or modifies tests" +[custom_labels."Refactoring"] +description = "Code refactoring without changing functionality" +[custom_labels."Enhancement"] +description = "Adds new features or functionality" +[custom_labels."Documentation"] +description = "Adds or modifies documentation" +[custom_labels."Other"] +description = "Other changes that do not fit in any of the above categories" \ No newline at end of file diff --git a/pr_agent/settings/pr_custom_labels.toml b/pr_agent/settings/pr_custom_labels.toml new file mode 100644 index 00000000..09b89842 --- /dev/null +++ b/pr_agent/settings/pr_custom_labels.toml @@ -0,0 +1,72 @@ +[pr_custom_labels_prompt] +system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests. +Your task is to label the type of the PR content. +- Make sure not to focus the new PR code (the '+' lines). +- If needed, each YAML output should be in block scalar format ('|-') +{%- if extra_instructions %} + +Extra instructions from the user: +' +{{ extra_instructions }} +' +{% endif %} + +You must use the following YAML schema to format your answer: +```yaml +PR Type: + type: array +{%- if enable_custom_labels %} + description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} + items: + type: string + enum: +{%- if enable_custom_labels %} +{{ custom_labels }} +{%- else %} + - Bug fix + - Tests + - Refactoring + - Enhancement + - Documentation + - Other +{%- endif %} + +Example output: +```yaml +PR Labels: +{%- if enable_custom_labels %} +{{ custom_labels_examples }} +{%- else %} + - Bug fix + - Tests +{%- endif %} +``` + +Make sure to output a valid YAML. Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields. +""" + +user="""PR Info: +Previous title: '{{title}}' +Previous description: '{{description}}' +Branch: '{{branch}}' +{%- if language %} + +Main language: {{language}} +{%- endif %} +{%- if commit_messages_str %} + +Commit messages: +{{commit_messages_str}} +{%- endif %} + + +The PR Git Diff: +``` +{{diff}} +``` +Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines. + +Response (should be a valid YAML, and nothing else): +```yaml +""" diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml index 14f66532..3d9c3f0b 100644 --- a/pr_agent/settings/pr_description_prompts.toml +++ b/pr_agent/settings/pr_description_prompts.toml @@ -19,11 +19,22 @@ PR Title: description: an informative title for the PR, describing its main theme PR Type: type: array +{%- if enable_custom_labels %} description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} items: type: string enum: +{%- if enable_custom_labels %} {{ custom_labels }} +{%- else %} + - Bug fix + - Tests + - Refactoring + - Enhancement + - Documentation + - Other +{%- endif %} PR Description: type: string description: an informative and concise description of the PR. @@ -47,7 +58,11 @@ Example output: PR Title: |- ... PR Type: +{%- if enable_custom_labels %} +{{ custom_labels_examples }} +{%- else %} - Bug fix +{%- endif %} PR Description: |- ... PR Main Files Walkthrough: diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml index 4841947b..b717ec3d 100644 --- a/pr_agent/settings/pr_reviewer_prompts.toml +++ b/pr_agent/settings/pr_reviewer_prompts.toml @@ -51,8 +51,22 @@ PR Analysis: description: summary of the PR in 2-3 sentences. Type of PR: type: string +{%- if enable_custom_labels %} + description: One or more labels that describe the PR type. Don't output the description in the parentheses. +{%- endif %} + items: + type: string enum: +{%- if enable_custom_labels %} {{ custom_labels }} +{%- else %} + - Bug fix + - Tests + - Refactoring + - Enhancement + - Documentation + - Other +{%- endif %} {%- if require_score %} Score: type: int diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c54b5131..cacfdad6 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -43,7 +43,9 @@ class PRDescription: "use_bullet_points": get_settings().pr_description.use_bullet_points, "extra_instructions": get_settings().pr_description.extra_instructions, "commit_messages_str": self.git_provider.get_commit_messages(), - "custom_labels": "" + "enable_custom_labels": get_settings().pr_description.enable_custom_labels, + "custom_labels": "", + "custom_labels_examples": "", } self.user_description = self.git_provider.get_user_description() @@ -141,7 +143,7 @@ class PRDescription: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - await set_custom_labels(variables) + set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables) diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py new file mode 100644 index 00000000..81b1d040 --- /dev/null +++ b/pr_agent/tools/pr_generate_labels.py @@ -0,0 +1,163 @@ +import copy +import re +from typing import List, Tuple + +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import load_yaml, set_custom_labels +from pr_agent.config_loader import get_settings +from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language +from pr_agent.log import get_logger + + +class PRGenerateLabels: + def __init__(self, pr_url: str, args: list = None): + """ + Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels + corresponding to the PR using an AI model. + Args: + pr_url (str): The URL of the pull request. + args (list, optional): List of arguments passed to the PRGenerateLabels class. Defaults to None. + """ + # Initialize the git provider and main PR language + self.git_provider = get_git_provider()(pr_url) + self.main_pr_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) + self.pr_id = self.git_provider.get_pr_id() + + # Initialize the AI handler + self.ai_handler = AiHandler() + + # Initialize the variables dictionary + self.vars = { + "title": self.git_provider.pr.title, + "branch": self.git_provider.get_pr_branch(), + "description": self.git_provider.get_pr_description(full=False), + "language": self.main_pr_language, + "diff": "", # empty diff for initial calculation + "use_bullet_points": get_settings().pr_description.use_bullet_points, + "extra_instructions": get_settings().pr_description.extra_instructions, + "commit_messages_str": self.git_provider.get_commit_messages(), + "custom_labels": "", + "custom_labels_examples": "", + "enable_custom_labels": get_settings().pr_description.enable_custom_labels, + } + + # Initialize the token handler + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + get_settings().pr_custom_labels_prompt.system, + get_settings().pr_custom_labels_prompt.user, + ) + + # Initialize patches_diff and prediction attributes + self.patches_diff = None + self.prediction = None + + async def run(self): + """ + Generates a PR labels using an AI model and publishes it to the PR. + """ + + try: + get_logger().info(f"Generating a PR labels {self.pr_id}") + if get_settings().config.publish_output: + self.git_provider.publish_comment("Preparing PR labels...", is_temporary=True) + + await retry_with_fallback_models(self._prepare_prediction) + + get_logger().info(f"Preparing answer {self.pr_id}") + if self.prediction: + self._prepare_data() + else: + return None + + pr_labels = self._prepare_labels() + + if get_settings().config.publish_output: + get_logger().info(f"Pushing labels {self.pr_id}") + if self.git_provider.is_supported("get_labels"): + current_labels = self.git_provider.get_labels() + if current_labels is None: + current_labels = [] + self.git_provider.publish_labels(pr_labels + current_labels) + self.git_provider.remove_initial_comment() + except Exception as e: + get_logger().error(f"Error generating PR labels {self.pr_id}: {e}") + + return "" + + async def _prepare_prediction(self, model: str) -> None: + """ + Prepare the AI prediction for the PR labels based on the provided model. + + Args: + model (str): The name of the model to be used for generating the prediction. + + Returns: + None + + Raises: + Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions. + + """ + + get_logger().info(f"Getting PR diff {self.pr_id}") + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + get_logger().info(f"Getting AI prediction {self.pr_id}") + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str) -> str: + """ + Generate an AI prediction for the PR labels based on the provided model. + + Args: + model (str): The name of the model to be used for generating the prediction. + + Returns: + str: The generated AI prediction. + """ + variables = copy.deepcopy(self.vars) + variables["diff"] = self.patches_diff # update diff + + environment = Environment(undefined=StrictUndefined) + set_custom_labels(variables) + system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables) + user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables) + + if get_settings().config.verbosity_level >= 2: + get_logger().info(f"\nSystem prompt:\n{system_prompt}") + get_logger().info(f"\nUser prompt:\n{user_prompt}") + + response, finish_reason = await self.ai_handler.chat_completion( + model=model, + temperature=0.2, + system=system_prompt, + user=user_prompt + ) + + return response + + def _prepare_data(self): + # Load the AI prediction data into a dictionary + self.data = load_yaml(self.prediction.strip()) + + + + def _prepare_labels(self) -> List[str]: + pr_types = [] + + # If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types' + if 'PR Type' in self.data: + if type(self.data['PR Type']) == list: + pr_types = self.data['PR Type'] + elif type(self.data['PR Type']) == str: + pr_types = self.data['PR Type'].split(',') + + return pr_types diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 7c722df0..b9aa5481 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -150,7 +150,7 @@ class PRReviewer: variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) - await set_custom_labels(variables) + set_custom_labels(variables) system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables) user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)