working example

This commit is contained in:
Albert Achtenberg
2023-07-07 15:02:40 +03:00
parent 4b786b350e
commit ed8cf27b05
7 changed files with 118 additions and 18 deletions

View File

@ -1,15 +1,17 @@
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.github_provider import GithubProvider
from pr_agent.git_providers.gitlab_provider import GitLabProvider
_GIT_PROVIDERS = { _GIT_PROVIDERS = {
'github': GithubProvider 'github': GithubProvider,
'gitlab': GitLabProvider,
} }
def get_git_provider(): def get_git_provider():
try: try:
provider_id = settings.config.git_provider provider_id = settings.config.git_provider
except AttributeError as e: except AttributeError as e:
raise ValueError("github_provider is a required attribute in the configuration file") from e raise ValueError("git_provider is a required attribute in the configuration file") from e
if provider_id not in _GIT_PROVIDERS: if provider_id not in _GIT_PROVIDERS:
raise ValueError(f"Unknown git provider: {provider_id}") raise ValueError(f"Unknown git provider: {provider_id}")
return _GIT_PROVIDERS[provider_id] return _GIT_PROVIDERS[provider_id]

View File

@ -0,0 +1,38 @@
from abc import ABC
from dataclasses import dataclass
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
class GitProvider(ABC):
def get_diff_files(self) -> list[FilePatchInfo]:
pass
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
def remove_initial_comment(self):
pass
def get_languages(self):
pass
def get_main_pr_language(self) -> str:
pass
def get_pr_branch(self):
pass
def get_user_id(self):
pass
def get_pr_description():
pass

View File

@ -1,25 +1,17 @@
import logging import logging
from collections import namedtuple
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from github import AppAuthentication, File, Github from github import AppAuthentication, Github
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo
@dataclass
class FilePatchInfo:
base_file: str
head_file: str
patch: str
filename: str
tokens: int = -1
class GithubProvider: class GithubProvider:
def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None): def __init__(self, pr_url: Optional[str] = None):
self.installation_id = installation_id self.installation_id = settings.get("GITHUB.INSTALLATION_ID")
self.github_client = self._get_github_client() self.github_client = self._get_github_client()
self.repo = None self.repo = None
self.pr_num = None self.pr_num = None
@ -65,7 +57,8 @@ class GithubProvider:
return self.pr.body return self.pr.body
def get_languages(self): def get_languages(self):
return self._get_repo().get_languages() languages = self._get_repo().get_languages()
return languages
def get_main_pr_language(self) -> str: def get_main_pr_language(self) -> str:
""" """
@ -112,6 +105,9 @@ class GithubProvider:
def get_pr_branch(self): def get_pr_branch(self):
return self.pr.head.ref return self.pr.head.ref
def get_pr_description(self):
return self.pr.body
def get_user_id(self): def get_user_id(self):
if not self.github_user_id: if not self.github_user_id:
try: try:

View File

@ -0,0 +1,57 @@
import gitlab
from typing import Optional, Tuple
class GitLabProvider:
def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None):
self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token)
self.project = None
self.mr_iid = None
self.mr = None
if merge_request_url:
self.set_merge_request(merge_request_url)
def set_merge_request(self, merge_request_url: str):
self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request()
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.mr.diffs.list()
diff_files = []
for diff in diffs:
# GitLab doesn't provide base and head files. Only diffs are available.
diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path']))
return diff_files
def publish_comment(self, mr_comment: str):
self.mr.notes.create({'body': mr_comment})
def get_title(self):
return self.mr.title
def get_description(self):
return self.mr.description
def get_languages(self):
# GitLab does not have a direct equivalent to get_languages().
# An alternative could be to manually parse all the repository files and determine the language from the file extensions.
raise NotImplementedError
def get_main_pr_language(self) -> str:
# Similar issue as get_languages().
raise NotImplementedError
def get_pr_branch(self):
return self.mr.source_branch
def get_notifications(self):
# GitLab doesn't provide a notifications API similar to GitHub's.
raise NotImplementedError
@staticmethod
def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]:
# This function will depend on your GitLab setup and URL structure
raise NotImplementedError
def _get_merge_request(self):
return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid)

View File

@ -25,3 +25,10 @@ private_key = """\
""" """
app_id = 123456 # The GitHub App ID, replace with your own. app_id = 123456 # The GitHub App ID, replace with your own.
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out. webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
[gitlab]
# The type of deployment to create. Valid values are 'app' or 'user'.
personal_access_token = ""
# URL to the gitlab service
gitlab_url = "https://gitlab.com"

View File

@ -16,9 +16,8 @@ from pr_agent.git_providers import get_git_provider
class PRReviewer: class PRReviewer:
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False): def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
self.git_provider = get_git_provider()(pr_url, installation_id) self.git_provider = get_git_provider()(pr_url)
self.main_language = self.git_provider.get_main_pr_language() self.main_language = self.git_provider.get_main_pr_language()
self.installation_id = installation_id
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -26,7 +25,7 @@ class PRReviewer:
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.pr.body, "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"require_tests": settings.pr_reviewer.require_tests_review, "require_tests": settings.pr_reviewer.require_tests_review,

View File

@ -6,3 +6,4 @@ openai==0.27.8
Jinja2==3.1.2 Jinja2==3.1.2
tiktoken==0.4.0 tiktoken==0.4.0
uvicorn==0.22.0 uvicorn==0.22.0
python-gitlab==3.15.0