Compare commits

..

13 Commits

Author SHA1 Message Date
b3e79ed677 cli.py - modify command line for a more coherent command invokation 2023-07-16 13:18:29 +03:00
5d2fe07bf7 Merge pull request #47 from Codium-ai/feature/github_action
Github custom action development
2023-07-16 12:54:40 +03:00
84bf95e9ab Merge pull request #50 from Codium-ai/tr/numbered_hunks
Adding numbered hunks and code suggestions feature
2023-07-16 12:27:29 +03:00
4f4989af8c full code suggestions
full code suggestions
2023-07-16 09:01:57 +03:00
23a249ccdb Merge pull request #48 from Codium-ai/hl/gitlab_fix
Inline suggestion refactor + supporting GitLab
2023-07-14 22:53:52 +03:00
4a6bf4c55a Merge branch 'main' into hl/gitlab_fix 2023-07-14 22:48:13 +03:00
3f75b14ba3 small addition 2023-07-14 22:45:07 +03:00
ae9cedd50d Merge pull request #46 from Codium-ai/tr/description_tool
Add PR Description Tool
2023-07-13 21:00:50 +03:00
ae63833043 Merge commit '055a8ea8590fbe9078cdc6af6398df2f053b9ce7' into hl/gitlab_fix 2023-07-13 20:44:26 +03:00
da6828ad87 Inline suggestion refactor + Gitlab WORKS 2023-07-13 20:43:49 +03:00
4e59693c76 diff_files 2023-07-13 18:26:35 +03:00
055a8ea859 Merge pull request #44 from zmeir/patch-1
Typo when setting `openai.api_version`
2023-07-13 17:52:33 +03:00
f57d58ee7d Typo when setting openai.api_version 2023-07-13 10:22:57 +03:00
13 changed files with 564 additions and 50 deletions

View File

@ -18,7 +18,7 @@ class AiHandler:
if settings.get("OPENAI.API_TYPE", None):
openai.api_type = settings.openai.api_type
if settings.get("OPENAI.API_VERSION", None):
openai.engine = settings.openai.api_version
openai.api_version = settings.openai.api_version
if settings.get("OPENAI.API_BASE", None):
openai.api_base = settings.openai.api_base
except AttributeError as e:

View File

