Merge branch 'main' into zmeir-automatic_github_app_options

This commit is contained in:
zmeir
2023-08-22 21:09:23 +03:00
19 changed files with 392 additions and 94 deletions

View File

@ -54,6 +54,10 @@ on:
jobs: jobs:
pr_agent_job: pr_agent_job:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
contents: write
name: Run pr agent on every pull request, respond to user comments name: Run pr agent on every pull request, respond to user comments
steps: steps:
- name: PR Agent action step - name: PR Agent action step
@ -72,6 +76,10 @@ on:
jobs: jobs:
pr_agent_job: pr_agent_job:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
contents: write
name: Run pr agent on every pull request, respond to user comments name: Run pr agent on every pull request, respond to user comments
steps: steps:
- name: PR Agent action step - name: PR Agent action step

View File

@ -55,7 +55,7 @@ class AiHandler:
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str): async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
""" """
Performs a chat completion using the OpenAI ChatCompletion API. Performs a chat completion using the OpenAI ChatCompletion API.
Retries in case of API errors or timeouts. Retries in case of API errors or timeouts.

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
@ -157,7 +156,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
example output: example output:
## src/file.ts ## src/file.ts
--new hunk-- __new hunk__
881 line1 881 line1
882 line2 882 line2
883 line3 883 line3
@ -166,7 +165,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
889 line6 889 line6
890 line7 890 line7
... ...
--old hunk-- __old hunk__
line1 line1
line2 line2
- line3 - line3
@ -176,8 +175,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
... ...
""" """
patch_with_lines_str = f"## {file.filename}\n" patch_with_lines_str = f"\n\n## {file.filename}\n"
import re
patch_lines = patch.splitlines() patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile( RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)") r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
@ -185,23 +183,30 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
old_content_lines = [] old_content_lines = []
match = None match = None
start1, size1, start2, size2 = -1, -1, -1, -1 start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line =[]
for line in patch_lines: for line in patch_lines:
if 'no newline at end of file' in line.lower(): if 'no newline at end of file' in line.lower():
continue continue
if line.startswith('@@'): if line.startswith('@@'):
header_line = line
match = RE_HUNK_HEADER.match(line) match = RE_HUNK_HEADER.match(line)
if match and new_content_lines: # found a new hunk, split the previous lines if match and new_content_lines: # found a new hunk, split the previous lines
if new_content_lines: if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n' if prev_header_line:
patch_with_lines_str += f'\n{prev_header_line}\n'
patch_with_lines_str += '__new hunk__\n'
for i, line_new in enumerate(new_content_lines): for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n" patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines: if old_content_lines:
patch_with_lines_str += '--old hunk--\n' patch_with_lines_str += '__old hunk__\n'
for line_old in old_content_lines: for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n" patch_with_lines_str += f"{line_old}\n"
new_content_lines = [] new_content_lines = []
old_content_lines = [] old_content_lines = []
if match:
prev_header_line = header_line
try: try:
start1, size1, start2, size2 = map(int, match.groups()[:4]) start1, size1, start2, size2 = map(int, match.groups()[:4])
except: # '@@ -0,0 +1 @@' case except: # '@@ -0,0 +1 @@' case
@ -219,12 +224,13 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
# finishing last hunk # finishing last hunk
if match and new_content_lines: if match and new_content_lines:
if new_content_lines: if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n' patch_with_lines_str += f'\n{header_line}\n'
patch_with_lines_str += '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines): for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n" patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines: if old_content_lines:
patch_with_lines_str += '\n--old hunk--\n' patch_with_lines_str += '\n__old hunk__\n'
for line_old in old_content_lines: for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n" patch_with_lines_str += f"{line_old}\n"
return patch_with_lines_str.strip() return patch_with_lines_str.rstrip()

View File

