Compare commits

..

13 Commits

Author SHA1 Message Date
b3e79ed677 cli.py - modify command line for a more coherent command invokation 2023-07-16 13:18:29 +03:00
5d2fe07bf7 Merge pull request #47 from Codium-ai/feature/github_action
Github custom action development
2023-07-16 12:54:40 +03:00
84bf95e9ab Merge pull request #50 from Codium-ai/tr/numbered_hunks
Adding numbered hunks and code suggestions feature
2023-07-16 12:27:29 +03:00
4f4989af8c full code suggestions
full code suggestions
2023-07-16 09:01:57 +03:00
23a249ccdb Merge pull request #48 from Codium-ai/hl/gitlab_fix
Inline suggestion refactor + supporting GitLab
2023-07-14 22:53:52 +03:00
4a6bf4c55a Merge branch 'main' into hl/gitlab_fix 2023-07-14 22:48:13 +03:00
3f75b14ba3 small addition 2023-07-14 22:45:07 +03:00
ae9cedd50d Merge pull request #46 from Codium-ai/tr/description_tool
Add PR Description Tool
2023-07-13 21:00:50 +03:00
ae63833043 Merge commit '055a8ea8590fbe9078cdc6af6398df2f053b9ce7' into hl/gitlab_fix 2023-07-13 20:44:26 +03:00
da6828ad87 Inline suggestion refactor + Gitlab WORKS 2023-07-13 20:43:49 +03:00
4e59693c76 diff_files 2023-07-13 18:26:35 +03:00
055a8ea859 Merge pull request #44 from zmeir/patch-1
Typo when setting `openai.api_version`
2023-07-13 17:52:33 +03:00
f57d58ee7d Typo when setting openai.api_version 2023-07-13 10:22:57 +03:00
13 changed files with 564 additions and 50 deletions

View File

@ -18,7 +18,7 @@ class AiHandler:
if settings.get("OPENAI.API_TYPE", None): if settings.get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type openai.api_type = settings.openai.api_type
if settings.get("OPENAI.API_VERSION", None): if settings.get("OPENAI.API_VERSION", None):
openai.engine = settings.openai.api_version openai.api_version = settings.openai.api_version
if settings.get("OPENAI.API_BASE", None): if settings.get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base openai.api_base = settings.openai.api_base
except AttributeError as e: except AttributeError as e:

View File

@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
if not patch_str or num_lines == 0: if not patch_str or num_lines == 0:
return patch_str return patch_str
if type(original_file_str) == bytes:
original_file_str = original_file_str.decode('utf-8')
original_lines = original_file_str.splitlines() original_lines = original_file_str.splitlines()
patch_lines = patch_str.splitlines() patch_lines = patch_str.splitlines()
extended_patch_lines = [] extended_patch_lines = []
@ -105,3 +108,78 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
logging.info(f"Processing file: {file_name}, hunks were deleted") logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new patch = patch_new
return patch return patch
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
# toDO: (maybe remove '-' and '+' from the beginning of the line)
"""
## src/file.ts
--new hunk--
881 line1
882 line2
883 line3
884 line4
885 line6
886 line7
887 + line8
888 + line9
889 line10
890 line11
...
--old hunk--
line1
line2
- line3
- line4
line5
line6
...
"""
patch_with_lines_str = f"## {file.filename}\n"
import re
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
for line in patch_lines:
if 'no newline at end of file' in line.lower():
continue
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match and new_content_lines: # found a new hunk, split the previous lines
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '--old hunk--\n'
for i, line_old in enumerate(old_content_lines):
patch_with_lines_str += f"{line_old}\n"
new_content_lines = []
old_content_lines = []
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
new_content_lines.append(line)
old_content_lines.append(line)
# finishing last hunk
if match and new_content_lines:
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '\n--old hunk--\n'
for i, line_old in enumerate(old_content_lines):
patch_with_lines_str += f"{line_old}\n"
return patch_with_lines_str.strip()

View File

