mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
working example
This commit is contained in:
@ -1,15 +1,17 @@
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers.github_provider import GithubProvider
|
||||
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||
|
||||
_GIT_PROVIDERS = {
|
||||
'github': GithubProvider
|
||||
'github': GithubProvider,
|
||||
'gitlab': GitLabProvider,
|
||||
}
|
||||
|
||||
def get_git_provider():
|
||||
try:
|
||||
provider_id = settings.config.git_provider
|
||||
except AttributeError as e:
|
||||
raise ValueError("github_provider is a required attribute in the configuration file") from e
|
||||
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
||||
if provider_id not in _GIT_PROVIDERS:
|
||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||
return _GIT_PROVIDERS[provider_id]
|
||||
|
38
pr_agent/git_providers/git_provider.py
Normal file
38
pr_agent/git_providers/git_provider.py
Normal file
@ -0,0 +1,38 @@
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
|
||||
|
||||
class GitProvider(ABC):
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
pass
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
pass
|
||||
|
||||
def remove_initial_comment(self):
|
||||
pass
|
||||
|
||||
def get_languages(self):
|
||||
pass
|
||||
|
||||
def get_main_pr_language(self) -> str:
|
||||
pass
|
||||
|
||||
def get_pr_branch(self):
|
||||
pass
|
||||
|
||||
def get_user_id(self):
|
||||
pass
|
||||
|
||||
def get_pr_description():
|
||||
pass
|
@ -1,25 +1,17 @@
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from github import AppAuthentication, File, Github
|
||||
from github import AppAuthentication, Github
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
from .git_provider import FilePatchInfo
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
|
||||
class GithubProvider:
|
||||
def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None):
|
||||
self.installation_id = installation_id
|
||||
def __init__(self, pr_url: Optional[str] = None):
|
||||
self.installation_id = settings.get("GITHUB.INSTALLATION_ID")
|
||||
self.github_client = self._get_github_client()
|
||||
self.repo = None
|
||||
self.pr_num = None
|
||||
@ -65,7 +57,8 @@ class GithubProvider:
|
||||
return self.pr.body
|
||||
|
||||
def get_languages(self):
|
||||
return self._get_repo().get_languages()
|
||||
languages = self._get_repo().get_languages()
|
||||
return languages
|
||||
|
||||
def get_main_pr_language(self) -> str:
|
||||
"""
|
||||
@ -112,6 +105,9 @@ class GithubProvider:
|
||||
def get_pr_branch(self):
|
||||
return self.pr.head.ref
|
||||
|
||||
def get_pr_description(self):
|
||||
return self.pr.body
|
||||
|
||||
def get_user_id(self):
|
||||
if not self.github_user_id:
|
||||
try:
|
||||
|
57
pr_agent/git_providers/gitlab_provider.py
Normal file
57
pr_agent/git_providers/gitlab_provider.py
Normal file
@ -0,0 +1,57 @@
|
||||
import gitlab
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class GitLabProvider:
|
||||
def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None):
|
||||
self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token)
|
||||
self.project = None
|
||||
self.mr_iid = None
|
||||
self.mr = None
|
||||
if merge_request_url:
|
||||
self.set_merge_request(merge_request_url)
|
||||
|
||||
def set_merge_request(self, merge_request_url: str):
|
||||
self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url)
|
||||
self.mr = self._get_merge_request()
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.mr.diffs.list()
|
||||
diff_files = []
|
||||
for diff in diffs:
|
||||
# GitLab doesn't provide base and head files. Only diffs are available.
|
||||
diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path']))
|
||||
return diff_files
|
||||
|
||||
def publish_comment(self, mr_comment: str):
|
||||
self.mr.notes.create({'body': mr_comment})
|
||||
|
||||
def get_title(self):
|
||||
return self.mr.title
|
||||
|
||||
def get_description(self):
|
||||
return self.mr.description
|
||||
|
||||
def get_languages(self):
|
||||
# GitLab does not have a direct equivalent to get_languages().
|
||||
# An alternative could be to manually parse all the repository files and determine the language from the file extensions.
|
||||
raise NotImplementedError
|
||||
|
||||
def get_main_pr_language(self) -> str:
|
||||
# Similar issue as get_languages().
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.mr.source_branch
|
||||
|
||||
def get_notifications(self):
|
||||
# GitLab doesn't provide a notifications API similar to GitHub's.
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]:
|
||||
# This function will depend on your GitLab setup and URL structure
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_merge_request(self):
|
||||
return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid)
|
@ -25,3 +25,10 @@ private_key = """\
|
||||
"""
|
||||
app_id = 123456 # The GitHub App ID, replace with your own.
|
||||
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
|
||||
|
||||
[gitlab]
|
||||
# The type of deployment to create. Valid values are 'app' or 'user'.
|
||||
personal_access_token = ""
|
||||
|
||||
# URL to the gitlab service
|
||||
gitlab_url = "https://gitlab.com"
|
@ -16,9 +16,8 @@ from pr_agent.git_providers import get_git_provider
|
||||
class PRReviewer:
|
||||
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url, installation_id)
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = self.git_provider.get_main_pr_language()
|
||||
self.installation_id = installation_id
|
||||
self.ai_handler = AiHandler()
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
@ -26,7 +25,7 @@ class PRReviewer:
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
"description": self.git_provider.pr.body,
|
||||
"description": self.git_provider.get_pr_description(),
|
||||
"language": self.main_language,
|
||||
"diff": "", # empty diff for initial calculation
|
||||
"require_tests": settings.pr_reviewer.require_tests_review,
|
||||
|
@ -6,3 +6,4 @@ openai==0.27.8
|
||||
Jinja2==3.1.2
|
||||
tiktoken==0.4.0
|
||||
uvicorn==0.22.0
|
||||
python-gitlab==3.15.0
|
||||
|
Reference in New Issue
Block a user