diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..0a905f38 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,6 @@ +## 2023-07-26 + +### Added +- New feature for updating the CHANGELOG.md based on the contents of a PR. +- Added support for this feature for the Github provider. +- New configuration settings and prompts for the changelog update feature. \ No newline at end of file diff --git a/README.md b/README.md index 173b7844..33c1b763 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull - [Usage and tools](#usage-and-tools) - [Configuration](./CONFIGURATION.md) - [How it works](#how-it-works) +- [Why use PR-Agent](#why-use-pr-agent) - [Roadmap](#roadmap) - [Similar projects](#similar-projects) @@ -81,6 +82,7 @@ CodiumAI `PR-Agent` is an open-source tool aiming to help developers review pull | | Auto-Description | :white_check_mark: | :white_check_mark: | | | | Improve Code | :white_check_mark: | :white_check_mark: | | | | Reflect and Review | :white_check_mark: | | | +| | Update CHANGELOG.md | :white_check_mark: | | | | | | | | | | USAGE | CLI | :white_check_mark: | :white_check_mark: | :white_check_mark: | | | App / webhook | :white_check_mark: | :white_check_mark: | | @@ -98,6 +100,7 @@ Examples for invoking the different tools via the CLI: - **Improve**: python cli.py --pr-url= improve - **Ask**: python cli.py --pr-url= ask "Write me a poem about this PR" - **Reflect**: python cli.py --pr-url= reflect +- **Update changelog**: python cli.py --pr-url= update_changelog "" is the url of the relevant PR (for example: https://github.com/Codium-ai/pr-agent/pull/50). @@ -146,6 +149,19 @@ There are several ways to use PR-Agent: Check out the [PR Compression strategy](./PR_COMPRESSION.md) page for more details on how we convert a code diff to a manageable LLM prompt +## Why use PR-Agent? + +A reasonable question that can be asked is: `"Why use PR-Agent? What make it stand out from existing tools?"` + +Here are some of the reasons why: + +- We emphasize **real-life practical usage**. Each tool (review, improve, ask, ...) has a single GPT-4 call, no more. We feel that this is critical for realistic team usage - obtaining an answer quickly (~30 seconds) and affordably. +- Our [PR Compression strategy](./PR_COMPRESSION.md) is a core ability that enables to effectively tackle both short and long PRs. +- Our JSON prompting strategy enables to have **modular, customizable tools**. For example, the '/review' tool categories can be controlled via the configuration file. Adding additional categories is easy and accessible. +- We support **multiple git providers** (GitHub, Gitlab, Bitbucket), and multiple ways to use the tool (CLI, GitHub Action, Docker, ...). +- We are open-source, and welcome contributions from the community. + + ## Roadmap - [ ] Support open-source models, as a replacement for OpenAI models. (Note - a minimal requirement for each open-source model is to have 8k+ context, and good support for generating JSON as an output) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 7aa61c03..a30c411b 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -6,6 +6,7 @@ from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer +from pr_agent.tools.pr_update_changelog import PRUpdateChangelog class PRAgent: @@ -26,7 +27,9 @@ class PRAgent: elif any(cmd == action for cmd in ["/improve", "/improve_code"]): await PRCodeSuggestions(pr_url).suggest() elif any(cmd == action for cmd in ["/ask", "/ask_question"]): - await PRQuestions(pr_url, args).answer() + await PRQuestions(pr_url, args=args).answer() + elif any(cmd == action for cmd in ["/update_changelog"]): + await PRUpdateChangelog(pr_url, args=args).update_changelog() else: return False diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 20933d51..45ef40b2 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging from typing import Tuple, Union, Callable, List +from github import RateLimitExceededException + from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages @@ -19,7 +21,6 @@ OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD = 1000 OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 - def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: """ @@ -40,7 +41,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s global PATCH_EXTRA_LINES PATCH_EXTRA_LINES = 0 - diff_files = list(git_provider.get_diff_files()) + try: + diff_files = list(git_provider.get_diff_files()) + except RateLimitExceededException as e: + logging.error(f"Rate limit exceeded for git provider API. original message {e}") + raise # get pr languages pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 4477016c..f04e51d7 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -8,6 +8,7 @@ from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_information_from_user import PRInformationFromUser from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer +from pr_agent.tools.pr_update_changelog import PRUpdateChangelog def run(args=None): @@ -27,13 +28,15 @@ ask / ask_question [question] - Ask a question about the PR. describe / describe_pr - Modify the PR title and description based on the PR's contents. improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit. reflect - Ask the PR author questions about the PR. +update_changelog - Update the changelog based on the PR's contents. """) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr', 'ask', 'ask_question', 'describe', 'describe_pr', 'improve', 'improve_code', - 'reflect', 'review_after_reflect'], + 'reflect', 'review_after_reflect', + 'update_changelog'], default='review') parser.add_argument('rest', nargs=argparse.REMAINDER, default=[]) args = parser.parse_args(args) @@ -49,7 +52,8 @@ reflect - Ask the PR author questions about the PR. 'review': _handle_review_command, 'review_pr': _handle_review_command, 'reflect': _handle_reflect_command, - 'review_after_reflect': _handle_review_after_reflect_command + 'review_after_reflect': _handle_review_after_reflect_command, + 'update_changelog': _handle_update_changelog, } if command in commands: commands[command](args.pr_url, args.rest) @@ -96,6 +100,10 @@ def _handle_review_after_reflect_command(pr_url: str, rest: list): reviewer = PRReviewer(pr_url, cli_mode=True, is_answer=True) asyncio.run(reviewer.review()) +def _handle_update_changelog(pr_url: str, rest: list): + print(f"Updating changlog for: {pr_url}") + reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest) + asyncio.run(reviewer.update_changelog()) if __name__ == '__main__': run() diff --git a/pr_agent/config_loader.py b/pr_agent/config_loader.py index 7841f0b7..69d20d88 100644 --- a/pr_agent/config_loader.py +++ b/pr_agent/config_loader.py @@ -19,6 +19,7 @@ settings = Dynaconf( "settings/pr_description_prompts.toml", "settings/pr_code_suggestions_prompts.toml", "settings/pr_information_from_user_prompts.toml", + "settings/pr_update_changelog.toml", "settings_prod/.secrets.toml" ]] ) diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 3f7c1ef2..677c2eb1 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -136,3 +136,4 @@ class IncrementalPR: self.commits_range = None self.first_new_commit_sha = None self.last_seen_commit_sha = None + diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 10a50412..7f617937 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -3,13 +3,15 @@ from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github import AppAuthentication, Github, Auth +from github import AppAuthentication, Auth, Github, GithubException +from retry import retry from pr_agent.config_loader import settings -from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file from ..algo.utils import load_large_diff +from .git_provider import FilePatchInfo, GitProvider, IncrementalPR +from ..servers.utils import RateLimitExceeded class GithubProvider(GitProvider): @@ -78,27 +80,34 @@ class GithubProvider(GitProvider): return self.file_set.values() return self.pr.get_files() + @retry(exceptions=RateLimitExceeded, + tries=settings.github.ratelimit_retries, delay=2, backoff=2, jitter=(1, 3)) def get_diff_files(self) -> list[FilePatchInfo]: - files = self.get_files() - diff_files = [] - for file in files: - if is_valid_file(file.filename): - new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) - patch = file.patch - if self.incremental.is_incremental and self.file_set: - original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) - patch = load_large_diff(file, - new_file_content_str, - original_file_content_str, - None) - self.file_set[file.filename] = patch - else: - original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) + try: + files = self.get_files() + diff_files = [] + for file in files: + if is_valid_file(file.filename): + new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) + patch = file.patch + if self.incremental.is_incremental and self.file_set: + original_file_content_str = self._get_pr_file_content(file, + self.incremental.last_seen_commit_sha) + patch = load_large_diff(file, + new_file_content_str, + original_file_content_str, + None) + self.file_set[file.filename] = patch + else: + original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) - diff_files.append( - FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) - self.diff_files = diff_files - return diff_files + diff_files.append( + FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) + self.diff_files = diff_files + return diff_files + except GithubException.RateLimitExceededException as e: + logging.error(f"Rate limit exceeded for GitHub API. Original message: {e}") + raise RateLimitExceeded("Rate limit exceeded for GitHub API.") from e def publish_description(self, pr_title: str, pr_body: str): self.pr.edit(title=pr_title, body=pr_body) diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 4a7775ac..304417ea 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -30,6 +30,8 @@ class LocalGitProvider(GitProvider): def __init__(self, target_branch_name, incremental=False): self.repo_path = _find_repository_root() + if self.repo_path is None: + raise ValueError('Could not find repository root') self.repo = Repo(self.repo_path) self.head_branch_name = self.repo.head.ref.name self.target_branch_name = target_branch_name @@ -167,7 +169,7 @@ class LocalGitProvider(GitProvider): """ Substitutes the branch-name as the PR-mimic title. """ - return self.target_branch_name + return self.head_branch_name def get_issue_comments(self): raise NotImplementedError('Getting issue comments is not implemented for the local git provider') diff --git a/pr_agent/servers/utils.py b/pr_agent/servers/utils.py index 942ac449..c24b880c 100644 --- a/pr_agent/servers/utils.py +++ b/pr_agent/servers/utils.py @@ -21,3 +21,7 @@ def verify_signature(payload_body, secret_token, signature_header): if not hmac.compare_digest(expected_signature, signature_header): raise HTTPException(status_code=403, detail="Request signatures didn't match!") + +class RateLimitExceeded(Exception): + """Raised when the git provider API rate limit has been exceeded.""" + pass diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index fbf8ffec..58f4ba32 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,6 +1,6 @@ [config] model="gpt-4" -fallback-models=["gpt-3.5-turbo-16k", "gpt-3.5-turbo"] +fallback_models=["gpt-3.5-turbo-16k"] git_provider="github" publish_output=true publish_output_progress=true @@ -24,9 +24,13 @@ publish_description_as_comment=false [pr_code_suggestions] num_code_suggestions=4 +[pr_update_changelog] +push_changelog_changes=false + [github] # The type of deployment to create. Valid values are 'app' or 'user'. deployment_type = "user" +ratelimit_retries = 5 [gitlab] # URL to the gitlab service diff --git a/pr_agent/settings/pr_update_changelog.toml b/pr_agent/settings/pr_update_changelog.toml new file mode 100644 index 00000000..91413010 --- /dev/null +++ b/pr_agent/settings/pr_update_changelog.toml @@ -0,0 +1,34 @@ +[pr_update_changelog_prompt] +system="""You are a language model called CodiumAI-PR-Changlog-summarizer. +Your task is to update the CHANGELOG.md file of the project, to shortly summarize important changes introduced in this PR (the '+' lines). +- The output should match the existing CHANGELOG.md format, style and conventions, so it will look like a natural part of the file. For example, if previous changes were summarized in a single line, you should do the same. +- Don't repeat previous changes. Generate only new content, that is not already in the CHANGELOG.md file. +- Be general, and avoid specific details, files, etc. The output should be minimal, no more than 3-4 short lines. Ignore non-relevant subsections. +""" + +user="""PR Info: +Title: '{{title}}' +Branch: '{{branch}}' +Description: '{{description}}' +{%- if language %} +Main language: {{language}} +{%- endif %} + + +The PR Diff: +``` +{{diff}} +``` + +Current date: +``` +{{today}} +``` + +The current CHANGELOG.md: +``` +{{ changelog_file_str }} +``` + +Response: +""" diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 0bf952dd..318b3c5e 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -2,6 +2,7 @@ import copy import json import logging from collections import OrderedDict +from typing import Tuple, List from jinja2 import Environment, StrictUndefined @@ -16,7 +17,19 @@ from pr_agent.servers.help import actions_help_text, bot_help_text class PRReviewer: - def __init__(self, pr_url: str, cli_mode=False, is_answer: bool = False, args=None): + """ + The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. + """ + def __init__(self, pr_url: str, cli_mode: bool = False, is_answer: bool = False, args: list = None): + """ + Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. + + Args: + pr_url (str): The URL of the pull request to be reviewed. + cli_mode (bool, optional): Indicates whether the review is being done in command-line interface mode. Defaults to False. + is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. + args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. + """ self.parse_args(args) self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) @@ -25,13 +38,15 @@ class PRReviewer: ) self.pr_url = pr_url self.is_answer = is_answer + if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") - answer_str, question_str = self._get_user_answers() self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None self.cli_mode = cli_mode + + answer_str, question_str = self._get_user_answers() self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), @@ -43,16 +58,27 @@ class PRReviewer: "require_security": settings.pr_reviewer.require_security_review, "require_focused": settings.pr_reviewer.require_focused_review, 'num_code_suggestions': settings.pr_reviewer.num_code_suggestions, - # 'question_str': question_str, 'answer_str': answer_str, } - self.token_handler = TokenHandler(self.git_provider.pr, - self.vars, - settings.pr_review_prompt.system, - settings.pr_review_prompt.user) - def parse_args(self, args): + self.token_handler = TokenHandler( + self.git_provider.pr, + self.vars, + settings.pr_review_prompt.system, + settings.pr_review_prompt.user + ) + + def parse_args(self, args: List[str]) -> None: + """ + Parse the arguments passed to the PRReviewer class and set the 'incremental' attribute accordingly. + + Args: + args: A list of arguments passed to the PRReviewer class. + + Returns: + None + """ is_incremental = False if args and len(args) >= 1: arg = args[0] @@ -60,60 +86,93 @@ class PRReviewer: is_incremental = True self.incremental = IncrementalPR(is_incremental) - async def review(self): + async def review(self) -> None: + """ + Review the pull request and generate feedback. + """ logging.info('Reviewing PR...') + if settings.config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) + await retry_with_fallback_models(self._prepare_prediction) + logging.info('Preparing PR review...') pr_comment = self._prepare_pr_review() + if settings.config.publish_output: logging.info('Pushing PR review...') self.git_provider.publish_comment(pr_comment) self.git_provider.remove_initial_comment() + if settings.pr_reviewer.inline_code_comments: logging.info('Pushing inline code comments...') self._publish_inline_code_comments() - return "" - async def _prepare_prediction(self, model: str): + async def _prepare_prediction(self, model: str) -> None: + """ + Prepare the AI prediction for the pull request review. + + Args: + model: A string representing the AI model to be used for the prediction. + + Returns: + None + """ logging.info('Getting PR diff...') self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) logging.info('Getting AI prediction...') self.prediction = await self._get_prediction(model) - async def _get_prediction(self, model: str): + async def _get_prediction(self, model: str) -> str: + """ + Generate an AI prediction for the pull request review. + + Args: + model: A string representing the AI model to be used for the prediction. + + Returns: + A string representing the AI prediction for the pull request review. + """ variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff + environment = Environment(undefined=StrictUndefined) system_prompt = environment.from_string(settings.pr_review_prompt.system).render(variables) user_prompt = environment.from_string(settings.pr_review_prompt.user).render(variables) + if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, - system=system_prompt, user=user_prompt) + + response, finish_reason = await self.ai_handler.chat_completion( + model=model, + temperature=0.2, + system=system_prompt, + user=user_prompt + ) return response def _prepare_pr_review(self) -> str: + """ + Prepare the PR review by processing the AI prediction and generating a markdown-formatted text that summarizes the feedback. + """ review = self.prediction.strip() + try: data = json.loads(review) except json.decoder.JSONDecodeError: data = try_fix_json(review) - # reordering for nicer display - if 'PR Feedback' in data: - if 'Security concerns' in data['PR Feedback']: - val = data['PR Feedback']['Security concerns'] - del data['PR Feedback']['Security concerns'] - data['PR Analysis']['Security concerns'] = val + # Move 'Security concerns' key to 'PR Analysis' section for better display + if 'PR Feedback' in data and 'Security concerns' in data['PR Feedback']: + val = data['PR Feedback']['Security concerns'] + del data['PR Feedback']['Security concerns'] + data['PR Analysis']['Security concerns'] = val - if settings.config.git_provider != 'bitbucket' and \ - settings.pr_reviewer.inline_code_comments and \ - 'Code suggestions' in data['PR Feedback']: - # keeping only code suggestions that can't be submitted as inline comments + # Filter out code suggestions that can be submitted as inline comments + if settings.config.git_provider != 'bitbucket' and settings.pr_reviewer.inline_code_comments and 'Code suggestions' in data['PR Feedback']: data['PR Feedback']['Code suggestions'] = [ d for d in data['PR Feedback']['Code suggestions'] if any(key not in d for key in ('relevant file', 'relevant line in file', 'suggestion content')) @@ -121,8 +180,8 @@ class PRReviewer: if not data['PR Feedback']['Code suggestions']: del data['PR Feedback']['Code suggestions'] + # Add incremental review section if self.incremental.is_incremental: - # Rename title when incremental review - Add to the beginning of the dict last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" data = OrderedDict(data) data.update({'Incremental PR Review': { @@ -132,6 +191,7 @@ class PRReviewer: markdown_text = convert_to_markdown(data) user = self.git_provider.get_user_id() + # Add help text if not in CLI mode if not self.cli_mode: markdown_text += "\n### How to use\n" if user and '[bot]' not in user: @@ -139,11 +199,16 @@ class PRReviewer: else: markdown_text += actions_help_text + # Log markdown response if verbosity level is high if settings.config.verbosity_level >= 2: logging.info(f"Markdown response:\n{markdown_text}") + return markdown_text - def _publish_inline_code_comments(self): + def _publish_inline_code_comments(self) -> None: + """ + Publishes inline comments on a pull request with code suggestions generated by the AI model. + """ if settings.pr_reviewer.num_code_suggestions == 0: return @@ -153,11 +218,11 @@ class PRReviewer: except json.decoder.JSONDecodeError: data = try_fix_json(review) - comments = [] - for d in data['PR Feedback']['Code suggestions']: - relevant_file = d.get('relevant file', '').strip() - relevant_line_in_file = d.get('relevant line in file', '').strip() - content = d.get('suggestion content', '') + comments: List[str] = [] + for suggestion in data.get('PR Feedback', {}).get('Code suggestions', []): + relevant_file = suggestion.get('relevant file', '').strip() + relevant_line_in_file = suggestion.get('relevant line in file', '').strip() + content = suggestion.get('suggestion content', '') if not relevant_file or not relevant_line_in_file or not content: logging.info("Skipping inline comment with missing file/line/content") continue @@ -172,15 +237,26 @@ class PRReviewer: if comments: self.git_provider.publish_inline_comments(comments) - def _get_user_answers(self): - answer_str = question_str = "" + def _get_user_answers(self) -> Tuple[str, str]: + """ + Retrieves the question and answer strings from the discussion messages related to a pull request. + + Returns: + A tuple containing the question and answer strings. + """ + question_str = "" + answer_str = "" + if self.is_answer: discussion_messages = self.git_provider.get_issue_comments() - for message in discussion_messages.reversed: + + for message in reversed(discussion_messages): if "Questions to better understand the PR:" in message.body: question_str = message.body elif '/answer' in message.body: answer_str = message.body + if answer_str and question_str: break + return question_str, answer_str diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py new file mode 100644 index 00000000..1b06c381 --- /dev/null +++ b/pr_agent/tools/pr_update_changelog.py @@ -0,0 +1,171 @@ +import copy +import logging +from datetime import date +from time import sleep +from typing import Tuple + +from jinja2 import Environment, StrictUndefined + +from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models +from pr_agent.algo.token_handler import TokenHandler +from pr_agent.config_loader import settings +from pr_agent.git_providers import get_git_provider, GithubProvider +from pr_agent.git_providers.git_provider import get_main_pr_language + +CHANGELOG_LINES = 50 + + +class PRUpdateChangelog: + def __init__(self, pr_url: str, cli_mode=False, args=None): + + self.git_provider = get_git_provider()(pr_url) + self.main_language = get_main_pr_language( + self.git_provider.get_languages(), self.git_provider.get_files() + ) + self.commit_changelog = self._parse_args(args, settings) + self._get_changlog_file() # self.changelog_file_str + self.ai_handler = AiHandler() + self.patches_diff = None + self.prediction = None + self.cli_mode = cli_mode + self.vars = { + "title": self.git_provider.pr.title, + "branch": self.git_provider.get_pr_branch(), + "description": self.git_provider.get_pr_description(), + "language": self.main_language, + "diff": "", # empty diff for initial calculation + "changelog_file_str": self.changelog_file_str, + "today": date.today(), + } + self.token_handler = TokenHandler(self.git_provider.pr, + self.vars, + settings.pr_update_changelog_prompt.system, + settings.pr_update_changelog_prompt.user) + + async def update_changelog(self): + assert type(self.git_provider) == GithubProvider, "Currently only Github is supported" + + logging.info('Updating the changelog...') + if settings.config.publish_output: + self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True) + await retry_with_fallback_models(self._prepare_prediction) + logging.info('Preparing PR changelog updates...') + new_file_content, answer = self._prepare_changelog_update() + if settings.config.publish_output: + self.git_provider.remove_initial_comment() + logging.info('Publishing changelog updates...') + if self.commit_changelog: + logging.info('Pushing PR changelog updates to repo...') + self._push_changelog_update(new_file_content, answer) + else: + logging.info('Publishing PR changelog as comment...') + self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}") + + async def _prepare_prediction(self, model: str): + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): + variables = copy.deepcopy(self.vars) + variables["diff"] = self.patches_diff # update diff + environment = Environment(undefined=StrictUndefined) + system_prompt = environment.from_string(settings.pr_update_changelog_prompt.system).render(variables) + user_prompt = environment.from_string(settings.pr_update_changelog_prompt.user).render(variables) + if settings.config.verbosity_level >= 2: + logging.info(f"\nSystem prompt:\n{system_prompt}") + logging.info(f"\nUser prompt:\n{user_prompt}") + response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, + system=system_prompt, user=user_prompt) + + return response + + def _prepare_changelog_update(self) -> Tuple[str, str]: + answer = self.prediction.strip().strip("```").strip() + if hasattr(self, "changelog_file"): + existing_content = self.changelog_file.decoded_content.decode() + else: + existing_content = "" + if existing_content: + new_file_content = answer + "\n\n" + self.changelog_file.decoded_content.decode() + else: + new_file_content = answer + + if not self.commit_changelog: + answer += "\n\n\n>to commit the new content to the CHANGELOG.md file, please type:" \ + "\n>'/update_changelog -commit'\n" + + if settings.config.verbosity_level >= 2: + logging.info(f"answer:\n{answer}") + + return new_file_content, answer + + def _push_changelog_update(self, new_file_content, answer): + self.git_provider.repo_obj.update_file(path=self.changelog_file.path, + message="Update CHANGELOG.md", + content=new_file_content, + sha=self.changelog_file.sha, + branch=self.git_provider.get_pr_branch()) + d = dict(body="CHANGELOG.md update", + path=self.changelog_file.path, + line=max(2, len(answer.splitlines())), + start_line=1) + + sleep(5) # wait for the file to be updated + last_commit_id = list(self.git_provider.pr.get_commits())[-1] + try: + self.git_provider.pr.create_review(commit=last_commit_id, comments=[d]) + except: + # we can't create a review for some reason, let's just publish a comment + self.git_provider.publish_comment(f"**Changelog updates:**\n\n{answer}") + + + def _get_default_changelog(self): + example_changelog = \ +""" +Example: +## + +### Added +... +### Changed +... +### Fixed +... +""" + return example_changelog + + def _parse_args(self, args, setting): + commit_changelog = False + if args and len(args) >= 1: + try: + if args[0] == "-commit": + commit_changelog = True + except: + pass + else: + commit_changelog = setting.pr_update_changelog.push_changelog_changes + + return commit_changelog + + def _get_changlog_file(self): + try: + self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md", + ref=self.git_provider.get_pr_branch()) + changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() + changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] + self.changelog_file_str = "\n".join(changelog_file_lines) + except: + self.changelog_file_str = "" + if self.commit_changelog: + logging.info("No CHANGELOG.md file found in the repository. Creating one...") + changelog_file = self.git_provider.repo_obj.create_file(path="CHANGELOG.md", + message='add CHANGELOG.md', + content="", + branch=self.git_provider.get_pr_branch()) + self.changelog_file = changelog_file['content'] + + if not self.changelog_file_str: + self.changelog_file_str = self._get_default_changelog()