diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 825338e6..a30c411b 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -27,9 +27,9 @@ class PRAgent: elif any(cmd == action for cmd in ["/improve", "/improve_code"]): await PRCodeSuggestions(pr_url).suggest() elif any(cmd == action for cmd in ["/ask", "/ask_question"]): - await PRQuestions(pr_url, args).answer() + await PRQuestions(pr_url, args=args).answer() elif any(cmd == action for cmd in ["/update_changelog"]): - await PRUpdateChangelog(pr_url, args).update_changelog() + await PRUpdateChangelog(pr_url, args=args).update_changelog() else: return False diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 2bae6bd4..f04e51d7 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -102,7 +102,7 @@ def _handle_review_after_reflect_command(pr_url: str, rest: list): def _handle_update_changelog(pr_url: str, rest: list): print(f"Updating changlog for: {pr_url}") - reviewer = PRUpdateChangelog(pr_url, cli_mode=True) + reviewer = PRUpdateChangelog(pr_url, cli_mode=True, args=rest) asyncio.run(reviewer.update_changelog()) if __name__ == '__main__': diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index aa24afb6..901c5b45 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -17,35 +17,14 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False): + def __init__(self, pr_url: str, cli_mode=False, args=None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) - try: - self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md", - ref=self.git_provider.get_pr_branch()) - changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() - changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] - self.changelog_file_str = "\n".join(changelog_file_lines) - except: - self.changelog_file_str = "" - if settings.config.publish_output and settings.pr_update_changelog.push_changelog_changes: - logging.info("No CHANGELOG.md file found in the repository. Creating one...") - changelog_file = self.git_provider.repo_obj.create_file(path="CHANGELOG.md", - message='add CHANGELOG.md', - content="", - branch=self.git_provider.get_pr_branch()) - self.changelog_file = changelog_file['content'] - - if not self.changelog_file_str: - self.changelog_file_str = self._get_default_changelog() - - - today = date.today() - print("Today's date:", today) - + self.commit_changelog = self._parse_args(args, settings) + self._get_changlog_file() # self.changelog_file_str self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None @@ -57,7 +36,7 @@ class PRUpdateChangelog: "language": self.main_language, "diff": "", # empty diff for initial calculation "changelog_file_str": self.changelog_file_str, - "today": today, + "today": date.today(), } self.token_handler = TokenHandler(self.git_provider.pr, self.vars, @@ -76,7 +55,7 @@ class PRUpdateChangelog: if settings.config.publish_output: self.git_provider.remove_initial_comment() logging.info('Publishing changelog updates...') - if settings.pr_update_changelog.push_changelog_changes: + if self.commit_changelog: logging.info('Pushing PR changelog updates to repo...') self._push_changelog_update(new_file_content, answer) else: @@ -113,8 +92,14 @@ class PRUpdateChangelog: new_file_content = answer + "\n\n" + self.changelog_file.decoded_content.decode() else: new_file_content = answer + + if not self.cli_mode and self.commit_changelog: + answer += "\n\n\n>to commit the new contnet to the CHANGELOG.md file, please type:" \ + "\n>'/update_changelog -commit'\n" + if settings.config.verbosity_level >= 2: logging.info(f"answer:\n{answer}") + return new_file_content, answer def _push_changelog_update(self, new_file_content, answer): @@ -151,3 +136,36 @@ Example: ... """ return example_changelog + + def _parse_args(self, args, setting): + commit_changelog = False + if args and len(args) >= 1: + try: + if args[0] == "-commit": + commit_changelog = True + except: + pass + else: + commit_changelog = setting.pr_update_changelog.push_changelog_changes + + return commit_changelog + + def _get_changlog_file(self): + try: + self.changelog_file = self.git_provider.repo_obj.get_contents("CHANGELOG.md", + ref=self.git_provider.get_pr_branch()) + changelog_file_lines = self.changelog_file.decoded_content.decode().splitlines() + changelog_file_lines = changelog_file_lines[:CHANGELOG_LINES] + self.changelog_file_str = "\n".join(changelog_file_lines) + except: + self.changelog_file_str = "" + if self.commit_changelog: + logging.info("No CHANGELOG.md file found in the repository. Creating one...") + changelog_file = self.git_provider.repo_obj.create_file(path="CHANGELOG.md", + message='add CHANGELOG.md', + content="", + branch=self.git_provider.get_pr_branch()) + self.changelog_file = changelog_file['content'] + + if not self.changelog_file_str: + self.changelog_file_str = self._get_default_changelog()