mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 12:20:38 +08:00
Merge branch 'main' into zmeir-automatic_github_app_options
This commit is contained in:
@ -54,6 +54,10 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
pr_agent_job:
|
pr_agent_job:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
contents: write
|
||||||
name: Run pr agent on every pull request, respond to user comments
|
name: Run pr agent on every pull request, respond to user comments
|
||||||
steps:
|
steps:
|
||||||
- name: PR Agent action step
|
- name: PR Agent action step
|
||||||
@ -72,6 +76,10 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
pr_agent_job:
|
pr_agent_job:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
|
contents: write
|
||||||
name: Run pr agent on every pull request, respond to user comments
|
name: Run pr agent on every pull request, respond to user comments
|
||||||
steps:
|
steps:
|
||||||
- name: PR Agent action step
|
- name: PR Agent action step
|
||||||
|
@ -55,7 +55,7 @@ class AiHandler:
|
|||||||
|
|
||||||
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
|
||||||
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
|
||||||
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||||
"""
|
"""
|
||||||
Performs a chat completion using the OpenAI ChatCompletion API.
|
Performs a chat completion using the OpenAI ChatCompletion API.
|
||||||
Retries in case of API errors or timeouts.
|
Retries in case of API errors or timeouts.
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -157,7 +156,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
|||||||
|
|
||||||
example output:
|
example output:
|
||||||
## src/file.ts
|
## src/file.ts
|
||||||
--new hunk--
|
__new hunk__
|
||||||
881 line1
|
881 line1
|
||||||
882 line2
|
882 line2
|
||||||
883 line3
|
883 line3
|
||||||
@ -166,7 +165,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
|||||||
889 line6
|
889 line6
|
||||||
890 line7
|
890 line7
|
||||||
...
|
...
|
||||||
--old hunk--
|
__old hunk__
|
||||||
line1
|
line1
|
||||||
line2
|
line2
|
||||||
- line3
|
- line3
|
||||||
@ -176,8 +175,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
|||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
patch_with_lines_str = f"## {file.filename}\n"
|
patch_with_lines_str = f"\n\n## {file.filename}\n"
|
||||||
import re
|
|
||||||
patch_lines = patch.splitlines()
|
patch_lines = patch.splitlines()
|
||||||
RE_HUNK_HEADER = re.compile(
|
RE_HUNK_HEADER = re.compile(
|
||||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||||
@ -185,23 +183,30 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
|||||||
old_content_lines = []
|
old_content_lines = []
|
||||||
match = None
|
match = None
|
||||||
start1, size1, start2, size2 = -1, -1, -1, -1
|
start1, size1, start2, size2 = -1, -1, -1, -1
|
||||||
|
prev_header_line = []
|
||||||
|
header_line =[]
|
||||||
for line in patch_lines:
|
for line in patch_lines:
|
||||||
if 'no newline at end of file' in line.lower():
|
if 'no newline at end of file' in line.lower():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if line.startswith('@@'):
|
if line.startswith('@@'):
|
||||||
|
header_line = line
|
||||||
match = RE_HUNK_HEADER.match(line)
|
match = RE_HUNK_HEADER.match(line)
|
||||||
if match and new_content_lines: # found a new hunk, split the previous lines
|
if match and new_content_lines: # found a new hunk, split the previous lines
|
||||||
if new_content_lines:
|
if new_content_lines:
|
||||||
patch_with_lines_str += '\n--new hunk--\n'
|
if prev_header_line:
|
||||||
|
patch_with_lines_str += f'\n{prev_header_line}\n'
|
||||||
|
patch_with_lines_str += '__new hunk__\n'
|
||||||
for i, line_new in enumerate(new_content_lines):
|
for i, line_new in enumerate(new_content_lines):
|
||||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||||
if old_content_lines:
|
if old_content_lines:
|
||||||
patch_with_lines_str += '--old hunk--\n'
|
patch_with_lines_str += '__old hunk__\n'
|
||||||
for line_old in old_content_lines:
|
for line_old in old_content_lines:
|
||||||
patch_with_lines_str += f"{line_old}\n"
|
patch_with_lines_str += f"{line_old}\n"
|
||||||
new_content_lines = []
|
new_content_lines = []
|
||||||
old_content_lines = []
|
old_content_lines = []
|
||||||
|
if match:
|
||||||
|
prev_header_line = header_line
|
||||||
try:
|
try:
|
||||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||||
except: # '@@ -0,0 +1 @@' case
|
except: # '@@ -0,0 +1 @@' case
|
||||||
@ -219,12 +224,13 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
|
|||||||
# finishing last hunk
|
# finishing last hunk
|
||||||
if match and new_content_lines:
|
if match and new_content_lines:
|
||||||
if new_content_lines:
|
if new_content_lines:
|
||||||
patch_with_lines_str += '\n--new hunk--\n'
|
patch_with_lines_str += f'\n{header_line}\n'
|
||||||
|
patch_with_lines_str += '\n__new hunk__\n'
|
||||||
for i, line_new in enumerate(new_content_lines):
|
for i, line_new in enumerate(new_content_lines):
|
||||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||||
if old_content_lines:
|
if old_content_lines:
|
||||||
patch_with_lines_str += '\n--old hunk--\n'
|
patch_with_lines_str += '\n__old hunk__\n'
|
||||||
for line_old in old_content_lines:
|
for line_old in old_content_lines:
|
||||||
patch_with_lines_str += f"{line_old}\n"
|
patch_with_lines_str += f"{line_old}\n"
|
||||||
|
|
||||||
return patch_with_lines_str.strip()
|
return patch_with_lines_str.rstrip()
|
||||||
|
@ -24,7 +24,7 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
|||||||
PATCH_EXTRA_LINES = 3
|
PATCH_EXTRA_LINES = 3
|
||||||
|
|
||||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
add_line_numbers_to_hunks: bool = True, disable_extra_lines: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
|||||||
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||||
|
|
||||||
# generate a standard diff string, with patch extension
|
# generate a standard diff string, with patch extension
|
||||||
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler,
|
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(pr_languages, token_handler,
|
||||||
add_line_numbers_to_hunks)
|
add_line_numbers_to_hunks)
|
||||||
|
|
||||||
# if we are under the limit, return the full diff
|
# if we are under the limit, return the full diff
|
||||||
@ -78,9 +78,9 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
|
|||||||
return final_diff
|
return final_diff
|
||||||
|
|
||||||
|
|
||||||
def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
def pr_generate_extended_diff(pr_languages: list,
|
||||||
add_line_numbers_to_hunks: bool) -> \
|
token_handler: TokenHandler,
|
||||||
Tuple[list, int]:
|
add_line_numbers_to_hunks: bool) -> Tuple[list, int, list]:
|
||||||
"""
|
"""
|
||||||
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
|
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
|
||||||
minimization techniques if needed.
|
minimization techniques if needed.
|
||||||
@ -90,13 +90,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
files.
|
files.
|
||||||
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||||
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
|
||||||
|
|
||||||
Returns:
|
|
||||||
- patches_extended: A list of extended patches for each file in the pull request.
|
|
||||||
- total_tokens: The total number of tokens used in the extended patches.
|
|
||||||
"""
|
"""
|
||||||
total_tokens = token_handler.prompt_tokens # initial tokens
|
total_tokens = token_handler.prompt_tokens # initial tokens
|
||||||
patches_extended = []
|
patches_extended = []
|
||||||
|
patches_extended_tokens = []
|
||||||
for lang in pr_languages:
|
for lang in pr_languages:
|
||||||
for file in lang['files']:
|
for file in lang['files']:
|
||||||
original_file_content_str = file.base_file
|
original_file_content_str = file.base_file
|
||||||
@ -106,7 +103,7 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
|
|
||||||
# extend each patch with extra lines of context
|
# extend each patch with extra lines of context
|
||||||
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
|
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
|
||||||
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
|
full_extended_patch = f"\n\n## {file.filename}\n\n{extended_patch}\n"
|
||||||
|
|
||||||
if add_line_numbers_to_hunks:
|
if add_line_numbers_to_hunks:
|
||||||
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
|
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)
|
||||||
@ -114,9 +111,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
patch_tokens = token_handler.count_tokens(full_extended_patch)
|
||||||
file.tokens = patch_tokens
|
file.tokens = patch_tokens
|
||||||
total_tokens += patch_tokens
|
total_tokens += patch_tokens
|
||||||
|
patches_extended_tokens.append(patch_tokens)
|
||||||
patches_extended.append(full_extended_patch)
|
patches_extended.append(full_extended_patch)
|
||||||
|
|
||||||
return patches_extended, total_tokens
|
return patches_extended, total_tokens, patches_extended_tokens
|
||||||
|
|
||||||
|
|
||||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||||
@ -324,7 +322,9 @@ def clip_tokens(text: str, max_tokens: int) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The clipped string.
|
str: The clipped string.
|
||||||
"""
|
"""
|
||||||
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word
|
if not text:
|
||||||
|
return text
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoder = get_token_encoder()
|
encoder = get_token_encoder()
|
||||||
num_input_tokens = len(encoder.encode(text))
|
num_input_tokens = len(encoder.encode(text))
|
||||||
@ -337,4 +337,84 @@ def clip_tokens(text: str, max_tokens: int) -> str:
|
|||||||
return clipped_text
|
return clipped_text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to clip tokens: {e}")
|
logging.warning(f"Failed to clip tokens: {e}")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def get_pr_multi_diffs(git_provider: GitProvider,
|
||||||
|
token_handler: TokenHandler,
|
||||||
|
model: str,
|
||||||
|
max_calls: int = 5) -> List[str]:
|
||||||
|
"""
|
||||||
|
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
|
||||||
|
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
git_provider (GitProvider): An object that provides access to Git provider APIs.
|
||||||
|
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
|
||||||
|
model (str): The name of the model.
|
||||||
|
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
diff_files = git_provider.get_diff_files()
|
||||||
|
except RateLimitExceededException as e:
|
||||||
|
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Sort files by main language
|
||||||
|
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)
|
||||||
|
|
||||||
|
# Sort files within each language group by tokens in descending order
|
||||||
|
sorted_files = []
|
||||||
|
for lang in pr_languages:
|
||||||
|
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||||
|
|
||||||
|
patches = []
|
||||||
|
final_diff_list = []
|
||||||
|
total_tokens = token_handler.prompt_tokens
|
||||||
|
call_number = 1
|
||||||
|
for file in sorted_files:
|
||||||
|
if call_number > max_calls:
|
||||||
|
if get_settings().config.verbosity_level >= 2:
|
||||||
|
logging.info(f"Reached max calls ({max_calls})")
|
||||||
|
break
|
||||||
|
|
||||||
|
original_file_content_str = file.base_file
|
||||||
|
new_file_content_str = file.head_file
|
||||||
|
patch = file.patch
|
||||||
|
if not patch:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove delete-only hunks
|
||||||
|
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename)
|
||||||
|
if patch is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||||
|
new_patch_tokens = token_handler.count_tokens(patch)
|
||||||
|
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
|
||||||
|
final_diff = "\n".join(patches)
|
||||||
|
final_diff_list.append(final_diff)
|
||||||
|
patches = []
|
||||||
|
total_tokens = token_handler.prompt_tokens
|
||||||
|
call_number += 1
|
||||||
|
if get_settings().config.verbosity_level >= 2:
|
||||||
|
logging.info(f"Call number: {call_number}")
|
||||||
|
|
||||||
|
if patch:
|
||||||
|
patches.append(patch)
|
||||||
|
total_tokens += new_patch_tokens
|
||||||
|
if get_settings().config.verbosity_level >= 2:
|
||||||
|
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")
|
||||||
|
|
||||||
|
# Add the last chunk
|
||||||
|
if patches:
|
||||||
|
final_diff = "\n".join(patches)
|
||||||
|
final_diff_list.append(final_diff)
|
||||||
|
|
||||||
|
return final_diff_list
|
||||||
|
@ -247,7 +247,8 @@ def update_settings_from_args(args: List[str]) -> List[str]:
|
|||||||
arg = arg.strip('-').strip()
|
arg = arg.strip('-').strip()
|
||||||
vals = arg.split('=', 1)
|
vals = arg.split('=', 1)
|
||||||
if len(vals) != 2:
|
if len(vals) != 2:
|
||||||
logging.error(f'Invalid argument format: {arg}')
|
if len(vals) > 2: # --extended is a valid argument
|
||||||
|
logging.error(f'Invalid argument format: {arg}')
|
||||||
other_args.append(arg)
|
other_args.append(arg)
|
||||||
continue
|
continue
|
||||||
key, value = _fix_key_value(*vals)
|
key, value = _fix_key_value(*vals)
|
||||||
|
@ -19,13 +19,21 @@ For example:
|
|||||||
- cli.py --pr_url=... reflect
|
- cli.py --pr_url=... reflect
|
||||||
|
|
||||||
Supported commands:
|
Supported commands:
|
||||||
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
|
-review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
|
||||||
ask / ask_question [question] - Ask a question about the PR.
|
|
||||||
describe / describe_pr - Modify the PR title and description based on the PR's contents.
|
|
||||||
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
|
|
||||||
reflect - Ask the PR author questions about the PR.
|
|
||||||
update_changelog - Update the changelog based on the PR's contents.
|
|
||||||
|
|
||||||
|
-ask / ask_question [question] - Ask a question about the PR.
|
||||||
|
|
||||||
|
-describe / describe_pr - Modify the PR title and description based on the PR's contents.
|
||||||
|
|
||||||
|
-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
|
||||||
|
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
|
||||||
|
|
||||||
|
-reflect - Ask the PR author questions about the PR.
|
||||||
|
|
||||||
|
-update_changelog - Update the changelog based on the PR's contents.
|
||||||
|
|
||||||
|
|
||||||
|
Configuration:
|
||||||
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
|
||||||
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
|
||||||
""")
|
""")
|
||||||
|
@ -19,6 +19,7 @@ global_settings = Dynaconf(
|
|||||||
"settings/pr_questions_prompts.toml",
|
"settings/pr_questions_prompts.toml",
|
||||||
"settings/pr_description_prompts.toml",
|
"settings/pr_description_prompts.toml",
|
||||||
"settings/pr_code_suggestions_prompts.toml",
|
"settings/pr_code_suggestions_prompts.toml",
|
||||||
|
"settings/pr_sort_code_suggestions_prompts.toml",
|
||||||
"settings/pr_information_from_user_prompts.toml",
|
"settings/pr_information_from_user_prompts.toml",
|
||||||
"settings/pr_update_changelog_prompts.toml",
|
"settings/pr_update_changelog_prompts.toml",
|
||||||
"settings_prod/.secrets.toml"
|
"settings_prod/.secrets.toml"
|
||||||
|
@ -6,12 +6,11 @@ from urllib.parse import urlparse
|
|||||||
import requests
|
import requests
|
||||||
from atlassian.bitbucket import Cloud
|
from atlassian.bitbucket import Cloud
|
||||||
|
|
||||||
from ..algo.pr_processing import clip_tokens
|
|
||||||
from ..config_loader import get_settings
|
from ..config_loader import get_settings
|
||||||
from .git_provider import FilePatchInfo
|
from .git_provider import FilePatchInfo, GitProvider
|
||||||
|
|
||||||
|
|
||||||
class BitbucketProvider:
|
class BitbucketProvider(GitProvider):
|
||||||
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
def __init__(self, pr_url: Optional[str] = None, incremental: Optional[bool] = False):
|
||||||
s = requests.Session()
|
s = requests.Session()
|
||||||
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
s.headers['Authorization'] = f'Bearer {get_settings().get("BITBUCKET.BEARER_TOKEN", None)}'
|
||||||
@ -36,7 +35,7 @@ class BitbucketProvider:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Publishes code suggestions as comments on the PR.
|
Publishes code suggestions as comments on the PR.
|
||||||
"""
|
"""
|
||||||
@ -156,10 +155,7 @@ class BitbucketProvider:
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.source_branch
|
return self.pr.source_branch
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.pr.description, max_tokens)
|
|
||||||
return self.pr.description
|
return self.pr.description
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
|
@ -54,7 +54,7 @@ class GitProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -82,9 +82,30 @@ class GitProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_pr_description(self) -> str:
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
from pr_agent.algo.pr_processing import clip_tokens
|
||||||
|
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
||||||
|
description = self.get_pr_description_full()
|
||||||
|
if max_tokens:
|
||||||
|
return clip_tokens(description, max_tokens)
|
||||||
|
return description
|
||||||
|
|
||||||
|
def get_user_description(self) -> str:
|
||||||
|
description = (self.get_pr_description_full() or "").strip()
|
||||||
|
# if the existing description wasn't generated by the pr-agent, just return it as-is
|
||||||
|
if not description.startswith("## PR Type"):
|
||||||
|
return description
|
||||||
|
# if the existing description was generated by the pr-agent, but it doesn't contain the user description,
|
||||||
|
# return nothing (empty string) because it means there is no user description
|
||||||
|
if "## User Description:" not in description:
|
||||||
|
return ""
|
||||||
|
# otherwise, extract the original user description from the existing pr-agent description and return it
|
||||||
|
return description.split("## User Description:", 1)[1].strip()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
pass
|
pass
|
||||||
|
@ -166,7 +166,7 @@ class GithubProvider(GitProvider):
|
|||||||
def publish_inline_comments(self, comments: list[dict]):
|
def publish_inline_comments(self, comments: list[dict]):
|
||||||
self.pr.create_review(commit=self.last_commit_id, comments=comments)
|
self.pr.create_review(commit=self.last_commit_id, comments=comments)
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Publishes code suggestions as comments on the PR.
|
Publishes code suggestions as comments on the PR.
|
||||||
"""
|
"""
|
||||||
@ -233,10 +233,7 @@ class GithubProvider(GitProvider):
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.pr.head.ref
|
return self.pr.head.ref
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.pr.body, max_tokens)
|
|
||||||
return self.pr.body
|
return self.pr.body
|
||||||
|
|
||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
|
@ -195,7 +195,7 @@ class GitLabProvider(GitProvider):
|
|||||||
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
|
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
|
||||||
return self.last_diff # fallback to last_diff if no relevant diff is found
|
return self.last_diff # fallback to last_diff if no relevant diff is found
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
for suggestion in code_suggestions:
|
for suggestion in code_suggestions:
|
||||||
try:
|
try:
|
||||||
body = suggestion['body']
|
body = suggestion['body']
|
||||||
@ -299,10 +299,7 @@ class GitLabProvider(GitProvider):
|
|||||||
def get_pr_branch(self):
|
def get_pr_branch(self):
|
||||||
return self.mr.source_branch
|
return self.mr.source_branch
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
max_tokens = get_settings().get("CONFIG.MAX_DESCRIPTION_TOKENS", None)
|
|
||||||
if max_tokens:
|
|
||||||
return clip_tokens(self.mr.description, max_tokens)
|
|
||||||
return self.mr.description
|
return self.mr.description
|
||||||
|
|
||||||
def get_issue_comments(self):
|
def get_issue_comments(self):
|
||||||
|
@ -130,7 +130,7 @@ class LocalGitProvider(GitProvider):
|
|||||||
relevant_lines_start: int, relevant_lines_end: int):
|
relevant_lines_start: int, relevant_lines_end: int):
|
||||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||||
|
|
||||||
def publish_code_suggestions(self, code_suggestions: list):
|
def publish_code_suggestions(self, code_suggestions: list) -> bool:
|
||||||
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')
|
||||||
|
|
||||||
def publish_labels(self, labels):
|
def publish_labels(self, labels):
|
||||||
@ -158,7 +158,7 @@ class LocalGitProvider(GitProvider):
|
|||||||
def get_user_id(self):
|
def get_user_id(self):
|
||||||
return -1 # Not used anywhere for the local provider, but required by the interface
|
return -1 # Not used anywhere for the local provider, but required by the interface
|
||||||
|
|
||||||
def get_pr_description(self):
|
def get_pr_description_full(self):
|
||||||
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
commits_diff = list(self.repo.iter_commits(self.target_branch_name + '..HEAD'))
|
||||||
# Get the commit messages and concatenate
|
# Get the commit messages and concatenate
|
||||||
commit_messages = " ".join([commit.message for commit in commits_diff])
|
commit_messages = " ".join([commit.message for commit in commits_diff])
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
|
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
|
||||||
"considers changes since the last review, include the '-i' option.\n" \
|
"considers changes since the last review, include the '-i' option.\n" \
|
||||||
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
|
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
|
||||||
"> **/improve**: Suggest improvements to the code in the PR. \n" \
|
"> **/improve [--extended]**: Suggest improvements to the code in the PR. Extended mode employs several calls, and provides a more thorough feedback. \n" \
|
||||||
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \
|
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \
|
||||||
"> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \
|
"> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \
|
||||||
">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \
|
">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \
|
||||||
|
@ -24,6 +24,8 @@ extra_instructions = ""
|
|||||||
|
|
||||||
[pr_description] # /describe #
|
[pr_description] # /describe #
|
||||||
publish_description_as_comment=false
|
publish_description_as_comment=false
|
||||||
|
add_original_user_description=false
|
||||||
|
keep_original_user_title=false
|
||||||
extra_instructions = ""
|
extra_instructions = ""
|
||||||
|
|
||||||
[pr_questions] # /ask #
|
[pr_questions] # /ask #
|
||||||
@ -31,6 +33,12 @@ extra_instructions = ""
|
|||||||
[pr_code_suggestions] # /improve #
|
[pr_code_suggestions] # /improve #
|
||||||
num_code_suggestions=4
|
num_code_suggestions=4
|
||||||
extra_instructions = ""
|
extra_instructions = ""
|
||||||
|
rank_suggestions = false
|
||||||
|
# params for '/improve --extended' mode
|
||||||
|
num_code_suggestions_per_chunk=8
|
||||||
|
rank_extended_suggestions = true
|
||||||
|
max_number_of_calls = 5
|
||||||
|
final_clip_factor = 0.9
|
||||||
|
|
||||||
[pr_update_changelog] # /update_changelog #
|
[pr_update_changelog] # /update_changelog #
|
||||||
push_changelog_changes=false
|
push_changelog_changes=false
|
||||||
|
@ -1,19 +1,49 @@
|
|||||||
[pr_code_suggestions_prompt]
|
[pr_code_suggestions_prompt]
|
||||||
system="""You are a language model called CodiumAI-PR-Code-Reviewer.
|
system="""You are a language model called PR-Code-Reviewer.
|
||||||
Your task is to provide meaningfull non-trivial code suggestions to improve the new code in a PR (the '+' lines).
|
Your task is to provide meaningful actionable code suggestions, to improve the new code presented in a PR.
|
||||||
- Try to give important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
|
|
||||||
- Suggestions should refer only to the 'new hunk' code, and focus on improving the new added code lines, with '+'.
|
Example PR Diff input:
|
||||||
|
'
|
||||||
|
## src/file1.py
|
||||||
|
|
||||||
|
@@ -12,3 +12,5 @@ def func1():
|
||||||
|
__new hunk__
|
||||||
|
12 code line that already existed in the file...
|
||||||
|
13 code line that already existed in the file....
|
||||||
|
14 +new code line added in the PR
|
||||||
|
15 code line that already existed in the file...
|
||||||
|
16 code line that already existed in the file...
|
||||||
|
__old hunk__
|
||||||
|
code line that already existed in the file...
|
||||||
|
-code line that was removed in the PR
|
||||||
|
code line that already existed in the file...
|
||||||
|
|
||||||
|
|
||||||
|
@@ ... @@ def func2():
|
||||||
|
__new hunk__
|
||||||
|
...
|
||||||
|
__old hunk__
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
## src/file2.py
|
||||||
|
...
|
||||||
|
'
|
||||||
|
|
||||||
|
Specific instructions:
|
||||||
|
- Focus on important suggestions like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
|
||||||
|
- Suggestions should refer only to code from the '__new hunk__' sections, and focus on new lines of code (lines starting with '+').
|
||||||
- Provide the exact line number range (inclusive) for each issue.
|
- Provide the exact line number range (inclusive) for each issue.
|
||||||
- Assume there is additional code in the relevant file that is not included in the diff.
|
- Assume there is additional relevant code, that is not included in the diff.
|
||||||
- Provide up to {{ num_code_suggestions }} code suggestions.
|
- Provide up to {{ num_code_suggestions }} code suggestions.
|
||||||
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
|
- Avoid making suggestions that have already been implemented in the PR code. For example, if you want to add logs or change a variable to const, or anything else, make sure it isn't already in the '__new hunk__' code.
|
||||||
- Don't output line numbers in the 'improved code' snippets.
|
- Don't suggest to add docstring or type hints.
|
||||||
|
|
||||||
{%- if extra_instructions %}
|
{%- if extra_instructions %}
|
||||||
|
|
||||||
Extra instructions from the user:
|
Extra instructions from the user:
|
||||||
{{ extra_instructions }}
|
{{ extra_instructions }}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
|
|
||||||
You must use the following JSON schema to format your answer:
|
You must use the following JSON schema to format your answer:
|
||||||
```json
|
```json
|
||||||
@ -30,39 +60,26 @@ You must use the following JSON schema to format your answer:
|
|||||||
},
|
},
|
||||||
"suggestion content": {
|
"suggestion content": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "a concrete suggestion for meaningfully improving the new PR code."
|
"description": "a concrete suggestion for meaningfully improving the new PR code (lines from the '__new hunk__' sections, starting with '+')."
|
||||||
},
|
},
|
||||||
"existing code": {
|
"existing code": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "a code snippet showing authentic relevant code lines from a 'new hunk' section. It must be continuous, correctly formatted and indented, and without line numbers."
|
"description": "a code snippet showing the relevant code lines from a '__new hunk__' section. It must be continuous, correctly formatted and indented, and without line numbers."
|
||||||
},
|
},
|
||||||
"relevant lines": {
|
"relevant lines": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "the relevant lines in the 'new hunk' sections, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
|
"description": "the relevant lines from a '__new hunk__' section, in the format of 'start_line-end_line'. For example: '10-15'. They should be derived from the hunk line numbers, and correspond to the 'existing code' snippet above."
|
||||||
},
|
},
|
||||||
"improved code": {
|
"improved code": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "a new code snippet that can be used to replace the relevant lines in 'new hunk' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
|
"description": "a new code snippet that can be used to replace the relevant lines in '__new hunk__' code. Replacement suggestions should be complete, correctly formatted and indented, and without line numbers."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Example input:
|
Don't output line numbers in the 'improved code' snippets.
|
||||||
'
|
|
||||||
## src/file1.py
|
|
||||||
---new_hunk---
|
|
||||||
```
|
|
||||||
[new hunk code, annotated with line numbers]
|
|
||||||
```
|
|
||||||
---old_hunk---
|
|
||||||
```
|
|
||||||
[old hunk code]
|
|
||||||
```
|
|
||||||
...
|
|
||||||
'
|
|
||||||
|
|
||||||
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
|
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
[pr_review_prompt]
|
[pr_review_prompt]
|
||||||
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
|
system="""You are CodiumAI-PR-Reviewer, a language model designed to review git pull requests.
|
||||||
Your task is to provide constructive and concise feedback for the PR, and also provide meaningfull code suggestions to improve the new PR code (the '+' lines).
|
Your task is to provide constructive and concise feedback for the PR, and also provide meaningful code suggestions to improve the new PR code (the '+' lines).
|
||||||
{%- if num_code_suggestions > 0 %}
|
{%- if num_code_suggestions > 0 %}
|
||||||
- Provide up to {{ num_code_suggestions }} code suggestions.
|
- Provide up to {{ num_code_suggestions }} code suggestions.
|
||||||
- Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningfull code improvements, like performance, vulnerability, modularity, and best practices.
|
- Try to focus on the most important suggestions, like fixing code problems, issues and bugs. As a second priority, provide suggestions for meaningful code improvements, like performance, vulnerability, modularity, and best practices.
|
||||||
- Suggestions should focus on improving the new added code lines.
|
- Suggestions should focus on improving the new added code lines.
|
||||||
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
|
- Make sure not to provide suggestions repeating modifications already implemented in the new PR code (the '+' lines).
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
46
pr_agent/settings/pr_sort_code_suggestions_prompts.toml
Normal file
46
pr_agent/settings/pr_sort_code_suggestions_prompts.toml
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
[pr_sort_code_suggestions_prompt]
|
||||||
|
system="""
|
||||||
|
"""
|
||||||
|
|
||||||
|
user="""You are given a list of code suggestions to improve a PR:
|
||||||
|
|
||||||
|
{{ suggestion_str|trim }}
|
||||||
|
|
||||||
|
|
||||||
|
Your task is to sort the code suggestions by their order of importance, and return a list with sorting order.
|
||||||
|
The sorting order is a list of pairs, where each pair contains the index of the suggestion in the original list.
|
||||||
|
Rank the suggestions based on their importance to improving the PR, with critical issues first and minor issues last.
|
||||||
|
|
||||||
|
You must use the following YAML schema to format your answer:
|
||||||
|
```yaml
|
||||||
|
Sort Order:
|
||||||
|
type: array
|
||||||
|
maxItems: {{ suggestion_list|length }}
|
||||||
|
uniqueItems: true
|
||||||
|
items:
|
||||||
|
suggestion number:
|
||||||
|
type: integer
|
||||||
|
minimum: 1
|
||||||
|
maximum: {{ suggestion_list|length }}
|
||||||
|
importance order:
|
||||||
|
type: integer
|
||||||
|
minimum: 1
|
||||||
|
maximum: {{ suggestion_list|length }}
|
||||||
|
```
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
```yaml
|
||||||
|
Sort Order:
|
||||||
|
- suggestion number: 1
|
||||||
|
importance order: 2
|
||||||
|
- suggestion number: 2
|
||||||
|
importance order: 3
|
||||||
|
- suggestion number: 3
|
||||||
|
importance order: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure to output a valid YAML. Use multi-line block scalar ('|') if needed.
|
||||||
|
Don't repeat the prompt in the answer, and avoid outputting the 'type' and 'description' fields.
|
||||||
|
Response (should be a valid YAML, and nothing else):
|
||||||
|
```yaml
|
||||||
|
"""
|
@ -2,11 +2,13 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import yaml
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models, get_pr_multi_diffs
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import try_fix_json
|
from pr_agent.algo.utils import try_fix_json
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -22,6 +24,13 @@ class PRCodeSuggestions:
|
|||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# extended mode
|
||||||
|
self.is_extended = any(["extended" in arg for arg in args])
|
||||||
|
if self.is_extended:
|
||||||
|
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions_per_chunk
|
||||||
|
else:
|
||||||
|
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
||||||
|
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = AiHandler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
@ -32,7 +41,7 @@ class PRCodeSuggestions:
|
|||||||
"description": self.git_provider.get_pr_description(),
|
"description": self.git_provider.get_pr_description(),
|
||||||
"language": self.main_language,
|
"language": self.main_language,
|
||||||
"diff": "", # empty diff for initial calculation
|
"diff": "", # empty diff for initial calculation
|
||||||
"num_code_suggestions": get_settings().pr_code_suggestions.num_code_suggestions,
|
"num_code_suggestions": num_code_suggestions,
|
||||||
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
"extra_instructions": get_settings().pr_code_suggestions.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages(),
|
"commit_messages_str": self.git_provider.get_commit_messages(),
|
||||||
}
|
}
|
||||||
@ -42,18 +51,26 @@ class PRCodeSuggestions:
|
|||||||
get_settings().pr_code_suggestions_prompt.user)
|
get_settings().pr_code_suggestions_prompt.user)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
# assert type(self.git_provider) != BitbucketProvider, "Bitbucket is not supported for now"
|
|
||||||
|
|
||||||
logging.info('Generating code suggestions for PR...')
|
logging.info('Generating code suggestions for PR...')
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
|
||||||
logging.info('Preparing PR review...')
|
logging.info('Preparing PR review...')
|
||||||
data = self._prepare_pr_code_suggestions()
|
if not self.is_extended:
|
||||||
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
|
data = self._prepare_pr_code_suggestions()
|
||||||
|
else:
|
||||||
|
data = await retry_with_fallback_models(self._prepare_prediction_extended)
|
||||||
|
|
||||||
|
if (not self.is_extended and get_settings().pr_code_suggestions.rank_suggestions) or \
|
||||||
|
(self.is_extended and get_settings().pr_code_suggestions.rank_extended_suggestions):
|
||||||
|
logging.info('Ranking Suggestions...')
|
||||||
|
data['Code suggestions'] = await self.rank_suggestions(data['Code suggestions'])
|
||||||
|
|
||||||
if get_settings().config.publish_output:
|
if get_settings().config.publish_output:
|
||||||
logging.info('Pushing PR review...')
|
logging.info('Pushing PR review...')
|
||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
logging.info('Pushing inline code comments...')
|
logging.info('Pushing inline code suggestions...')
|
||||||
self.push_inline_code_suggestions(data)
|
self.push_inline_code_suggestions(data)
|
||||||
|
|
||||||
async def _prepare_prediction(self, model: str):
|
async def _prepare_prediction(self, model: str):
|
||||||
@ -121,7 +138,11 @@ class PRCodeSuggestions:
|
|||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"Could not parse suggestion: {d}")
|
logging.info(f"Could not parse suggestion: {d}")
|
||||||
|
|
||||||
self.git_provider.publish_code_suggestions(code_suggestions)
|
is_successful = self.git_provider.publish_code_suggestions(code_suggestions)
|
||||||
|
if not is_successful:
|
||||||
|
logging.info("Failed to publish code suggestions, trying to publish each suggestion separately")
|
||||||
|
for code_suggestion in code_suggestions:
|
||||||
|
self.git_provider.publish_code_suggestions([code_suggestion])
|
||||||
|
|
||||||
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
|
def dedent_code(self, relevant_file, relevant_lines_start, new_code_snippet):
|
||||||
try: # dedent code snippet
|
try: # dedent code snippet
|
||||||
@ -145,3 +166,81 @@ class PRCodeSuggestions:
|
|||||||
|
|
||||||
return new_code_snippet
|
return new_code_snippet
|
||||||
|
|
||||||
|
async def _prepare_prediction_extended(self, model: str) -> dict:
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
|
||||||
|
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
|
||||||
|
|
||||||
|
logging.info('Getting multi AI predictions...')
|
||||||
|
prediction_list = []
|
||||||
|
for i, patches_diff in enumerate(patches_diff_list):
|
||||||
|
logging.info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
|
||||||
|
self.patches_diff = patches_diff
|
||||||
|
prediction = await self._get_prediction(model)
|
||||||
|
prediction_list.append(prediction)
|
||||||
|
self.prediction_list = prediction_list
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
for prediction in prediction_list:
|
||||||
|
self.prediction = prediction
|
||||||
|
data_per_chunk = self._prepare_pr_code_suggestions()
|
||||||
|
if "Code suggestions" in data:
|
||||||
|
data["Code suggestions"].extend(data_per_chunk["Code suggestions"])
|
||||||
|
else:
|
||||||
|
data.update(data_per_chunk)
|
||||||
|
self.data = data
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def rank_suggestions(self, data: List) -> List:
|
||||||
|
"""
|
||||||
|
Call a model to rank (sort) code suggestions based on their importance order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (List): A list of code suggestions to be ranked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: The ranked list of code suggestions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
suggestion_list = []
|
||||||
|
# remove invalid suggestions
|
||||||
|
for i, suggestion in enumerate(data):
|
||||||
|
if suggestion['existing code'] != suggestion['improved code']:
|
||||||
|
suggestion_list.append(suggestion)
|
||||||
|
|
||||||
|
data_sorted = [[]] * len(suggestion_list)
|
||||||
|
|
||||||
|
try:
|
||||||
|
suggestion_str = ""
|
||||||
|
for i, suggestion in enumerate(suggestion_list):
|
||||||
|
suggestion_str += f"suggestion {i + 1}: " + str(suggestion) + '\n\n'
|
||||||
|
|
||||||
|
variables = {'suggestion_list': suggestion_list, 'suggestion_str': suggestion_str}
|
||||||
|
model = get_settings().config.model
|
||||||
|
environment = Environment(undefined=StrictUndefined)
|
||||||
|
system_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.system).render(
|
||||||
|
variables)
|
||||||
|
user_prompt = environment.from_string(get_settings().pr_sort_code_suggestions_prompt.user).render(variables)
|
||||||
|
if get_settings().config.verbosity_level >= 2:
|
||||||
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, system=system_prompt,
|
||||||
|
user=user_prompt)
|
||||||
|
|
||||||
|
sort_order = yaml.safe_load(response)
|
||||||
|
for s in sort_order['Sort Order']:
|
||||||
|
suggestion_number = s['suggestion number']
|
||||||
|
importance_order = s['importance order']
|
||||||
|
data_sorted[importance_order - 1] = suggestion_list[suggestion_number - 1]
|
||||||
|
|
||||||
|
if get_settings().pr_code_suggestions.final_clip_factor != 1:
|
||||||
|
new_len = int(0.5 + len(data_sorted) * get_settings().pr_code_suggestions.final_clip_factor)
|
||||||
|
data_sorted = data_sorted[:new_len]
|
||||||
|
except Exception as e:
|
||||||
|
if get_settings().config.verbosity_level >= 1:
|
||||||
|
logging.info(f"Could not sort suggestions, error: {e}")
|
||||||
|
data_sorted = suggestion_list
|
||||||
|
|
||||||
|
return data_sorted
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ class PRDescription:
|
|||||||
"extra_instructions": get_settings().pr_description.extra_instructions,
|
"extra_instructions": get_settings().pr_description.extra_instructions,
|
||||||
"commit_messages_str": self.git_provider.get_commit_messages()
|
"commit_messages_str": self.git_provider.get_commit_messages()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.user_description = self.git_provider.get_user_description()
|
||||||
|
|
||||||
# Initialize the token handler
|
# Initialize the token handler
|
||||||
self.token_handler = TokenHandler(
|
self.token_handler = TokenHandler(
|
||||||
@ -145,6 +147,9 @@ class PRDescription:
|
|||||||
# Load the AI prediction data into a dictionary
|
# Load the AI prediction data into a dictionary
|
||||||
data = load_yaml(self.prediction.strip())
|
data = load_yaml(self.prediction.strip())
|
||||||
|
|
||||||
|
if get_settings().pr_description.add_original_user_description and self.user_description:
|
||||||
|
data["User Description"] = self.user_description
|
||||||
|
|
||||||
# Initialization
|
# Initialization
|
||||||
pr_types = []
|
pr_types = []
|
||||||
|
|
||||||
@ -161,13 +166,19 @@ class PRDescription:
|
|||||||
elif type(data['PR Type']) == str:
|
elif type(data['PR Type']) == str:
|
||||||
pr_types = data['PR Type'].split(',')
|
pr_types = data['PR Type'].split(',')
|
||||||
|
|
||||||
# Assign the value of the 'PR Title' key to 'title' variable and remove it from the dictionary
|
# Remove the 'PR Title' key from the dictionary
|
||||||
title = data.pop('PR Title')
|
ai_title = data.pop('PR Title')
|
||||||
|
if get_settings().pr_description.keep_original_user_title:
|
||||||
|
# Assign the original PR title to the 'title' variable
|
||||||
|
title = self.vars["title"]
|
||||||
|
else:
|
||||||
|
# Assign the value of the 'PR Title' key to 'title' variable
|
||||||
|
title = ai_title
|
||||||
|
|
||||||
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
|
# Iterate over the remaining dictionary items and append the key and value to 'pr_body' in a markdown format,
|
||||||
# except for the items containing the word 'walkthrough'
|
# except for the items containing the word 'walkthrough'
|
||||||
pr_body = ""
|
pr_body = ""
|
||||||
for key, value in data.items():
|
for idx, (key, value) in enumerate(data.items()):
|
||||||
pr_body += f"## {key}:\n"
|
pr_body += f"## {key}:\n"
|
||||||
if 'walkthrough' in key.lower():
|
if 'walkthrough' in key.lower():
|
||||||
# for filename, description in value.items():
|
# for filename, description in value.items():
|
||||||
@ -179,7 +190,9 @@ class PRDescription:
|
|||||||
# if the value is a list, join its items by comma
|
# if the value is a list, join its items by comma
|
||||||
if type(value) == list:
|
if type(value) == list:
|
||||||
value = ', '.join(v for v in value)
|
value = ', '.join(v for v in value)
|
||||||
pr_body += f"{value}\n\n___\n"
|
pr_body += f"{value}\n"
|
||||||
|
if idx < len(data) - 1:
|
||||||
|
pr_body += "\n___\n"
|
||||||
|
|
||||||
if get_settings().config.verbosity_level >= 2:
|
if get_settings().config.verbosity_level >= 2:
|
||||||
logging.info(f"title:\n{title}\n{pr_body}")
|
logging.info(f"title:\n{title}\n{pr_body}")
|
||||||
|
Reference in New Issue
Block a user