@ -13,6 +13,9 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str:
if not patch_str or num_lines == 0:
return patch_str
if type(original_file_str) == bytes:
original_file_str = original_file_str.decode('utf-8')
original_lines = original_file_str.splitlines()
patch_lines = patch_str.splitlines()
extended_patch_lines = []
@ -105,3 +108,78 @@ def handle_patch_deletions(patch: str, original_file_content_str: str,
logging.info(f"Processing file: {file_name}, hunks were deleted")
patch = patch_new
return patch
def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
# toDO: (maybe remove '-' and '+' from the beginning of the line)
"""
## src/file.ts
--new hunk--
881 line1
882 line2
883 line3
884 line4
885 line6
886 line7
887 + line8
888 + line9
889 line10
890 line11
...
--old hunk--
line1
line2
- line3
- line4
line5
line6
...
"""
patch_with_lines_str = f"## {file.filename}\n"
import re
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
for line in patch_lines:
if 'no newline at end of file' in line.lower():
continue
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if match and new_content_lines: # found a new hunk, split the previous lines
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '--old hunk--\n'
for i, line_old in enumerate(old_content_lines):
patch_with_lines_str += f"{line_old}\n"
new_content_lines = []
old_content_lines = []
start1, size1, start2, size2 = map(int, match.groups()[:4])
elif line.startswith('+'):
new_content_lines.append(line)
elif line.startswith('-'):
old_content_lines.append(line)
else:
new_content_lines.append(line)
old_content_lines.append(line)
# finishing last hunk
if match and new_content_lines:
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '\n--old hunk--\n'
for i, line_old in enumerate(old_content_lines):
patch_with_lines_str += f"{line_old}\n"
return patch_with_lines_str.strip()

View File

@ -4,7 +4,8 @@ import difflib
import logging
from typing import Any, Tuple, Union
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions
from pr_agent.algo.git_patch_processing import extend_patch, handle_patch_deletions, \
convert_to_hunks_with_lines_numbers
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import settings
@ -19,26 +20,33 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHandler) -> str:
def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
"""
Returns a string with the diff of the PR.
If needed, apply diff minimization techniques to reduce the number of tokens
"""
git_provider.pr.files = list(git_provider.get_diff_files())
if disable_extra_lines:
global PATCH_EXTRA_LINES
PATCH_EXTRA_LINES = 0
git_provider.pr.diff_files = list(git_provider.get_diff_files())
# get pr languages
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), git_provider.pr.files)
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), git_provider.pr.diff_files)
# generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler)
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler,
add_line_numbers_to_hunks)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
return "\n".join(patches_extended)
# if we are over the limit, start pruning
patches_compressed, modified_file_names, deleted_file_names = pr_generate_compressed_diff(pr_languages,
token_handler)
patches_compressed, modified_file_names, deleted_file_names = \
pr_generate_compressed_diff(pr_languages, token_handler, add_line_numbers_to_hunks)
final_diff = "\n".join(patches_compressed)
if modified_file_names:
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
@ -49,7 +57,8 @@ def get_pr_diff(git_provider: Union[GithubProvider, Any], token_handler: TokenHa
return final_diff
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -> \
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]:
"""
Generate a standard diff string, with patch extension
@ -72,6 +81,9 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
@ -80,7 +92,8 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler) -
return patches_extended, total_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) -> Tuple[list, list, list]:
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
# Apply Diff Minimization techniques to reduce the number of tokens:
# 0. Start from the largest diff patch to smaller ones
# 1. Don't use extend context lines around diff
@ -114,6 +127,10 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) ->
deleted_files_list.append(file.filename)
total_tokens += token_handler.count_tokens(file.filename) + 1
continue
if convert_hunks_to_line_numbers:
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
# Hard Stop, no more tokens
@ -135,7 +152,10 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler) ->
continue
if patch:
if not convert_hunks_to_line_numbers:
patch_final = f"## {file.filename}\n\n{patch}\n"
else:
patch_final = patch
patches.append(patch_final)
total_tokens += token_handler.count_tokens(patch_final)
if settings.config.verbosity_level >= 2:

View File

@ -3,30 +3,55 @@ import asyncio
import logging
import os
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer
def run():
parser = argparse.ArgumentParser(description='AI based pull request analyzer')
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage="""\
Usage: cli.py --pr-url <URL on supported git hosting service> <command> [<args>].
Supported commands:
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
""")
parser.add_argument('--pr_url', type=str, help='The URL of the PR to review', required=True)
parser.add_argument('--question', type=str, help='Optional question to ask', required=False)
parser.add_argument('--pr_description', action='store_true', help='Optional question to ask', required=False)
parser.add_argument('command', type=str, help='The', choices=['review', 'review_pr',
'ask', 'ask_question',
'describe', 'describe_pr',
'improve', 'improve_code'], default='review')
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
args = parser.parse_args()
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
if args.question:
print(f"Question: {args.question} about PR {args.pr_url}")
reviewer = PRQuestions(args.pr_url, args.question)
command = args.command.lower()
if command in ['ask', 'ask_question']:
question = ' '.join(args.rest).strip()
if len(question) == 0:
print("Please specify a question")
parser.print_help()
return
print(f"Question: {question} about PR {args.pr_url}")
reviewer = PRQuestions(args.pr_url, question)
asyncio.run(reviewer.answer())
elif args.pr_description:
elif command in ['describe', 'describe_pr']:
print(f"PR description: {args.pr_url}")
reviewer = PRDescription(args.pr_url)
asyncio.run(reviewer.describe())
else:
elif command in ['improve', 'improve_code']:
print(f"PR code suggestions: {args.pr_url}")
reviewer = PRCodeSuggestions(args.pr_url)
asyncio.run(reviewer.suggest())
elif command in ['review', 'review_pr']:
print(f"Reviewing PR: {args.pr_url}")
reviewer = PRReviewer(args.pr_url, cli_mode=True)
asyncio.run(reviewer.review())
else:
print(f"Unknown command: {command}")
parser.print_help()
if __name__ == '__main__':

View File

@ -12,6 +12,7 @@ settings = Dynaconf(
"settings/pr_reviewer_prompts.toml",
"settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings_prod/.secrets.toml"
]]
)

View File

@ -1,6 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
from enum import Enum
class EDIT_TYPE(Enum):
ADDED = 1
DELETED = 2
MODIFIED = 3
RENAMED = 4
@dataclass
class FilePatchInfo:
@ -9,6 +16,8 @@ class FilePatchInfo:
patch: str
filename: str
tokens: int = -1
edit_type: EDIT_TYPE = EDIT_TYPE.MODIFIED
old_filename: str = None
class GitProvider(ABC):
@ -24,6 +33,15 @@ class GitProvider(ABC):
def publish_comment(self, pr_comment: str, is_temporary: bool = False):
pass
@abstractmethod
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
pass
@abstractmethod
def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
pass
@abstractmethod
def remove_initial_comment(self):
pass

View File

@ -7,10 +7,10 @@ from github import AppAuthentication, Github
from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo
from .git_provider import FilePatchInfo, GitProvider
class GithubProvider:
class GithubProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None):
self.installation_id = settings.get("GITHUB.INSTALLATION_ID")
self.github_client = self._get_github_client()
@ -18,16 +18,16 @@ class GithubProvider:
self.pr_num = None
self.pr = None
self.github_user_id = None
self.diff_files = None
if pr_url:
self.set_pr(pr_url)
self.last_commit_id = list(self.pr.get_commits())[-1]
def set_pr(self, pr_url: str):
self.repo, self.pr_num = self._parse_pr_url(pr_url)
self.pr = self._get_pr()
def get_files(self):
if hasattr(self.pr, 'files'):
return self.pr.files
return self.pr.get_files()
def get_diff_files(self) -> list[FilePatchInfo]:
@ -37,6 +37,7 @@ class GithubProvider:
original_file_content_str = self._get_pr_file_content(file, self.pr.base.sha)
new_file_content_str = self._get_pr_file_content(file, self.pr.head.sha)
diff_files.append(FilePatchInfo(original_file_content_str, new_file_content_str, file.patch, file.filename))
self.diff_files = diff_files
return diff_files
def publish_description(self, pr_title: str, pr_body: str):
@ -52,6 +53,76 @@ class GithubProvider:
self.pr.comments_list = []
self.pr.comments_list.append(response)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
position = -1
for file in self.diff_files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
position = i
break
if position == -1:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
path = relevant_file.strip()
self.pr.create_review_comment(body=body, commit_id=self.last_commit_id, path=path, position=position)
def publish_code_suggestion(self, body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int):
if not relevant_lines_start or relevant_lines_start == -1:
if settings.config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, relevant_lines_start is {relevant_lines_start}")
return False
if relevant_lines_end<relevant_lines_start:
if settings.config.verbosity_level >= 2:
logging.exception(f"Failed to publish code suggestion, "
f"relevant_lines_end is {relevant_lines_end} and "
f"relevant_lines_start is {relevant_lines_start}")
return False
try:
import github.PullRequestComment
if relevant_lines_end > relevant_lines_start:
post_parameters = {
"body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file,
"line": relevant_lines_end,
"start_line": relevant_lines_start,
"start_side": "RIGHT",
}
else: # API is different for single line comments
post_parameters = {
"body": body,
"commit_id": self.last_commit_id._identity,
"path": relevant_file,
"line": relevant_lines_start,
"side": "RIGHT",
}
headers, data = self.pr._requester.requestJsonAndCheck(
"POST", f"{self.pr.url}/comments", input=post_parameters
)
github.PullRequestComment.PullRequestComment(
self.pr._requester, headers, data, completed=True
)
return True
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.error(f"Failed to publish code suggestion, error: {e}")
return False
def remove_initial_comment(self):
try:
for comment in self.pr.comments_list:
@ -152,9 +223,9 @@ class GithubProvider:
def _get_pr(self):
return self._get_repo().get_pull(self.pr_num)
def _get_pr_file_content(self, file: FilePatchInfo, sha: str):
def _get_pr_file_content(self, file: FilePatchInfo, sha: str) -> str:
try:
file_content_str = self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode()
file_content_str = str(self._get_repo().get_contents(file.filename, ref=sha).decoded_content.decode())
except Exception:
file_content_str = ""
return file_content_str

View File

@ -1,4 +1,5 @@
import logging
import re
from typing import Optional, Tuple
from urllib.parse import urlparse
@ -6,7 +7,7 @@ import gitlab
from pr_agent.config_loader import settings
from .git_provider import FilePatchInfo, GitProvider
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
class GitLabProvider(GitProvider):
@ -24,6 +25,7 @@ class GitLabProvider(GitProvider):
self.id_project = None
self.id_mr = None
self.mr = None
self.diff_files = None
self.temp_comments = []
self._set_merge_request(merge_request_url)
@ -35,10 +37,35 @@ class GitLabProvider(GitProvider):
def _set_merge_request(self, merge_request_url: str):
self.id_project, self.id_mr = self._parse_merge_request_url(merge_request_url)
self.mr = self._get_merge_request()
self.last_diff = self.mr.diffs.list()[-1]
def _get_pr_file_content(self, file_path: str, branch: str) -> str:
return self.gl.projects.get(self.id_project).files.get(file_path, branch).decode()
def get_diff_files(self) -> list[FilePatchInfo]:
diffs = self.mr.changes()['changes']
diff_files = [FilePatchInfo("", "", diff['diff'], diff['new_path']) for diff in diffs]
diff_files = []
for diff in diffs:
original_file_content_str = self._get_pr_file_content(diff['old_path'], self.mr.target_branch)
new_file_content_str = self._get_pr_file_content(diff['new_path'], self.mr.source_branch)
edit_type = EDIT_TYPE.MODIFIED
if diff['new_file']:
edit_type = EDIT_TYPE.ADDED
elif diff['deleted_file']:
edit_type = EDIT_TYPE.DELETED
elif diff['renamed_file']:
edit_type = EDIT_TYPE.RENAMED
try:
original_file_content_str = bytes.decode(original_file_content_str, 'utf-8')
new_file_content_str = bytes.decode(new_file_content_str, 'utf-8')
except UnicodeDecodeError:
logging.warning(
f"Cannot decode file {diff['old_path']} or {diff['new_path']} in merge request {self.id_mr}")
diff_files.append(
FilePatchInfo(original_file_content_str, new_file_content_str, diff['diff'], diff['new_path'],
edit_type=edit_type,
old_filename=None if diff['old_path'] == diff['new_path'] else diff['old_path']))
self.diff_files = diff_files
return diff_files
def get_files(self):
@ -53,6 +80,87 @@ class GitLabProvider(GitProvider):
if is_temporary:
self.temp_comments.append(comment)
def publish_inline_comment(self, body: str, relevant_file: str, relevant_line_in_file: str):
self.diff_files = self.diff_files if self.diff_files else self.get_diff_files()
edit_type, found, source_line_no, target_file, target_line_no = self.search_line(relevant_file,
relevant_line_in_file)
if not found:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
if edit_type == 'addition':
position = target_line_no - 1
else:
position = source_line_no - 1
d = self.last_diff
pos_obj = {'position_type': 'text',
'new_path': target_file.filename,
'old_path': target_file.old_filename if target_file.old_filename else target_file.filename,
'base_sha': d.base_commit_sha, 'start_sha': d.start_commit_sha, 'head_sha': d.head_commit_sha}
if edit_type == 'deletion':
pos_obj['old_line'] = position
elif edit_type == 'addition':
pos_obj['new_line'] = position
else:
pos_obj['new_line'] = position
pos_obj['old_line'] = position
self.mr.discussions.create({'body': body,
'position': pos_obj})
def publish_code_suggestion(self, body: str,
relevant_file: str,
relevant_lines_start: int,
relevant_lines_end: int):
raise "not implemented yet for gitlab"
def search_line(self, relevant_file, relevant_line_in_file):
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
target_file = None
source_line_no = 0
target_line_no = 0
found = False
edit_type = self.get_edit_type(relevant_line_in_file)
for file in self.diff_files:
if file.filename == relevant_file:
target_file = file
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if line.startswith('@@'):
match = RE_HUNK_HEADER.match(line)
if not match:
continue
start_old, size_old, start_new, size_new, _ = match.groups()
source_line_no = int(start_old)
target_line_no = int(start_new)
continue
if line.startswith('-'):
source_line_no += 1
elif line.startswith('+'):
target_line_no += 1
elif line.startswith(' '):
source_line_no += 1
target_line_no += 1
if relevant_line_in_file in line:
found = True
edit_type = self.get_edit_type(line)
break
elif relevant_line_in_file[0] == '+' and relevant_line_in_file[1:] in line:
# The model often adds a '+' to the beginning of the relevant_line_in_file even if originally
# it's a context line
found = True
edit_type = self.get_edit_type(line)
break
return edit_type, found, source_line_no, target_file, target_line_no
def get_edit_type(self, relevant_line_in_file):
edit_type = 'context'
if relevant_line_in_file[0] == '-':
edit_type = 'deletion'
elif relevant_line_in_file[0] == '+':
edit_type = 'addition'
return edit_type
def remove_initial_comment(self):
try:
for comment in self.temp_comments:
@ -94,3 +202,6 @@ class GitLabProvider(GitProvider):
def _get_merge_request(self):
mr = self.gl.projects.get(self.id_project).mergerequests.get(self.id_mr)
return mr
def get_user_id(self):
return None

View File

@ -8,11 +8,14 @@ verbosity_level=0 # 0,1,2
require_focused_review=true
require_tests_review=true
require_security_review=true
num_code_suggestions=4
num_code_suggestions=3
inline_code_comments = true
[pr_questions]
[pr_code_suggestions]
num_code_suggestions=4
[github]
# The type of deployment to create. Valid values are 'app' or 'user'.
deployment_type = "user"

View File

@ -0,0 +1,79 @@
[pr_code_suggestions_prompt]
system="""You are a language model called CodiumAI-PR-Code-Reviewer.
Your task is to provide provide meaningfull non-trivial code suggestions to improve the new code in a PR (the '+' lines).
- Try to give important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to the 'new hunk' code, and focus on improving the new added code lines, with '+'.
- Provide the exact line number range (inclusive) for each issue.
- Assume there is additional code in the relevant file that is not included in the diff.
- Provide up to {{ num_code_suggestions }} code suggestions.
- Make sure not to provide suggestion repeating modifications already implemented in the new PR code (the '+' lines).
- Don't output line numbers in the 'improved code' snippets.
You must use the following JSON schema to format your answer:
```json
{
"Code suggestions": {
"type": "array",
"minItems": 1,
"maxItems": {{ num_code_suggestions }},
"uniqueItems": "true",
"items": {
"relevant file": {
"type": "string",
"description": "the relevant file full path"
},
"suggestion content": {
"type": "string",
"description": "a concrete suggestion for meaningfully improving the new PR code."
},
"existing code": {
"type": "string",
"description": "a code snippet showing authentic relevant code lines from a 'new hunk' section. It must be continuous, correctly formatted and indented, and without line numbers."
},
"relevant lines": {
"type": "string",
"description": "the relevant lines in the 'new hunk' sections, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
},
"improved code": {
"type": "string",
"description": "a new code snippet that can be used to replace the relevant lines in 'new hunk' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
}
}
}
}
```
Example input:
'
## src/file1.py
---new_hunk---
```
[new hunk code, annotated with line numbers]
```
---old_hunk---
```
[old hunk code]
```
...
'
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
"""
user="""PR Info:
Title: '{{title}}'
Branch: '{{branch}}'
Description: '{{description}}'
{%- if language %}
Main language: {{language}}
{%- endif %}
The PR Diff:
```
{{diff}}
```
Response (should be a valid JSON, and nothing else):
```json
"""

View File

@ -0,0 +1,127 @@
import copy
import json
import logging
import textwrap
from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
from pr_agent.config_loader import settings
from pr_agent.git_providers import get_git_provider, GithubProvider
from pr_agent.git_providers.git_provider import get_main_pr_language
class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
"description": self.git_provider.get_pr_description(),
"language": self.main_language,
"diff": "", # empty diff for initial calculation
'num_code_suggestions': settings.pr_code_suggestions.num_code_suggestions,
}
self.token_handler = TokenHandler(self.git_provider.pr,
self.vars,
settings.pr_code_suggestions_prompt.system,
settings.pr_code_suggestions_prompt.user)
async def suggest(self):
assert type(self.git_provider) == GithubProvider, "Only Github is supported for now"
logging.info('Generating code suggestions for PR...')
if settings.config.publish_review:
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
logging.info('Getting PR diff...')
# we are using extended hunk with line numbers for code suggestions
self.patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)
logging.info('Getting AI prediction...')
self.prediction = await self._get_prediction()
logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions()
if settings.config.publish_review:
logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...')
self.push_inline_code_suggestions(data)
async def _get_prediction(self):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(settings.pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(settings.pr_code_suggestions_prompt.user).render(variables)
if settings.config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
model = settings.config.model
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
system=system_prompt, user=user_prompt)
return response
def _prepare_pr_code_suggestions(self) -> str:
review = self.prediction.strip()
data = None
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not parse json response: {review}")
data = try_fix_json(review)
return data
def push_inline_code_suggestions(self, data):
for d in data['Code suggestions']:
if settings.config.verbosity_level >= 2:
logging.info(f"suggestion: {d}")
relevant_file = d['relevant file'].strip()
relevant_lines_str = d['relevant lines'].strip()
relevant_lines_start = int(relevant_lines_str.split('-')[0]) # absolute position
relevant_lines_end = int(relevant_lines_str.split('-')[-1])
content = d['suggestion content']
existing_code_snippet = d['existing code']
new_code_snippet = d['improved code']
if new_code_snippet:
try: # dedent code snippet
self.diff_files = self.git_provider.diff_files if self.git_provider.diff_files else self.git_provider.get_diff_files()
original_initial_line = None
for file in self.diff_files:
if file.filename.strip() == relevant_file:
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
break
if original_initial_line:
suggested_initial_line = new_code_snippet.splitlines()[0]
original_initial_spaces = len(original_initial_line) - len(original_initial_line.lstrip())
suggested_initial_spaces = len(suggested_initial_line) - len(suggested_initial_line.lstrip())
delta_spaces = original_initial_spaces - suggested_initial_spaces
if delta_spaces > 0:
new_code_snippet = textwrap.indent(new_code_snippet, delta_spaces * " ").rstrip('\n')
except Exception as e:
if settings.config.verbosity_level >= 2:
logging.info(f"Could not dedent code snippet for file {relevant_file}, error: {e}")
body = f"**Suggestion:** {content}\n```suggestion\n" + new_code_snippet + "\n```"
success = self.git_provider.publish_code_suggestion(body=body,
relevant_file=relevant_file,
relevant_lines_start=relevant_lines_start,
relevant_lines_end=relevant_lines_end)

View File

@ -35,7 +35,7 @@ class PRDescription:
self.prediction = None
async def describe(self):
logging.info('Answering a PR question...')
logging.info('Generating a PR description...')
if settings.config.publish_review:
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
logging.info('Getting PR diff...')

View File

@ -111,34 +111,15 @@ class PRReviewer:
return markdown_text
def _publish_inline_code_comments(self):
if settings.config.git_provider != 'github': # inline comments are currently only supported for github
return
review = self.prediction.strip()
try:
data = json.loads(review)
except json.decoder.JSONDecodeError:
data = try_fix_json(review)
pr = self.git_provider.pr
last_commit_id = list(pr.get_commits())[-1]
files = list(self.git_provider.get_diff_files())
for d in data['PR Feedback']['Code suggestions']:
relevant_file = d['relevant file'].strip()
relevant_line_in_file = d['relevant line in file'].strip()
content = d['suggestion content']
position = -1
for file in files:
if file.filename.strip() == relevant_file:
patch = file.patch
patch_lines = patch.splitlines()
for i, line in enumerate(patch_lines):
if relevant_line_in_file in line:
position = i
if position == -1:
logging.info(f"Could not find position for {relevant_file} {relevant_line_in_file}")
else:
body = content
path = relevant_file.strip()
pr.create_review_comment(body=body, commit_id=last_commit_id, path=path, position=position)
self.git_provider.publish_inline_comment(content, relevant_file, relevant_line_in_file)