@ -4,7 +4,8 @@ import difflib
import logging import logging
from typing import Any, Tuple, Union from typing import Any, Tuple, Union
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions, \
convert_to_hunks_with_lines_numbers
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
@ -19,26 +20,33 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3 PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHandler) -> str: def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
""" """
Returns a string with the diff of the PR. Returns a string with the diff of the PR.
If needed, apply diff minimization techniques to reduce the number of tokens If needed, apply diff minimization techniques to reduce the number of tokens
""" """
git_provider.pr.files = list(git_provider.get_diff_files()) if disable_extra_lines:
global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0
git_provider.pr.diff_files = list(git_provider.get_diff_files())
# get pr languages # get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), git_provider.pr.files) pr_languages = sort_files_by_main_languages(git_provider.get_languages(), git_provider.pr.diff_files)
# generate a standard diff string, with patch extension # generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler) patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler,
add_line_numbers_to_hunks)
# if we are under the limit, return the full diff # if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit: if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
return "\n".join(patches_extended) return "\n".join(patches_extended)
# if we are over the limit, start pruning # if we are over the limit, start pruning
patches_compressed, modified_file_names, deleted_file_names = pr_generate_compressed_diff(pr_languages, patches_compressed, modified_file_names, deleted_file_names = \
token_handler) pr_generate_compressed_diff(pr_languages, token_handler, add_line_numbers_to_hunks)
final_diff = "\n".join(patches_compressed) final_diff = "\n".join(patches_compressed)
if modified_file_names: if modified_file_names:
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names) modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
@ -49,7 +57,8 @@ def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHa
return final_diff return final_diff
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -> \ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]: Tuple[list, int]:
""" """
Generate a standard diff string, with patch extension Generate a standard diff string, with patch extension
@ -72,6 +81,9 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES) extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n" full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
patch_tokens = token_handler.count_tokens(full_extended_patch) patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens file.tokens = patch_tokens
total_tokens += patch_tokens total_tokens += patch_tokens
@ -80,7 +92,8 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -
return patches_extended, total_tokens return patches_extended, total_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) -> Tuple[list, list, list]: def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
# Apply Diff Minimization techniques to reduce the number of tokens: # Apply Diff Minimization techniques to reduce the number of tokens:
# 0. Start from the largest diff patch to smaller ones # 0. Start from the largest diff patch to smaller ones
# 1. Don't use extend context lines around diff # 1. Don't use extend context lines around diff
@ -114,6 +127,10 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) ->
deleted_files_list.append(file.filename) deleted_files_list.append(file.filename)
total_tokens += token_handler.count_tokens(file.filename) + 1 total_tokens += token_handler.count_tokens(file.filename) + 1
continue continue
if convert_hunks_to_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch) new_patch_tokens = token_handler.count_tokens(patch)
# Hard Stop, no more tokens # Hard Stop, no more tokens
@ -135,7 +152,10 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) ->
continue continue
if patch: if patch:
if not convert_hunks_to_line_numbers:
patch_final = f"## {file.filename}\n\n{patch}\n" patch_final = f"## {file.filename}\n\n{patch}\n"
else:
patch_final = patch
patches.append(patch_final) patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final) total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2: if settings.config.verbosity_level >= 2:

View File

@ -3,30 +3,55 @@ import asyncio
import logging import logging
import os import os
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
def run(): def run():
parser = argparse.ArgumentParser(description='AI based pull request analyzer') parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage="""\
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
Supported commands:
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
""")
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True) parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('--question', type=str, help='Optional question to ask', required=False) parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr',
parser.add_argument('--pr_description', action='store_true', help='Optional question to ask', required=False) 'ask', 'ask_question',
'describe', 'describe_pr',
'improve', 'improve_code'], default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
if args.question: command = args.command.lower()
print(f"Question: {args.question} about PR {args.pr_url}") if command in ['ask', 'ask_question']:
reviewer = PRQuestions(args.pr_url, args.question) question = ' '.join(args.rest).strip()
if len(question) == 0:
print("Please specify a question")
parser.print_help()
return
print(f"Question: {question} about PR {args.pr_url}")
reviewer = PRQuestions(args.pr_url, question)
asyncio.run(reviewer.answer()) asyncio.run(reviewer.answer())
elif args.pr_description: elif command in ['describe', 'describe_pr']:
print(f"PR description: {args.pr_url}") print(f"PR description: {args.pr_url}")
reviewer = PRDescription(args.pr_url) reviewer = PRDescription(args.pr_url)
asyncio.run(reviewer.describe()) asyncio.run(reviewer.describe())
else: elif command in ['improve', 'improve_code']:
print(f"PR code suggestions: {args.pr_url}")
reviewer = PRCodeSuggestions(args.pr_url)
asyncio.run(reviewer.suggest())
elif command in ['review', 'review_pr']:
print(f"Reviewing PR: {args.pr_url}") print(f"Reviewing PR: {args.pr_url}")
reviewer = PRReviewer(args.pr_url, cli_mode=True) reviewer = PRReviewer(args.pr_url, cli_mode=True)
asyncio.run(reviewer.review()) asyncio.run(reviewer.review())
else:
print(f"Unknown command: {command}")
parser.print_help()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -12,6 +12,7 @@ settings = Dynaconf(
"settings/pr_reviewer_prompts.toml", "settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml", "settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml", "settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings_prod/.secrets.toml" "settings_prod/.secrets.toml"
]] ]]
) )

