mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 05:10:38 +08:00
Support Code Suggestion in Gitlab
This commit is contained in:
@ -77,7 +77,7 @@ To set up your own PR-Agent, see the [Quickstart](#Quickstart) section
|
|||||||
| | ⮑ Inline review | ✓ | ✓ | |
|
| | ⮑ Inline review | ✓ | ✓ | |
|
||||||
| | Ask | ✓ | ✓ | |
|
| | Ask | ✓ | ✓ | |
|
||||||
| | Auto-Description | ✓ | | |
|
| | Auto-Description | ✓ | | |
|
||||||
| | Improve Code | ✓ | | |
|
| | Improve Code | ✓ | ✓ | |
|
||||||
| | | | | |
|
| | | | | |
|
||||||
| USAGE | CLI | ✓ | ✓ | ✓ |
|
| USAGE | CLI | ✓ | ✓ | ✓ |
|
||||||
| | Tagging bot | ✓ | ✓ | |
|
| | Tagging bot | ✓ | ✓ | |
|
||||||
|
@ -61,18 +61,24 @@ def parse_code_suggestion(code_suggestions: dict) -> str:
|
|||||||
return markdown_text
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
def try_fix_json(review, max_iter=10):
|
def try_fix_json(review, max_iter=10, code_suggestions=False):
|
||||||
|
if review.endswith("}"):
|
||||||
|
return fix_json_escape_char(review)
|
||||||
# Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
|
# Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
|
||||||
data = {}
|
data = {}
|
||||||
|
if code_suggestions:
|
||||||
|
closing_bracket = "]}"
|
||||||
|
else:
|
||||||
|
closing_bracket = "]}}"
|
||||||
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
|
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
|
||||||
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
|
||||||
valid_json = False
|
valid_json = False
|
||||||
iter_count = 0
|
iter_count = 0
|
||||||
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
|
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
|
||||||
try:
|
try:
|
||||||
data = json.loads(review[:last_code_suggestion_ind] + "]}}")
|
data = json.loads(review[:last_code_suggestion_ind] + closing_bracket)
|
||||||
valid_json = True
|
valid_json = True
|
||||||
review = review[:last_code_suggestion_ind].strip() + "]}}"
|
review = review[:last_code_suggestion_ind].strip() + closing_bracket
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
review = review[:last_code_suggestion_ind]
|
review = review[:last_code_suggestion_ind]
|
||||||
# Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
|
# Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
|
||||||
@ -82,3 +88,17 @@ def try_fix_json(review, max_iter=10):
|
|||||||
logging.error("Unable to decode JSON response from AI")
|
logging.error("Unable to decode JSON response from AI")
|
||||||
data = {}
|
data = {}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def fix_json_escape_char(json_message=None):
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
|
result = json.loads(json_message)
|
||||||
|
except Exception as e:
|
||||||
|
# Find the offending character index:
|
||||||
|
idx_to_replace = int(str(e).split(' ')[-1].replace(')', ''))
|
||||||
|
# Remove the offending character:
|
||||||
|
json_message = list(json_message)
|
||||||
|
json_message[idx_to_replace] = ' '
|
||||||
|
new_message = ''.join(json_message)
|
||||||
|
return fix_JSON(json_message=new_message)
|
||||||
|
return result
|
@ -28,6 +28,8 @@ class GitLabProvider(GitProvider):
|
|||||||
self.diff_files = None
|
self.diff_files = None
|
||||||
self.temp_comments = []
|
self.temp_comments = []
|
||||||
self._set_merge_request(merge_request_url)
|
self._set_merge_request(merge_request_url)
|
||||||
|
self.RE_HUNK_HEADER = re.compile(
|
||||||
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pr(self):
|
def pr(self):
|
||||||
@ -84,25 +86,26 @@ class GitLabProvider(GitProvider):
|
|||||||
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||||
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
|
||||||
relevant_line_in_file)
|
relevant_line_in_file)
|
||||||
|
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||||
|
target_file, target_line_no)
|
||||||
|
|
||||||
|
def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||||
|
target_file, target_line_no):
|
||||||
if not found:
|
if not found:
|
||||||
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
|
||||||
else:
|
else:
|
||||||
if edit_type == 'addition':
|
|
||||||
position = target_line_no - 1
|
|
||||||
else:
|
|
||||||
position = source_line_no - 1
|
|
||||||
d = self.last_diff
|
d = self.last_diff
|
||||||
pos_obj = {'position_type': 'text',
|
pos_obj = {'position_type': 'text',
|
||||||
'new_path': target_file.filename,
|
'new_path': target_file.filename,
|
||||||
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||||
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
|
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
|
||||||
if edit_type == 'deletion':
|
if edit_type == 'deletion':
|
||||||
pos_obj['old_line'] = position
|
pos_obj['old_line'] = source_line_no - 1
|
||||||
elif edit_type == 'addition':
|
elif edit_type == 'addition':
|
||||||
pos_obj['new_line'] = position
|
pos_obj['new_line'] = target_line_no - 1
|
||||||
else:
|
else:
|
||||||
pos_obj['new_line'] = position
|
pos_obj['new_line'] = target_line_no - 1
|
||||||
pos_obj['old_line'] = position
|
pos_obj['old_line'] = source_line_no - 1
|
||||||
self.mr.discussions.create({'body': body,
|
self.mr.discussions.create({'body': body,
|
||||||
'position': pos_obj})
|
'position': pos_obj})
|
||||||
|
|
||||||
@ -110,47 +113,81 @@ class GitLabProvider(GitProvider):
|
|||||||
relevant_file: str,
|
relevant_file: str,
|
||||||
relevant_lines_start: int,
|
relevant_lines_start: int,
|
||||||
relevant_lines_end: int):
|
relevant_lines_end: int):
|
||||||
raise "not implemented yet for gitlab"
|
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
|
||||||
|
target_file = None
|
||||||
|
for file in self.diff_files:
|
||||||
|
if file.filename == relevant_file:
|
||||||
|
if file.filename == relevant_file:
|
||||||
|
target_file = file
|
||||||
|
break
|
||||||
|
range = relevant_lines_end - relevant_lines_start + 1
|
||||||
|
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
|
||||||
|
|
||||||
|
d = self.last_diff
|
||||||
|
#
|
||||||
|
# pos_obj = {'position_type': 'text',
|
||||||
|
# 'new_path': target_file.filename,
|
||||||
|
# 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
|
||||||
|
# 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
|
||||||
|
lines = target_file.head_file.splitlines()
|
||||||
|
relevant_line_in_file = lines[relevant_lines_start - 1]
|
||||||
|
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file, relevant_line_in_file)
|
||||||
|
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
|
||||||
|
target_file, target_line_no)
|
||||||
|
# if lines[relevant_lines_start][0] == '-':
|
||||||
|
# pos_obj['old_line'] = relevant_lines_start
|
||||||
|
# elif lines[relevant_lines_start][0] == '+':
|
||||||
|
# pos_obj['new_line'] = relevant_lines_start
|
||||||
|
# else:
|
||||||
|
# pos_obj['new_line'] = relevant_lines_start
|
||||||
|
# pos_obj['old_line'] = relevant_lines_start
|
||||||
|
# self.mr.discussions.create({'body': body,
|
||||||
|
# 'position': pos_obj})
|
||||||
|
|
||||||
def search_line(self, relevant_file, relevant_line_in_file):
|
def search_line(self, relevant_file, relevant_line_in_file):
|
||||||
RE_HUNK_HEADER = re.compile(
|
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
|
||||||
target_file = None
|
target_file = None
|
||||||
source_line_no = 0
|
|
||||||
target_line_no = 0
|
|
||||||
found = False
|
|
||||||
edit_type = self.get_edit_type(relevant_line_in_file)
|
edit_type = self.get_edit_type(relevant_line_in_file)
|
||||||
for file in self.diff_files:
|
for file in self.diff_files:
|
||||||
if file.filename == relevant_file:
|
if file.filename == relevant_file:
|
||||||
target_file = file
|
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
|
||||||
patch = file.patch
|
relevant_line_in_file)
|
||||||
patch_lines = patch.splitlines()
|
return edit_type, found, source_line_no, target_file, target_line_no
|
||||||
for i, line in enumerate(patch_lines):
|
|
||||||
if line.startswith('@@'):
|
def find_in_file(self, file, relevant_line_in_file):
|
||||||
match = RE_HUNK_HEADER.match(line)
|
edit_type = 'context'
|
||||||
if not match:
|
source_line_no = 0
|
||||||
continue
|
target_line_no = 0
|
||||||
start_old, size_old, start_new, size_new, _ = match.groups()
|
found = False
|
||||||
source_line_no = int(start_old)
|
target_file = file
|
||||||
target_line_no = int(start_new)
|
patch = file.patch
|
||||||
continue
|
patch_lines = patch.splitlines()
|
||||||
if line.startswith('-'):
|
for i, line in enumerate(patch_lines):
|
||||||
source_line_no += 1
|
if line.startswith('@@'):
|
||||||
elif line.startswith('+'):
|
match = self.RE_HUNK_HEADER.match(line)
|
||||||
target_line_no += 1
|
if not match:
|
||||||
elif line.startswith(' '):
|
continue
|
||||||
source_line_no += 1
|
start_old, size_old, start_new, size_new, _ = match.groups()
|
||||||
target_line_no += 1
|
source_line_no = int(start_old)
|
||||||
if relevant_line_in_file in line:
|
target_line_no = int(start_new)
|
||||||
found = True
|
continue
|
||||||
edit_type = self.get_edit_type(line)
|
if line.startswith('-'):
|
||||||
break
|
source_line_no += 1
|
||||||
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
elif line.startswith('+'):
|
||||||
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
target_line_no += 1
|
||||||
# it's a context line
|
elif line.startswith(' '):
|
||||||
found = True
|
source_line_no += 1
|
||||||
edit_type = self.get_edit_type(line)
|
target_line_no += 1
|
||||||
break
|
if relevant_line_in_file in line:
|
||||||
|
found = True
|
||||||
|
edit_type = self.get_edit_type(line)
|
||||||
|
break
|
||||||
|
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
|
||||||
|
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
|
||||||
|
# it's a context line
|
||||||
|
found = True
|
||||||
|
edit_type = self.get_edit_type(line)
|
||||||
|
break
|
||||||
return edit_type, found, source_line_no, target_file, target_line_no
|
return edit_type, found, source_line_no, target_file, target_line_no
|
||||||
|
|
||||||
def get_edit_type(self, relevant_line_in_file):
|
def get_edit_type(self, relevant_line_in_file):
|
||||||
|
@ -10,7 +10,7 @@ from pr_agent.algo.pr_processing import get_pr_diff
|
|||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers import get_git_provider, GithubProvider
|
from pr_agent.git_providers import get_git_provider, BitbucketProvider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class PRCodeSuggestions:
|
|||||||
settings.pr_code_suggestions_prompt.user)
|
settings.pr_code_suggestions_prompt.user)
|
||||||
|
|
||||||
async def suggest(self):
|
async def suggest(self):
|
||||||
assert type(self.git_provider) == GithubProvider, "Only Github is supported for now"
|
assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
|
||||||
|
|
||||||
logging.info('Generating code suggestions for PR...')
|
logging.info('Generating code suggestions for PR...')
|
||||||
if settings.config.publish_review:
|
if settings.config.publish_review:
|
||||||
@ -86,7 +86,7 @@ class PRCodeSuggestions:
|
|||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not parse json response: {review}")
|
logging.info(f"Could not parse json response: {review}")
|
||||||
data = try_fix_json(review)
|
data = try_fix_json(review, code_suggestions=True)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def push_inline_code_suggestions(self, data):
|
def push_inline_code_suggestions(self, data):
|
||||||
|
Reference in New Issue
Block a user