Support Code Suggestion in Gitlab

This commit is contained in:
Hussam.lawen
2023-07-17 01:44:40 +03:00
parent fe98f67e08
commit fc309f69b9
4 changed files with 110 additions and 53 deletions

View File

@ -77,7 +77,7 @@ To set up your own PR-Agent, see the [Quickstart](#Quickstart) section
| | ⮑ Inline review | ✓ | ✓ | | | | ⮑ Inline review | ✓ | ✓ | |
| | Ask | ✓ | ✓ | | | | Ask | ✓ | ✓ | |
| | Auto-Description | ✓ | | | | | Auto-Description | ✓ | | |
| | Improve Code | ✓ | | | | | Improve Code | ✓ | | |
| | | | | | | | | | | |
| USAGE | CLI | ✓ | ✓ | ✓ | | USAGE | CLI | ✓ | ✓ | ✓ |
| | Tagging bot | ✓ | ✓ | | | | Tagging bot | ✓ | ✓ | |

View File

@ -61,18 +61,24 @@ def parse_code_suggestion(code_suggestions: dict) -> str:
return markdown_text return markdown_text
def try_fix_json(review, max_iter=10): def try_fix_json(review, max_iter=10, code_suggestions=False):
if review.endswith("}"):
return fix_json_escape_char(review)
# Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion # Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion
data = {} data = {}
if code_suggestions:
closing_bracket = "]}"
else:
closing_bracket = "]}}"
if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0: if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0:
last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1
valid_json = False valid_json = False
iter_count = 0 iter_count = 0
while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter: while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter:
try: try:
data = json.loads(review[:last_code_suggestion_ind] + "]}}") data = json.loads(review[:last_code_suggestion_ind] + closing_bracket)
valid_json = True valid_json = True
review = review[:last_code_suggestion_ind].strip() + "]}}" review = review[:last_code_suggestion_ind].strip() + closing_bracket
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
review = review[:last_code_suggestion_ind] review = review[:last_code_suggestion_ind]
# Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines # Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines
@ -82,3 +88,17 @@ def try_fix_json(review, max_iter=10):
logging.error("Unable to decode JSON response from AI") logging.error("Unable to decode JSON response from AI")
data = {} data = {}
return data return data
def fix_json_escape_char(json_message=None):
result = None
try:
result = json.loads(json_message)
except Exception as e:
# Find the offending character index:
idx_to_replace = int(str(e).split(' ')[-1].replace(')', ''))
# Remove the offending character:
json_message = list(json_message)
json_message[idx_to_replace] = ' '
new_message = ''.join(json_message)
return fix_JSON(json_message=new_message)
return result

View File

@ -28,6 +28,8 @@ class GitLabProvider(GitProvider):
self.diff_files = None self.diff_files = None
self.temp_comments = [] self.temp_comments = []
self._set_merge_request(merge_request_url) self._set_merge_request(merge_request_url)
self.RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
@property @property
def pr(self): def pr(self):
@ -84,25 +86,26 @@ class GitLabProvider(GitProvider):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file) relevant_line_in_file)
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no)
def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no):
if not found: if not found:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else: else:
if edit_type == 'addition':
position = target_line_no - 1
else:
position = source_line_no - 1
d = self.last_diff d = self.last_diff
pos_obj = {'position_type': 'text', pos_obj = {'position_type': 'text',
'new_path': target_file.filename, 'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha} 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
if edit_type == 'deletion': if edit_type == 'deletion':
pos_obj['old_line'] = position pos_obj['old_line'] = source_line_no - 1
elif edit_type == 'addition': elif edit_type == 'addition':
pos_obj['new_line'] = position pos_obj['new_line'] = target_line_no - 1
else: else:
pos_obj['new_line'] = position pos_obj['new_line'] = target_line_no - 1
pos_obj['old_line'] = position pos_obj['old_line'] = source_line_no - 1
self.mr.discussions.create({'body': body, self.mr.discussions.create({'body': body,
'position': pos_obj}) 'position': pos_obj})
@ -110,24 +113,58 @@ class GitLabProvider(GitProvider):
relevant_file: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_start: int,
relevant_lines_end: int): relevant_lines_end: int):
raise "not implemented yet for gitlab" self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
target_file = None
for file in self.diff_files:
if file.filename == relevant_file:
if file.filename == relevant_file:
target_file = file
break
range = relevant_lines_end - relevant_lines_start + 1
body = body.replace('```suggestion', f'```suggestion:-0+{range}')
d = self.last_diff
#
# pos_obj = {'position_type': 'text',
# 'new_path': target_file.filename,
# 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
# 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
lines = target_file.head_file.splitlines()
relevant_line_in_file = lines[relevant_lines_start - 1]
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file, relevant_line_in_file)
self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no,
target_file, target_line_no)
# if lines[relevant_lines_start][0] == '-':
# pos_obj['old_line'] = relevant_lines_start
# elif lines[relevant_lines_start][0] == '+':
# pos_obj['new_line'] = relevant_lines_start
# else:
# pos_obj['new_line'] = relevant_lines_start
# pos_obj['old_line'] = relevant_lines_start
# self.mr.discussions.create({'body': body,
# 'position': pos_obj})
def search_line(self, relevant_file, relevant_line_in_file): def search_line(self, relevant_file, relevant_line_in_file):
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
target_file = None target_file = None
source_line_no = 0
target_line_no = 0
found = False
edit_type = self.get_edit_type(relevant_line_in_file) edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.diff_files: for file in self.diff_files:
if file.filename == relevant_file: if file.filename == relevant_file:
edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file,
relevant_line_in_file)
return edit_type, found, source_line_no, target_file, target_line_no
def find_in_file(self, file, relevant_line_in_file):
edit_type = 'context'
source_line_no = 0
target_line_no = 0
found = False
target_file = file target_file = file
patch = file.patch patch = file.patch
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines): for i, line in enumerate(patch_lines):
if line.startswith('@@'): if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line) match = self.RE_HUNK_HEADER.match(line)
if not match: if not match:
continue continue
start_old, size_old, start_new, size_new, _ = match.groups() start_old, size_old, start_new, size_new, _ = match.groups()

View File

@ -10,7 +10,7 @@ from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider, GithubProvider from pr_agent.git_providers import get_git_provider, BitbucketProvider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -39,7 +39,7 @@ class PRCodeSuggestions:
settings.pr_code_suggestions_prompt.user) settings.pr_code_suggestions_prompt.user)
async def suggest(self): async def suggest(self):
assert type(self.git_provider) == GithubProvider, "Only Github is supported for now" assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...') logging.info('Generating code suggestions for PR...')
if settings.config.publish_review: if settings.config.publish_review:
@ -86,7 +86,7 @@ class PRCodeSuggestions:
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}") logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review) data = try_fix_json(review, code_suggestions=True)
return data return data
def push_inline_code_suggestions(self, data): def push_inline_code_suggestions(self, data):