mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 05:10:38 +08:00
Merge pull request #48 from Codium-ai/hl/gitlab_fix
Inline suggestion refactor + supporting GitLab
This commit is contained in:
@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
|
|||||||
if not patch_str or num_lines == 0:
|
if not patch_str or num_lines == 0:
|
||||||
return patch_str
|
return patch_str
|
||||||
|
|
||||||
|
if type(original_file_str) == bytes:
|
||||||
|
original_file_str = original_file_str.decode('utf-8')
|
||||||
|
|
||||||
original_lines = original_file_str.splitlines()
|
original_lines = original_file_str.splitlines()
|
||||||
patch_lines = patch_str.splitlines()
|
patch_lines = patch_str.splitlines()
|
||||||
extended_patch_lines = []
|
extended_patch_lines = []
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||||
|
from enum import Enum
|
||||||
|
class EDIT_TYPE(Enum):
|
||||||
|
ADDED = 1
|
||||||
|
DELETED = 2
|
||||||
|
MODIFIED = 3
|
||||||
|
RENAMED = 4
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilePatchInfo:
|
class FilePatchInfo:
|
||||||
@ -9,6 +16,8 @@ class FilePatchInfo:
|
|||||||
patch: str
|
patch: str
|
||||||
filename: str
|
filename: str
|
||||||
tokens: int = -1
|
tokens: int = -1
|
||||||
|
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
|
||||||
|
old_filename: str = None
|
||||||
|
|
||||||
|
|
||||||
class GitProvider(ABC):
|
class GitProvider(ABC):
|
||||||
@ -24,6 +33,10 @@ class GitProvider(ABC):
|
|||||||
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_initial_comment(self):
|
def remove_initial_comment(self):
|
||||||
pass
|
pass
|
||||||
|
@ -18,8 +18,10 @@ class GithubProvider:
|
|||||||
self.pr_num = None
|
self.pr_num = None
|
||||||
self.pr = None
|
self.pr = None
|
||||||
self.github_user_id = None
|
self.github_user_id = None
|
||||||
|
self.diff_files = None
|
||||||
if pr_url:
|
if pr_url:
|
||||||
self.set_pr(pr_url)
|
self.set_pr(pr_url)
|
||||||
|
self.last_commit_id = list(self.pr.get_commits())[-1]
|
||||||
|
|
||||||
def set_pr(self, pr_url: str):
|
def set_pr(self, pr_url: str):
|
||||||
self.repo, self.pr_num = self._parse_pr_url(pr_url)
|
self.repo, self.pr_num = self._parse_pr_url(pr_url)
|
||||||
@ -35,6 +37,7 @@ class GithubProvider:
|
|||||||
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
|
||||||
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
|
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
|
||||||
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
|
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
|
||||||
|
self.diff_files = diff_files
|
||||||
return diff_files
|
return diff_files
|
||||||
|
|
||||||
def publish_description(self, pr_title: str, pr_body: str):
|
def publish_description(self, pr_title: str, pr_body: str):
|
||||||
@ -50,6 +53,29 @@ class GithubProvider:
|
|||||||
self.pr.comments_list = []
|
self.pr.comments_list = []
|
||||||
self.pr.comments_list.append(response)
|
self.pr.comments_list.append(response)
|
||||||
|
|
||||||
|
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||||
|
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||||
|
position = -1
|
||||||
|
for file in self.diff_files:
|
||||||
|
if file.filename.strip() == relevant_file:
|
||||||
|
patch = file.patch
|
||||||
|
patch_lines = patch.splitlines()
|
||||||
|
for i, line in enumerate(patch_lines):
|
||||||
|
if relevant_line_in_file in line:
|
||||||
|
position = i
|
||||||
|
break
|
||||||
|
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
||||||
|
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||||
|
# it's a context line
|
||||||
|
position = i
|
||||||
|
break
|
||||||
|
if position == -1:
|
||||||
|
if settings.config.verbosity_level >= 2:
|
||||||
|
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||||
|
else:
|
||||||
|
path = relevant_file.strip()
|
||||||
|
self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position)
|
||||||
|
|
||||||
def remove_initial_comment(self):
|
def remove_initial_comment(self):
|
||||||
try:
|
try:
|
||||||
for comment in self.pr.comments_list:
|
for comment in self.pr.comments_list:
|
||||||
@ -150,9 +176,9 @@ class GithubProvider:
|
|||||||
def _get_pr(self):
|
def _get_pr(self):
|
||||||
return self._get_repo().get_pull(self.pr_num)
|
return self._get_repo().get_pull(self.pr_num)
|
||||||
|
|
||||||
def _get_pr_file_content(self, file: FilePatchInfo, sha: str):
|
def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str:
|
||||||
try:
|
try:
|
||||||
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()
|
file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode())
|
||||||
except Exception:
|
except Exception:
|
||||||
file_content_str = ""
|
file_content_str = ""
|
||||||
return file_content_str
|
return file_content_str
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -6,7 +7,7 @@ import gitlab
|
|||||||
|
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
|
|
||||||
from .git_provider import FilePatchInfo, GitProvider
|
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||||
|
|
||||||
|
|
||||||
class GitLabProvider(GitProvider):
|
class GitLabProvider(GitProvider):
|
||||||
@ -24,6 +25,7 @@ class GitLabProvider(GitProvider):
|
|||||||
self.id_project = None
|
self.id_project = None
|
||||||
self.id_mr = None
|
self.id_mr = None
|
||||||
self.mr = None
|
self.mr = None
|
||||||
|
self.diff_files = None
|
||||||
self.temp_comments = []
|
self.temp_comments = []
|
||||||
self._set_merge_request(merge_request_url)
|
self._set_merge_request(merge_request_url)
|
||||||
|
|
||||||
@ -35,10 +37,35 @@ class GitLabProvider(GitProvider):
|
|||||||
def _set_merge_request(self, merge_request_url: str):
|
def _set_merge_request(self, merge_request_url: str):
|
||||||
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
|
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
|
||||||
self.mr = self._get_merge_request()
|
self.mr = self._get_merge_request()
|
||||||
|
self.last_diff = self.mr.diffs.list()[-1]
|
||||||
|
|
||||||
|
def _get_pr_file_content(self, file_path: str, branch: str) -> str:
|
||||||
|
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
|
||||||
|
|
||||||
def get_diff_files(self) -> list[FilePatchInfo]:
|
def get_diff_files(self) -> list[FilePatchInfo]:
|
||||||
diffs = self.mr.changes()['changes']
|
diffs = self.mr.changes()['changes']
|
||||||
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs]
|
diff_files = []
|
||||||
|
for diff in diffs:
|
||||||
|
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
|
||||||
|
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
|
||||||
|
edit_type = EDIT_TYPE.MODIFIED
|
||||||
|
if diff['new_file']:
|
||||||
|
edit_type = EDIT_TYPE.ADDED
|
||||||
|
elif diff['deleted_file']:
|
||||||
|
edit_type = EDIT_TYPE.DELETED
|
||||||
|
elif diff['renamed_file']:
|
||||||
|
edit_type = EDIT_TYPE.RENAMED
|
||||||
|
try:
|
||||||
|
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
|
||||||
|
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logging.warning(
|
||||||
|
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
|
||||||
|
diff_files.append(
|
||||||
|
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
|
||||||
|
edit_type=edit_type,
|
||||||
|
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
|
||||||
|
self.diff_files = diff_files
|
||||||
return diff_files
|
return diff_files
|
||||||
|
|
||||||
def get_files(self):
|
def get_files(self):
|
||||||
@ -53,6 +80,81 @@ class GitLabProvider(GitProvider):
|
|||||||
if is_temporary:
|
if is_temporary:
|
||||||
self.temp_comments.append(comment)
|
self.temp_comments.append(comment)
|
||||||
|
|
||||||
|
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
|
||||||
|
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||||
|
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
||||||
|
relevant_line_in_file)
|
||||||
|
if not found:
|
||||||
|
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||||
|
else:
|
||||||
|
if edit_type == 'addition':
|
||||||
|
position = target_line_no - 1
|
||||||
|
else:
|
||||||
|
position = source_line_no - 1
|
||||||
|
d = self.last_diff
|
||||||
|
pos_obj = {'position_type': 'text',
|
||||||
|
'new_path': target_file.filename,
|
||||||
|
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||||
|
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
|
||||||
|
if edit_type == 'deletion':
|
||||||
|
pos_obj['old_line'] = position
|
||||||
|
elif edit_type == 'addition':
|
||||||
|
pos_obj['new_line'] = position
|
||||||
|
else:
|
||||||
|
pos_obj['new_line'] = position
|
||||||
|
pos_obj['old_line'] = position
|
||||||
|
self.mr.discussions.create({'body': body,
|
||||||
|
'position': pos_obj})
|
||||||
|
|
||||||
|
def search_line(self, relevant_file, relevant_line_in_file):
|
||||||
|
RE_HUNK_HEADER = re.compile(
|
||||||
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
|
target_file = None
|
||||||
|
source_line_no = 0
|
||||||
|
target_line_no = 0
|
||||||
|
found = False
|
||||||
|
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||||
|
for file in self.diff_files:
|
||||||
|
if file.filename == relevant_file:
|
||||||
|
target_file = file
|
||||||
|
patch = file.patch
|
||||||
|
patch_lines = patch.splitlines()
|
||||||
|
for i, line in enumerate(patch_lines):
|
||||||
|
if line.startswith('@@'):
|
||||||
|
match = RE_HUNK_HEADER.match(line)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
start_old, size_old, start_new, size_new, _ = match.groups()
|
||||||
|
source_line_no = int(start_old)
|
||||||
|
target_line_no = int(start_new)
|
||||||
|
continue
|
||||||
|
if line.startswith('-'):
|
||||||
|
source_line_no += 1
|
||||||
|
elif line.startswith('+'):
|
||||||
|
target_line_no += 1
|
||||||
|
elif line.startswith(' '):
|
||||||
|
source_line_no += 1
|
||||||
|
target_line_no += 1
|
||||||
|
if relevant_line_in_file in line:
|
||||||
|
found = True
|
||||||
|
edit_type = self.get_edit_type(line)
|
||||||
|
break
|
||||||
|
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
||||||
|
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||||
|
# it's a context line
|
||||||
|
found = True
|
||||||
|
edit_type = self.get_edit_type(line)
|
||||||
|
break
|
||||||
|
return edit_type, found, source_line_no, target_file, target_line_no
|
||||||
|
|
||||||
|
def get_edit_type(self, relevant_line_in_file):
|
||||||
|
edit_type = 'context'
|
||||||
|
if relevant_line_in_file[0] == '-':
|
||||||
|
edit_type = 'deletion'
|
||||||
|
elif relevant_line_in_file[0] == '+':
|
||||||
|
edit_type = 'addition'
|
||||||
|
return edit_type
|
||||||
|
|
||||||
def remove_initial_comment(self):
|
def remove_initial_comment(self):
|
||||||
try:
|
try:
|
||||||
for comment in self.temp_comments:
|
for comment in self.temp_comments:
|
||||||
@ -94,3 +196,6 @@ class GitLabProvider(GitProvider):
|
|||||||
def _get_merge_request(self):
|
def _get_merge_request(self):
|
||||||
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
|
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
|
||||||
return mr
|
return mr
|
||||||
|
|
||||||
|
def get_user_id(self):
|
||||||
|
return None
|
||||||
|
@ -111,39 +111,15 @@ class PRReviewer:
|
|||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
def _publish_inline_code_comments(self):
|
def _publish_inline_code_comments(self):
|
||||||
if settings.config.git_provider != 'github': # inline comments are currently only supported for github
|
|
||||||
return
|
|
||||||
|
|
||||||
review = self.prediction.strip()
|
review = self.prediction.strip()
|
||||||
try:
|
try:
|
||||||
data = json.loads(review)
|
data = json.loads(review)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
data = try_fix_json(review)
|
data = try_fix_json(review)
|
||||||
|
|
||||||
pr = self.git_provider.pr
|
|
||||||
last_commit_id = list(pr.get_commits())[-1]
|
|
||||||
if hasattr(pr, 'diff_files'): # prevent bringing all the files again
|
|
||||||
diff_files = pr.diff_files
|
|
||||||
else:
|
|
||||||
diff_files = list(self.git_provider.get_diff_files())
|
|
||||||
|
|
||||||
for d in data['PR Feedback']['Code suggestions']:
|
for d in data['PR Feedback']['Code suggestions']:
|
||||||
relevant_file = d['relevant file'].strip()
|
relevant_file = d['relevant file'].strip()
|
||||||
relevant_line_in_file = d['relevant line in file'].strip()
|
relevant_line_in_file = d['relevant line in file'].strip()
|
||||||
content = d['suggestion content']
|
content = d['suggestion content']
|
||||||
position = -1
|
|
||||||
for file in diff_files:
|
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
|
||||||
if file.filename.strip() == relevant_file:
|
|
||||||
patch = file.patch
|
|
||||||
patch_lines = patch.splitlines()
|
|
||||||
for i, line in enumerate(patch_lines):
|
|
||||||
if relevant_line_in_file in line:
|
|
||||||
position = i
|
|
||||||
break
|
|
||||||
if position == -1:
|
|
||||||
if settings.config.verbosity_level >= 2:
|
|
||||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
|
||||||
else:
|
|
||||||
body = content
|
|
||||||
path = relevant_file.strip()
|
|
||||||
pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position)
|
|
Reference in New Issue
Block a user