mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Gerrit support
This commit is contained in:
@ -5,6 +5,8 @@ from pr_agent.git_providers.github_provider import GithubProvider
|
||||
from pr_agent.git_providers.gitlab_provider import GitLabProvider
|
||||
from pr_agent.git_providers.local_git_provider import LocalGitProvider
|
||||
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
|
||||
from pr_agent.git_providers.gerrit_provider import GerritProvider
|
||||
|
||||
|
||||
_GIT_PROVIDERS = {
|
||||
'github': GithubProvider,
|
||||
@ -12,7 +14,8 @@ _GIT_PROVIDERS = {
|
||||
'bitbucket': BitbucketProvider,
|
||||
'azure': AzureDevopsProvider,
|
||||
'codecommit': CodeCommitProvider,
|
||||
'local' : LocalGitProvider
|
||||
'local' : LocalGitProvider,
|
||||
'gerrit': GerritProvider,
|
||||
}
|
||||
|
||||
def get_git_provider():
|
||||
|
397
pr_agent/git_providers/gerrit_provider.py
Normal file
397
pr_agent/git_providers/gerrit_provider.py
Normal file
@ -0,0 +1,397 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
from collections import Counter, namedtuple
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp, NamedTemporaryFile
|
||||
|
||||
import requests
|
||||
import urllib3.util
|
||||
from git import Repo
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo, \
|
||||
EDIT_TYPE
|
||||
from pr_agent.git_providers.local_git_provider import PullRequestMimic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _call(*command, **kwargs) -> (int, str, str):
|
||||
res = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
**kwargs,
|
||||
)
|
||||
return res.stdout.decode()
|
||||
|
||||
|
||||
def clone(url, directory):
|
||||
logger.info("Cloning %s to %s", url, directory)
|
||||
stdout = _call('git', 'clone', "--depth", "1", url, directory)
|
||||
logger.info(stdout)
|
||||
|
||||
|
||||
def fetch(url, refspec, cwd):
|
||||
logger.info("Fetching %s %s", url, refspec)
|
||||
stdout = _call(
|
||||
'git', 'fetch', '--depth', '2', url, refspec,
|
||||
cwd=cwd
|
||||
)
|
||||
logger.info(stdout)
|
||||
|
||||
|
||||
def checkout(cwd):
|
||||
logger.info("Checking out")
|
||||
stdout = _call('git', 'checkout', "FETCH_HEAD", cwd=cwd)
|
||||
logger.info(stdout)
|
||||
|
||||
|
||||
def show(*args, cwd=None):
|
||||
logger.info("Show")
|
||||
return _call('git', 'show', *args, cwd=cwd)
|
||||
|
||||
|
||||
def diff(*args, cwd=None):
|
||||
logger.info("Diff")
|
||||
patch = _call('git', 'diff', *args, cwd=cwd)
|
||||
if not patch:
|
||||
logger.warning("No changes found")
|
||||
return
|
||||
return patch
|
||||
|
||||
|
||||
def reset_local_changes(cwd):
|
||||
logger.info("Reset local changes")
|
||||
_call('git', 'checkout', "--force", cwd=cwd)
|
||||
|
||||
|
||||
def add_comment(url: urllib3.util.Url, refspec, message):
|
||||
*_, patchset, changenum = refspec.rsplit("/")
|
||||
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
||||
return _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "review",
|
||||
"--message", message,
|
||||
# "--code-review", score,
|
||||
f"{patchset},{changenum}",
|
||||
)
|
||||
|
||||
|
||||
def list_comments(url: urllib3.util.Url, refspec):
|
||||
*_, patchset, _ = refspec.rsplit("/")
|
||||
stdout = _call(
|
||||
"ssh",
|
||||
"-p", str(url.port),
|
||||
f"{url.auth}@{url.host}",
|
||||
"gerrit", "query",
|
||||
"--comments",
|
||||
"--current-patch-set", patchset,
|
||||
"--format", "JSON",
|
||||
)
|
||||
change_set, *_ = stdout.splitlines()
|
||||
return json.loads(change_set)["currentPatchSet"]["comments"]
|
||||
|
||||
|
||||
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
||||
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
|
||||
|
||||
directory = pathlib.Path(mkdtemp())
|
||||
clone(repo_url, directory),
|
||||
fetch(repo_url, refspec, cwd=directory)
|
||||
checkout(cwd=directory)
|
||||
return directory
|
||||
|
||||
|
||||
def adopt_to_gerrit_message(message):
|
||||
lines = message.splitlines()
|
||||
buf = []
|
||||
for line in lines:
|
||||
line = line.replace("*", "").replace("``", "`")
|
||||
line = line.strip()
|
||||
if line.startswith('#'):
|
||||
buf.append("\n" +
|
||||
line.replace('#', '').removesuffix(":").strip() +
|
||||
":")
|
||||
continue
|
||||
elif line.startswith('-'):
|
||||
buf.append(line.removeprefix('-').strip())
|
||||
continue
|
||||
else:
|
||||
buf.append(line)
|
||||
return "\n".join(buf).strip()
|
||||
|
||||
|
||||
def add_suggestion(src_filename, context: str, start, end: int):
|
||||
with (
|
||||
NamedTemporaryFile("w", delete=False) as tmp,
|
||||
open(src_filename, "r") as src
|
||||
):
|
||||
lines = src.readlines()
|
||||
tmp.writelines(lines[:start - 1])
|
||||
if context:
|
||||
tmp.write(context)
|
||||
tmp.writelines(lines[end:])
|
||||
|
||||
shutil.copy(tmp.name, src_filename)
|
||||
os.remove(tmp.name)
|
||||
|
||||
|
||||
def upload_patch(patch, path):
|
||||
patch_server_endpoint = get_settings().get(
|
||||
'gerrit.patch_server_endpoint')
|
||||
patch_server_token = get_settings().get(
|
||||
'gerrit.patch_server_token')
|
||||
|
||||
response = requests.post(
|
||||
patch_server_endpoint,
|
||||
json={
|
||||
"content": patch,
|
||||
"path": path,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {patch_server_token}",
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
||||
return patch_server_endpoint + "/" + path
|
||||
|
||||
|
||||
class GerritProvider(GitProvider):
|
||||
|
||||
def __init__(self, key: str, incremental=False):
|
||||
self.project, self.refspec = key.split(':')
|
||||
assert self.project, "Project name is required"
|
||||
assert self.refspec, "Refspec is required"
|
||||
base_url = get_settings().get('gerrit.url')
|
||||
assert base_url, "Gerrit URL is required"
|
||||
user = get_settings().get('gerrit.user')
|
||||
assert user, "Gerrit user is required"
|
||||
|
||||
parsed = urllib3.util.parse_url(base_url)
|
||||
self.parsed_url = urllib3.util.parse_url(
|
||||
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
||||
)
|
||||
|
||||
self.repo_path = prepare_repo(
|
||||
self.parsed_url, self.project, self.refspec
|
||||
)
|
||||
self.repo = Repo(self.repo_path)
|
||||
assert self.repo
|
||||
|
||||
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
||||
|
||||
def get_pr_title(self):
|
||||
"""
|
||||
Substitutes the branch-name as the PR-mimic title.
|
||||
"""
|
||||
return self.repo.branches[0].name
|
||||
|
||||
def get_issue_comments(self):
|
||||
# raise NotImplementedError(
|
||||
# 'Getting comments is not implemented for the gerrit provider'
|
||||
# )
|
||||
# unclear how to get right comments from gerrit
|
||||
# all user's comments like this "'Patch Set 1: (1 comment)'"
|
||||
# i.e. it's not clear how to get the comment body
|
||||
|
||||
comments = list_comments(self.parsed_url, self.refspec)
|
||||
Comments = namedtuple('Comments', ['reversed'])
|
||||
Comment = namedtuple('Comment', ['body'])
|
||||
return Comments([Comment(c['message']) for c in reversed(comments)])
|
||||
|
||||
def get_labels(self):
|
||||
raise NotImplementedError(
|
||||
'Getting labels is not implemented for the gerrit provider')
|
||||
|
||||
def add_eyes_reaction(self, issue_comment_id: int):
|
||||
raise NotImplementedError(
|
||||
'Adding reactions is not implemented for the gerrit provider')
|
||||
|
||||
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
||||
raise NotImplementedError(
|
||||
'Removing reactions is not implemented for the gerrit provider')
|
||||
|
||||
def get_commit_messages(self):
|
||||
return [self.repo.head.commit.message]
|
||||
|
||||
def get_repo_settings(self):
|
||||
"""
|
||||
TODO: Implement support of .pr_agent.toml
|
||||
"""
|
||||
return ""
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
create_patch=True,
|
||||
R=True
|
||||
)
|
||||
|
||||
diff_files = []
|
||||
for diff_item in diffs:
|
||||
if diff_item.a_blob is not None:
|
||||
original_file_content_str = (
|
||||
diff_item.a_blob.data_stream.read().decode('utf-8')
|
||||
)
|
||||
else:
|
||||
original_file_content_str = "" # empty file
|
||||
if diff_item.b_blob is not None:
|
||||
new_file_content_str = diff_item.b_blob.data_stream.read(). \
|
||||
decode('utf-8')
|
||||
else:
|
||||
new_file_content_str = "" # empty file
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff_item.new_file:
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff_item.deleted_file:
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif diff_item.renamed_file:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
diff_files.append(
|
||||
FilePatchInfo(
|
||||
original_file_content_str,
|
||||
new_file_content_str,
|
||||
diff_item.diff.decode('utf-8'),
|
||||
diff_item.b_path,
|
||||
edit_type=edit_type,
|
||||
old_filename=None
|
||||
if diff_item.a_path == diff_item.b_path
|
||||
else diff_item.a_path
|
||||
)
|
||||
)
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self):
|
||||
diff_index = self.repo.head.commit.diff(
|
||||
self.repo.head.commit.parents[0], # previous commit
|
||||
R=True
|
||||
)
|
||||
# Get the list of changed files
|
||||
diff_files = [item.a_path for item in diff_index]
|
||||
return diff_files
|
||||
|
||||
def get_languages(self):
|
||||
"""
|
||||
Calculate percentage of languages in repository. Used for hunk
|
||||
prioritisation.
|
||||
"""
|
||||
# Get all files in repository
|
||||
filepaths = [Path(item.path) for item in
|
||||
self.repo.tree().traverse() if item.type == 'blob']
|
||||
# Identify language by file extension and count
|
||||
lang_count = Counter(
|
||||
ext.lstrip('.') for filepath in filepaths for ext in
|
||||
[filepath.suffix.lower()])
|
||||
# Convert counts to percentages
|
||||
total_files = len(filepaths)
|
||||
lang_percentage = {lang: count / total_files * 100 for lang, count
|
||||
in lang_count.items()}
|
||||
return lang_percentage
|
||||
|
||||
def get_pr_description(self):
|
||||
return self.repo.head.commit.message
|
||||
|
||||
def get_user_id(self):
|
||||
return self.repo.head.commit.author.email
|
||||
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
if capability in [
|
||||
# 'get_issue_comments',
|
||||
'create_inline_comment',
|
||||
'publish_inline_comments',
|
||||
'get_labels'
|
||||
]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def split_suggestion(self, msg) -> tuple[str, str]:
|
||||
is_code_context = False
|
||||
description = []
|
||||
context = []
|
||||
for line in msg.splitlines():
|
||||
if line.startswith('```suggestion'):
|
||||
is_code_context = True
|
||||
continue
|
||||
if line.startswith('```'):
|
||||
is_code_context = False
|
||||
continue
|
||||
if is_code_context:
|
||||
context.append(line)
|
||||
else:
|
||||
description.append(
|
||||
line.replace('*', '')
|
||||
)
|
||||
|
||||
return (
|
||||
'\n'.join(description),
|
||||
'\n'.join(context) + '\n' if context else ''
|
||||
)
|
||||
|
||||
def publish_code_suggestions(self, code_suggestions: list):
|
||||
msg = []
|
||||
for i, suggestion in enumerate(code_suggestions):
|
||||
description, code = self.split_suggestion(suggestion['body'])
|
||||
add_suggestion(
|
||||
pathlib.Path(self.repo_path) / suggestion["relevant_file"],
|
||||
code,
|
||||
suggestion["relevant_lines_start"],
|
||||
suggestion["relevant_lines_end"],
|
||||
)
|
||||
patch = diff(cwd=self.repo_path)
|
||||
path = "/".join(["codium-ai", self.refspec, str(i)])
|
||||
full_path = upload_patch(patch, path)
|
||||
reset_local_changes(self.repo_path)
|
||||
msg.append(f'* {description}\n{full_path}')
|
||||
|
||||
if msg:
|
||||
add_comment(self.parsed_url, self.refspec, "---\n".join(msg))
|
||||
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
if not is_temporary:
|
||||
msg = adopt_to_gerrit_message(pr_comment)
|
||||
add_comment(self.parsed_url, self.refspec, msg)
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
msg = adopt_to_gerrit_message(pr_body)
|
||||
add_comment(self.parsed_url, self.refspec, pr_title + '\n' + msg)
|
||||
|
||||
def publish_inline_comments(self, comments: list[dict]):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str,
|
||||
relevant_line_in_file: str):
|
||||
raise NotImplementedError(
|
||||
'Publishing inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
def create_inline_comment(self, body: str, relevant_file: str,
|
||||
relevant_line_in_file: str):
|
||||
raise NotImplementedError(
|
||||
'Creating inline comments is not implemented for the gerrit '
|
||||
'provider')
|
||||
|
||||
def publish_labels(self, labels):
|
||||
# Not applicable to the local git provider,
|
||||
# but required by the interface
|
||||
pass
|
||||
|
||||
def remove_initial_comment(self):
|
||||
# remove repo, cloned in previous steps
|
||||
# shutil.rmtree(self.repo_path)
|
||||
pass
|
||||
|
||||
def get_pr_branch(self):
|
||||
return self.repo.head
|
81
pr_agent/servers/gerrit_server.py
Normal file
81
pr_agent/servers/gerrit_server.py
Normal file
@ -0,0 +1,81 @@
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
|
||||
import uvicorn
|
||||
from fastapi import APIRouter, FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware import Middleware
|
||||
from starlette_context import context
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from pr_agent.agent.pr_agent import PRAgent
|
||||
from pr_agent.config_loader import global_settings, get_settings
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class Action(str, Enum):
|
||||
review = "review"
|
||||
describe = "describe"
|
||||
ask = "ask"
|
||||
improve = "improve"
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
refspec: str
|
||||
project: str
|
||||
msg: str = None
|
||||
|
||||
|
||||
@router.post("/api/v1/gerrit/{action}")
|
||||
async def handle_gerrit_request(action: Action, item: Item):
|
||||
logging.debug("Received a Gerrit request")
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
|
||||
agent = PRAgent()
|
||||
pr_url = f"{item.project}:{item.refspec}"
|
||||
if action == Action.review:
|
||||
await agent.handle_request(pr_url, "/review")
|
||||
elif action == Action.describe:
|
||||
await agent.handle_request(pr_url, "/describe")
|
||||
elif action == Action.improve:
|
||||
await agent.handle_request(pr_url, "/improve")
|
||||
elif action == Action.ask:
|
||||
if not item.msg:
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail="msg is required for ask command"
|
||||
)
|
||||
await agent.handle_request(pr_url, f"/ask {item.msg.strip()}")
|
||||
|
||||
|
||||
async def get_body(request):
|
||||
try:
|
||||
body = await request.json()
|
||||
except JSONDecodeError as e:
|
||||
logging.error("Error parsing request body", e)
|
||||
return {}
|
||||
return body
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def start():
|
||||
# to prevent adding help messages with the output
|
||||
get_settings().set("CONFIG.CLI_MODE", True)
|
||||
middleware = [Middleware(RawContextMiddleware)]
|
||||
app = FastAPI(middleware=middleware)
|
||||
app.include_router(router)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=3000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
start()
|
@ -1,7 +1,7 @@
|
||||
[config]
|
||||
model="gpt-4"
|
||||
fallback_models=["gpt-3.5-turbo-16k"]
|
||||
git_provider="github"
|
||||
git_provider="gerrit"
|
||||
publish_output=true
|
||||
publish_output_progress=true
|
||||
verbosity_level=0 # 0,1,2
|
||||
@ -85,3 +85,13 @@ polling_interval_seconds = 30
|
||||
# LocalGitProvider settings - uncomment to use paths other than default
|
||||
# description_path= "path/to/description.md"
|
||||
# review_path= "path/to/review.md"
|
||||
|
||||
[gerrit]
|
||||
# endpoint to the gerrit service
|
||||
# url = "ssh://gerrit.example.com:29418"
|
||||
# user for gerrit authentication
|
||||
# user = "ai-reviewer"
|
||||
# patch server where patches will be saved
|
||||
# patch_server_endpoint = "http://127.0.0.1:5000/patch"
|
||||
# token to authenticate in the patch server
|
||||
# patch_server_token = ""
|
||||
|
Reference in New Issue
Block a user