View File

@ -1,6 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
@dataclass @dataclass
class FilePatchInfo: class FilePatchInfo:
@ -9,6 +16,8 @@ class FilePatchInfo:
patch: str patch: str
filename: str filename: str
tokens: int = -1 tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
class GitProvider(ABC): class GitProvider(ABC):
@ -24,6 +33,15 @@ class GitProvider(ABC):
def publish_comment(self, pr_comment: str, is_temporary: bool = False): def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass pass
@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
@abstractmethod
def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
pass
@abstractmethod @abstractmethod
def remove_initial_comment(self): def remove_initial_comment(self):
pass pass

View File

@ -7,10 +7,10 @@ from github import AppAuthentication, Github
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo, GitProvider
class GithubProvider: class GithubProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None): def __init__(self, pr_url: Optional[str] = None):
self.installation_id = settings.get("GITHUB.INSTALLATION_ID") self.installation_id = settings.get("GITHUB.INSTALLATION_ID")
self.github_client = self._get_github_client() self.github_client = self._get_github_client()
@ -18,16 +18,16 @@ class GithubProvider:
self.pr_num = None self.pr_num = None
self.pr = None self.pr = None
self.github_user_id = None self.github_user_id = None
self.diff_files = None
if pr_url: if pr_url:
self.set_pr(pr_url) self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1]
def set_pr(self, pr_url: str): def set_pr(self, pr_url: str):
self.repo, self.pr_num = self._parse_pr_url(pr_url) self.repo, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr() self.pr = self._get_pr()
def get_files(self): def get_files(self):
if hasattr(self.pr, 'files'):
return self.pr.files
return self.pr.get_files() return self.pr.get_files()
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
@ -37,6 +37,7 @@ class GithubProvider:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha) original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha) new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename)) diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
self.diff_files = diff_files
return diff_files return diff_files
def publish_description(self, pr_title: str, pr_body: str): def publish_description(self, pr_title: str, pr_body: str):
@ -52,6 +53,76 @@ class GithubProvider:
self.pr.comments_list = [] self.pr.comments_list = []
self.pr.comments_list.append(response) self.pr.comments_list.append(response)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
position = -1
for file in self.diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
break
if position == -1:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
path = relevant_file.strip()
self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position)
def publish_code_suggestion(self, body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int):
if not relevant_lines_start or relevant_lines_start == -1:
if settings.config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
return False
if relevant_lines_end<relevant_lines_start:
if settings.config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
return False
try:
import github.PullRequestComment
if relevant_lines_end > relevant_lines_start:
post_parameters = {
"body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file,
"line": relevant_lines_end,
"start_line": relevant_lines_start,
"start_side": "RIGHT",
}
else: # API is different for single line comments
post_parameters = {
"body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file,
"line": relevant_lines_start,
"side": "RIGHT",
}
headers, data = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.pr.url}/comments", input=post_parameters
)
github.PullRequestComment.PullRequestComment(
self.pr._requester, headers, data, completed=True
)
return True
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.error(f"Failed to publish code suggestion, error: {e}")
return False
def remove_initial_comment(self): def remove_initial_comment(self):
try: try:
for comment in self.pr.comments_list: for comment in self.pr.comments_list:
@ -152,9 +223,9 @@ class GithubProvider:
def _get_pr(self): def _get_pr(self):
return self._get_repo().get_pull(self.pr_num) return self._get_repo().get_pull(self.pr_num)
def _get_pr_file_content(self, file: FilePatchInfo, sha: str): def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str:
try: try:
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode() file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode())
except Exception: except Exception:
file_content_str = "" file_content_str = ""
return file_content_str return file_content_str

