mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-16 10:40:16 +08:00
Add context-aware git provider retrieval and refactor related functions
This commit is contained in:
@ -2,11 +2,13 @@ from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||
from pr_agent.git_providers.bitbucket_server_provider import BitbucketServerProvider
|
||||
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
from pr_agent.git_providers.github_provider import GithubProvider
|
||||
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||
from pr_agent.git_providers.local_git_provider import LocalGitProvider
|
||||
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||
from pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||
from starlette_context import context
|
||||
|
||||
_GIT_PROVIDERS = {
|
||||
'github': GithubProvider,
|
||||
@ -28,3 +30,33 @@ def get_git_provider():
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
return _GIT_PROVIDERS[provider_id]
|
||||
|
||||
|
||||
def get_git_provider_with_context(pr_url) -> GitProvider:
|
||||
"""
|
||||
Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it.
|
||||
"""
|
||||
|
||||
is_context_env = None
|
||||
try:
|
||||
is_context_env = context.get("settings", None)
|
||||
except Exception:
|
||||
pass # we are not in a context environment (CLI)
|
||||
|
||||
# check if context["git_provider"]["pr_url"] exists
|
||||
if is_context_env and context.get("git_provider", {}).get("pr_url", {}):
|
||||
git_provider = context["git_provider"]["pr_url"]
|
||||
# possibly check if the git_provider is still valid, or if some reset is needed
|
||||
# ...
|
||||
return git_provider
|
||||
else:
|
||||
try:
|
||||
provider_id = get_settings().config.git_provider
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
git_provider = _GIT_PROVIDERS[provider_id](pr_url)
|
||||
if is_context_env:
|
||||
context["git_provider"] = {pr_url: git_provider}
|
||||
return git_provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get git provider for {pr_url}") from e
|
||||
|
@ -13,10 +13,17 @@ class GitProvider(ABC):
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_files(self) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
pass
|
||||
|
||||
def get_incremental_commits(self, is_incremental):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
pass
|
||||
|
@ -19,7 +19,7 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
|
||||
class GithubProvider(GitProvider):
|
||||
def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)):
|
||||
def __init__(self, pr_url: Optional[str] = None):
|
||||
self.repo_obj = None
|
||||
try:
|
||||
self.installation_id = context.get("installation_id", None)
|
||||
@ -34,18 +34,21 @@ class GithubProvider(GitProvider):
|
||||
self.github_user_id = None
|
||||
self.diff_files = None
|
||||
self.git_files = None
|
||||
self.incremental = incremental
|
||||
self.incremental = IncrementalPR(False)
|
||||
if pr_url and 'pull' in pr_url:
|
||||
self.set_pr(pr_url)
|
||||
self.pr_commits = list(self.pr.get_commits())
|
||||
if self.incremental.is_incremental:
|
||||
self.unreviewed_files_set = dict()
|
||||
self.get_incremental_commits()
|
||||
self.last_commit_id = self.pr_commits[-1]
|
||||
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
||||
else:
|
||||
self.pr_commits = None
|
||||
|
||||
def get_incremental_commits(self, incremental=IncrementalPR(False)):
|
||||
self.incremental = incremental
|
||||
if self.incremental.is_incremental:
|
||||
self.unreviewed_files_set = dict()
|
||||
self.get_incremental_commits()
|
||||
self.last_commit_id = self.pr_commits[-1]
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -5,12 +5,13 @@ import tempfile
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.log import get_logger
|
||||
from starlette_context import context
|
||||
|
||||
|
||||
def apply_repo_settings(pr_url):
|
||||
git_provider = get_git_provider_with_context(pr_url)
|
||||
if get_settings().config.use_repo_settings_file:
|
||||
repo_settings_file = None
|
||||
try:
|
||||
@ -20,7 +21,6 @@ def apply_repo_settings(pr_url):
|
||||
repo_settings = None
|
||||
pass
|
||||
if repo_settings is None: # None is different from "", which is a valid value
|
||||
git_provider = get_git_provider()(pr_url)
|
||||
repo_settings = git_provider.get_repo_settings()
|
||||
try:
|
||||
context["repo_settings"] = repo_settings
|
||||
|
Reference in New Issue
Block a user