From 025a14014af2aa9cc4b112ceab4e70852381c2cd Mon Sep 17 00:00:00 2001 From: mrT23 Date: Wed, 19 Jun 2024 09:36:37 +0300 Subject: [PATCH] Add context-aware git provider retrieval and refactor related functions --- pr_agent/git_providers/__init__.py | 32 ++++++++++++++++++++ pr_agent/git_providers/git_provider.py | 7 +++++ pr_agent/git_providers/github_provider.py | 15 +++++---- pr_agent/git_providers/utils.py | 4 +-- pr_agent/servers/github_app.py | 37 ++++++++++++----------- pr_agent/tools/pr_code_suggestions.py | 5 +-- pr_agent/tools/pr_description.py | 4 +-- pr_agent/tools/pr_reviewer.py | 7 +++-- 8 files changed, 80 insertions(+), 31 deletions(-) diff --git a/pr_agent/git_providers/__init__.py b/pr_agent/git_providers/__init__.py index 8af6b7c3..c7e3e6e8 100644 --- a/pr_agent/git_providers/__init__.py +++ b/pr_agent/git_providers/__init__.py @@ -2,11 +2,13 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers.bitbucket_provider import BitbucketProvider from pr_agent.git_providers.bitbucket_server_provider import BitbucketServerProvider from pr_agent.git_providers.codecommit_provider import CodeCommitProvider +from pr_agent.git_providers.git_provider import GitProvider from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider from pr_agent.git_providers.local_git_provider import LocalGitProvider from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider from pr_agent.git_providers.gerrit_provider import GerritProvider +from starlette_context import context _GIT_PROVIDERS = { 'github': GithubProvider, @@ -28,3 +30,33 @@ def get_git_provider(): if provider_id not in _GIT_PROVIDERS: raise ValueError(f"Unknown git provider: {provider_id}") return _GIT_PROVIDERS[provider_id] + + +def get_git_provider_with_context(pr_url) -> GitProvider: + """ + Get a GitProvider instance for the given PR URL. If the GitProvider instance is already in the context, return it. + """ + + is_context_env = None + try: + is_context_env = context.get("settings", None) + except Exception: + pass # we are not in a context environment (CLI) + + # check if context["git_provider"]["pr_url"] exists + if is_context_env and context.get("git_provider", {}).get("pr_url", {}): + git_provider = context["git_provider"]["pr_url"] + # possibly check if the git_provider is still valid, or if some reset is needed + # ... + return git_provider + else: + try: + provider_id = get_settings().config.git_provider + if provider_id not in _GIT_PROVIDERS: + raise ValueError(f"Unknown git provider: {provider_id}") + git_provider = _GIT_PROVIDERS[provider_id](pr_url) + if is_context_env: + context["git_provider"] = {pr_url: git_provider} + return git_provider + except Exception as e: + raise ValueError(f"Failed to get git provider for {pr_url}") from e diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py index 0ff5caf1..d5808409 100644 --- a/pr_agent/git_providers/git_provider.py +++ b/pr_agent/git_providers/git_provider.py @@ -13,10 +13,17 @@ class GitProvider(ABC): def is_supported(self, capability: str) -> bool: pass + @abstractmethod + def get_files(self) -> list: + pass + @abstractmethod def get_diff_files(self) -> list[FilePatchInfo]: pass + def get_incremental_commits(self, is_incremental): + pass + @abstractmethod def publish_description(self, pr_title: str, pr_body: str): pass diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 0f8727e0..991a92cb 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -19,7 +19,7 @@ from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo class GithubProvider(GitProvider): - def __init__(self, pr_url: Optional[str] = None, incremental=IncrementalPR(False)): + def __init__(self, pr_url: Optional[str] = None): self.repo_obj = None try: self.installation_id = context.get("installation_id", None) @@ -34,18 +34,21 @@ class GithubProvider(GitProvider): self.github_user_id = None self.diff_files = None self.git_files = None - self.incremental = incremental + self.incremental = IncrementalPR(False) if pr_url and 'pull' in pr_url: self.set_pr(pr_url) self.pr_commits = list(self.pr.get_commits()) - if self.incremental.is_incremental: - self.unreviewed_files_set = dict() - self.get_incremental_commits() - self.last_commit_id = self.pr_commits[-1] self.pr_url = self.get_pr_url() # pr_url for github actions can be as api.github.com, so we need to get the url from the pr object else: self.pr_commits = None + def get_incremental_commits(self, incremental=IncrementalPR(False)): + self.incremental = incremental + if self.incremental.is_incremental: + self.unreviewed_files_set = dict() + self.get_incremental_commits() + self.last_commit_id = self.pr_commits[-1] + def is_supported(self, capability: str) -> bool: return True diff --git a/pr_agent/git_providers/utils.py b/pr_agent/git_providers/utils.py index cb0d554f..a0d65b66 100644 --- a/pr_agent/git_providers/utils.py +++ b/pr_agent/git_providers/utils.py @@ -5,12 +5,13 @@ import tempfile from dynaconf import Dynaconf from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context from pr_agent.log import get_logger from starlette_context import context def apply_repo_settings(pr_url): + git_provider = get_git_provider_with_context(pr_url) if get_settings().config.use_repo_settings_file: repo_settings_file = None try: @@ -20,7 +21,6 @@ def apply_repo_settings(pr_url): repo_settings = None pass if repo_settings is None: # None is different from "", which is a valid value - git_provider = get_git_provider()(pr_url) repo_settings = git_provider.get_repo_settings() try: context["repo_settings"] = repo_settings diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 9361d595..cb188f5b 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -15,7 +15,7 @@ from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings, global_settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context from pr_agent.git_providers.git_provider import IncrementalPR from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.identity_providers import get_identity_provider @@ -48,7 +48,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req installation_id = body.get("installation", {}).get("id") context["installation_id"] = installation_id context["settings"] = copy.deepcopy(global_settings) - + context["git_provider"] = None background_tasks.add_task(handle_request, body, event=request.headers.get("X-GitHub-Event", None)) return {} @@ -111,7 +111,7 @@ async def handle_comments_on_pr(body: Dict[str, Any], return {} log_context["api_url"] = api_url comment_id = body.get("comment", {}).get("id") - provider = get_git_provider()(pr_url=api_url) + provider = get_git_provider_with_context(pr_url=api_url) with get_logger().contextualize(**log_context): if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: get_logger().info(f"Processing comment on PR {api_url=}, comment_body={comment_body}") @@ -143,7 +143,7 @@ async def handle_new_pr_opened(body: Dict[str, Any], return {} if action in get_settings().github_app.handle_pr_actions: # ['opened', 'reopened', 'ready_for_review'] if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: - await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context) + await _perform_auto_commands_github("pr_commands", agent, body, api_url, log_context) else: get_logger().info(f"User {sender=} is not eligible to process PR {api_url=}") @@ -201,13 +201,12 @@ async def handle_push_trigger_for_new_commits(body: Dict[str, Any], try: if get_settings().github_app.push_trigger_wait_for_initial_review and not get_git_provider()(api_url, - incremental=IncrementalPR( - True)).previous_review: + incremental=IncrementalPR(True)).previous_review: get_logger().info(f"Skipping incremental review because there was no initial review for {api_url=} yet") return {} if get_identity_provider().verify_eligibility("github", sender_id, api_url) is not Eligibility.NOT_ELIGIBLE: - get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}") - await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context) + get_logger().info(f"Performing incremental review for {api_url=} because of {event=} and {action=}") + await _perform_auto_commands_github("push_commands", agent, body, api_url, log_context) finally: # release the waiting task block @@ -241,7 +240,7 @@ def get_log_context(body, event, action, build_number): app_name = get_settings().get("CONFIG.APP_NAME", "Unknown") log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app", "request_id": uuid.uuid4().hex, "build_number": build_number, "app_name": app_name, - "repo": repo, "git_org": git_org, "installation_id": installation_id} + "repo": repo, "git_org": git_org, "installation_id": installation_id} except Exception as e: get_logger().error("Failed to get log context", e) log_context = {} @@ -268,18 +267,23 @@ async def handle_request(body: Dict[str, Any], event: str): get_logger().info(f"Ignoring PR from '{sender=}' because it is a bot") return {} + if 'check_run' in body: # handle failed checks + # get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_checks + pass # handle comments on PRs - if action == 'created': + elif action == 'created': get_logger().debug(f'Request body', artifact=body, event=event) await handle_comments_on_pr(body, event, sender, sender_id, action, log_context, agent) # handle new PRs elif event == 'pull_request' and action != 'synchronize' and action != 'closed': get_logger().debug(f'Request body', artifact=body, event=event) await handle_new_pr_opened(body, event, sender, sender_id, action, log_context, agent) + elif event == "issue_comment" and 'edited' in action: + pass # handle_checkbox_clicked # handle pull_request event with synchronize action - "push trigger" for new commits elif event == 'pull_request' and action == 'synchronize': - get_logger().debug(f'Request body', artifact=body, event=event) - await handle_push_trigger_for_new_commits(body, event, sender, sender_id, action, log_context, agent) + # get_logger().debug(f'Request body', artifact=body, event=event) # added inside handle_push_trigger_for_new_commits + await handle_push_trigger_for_new_commits(body, event, sender,sender_id, action, log_context, agent) elif event == 'pull_request' and action == 'closed': if get_settings().get("CONFIG.ANALYTICS_FOLDER", ""): handle_closed_pr(body, event, action, log_context) @@ -326,8 +330,7 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body apply_repo_settings(api_url) commands = get_settings().get(f"github_app.{commands_conf}") if not commands: - with get_logger().contextualize(**log_context): - get_logger().info(f"New PR, but no auto commands configured") + get_logger().info(f"New PR, but no auto commands configured") return for command in commands: split_command = command.split(" ") @@ -335,9 +338,8 @@ async def _perform_auto_commands_github(commands_conf: str, agent: PRAgent, body args = split_command[1:] other_args = update_settings_from_args(args) new_command = ' '.join([command] + other_args) - with get_logger().contextualize(**log_context): - get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}") - await agent.handle_request(api_url, new_command) + get_logger().info(f"{commands_conf}. Performing auto command '{new_command}', for {api_url=}") + await agent.handle_request(api_url, new_command) @router.get("/") @@ -357,5 +359,6 @@ app.include_router(router) def start(): uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "3000"))) + if __name__ == '__main__': start() diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index c92859c0..e8054d3d 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -11,18 +11,19 @@ from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_w from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType, show_relevant_configurations from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage from pr_agent.tools.pr_description import insert_br_after_x_chars import difflib + class PRCodeSuggestions: def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler): - self.git_provider = get_git_provider()(pr_url) + self.git_provider = get_git_provider_with_context(pr_url) self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index d05dc348..02ebf28d 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -11,7 +11,7 @@ from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels, ModelType, show_relevant_configurations from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage @@ -28,7 +28,7 @@ class PRDescription: args (list, optional): List of arguments passed to the PRDescription class. Defaults to None. """ # Initialize the git provider and main PR language - self.git_provider = get_git_provider()(pr_url) + self.git_provider = get_git_provider_with_context(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() ) diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 68267232..ecc8f3e9 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -11,7 +11,7 @@ from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, github_action_output, load_yaml, ModelType, \ show_relevant_configurations from pr_agent.config_loader import get_settings -from pr_agent.git_providers import get_git_provider +from pr_agent.git_providers import get_git_provider, get_git_provider_with_context from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language from pr_agent.log import get_logger from pr_agent.servers.help import HelpMessage @@ -37,7 +37,10 @@ class PRReviewer: self.args = args self.parse_args(args) # -i command - self.git_provider = get_git_provider()(pr_url, incremental=self.incremental) + self.git_provider = get_git_provider_with_context(pr_url) + if self.incremental and self.incremental.is_incremental: + self.git_provider.get_incremental_commits(self.incremental) + self.main_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() )