From eeea38dab3e6e1b3f4a03f35997c5578cc3e77fa Mon Sep 17 00:00:00 2001 From: Nikolay Telepenin Date: Fri, 1 Sep 2023 12:24:20 +0100 Subject: [PATCH] Gerrit support --- pr_agent/git_providers/__init__.py | 5 +- pr_agent/git_providers/gerrit_provider.py | 397 ++++++++++++++++++++++ pr_agent/servers/gerrit_server.py | 81 +++++ pr_agent/settings/configuration.toml | 14 +- 4 files changed, 494 insertions(+), 3 deletions(-) create mode 100644 pr_agent/git_providers/gerrit_provider.py create mode 100644 pr_agent/servers/gerrit_server.py diff --git a/pr_agent/git_providers/__init__.py b/pr_agent/git_providers/__init__.py index 376d09f5..968f0dfc 100644 --- a/pr_agent/git_providers/__init__.py +++ b/pr_agent/git_providers/__init__.py @@ -5,6 +5,8 @@ from pr_agent.git_providers.github_provider import GithubProvider from pr_agent.git_providers.gitlab_provider import GitLabProvider from pr_agent.git_providers.local_git_provider import LocalGitProvider from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider +from pr_agent.git_providers.gerrit_provider import GerritProvider + _GIT_PROVIDERS = { 'github': GithubProvider, @@ -12,7 +14,8 @@ _GIT_PROVIDERS = { 'bitbucket': BitbucketProvider, 'azure': AzureDevopsProvider, 'codecommit': CodeCommitProvider, - 'local' : LocalGitProvider + 'local' : LocalGitProvider, + 'gerrit': GerritProvider, } def get_git_provider(): diff --git a/pr_agent/git_providers/gerrit_provider.py b/pr_agent/git_providers/gerrit_provider.py new file mode 100644 index 00000000..d2a80a65 --- /dev/null +++ b/pr_agent/git_providers/gerrit_provider.py @@ -0,0 +1,397 @@ +import json +import logging +import os +import pathlib +import shutil +import subprocess +from collections import Counter, namedtuple +from pathlib import Path +from tempfile import mkdtemp, NamedTemporaryFile + +import requests +import urllib3.util +from git import Repo + +from pr_agent.config_loader import get_settings +from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo, \ + EDIT_TYPE +from pr_agent.git_providers.local_git_provider import PullRequestMimic + +logger = logging.getLogger(__name__) + + +def _call(*command, **kwargs) -> (int, str, str): + res = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + **kwargs, + ) + return res.stdout.decode() + + +def clone(url, directory): + logger.info("Cloning %s to %s", url, directory) + stdout = _call('git', 'clone', "--depth", "1", url, directory) + logger.info(stdout) + + +def fetch(url, refspec, cwd): + logger.info("Fetching %s %s", url, refspec) + stdout = _call( + 'git', 'fetch', '--depth', '2', url, refspec, + cwd=cwd + ) + logger.info(stdout) + + +def checkout(cwd): + logger.info("Checking out") + stdout = _call('git', 'checkout', "FETCH_HEAD", cwd=cwd) + logger.info(stdout) + + +def show(*args, cwd=None): + logger.info("Show") + return _call('git', 'show', *args, cwd=cwd) + + +def diff(*args, cwd=None): + logger.info("Diff") + patch = _call('git', 'diff', *args, cwd=cwd) + if not patch: + logger.warning("No changes found") + return + return patch + + +def reset_local_changes(cwd): + logger.info("Reset local changes") + _call('git', 'checkout', "--force", cwd=cwd) + + +def add_comment(url: urllib3.util.Url, refspec, message): + *_, patchset, changenum = refspec.rsplit("/") + message = "'" + message.replace("'", "'\"'\"'") + "'" + return _call( + "ssh", + "-p", str(url.port), + f"{url.auth}@{url.host}", + "gerrit", "review", + "--message", message, + # "--code-review", score, + f"{patchset},{changenum}", + ) + + +def list_comments(url: urllib3.util.Url, refspec): + *_, patchset, _ = refspec.rsplit("/") + stdout = _call( + "ssh", + "-p", str(url.port), + f"{url.auth}@{url.host}", + "gerrit", "query", + "--comments", + "--current-patch-set", patchset, + "--format", "JSON", + ) + change_set, *_ = stdout.splitlines() + return json.loads(change_set)["currentPatchSet"]["comments"] + + +def prepare_repo(url: urllib3.util.Url, project, refspec): + repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}") + + directory = pathlib.Path(mkdtemp()) + clone(repo_url, directory), + fetch(repo_url, refspec, cwd=directory) + checkout(cwd=directory) + return directory + + +def adopt_to_gerrit_message(message): + lines = message.splitlines() + buf = [] + for line in lines: + line = line.replace("*", "").replace("``", "`") + line = line.strip() + if line.startswith('#'): + buf.append("\n" + + line.replace('#', '').removesuffix(":").strip() + + ":") + continue + elif line.startswith('-'): + buf.append(line.removeprefix('-').strip()) + continue + else: + buf.append(line) + return "\n".join(buf).strip() + + +def add_suggestion(src_filename, context: str, start, end: int): + with ( + NamedTemporaryFile("w", delete=False) as tmp, + open(src_filename, "r") as src + ): + lines = src.readlines() + tmp.writelines(lines[:start - 1]) + if context: + tmp.write(context) + tmp.writelines(lines[end:]) + + shutil.copy(tmp.name, src_filename) + os.remove(tmp.name) + + +def upload_patch(patch, path): + patch_server_endpoint = get_settings().get( + 'gerrit.patch_server_endpoint') + patch_server_token = get_settings().get( + 'gerrit.patch_server_token') + + response = requests.post( + patch_server_endpoint, + json={ + "content": patch, + "path": path, + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {patch_server_token}", + } + ) + response.raise_for_status() + patch_server_endpoint = patch_server_endpoint.rstrip("/") + return patch_server_endpoint + "/" + path + + +class GerritProvider(GitProvider): + + def __init__(self, key: str, incremental=False): + self.project, self.refspec = key.split(':') + assert self.project, "Project name is required" + assert self.refspec, "Refspec is required" + base_url = get_settings().get('gerrit.url') + assert base_url, "Gerrit URL is required" + user = get_settings().get('gerrit.user') + assert user, "Gerrit user is required" + + parsed = urllib3.util.parse_url(base_url) + self.parsed_url = urllib3.util.parse_url( + f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}" + ) + + self.repo_path = prepare_repo( + self.parsed_url, self.project, self.refspec + ) + self.repo = Repo(self.repo_path) + assert self.repo + + self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files()) + + def get_pr_title(self): + """ + Substitutes the branch-name as the PR-mimic title. + """ + return self.repo.branches[0].name + + def get_issue_comments(self): + # raise NotImplementedError( + # 'Getting comments is not implemented for the gerrit provider' + # ) + # unclear how to get right comments from gerrit + # all user's comments like this "'Patch Set 1: (1 comment)'" + # i.e. it's not clear how to get the comment body + + comments = list_comments(self.parsed_url, self.refspec) + Comments = namedtuple('Comments', ['reversed']) + Comment = namedtuple('Comment', ['body']) + return Comments([Comment(c['message']) for c in reversed(comments)]) + + def get_labels(self): + raise NotImplementedError( + 'Getting labels is not implemented for the gerrit provider') + + def add_eyes_reaction(self, issue_comment_id: int): + raise NotImplementedError( + 'Adding reactions is not implemented for the gerrit provider') + + def remove_reaction(self, issue_comment_id: int, reaction_id: int): + raise NotImplementedError( + 'Removing reactions is not implemented for the gerrit provider') + + def get_commit_messages(self): + return [self.repo.head.commit.message] + + def get_repo_settings(self): + """ + TODO: Implement support of .pr_agent.toml + """ + return "" + + def get_diff_files(self) -> list[FilePatchInfo]: + diffs = self.repo.head.commit.diff( + self.repo.head.commit.parents[0], # previous commit + create_patch=True, + R=True + ) + + diff_files = [] + for diff_item in diffs: + if diff_item.a_blob is not None: + original_file_content_str = ( + diff_item.a_blob.data_stream.read().decode('utf-8') + ) + else: + original_file_content_str = "" # empty file + if diff_item.b_blob is not None: + new_file_content_str = diff_item.b_blob.data_stream.read(). \ + decode('utf-8') + else: + new_file_content_str = "" # empty file + edit_type = EDIT_TYPE.MODIFIED + if diff_item.new_file: + edit_type = EDIT_TYPE.ADDED + elif diff_item.deleted_file: + edit_type = EDIT_TYPE.DELETED + elif diff_item.renamed_file: + edit_type = EDIT_TYPE.RENAMED + diff_files.append( + FilePatchInfo( + original_file_content_str, + new_file_content_str, + diff_item.diff.decode('utf-8'), + diff_item.b_path, + edit_type=edit_type, + old_filename=None + if diff_item.a_path == diff_item.b_path + else diff_item.a_path + ) + ) + self.diff_files = diff_files + return diff_files + + def get_files(self): + diff_index = self.repo.head.commit.diff( + self.repo.head.commit.parents[0], # previous commit + R=True + ) + # Get the list of changed files + diff_files = [item.a_path for item in diff_index] + return diff_files + + def get_languages(self): + """ + Calculate percentage of languages in repository. Used for hunk + prioritisation. + """ + # Get all files in repository + filepaths = [Path(item.path) for item in + self.repo.tree().traverse() if item.type == 'blob'] + # Identify language by file extension and count + lang_count = Counter( + ext.lstrip('.') for filepath in filepaths for ext in + [filepath.suffix.lower()]) + # Convert counts to percentages + total_files = len(filepaths) + lang_percentage = {lang: count / total_files * 100 for lang, count + in lang_count.items()} + return lang_percentage + + def get_pr_description(self): + return self.repo.head.commit.message + + def get_user_id(self): + return self.repo.head.commit.author.email + + def is_supported(self, capability: str) -> bool: + if capability in [ + # 'get_issue_comments', + 'create_inline_comment', + 'publish_inline_comments', + 'get_labels' + ]: + return False + return True + + def split_suggestion(self, msg) -> tuple[str, str]: + is_code_context = False + description = [] + context = [] + for line in msg.splitlines(): + if line.startswith('```suggestion'): + is_code_context = True + continue + if line.startswith('```'): + is_code_context = False + continue + if is_code_context: + context.append(line) + else: + description.append( + line.replace('*', '') + ) + + return ( + '\n'.join(description), + '\n'.join(context) + '\n' if context else '' + ) + + def publish_code_suggestions(self, code_suggestions: list): + msg = [] + for i, suggestion in enumerate(code_suggestions): + description, code = self.split_suggestion(suggestion['body']) + add_suggestion( + pathlib.Path(self.repo_path) / suggestion["relevant_file"], + code, + suggestion["relevant_lines_start"], + suggestion["relevant_lines_end"], + ) + patch = diff(cwd=self.repo_path) + path = "/".join(["codium-ai", self.refspec, str(i)]) + full_path = upload_patch(patch, path) + reset_local_changes(self.repo_path) + msg.append(f'* {description}\n{full_path}') + + if msg: + add_comment(self.parsed_url, self.refspec, "---\n".join(msg)) + + def publish_comment(self, pr_comment: str, is_temporary: bool = False): + if not is_temporary: + msg = adopt_to_gerrit_message(pr_comment) + add_comment(self.parsed_url, self.refspec, msg) + + def publish_description(self, pr_title: str, pr_body: str): + msg = adopt_to_gerrit_message(pr_body) + add_comment(self.parsed_url, self.refspec, pr_title + '\n' + msg) + + def publish_inline_comments(self, comments: list[dict]): + raise NotImplementedError( + 'Publishing inline comments is not implemented for the gerrit ' + 'provider') + + def publish_inline_comment(self, body: str, relevant_file: str, + relevant_line_in_file: str): + raise NotImplementedError( + 'Publishing inline comments is not implemented for the gerrit ' + 'provider') + + def create_inline_comment(self, body: str, relevant_file: str, + relevant_line_in_file: str): + raise NotImplementedError( + 'Creating inline comments is not implemented for the gerrit ' + 'provider') + + def publish_labels(self, labels): + # Not applicable to the local git provider, + # but required by the interface + pass + + def remove_initial_comment(self): + # remove repo, cloned in previous steps + # shutil.rmtree(self.repo_path) + pass + + def get_pr_branch(self): + return self.repo.head diff --git a/pr_agent/servers/gerrit_server.py b/pr_agent/servers/gerrit_server.py new file mode 100644 index 00000000..07ead55a --- /dev/null +++ b/pr_agent/servers/gerrit_server.py @@ -0,0 +1,81 @@ +import copy +import logging +import sys +from enum import Enum +from json import JSONDecodeError + +import uvicorn +from fastapi import APIRouter, FastAPI, HTTPException +from pydantic import BaseModel +from starlette.middleware import Middleware +from starlette_context import context +from starlette_context.middleware import RawContextMiddleware + +from pr_agent.agent.pr_agent import PRAgent +from pr_agent.config_loader import global_settings, get_settings + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +router = APIRouter() + + +class Action(str, Enum): + review = "review" + describe = "describe" + ask = "ask" + improve = "improve" + + +class Item(BaseModel): + refspec: str + project: str + msg: str = None + + +@router.post("/api/v1/gerrit/{action}") +async def handle_gerrit_request(action: Action, item: Item): + logging.debug("Received a Gerrit request") + context["settings"] = copy.deepcopy(global_settings) + + agent = PRAgent() + pr_url = f"{item.project}:{item.refspec}" + if action == Action.review: + await agent.handle_request(pr_url, "/review") + elif action == Action.describe: + await agent.handle_request(pr_url, "/describe") + elif action == Action.improve: + await agent.handle_request(pr_url, "/improve") + elif action == Action.ask: + if not item.msg: + return HTTPException( + status_code=400, + detail="msg is required for ask command" + ) + await agent.handle_request(pr_url, f"/ask {item.msg.strip()}") + + +async def get_body(request): + try: + body = await request.json() + except JSONDecodeError as e: + logging.error("Error parsing request body", e) + return {} + return body + + +@router.get("/") +async def root(): + return {"status": "ok"} + + +def start(): + # to prevent adding help messages with the output + get_settings().set("CONFIG.CLI_MODE", True) + middleware = [Middleware(RawContextMiddleware)] + app = FastAPI(middleware=middleware) + app.include_router(router) + + uvicorn.run(app, host="0.0.0.0", port=3000) + + +if __name__ == '__main__': + start() diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index f8abd555..4dfc80c8 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,7 +1,7 @@ [config] model="gpt-4" fallback_models=["gpt-3.5-turbo-16k"] -git_provider="github" +git_provider="gerrit" publish_output=true publish_output_progress=true verbosity_level=0 # 0,1,2 @@ -84,4 +84,14 @@ polling_interval_seconds = 30 [local] # LocalGitProvider settings - uncomment to use paths other than default # description_path= "path/to/description.md" -# review_path= "path/to/review.md" \ No newline at end of file +# review_path= "path/to/review.md" + +[gerrit] +# endpoint to the gerrit service +# url = "ssh://gerrit.example.com:29418" +# user for gerrit authentication +# user = "ai-reviewer" +# patch server where patches will be saved +# patch_server_endpoint = "http://127.0.0.1:5000/patch" +# token to authenticate in the patch server +# patch_server_token = ""