- Replaced two dot diff with three dot diff. Cleaned up obsolete code linked to double dot diff.

- Moved target_branch_existence assertion to _prepare_repo method
- Renamed branch_name -> target_branch_name
- Simplified get_files method
This commit is contained in:
Patryk Kowalski
2023-07-25 13:07:21 +02:00
parent 918549a4fc
commit 0815e2024c

View File

@ -1,10 +1,9 @@
import logging
import uuid
from collections import Counter
from pathlib import Path
from typing import List
from git import GitCommandError, Repo
from git import Repo
from pr_agent.config_loader import _find_repository_root, settings
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
@ -29,12 +28,11 @@ class LocalGitProvider(GitProvider):
For the MVP it only supports the /review and /describe capabilities.
"""
def __init__(self, branch_name, incremental=False):
def __init__(self, target_branch_name, incremental=False):
self.repo_path = _find_repository_root()
self.repo = Repo(self.repo_path)
self.head_branch_name = self.repo.head.ref.name
self.branch_name = branch_name
self.tmp_branch_name = f'pr_agent_{uuid.uuid4()}'
self.target_branch_name = target_branch_name
self._prepare_repo()
self.diff_files = None
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
@ -45,20 +43,6 @@ class LocalGitProvider(GitProvider):
# inline code comments are not supported for local git repositories
settings.pr_reviewer.inline_code_comments = False
def __del__(self):
logging.debug('Deleting temporary branch...')
self.repo.git.checkout(self.head_branch_name) # switch back to the original branch
# delete the temporary branch
if self.tmp_branch_name not in self.repo.heads:
return
try:
self.repo.delete_head(self.tmp_branch_name, force=True)
except GitCommandError as e:
raise ValueError(
'Error while trying to delete the temporary branch.'
'Ensure the branch exists.'
) from e
def _prepare_repo(self):
"""
Prepare the repository for PR-mimic generation.
@ -66,15 +50,8 @@ class LocalGitProvider(GitProvider):
logging.debug('Preparing repository for PR-mimic generation...')
if self.repo.is_dirty():
raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.')
if self.tmp_branch_name in self.repo.heads:
self.repo.delete_head(self.tmp_branch_name, force=True)
self.repo.git.checkout('HEAD', b=self.tmp_branch_name)
try:
logging.debug('Rebasing the temporary branch on the main branch...')
self.repo.git.rebase(self.branch_name)
except GitCommandError as e:
raise ValueError('Error while rebasing. Resolve conflicts before retrying.') from e
if self.target_branch_name not in self.repo.heads:
raise KeyError(f'Branch: {self.target_branch_name} does not exist')
def is_supported(self, capability: str) -> bool:
if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels']:
@ -83,7 +60,7 @@ class LocalGitProvider(GitProvider):
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.repo.head.commit.diff(
self.repo.branches[self.branch_name].commit,
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
create_patch=True,
R=True
)
@ -120,13 +97,10 @@ class LocalGitProvider(GitProvider):
"""
Returns a list of files with changes in the diff.
"""
# Assert existence of specific branch
branch_names = [ref.name for ref in self.repo.branches]
if self.branch_name not in branch_names:
raise KeyError(f"Branch: {self.branch_name} does not exist")
branch = self.repo.branches[self.branch_name]
# Compare the two branches
diff_index = self.repo.head.commit.diff(branch.commit)
diff_index = self.repo.head.commit.diff(
self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]),
R=True
)
# Get the list of changed files
diff_files = [item.a_path for item in diff_index]
return diff_files
@ -183,7 +157,7 @@ class LocalGitProvider(GitProvider):
return -1 # Not used anywhere for the local provider, but required by the interface
def get_pr_description(self):
commits_diff = list(self.repo.iter_commits(self.branch_name + '..HEAD'))
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
# Get the commit messages and concatenate
commit_messages = " ".join([commit.message for commit in commits_diff])
# TODO Handle the description better - maybe use gpt-3.5 summarisation here?
@ -193,7 +167,7 @@ class LocalGitProvider(GitProvider):
"""
Substitutes the branch-name as the PR-mimic title.
"""
return self.branch_name
return self.target_branch_name
def get_issue_comments(self):
raise NotImplementedError('Getting issue comments is not implemented for the local git provider')