From ed8cf27b05a5326cdd0eddccd5d31eefa1444a64 Mon Sep 17 00:00:00 2001 From: Albert Achtenberg Date: Fri, 7 Jul 2023 15:02:40 +0300 Subject: [PATCH] working example --- pr_agent/git_providers/__init__.py | 6 ++- pr_agent/git_providers/git_provider.py | 38 +++++++++++++++ pr_agent/git_providers/github_provider.py | 22 ++++----- pr_agent/git_providers/gitlab_provider.py | 57 +++++++++++++++++++++++ pr_agent/settings/.secrets_template.toml | 7 +++ pr_agent/tools/pr_reviewer.py | 5 +- requirements.txt | 1 + 7 files changed, 118 insertions(+), 18 deletions(-) create mode 100644 pr_agent/git_providers/git_provider.py create mode 100644 pr_agent/git_providers/gitlab_provider.py diff --git a/pr_agent/git_providers/__init__.py b/pr_agent/git_providers/__init__.py index 54e52767..15f95ba7 100644 --- a/pr_agent/git_providers/__init__.py +++ b/pr_agent/git_providers/__init__.py @@ -1,15 +1,17 @@ from pr_agent.config_loader import settings from pr_agent.git_providers.github_provider import GithubProvider +from pr_agent.git_providers.gitlab_provider import GitLabProvider _GIT_PROVIDERS = { - 'github': GithubProvider + 'github': GithubProvider, + 'gitlab': GitLabProvider, } def get_git_provider(): try: provider_id = settings.config.git_provider except AttributeError as e: - raise ValueError("github_provider is a required attribute in the configuration file") from e + raise ValueError("git_provider is a required attribute in the configuration file") from e if provider_id not in _GIT_PROVIDERS: raise ValueError(f"Unknown git provider: {provider_id}") return _GIT_PROVIDERS[provider_id] diff --git a/pr_agent/git_providers/git_provider.py b/pr_agent/git_providers/git_provider.py new file mode 100644 index 00000000..d3733b6d --- /dev/null +++ b/pr_agent/git_providers/git_provider.py @@ -0,0 +1,38 @@ + +from abc import ABC +from dataclasses import dataclass + + +@dataclass +class FilePatchInfo: + base_file: str + head_file: str + patch: str + filename: str + tokens: int = -1 + + +class GitProvider(ABC): + def get_diff_files(self) -> list[FilePatchInfo]: + pass + + def publish_comment(self, pr_comment: str, is_temporary: bool = False): + pass + + def remove_initial_comment(self): + pass + + def get_languages(self): + pass + + def get_main_pr_language(self) -> str: + pass + + def get_pr_branch(self): + pass + + def get_user_id(self): + pass + + def get_pr_description(): + pass \ No newline at end of file diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index a03d0bee..bd5576c4 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -1,25 +1,17 @@ import logging -from collections import namedtuple -from dataclasses import dataclass from datetime import datetime from typing import Optional, Tuple from urllib.parse import urlparse -from github import AppAuthentication, File, Github +from github import AppAuthentication, Github from pr_agent.config_loader import settings +from .git_provider import FilePatchInfo -@dataclass -class FilePatchInfo: - base_file: str - head_file: str - patch: str - filename: str - tokens: int = -1 class GithubProvider: - def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None): - self.installation_id = installation_id + def __init__(self, pr_url: Optional[str] = None): + self.installation_id = settings.get("GITHUB.INSTALLATION_ID") self.github_client = self._get_github_client() self.repo = None self.pr_num = None @@ -65,7 +57,8 @@ class GithubProvider: return self.pr.body def get_languages(self): - return self._get_repo().get_languages() + languages = self._get_repo().get_languages() + return languages def get_main_pr_language(self) -> str: """ @@ -112,6 +105,9 @@ class GithubProvider: def get_pr_branch(self): return self.pr.head.ref + def get_pr_description(self): + return self.pr.body + def get_user_id(self): if not self.github_user_id: try: diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py new file mode 100644 index 00000000..54c27ad3 --- /dev/null +++ b/pr_agent/git_providers/gitlab_provider.py @@ -0,0 +1,57 @@ +import gitlab +from typing import Optional, Tuple + + +class GitLabProvider: + def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None): + self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token) + self.project = None + self.mr_iid = None + self.mr = None + if merge_request_url: + self.set_merge_request(merge_request_url) + + def set_merge_request(self, merge_request_url: str): + self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url) + self.mr = self._get_merge_request() + + def get_diff_files(self) -> list[FilePatchInfo]: + diffs = self.mr.diffs.list() + diff_files = [] + for diff in diffs: + # GitLab doesn't provide base and head files. Only diffs are available. + diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path'])) + return diff_files + + def publish_comment(self, mr_comment: str): + self.mr.notes.create({'body': mr_comment}) + + def get_title(self): + return self.mr.title + + def get_description(self): + return self.mr.description + + def get_languages(self): + # GitLab does not have a direct equivalent to get_languages(). + # An alternative could be to manually parse all the repository files and determine the language from the file extensions. + raise NotImplementedError + + def get_main_pr_language(self) -> str: + # Similar issue as get_languages(). + raise NotImplementedError + + def get_pr_branch(self): + return self.mr.source_branch + + def get_notifications(self): + # GitLab doesn't provide a notifications API similar to GitHub's. + raise NotImplementedError + + @staticmethod + def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]: + # This function will depend on your GitLab setup and URL structure + raise NotImplementedError + + def _get_merge_request(self): + return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid) diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index 59f4625b..82812734 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -25,3 +25,10 @@ private_key = """\ """ app_id = 123456 # The GitHub App ID, replace with your own. webhook_secret = "" # Optional, may be commented out. + +[gitlab] +# The type of deployment to create. Valid values are 'app' or 'user'. +personal_access_token = "" + +# URL to the gitlab service +gitlab_url = "https://gitlab.com" \ No newline at end of file diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index b39a231e..214056a1 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -16,9 +16,8 @@ from pr_agent.git_providers import get_git_provider class PRReviewer: def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): - self.git_provider = get_git_provider()(pr_url, installation_id) + self.git_provider = get_git_provider()(pr_url) self.main_language = self.git_provider.get_main_pr_language() - self.installation_id = installation_id self.ai_handler = AiHandler() self.patches_diff = None self.prediction = None @@ -26,7 +25,7 @@ class PRReviewer: self.vars = { "title": self.git_provider.pr.title, "branch": self.git_provider.get_pr_branch(), - "description": self.git_provider.pr.body, + "description": self.git_provider.get_pr_description(), "language": self.main_language, "diff": "", # empty diff for initial calculation "require_tests": settings.pr_reviewer.require_tests_review, diff --git a/requirements.txt b/requirements.txt index d7610000..4a6a6255 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ openai==0.27.8 Jinja2==3.1.2 tiktoken==0.4.0 uvicorn==0.22.0 +python-gitlab==3.15.0