mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
working example
This commit is contained in:
@ -1,15 +1,17 @@
|
|||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers.github_provider import GithubProvider
|
from pr_agent.git_providers.github_provider import GithubProvider
|
||||||
|
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||||
|
|
||||||
_GIT_PROVIDERS = {
|
_GIT_PROVIDERS = {
|
||||||
'github': GithubProvider
|
'github': GithubProvider,
|
||||||
|
'gitlab': GitLabProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_git_provider():
|
def get_git_provider():
|
||||||
try:
|
try:
|
||||||
provider_id = settings.config.git_provider
|
provider_id = settings.config.git_provider
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
raise ValueError("github_provider is a required attribute in the configuration file") from e
|
raise ValueError("git_provider is a required attribute in the configuration file") from e
|
||||||
if provider_id not in _GIT_PROVIDERS:
|
if provider_id not in _GIT_PROVIDERS:
|
||||||
raise ValueError(f"Unknown git provider: {provider_id}")
|
raise ValueError(f"Unknown git provider: {provider_id}")
|
||||||
return _GIT_PROVIDERS[provider_id]
|
return _GIT_PROVIDERS[provider_id]
|
||||||
|
38
pr_agent/git_providers/git_provider.py
Normal file
38
pr_agent/git_providers/git_provider.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilePatchInfo:
|
||||||
|
base_file: str
|
||||||
|
head_file: str
|
||||||
|
patch: str
|
||||||
|
filename: str
|
||||||
|
tokens: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
class GitProvider(ABC):
|
||||||
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_initial_comment(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_languages(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_main_pr_language(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_pr_branch(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_user_id(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_pr_description():
|
||||||
|
pass
|
@ -1,25 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from github import AppAuthentication, File, Github
|
from github import AppAuthentication, Github
|
||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
|
from .git_provider import FilePatchInfo
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FilePatchInfo:
|
|
||||||
base_file: str
|
|
||||||
head_file: str
|
|
||||||
patch: str
|
|
||||||
filename: str
|
|
||||||
tokens: int = -1
|
|
||||||
|
|
||||||
class GithubProvider:
|
class GithubProvider:
|
||||||
def __init__(self, pr_url: Optional[str] = None, installation_id: Optional[int] = None):
|
def __init__(self, pr_url: Optional[str] = None):
|
||||||
self.installation_id = installation_id
|
self.installation_id = settings.get("GITHUB.INSTALLATION_ID")
|
||||||
self.github_client = self._get_github_client()
|
self.github_client = self._get_github_client()
|
||||||
self.repo = None
|
self.repo = None
|
||||||
self.pr_num = None
|
self.pr_num = None
|
||||||
@ -65,7 +57,8 @@ class GithubProvider:
|
|||||||
return self.pr.body
|
return self.pr.body
|
||||||
|
|
||||||
def get_languages(self):
|
def get_languages(self):
|
||||||
return self._get_repo().get_languages()
|
languages = self._get_repo().get_languages()
|
||||||
|
return languages
|
||||||
|
|
||||||
def get_main_pr_language(self) -> str:
|
def get_main_pr_language(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -112,6 +105,9 @@ class GithubProvider:
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.head.ref
|
return self.pr.head.ref
|
||||||
|
|
||||||
|
def get_pr_description(self):
|
||||||
|
return self.pr.body
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
if not self.github_user_id:
|
if not self.github_user_id:
|
||||||
try:
|
try:
|
||||||
|
57
pr_agent/git_providers/gitlab_provider.py
Normal file
57
pr_agent/git_providers/gitlab_provider.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import gitlab
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class GitLabProvider:
|
||||||
|
def __init__(self, merge_request_url: Optional[str] = None, personal_access_token: Optional[str] = None):
|
||||||
|
self.gl = gitlab.Gitlab('https://your.gitlab.com', private_token=personal_access_token)
|
||||||
|
self.project = None
|
||||||
|
self.mr_iid = None
|
||||||
|
self.mr = None
|
||||||
|
if merge_request_url:
|
||||||
|
self.set_merge_request(merge_request_url)
|
||||||
|
|
||||||
|
def set_merge_request(self, merge_request_url: str):
|
||||||
|
self.project, self.mr_iid = self._parse_merge_request_url(merge_request_url)
|
||||||
|
self.mr = self._get_merge_request()
|
||||||
|
|
||||||
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
|
diffs = self.mr.diffs.list()
|
||||||
|
diff_files = []
|
||||||
|
for diff in diffs:
|
||||||
|
# GitLab doesn't provide base and head files. Only diffs are available.
|
||||||
|
diff_files.append(FilePatchInfo("", "", diff['diff'], diff['new_path']))
|
||||||
|
return diff_files
|
||||||
|
|
||||||
|
def publish_comment(self, mr_comment: str):
|
||||||
|
self.mr.notes.create({'body': mr_comment})
|
||||||
|
|
||||||
|
def get_title(self):
|
||||||
|
return self.mr.title
|
||||||
|
|
||||||
|
def get_description(self):
|
||||||
|
return self.mr.description
|
||||||
|
|
||||||
|
def get_languages(self):
|
||||||
|
# GitLab does not have a direct equivalent to get_languages().
|
||||||
|
# An alternative could be to manually parse all the repository files and determine the language from the file extensions.
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_main_pr_language(self) -> str:
|
||||||
|
# Similar issue as get_languages().
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pr_branch(self):
|
||||||
|
return self.mr.source_branch
|
||||||
|
|
||||||
|
def get_notifications(self):
|
||||||
|
# GitLab doesn't provide a notifications API similar to GitHub's.
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_merge_request_url(merge_request_url: str) -> Tuple[str, int]:
|
||||||
|
# This function will depend on your GitLab setup and URL structure
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _get_merge_request(self):
|
||||||
|
return self.gl.projects.get(self.project).mergerequests.get(self.mr_iid)
|
@ -25,3 +25,10 @@ private_key = """\
|
|||||||
"""
|
"""
|
||||||
app_id = 123456 # The GitHub App ID, replace with your own.
|
app_id = 123456 # The GitHub App ID, replace with your own.
|
||||||
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
|
webhook_secret = "<WEBHOOK SECRET>" # Optional, may be commented out.
|
||||||
|
|
||||||
|
[gitlab]
|
||||||
|
# The type of deployment to create. Valid values are 'app' or 'user'.
|
||||||
|
personal_access_token = ""
|
||||||
|
|
||||||
|
# URL to the gitlab service
|
||||||
|
gitlab_url = "https://gitlab.com"
|
@ -16,9 +16,8 @@ from pr_agent.git_providers import get_git_provider
|
|||||||
class PRReviewer:
|
class PRReviewer:
|
||||||
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
|
def __init__(self, pr_url: str, installation_id: Optional[int] = None, cli_mode=False):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url, installation_id)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = self.git_provider.get_main_pr_language()
|
self.main_language = self.git_provider.get_main_pr_language()
|
||||||
self.installation_id = installation_id
|
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
@ -26,7 +25,7 @@ class PRReviewer:
|
|||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
"branch": self.git_provider.get_pr_branch(),
|
"branch": self.git_provider.get_pr_branch(),
|
||||||
"description": self.git_provider.pr.body,
|
"description": self.git_provider.get_pr_description(),
|
||||||
"language": self.main_language,
|
"language": self.main_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"require_tests": settings.pr_reviewer.require_tests_review,
|
"require_tests": settings.pr_reviewer.require_tests_review,
|
||||||
|
@ -6,3 +6,4 @@ openai==0.27.8
|
|||||||
Jinja2==3.1.2
|
Jinja2==3.1.2
|
||||||
tiktoken==0.4.0
|
tiktoken==0.4.0
|
||||||
uvicorn==0.22.0
|
uvicorn==0.22.0
|
||||||
|
python-gitlab==3.15.0
|
||||||
|
Reference in New Issue
Block a user