diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 66d00cd1..8b280065 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -3,6 +3,7 @@ import asyncio import logging import os +from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer @@ -11,12 +12,17 @@ def run(): parser = argparse.ArgumentParser(description='AI based pull request analyzer') parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('--question', type=str, help='Optional question to ask', required=False) + parser.add_argument('--pr_description', action='store_true', help='Optional question to ask', required=False) args = parser.parse_args() logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) if args.question: print(f"Question: {args.question} about PR {args.pr_url}") reviewer = PRQuestions(args.pr_url, args.question) asyncio.run(reviewer.answer()) + elif args.pr_description: + print(f"PR description: {args.pr_url}") + reviewer = PRDescription(args.pr_url) + asyncio.run(reviewer.describe()) else: print(f"Reviewing PR: {args.pr_url}") reviewer = PRReviewer(args.pr_url, cli_mode=True) diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 550f743e..ed98fccd 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -11,6 +11,7 @@ settings = Dynaconf( "settings/configuration.toml", "settings/pr_reviewer_prompts.toml", "settings/pr_questions_prompts.toml", + "settings/pr_description_prompts.toml", "settings_prod/.secrets.toml" ]] ) diff --git a/pr_agent/settings/pr_description_prompts.toml b/pr_agent/settings/pr_description_prompts.toml new file mode 100644 index 00000000..76e8dd45 --- /dev/null +++ b/pr_agent/settings/pr_description_prompts.toml @@ -0,0 +1,45 @@ +[pr_description_prompt] +system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests. +Your task is to provide full description of the PR content. +- Make sure not to focus the new PR code (the '+' lines). + +You must use the following JSON schema to format your answer: +```json +{ + "PR Title": { + "type": "string", + "description": "an informative title for the PR, describing its main theme" + }, + "Type of PR": { + "type": "string", + "enum": ["Bug fix", "Tests", "Bug fix with tests", "Refactoring", "Enhancement", "Documentation", "Other"] + }, + "PR Description": { + "type": "string", + "description": "an informative and concise description of the PR" + }, + "PR Walkthrough": { + "type": "string", + "description": "a walkthrough of the PR changes. Review file by file, in bullet points, and shortly describe the changes in each file. Format: -`filename`: description of changes\n..." + } +} +""" + +user="""PR Info: +Title: '{{title}}' +Branch: '{{branch}}' +Description: '{{description}}' +{%- if language %} +Main language: {{language}} +{%- endif %} + + +The PR Git Diff: +``` +{{diff}} +``` +Note that lines in the diff body are prefixed with a symbol that represents the type of change: '-' for deletions, '+' for additions, and ' ' (a space) for unchanged lines. + +Response (should be a valid JSON, and nothing else): +```json +""" diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py new file mode 100644 index 00000000..be237254 --- /dev/null +++ b/pr_agent/tools/pr_description.py @@ -0,0 +1,81 @@ +import copy +import json +import logging + +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import convert_to_markdown +from pr_agent.config_loader import settings +from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers.git_provider import get_main_pr_language + + +class PRDescription: + def __init__(self, pr_url: str): + self.git_provider = get_git_provider()(pr_url) + self.main_pr_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) + self.ai_handler = AiHandler() + self.vars = { + "title": self.git_provider.pr.title, + "branch": self.git_provider.get_pr_branch(), + "description": self.git_provider.get_description(), + "language": self.main_pr_language, + "diff": "", # empty diff for initial calculation + } + self.token_handler = TokenHandler(self.git_provider.pr, + self.vars, + settings.pr_description_prompt.system, + settings.pr_description_prompt.user) + self.patches_diff = None + self.prediction = None + + async def describe(self): + logging.info('Answering a PR question...') + if settings.config.publish_review: + self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction() + logging.info('Preparing answer...') + pr_comment = self._prepare_pr_answer() + if settings.config.publish_review: + logging.info('Pushing answer...') + self.git_provider.publish_comment(pr_comment) + self.git_provider.remove_initial_comment() + return "" + + async def _get_prediction(self): + variables = copy.deepcopy(self.vars) + variables["diff"] = self.patches_diff # update diff + environment = Environment(undefined=StrictUndefined) + system_prompt = environment.from_string(settings.pr_description_prompt.system).render(variables) + user_prompt = environment.from_string(settings.pr_description_prompt.user).render(variables) + if settings.config.verbosity_level >= 2: + logging.info(f"\nSystem prompt:\n{system_prompt}") + logging.info(f"\nUser prompt:\n{user_prompt}") + model = settings.config.model + response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, + system=system_prompt, user=user_prompt) + return response + + def _prepare_pr_answer(self) -> str: + data = json.loads(self.prediction) + markdown_text = "" + # for key, value in data.items(): + # markdown_text += f"## {key}\n\n" + # markdown_text += f"{value}\n\n" + for key, value in data.items(): + markdown_text += f"{key}:\n" + if 'walkthrough' not in key.lower(): + markdown_text += f"**{value}**\n" + else: + markdown_text += f"{value}\n\n___\n" + if settings.config.verbosity_level >= 2: + logging.info(f"markdown_text:\n{markdown_text}") + return markdown_text