View File

@ -1,4 +1,5 @@
import logging import logging
import re
from typing import Optional, Tuple from typing import Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
@ -6,7 +7,7 @@ import gitlab
from pr_agent.config_loader import settings from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
class GitLabProvider(GitProvider): class GitLabProvider(GitProvider):
@ -24,6 +25,7 @@ class GitLabProvider(GitProvider):
self.id_project = None self.id_project = None
self.id_mr = None self.id_mr = None
self.mr = None self.mr = None
self.diff_files = None
self.temp_comments = [] self.temp_comments = []
self._set_merge_request(merge_request_url) self._set_merge_request(merge_request_url)
@ -35,10 +37,35 @@ class GitLabProvider(GitProvider):
def _set_merge_request(self, merge_request_url: str): def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url) self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request() self.mr = self._get_merge_request()
self.last_diff = self.mr.diffs.list()[-1]
def _get_pr_file_content(self, file_path: str, branch: str) -> str:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
def get_diff_files(self) -> list[FilePatchInfo]: def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.mr.changes()['changes'] diffs = self.mr.changes()['changes']
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs] diff_files = []
for diff in diffs:
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED
try:
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError:
logging.warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
self.diff_files = diff_files
return diff_files return diff_files
def get_files(self): def get_files(self):
@ -53,6 +80,87 @@ class GitLabProvider(GitProvider):
if is_temporary: if is_temporary:
self.temp_comments.append(comment) self.temp_comments.append(comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
if not found:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
if edit_type == 'addition':
position = target_line_no - 1
else:
position = source_line_no - 1
d = self.last_diff
pos_obj = {'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
if edit_type == 'deletion':
pos_obj['old_line'] = position
elif edit_type == 'addition':
pos_obj['new_line'] = position
else:
pos_obj['new_line'] = position
pos_obj['old_line'] = position
self.mr.discussions.create({'body': body,
'position': pos_obj})
def publish_code_suggestion(self, body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int):
raise "not implemented yet for gitlab"
def search_line(self, relevant_file, relevant_line_in_file):
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
target_file = None
source_line_no = 0
target_line_no = 0
found = False
edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.diff_files:
if file.filename == relevant_file:
target_file = file
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if not match:
continue
start_old, size_old, start_new, size_new, _ = match.groups()
source_line_no = int(start_old)
target_line_no = int(start_new)
continue
if line.startswith('-'):
source_line_no += 1
elif line.startswith('+'):
target_line_no += 1
elif line.startswith(' '):
source_line_no += 1
target_line_no += 1
if relevant_line_in_file in line:
found = True
edit_type = self.get_edit_type(line)
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
found = True
edit_type = self.get_edit_type(line)
break
return edit_type, found, source_line_no, target_file, target_line_no
def get_edit_type(self, relevant_line_in_file):
edit_type = 'context'
if relevant_line_in_file[0] == '-':
edit_type = 'deletion'
elif relevant_line_in_file[0] == '+':
edit_type = 'addition'
return edit_type
def remove_initial_comment(self): def remove_initial_comment(self):
try: try:
for comment in self.temp_comments: for comment in self.temp_comments:
@ -94,3 +202,6 @@ class GitLabProvider(GitProvider):
def _get_merge_request(self): def _get_merge_request(self):
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr) mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
return mr return mr
def get_user_id(self):
return None

View File

@ -8,11 +8,14 @@ verbosity_level=0 # 0,1,2
require_focused_review=true require_focused_review=true
require_tests_review=true require_tests_review=true
require_security_review=true require_security_review=true
num_code_suggestions=4 num_code_suggestions=3
inline_code_comments = true inline_code_comments = true
[pr_questions] [pr_questions]
[pr_code_suggestions]
num_code_suggestions=4
[github] [github]
# The type of deployment to create. Valid values are 'app' or 'user'. # The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user" deployment_type = "user"

View File

@ -0,0 +1,79 @@
[pr_code_suggestions_prompt]
system="""You are a language model called CodiumAI-PR-Code-Reviewer.
Your task is to provide provide meaningfull non-trivial code suggestions to improve the new code in a PR (the '+' lines).
- Try to give important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to the 'new hunk' code, and focus on improving the new added code lines, with '+'.
- Provide the exact line number range (inclusive) for each issue.
- Assume there is additional code in the relevant file that is not included in the diff.
- Provide up to {{ num_code_suggestions }} code suggestions.
- Make sure not to provide suggestion repeating modifications already implemented in the new PR code (the '+' lines).
- Don't output line numbers in the 'improved code' snippets.
You must use the following JSON schema to format your answer:
```json
{
"Code suggestions": {
"type": "array",
"minItems": 1,
"maxItems": {{ num_code_suggestions }},
"uniqueItems": "true",
"items": {
"relevant file": {
"type": "string",
"description": "the relevant file full path"
},
"suggestion content": {
"type": "string",
"description": "a concrete suggestion for meaningfully improving the new PR code."
},
"existing code": {
"type": "string",
"description": "a code snippet showing authentic relevant code lines from a 'new hunk' section. It must be continuous, correctly formatted and indented, and without line numbers."
},
"relevant lines": {
"type": "string",
"description": "the relevant lines in the 'new hunk' sections, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
},
"improved code": {
"type": "string",
"description": "a new code snippet that can be used to replace the relevant lines in 'new hunk' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
}
}
}
}
```
Example input:
'
## src/file1.py
---new_hunk---
```
[new hunk code, annotated with line numbers]
```
---old_hunk---
```
[old hunk code]
```
...
'
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
"""
user="""PR Info:
Title: '{{title}}'
Branch: '{{branch}}'
Description: '{{description}}'
{%- if language %}
Main language: {{language}}
{%- endif %}
The PR Diff:
```
{{diff}}
```
Response (should be a valid JSON, and nothing else):
```json
"""

View File

@ -0,0 +1,127 @@
import copy
import json
import logging
import textwrap
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider, GithubProvider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
'num_code_suggestions': settings.pr_code_suggestions.num_code_suggestions,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_code_suggestions_prompt.system,
settings.pr_code_suggestions_prompt.user)
async def suggest(self):
assert type(self.git_provider) == GithubProvider, "Only Github is supported for now"
logging.info('Generating code suggestions for PR...')
if settings.config.publish_review:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
logging.info('Getting PR diff...')
# we are using extended hunk with line numbers for code suggestions
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction()
logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions()
if settings.config.publish_review:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
self.push_inline_code_suggestions(data)
async def _get_prediction(self):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
model = settings.config.model
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_code_suggestions(self) -> str:
review = self.prediction.strip()
data = None
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review)
return data
def push_inline_code_suggestions(self, data):
for d in data['Code suggestions']:
if settings.config.verbosity_level >= 2:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
content = d['suggestion content']
existing_code_snippet = d['existing code']
new_code_snippet = d['improved code']
if new_code_snippet:
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```"
success = self.git_provider.publish_code_suggestion(body=body,
relevant_file=relevant_file,
relevant_lines_start=relevant_lines_start,
relevant_lines_end=relevant_lines_end)

View File

@ -35,7 +35,7 @@ class PRDescription:
self.prediction = None self.prediction = None
async def describe(self): async def describe(self):
logging.info('Answering a PR question...') logging.info('Generating a PR description...')
if settings.config.publish_review: if settings.config.publish_review:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
logging.info('Getting PR diff...') logging.info('Getting PR diff...')

View File

@ -111,34 +111,15 @@ class PRReviewer:
return markdown_text return markdown_text
def _publish_inline_code_comments(self): def _publish_inline_code_comments(self):
if settings.config.git_provider != 'github': # inline comments are currently only supported for github
return
review = self.prediction.strip() review = self.prediction.strip()
try: try:
data = json.loads(review) data = json.loads(review)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
data = try_fix_json(review) data = try_fix_json(review)
pr = self.git_provider.pr
last_commit_id = list(pr.get_commits())[-1]
files = list(self.git_provider.get_diff_files())
for d in data['PR Feedback']['Code suggestions']: for d in data['PR Feedback']['Code suggestions']:
relevant_file = d['relevant file'].strip() relevant_file = d['relevant file'].strip()
relevant_line_in_file = d['relevant line in file'].strip() relevant_line_in_file = d['relevant line in file'].strip()
content = d['suggestion content'] content = d['suggestion content']
position = -1
for file in files: self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
if position == -1:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
body = content
path = relevant_file.strip()
pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position)