mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
400 lines
12 KiB
Python
400 lines
12 KiB
Python
import json
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
import subprocess
|
|
import uuid
|
|
from collections import Counter, namedtuple
|
|
from pathlib import Path
|
|
from tempfile import NamedTemporaryFile, mkdtemp
|
|
|
|
import requests
|
|
import urllib3.util
|
|
from git import Repo
|
|
|
|
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
|
from pr_agent.config_loader import get_settings
|
|
from pr_agent.git_providers.git_provider import GitProvider
|
|
from pr_agent.git_providers.local_git_provider import PullRequestMimic
|
|
from pr_agent.log import get_logger
|
|
|
|
|
|
def _call(*command, **kwargs) -> (int, str, str):
|
|
res = subprocess.run(
|
|
command,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
check=True,
|
|
**kwargs,
|
|
)
|
|
return res.stdout.decode()
|
|
|
|
|
|
def clone(url, directory):
|
|
get_logger().info("Cloning %s to %s", url, directory)
|
|
stdout = _call('git', 'clone', "--depth", "1", url, directory)
|
|
get_logger().info(stdout)
|
|
|
|
|
|
def fetch(url, refspec, cwd):
|
|
get_logger().info("Fetching %s %s", url, refspec)
|
|
stdout = _call(
|
|
'git', 'fetch', '--depth', '2', url, refspec,
|
|
cwd=cwd
|
|
)
|
|
get_logger().info(stdout)
|
|
|
|
|
|
def checkout(cwd):
|
|
get_logger().info("Checking out")
|
|
stdout = _call('git', 'checkout', "FETCH_HEAD", cwd=cwd)
|
|
get_logger().info(stdout)
|
|
|
|
|
|
def show(*args, cwd=None):
|
|
get_logger().info("Show")
|
|
return _call('git', 'show', *args, cwd=cwd)
|
|
|
|
|
|
def diff(*args, cwd=None):
|
|
get_logger().info("Diff")
|
|
patch = _call('git', 'diff', *args, cwd=cwd)
|
|
if not patch:
|
|
get_logger().warning("No changes found")
|
|
return
|
|
return patch
|
|
|
|
|
|
def reset_local_changes(cwd):
|
|
get_logger().info("Reset local changes")
|
|
_call('git', 'checkout', "--force", cwd=cwd)
|
|
|
|
|
|
def add_comment(url: urllib3.util.Url, refspec, message):
|
|
*_, patchset, changenum = refspec.rsplit("/")
|
|
message = "'" + message.replace("'", "'\"'\"'") + "'"
|
|
return _call(
|
|
"ssh",
|
|
"-p", str(url.port),
|
|
f"{url.auth}@{url.host}",
|
|
"gerrit", "review",
|
|
"--message", message,
|
|
# "--code-review", score,
|
|
f"{patchset},{changenum}",
|
|
)
|
|
|
|
|
|
def list_comments(url: urllib3.util.Url, refspec):
|
|
*_, patchset, _ = refspec.rsplit("/")
|
|
stdout = _call(
|
|
"ssh",
|
|
"-p", str(url.port),
|
|
f"{url.auth}@{url.host}",
|
|
"gerrit", "query",
|
|
"--comments",
|
|
"--current-patch-set", patchset,
|
|
"--format", "JSON",
|
|
)
|
|
change_set, *_ = stdout.splitlines()
|
|
return json.loads(change_set)["currentPatchSet"]["comments"]
|
|
|
|
|
|
def prepare_repo(url: urllib3.util.Url, project, refspec):
|
|
repo_url = (f"{url.scheme}://{url.auth}@{url.host}:{url.port}/{project}")
|
|
|
|
directory = pathlib.Path(mkdtemp())
|
|
clone(repo_url, directory),
|
|
fetch(repo_url, refspec, cwd=directory)
|
|
checkout(cwd=directory)
|
|
return directory
|
|
|
|
|
|
def adopt_to_gerrit_message(message):
|
|
lines = message.splitlines()
|
|
buf = []
|
|
for line in lines:
|
|
# remove markdown formatting
|
|
line = (line.replace("*", "")
|
|
.replace("``", "`")
|
|
.replace("<details>", "")
|
|
.replace("</details>", "")
|
|
.replace("<summary>", "")
|
|
.replace("</summary>", ""))
|
|
|
|
line = line.strip()
|
|
if line.startswith('#'):
|
|
buf.append("\n" +
|
|
line.replace('#', '').removesuffix(":").strip() +
|
|
":")
|
|
continue
|
|
elif line.startswith('-'):
|
|
buf.append(line.removeprefix('-').strip())
|
|
continue
|
|
else:
|
|
buf.append(line)
|
|
return "\n".join(buf).strip()
|
|
|
|
|
|
def add_suggestion(src_filename, context: str, start, end: int):
|
|
with (
|
|
NamedTemporaryFile("w", delete=False) as tmp,
|
|
open(src_filename, "r") as src
|
|
):
|
|
lines = src.readlines()
|
|
tmp.writelines(lines[:start - 1])
|
|
if context:
|
|
tmp.write(context)
|
|
tmp.writelines(lines[end:])
|
|
|
|
shutil.copy(tmp.name, src_filename)
|
|
os.remove(tmp.name)
|
|
|
|
|
|
def upload_patch(patch, path):
|
|
patch_server_endpoint = get_settings().get(
|
|
'gerrit.patch_server_endpoint')
|
|
patch_server_token = get_settings().get(
|
|
'gerrit.patch_server_token')
|
|
|
|
response = requests.post(
|
|
patch_server_endpoint,
|
|
json={
|
|
"content": patch,
|
|
"path": path,
|
|
},
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {patch_server_token}",
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
patch_server_endpoint = patch_server_endpoint.rstrip("/")
|
|
return patch_server_endpoint + "/" + path
|
|
|
|
|
|
class GerritProvider(GitProvider):
|
|
|
|
def __init__(self, key: str, incremental=False):
|
|
self.project, self.refspec = key.split(':')
|
|
assert self.project, "Project name is required"
|
|
assert self.refspec, "Refspec is required"
|
|
base_url = get_settings().get('gerrit.url')
|
|
assert base_url, "Gerrit URL is required"
|
|
user = get_settings().get('gerrit.user')
|
|
assert user, "Gerrit user is required"
|
|
|
|
parsed = urllib3.util.parse_url(base_url)
|
|
self.parsed_url = urllib3.util.parse_url(
|
|
f"{parsed.scheme}://{user}@{parsed.host}:{parsed.port}"
|
|
)
|
|
|
|
self.repo_path = prepare_repo(
|
|
self.parsed_url, self.project, self.refspec
|
|
)
|
|
self.repo = Repo(self.repo_path)
|
|
assert self.repo
|
|
self.pr_url = base_url
|
|
self.pr = PullRequestMimic(self.get_pr_title(), self.get_diff_files())
|
|
|
|
def get_pr_title(self):
|
|
"""
|
|
Substitutes the branch-name as the PR-mimic title.
|
|
"""
|
|
return self.repo.branches[0].name
|
|
|
|
def get_issue_comments(self):
|
|
comments = list_comments(self.parsed_url, self.refspec)
|
|
Comments = namedtuple('Comments', ['reversed'])
|
|
Comment = namedtuple('Comment', ['body'])
|
|
return Comments([Comment(c['message']) for c in reversed(comments)])
|
|
|
|
def get_pr_labels(self, update=False):
|
|
raise NotImplementedError(
|
|
'Getting labels is not implemented for the gerrit provider')
|
|
|
|
def add_eyes_reaction(self, issue_comment_id: int, disable_eyes: bool = False):
|
|
raise NotImplementedError(
|
|
'Adding reactions is not implemented for the gerrit provider')
|
|
|
|
def remove_reaction(self, issue_comment_id: int, reaction_id: int):
|
|
raise NotImplementedError(
|
|
'Removing reactions is not implemented for the gerrit provider')
|
|
|
|
def get_commit_messages(self):
|
|
return [self.repo.head.commit.message]
|
|
|
|
def get_repo_settings(self):
|
|
try:
|
|
with open(self.repo_path / ".pr_agent.toml", 'rb') as f:
|
|
contents = f.read()
|
|
return contents
|
|
except OSError:
|
|
return b""
|
|
|
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
|
diffs = self.repo.head.commit.diff(
|
|
self.repo.head.commit.parents[0], # previous commit
|
|
create_patch=True,
|
|
R=True
|
|
)
|
|
|
|
diff_files = []
|
|
for diff_item in diffs:
|
|
if diff_item.a_blob is not None:
|
|
original_file_content_str = (
|
|
diff_item.a_blob.data_stream.read().decode('utf-8')
|
|
)
|
|
else:
|
|
original_file_content_str = "" # empty file
|
|
if diff_item.b_blob is not None:
|
|
new_file_content_str = diff_item.b_blob.data_stream.read(). \
|
|
decode('utf-8')
|
|
else:
|
|
new_file_content_str = "" # empty file
|
|
edit_type = EDIT_TYPE.MODIFIED
|
|
if diff_item.new_file:
|
|
edit_type = EDIT_TYPE.ADDED
|
|
elif diff_item.deleted_file:
|
|
edit_type = EDIT_TYPE.DELETED
|
|
elif diff_item.renamed_file:
|
|
edit_type = EDIT_TYPE.RENAMED
|
|
diff_files.append(
|
|
FilePatchInfo(
|
|
original_file_content_str,
|
|
new_file_content_str,
|
|
diff_item.diff.decode('utf-8'),
|
|
diff_item.b_path,
|
|
edit_type=edit_type,
|
|
old_filename=None
|
|
if diff_item.a_path == diff_item.b_path
|
|
else diff_item.a_path
|
|
)
|
|
)
|
|
self.diff_files = diff_files
|
|
return diff_files
|
|
|
|
def get_files(self):
|
|
diff_index = self.repo.head.commit.diff(
|
|
self.repo.head.commit.parents[0], # previous commit
|
|
R=True
|
|
)
|
|
# Get the list of changed files
|
|
diff_files = [item.a_path for item in diff_index]
|
|
return diff_files
|
|
|
|
def get_languages(self):
|
|
"""
|
|
Calculate percentage of languages in repository. Used for hunk
|
|
prioritisation.
|
|
"""
|
|
# Get all files in repository
|
|
filepaths = [Path(item.path) for item in
|
|
self.repo.tree().traverse() if item.type == 'blob']
|
|
# Identify language by file extension and count
|
|
lang_count = Counter(
|
|
ext.lstrip('.') for filepath in filepaths for ext in
|
|
[filepath.suffix.lower()])
|
|
# Convert counts to percentages
|
|
total_files = len(filepaths)
|
|
lang_percentage = {lang: count / total_files * 100 for lang, count
|
|
in lang_count.items()}
|
|
return lang_percentage
|
|
|
|
def get_pr_description_full(self):
|
|
return self.repo.head.commit.message
|
|
|
|
def get_user_id(self):
|
|
return self.repo.head.commit.author.email
|
|
|
|
def is_supported(self, capability: str) -> bool:
|
|
if capability in [
|
|
# 'get_issue_comments',
|
|
'create_inline_comment',
|
|
'publish_inline_comments',
|
|
'get_labels',
|
|
'gfm_markdown'
|
|
]:
|
|
return False
|
|
return True
|
|
|
|
def split_suggestion(self, msg) -> tuple[str, str]:
|
|
is_code_context = False
|
|
description = []
|
|
context = []
|
|
for line in msg.splitlines():
|
|
if line.startswith('```suggestion'):
|
|
is_code_context = True
|
|
continue
|
|
if line.startswith('```'):
|
|
is_code_context = False
|
|
continue
|
|
if is_code_context:
|
|
context.append(line)
|
|
else:
|
|
description.append(
|
|
line.replace('*', '')
|
|
)
|
|
|
|
return (
|
|
'\n'.join(description),
|
|
'\n'.join(context) + '\n' if context else ''
|
|
)
|
|
|
|
def publish_code_suggestions(self, code_suggestions: list):
|
|
msg = []
|
|
for suggestion in code_suggestions:
|
|
description, code = self.split_suggestion(suggestion['body'])
|
|
add_suggestion(
|
|
pathlib.Path(self.repo_path) / suggestion["relevant_file"],
|
|
code,
|
|
suggestion["relevant_lines_start"],
|
|
suggestion["relevant_lines_end"],
|
|
)
|
|
patch = diff(cwd=self.repo_path)
|
|
patch_id = uuid.uuid4().hex[0:4]
|
|
path = "/".join(["codium-ai", self.refspec, patch_id])
|
|
full_path = upload_patch(patch, path)
|
|
reset_local_changes(self.repo_path)
|
|
msg.append(f'* {description}\n{full_path}')
|
|
|
|
if msg:
|
|
add_comment(self.parsed_url, self.refspec, "\n".join(msg))
|
|
return True
|
|
|
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
|
if not is_temporary:
|
|
msg = adopt_to_gerrit_message(pr_comment)
|
|
add_comment(self.parsed_url, self.refspec, msg)
|
|
|
|
def publish_description(self, pr_title: str, pr_body: str):
|
|
msg = adopt_to_gerrit_message(pr_body)
|
|
add_comment(self.parsed_url, self.refspec, pr_title + '\n' + msg)
|
|
|
|
def publish_inline_comments(self, comments: list[dict]):
|
|
raise NotImplementedError(
|
|
'Publishing inline comments is not implemented for the gerrit '
|
|
'provider')
|
|
|
|
def publish_inline_comment(self, body: str, relevant_file: str,
|
|
relevant_line_in_file: str, original_suggestion=None):
|
|
raise NotImplementedError(
|
|
'Publishing inline comments is not implemented for the gerrit '
|
|
'provider')
|
|
|
|
|
|
def publish_labels(self, labels):
|
|
# Not applicable to the local git provider,
|
|
# but required by the interface
|
|
pass
|
|
|
|
def remove_initial_comment(self):
|
|
# remove repo, cloned in previous steps
|
|
# shutil.rmtree(self.repo_path)
|
|
pass
|
|
|
|
def remove_comment(self, comment):
|
|
pass
|
|
|
|
def get_pr_branch(self):
|
|
return self.repo.head
|