diff --git a/README.md b/README.md index d5f44f96..cd428e53 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ To set up your own PR-Agent, see the [Quickstart](#Quickstart) section | | ⮑ Inline review | ✓ | ✓ | | | | Ask | ✓ | ✓ | | | | Auto-Description | ✓ | | | -| | Improve Code | ✓ | | | +| | Improve Code | ✓ | ✓ | | | | | | | | | USAGE | CLI | ✓ | ✓ | ✓ | | | Tagging bot | ✓ | ✓ | | diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index ff4bbdac..1f8c175f 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -61,18 +61,24 @@ def parse_code_suggestion(code_suggestions: dict) -> str: return markdown_text -def try_fix_json(review, max_iter=10): +def try_fix_json(review, max_iter=10, code_suggestions=False): + if review.endswith("}"): + return fix_json_escape_char(review) # Try to fix JSON if it is broken/incomplete: parse until the last valid code suggestion data = {} + if code_suggestions: + closing_bracket = "]}" + else: + closing_bracket = "]}}" if review.rfind("'Code suggestions': [") > 0 or review.rfind('"Code suggestions": [') > 0: last_code_suggestion_ind = [m.end() for m in re.finditer(r"\}\s*,", review)][-1] - 1 valid_json = False iter_count = 0 while last_code_suggestion_ind > 0 and not valid_json and iter_count < max_iter: try: - data = json.loads(review[:last_code_suggestion_ind] + "]}}") + data = json.loads(review[:last_code_suggestion_ind] + closing_bracket) valid_json = True - review = review[:last_code_suggestion_ind].strip() + "]}}" + review = review[:last_code_suggestion_ind].strip() + closing_bracket except json.decoder.JSONDecodeError: review = review[:last_code_suggestion_ind] # Use regular expression to find the last occurrence of "}," with any number of whitespaces or newlines @@ -82,3 +88,17 @@ def try_fix_json(review, max_iter=10): logging.error("Unable to decode JSON response from AI") data = {} return data + +def fix_json_escape_char(json_message=None): + result = None + try: + result = json.loads(json_message) + except Exception as e: + # Find the offending character index: + idx_to_replace = int(str(e).split(' ')[-1].replace(')', '')) + # Remove the offending character: + json_message = list(json_message) + json_message[idx_to_replace] = ' ' + new_message = ''.join(json_message) + return fix_JSON(json_message=new_message) + return result \ No newline at end of file diff --git a/pr_agent/git_providers/gitlab_provider.py b/pr_agent/git_providers/gitlab_provider.py index 2c0d6863..d6bc3ee4 100644 --- a/pr_agent/git_providers/gitlab_provider.py +++ b/pr_agent/git_providers/gitlab_provider.py @@ -28,6 +28,8 @@ class GitLabProvider(GitProvider): self.diff_files = None self.temp_comments = [] self._set_merge_request(merge_request_url) + self.RE_HUNK_HEADER = re.compile( + r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") @property def pr(self): @@ -84,25 +86,26 @@ class GitLabProvider(GitProvider): self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file, relevant_line_in_file) + self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, + target_file, target_line_no) + + def send_inline_comment(self, body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, + target_file, target_line_no): if not found: logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}") else: - if edit_type == 'addition': - position = target_line_no - 1 - else: - position = source_line_no - 1 d = self.last_diff pos_obj = {'position_type': 'text', - 'new_path': target_file.filename, - 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, - 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha} + 'new_path': target_file.filename, + 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, + 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha} if edit_type == 'deletion': - pos_obj['old_line'] = position + pos_obj['old_line'] = source_line_no - 1 elif edit_type == 'addition': - pos_obj['new_line'] = position + pos_obj['new_line'] = target_line_no - 1 else: - pos_obj['new_line'] = position - pos_obj['old_line'] = position + pos_obj['new_line'] = target_line_no - 1 + pos_obj['old_line'] = source_line_no - 1 self.mr.discussions.create({'body': body, 'position': pos_obj}) @@ -110,47 +113,81 @@ class GitLabProvider(GitProvider): relevant_file: str, relevant_lines_start: int, relevant_lines_end: int): - raise "not implemented yet for gitlab" + self.diff_files = self.diff_files if self.diff_files else self.get_diff_files() + target_file = None + for file in self.diff_files: + if file.filename == relevant_file: + if file.filename == relevant_file: + target_file = file + break + range = relevant_lines_end - relevant_lines_start + 1 + body = body.replace('```suggestion', f'```suggestion:-0+{range}') + + d = self.last_diff + # + # pos_obj = {'position_type': 'text', + # 'new_path': target_file.filename, + # 'old_path': target_file.old_filename if target_file.old_filename else target_file.filename, + # 'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha} + lines = target_file.head_file.splitlines() + relevant_line_in_file = lines[relevant_lines_start - 1] + edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(target_file, relevant_line_in_file) + self.send_inline_comment(body, edit_type, found, relevant_file, relevant_line_in_file, source_line_no, + target_file, target_line_no) + # if lines[relevant_lines_start][0] == '-': + # pos_obj['old_line'] = relevant_lines_start + # elif lines[relevant_lines_start][0] == '+': + # pos_obj['new_line'] = relevant_lines_start + # else: + # pos_obj['new_line'] = relevant_lines_start + # pos_obj['old_line'] = relevant_lines_start + # self.mr.discussions.create({'body': body, + # 'position': pos_obj}) def search_line(self, relevant_file, relevant_line_in_file): - RE_HUNK_HEADER = re.compile( - r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") target_file = None - source_line_no = 0 - target_line_no = 0 - found = False + edit_type = self.get_edit_type(relevant_line_in_file) for file in self.diff_files: if file.filename == relevant_file: - target_file = file - patch = file.patch - patch_lines = patch.splitlines() - for i, line in enumerate(patch_lines): - if line.startswith('@@'): - match = RE_HUNK_HEADER.match(line) - if not match: - continue - start_old, size_old, start_new, size_new, _ = match.groups() - source_line_no = int(start_old) - target_line_no = int(start_new) - continue - if line.startswith('-'): - source_line_no += 1 - elif line.startswith('+'): - target_line_no += 1 - elif line.startswith(' '): - source_line_no += 1 - target_line_no += 1 - if relevant_line_in_file in line: - found = True - edit_type = self.get_edit_type(line) - break - elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line: - # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally - # it's a context line - found = True - edit_type = self.get_edit_type(line) - break + edit_type, found, source_line_no, target_file, target_line_no = self.find_in_file(file, + relevant_line_in_file) + return edit_type, found, source_line_no, target_file, target_line_no + + def find_in_file(self, file, relevant_line_in_file): + edit_type = 'context' + source_line_no = 0 + target_line_no = 0 + found = False + target_file = file + patch = file.patch + patch_lines = patch.splitlines() + for i, line in enumerate(patch_lines): + if line.startswith('@@'): + match = self.RE_HUNK_HEADER.match(line) + if not match: + continue + start_old, size_old, start_new, size_new, _ = match.groups() + source_line_no = int(start_old) + target_line_no = int(start_new) + continue + if line.startswith('-'): + source_line_no += 1 + elif line.startswith('+'): + target_line_no += 1 + elif line.startswith(' '): + source_line_no += 1 + target_line_no += 1 + if relevant_line_in_file in line: + found = True + edit_type = self.get_edit_type(line) + break + elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line: + # The model often adds a '+' to the beginning of the relevant_line_in_file even if originally + # it's a context line + found = True + edit_type = self.get_edit_type(line) + break return edit_type, found, source_line_no, target_file, target_line_no def get_edit_type(self, relevant_line_in_file): diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index d55326bc..c008368a 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -10,7 +10,7 @@ from pr_agent.algo.pr_processing import get_pr_diff from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.config_loader import settings -from pr_agent.git_providers import get_git_provider, GithubProvider +from pr_agent.git_providers import get_git_provider, BitbucketProvider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -39,7 +39,7 @@ class PRCodeSuggestions: settings.pr_code_suggestions_prompt.user) async def suggest(self): - assert type(self.git_provider) == GithubProvider, "Only Github is supported for now" + assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now" logging.info('Generating code suggestions for PR...') if settings.config.publish_review: @@ -86,7 +86,7 @@ class PRCodeSuggestions: except json.decoder.JSONDecodeError: if settings.config.verbosity_level >= 2: logging.info(f"Could not parse json response: {review}") - data = try_fix_json(review) + data = try_fix_json(review, code_suggestions=True) return data def push_inline_code_suggestions(self, data):