diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 4b7af70a..0eefcd74 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -13,23 +13,20 @@ class PRAgent: pass async def handle_request(self, pr_url, request) -> bool: - if any(cmd in request for cmd in ["/answer"]): + action, *args = request.split(" ") + if any(cmd == action for cmd in ["/answer"]): await PRReviewer(pr_url, is_answer=True).review() - elif any(cmd in request for cmd in ["/review", "/review_pr", "/reflect_and_review"]): + elif any(cmd == action for cmd in ["/review", "/review_pr", "/reflect_and_review"]): if settings.pr_reviewer.ask_and_reflect or "/reflect_and_review" in request: await PRInformationFromUser(pr_url).generate_questions() else: - await PRReviewer(pr_url).review() - elif any(cmd in request for cmd in ["/describe", "/describe_pr"]): + await PRReviewer(pr_url, args=args).review() + elif any(cmd == action for cmd in ["/describe", "/describe_pr"]): await PRDescription(pr_url).describe() - elif any(cmd in request for cmd in ["/improve", "/improve_code"]): + elif any(cmd == action for cmd in ["/improve", "/improve_code"]): await PRCodeSuggestions(pr_url).suggest() - elif any(cmd in request for cmd in ["/ask", "/ask_question"]): - pattern = r'(/ask|/ask_question)\s*(.*)' - matches = re.findall(pattern, request, re.IGNORECASE) - if matches: - question = matches[0][1] - await PRQuestions(pr_url, question).answer() + elif any(cmd == action for cmd in ["/ask", "/ask_question"]): + await PRQuestions(pr_url, args).answer() else: return False diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 165b7de5..11f16449 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -1,14 +1,15 @@ from __future__ import annotations -import difflib import logging -from typing import Any, Tuple, Union +from typing import Tuple, Union from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler +from pr_agent.algo.utils import load_large_diff from pr_agent.config_loader import settings -from pr_agent.git_providers import GithubProvider +from pr_agent.git_providers.git_provider import GitProvider + DELETED_FILES_ = "Deleted files:\n" @@ -19,7 +20,7 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 -def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHandler, +def get_pr_diff(git_provider: Union[GitProvider], token_handler: TokenHandler, add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str: """ Returns a string with the diff of the PR. @@ -163,14 +164,3 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, return patches, modified_files_list, deleted_files_list -def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: - if not patch: # to Do - also add condition for file extension - try: - diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), - new_file_content_str.splitlines(keepends=True)) - if settings.config.verbosity_level >= 2: - logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") - patch = ''.join(diff) - except Exception: - pass - return patch diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 866b160e..a4fb4f6e 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -1,10 +1,14 @@ from __future__ import annotations +import difflib +from datetime import datetime import json import logging import re import textwrap +from pr_agent.config_loader import settings + def convert_to_markdown(output_data: dict) -> str: markdown_text = "" @@ -18,7 +22,7 @@ def convert_to_markdown(output_data: dict) -> str: "Security concerns": "🔒", "General PR suggestions": "💡", "Insights from user's answers": "📝", - "Code suggestions": "🤖" + "Code suggestions": "🤖", } for key, value in output_data.items(): @@ -103,3 +107,21 @@ def fix_json_escape_char(json_message=None): new_message = ''.join(json_message) return fix_json_escape_char(json_message=new_message) return result + + +def convert_str_to_datetime(date_str): + datetime_format = '%a, %d %b %Y %H:%M:%S %Z' + return datetime.strptime(date_str, datetime_format) + + +def load_large_diff(file, new_file_content_str: str, original_file_content_str: str, patch: str) -> str: + if not patch: # to Do - also add condition for file extension + try: + diff = difflib.unified_diff(original_file_content_str.splitlines(keepends=True), + new_file_content_str.splitlines(keepends=True)) + if settings.config.verbosity_level >= 2: + logging.warning(f"File was modified, but no patch was found. Manually creating patch: {file.filename}.") + patch = ''.join(diff) + except Exception: + pass + return patch diff --git a/pr_agent/cli.py b/pr_agent/cli.py index d6af7cf6..803ede2f 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -57,7 +57,7 @@ reflect - Ask the PR author questions about the PR. asyncio.run(reviewer.suggest()) elif command in ['review', 'review_pr']: print(f"Reviewing PR: {args.pr_url}") - reviewer = PRReviewer(args.pr_url, cli_mode=True) + reviewer = PRReviewer(args.pr_url, cli_mode=True, args=args.rest) asyncio.run(reviewer.review()) elif command in ['reflect']: print(f"Asking the PR author questions: {args.pr_url}") diff --git a/pr_agent/git_providers/bitbucket_provider.py b/pr_agent/git_providers/bitbucket_provider.py index 62ec0a19..1470ce78 100644 --- a/pr_agent/git_providers/bitbucket_provider.py +++ b/pr_agent/git_providers/bitbucket_provider.py @@ -11,7 +11,7 @@ from .git_provider import FilePatchInfo class BitbucketProvider: - def __init__(self, pr_url: Optional[str] = None): + def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): s = requests.Session() s.headers['Authorization'] = f'Bearer {settings.get("BITBUCKET.BEARER_TOKEN", None)}' self.bitbucket_client = Cloud(session=s) @@ -22,6 +22,7 @@ class BitbucketProvider: self.pr_num = None self.pr = None self.temp_comments = [] + self.incremental = incremental if pr_url: self.set_pr(pr_url) diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index f7c7fa98..32f6e315 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -121,3 +121,11 @@ def get_main_pr_language(languages, files) -> str: pass return main_language_str + + +class IncrementalPR: + def __init__(self, is_incremental: bool = False): + self.is_incremental = is_incremental + self.commits_range = None + self.first_new_commit_sha = None + self.last_seen_commit_sha = None diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index fea1ae69..1f2ffec7 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -7,12 +7,14 @@ from github import AppAuthentication, Github, Auth from pr_agent.config_loader import settings -from .git_provider import FilePatchInfo, GitProvider +from .git_provider import FilePatchInfo, GitProvider, IncrementalPR from ..algo.language_handler import is_valid_file +from ..algo.utils import load_large_diff class GithubProvider(GitProvider): - def __init__(self, pr_url: Optional[str] = None): + def __init__(self, pr_url: Optional[str] = None, incremental: Optional[IncrementalPR] = False): + self.repo_obj = None self.installation_id = settings.get("GITHUB.INSTALLATION_ID") self.github_client = self._get_github_client() self.repo = None @@ -20,6 +22,7 @@ class GithubProvider(GitProvider): self.pr = None self.github_user_id = None self.diff_files = None + self.incremental = incremental if pr_url: self.set_pr(pr_url) self.last_commit_id = list(self.pr.get_commits())[-1] @@ -27,21 +30,73 @@ class GithubProvider(GitProvider): def is_supported(self, capability: str) -> bool: return True + def get_pr_url(self) -> str: + return f"https://github.com/{self.repo}/pull/{self.pr_num}" + def set_pr(self, pr_url: str): self.repo, self.pr_num = self._parse_pr_url(pr_url) self.pr = self._get_pr() + if self.incremental.is_incremental: + self.get_incremental_commits() + + def get_incremental_commits(self): + self.commits = list(self.pr.get_commits()) + + self.get_previous_review() + if self.previous_review: + self.incremental.commits_range = self.get_commit_range() + # Get all files changed during the commit range + self.file_set = dict() + for commit in self.incremental.commits_range: + if commit.commit.message.startswith(f"Merge branch '{self._get_repo().default_branch}'"): + logging.info(f"Skipping merge commit {commit.commit.message}") + continue + self.file_set.update({file.filename: file for file in commit.files}) + + def get_commit_range(self): + last_review_time = self.previous_review.created_at + first_new_commit_index = 0 + for index in range(len(self.commits) - 1, -1, -1): + if self.commits[index].commit.author.date > last_review_time: + self.incremental.first_new_commit_sha = self.commits[index].sha + first_new_commit_index = index + else: + self.incremental.last_seen_commit_sha = self.commits[index].sha + break + return self.commits[first_new_commit_index:] + + def get_previous_review(self): + self.previous_review = None + self.comments = list(self.pr.get_issue_comments()) + for index in range(len(self.comments) - 1, -1, -1): + if self.comments[index].body.startswith("## PR Analysis"): + self.previous_review = self.comments[index] + break def get_files(self): + if self.incremental.is_incremental and self.file_set: + return self.file_set.values() return self.pr.get_files() def get_diff_files(self) -> list[FilePatchInfo]: - files = self.pr.get_files() + files = self.get_files() diff_files = [] for file in files: if is_valid_file(file.filename): - original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) - diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename)) + patch = file.patch + if self.incremental.is_incremental and self.file_set: + original_file_content_str = self._get_pr_file_content(file, self.incremental.last_seen_commit_sha) + patch = load_large_diff(file, + new_file_content_str, + original_file_content_str, + None) + self.file_set[file.filename] = patch + else: + original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) + + diff_files.append( + FilePatchInfo(original_file_content_str, new_file_content_str, patch, file.filename)) self.diff_files = diff_files return diff_files @@ -100,7 +155,7 @@ class GithubProvider(GitProvider): logging.exception(f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}") return False - if relevant_lines_end= 2: logging.exception(f"Failed to publish code suggestion, " f"relevant_lines_end is {relevant_lines_end} and " @@ -233,7 +288,14 @@ class GithubProvider(GitProvider): return Github(auth=Auth.Token(token)) def _get_repo(self): - return self.github_client.get_repo(self.repo) + if hasattr(self, 'repo_obj') and \ + hasattr(self.repo_obj, 'full_name') and \ + self.repo_obj.full_name == self.repo: + return self.repo_obj + else: + self.repo_obj = self.github_client.get_repo(self.repo) + return self.repo_obj + def _get_pr(self): return self._get_repo().get_pull(self.pr_num) diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 4cc8e9e0..95246647 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -13,7 +13,7 @@ from ..algo.language_handler import is_valid_file class GitLabProvider(GitProvider): - def __init__(self, merge_request_url: Optional[str] = None): + def __init__(self, merge_request_url: Optional[str] = None, incremental: Optional[bool] = False): gitlab_url = settings.get("GITLAB.URL", None) if not gitlab_url: raise ValueError("GitLab URL is not set in the config file") @@ -32,6 +32,7 @@ class GitLabProvider(GitProvider): self._set_merge_request(merge_request_url) self.RE_HUNK_HEADER = re.compile( r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") + self.incremental = incremental def is_supported(self, capability: str) -> bool: if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments']: diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 08af3797..8d24c04c 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -12,7 +12,8 @@ from pr_agent.git_providers.git_provider import get_main_pr_language class PRQuestions: - def __init__(self, pr_url: str, question_str: str): + def __init__(self, pr_url: str, args=None): + question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() @@ -34,6 +35,13 @@ class PRQuestions: self.patches_diff = None self.prediction = None + def parse_args(self, args): + if args and len(args) > 0: + question_str = " ".join(args) + else: + question_str = "" + return question_str + async def answer(self): logging.info('Answering a PR question...') if settings.config.publish_output: diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 8da08dfa..06b7f6bf 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -1,6 +1,7 @@ import copy import json import logging +from collections import OrderedDict from jinja2 import Environment, StrictUndefined @@ -10,17 +11,19 @@ from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider -from pr_agent.git_providers.git_provider import get_main_pr_language +from pr_agent.git_providers.git_provider import get_main_pr_language, IncrementalPR from pr_agent.servers.help import actions_help_text, bot_help_text class PRReviewer: - def __init__(self, pr_url: str, cli_mode=False, is_answer: bool = False): + def __init__(self, pr_url: str, cli_mode=False, is_answer: bool = False, args=None): + self.parse_args(args) - self.git_provider = get_git_provider()(pr_url) + self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) + self.pr_url = pr_url self.is_answer = is_answer if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): raise Exception(f"Answer mode is not supported for {settings.config.git_provider} for now") @@ -48,6 +51,14 @@ class PRReviewer: settings.pr_review_prompt.system, settings.pr_review_prompt.user) + def parse_args(self, args): + is_incremental = False + if len(args) >= 1: + arg = args[0] + if arg == "-i": + is_incremental = True + self.incremental = IncrementalPR(is_incremental) + async def review(self): logging.info('Reviewing PR...') if settings.config.publish_output: @@ -107,6 +118,14 @@ class PRReviewer: if not data['PR Feedback']['Code suggestions']: del data['PR Feedback']['Code suggestions'] + if self.incremental.is_incremental: + # Rename title when incremental review - Add to the beginning of the dict + last_commit_url = f"{self.git_provider.get_pr_url()}/commits/{self.git_provider.incremental.first_new_commit_sha}" + data = OrderedDict(data) + data.update({'Incremental PR Review': { + "⏮️ Review for commits since previous PR-Agent review": f"Starting from commit {last_commit_url}"}}) + data.move_to_end('Incremental PR Review', last=False) + markdown_text = convert_to_markdown(data) user = self.git_provider.get_user_id()