mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
Add context-aware git provider retrieval and refactor related functions
This commit is contained in:
@ -2,11 +2,13 @@ from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.bitbucket_provider import BitbucketProvider
|
||||
from pr_agent.git_providers.bitbucket_server_provider import BitbucketServerProvider
|
||||
from pr_agent.git_providers.codecommit_provider import CodeCommitProvider
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
from pr_agent.git_providers.github_provider import GithubProvider
|
||||
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||
from pr_agent.git_providers.local_git_provider import LocalGitProvider
|
||||
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||
from pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||
from starlette_context import context
|
||||
|
||||
_GIT_PROVIDERS = {
|
||||
'github': GithubProvider,
|
||||
@ -28,3 +30,33 @@ def get_git_provider():
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
return _GIT_PROVIDERS[provider_id]
|
||||
|
||||
|
||||
def get_git_provider_with_context(pr_url) -> GitProvider:
|
||||
"""
|
||||
Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it.
|
||||
"""
|
||||
|
||||
is_context_env = None
|
||||
try:
|
||||
is_context_env = context.get("settings", None)
|
||||
except Exception:
|
||||
pass # we are not in a context environment (CLI)
|
||||
|
||||
# check if context["git_provider"]["pr_url"] exists
|
||||
if is_context_env and context.get("git_provider", {}).get("pr_url", {}):
|
||||
git_provider = context["git_provider"]["pr_url"]
|
||||
# possibly check if the git_provider is still valid, or if some reset is needed
|
||||
# ...
|
||||
return git_provider
|
||||
else:
|
||||
try:
|
||||
provider_id = get_settings().config.git_provider
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
git_provider = _GIT_PROVIDERS[provider_id](pr_url)
|
||||
if is_context_env:
|
||||
context["git_provider"] = {pr_url: git_provider}
|
||||
return git_provider
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to get git provider for {pr_url}") from e
|
||||
|
@ -13,10 +13,17 @@ class GitProvider(ABC):
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_files(self) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
pass
|
||||
|
||||
def get_incremental_commits(self, is_incremental):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
pass
|
||||
|
@ -19,7 +19,7 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
|
||||
class GithubProvider(GitProvider):
|
||||
def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)):
|
||||
def __init__(self, pr_url: Optional[str] = None):
|
||||
self.repo_obj = None
|
||||
try:
|
||||
self.installation_id = context.get("installation_id", None)
|
||||
@ -34,18 +34,21 @@ class GithubProvider(GitProvider):
|
||||
self.github_user_id = None
|
||||
self.diff_files = None
|
||||
self.git_files = None
|
||||
self.incremental = incremental
|
||||
self.incremental = IncrementalPR(False)
|
||||
if pr_url and 'pull' in pr_url:
|
||||
self.set_pr(pr_url)
|
||||
self.pr_commits = list(self.pr.get_commits())
|
||||
if self.incremental.is_incremental:
|
||||
self.unreviewed_files_set = dict()
|
||||
self.get_incremental_commits()
|
||||
self.last_commit_id = self.pr_commits[-1]
|
||||
self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object
|
||||
else:
|
||||
self.pr_commits = None
|
||||
|
||||
def get_incremental_commits(self, incremental=IncrementalPR(False)):
|
||||
self.incremental = incremental
|
||||
if self.incremental.is_incremental:
|
||||
self.unreviewed_files_set = dict()
|
||||
self.get_incremental_commits()
|
||||
self.last_commit_id = self.pr_commits[-1]
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -5,12 +5,13 @@ import tempfile
|
||||
from dynaconf import Dynaconf
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.log import get_logger
|
||||
from starlette_context import context
|
||||
|
||||
|
||||
def apply_repo_settings(pr_url):
|
||||
git_provider = get_git_provider_with_context(pr_url)
|
||||
if get_settings().config.use_repo_settings_file:
|
||||
repo_settings_file = None
|
||||
try:
|
||||
@ -20,7 +21,6 @@ def apply_repo_settings(pr_url):
|
||||
repo_settings = None
|
||||
pass
|
||||
if repo_settings is None: # None is different from "", which is a valid value
|
||||
git_provider = get_git_provider()(pr_url)
|
||||
repo_settings = git_provider.get_repo_settings()
|
||||
try:
|
||||
context["repo_settings"] = repo_settings
|
||||
|
@ -15,7 +15,7 @@ from starlette_context.middleware import RawContextMiddleware
|
||||
from pr_agent.agent.pr_agent import PRAgent
|
||||
from pr_agent.algo.utils import update_settings_from_args
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.git_providers.git_provider import IncrementalPR
|
||||
from pr_agent.git_providers.utils import apply_repo_settings
|
||||
from pr_agent.identity_providers import get_identity_provider
|
||||
@ -48,7 +48,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
||||
installation_id = body.get("installation", {}).get("id")
|
||||
context["installation_id"] = installation_id
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
|
||||
context["git_provider"] = None
|
||||
background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None))
|
||||
return {}
|
||||
|
||||
@ -111,7 +111,7 @@ async def handle_comments_on_pr(body: Dict[str, Any],
|
||||
return {}
|
||||
log_context["api_url"] = api_url
|
||||
comment_id = body.get("comment", {}).get("id")
|
||||
provider = get_git_provider()(pr_url=api_url)
|
||||
provider = get_git_provider_with_context(pr_url=api_url)
|
||||
with get_logger().contextualize(**log_context):
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}")
|
||||
@ -143,7 +143,7 @@ async def handle_new_pr_opened(body: Dict[str, Any],
|
||||
return {}
|
||||
if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review']
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
|
||||
await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context)
|
||||
else:
|
||||
get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}")
|
||||
|
||||
@ -201,13 +201,12 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any],
|
||||
|
||||
try:
|
||||
if get_settings().github_app.push_trigger_wait_for_initial_review and not get_git_provider()(api_url,
|
||||
incremental=IncrementalPR(
|
||||
True)).previous_review:
|
||||
incremental=IncrementalPR(True)).previous_review:
|
||||
get_logger().info(f"Skipping incremental review because there was no initial review for {api_url=} yet")
|
||||
return {}
|
||||
if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE:
|
||||
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
|
||||
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
|
||||
get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}")
|
||||
await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context)
|
||||
|
||||
finally:
|
||||
# release the waiting task block
|
||||
@ -241,7 +240,7 @@ def get_log_context(body, event, action, build_number):
|
||||
app_name = get_settings().get("CONFIG.APP_NAME", "Unknown")
|
||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app",
|
||||
"request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name,
|
||||
"repo": repo, "git_org": git_org, "installation_id": installation_id}
|
||||
"repo": repo, "git_org": git_org, "installation_id": installation_id}
|
||||
except Exception as e:
|
||||
get_logger().error("Failed to get log context", e)
|
||||
log_context = {}
|
||||
@ -268,18 +267,23 @@ async def handle_request(body: Dict[str, Any], event: str):
|
||||
get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot")
|
||||
return {}
|
||||
|
||||
if 'check_run' in body: # handle failed checks
|
||||
# get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_checks
|
||||
pass
|
||||
# handle comments on PRs
|
||||
if action == 'created':
|
||||
elif action == 'created':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent)
|
||||
# handle new PRs
|
||||
elif event == 'pull_request' and action != 'synchronize' and action != 'closed':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent)
|
||||
elif event == "issue_comment" and 'edited' in action:
|
||||
pass # handle_checkbox_clicked
|
||||
# handle pull_request event with synchronize action - "push trigger" for new commits
|
||||
elif event == 'pull_request' and action == 'synchronize':
|
||||
get_logger().debug(f'Request body', artifact=body, event=event)
|
||||
await handle_push_trigger_for_new_commits(body, event, sender, sender_id, action, log_context, agent)
|
||||
# get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_push_trigger_for_new_commits
|
||||
await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent)
|
||||
elif event == 'pull_request' and action == 'closed':
|
||||
if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""):
|
||||
handle_closed_pr(body, event, action, log_context)
|
||||
@ -326,8 +330,7 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
|
||||
apply_repo_settings(api_url)
|
||||
commands = get_settings().get(f"github_app.{commands_conf}")
|
||||
if not commands:
|
||||
with get_logger().contextualize(**log_context):
|
||||
get_logger().info(f"New PR, but no auto commands configured")
|
||||
get_logger().info(f"New PR, but no auto commands configured")
|
||||
return
|
||||
for command in commands:
|
||||
split_command = command.split(" ")
|
||||
@ -335,9 +338,8 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body
|
||||
args = split_command[1:]
|
||||
other_args = update_settings_from_args(args)
|
||||
new_command = ' '.join([command] + other_args)
|
||||
with get_logger().contextualize(**log_context):
|
||||
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
|
||||
await agent.handle_request(api_url, new_command)
|
||||
get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}")
|
||||
await agent.handle_request(api_url, new_command)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@ -357,5 +359,6 @@ app.include_router(router)
|
||||
def start():
|
||||
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000")))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
|
@ -11,18 +11,19 @@ from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_w
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType, show_relevant_configurations
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
from pr_agent.log import get_logger
|
||||
from pr_agent.servers.help import HelpMessage
|
||||
from pr_agent.tools.pr_description import insert_br_after_x_chars
|
||||
import difflib
|
||||
|
||||
|
||||
class PRCodeSuggestions:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.git_provider = get_git_provider_with_context(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels, ModelType, show_relevant_configurations
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
from pr_agent.log import get_logger
|
||||
from pr_agent.servers.help import HelpMessage
|
||||
@ -28,7 +28,7 @@ class PRDescription:
|
||||
args (list, optional): List of arguments passed to the PRDescription class. Defaults to None.
|
||||
"""
|
||||
# Initialize the git provider and main PR language
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.git_provider = get_git_provider_with_context(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import convert_to_markdown, github_action_output, load_yaml, ModelType, \
|
||||
show_relevant_configurations
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers import get_git_provider, get_git_provider_with_context
|
||||
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
|
||||
from pr_agent.log import get_logger
|
||||
from pr_agent.servers.help import HelpMessage
|
||||
@ -37,7 +37,10 @@ class PRReviewer:
|
||||
self.args = args
|
||||
self.parse_args(args) # -i command
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
||||
self.git_provider = get_git_provider_with_context(pr_url)
|
||||
if self.incremental and self.incremental.is_incremental:
|
||||
self.git_provider.get_incremental_commits(self.incremental)
|
||||
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
Reference in New Issue
Block a user