Add context-aware git provider retrieval and refactor related functions

This commit is contained in:
mrT23
2024-06-19 09:36:37 +03:00
parent 5968db67b9
commit 025a14014a
8 changed files with 80 additions and 31 deletions

View File

@ -2,11 +2,13 @@ from pr_agent.config_loader import get_settings
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
from pr_agent.git_providers.bitbucket_server_provider import BitbucketServerProvider
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider
from pr_agent.git_providers.local_git_provider import LocalGitProvider
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from pr_agent.git_providers.gerrit_provider import GerritProvider
from starlette_context import context
_GIT_PROVIDERS = {
'github': GithubProvider,
@ -28,3 +30,33 @@ def get_git_provider():
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id]
def get_git_provider_with_context(pr_url) -> GitProvider:
"""
Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it.
"""
is_context_env = None
try:
is_context_env = context.get("settings", None)
except Exception:
pass # we are not in a context environment (CLI)
# check if context["git_provider"]["pr_url"] exists
if is_context_env and context.get("git_provider", {}).get("pr_url", {}):
git_provider = context["git_provider"]["pr_url"]
# possibly check if the git_provider is still valid, or if some reset is needed
# ...
return git_provider
else:
try:
provider_id = get_settings().config.git_provider
if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}")
git_provider = _GIT_PROVIDERS[provider_id](pr_url)
if is_context_env:
context["git_provider"] = {pr_url: git_provider}
return git_provider
except Exception as e:
raise ValueError(f"Failed to get git provider for {pr_url}") from e

View File

@ -13,10 +13,17 @@ class GitProvider(ABC):
def is_supported(self, capability: str) -> bool:
pass
@abstractmethod
def get_files(self) -> list:
pass
@abstractmethod
def get_diff_files(self) -> list[FilePatchInfo]:
pass
def get_incremental_commits(self, is_incremental):
pass
@abstractmethod
def publish_description(self, pr_title: str, pr_body: str):
pass

View File

@ -19,7 +19,7 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
class GithubProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)):
def __init__(self, pr_url: Optional[str] = None):
self.repo_obj = None
try:
self.installation_id = context.get("installation_id", None)
@ -34,18 +34,21 @@ class GithubProvider(GitProvider):
self.github_user_id = None
self.diff_files = None
self.git_files = None
self.incremental = incremental
self.incremental = IncrementalPR(False)
if pr_url and 'pull' in pr_url:
self.set_pr(pr_url)
self.pr_commits = list(self.pr.get_commits())
if self.incremental.is_incremental:
self.unreviewed_files_set = dict()
self.get_incremental_commits()
self.last_commit_id = self.pr_commits[-1]
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
else:
self.pr_commits = None
def get_incremental_commits(self, incremental=IncrementalPR(False)):
self.incremental = incremental
if self.incremental.is_incremental:
self.unreviewed_files_set = dict()
self.get_incremental_commits()
self.last_commit_id = self.pr_commits[-1]
def is_supported(self, capability: str) -> bool:
return True

View File

@ -5,12 +5,13 @@ import tempfile
from dynaconf import Dynaconf
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
from pr_agent.log import get_logger
from starlette_context import context
def apply_repo_settings(pr_url):
git_provider = get_git_provider_with_context(pr_url)
if get_settings().config.use_repo_settings_file:
repo_settings_file = None
try:
@ -20,7 +21,6 @@ def apply_repo_settings(pr_url):
repo_settings = None
pass
if repo_settings is None: # None is different from "", which is a valid value
git_provider = get_git_provider()(pr_url)
repo_settings = git_provider.get_repo_settings()
try:
context["repo_settings"] = repo_settings