diff --git a/docs/GENERATE_CUSTOM_LABELS.md b/docs/GENERATE_CUSTOM_LABELS.md
new file mode 100644
index 00000000..5c1743f4
--- /dev/null
+++ b/docs/GENERATE_CUSTOM_LABELS.md
@@ -0,0 +1,35 @@
+# Generate Custom Labels
+The `generte_labels` tool scans the PR code changes, and given a list of labels and their descriptions, it automatically suggests labels that match the PR code changes.
+
+It can be invoked manually by commenting on any PR:
+```
+/generte_labels
+```
+For example:
+
+If we wish to add detect changes to SQL queries in a given PR, we can add the following custom label along with its description:
+
+
+When running the `generte_labels` tool on a PR that includes changes in SQL queries, it will automatically suggest the custom label:
+
+
+### Configuration options
+To enable custom labels, you need to add the following configuration to the [custom_labels file](./../pr_agent/settings/custom_labels.toml):
+ - Change `enable_custom_labels` to True: This will turn off the default labels and enable the custom labels provided in the custom_labels.toml file.
+ - Add the custom labels to the custom_labels.toml file. It should be formatted as follows:
+ ```
+[custom_labels."Custom Label Name"]
+description = "Description of when AI should suggest this label"
+```
+ - You can add modify the list to include all the custom labels you wish to use in your repository.
+
+#### Github Action
+To use the `generte_labels` tool with Github Action:
+
+- Add the following file to your repository under `env` section in `.github/workflows/pr_agent.yml`
+- Comma separated list of custom labels and their descriptions
+- The number of labels and descriptions should be the same and in the same order (empty descriptions are allowed):
+```
+CUSTOM_LABELS: "label1, label2, ..."
+CUSTOM_LABELS_DESCRIPTION: "label1 description, label2 description, ..."
+```
\ No newline at end of file
diff --git a/pics/custom_label_published.png b/pics/custom_label_published.png
new file mode 100644
index 00000000..7dfffcf6
Binary files /dev/null and b/pics/custom_label_published.png differ
diff --git a/pics/custom_labels_list.png b/pics/custom_labels_list.png
new file mode 100644
index 00000000..4e00caad
Binary files /dev/null and b/pics/custom_labels_list.png differ
diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py
index cd2bf2cc..6e76c5e0 100644
--- a/pr_agent/agent/pr_agent.py
+++ b/pr_agent/agent/pr_agent.py
@@ -7,6 +7,7 @@ from pr_agent.tools.pr_add_docs import PRAddDocs
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_config import PRConfig
from pr_agent.tools.pr_description import PRDescription
+from pr_agent.tools.pr_generate_labels import PRGenerateLabels
from pr_agent.tools.pr_information_from_user import PRInformationFromUser
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
@@ -31,6 +32,7 @@ command2class = {
"settings": PRConfig,
"similar_issue": PRSimilarIssue,
"add_docs": PRAddDocs,
+ "generate_labels": PRGenerateLabels,
}
commands = list(command2class.keys())
diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py
index 4e88b33e..7de31fd4 100644
--- a/pr_agent/algo/utils.py
+++ b/pr_agent/algo/utils.py
@@ -304,3 +304,19 @@ def try_fix_yaml(review_text: str) -> dict:
except:
pass
return data
+
+
+def set_custom_labels(variables):
+ labels = get_settings().custom_labels
+ if not labels:
+ # set default labels
+ labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']
+ labels_list = "\n - ".join(labels) if labels else ""
+ labels_list = f" - {labels_list}" if labels_list else ""
+ variables["custom_labels"] = labels_list
+ return
+ final_labels = ""
+ for k, v in labels.items():
+ final_labels += f" - {k} ({v['description']})\n"
+ variables["custom_labels"] = final_labels
+ variables["custom_labels_examples"] = f" - {list(labels.keys())[0]}"
diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py
index 3b0b0360..a160bb1a 100644
--- a/pr_agent/config_loader.py
+++ b/pr_agent/config_loader.py
@@ -23,8 +23,10 @@ global_settings = Dynaconf(
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
+ "settings/pr_custom_labels.toml",
"settings/pr_add_docs.toml",
- "settings_prod/.secrets.toml"
+ "settings_prod/.secrets.toml",
+ "settings/custom_labels.toml"
]]
)
diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py
index 714e7297..e1a5e56a 100644
--- a/pr_agent/servers/github_action_runner.py
+++ b/pr_agent/servers/github_action_runner.py
@@ -17,8 +17,11 @@ async def run_action():
OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
- get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
+ CUSTOM_LABELS = os.environ.get('CUSTOM_LABELS')
+ CUSTOM_LABELS_DESCRIPTIONS = os.environ.get('CUSTOM_LABELS_DESCRIPTIONS')
+ # CUSTOM_LABELS is a comma separated list of labels (string), convert to list and strip spaces
+ get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
# Check if required environment variables are set
if not GITHUB_EVENT_NAME:
@@ -33,6 +36,7 @@ async def run_action():
if not GITHUB_TOKEN:
print("GITHUB_TOKEN not set")
return
+ CUSTOM_LABELS_DICT = handle_custom_labels(CUSTOM_LABELS, CUSTOM_LABELS_DESCRIPTIONS)
# Set the environment variables in the settings
get_settings().set("OPENAI.KEY", OPENAI_KEY)
@@ -40,6 +44,7 @@ async def run_action():
get_settings().set("OPENAI.ORG", OPENAI_ORG)
get_settings().set("GITHUB.USER_TOKEN", GITHUB_TOKEN)
get_settings().set("GITHUB.DEPLOYMENT_TYPE", "user")
+ get_settings().set("CUSTOM_LABELS", CUSTOM_LABELS_DICT)
# Load the event payload
try:
@@ -88,5 +93,31 @@ async def run_action():
await PRAgent().handle_request(url, body)
+def handle_custom_labels(CUSTOM_LABELS, CUSTOM_LABELS_DESCRIPTIONS):
+ if CUSTOM_LABELS:
+ CUSTOM_LABELS = [x.strip() for x in CUSTOM_LABELS.split(',')]
+ else:
+ # Set default labels
+ CUSTOM_LABELS = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation',
+ 'Other']
+ print(f"Using default labels: {CUSTOM_LABELS}")
+ if CUSTOM_LABELS_DESCRIPTIONS:
+ CUSTOM_LABELS_DESCRIPTIONS = [x.strip() for x in CUSTOM_LABELS_DESCRIPTIONS.split(',')]
+ else:
+ # Set default labels
+ CUSTOM_LABELS_DESCRIPTIONS = ['Fixes a bug in the code', 'Adds or modifies tests',
+ 'Fixes a bug in the code and adds or modifies tests',
+ 'Refactors the code without changing its functionality',
+ 'Adds new features or functionality',
+ 'Adds or modifies documentation',
+ 'Other changes that do not fit in any of the above categories']
+ print(f"Using default labels: {CUSTOM_LABELS_DESCRIPTIONS}")
+ # create a dictionary of labels and descriptions
+ CUSTOM_LABELS_DICT = dict()
+ for i in range(len(CUSTOM_LABELS)):
+ CUSTOM_LABELS_DICT[CUSTOM_LABELS[i]] = {'description': CUSTOM_LABELS_DESCRIPTIONS[i]}
+ return CUSTOM_LABELS_DICT
+
+
if __name__ == '__main__':
asyncio.run(run_action())
\ No newline at end of file
diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml
index 97d92261..9486c740 100644
--- a/pr_agent/settings/configuration.toml
+++ b/pr_agent/settings/configuration.toml
@@ -33,10 +33,13 @@ add_original_user_description=false
keep_original_user_title=false
use_bullet_points=true
extra_instructions = ""
+
# markers
use_description_markers=false
include_generated_by_header=true
+#custom_labels = ['Bug fix', 'Tests', 'Bug fix with tests', 'Refactoring', 'Enhancement', 'Documentation', 'Other']
+
[pr_questions] # /ask #
[pr_code_suggestions] # /improve #
diff --git a/pr_agent/settings/custom_labels.toml b/pr_agent/settings/custom_labels.toml
new file mode 100644
index 00000000..8b1340f2
--- /dev/null
+++ b/pr_agent/settings/custom_labels.toml
@@ -0,0 +1,16 @@
+enable_custom_labels=false
+
+[custom_labels."Bug fix"]
+description = "Fixes a bug in the code"
+[custom_labels."Tests"]
+description = "Adds or modifies tests"
+[custom_labels."Bug fix with tests"]
+description = "Fixes a bug in the code and adds or modifies tests"
+[custom_labels."Refactoring"]
+description = "Code refactoring without changing functionality"
+[custom_labels."Enhancement"]
+description = "Adds new features or functionality"
+[custom_labels."Documentation"]
+description = "Adds or modifies documentation"
+[custom_labels."Other"]
+description = "Other changes that do not fit in any of the above categories"
\ No newline at end of file
diff --git a/pr_agent/settings/pr_custom_labels.toml b/pr_agent/settings/pr_custom_labels.toml
new file mode 100644
index 00000000..96cee17b
--- /dev/null
+++ b/pr_agent/settings/pr_custom_labels.toml
@@ -0,0 +1,72 @@
+[pr_custom_labels_prompt]
+system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
+Your task is to label the type of the PR content.
+- Make sure not to focus the new PR code (the '+' lines).
+- If needed, each YAML output should be in block scalar format ('|-')
+{%- if extra_instructions %}
+
+Extra instructions from the user:
+'
+{{ extra_instructions }}
+'
+{% endif %}
+
+You must use the following YAML schema to format your answer:
+```yaml
+PR Type:
+ type: array
+{%- if enable_custom_labels %}
+ description: One or more labels that describe the PR type. Don't output the description in the parentheses.
+{%- endif %}
+ items:
+ type: string
+ enum:
+{%- if enable_custom_labels %}
+{{ custom_labels }}
+{%- else %}
+ - Bug fix
+ - Tests
+ - Refactoring
+ - Enhancement
+ - Documentation
+ - Other
+{%- endif %}
+
+Example output:
+```yaml
+PR Type:
+{%- if enable_custom_labels %}
+{{ custom_labels_examples }}
+{%- else %}
+ - Bug fix
+ - Tests
+{%- endif %}
+```
+
+Make sure to output a valid YAML. Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
+"""
+
+user="""PR Info:
+Previous title: '{{title}}'
+Previous description: '{{description}}'
+Branch: '{{branch}}'
+{%- if language %}
+
+Main language: {{language}}
+{%- endif %}
+{%- if commit_messages_str %}
+
+Commit messages:
+{{commit_messages_str}}
+{%- endif %}
+
+
+The PR Git Diff:
+```
+{{diff}}
+```
+Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines.
+
+Response (should be a valid YAML, and nothing else):
+```yaml
+"""
diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml
index 7d1621b0..3d9c3f0b 100644
--- a/pr_agent/settings/pr_description_prompts.toml
+++ b/pr_agent/settings/pr_description_prompts.toml
@@ -19,16 +19,22 @@ PR Title:
description: an informative title for the PR, describing its main theme
PR Type:
type: array
+{%- if enable_custom_labels %}
+ description: One or more labels that describe the PR type. Don't output the description in the parentheses.
+{%- endif %}
items:
type: string
enum:
+{%- if enable_custom_labels %}
+{{ custom_labels }}
+{%- else %}
- Bug fix
- Tests
- - Bug fix with tests
- Refactoring
- Enhancement
- Documentation
- Other
+{%- endif %}
PR Description:
type: string
description: an informative and concise description of the PR.
@@ -52,7 +58,11 @@ Example output:
PR Title: |-
...
PR Type:
+{%- if enable_custom_labels %}
+{{ custom_labels_examples }}
+{%- else %}
- Bug fix
+{%- endif %}
PR Description: |-
...
PR Main Files Walkthrough:
diff --git a/pr_agent/settings/pr_reviewer_prompts.toml b/pr_agent/settings/pr_reviewer_prompts.toml
index cb49b5d0..b717ec3d 100644
--- a/pr_agent/settings/pr_reviewer_prompts.toml
+++ b/pr_agent/settings/pr_reviewer_prompts.toml
@@ -51,13 +51,22 @@ PR Analysis:
description: summary of the PR in 2-3 sentences.
Type of PR:
type: string
+{%- if enable_custom_labels %}
+ description: One or more labels that describe the PR type. Don't output the description in the parentheses.
+{%- endif %}
+ items:
+ type: string
enum:
+{%- if enable_custom_labels %}
+{{ custom_labels }}
+{%- else %}
- Bug fix
- Tests
- Refactoring
- Enhancement
- Documentation
- Other
+{%- endif %}
{%- if require_score %}
Score:
type: int
diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py
index c1bd03fd..a88ff336 100644
--- a/pr_agent/tools/pr_description.py
+++ b/pr_agent/tools/pr_description.py
@@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
-from pr_agent.algo.utils import load_yaml
+from pr_agent.algo.utils import load_yaml, set_custom_labels
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@@ -42,7 +42,10 @@ class PRDescription:
"diff": "", # empty diff for initial calculation
"use_bullet_points": get_settings().pr_description.use_bullet_points,
"extra_instructions": get_settings().pr_description.extra_instructions,
- "commit_messages_str": self.git_provider.get_commit_messages()
+ "commit_messages_str": self.git_provider.get_commit_messages(),
+ "enable_custom_labels": get_settings().enable_custom_labels,
+ "custom_labels": "",
+ "custom_labels_examples": "",
}
self.user_description = self.git_provider.get_user_description()
@@ -140,6 +143,7 @@ class PRDescription:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
+ set_custom_labels(variables)
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
@@ -156,7 +160,6 @@ class PRDescription:
return response
-
def _prepare_data(self):
# Load the AI prediction data into a dictionary
self.data = load_yaml(self.prediction.strip())
diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py
new file mode 100644
index 00000000..bf5b5f98
--- /dev/null
+++ b/pr_agent/tools/pr_generate_labels.py
@@ -0,0 +1,163 @@
+import copy
+import re
+from typing import List, Tuple
+
+from jinja2 import Environment, StrictUndefined
+
+from pr_agent.algo.ai_handler import AiHandler
+from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
+from pr_agent.algo.token_handler import TokenHandler
+from pr_agent.algo.utils import load_yaml, set_custom_labels
+from pr_agent.config_loader import get_settings
+from pr_agent.git_providers import get_git_provider
+from pr_agent.git_providers.git_provider import get_main_pr_language
+from pr_agent.log import get_logger
+
+
+class PRGenerateLabels:
+ def __init__(self, pr_url: str, args: list = None):
+ """
+ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
+ corresponding to the PR using an AI model.
+ Args:
+ pr_url (str): The URL of the pull request.
+ args (list, optional): List of arguments passed to the PRGenerateLabels class. Defaults to None.
+ """
+ # Initialize the git provider and main PR language
+ self.git_provider = get_git_provider()(pr_url)
+ self.main_pr_language = get_main_pr_language(
+ self.git_provider.get_languages(), self.git_provider.get_files()
+ )
+ self.pr_id = self.git_provider.get_pr_id()
+
+ # Initialize the AI handler
+ self.ai_handler = AiHandler()
+
+ # Initialize the variables dictionary
+ self.vars = {
+ "title": self.git_provider.pr.title,
+ "branch": self.git_provider.get_pr_branch(),
+ "description": self.git_provider.get_pr_description(full=False),
+ "language": self.main_pr_language,
+ "diff": "", # empty diff for initial calculation
+ "use_bullet_points": get_settings().pr_description.use_bullet_points,
+ "extra_instructions": get_settings().pr_description.extra_instructions,
+ "commit_messages_str": self.git_provider.get_commit_messages(),
+ "custom_labels": "",
+ "custom_labels_examples": "",
+ "enable_custom_labels": get_settings().enable_custom_labels,
+ }
+
+ # Initialize the token handler
+ self.token_handler = TokenHandler(
+ self.git_provider.pr,
+ self.vars,
+ get_settings().pr_custom_labels_prompt.system,
+ get_settings().pr_custom_labels_prompt.user,
+ )
+
+ # Initialize patches_diff and prediction attributes
+ self.patches_diff = None
+ self.prediction = None
+
+ async def run(self):
+ """
+ Generates a PR labels using an AI model and publishes it to the PR.
+ """
+
+ try:
+ get_logger().info(f"Generating a PR labels {self.pr_id}")
+ if get_settings().config.publish_output:
+ self.git_provider.publish_comment("Preparing PR labels...", is_temporary=True)
+
+ await retry_with_fallback_models(self._prepare_prediction)
+
+ get_logger().info(f"Preparing answer {self.pr_id}")
+ if self.prediction:
+ self._prepare_data()
+ else:
+ return None
+
+ pr_labels = self._prepare_labels()
+
+ if get_settings().config.publish_output:
+ get_logger().info(f"Pushing labels {self.pr_id}")
+ if self.git_provider.is_supported("get_labels"):
+ current_labels = self.git_provider.get_labels()
+ if current_labels is None:
+ current_labels = []
+ self.git_provider.publish_labels(pr_labels + current_labels)
+ self.git_provider.remove_initial_comment()
+ except Exception as e:
+ get_logger().error(f"Error generating PR labels {self.pr_id}: {e}")
+
+ return ""
+
+ async def _prepare_prediction(self, model: str) -> None:
+ """
+ Prepare the AI prediction for the PR labels based on the provided model.
+
+ Args:
+ model (str): The name of the model to be used for generating the prediction.
+
+ Returns:
+ None
+
+ Raises:
+ Any exceptions raised by the 'get_pr_diff' and '_get_prediction' functions.
+
+ """
+
+ get_logger().info(f"Getting PR diff {self.pr_id}")
+ self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
+ get_logger().info(f"Getting AI prediction {self.pr_id}")
+ self.prediction = await self._get_prediction(model)
+
+ async def _get_prediction(self, model: str) -> str:
+ """
+ Generate an AI prediction for the PR labels based on the provided model.
+
+ Args:
+ model (str): The name of the model to be used for generating the prediction.
+
+ Returns:
+ str: The generated AI prediction.
+ """
+ variables = copy.deepcopy(self.vars)
+ variables["diff"] = self.patches_diff # update diff
+
+ environment = Environment(undefined=StrictUndefined)
+ set_custom_labels(variables)
+ system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
+ user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
+
+ if get_settings().config.verbosity_level >= 2:
+ get_logger().info(f"\nSystem prompt:\n{system_prompt}")
+ get_logger().info(f"\nUser prompt:\n{user_prompt}")
+
+ response, finish_reason = await self.ai_handler.chat_completion(
+ model=model,
+ temperature=0.2,
+ system=system_prompt,
+ user=user_prompt
+ )
+
+ return response
+
+ def _prepare_data(self):
+ # Load the AI prediction data into a dictionary
+ self.data = load_yaml(self.prediction.strip())
+
+
+
+ def _prepare_labels(self) -> List[str]:
+ pr_types = []
+
+ # If the 'PR Type' key is present in the dictionary, split its value by comma and assign it to 'pr_types'
+ if 'PR Type' in self.data:
+ if type(self.data['PR Type']) == list:
+ pr_types = self.data['PR Type']
+ elif type(self.data['PR Type']) == str:
+ pr_types = self.data['PR Type'].split(',')
+
+ return pr_types
diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py
index ed99ddf6..0eeb5578 100644
--- a/pr_agent/tools/pr_reviewer.py
+++ b/pr_agent/tools/pr_reviewer.py
@@ -9,7 +9,7 @@ from yaml import SafeLoader
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
-from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml
+from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
@@ -63,6 +63,8 @@ class PRReviewer:
'answer_str': answer_str,
"extra_instructions": get_settings().pr_reviewer.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(),
+ "custom_labels": "",
+ "enable_custom_labels": get_settings().enable_custom_labels,
}
self.token_handler = TokenHandler(
@@ -149,6 +151,7 @@ class PRReviewer:
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
+ set_custom_labels(variables)
system_prompt = environment.from_string(get_settings().pr_review_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_review_prompt.user).render(variables)