diff --git a/pr_agent/git_providers/local_git_provider.py b/pr_agent/git_providers/local_git_provider.py index 7c157d81..4a7775ac 100644 --- a/pr_agent/git_providers/local_git_provider.py +++ b/pr_agent/git_providers/local_git_provider.py @@ -1,10 +1,9 @@ import logging -import uuid from collections import Counter from pathlib import Path from typing import List -from git import GitCommandError, Repo +from git import Repo from pr_agent.config_loader import _find_repository_root, settings from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider @@ -29,12 +28,11 @@ class LocalGitProvider(GitProvider): For the MVP it only supports the /review and /describe capabilities. """ - def __init__(self, branch_name, incremental=False): + def __init__(self, target_branch_name, incremental=False): self.repo_path = _find_repository_root() self.repo = Repo(self.repo_path) self.head_branch_name = self.repo.head.ref.name - self.branch_name = branch_name - self.tmp_branch_name = f'pr_agent_{uuid.uuid4()}' + self.target_branch_name = target_branch_name self._prepare_repo() self.diff_files = None self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) @@ -45,20 +43,6 @@ class LocalGitProvider(GitProvider): # inline code comments are not supported for local git repositories settings.pr_reviewer.inline_code_comments = False - def __del__(self): - logging.debug('Deleting temporary branch...') - self.repo.git.checkout(self.head_branch_name) # switch back to the original branch - # delete the temporary branch - if self.tmp_branch_name not in self.repo.heads: - return - try: - self.repo.delete_head(self.tmp_branch_name, force=True) - except GitCommandError as e: - raise ValueError( - 'Error while trying to delete the temporary branch.' - 'Ensure the branch exists.' - ) from e - def _prepare_repo(self): """ Prepare the repository for PR-mimic generation. @@ -66,15 +50,8 @@ class LocalGitProvider(GitProvider): logging.debug('Preparing repository for PR-mimic generation...') if self.repo.is_dirty(): raise ValueError('The repository is not in a clean state. Please commit or stash pending changes.') - if self.tmp_branch_name in self.repo.heads: - self.repo.delete_head(self.tmp_branch_name, force=True) - self.repo.git.checkout('HEAD', b=self.tmp_branch_name) - - try: - logging.debug('Rebasing the temporary branch on the main branch...') - self.repo.git.rebase(self.branch_name) - except GitCommandError as e: - raise ValueError('Error while rebasing. Resolve conflicts before retrying.') from e + if self.target_branch_name not in self.repo.heads: + raise KeyError(f'Branch: {self.target_branch_name} does not exist') def is_supported(self, capability: str) -> bool: if capability in ['get_issue_comments', 'create_inline_comment', 'publish_inline_comments', 'get_labels']: @@ -83,7 +60,7 @@ class LocalGitProvider(GitProvider): def get_diff_files(self) -> list[FilePatchInfo]: diffs = self.repo.head.commit.diff( - self.repo.branches[self.branch_name].commit, + self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), create_patch=True, R=True ) @@ -120,13 +97,10 @@ class LocalGitProvider(GitProvider): """ Returns a list of files with changes in the diff. """ - # Assert existence of specific branch - branch_names = [ref.name for ref in self.repo.branches] - if self.branch_name not in branch_names: - raise KeyError(f"Branch: {self.branch_name} does not exist") - branch = self.repo.branches[self.branch_name] - # Compare the two branches - diff_index = self.repo.head.commit.diff(branch.commit) + diff_index = self.repo.head.commit.diff( + self.repo.merge_base(self.repo.head, self.repo.branches[self.target_branch_name]), + R=True + ) # Get the list of changed files diff_files = [item.a_path for item in diff_index] return diff_files @@ -183,7 +157,7 @@ class LocalGitProvider(GitProvider): return -1 # Not used anywhere for the local provider, but required by the interface def get_pr_description(self): - commits_diff = list(self.repo.iter_commits(self.branch_name + '..HEAD')) + commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD')) # Get the commit messages and concatenate commit_messages = " ".join([commit.message for commit in commits_diff]) # TODO Handle the description better - maybe use gpt-3.5 summarisation here? @@ -193,7 +167,7 @@ class LocalGitProvider(GitProvider): """ Substitutes the branch-name as the PR-mimic title. """ - return self.branch_name + return self.target_branch_name def get_issue_comments(self): raise NotImplementedError('Getting issue comments is not implemented for the local git provider')