@ -24,7 +24,7 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3 PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: add_line_numbers_to_hunks: bool = True, disable_extra_lines: bool = True) -> str:
""" """
Returns a string with the diff of the pull request, applying diff minimization techniques if needed. Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
@ -57,7 +57,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files) pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# generate a standard diff string, with patch extension # generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler, patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(pr_languages, token_handler,
add_line_numbers_to_hunks) add_line_numbers_to_hunks)
# if we are under the limit, return the full diff # if we are under the limit, return the full diff
@ -78,9 +78,9 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
return final_diff return final_diff
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, def pr_generate_extended_diff(pr_languages: list,
add_line_numbers_to_hunks: bool) -> \ token_handler: TokenHandler,
Tuple[list, int]: add_line_numbers_to_hunks: bool) -> Tuple[list, int, list]:
""" """
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed. minimization techniques if needed.
@ -90,13 +90,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
files. files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
Returns:
- patches_extended: A list of extended patches for each file in the pull request.
- total_tokens: The total number of tokens used in the extended patches.
""" """
total_tokens = token_handler.prompt_tokens # initial tokens total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = [] patches_extended = []
patches_extended_tokens = []
for lang in pr_languages: for lang in pr_languages:
for file in lang['files']: for file in lang['files']:
original_file_content_str = file.base_file original_file_content_str = file.base_file
@ -106,7 +103,7 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
# extend each patch with extra lines of context # extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES) extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n" full_extended_patch = f"\n\n## {file.filename}\n\n{extended_patch}\n"
if add_line_numbers_to_hunks: if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file) full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
@ -114,9 +111,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
patch_tokens = token_handler.count_tokens(full_extended_patch) patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens file.tokens = patch_tokens
total_tokens += patch_tokens total_tokens += patch_tokens
patches_extended_tokens.append(patch_tokens)
patches_extended.append(full_extended_patch) patches_extended.append(full_extended_patch)
return patches_extended, total_tokens return patches_extended, total_tokens, patches_extended_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
@ -324,7 +322,9 @@ def clip_tokens(text: str, max_tokens: int) -> str:
Returns: Returns:
str: The clipped string. str: The clipped string.
""" """
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word if not text:
return text
try: try:
encoder = get_token_encoder() encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text)) num_input_tokens = len(encoder.encode(text))
@ -337,4 +337,84 @@ def clip_tokens(text: str, max_tokens: int) -> str:
return clipped_text return clipped_text
except Exception as e: except Exception as e:
logging.warning(f"Failed to clip tokens: {e}") logging.warning(f"Failed to clip tokens: {e}")
return text return text
def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
Args:
git_provider (GitProvider): An object that provides access to Git provider APIs.
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
model (str): The name of the model.
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
Returns:
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
Raises:
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
"""
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise
# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
# Sort files within each language group by tokens in descending order
sorted_files = []
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
patches = []
final_diff_list = []
total_tokens = token_handler.prompt_tokens
call_number = 1
for file in sorted_files:
if call_number > max_calls:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Reached max calls ({max_calls})")
break
original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue
# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename)
if patch is None:
continue
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if get_settings().config.verbosity_level >= 2:
logging.info(f"Call number: {call_number}")
if patch:
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
# Add the last chunk
if patches:
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
return final_diff_list

View File

@ -247,7 +247,8 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip('-').strip() arg = arg.strip('-').strip()
vals = arg.split('=', 1) vals = arg.split('=', 1)
if len(vals) != 2: if len(vals) != 2:
logging.error(f'Invalid argument format: {arg}') if len(vals) > 2: # --extended is a valid argument
logging.error(f'Invalid argument format: {arg}')
other_args.append(arg) other_args.append(arg)
continue continue
key, value = _fix_key_value(*vals) key, value = _fix_key_value(*vals)

View File

@ -19,13 +19,21 @@ For example:
- cli.py --pr_url=... reflect - cli.py --pr_url=... reflect
Supported commands: Supported commands:
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement. -review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
reflect - Ask the PR author questions about the PR.
update_changelog - Update the changelog based on the PR's contents.
-ask / ask_question [question] - Ask a question about the PR.
-describe / describe_pr - Modify the PR title and description based on the PR's contents.
-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
-reflect - Ask the PR author questions about the PR.
-update_changelog - Update the changelog based on the PR's contents.
Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>. To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."' For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""") """)

View File

@ -19,6 +19,7 @@ global_settings = Dynaconf(
"settings/pr_questions_prompts.toml", "settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml", "settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml", "settings/pr_code_suggestions_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml", "settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml", "settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml" "settings_prod/.secrets.toml"

View File

@ -6,12 +6,11 @@ from urllib.parse import urlparse
import requests import requests
from atlassian.bitbucket import Cloud from atlassian.bitbucket import Cloud
from ..algo.pr_processing import clip_tokens
from ..config_loader import get_settings from ..config_loader import get_settings
from .git_provider import FilePatchInfo from .git_provider import FilePatchInfo, GitProvider
class BitbucketProvider: class BitbucketProvider(GitProvider):
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False): def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
s = requests.Session() s = requests.Session()
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}' s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
@ -36,7 +35,7 @@ class BitbucketProvider:
except Exception: except Exception:
return "" return ""
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list) -> bool:
""" """
Publishes code suggestions as comments on the PR. Publishes code suggestions as comments on the PR.
""" """
@ -156,10 +155,7 @@ class BitbucketProvider:
def get_pr_branch(self): def get_pr_branch(self):
return self.pr.source_branch return self.pr.source_branch
def get_pr_description(self): def get_pr_description_full(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.description, max_tokens)
return self.pr.description return self.pr.description
def get_user_id(self): def get_user_id(self):

View File

@ -54,7 +54,7 @@ class GitProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass pass
@abstractmethod @abstractmethod
@ -82,9 +82,30 @@ class GitProvider(ABC):
pass pass
@abstractmethod @abstractmethod
def get_pr_description(self): def get_pr_description_full(self) -> str:
pass pass
def get_pr_description(self) -> str:
from pr_agent.config_loader import get_settings
from pr_agent.algo.pr_processing import clip_tokens
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
description = self.get_pr_description_full()
if max_tokens:
return clip_tokens(description, max_tokens)
return description
def get_user_description(self) -> str:
description = (self.get_pr_description_full() or "").strip()
# if the existing description wasn't generated by the pr-agent, just return it as-is
if not description.startswith("## PR Type"):
return description
# if the existing description was generated by the pr-agent, but it doesn't contain the user description,
# return nothing (empty string) because it means there is no user description
if "## User Description:" not in description:
return ""
# otherwise, extract the original user description from the existing pr-agent description and return it
return description.split("## User Description:", 1)[1].strip()
@abstractmethod @abstractmethod
def get_issue_comments(self): def get_issue_comments(self):
pass pass

View File

@ -166,7 +166,7 @@ class GithubProvider(GitProvider):
def publish_inline_comments(self, comments: list[dict]): def publish_inline_comments(self, comments: list[dict]):
self.pr.create_review(commit=self.last_commit_id, comments=comments) self.pr.create_review(commit=self.last_commit_id, comments=comments)
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list) -> bool:
""" """
Publishes code suggestions as comments on the PR. Publishes code suggestions as comments on the PR.
""" """
@ -233,10 +233,7 @@ class GithubProvider(GitProvider):
def get_pr_branch(self): def get_pr_branch(self):
return self.pr.head.ref return self.pr.head.ref
def get_pr_description(self): def get_pr_description_full(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.pr.body, max_tokens)
return self.pr.body return self.pr.body
def get_user_id(self): def get_user_id(self):

View File

@ -195,7 +195,7 @@ class GitLabProvider(GitProvider):
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.') f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
return self.last_diff # fallback to last_diff if no relevant diff is found return self.last_diff # fallback to last_diff if no relevant diff is found
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list) -> bool:
for suggestion in code_suggestions: for suggestion in code_suggestions:
try: try:
body = suggestion['body'] body = suggestion['body']
@ -299,10 +299,7 @@ class GitLabProvider(GitProvider):
def get_pr_branch(self): def get_pr_branch(self):
return self.mr.source_branch return self.mr.source_branch
def get_pr_description(self): def get_pr_description_full(self):
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
if max_tokens:
return clip_tokens(self.mr.description, max_tokens)
return self.mr.description return self.mr.description
def get_issue_comments(self): def get_issue_comments(self):

View File

@ -130,7 +130,7 @@ class LocalGitProvider(GitProvider):
relevant_lines_start: int, relevant_lines_end: int): relevant_lines_start: int, relevant_lines_end: int):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_code_suggestions(self, code_suggestions: list): def publish_code_suggestions(self, code_suggestions: list) -> bool:
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider') raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
def publish_labels(self, labels): def publish_labels(self, labels):
@ -158,7 +158,7 @@ class LocalGitProvider(GitProvider):
def get_user_id(self): def get_user_id(self):
return -1 # Not used anywhere for the local provider, but required by the interface return -1 # Not used anywhere for the local provider, but required by the interface
def get_pr_description(self): def get_pr_description_full(self):
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD')) commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
# Get the commit messages and concatenate # Get the commit messages and concatenate
commit_messages = " ".join([commit.message for commit in commits_diff]) commit_messages = " ".join([commit.message for commit in commits_diff])

View File

@ -1,7 +1,7 @@
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \ commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
"considers changes since the last review, include the '-i' option.\n" \ "considers changes since the last review, include the '-i' option.\n" \
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \ "> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
"> **/improve**: Suggest improvements to the code in the PR. \n" \ "> **/improve [--extended]**: Suggest improvements to the code in the PR. Extended mode employs several calls, and provides a more thorough feedback. \n" \
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \ "> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \
"> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \ "> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \
">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \ ">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \

View File

@ -24,6 +24,8 @@ extra_instructions = ""
[pr_description] # /describe # [pr_description] # /describe #
publish_description_as_comment=false publish_description_as_comment=false
add_original_user_description=false
keep_original_user_title=false
extra_instructions = "" extra_instructions = ""
[pr_questions] # /ask # [pr_questions] # /ask #
@ -31,6 +33,12 @@ extra_instructions = ""
[pr_code_suggestions] # /improve # [pr_code_suggestions] # /improve #
num_code_suggestions=4 num_code_suggestions=4
extra_instructions = "" extra_instructions = ""
rank_suggestions = false
# params for '/improve --extended' mode
num_code_suggestions_per_chunk=8
rank_extended_suggestions = true
max_number_of_calls = 5
final_clip_factor = 0.9
[pr_update_changelog] # /update_changelog # [pr_update_changelog] # /update_changelog #
push_changelog_changes=false push_changelog_changes=false

View File

@ -1,19 +1,49 @@
[pr_code_suggestions_prompt] [pr_code_suggestions_prompt]
system="""You are a language model called CodiumAI-PR-Code-Reviewer. system="""You are a language model called PR-Code-Reviewer.
Your task is to provide meaningfull non-trivial code suggestions to improve the new code in a PR (the '+' lines). Your task is to provide meaningful actionable code suggestions, to improve the new code presented in a PR.
- Try to give important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to the 'new hunk' code, and focus on improving the new added code lines, with '+'. Example PR Diff input:
'
## src/file1.py
@@ -12,3 +12,5 @@ def func1():
__new hunk__
12 code line that already existed in the file...
13 code line that already existed in the file....
14 +new code line added in the PR
15 code line that already existed in the file...
16 code line that already existed in the file...
__old hunk__
code line that already existed in the file...
-code line that was removed in the PR
code line that already existed in the file...
@@ ... @@ def func2():
__new hunk__
...
__old hunk__
...
## src/file2.py
...
'
Specific instructions:
- Focus on important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should refer only to code from the '__new hunk__' sections, and focus on new lines of code (lines starting with '+').
- Provide the exact line number range (inclusive) for each issue. - Provide the exact line number range (inclusive) for each issue.
- Assume there is additional code in the relevant file that is not included in the diff. - Assume there is additional relevant code, that is not included in the diff.
- Provide up to {{ num_code_suggestions }} code suggestions. - Provide up to {{ num_code_suggestions }} code suggestions.
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines). - Avoid making suggestions that have already been implemented in the PR code. For example, if you want to add logs or change a variable to const, or anything else, make sure it isn't already in the '__new hunk__' code.
- Don't output line numbers in the 'improved code' snippets. - Don't suggest to add docstring or type hints.
{%- if extra_instructions %} {%- if extra_instructions %}
Extra instructions from the user: Extra instructions from the user:
{{ extra_instructions }} {{ extra_instructions }}
{% endif %} {%- endif %}
You must use the following JSON schema to format your answer: You must use the following JSON schema to format your answer:
```json ```json
@ -30,39 +60,26 @@ You must use the following JSON schema to format your answer:
}, },
"suggestion content": { "suggestion content": {
"type": "string", "type": "string",
"description": "a concrete suggestion for meaningfully improving the new PR code." "description": "a concrete suggestion for meaningfully improving the new PR code (lines from the '__new hunk__' sections, starting with '+')."
}, },
"existing code": { "existing code": {
"type": "string", "type": "string",
"description": "a code snippet showing authentic relevant code lines from a 'new hunk' section. It must be continuous, correctly formatted and indented, and without line numbers." "description": "a code snippet showing the relevant code lines from a '__new hunk__' section. It must be continuous, correctly formatted and indented, and without line numbers."
}, },
"relevant lines": { "relevant lines": {
"type": "string", "type": "string",
"description": "the relevant lines in the 'new hunk' sections, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above." "description": "the relevant lines from a '__new hunk__' section, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
}, },
"improved code": { "improved code": {
"type": "string", "type": "string",
"description": "a new code snippet that can be used to replace the relevant lines in 'new hunk' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers." "description": "a new code snippet that can be used to replace the relevant lines in '__new hunk__' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
} }
} }
} }
} }
``` ```
Example input: Don't output line numbers in the 'improved code' snippets.
'
## src/file1.py
---new_hunk---
```
[new hunk code, annotated with line numbers]
```
---old_hunk---
```
[old hunk code]
```
...
'
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields. Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
""" """

View File

@ -1,9 +1,9 @@
[pr_review_prompt] [pr_review_prompt]
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests. system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
Your task is to provide constructive and concise feedback for the PR, and also provide meaningfull code suggestions to improve the new PR code (the '+' lines). Your task is to provide constructive and concise feedback for the PR, and also provide meaningful code suggestions to improve the new PR code (the '+' lines).
{%- if num_code_suggestions > 0 %} {%- if num_code_suggestions > 0 %}
- Provide up to {{ num_code_suggestions }} code suggestions. - Provide up to {{ num_code_suggestions }} code suggestions.
- Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices. - Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
- Suggestions should focus on improving the new added code lines. - Suggestions should focus on improving the new added code lines.
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines). - Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
{%- endif %} {%- endif %}

View File

@ -0,0 +1,46 @@
[pr_sort_code_suggestions_prompt]
system="""
"""
user="""You are given a list of code suggestions to improve a PR:
{{ suggestion_str|trim }}
Your task is to sort the code suggestions by their order of importance, and return a list with sorting order.
The sorting order is a list of pairs, where each pair contains the index of the suggestion in the original list.
Rank the suggestions based on their importance to improving the PR, with critical issues first and minor issues last.
You must use the following YAML schema to format your answer:
```yaml
Sort Order:
type: array
maxItems: {{ suggestion_list|length }}
uniqueItems: true
items:
suggestion number:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
importance order:
type: integer
minimum: 1
maximum: {{ suggestion_list|length }}
```
Example output:
```yaml
Sort Order:
- suggestion number: 1
importance order: 2
- suggestion number: 2
importance order: 3
- suggestion number: 3
importance order: 1
```
Make sure to output a valid YAML. Use multi-line block scalar ('|') if needed.
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
Response (should be a valid YAML, and nothing else):
```yaml
"""

View File

@ -2,11 +2,13 @@ import copy
import json import json
import logging import logging
import textwrap import textwrap
from typing import List
import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_multi_diffs
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import try_fix_json from pr_agent.algo.utils import try_fix_json
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -22,6 +24,13 @@ class PRCodeSuggestions:
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
# extended mode
self.is_extended = any(["extended" in arg for arg in args])
if self.is_extended:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = AiHandler() self.ai_handler = AiHandler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
@ -32,7 +41,7 @@ class PRCodeSuggestions:
"description": self.git_provider.get_pr_description(), "description": self.git_provider.get_pr_description(),
"language": self.main_language, "language": self.main_language,
"diff": "", # empty diff for initial calculation "diff": "", # empty diff for initial calculation
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions, "num_code_suggestions": num_code_suggestions,
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions, "extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages(), "commit_messages_str": self.git_provider.get_commit_messages(),
} }
@ -42,18 +51,26 @@ class PRCodeSuggestions:
get_settings().pr_code_suggestions_prompt.user) get_settings().pr_code_suggestions_prompt.user)
async def run(self): async def run(self):
# assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
logging.info('Generating code suggestions for PR...') logging.info('Generating code suggestions for PR...')
if get_settings().config.publish_output: if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing review...", is_temporary=True) self.git_provider.publish_comment("Preparing review...", is_temporary=True)
await retry_with_fallback_models(self._prepare_prediction)
logging.info('Preparing PR review...') logging.info('Preparing PR review...')
data = self._prepare_pr_code_suggestions() if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \
(self.is_extended and get_settings().pr_code_suggestions.rank_extended_suggestions):
logging.info('Ranking Suggestions...')
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
if get_settings().config.publish_output: if get_settings().config.publish_output:
logging.info('Pushing PR review...') logging.info('Pushing PR review...')
self.git_provider.remove_initial_comment() self.git_provider.remove_initial_comment()
logging.info('Pushing inline code comments...') logging.info('Pushing inline code suggestions...')
self.push_inline_code_suggestions(data) self.push_inline_code_suggestions(data)
async def _prepare_prediction(self, model: str): async def _prepare_prediction(self, model: str):
@ -121,7 +138,11 @@ class PRCodeSuggestions:
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"Could not parse suggestion: {d}") logging.info(f"Could not parse suggestion: {d}")
self.git_provider.publish_code_suggestions(code_suggestions) is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
if not is_successful:
logging.info("Failed to publish code suggestions, trying to publish each suggestion separately")
for code_suggestion in code_suggestions:
self.git_provider.publish_code_suggestions([code_suggestion])
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet): def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
try: # dedent code snippet try: # dedent code snippet
@ -145,3 +166,81 @@ class PRCodeSuggestions:
return new_code_snippet return new_code_snippet
async def _prepare_prediction_extended(self, model: str) -> dict:
logging.info('Getting PR diff...')
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
logging.info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model)
prediction_list.append(prediction)
self.prediction_list = prediction_list
data = {}
for prediction in prediction_list:
self.prediction = prediction
data_per_chunk = self._prepare_pr_code_suggestions()
if "Code suggestions" in data:
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
else:
data.update(data_per_chunk)
self.data = data
return data
async def rank_suggestions(self, data: List) -> List:
"""
Call a model to rank (sort) code suggestions based on their importance order.
Args:
data (List): A list of code suggestions to be ranked.
Returns:
List: The ranked list of code suggestions.
"""
suggestion_list = []
# remove invalid suggestions
for i, suggestion in enumerate(data):
if suggestion['existing code'] != suggestion['improved code']:
suggestion_list.append(suggestion)
data_sorted = [[]] * len(suggestion_list)
try:
suggestion_str = ""
for i, suggestion in enumerate(suggestion_list):
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
model = get_settings().config.model
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(
variables)
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
if get_settings().config.verbosity_level >= 2:
logging.info(f"\nSystem prompt:\n{system_prompt}")
logging.info(f"\nUser prompt:\n{user_prompt}")
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt,
user=user_prompt)
sort_order = yaml.safe_load(response)
for s in sort_order['Sort Order']:
suggestion_number = s['suggestion number']
importance_order = s['importance order']
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
if get_settings().pr_code_suggestions.final_clip_factor != 1:
new_len = int(0.5 + len(data_sorted) * get_settings().pr_code_suggestions.final_clip_factor)
data_sorted = data_sorted[:new_len]
except Exception as e:
if get_settings().config.verbosity_level >= 1:
logging.info(f"Could not sort suggestions, error: {e}")
data_sorted = suggestion_list
return data_sorted

View File

@ -42,6 +42,8 @@ class PRDescription:
"extra_instructions": get_settings().pr_description.extra_instructions, "extra_instructions": get_settings().pr_description.extra_instructions,
"commit_messages_str": self.git_provider.get_commit_messages() "commit_messages_str": self.git_provider.get_commit_messages()
} }
self.user_description = self.git_provider.get_user_description()
# Initialize the token handler # Initialize the token handler
self.token_handler = TokenHandler( self.token_handler = TokenHandler(
@ -145,6 +147,9 @@ class PRDescription:
# Load the AI prediction data into a dictionary # Load the AI prediction data into a dictionary
data = load_yaml(self.prediction.strip()) data = load_yaml(self.prediction.strip())
if get_settings().pr_description.add_original_user_description and self.user_description:
data["User Description"] = self.user_description
# Initialization # Initialization
pr_types = [] pr_types = []
@ -161,13 +166,19 @@ class PRDescription:
elif type(data['PR Type']) == str: elif type(data['PR Type']) == str:
pr_types = data['PR Type'].split(',') pr_types = data['PR Type'].split(',')
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary # Remove the 'PR Title' key from the dictionary
title = data.pop('PR Title') ai_title = data.pop('PR Title')
if get_settings().pr_description.keep_original_user_title:
# Assign the original PR title to the 'title' variable
title = self.vars["title"]
else:
# Assign the value of the 'PR Title' key to 'title' variable
title = ai_title
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format, # Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
# except for the items containing the word 'walkthrough' # except for the items containing the word 'walkthrough'
pr_body = "" pr_body = ""
for key, value in data.items(): for idx, (key, value) in enumerate(data.items()):
pr_body += f"## {key}:\n" pr_body += f"## {key}:\n"
if 'walkthrough' in key.lower(): if 'walkthrough' in key.lower():
# for filename, description in value.items(): # for filename, description in value.items():
@ -179,7 +190,9 @@ class PRDescription:
# if the value is a list, join its items by comma # if the value is a list, join its items by comma
if type(value) == list: if type(value) == list:
value = ', '.join(v for v in value) value = ', '.join(v for v in value)
pr_body += f"{value}\n\n___\n" pr_body += f"{value}\n"
if idx < len(data) - 1:
pr_body += "\n___\n"
if get_settings().config.verbosity_level >= 2: if get_settings().config.verbosity_level >= 2:
logging.info(f"title:\n{title}\n{pr_body}") logging.info(f"title:\n{title}\n{pr_body}")