mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Merge pull request #48 from Codium-ai/hl/gitlab_fix
Inline suggestion refactor + supporting GitLab
This commit is contained in:
@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
||||
if not patch_str or num_lines == 0:
|
||||
return patch_str
|
||||
|
||||
if type(original_file_str) == bytes:
|
||||
original_file_str = original_file_str.decode('utf-8')
|
||||
|
||||
original_lines = original_file_str.splitlines()
|
||||
patch_lines = patch_str.splitlines()
|
||||
extended_patch_lines = []
|
||||
|
@ -1,6 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||
from enum import Enum
|
||||
class EDIT_TYPE(Enum):
|
||||
ADDED = 1
|
||||
DELETED = 2
|
||||
MODIFIED = 3
|
||||
RENAMED = 4
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
@ -9,6 +16,8 @@ class FilePatchInfo:
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
|
||||
old_filename: str = None
|
||||
|
||||
|
||||
class GitProvider(ABC):
|
||||
@ -24,6 +33,10 @@ class GitProvider(ABC):
|
||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_initial_comment(self):
|
||||
pass
|
||||
|
@ -18,8 +18,10 @@ class GithubProvider:
|
||||
self.pr_num = None
|
||||
self.pr = None
|
||||
self.github_user_id = None
|
||||
self.diff_files = None
|
||||
if pr_url:
|
||||
self.set_pr(pr_url)
|
||||
self.last_commit_id = list(self.pr.get_commits())[-1]
|
||||
|
||||
def set_pr(self, pr_url: str):
|
||||
self.repo, self.pr_num = self._parse_pr_url(pr_url)
|
||||
@ -35,6 +37,7 @@ class GithubProvider:
|
||||
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
||||
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
|
||||
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def publish_description(self, pr_title: str, pr_body: str):
|
||||
@ -50,6 +53,29 @@ class GithubProvider:
|
||||
self.pr.comments_list = []
|
||||
self.pr.comments_list.append(response)
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||
position = -1
|
||||
for file in self.diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if relevant_line_in_file in line:
|
||||
position = i
|
||||
break
|
||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
position = i
|
||||
break
|
||||
if position == -1:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
else:
|
||||
path = relevant_file.strip()
|
||||
self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position)
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.pr.comments_list:
|
||||
@ -150,9 +176,9 @@ class GithubProvider:
|
||||
def _get_pr(self):
|
||||
return self._get_repo().get_pull(self.pr_num)
|
||||
|
||||
def _get_pr_file_content(self, file: FilePatchInfo, sha: str):
|
||||
def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str:
|
||||
try:
|
||||
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()
|
||||
file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode())
|
||||
except Exception:
|
||||
file_content_str = ""
|
||||
return file_content_str
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -6,7 +7,7 @@ import gitlab
|
||||
|
||||
from pr_agent.config_loader import settings
|
||||
|
||||
from .git_provider import FilePatchInfo, GitProvider
|
||||
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||
|
||||
|
||||
class GitLabProvider(GitProvider):
|
||||
@ -24,6 +25,7 @@ class GitLabProvider(GitProvider):
|
||||
self.id_project = None
|
||||
self.id_mr = None
|
||||
self.mr = None
|
||||
self.diff_files = None
|
||||
self.temp_comments = []
|
||||
self._set_merge_request(merge_request_url)
|
||||
|
||||
@ -35,10 +37,35 @@ class GitLabProvider(GitProvider):
|
||||
def _set_merge_request(self, merge_request_url: str):
|
||||
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
|
||||
self.mr = self._get_merge_request()
|
||||
self.last_diff = self.mr.diffs.list()[-1]
|
||||
|
||||
def _get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
|
||||
|
||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||
diffs = self.mr.changes()['changes']
|
||||
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs]
|
||||
diff_files = []
|
||||
for diff in diffs:
|
||||
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
|
||||
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
|
||||
edit_type = EDIT_TYPE.MODIFIED
|
||||
if diff['new_file']:
|
||||
edit_type = EDIT_TYPE.ADDED
|
||||
elif diff['deleted_file']:
|
||||
edit_type = EDIT_TYPE.DELETED
|
||||
elif diff['renamed_file']:
|
||||
edit_type = EDIT_TYPE.RENAMED
|
||||
try:
|
||||
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
|
||||
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
||||
except UnicodeDecodeError:
|
||||
logging.warning(
|
||||
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
|
||||
diff_files.append(
|
||||
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
|
||||
edit_type=edit_type,
|
||||
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
|
||||
self.diff_files = diff_files
|
||||
return diff_files
|
||||
|
||||
def get_files(self):
|
||||
@ -53,6 +80,81 @@ class GitLabProvider(GitProvider):
|
||||
if is_temporary:
|
||||
self.temp_comments.append(comment)
|
||||
|
||||
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
||||
relevant_line_in_file)
|
||||
if not found:
|
||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
else:
|
||||
if edit_type == 'addition':
|
||||
position = target_line_no - 1
|
||||
else:
|
||||
position = source_line_no - 1
|
||||
d = self.last_diff
|
||||
pos_obj = {'position_type': 'text',
|
||||
'new_path': target_file.filename,
|
||||
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
|
||||
if edit_type == 'deletion':
|
||||
pos_obj['old_line'] = position
|
||||
elif edit_type == 'addition':
|
||||
pos_obj['new_line'] = position
|
||||
else:
|
||||
pos_obj['new_line'] = position
|
||||
pos_obj['old_line'] = position
|
||||
self.mr.discussions.create({'body': body,
|
||||
'position': pos_obj})
|
||||
|
||||
def search_line(self, relevant_file, relevant_line_in_file):
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
target_file = None
|
||||
source_line_no = 0
|
||||
target_line_no = 0
|
||||
found = False
|
||||
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||
for file in self.diff_files:
|
||||
if file.filename == relevant_file:
|
||||
target_file = file
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
match = RE_HUNK_HEADER.match(line)
|
||||
if not match:
|
||||
continue
|
||||
start_old, size_old, start_new, size_new, _ = match.groups()
|
||||
source_line_no = int(start_old)
|
||||
target_line_no = int(start_new)
|
||||
continue
|
||||
if line.startswith('-'):
|
||||
source_line_no += 1
|
||||
elif line.startswith('+'):
|
||||
target_line_no += 1
|
||||
elif line.startswith(' '):
|
||||
source_line_no += 1
|
||||
target_line_no += 1
|
||||
if relevant_line_in_file in line:
|
||||
found = True
|
||||
edit_type = self.get_edit_type(line)
|
||||
break
|
||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
found = True
|
||||
edit_type = self.get_edit_type(line)
|
||||
break
|
||||
return edit_type, found, source_line_no, target_file, target_line_no
|
||||
|
||||
def get_edit_type(self, relevant_line_in_file):
|
||||
edit_type = 'context'
|
||||
if relevant_line_in_file[0] == '-':
|
||||
edit_type = 'deletion'
|
||||
elif relevant_line_in_file[0] == '+':
|
||||
edit_type = 'addition'
|
||||
return edit_type
|
||||
|
||||
def remove_initial_comment(self):
|
||||
try:
|
||||
for comment in self.temp_comments:
|
||||
@ -94,3 +196,6 @@ class GitLabProvider(GitProvider):
|
||||
def _get_merge_request(self):
|
||||
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
|
||||
return mr
|
||||
|
||||
def get_user_id(self):
|
||||
return None
|
||||
|
@ -111,39 +111,15 @@ class PRReviewer:
|
||||
return markdown_text
|
||||
|
||||
def _publish_inline_code_comments(self):
|
||||
if settings.config.git_provider != 'github': # inline comments are currently only supported for github
|
||||
return
|
||||
|
||||
review = self.prediction.strip()
|
||||
try:
|
||||
data = json.loads(review)
|
||||
except json.decoder.JSONDecodeError:
|
||||
data = try_fix_json(review)
|
||||
|
||||
pr = self.git_provider.pr
|
||||
last_commit_id = list(pr.get_commits())[-1]
|
||||
if hasattr(pr, 'diff_files'): # prevent bringing all the files again
|
||||
diff_files = pr.diff_files
|
||||
else:
|
||||
diff_files = list(self.git_provider.get_diff_files())
|
||||
|
||||
for d in data['PR Feedback']['Code suggestions']:
|
||||
relevant_file = d['relevant file'].strip()
|
||||
relevant_line_in_file = d['relevant line in file'].strip()
|
||||
content = d['suggestion content']
|
||||
position = -1
|
||||
for file in diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if relevant_line_in_file in line:
|
||||
position = i
|
||||
break
|
||||
if position == -1:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||
else:
|
||||
body = content
|
||||
path = relevant_file.strip()
|
||||
pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position)
|
||||
|
||||
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
|
Reference in New Issue
Block a user