mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-15 02:00:39 +08:00
Merge commit 'e4f177908b620e46740b03966fda9243473d979e' into hl/pr_review_table
This commit is contained in:
@ -45,6 +45,7 @@ commands = list(command2class.keys())
|
||||
class PRAgent:
|
||||
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
self.ai_handler = ai_handler # will be initialized in run_action
|
||||
self.forbidden_cli_args = ['enable_auto_approval']
|
||||
|
||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||
# First, apply repo specific settings if exists
|
||||
@ -58,6 +59,13 @@ class PRAgent:
|
||||
action, *args = list(lexer)
|
||||
else:
|
||||
action, *args = request
|
||||
|
||||
if args:
|
||||
for forbidden_arg in self.forbidden_cli_args:
|
||||
for arg in args:
|
||||
if forbidden_arg in arg:
|
||||
get_logger().error(f"CLI argument '{forbidden_arg}' is forbidden")
|
||||
return False
|
||||
args = update_settings_from_args(args)
|
||||
|
||||
action = action.lstrip("/").lower()
|
||||
|
@ -9,6 +9,7 @@ MAX_TOKENS = {
|
||||
'gpt-4-0613': 8000,
|
||||
'gpt-4-32k': 32000,
|
||||
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'command-nightly': 4096,
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import re
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import EDIT_TYPE
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
||||
@ -181,7 +181,7 @@ __old hunk__
|
||||
...
|
||||
"""
|
||||
|
||||
patch_with_lines_str = f"\n\n## {file.filename}\n"
|
||||
patch_with_lines_str = f"\n\n## file: '{file.filename.strip()}'\n"
|
||||
patch_lines = patch.splitlines()
|
||||
RE_HUNK_HEADER = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
@ -202,11 +202,11 @@ __old hunk__
|
||||
if new_content_lines:
|
||||
if prev_header_line:
|
||||
patch_with_lines_str += f'\n{prev_header_line}\n'
|
||||
patch_with_lines_str += '__new hunk__\n'
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip()+'\n__new hunk__\n'
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
if old_content_lines:
|
||||
patch_with_lines_str += '__old hunk__\n'
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip()+'\n__old hunk__\n'
|
||||
for line_old in old_content_lines:
|
||||
patch_with_lines_str += f"{line_old}\n"
|
||||
new_content_lines = []
|
||||
@ -236,11 +236,11 @@ __old hunk__
|
||||
if match and new_content_lines:
|
||||
if new_content_lines:
|
||||
patch_with_lines_str += f'\n{header_line}\n'
|
||||
patch_with_lines_str += '\n__new hunk__\n'
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip()+ '\n__new hunk__\n'
|
||||
for i, line_new in enumerate(new_content_lines):
|
||||
patch_with_lines_str += f"{start2 + i} {line_new}\n"
|
||||
if old_content_lines:
|
||||
patch_with_lines_str += '\n__old hunk__\n'
|
||||
patch_with_lines_str = patch_with_lines_str.rstrip() + '\n__old hunk__\n'
|
||||
for line_old in old_content_lines:
|
||||
patch_with_lines_str += f"{line_old}\n"
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any, Callable, List, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from github import RateLimitExceededException
|
||||
|
||||
@ -11,9 +9,10 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.file_filter import filter_ignored
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import get_max_tokens
|
||||
from pr_agent.algo.utils import get_max_tokens, ModelType
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
@ -209,9 +208,9 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
|
||||
if patch:
|
||||
if not convert_hunks_to_line_numbers:
|
||||
patch_final = f"## {file.filename}\n\n{patch}\n"
|
||||
patch_final = f"\n\n## file: '{file.filename.strip()}\n\n{patch.strip()}\n'"
|
||||
else:
|
||||
patch_final = patch
|
||||
patch_final = "\n\n" + patch.strip()
|
||||
patches.append(patch_final)
|
||||
total_tokens += token_handler.count_tokens(patch_final)
|
||||
if get_settings().config.verbosity_level >= 2:
|
||||
@ -220,8 +219,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
return patches, modified_files_list, deleted_files_list, added_files_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable):
|
||||
all_models = _get_all_models()
|
||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||
all_models = _get_all_models(model_type)
|
||||
all_deployments = _get_all_deployments(all_models)
|
||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
||||
@ -243,8 +242,11 @@ async def retry_with_fallback_models(f: Callable):
|
||||
raise # Re-raise the last exception
|
||||
|
||||
|
||||
def _get_all_models() -> List[str]:
|
||||
model = get_settings().config.model
|
||||
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||
if model_type == ModelType.TURBO:
|
||||
model = get_settings().config.model_turbo
|
||||
else:
|
||||
model = get_settings().config.model
|
||||
fallback_models = get_settings().config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [m.strip() for m in fallback_models.split(",")]
|
||||
@ -267,78 +269,6 @@ def _get_all_deployments(all_models: List[str]) -> List[str]:
|
||||
return all_deployments
|
||||
|
||||
|
||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None) -> Tuple[int, int]:
|
||||
position = -1
|
||||
if absolute_position is None:
|
||||
absolute_position = -1
|
||||
re_hunk_header = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
|
||||
for file in diff_files:
|
||||
if file.filename and (file.filename.strip() == relevant_file):
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
delta = 0
|
||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||
if absolute_position != -1: # matching absolute to relative
|
||||
for i, line in enumerate(patch_lines):
|
||||
# new hunk
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
#
|
||||
absolute_position_curr = start2 + delta - 1
|
||||
|
||||
if absolute_position_curr == absolute_position:
|
||||
position = i
|
||||
break
|
||||
else:
|
||||
# try to find the line in the patch using difflib, with some margin of error
|
||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
||||
patch_lines, n=3, cutoff=0.93)
|
||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||
relevant_line_in_file = matches_difflib[0]
|
||||
|
||||
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if relevant_line_in_file in line and line[0] != '-':
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
|
||||
if position == -1 and relevant_line_in_file[0] == '+':
|
||||
no_plus_line = relevant_line_in_file[1:].lstrip()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if no_plus_line in line and line[0] != '-':
|
||||
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
return position, absolute_position
|
||||
|
||||
|
||||
def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
token_handler: TokenHandler,
|
||||
model: str,
|
||||
@ -375,6 +305,13 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
for lang in pr_languages:
|
||||
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))
|
||||
|
||||
|
||||
# try first a single run with standard diff string, with patch extension, and no deletions
|
||||
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
|
||||
pr_languages, token_handler, add_line_numbers_to_hunks=True)
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
|
||||
return ["\n".join(patches_extended)]
|
||||
|
||||
patches = []
|
||||
final_diff_list = []
|
||||
total_tokens = token_handler.prompt_tokens
|
||||
@ -398,6 +335,11 @@ def get_pr_multi_diffs(git_provider: GitProvider,
|
||||
|
||||
patch = convert_to_hunks_with_lines_numbers(patch, file)
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
|
||||
if patch and (token_handler.prompt_tokens + new_patch_tokens) > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
get_logger().warning(f"Patch too large, skipping: {file.filename}")
|
||||
continue
|
||||
|
||||
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
|
||||
final_diff = "\n".join(patches)
|
||||
final_diff_list.append(final_diff)
|
||||
|
23
pr_agent/algo/types.py
Normal file
23
pr_agent/algo/types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EDIT_TYPE(Enum):
|
||||
ADDED = 1
|
||||
DELETED = 2
|
||||
MODIFIED = 3
|
||||
RENAMED = 4
|
||||
UNKNOWN = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN
|
||||
old_filename: str = None
|
||||
num_plus_lines: int = -1
|
||||
num_minus_lines: int = -1
|
@ -5,7 +5,8 @@ import json
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
from enum import Enum
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import yaml
|
||||
from starlette_context import context
|
||||
@ -13,8 +14,12 @@ from starlette_context import context
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.token_handler import get_token_encoder
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.algo.types import FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
class ModelType(str, Enum):
|
||||
REGULAR = "regular"
|
||||
TURBO = "turbo"
|
||||
|
||||
def get_setting(key: str) -> Any:
|
||||
try:
|
||||
@ -489,4 +494,76 @@ def replace_code_tags(text):
|
||||
parts = text.split('`')
|
||||
for i in range(1, len(parts), 2):
|
||||
parts[i] = '<code>' + parts[i] + '</code>'
|
||||
return ''.join(parts)
|
||||
return ''.join(parts)
|
||||
|
||||
|
||||
def find_line_number_of_relevant_line_in_file(diff_files: List[FilePatchInfo],
|
||||
relevant_file: str,
|
||||
relevant_line_in_file: str,
|
||||
absolute_position: int = None) -> Tuple[int, int]:
|
||||
position = -1
|
||||
if absolute_position is None:
|
||||
absolute_position = -1
|
||||
re_hunk_header = re.compile(
|
||||
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
|
||||
|
||||
for file in diff_files:
|
||||
if file.filename and (file.filename.strip() == relevant_file):
|
||||
patch = file.patch
|
||||
patch_lines = patch.splitlines()
|
||||
delta = 0
|
||||
start1, size1, start2, size2 = 0, 0, 0, 0
|
||||
if absolute_position != -1: # matching absolute to relative
|
||||
for i, line in enumerate(patch_lines):
|
||||
# new hunk
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
#
|
||||
absolute_position_curr = start2 + delta - 1
|
||||
|
||||
if absolute_position_curr == absolute_position:
|
||||
position = i
|
||||
break
|
||||
else:
|
||||
# try to find the line in the patch using difflib, with some margin of error
|
||||
matches_difflib: list[str | Any] = difflib.get_close_matches(relevant_line_in_file,
|
||||
patch_lines, n=3, cutoff=0.93)
|
||||
if len(matches_difflib) == 1 and matches_difflib[0].startswith('+'):
|
||||
relevant_line_in_file = matches_difflib[0]
|
||||
|
||||
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if relevant_line_in_file in line and line[0] != '-':
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
|
||||
if position == -1 and relevant_line_in_file[0] == '+':
|
||||
no_plus_line = relevant_line_in_file[1:].lstrip()
|
||||
for i, line in enumerate(patch_lines):
|
||||
if line.startswith('@@'):
|
||||
delta = 0
|
||||
match = re_hunk_header.match(line)
|
||||
start1, size1, start2, size2 = map(int, match.groups()[:4])
|
||||
elif not line.startswith('-'):
|
||||
delta += 1
|
||||
|
||||
if no_plus_line in line and line[0] != '-':
|
||||
# The model might add a '+' to the beginning of the relevant_line_in_file even if originally
|
||||
# it's a context line
|
||||
position = i
|
||||
absolute_position = start2 + delta - 1
|
||||
break
|
||||
return position, absolute_position
|
||||
|
@ -6,7 +6,8 @@ from ..log import get_logger
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.utils import clip_tokens, load_large_diff
|
||||
from ..config_loader import get_settings
|
||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
from .git_provider import GitProvider
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
AZURE_DEVOPS_AVAILABLE = True
|
||||
|
||||
|
@ -6,10 +6,11 @@ import requests
|
||||
from atlassian.bitbucket import Cloud
|
||||
from starlette_context import context
|
||||
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
|
||||
from pr_agent.algo.types import FilePatchInfo, EDIT_TYPE
|
||||
from ..algo.utils import find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||
from .git_provider import GitProvider
|
||||
|
||||
|
||||
class BitbucketProvider(GitProvider):
|
||||
|
@ -6,9 +6,9 @@ import requests
|
||||
from atlassian.bitbucket import Bitbucket
|
||||
from starlette_context import context
|
||||
|
||||
from .git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
|
||||
from ..algo.utils import load_large_diff
|
||||
from .git_provider import GitProvider
|
||||
from pr_agent.algo.types import FilePatchInfo
|
||||
from ..algo.utils import load_large_diff, find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
|
||||
|
@ -5,9 +5,9 @@ from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pr_agent.git_providers.codecommit_client import CodeCommitClient
|
||||
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from ..algo.utils import load_large_diff
|
||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
from .git_provider import GitProvider
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
|
||||
|
@ -13,7 +13,8 @@ import urllib3.util
|
||||
from git import Repo
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from pr_agent.git_providers.local_git_provider import PullRequestMimic
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
@ -1,35 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
# enum EDIT_TYPE (ADDED, DELETED, MODIFIED, RENAMED)
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.algo.types import FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class EDIT_TYPE(Enum):
|
||||
ADDED = 1
|
||||
DELETED = 2
|
||||
MODIFIED = 3
|
||||
RENAMED = 4
|
||||
UNKNOWN = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePatchInfo:
|
||||
base_file: str
|
||||
head_file: str
|
||||
patch: str
|
||||
filename: str
|
||||
tokens: int = -1
|
||||
edit_type: EDIT_TYPE = EDIT_TYPE.UNKNOWN
|
||||
old_filename: str = None
|
||||
num_plus_lines: int = -1
|
||||
num_minus_lines: int = -1
|
||||
|
||||
|
||||
class GitProvider(ABC):
|
||||
@abstractmethod
|
||||
def is_supported(self, capability: str) -> bool:
|
||||
@ -193,6 +171,11 @@ class GitProvider(ABC):
|
||||
def get_latest_commit_url(self) -> str:
|
||||
return ""
|
||||
|
||||
def auto_approve(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def get_main_pr_language(languages, files) -> str:
|
||||
"""
|
||||
Get the main language of the commit. Return an empty string if cannot determine.
|
||||
@ -261,7 +244,6 @@ def get_main_pr_language(languages, files) -> str:
|
||||
|
||||
return main_language_str
|
||||
|
||||
|
||||
class IncrementalPR:
|
||||
def __init__(self, is_incremental: bool = False):
|
||||
self.is_incremental = is_incremental
|
||||
|
@ -9,12 +9,12 @@ from retry import retry
|
||||
from starlette_context import context
|
||||
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
|
||||
from ..algo.utils import load_large_diff, clip_tokens
|
||||
from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from ..log import get_logger
|
||||
from ..servers.utils import RateLimitExceeded
|
||||
from .git_provider import FilePatchInfo, GitProvider, IncrementalPR, EDIT_TYPE
|
||||
from .git_provider import GitProvider, IncrementalPR
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
|
||||
|
||||
class GithubProvider(GitProvider):
|
||||
@ -643,3 +643,13 @@ class GithubProvider(GitProvider):
|
||||
return pr_id
|
||||
except:
|
||||
return ""
|
||||
|
||||
def auto_approve(self) -> bool:
|
||||
try:
|
||||
res = self.pr.create_review(event="APPROVE")
|
||||
if res.state == "APPROVED":
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
get_logger().exception(f"Failed to auto-approve, error: {e}")
|
||||
return False
|
@ -7,10 +7,10 @@ import gitlab
|
||||
from gitlab import GitlabGetError
|
||||
|
||||
from ..algo.language_handler import is_valid_file
|
||||
from ..algo.pr_processing import find_line_number_of_relevant_line_in_file
|
||||
from ..algo.utils import load_large_diff, clip_tokens
|
||||
from ..algo.utils import load_large_diff, clip_tokens, find_line_number_of_relevant_line_in_file
|
||||
from ..config_loader import get_settings
|
||||
from .git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
from .git_provider import GitProvider
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from ..log import get_logger
|
||||
|
||||
|
||||
|
@ -5,7 +5,8 @@ from typing import List
|
||||
from git import Repo
|
||||
|
||||
from pr_agent.config_loader import _find_repository_root, get_settings
|
||||
from pr_agent.git_providers.git_provider import EDIT_TYPE, FilePatchInfo, GitProvider
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
|
||||
|
@ -2,15 +2,18 @@ from pr_agent.config_loader import get_settings
|
||||
|
||||
|
||||
def get_secret_provider():
|
||||
try:
|
||||
provider_id = get_settings().config.secret_provider
|
||||
except AttributeError as e:
|
||||
raise ValueError("secret_provider is a required attribute in the configuration file") from e
|
||||
try:
|
||||
if provider_id == 'google_cloud_storage':
|
||||
if not get_settings().get("CONFIG.SECRET_PROVIDER"):
|
||||
return None
|
||||
|
||||
provider_id = get_settings().config.secret_provider
|
||||
if provider_id == 'google_cloud_storage':
|
||||
try:
|
||||
from pr_agent.secret_providers.google_cloud_storage_secret_provider import GoogleCloudStorageSecretProvider
|
||||
return GoogleCloudStorageSecretProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown secret provider: {provider_id}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize secret provider {provider_id}") from e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to initialize google_cloud_storage secret provider {provider_id}") from e
|
||||
else:
|
||||
raise ValueError("Unknown SECRET_PROVIDER")
|
||||
|
||||
|
||||
|
||||
|
@ -26,7 +26,8 @@ from pr_agent.tools.pr_reviewer import PRReviewer
|
||||
|
||||
setup_logger(fmt=LoggingFormat.JSON)
|
||||
router = APIRouter()
|
||||
secret_provider = get_secret_provider()
|
||||
secret_provider = get_secret_provider() if get_settings().get("CONFIG.SECRET_PROVIDER") else None
|
||||
|
||||
|
||||
async def get_bearer_token(shared_secret: str, client_key: str):
|
||||
try:
|
||||
|
@ -82,14 +82,23 @@ async def run_action():
|
||||
if action in ["opened", "reopened"]:
|
||||
pr_url = event_payload.get("pull_request", {}).get("url")
|
||||
if pr_url:
|
||||
# legacy - supporting both GITHUB_ACTION and GITHUB_ACTION_CONFIG
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION.AUTO_REVIEW", None)
|
||||
if auto_review is None:
|
||||
auto_review = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_REVIEW", None)
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
||||
if auto_describe is None:
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_DESCRIBE", None)
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
||||
if auto_improve is None:
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION_CONFIG.AUTO_IMPROVE", None)
|
||||
|
||||
# invoke by default all three tools
|
||||
if auto_describe is None or is_true(auto_describe):
|
||||
await PRDescription(pr_url).run()
|
||||
if auto_review is None or is_true(auto_review):
|
||||
await PRReviewer(pr_url).run()
|
||||
auto_describe = get_setting_or_env("GITHUB_ACTION.AUTO_DESCRIBE", None)
|
||||
if is_true(auto_describe):
|
||||
await PRDescription(pr_url).run()
|
||||
auto_improve = get_setting_or_env("GITHUB_ACTION.AUTO_IMPROVE", None)
|
||||
if is_true(auto_improve):
|
||||
if auto_improve is None or is_true(auto_improve):
|
||||
await PRCodeSuggestions(pr_url).run()
|
||||
|
||||
# Handle issue comment event
|
||||
|
@ -48,7 +48,7 @@ Examples for extra instructions:
|
||||
```
|
||||
[pr_reviewer] # /review #
|
||||
extra_instructions="""
|
||||
In the code feedback section, emphasize the following:
|
||||
In the 'general suggestions' section, emphasize the following:
|
||||
- Does the code logic cover relevant edge cases?
|
||||
- Is the code logic clear and easy to understand?
|
||||
- Is the code logic efficient?
|
||||
@ -71,14 +71,14 @@ Edit this field to enable/disable the tool, or to change the used configurations
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
# code feedback
|
||||
output += "<tr><td><details> <summary><strong> About the 'Code feedback' section</strong></summary><hr>\n\n"
|
||||
output+="""\
|
||||
The `review` tool provides several type of feedbacks, one of them is code suggestions.
|
||||
If you are interested **only** in the code suggestions, it is recommended to use the [`improve`](https://github.com/Codium-ai/pr-agent/blob/main/docs/IMPROVE.md) feature instead, since it dedicated only to code suggestions, and usually gives better results.
|
||||
Use the `review` tool if you want to get a more comprehensive feedback, which includes code suggestions as well.
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
# # code feedback
|
||||
# output += "<tr><td><details> <summary><strong> About the 'Code feedback' section</strong></summary><hr>\n\n"
|
||||
# output+="""\
|
||||
# The `review` tool provides several type of feedbacks, one of them is code suggestions.
|
||||
# If you are interested **only** in the code suggestions, it is recommended to use the [`improve`](https://github.com/Codium-ai/pr-agent/blob/main/docs/IMPROVE.md) feature instead, since it dedicated only to code suggestions, and usually gives better results.
|
||||
# Use the `review` tool if you want to get a more comprehensive feedback, which includes code suggestions as well.
|
||||
# """
|
||||
# output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
# auto-labels
|
||||
output += "<tr><td><details> <summary><strong> Auto-labels</strong></summary><hr>\n\n"
|
||||
@ -99,6 +99,31 @@ Some of the feature that are disabled by default are quite useful, and should be
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
output += "<tr><td><details> <summary><strong> Auto-approve PRs</strong></summary><hr>\n\n"
|
||||
output += '''\
|
||||
By invoking:
|
||||
```
|
||||
/review auto_approve
|
||||
```
|
||||
The tool will automatically approve the PR, and add a comment with the approval.
|
||||
|
||||
|
||||
To ensure safety, the auto-approval feature is disabled by default. To enable auto-approval, you need to actively set in a pre-defined configuration file the following:
|
||||
```
|
||||
[pr_reviewer]
|
||||
enable_auto_approval = true
|
||||
```
|
||||
(this specific flag cannot be set with a command line argument, only in the configuration file, committed to the repository)
|
||||
|
||||
|
||||
You can also enable auto-approval only if the PR meets certain requirements, such as that the `estimated_review_effort` is equal or below a certain threshold, by adjusting the flag:
|
||||
```
|
||||
[pr_reviewer]
|
||||
maximal_review_effort = 5
|
||||
```
|
||||
'''
|
||||
output += "\n\n</details></td></tr>\n\n"
|
||||
|
||||
# general
|
||||
output += "\n\n<tr><td><details> <summary><strong> More PR-Agent commands</strong></summary><hr> \n\n"
|
||||
output += HelpMessage.get_general_bot_help_text()
|
||||
@ -186,6 +211,7 @@ To enable inline file summary, set `pr_description.inline_file_summary` in the c
|
||||
- `true`: A collapsable file comment with changes title and a changes summary for each file in the PR.
|
||||
- `false` (default): File changes walkthrough will be added only to the "Conversation" tab.
|
||||
"""
|
||||
|
||||
# extra instructions
|
||||
output += "<tr><td><details> <summary><strong> Utilizing extra instructions</strong></summary><hr>\n\n"
|
||||
output += '''\
|
||||
@ -309,8 +335,9 @@ Use triple quotes to write multi-line instructions. Use bullet points to make th
|
||||
output += """\
|
||||
- While the current AI for code is getting better and better (GPT-4), it's not flawless. Not all the suggestions will be perfect, and a user should not accept all of them automatically.
|
||||
- Suggestions are not meant to be simplistic. Instead, they aim to give deep feedback and raise questions, ideas and thoughts to the user, who can then use his judgment, experience, and understanding of the code base.
|
||||
- Recommended to use the 'extra_instructions' field to guide the model to suggestions that are more relevant to the specific needs of the project.
|
||||
- Best quality will be obtained by using 'improve --extended' mode.
|
||||
- Recommended to use the 'extra_instructions' field to guide the model to suggestions that are more relevant to the specific needs of the project, or use the [custom suggestions :gem:](https://github.com/Codium-ai/pr-agent/blob/main/docs/CUSTOM_SUGGESTIONS.md) tool
|
||||
- With large PRs, best quality will be obtained by using 'improve --extended' mode.
|
||||
|
||||
|
||||
"""
|
||||
output += "\n\n</details></td></tr>\n\n"\
|
||||
|
@ -1,5 +1,6 @@
|
||||
[config]
|
||||
model="gpt-4" # "gpt-4-1106-preview"
|
||||
model="gpt-4" # "gpt-4-0125-preview"
|
||||
model_turbo="gpt-4-0125-preview"
|
||||
fallback_models=["gpt-3.5-turbo-16k"]
|
||||
git_provider="github"
|
||||
publish_output=true
|
||||
@ -8,11 +9,11 @@ verbosity_level=0 # 0,1,2
|
||||
use_extra_bad_extensions=false
|
||||
use_repo_settings_file=true
|
||||
use_global_settings_file=true
|
||||
ai_timeout=180
|
||||
ai_timeout=90
|
||||
max_description_tokens = 500
|
||||
max_commits_tokens = 500
|
||||
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
|
||||
patch_extra_lines = 3
|
||||
patch_extra_lines = 1
|
||||
secret_provider="google_cloud_storage"
|
||||
cli_mode=false
|
||||
|
||||
@ -42,12 +43,16 @@ require_all_thresholds_for_incremental_review=false
|
||||
minimal_commits_for_incremental_review=0
|
||||
minimal_minutes_for_incremental_review=0
|
||||
enable_help_text=true # Determines whether to include help text in the PR review. Enabled by default.
|
||||
# auto approval
|
||||
enable_auto_approval=false
|
||||
maximal_review_effort=5
|
||||
|
||||
|
||||
[pr_description] # /describe #
|
||||
publish_labels=true
|
||||
publish_description_as_comment=false
|
||||
add_original_user_description=true
|
||||
keep_original_user_title=false
|
||||
keep_original_user_title=true
|
||||
use_bullet_points=true
|
||||
extra_instructions = ""
|
||||
enable_pr_type=true
|
||||
@ -68,17 +73,19 @@ enable_help_text=true
|
||||
|
||||
|
||||
[pr_code_suggestions] # /improve #
|
||||
max_context_tokens=8000
|
||||
num_code_suggestions=4
|
||||
summarize = true
|
||||
extra_instructions = ""
|
||||
rank_suggestions = false
|
||||
enable_help_text=true
|
||||
# params for '/improve --extended' mode
|
||||
auto_extended_mode=false
|
||||
num_code_suggestions_per_chunk=8
|
||||
rank_extended_suggestions = true
|
||||
max_number_of_calls = 5
|
||||
final_clip_factor = 0.9
|
||||
auto_extended_mode=true
|
||||
num_code_suggestions_per_chunk=5
|
||||
max_number_of_calls = 3
|
||||
parallel_calls = true
|
||||
rank_extended_suggestions = false
|
||||
final_clip_factor = 0.8
|
||||
|
||||
[pr_add_docs] # /add_docs #
|
||||
extra_instructions = ""
|
||||
@ -90,6 +97,15 @@ extra_instructions = ""
|
||||
|
||||
[pr_analyze] # /analyze #
|
||||
|
||||
[pr_test] # /test #
|
||||
extra_instructions = ""
|
||||
testing_framework = "" # specify the testing framework you want to use
|
||||
num_tests=3 # number of tests to generate. max 5.
|
||||
avoid_mocks=true # if true, the generated tests will prefer to use real objects instead of mocks
|
||||
file = "" # in case there are several components with the same name, you can specify the relevant file
|
||||
class_name = "" # in case there are several methods with the same name in the same file, you can specify the relevant class name
|
||||
enable_help_text=true
|
||||
|
||||
[pr_config] # /config #
|
||||
|
||||
[github]
|
||||
@ -100,7 +116,7 @@ base_url = "https://api.github.com"
|
||||
publish_inline_comments_fallback_with_verification = true
|
||||
try_fix_invalid_inline_comments = true
|
||||
|
||||
[github_action]
|
||||
[github_action_config]
|
||||
# auto_review = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
# auto_describe = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
# auto_improve = true # set as env var in .github/workflows/pr-agent.yaml
|
||||
|
@ -5,7 +5,7 @@ Your task is to generate {{ docs_for_language }} for code components in the PR D
|
||||
|
||||
Example for the PR Diff format:
|
||||
======
|
||||
## src/file1.py
|
||||
## file: 'src/file1.py'
|
||||
|
||||
@@ -12,3 +12,4 @@ def func1():
|
||||
__new hunk__
|
||||
@ -18,7 +18,6 @@ __old hunk__
|
||||
-code line that was removed in the PR
|
||||
code line2 that remained unchanged in the PR
|
||||
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
...
|
||||
@ -26,7 +25,7 @@ __old hunk__
|
||||
...
|
||||
|
||||
|
||||
## src/file2.py
|
||||
## file: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
|
@ -4,19 +4,18 @@ Your task is to provide meaningful and actionable code suggestions, to improve t
|
||||
|
||||
Example for the PR Diff format:
|
||||
======
|
||||
## src/file1.py
|
||||
## file: 'src/file1.py'
|
||||
|
||||
@@ ... @@ def func1():
|
||||
__new hunk__
|
||||
12 code line1 that remained unchanged in the PR
|
||||
13 +new code line2 added in the PR
|
||||
13 +new hunk code line2 added in the PR
|
||||
14 code line3 that remained unchanged in the PR
|
||||
__old hunk__
|
||||
code line1 that remained unchanged in the PR
|
||||
-old code line2 that was removed in the PR
|
||||
-old hunk code line2 that was removed in the PR
|
||||
code line3 that remained unchanged in the PR
|
||||
|
||||
|
||||
@@ ... @@ def func2():
|
||||
__new hunk__
|
||||
...
|
||||
@ -24,7 +23,7 @@ __old hunk__
|
||||
...
|
||||
|
||||
|
||||
## src/file2.py
|
||||
## file: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
@ -34,7 +33,7 @@ Specific instructions:
|
||||
- The suggestions should refer only to code from the '__new hunk__' sections, and focus on new lines of code (lines starting with '+').
|
||||
- Prioritize suggestions that address major problems, issues and bugs in the PR code. As a second priority, suggestions should focus on enhancement, best practice, performance, maintainability, and other aspects.
|
||||
- Don't suggest to add docstring, type hints, or comments, or to remove unused imports.
|
||||
- Avoid making suggestions that have already been implemented in the PR code. For example, if you want to add logs, or change a variable to const, or anything else, make sure it isn't already in the '__new hunk__' code.
|
||||
- Suggestions should not repeat code already present in the '__new hunk__' sections.
|
||||
- Provide the exact line numbers range (inclusive) for each suggestion.
|
||||
- When quoting variables or names from the code, use backticks (`) instead of single quote (').
|
||||
|
||||
@ -51,6 +50,7 @@ The output must be a YAML object equivalent to type $PRCodeSuggestions, accordin
|
||||
=====
|
||||
class CodeSuggestion(BaseModel):
|
||||
relevant_file: str = Field(description="the relevant file full path")
|
||||
language: str = Field(description="the code language of the relevant file")
|
||||
suggestion_content: str = Field(description="an actionable suggestion for meaningfully improving the new code introduced in the PR")
|
||||
{%- if summarize_mode %}
|
||||
existing_code: str = Field(description="a short code snippet from a '__new hunk__' section to illustrate the relevant existing code. Don't show the line numbers.")
|
||||
@ -72,44 +72,41 @@ class PRCodeSuggestions(BaseModel):
|
||||
Example output:
|
||||
```yaml
|
||||
code_suggestions:
|
||||
- relevant_file: |-
|
||||
- relevant_file: |
|
||||
src/file1.py
|
||||
suggestion_content: |-
|
||||
language: |
|
||||
python
|
||||
suggestion_content: |
|
||||
Add a docstring to func1()
|
||||
{%- if summarize_mode %}
|
||||
existing_code: |-
|
||||
existing_code: |
|
||||
def func1():
|
||||
improved_code: |-
|
||||
improved_code: |
|
||||
...
|
||||
one_sentence_summary: |-
|
||||
one_sentence_summary: |
|
||||
...
|
||||
relevant_lines_start: 12
|
||||
relevant_lines_end: 12
|
||||
{%- else %}
|
||||
existing_code: |-
|
||||
existing_code: |
|
||||
def func1():
|
||||
relevant_lines_start: 12
|
||||
relevant_lines_end: 12
|
||||
improved_code: |-
|
||||
improved_code: |
|
||||
...
|
||||
{%- endif %}
|
||||
label: |-
|
||||
label: |
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
Each YAML output MUST be after a newline, indented, with block scalar indicator ('|-').
|
||||
Each YAML output MUST be after a newline, indented, with block scalar indicator ('|').
|
||||
"""
|
||||
|
||||
user="""PR Info:
|
||||
|
||||
Title: '{{title}}'
|
||||
|
||||
{%- if language %}
|
||||
|
||||
Main PR language: '{{ language }}'
|
||||
{%- endif %}
|
||||
|
||||
|
||||
The PR Diff:
|
||||
======
|
||||
|
@ -39,6 +39,7 @@ class PRType(str, Enum):
|
||||
|
||||
Class FileDescription(BaseModel):
|
||||
filename: str = Field(description="the relevant file full path")
|
||||
language: str = Field(description="the relevant file language")
|
||||
changes_summary: str = Field(description="concise summary of the changes in the relevant file, in bullet points (1-4 bullet points).")
|
||||
changes_title: str = Field(description="an informative title for the changes in the files, describing its main theme (5-10 words).")
|
||||
label: str = Field(description="a single semantic label that represents a type of code changes that occurred in the File. Possible values (partial list): 'bug fix', 'tests', 'enhancement', 'documentation', 'error handling', 'configuration changes', 'dependencies', 'formatting', 'miscellaneous', ...")
|
||||
@ -67,6 +68,8 @@ type:
|
||||
pr_files:
|
||||
- filename: |
|
||||
...
|
||||
language: |
|
||||
...
|
||||
changes_summary: |
|
||||
...
|
||||
changes_title: |
|
||||
@ -104,10 +107,7 @@ Previous description:
|
||||
{%- endif %}
|
||||
|
||||
Branch: '{{branch}}'
|
||||
{%- if language %}
|
||||
|
||||
Main PR language: '{{ language }}'
|
||||
{%- endif %}
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
Commit messages:
|
||||
|
@ -5,7 +5,7 @@ The review should focus on new code added in the PR diff (lines starting with '+
|
||||
|
||||
Example PR Diff:
|
||||
======
|
||||
## src/file1.py
|
||||
## file: 'src/file1.py'
|
||||
|
||||
@@ -12,5 +12,5 @@ def func1():
|
||||
code line 1 that remained unchanged in the PR
|
||||
@ -14,12 +14,11 @@ code line 2 that remained unchanged in the PR
|
||||
+code line added in the PR
|
||||
code line 3 that remained unchanged in the PR
|
||||
|
||||
|
||||
@@ ... @@ def func2():
|
||||
...
|
||||
|
||||
|
||||
## src/file2.py
|
||||
## file: 'src/file2.py'
|
||||
...
|
||||
======
|
||||
|
||||
@ -96,6 +95,9 @@ PR Feedback:
|
||||
relevant file:
|
||||
type: string
|
||||
description: the relevant file full path
|
||||
language:
|
||||
type: string
|
||||
description: the language of the relevant file
|
||||
suggestion:
|
||||
type: string
|
||||
description: |-
|
||||
@ -141,6 +143,8 @@ PR Feedback:
|
||||
Code feedback:
|
||||
- relevant file: |-
|
||||
directory/xxx.py
|
||||
language: |-
|
||||
python
|
||||
suggestion: |-
|
||||
xxx [important]
|
||||
relevant line: |-
|
||||
@ -170,10 +174,6 @@ Description:
|
||||
======
|
||||
{%- endif %}
|
||||
|
||||
{%- if language %}
|
||||
|
||||
Main PR language: '{{ language }}'
|
||||
{%- endif %}
|
||||
{%- if commit_messages_str %}
|
||||
|
||||
Commit messages:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import textwrap
|
||||
from functools import partial
|
||||
@ -8,7 +9,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, replace_code_tags
|
||||
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -26,6 +27,14 @@ class PRCodeSuggestions:
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
# limit context specifically for the improve command, which has hard input to parse:
|
||||
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
|
||||
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
@ -64,10 +73,10 @@ class PRCodeSuggestions:
|
||||
|
||||
get_logger().info('Preparing PR code suggestions...')
|
||||
if not self.is_extended:
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
|
||||
data = self._prepare_pr_code_suggestions()
|
||||
else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended)
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
|
||||
|
||||
|
||||
if (not data) or (not 'code_suggestions' in data):
|
||||
@ -103,18 +112,18 @@ class PRCodeSuggestions:
|
||||
|
||||
async def _prepare_prediction(self, model: str):
|
||||
get_logger().info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=True)
|
||||
|
||||
get_logger().info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
self.prediction = await self._get_prediction(model, patches_diff)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
async def _get_prediction(self, model: str, patches_diff: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
variables["diff"] = patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
|
||||
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
|
||||
@ -190,7 +199,8 @@ class PRCodeSuggestions:
|
||||
original_initial_line = None
|
||||
for file in self.diff_files:
|
||||
if file.filename.strip() == relevant_file:
|
||||
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
|
||||
if file.head_file: # in bitbucket, head_file is empty. toDo: fix this
|
||||
original_initial_line = file.head_file.splitlines()[relevant_lines_start - 1]
|
||||
break
|
||||
if original_initial_line:
|
||||
suggested_initial_line = new_code_snippet.splitlines()[0]
|
||||
@ -220,14 +230,18 @@ class PRCodeSuggestions:
|
||||
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
|
||||
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)
|
||||
|
||||
get_logger().info('Getting multi AI predictions...')
|
||||
prediction_list = []
|
||||
for i, patches_diff in enumerate(patches_diff_list):
|
||||
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
|
||||
self.patches_diff = patches_diff
|
||||
prediction = await self._get_prediction(model)
|
||||
prediction_list.append(prediction)
|
||||
self.prediction_list = prediction_list
|
||||
# parallelize calls to AI:
|
||||
if get_settings().pr_code_suggestions.parallel_calls:
|
||||
get_logger().info('Getting multi AI predictions in parallel...')
|
||||
prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in patches_diff_list])
|
||||
self.prediction_list = prediction_list
|
||||
else:
|
||||
get_logger().info('Getting multi AI predictions...')
|
||||
prediction_list = []
|
||||
for i, patches_diff in enumerate(patches_diff_list):
|
||||
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
|
||||
prediction = await self._get_prediction(model, patches_diff)
|
||||
prediction_list.append(prediction)
|
||||
|
||||
data = {}
|
||||
for prediction in prediction_list:
|
||||
@ -252,10 +266,15 @@ class PRCodeSuggestions:
|
||||
"""
|
||||
|
||||
suggestion_list = []
|
||||
if not data:
|
||||
return suggestion_list
|
||||
for suggestion in data:
|
||||
suggestion_list.append(suggestion)
|
||||
data_sorted = [[]] * len(suggestion_list)
|
||||
|
||||
if len(suggestion_list ) == 1:
|
||||
return suggestion_list
|
||||
|
||||
try:
|
||||
suggestion_str = ""
|
||||
for i, suggestion in enumerate(suggestion_list):
|
||||
@ -311,7 +330,7 @@ class PRCodeSuggestions:
|
||||
|
||||
pr_body += "<table>"
|
||||
header = f"Suggestions"
|
||||
delta = 77
|
||||
delta = 75
|
||||
header += " " * delta
|
||||
pr_body += f"""<thead><tr><th></th><th>{header}</th></tr></thead>"""
|
||||
pr_body += """<tbody>"""
|
||||
|
@ -9,7 +9,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels, ModelType
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -80,7 +80,7 @@ class PRDescription:
|
||||
if get_settings().config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing PR description...", is_temporary=True)
|
||||
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO) # turbo model because larger context
|
||||
|
||||
get_logger().info(f"Preparing answer {self.pr_id}")
|
||||
if self.prediction:
|
||||
@ -113,22 +113,27 @@ class PRDescription:
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
get_logger().info(f"Pushing answer {self.pr_id}")
|
||||
|
||||
# publish labels
|
||||
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
|
||||
current_labels = self.git_provider.get_pr_labels()
|
||||
user_labels = get_user_labels(current_labels)
|
||||
self.git_provider.publish_labels(pr_labels + user_labels)
|
||||
|
||||
# publish description
|
||||
if get_settings().pr_description.publish_description_as_comment:
|
||||
get_logger().info(f"Publishing answer as comment")
|
||||
self.git_provider.publish_comment(full_markdown_description)
|
||||
else:
|
||||
self.git_provider.publish_description(pr_title, pr_body)
|
||||
if get_settings().pr_description.publish_labels and self.git_provider.is_supported("get_labels"):
|
||||
current_labels = self.git_provider.get_pr_labels()
|
||||
user_labels = get_user_labels(current_labels)
|
||||
self.git_provider.publish_labels(pr_labels + user_labels)
|
||||
|
||||
# publish final update message
|
||||
if (get_settings().pr_description.final_update_message and
|
||||
hasattr(self.git_provider, 'pr_url') and self.git_provider.pr_url):
|
||||
latest_commit_url = self.git_provider.get_latest_commit_url()
|
||||
if latest_commit_url:
|
||||
self.git_provider.publish_comment(
|
||||
f"**[PR Description]({self.git_provider.pr_url})** updated to latest commit ({latest_commit_url})")
|
||||
f"**[PR Description]({self.git_provider.get_pr_url()})** updated to latest commit ({latest_commit_url})")
|
||||
self.git_provider.remove_initial_comment()
|
||||
except Exception as e:
|
||||
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")
|
||||
@ -358,7 +363,7 @@ class PRDescription:
|
||||
try:
|
||||
pr_body += "<table>"
|
||||
header = f"Relevant files"
|
||||
delta = 77
|
||||
delta = 75
|
||||
# header += " " * delta
|
||||
pr_body += f"""<thead><tr><th></th><th align="left">{header}</th></tr></thead>"""
|
||||
pr_body += """<tbody>"""
|
||||
@ -374,8 +379,7 @@ class PRDescription:
|
||||
for filename, file_changes_title, file_change_description in list_tuples:
|
||||
filename = filename.replace("'", "`")
|
||||
filename_publish = filename.split("/")[-1]
|
||||
file_changes_title_br = insert_br_after_x_chars(file_changes_title, x=(delta - 5),
|
||||
new_line_char="\n\n")
|
||||
file_changes_title_br = insert_br_after_x_chars(file_changes_title, x=(delta - 5))
|
||||
file_changes_title_extended = file_changes_title_br.strip() + "</code>"
|
||||
if len(file_changes_title_extended) < (delta - 5):
|
||||
file_changes_title_extended += " " * ((delta - 5) - len(file_changes_title_extended))
|
||||
@ -407,7 +411,11 @@ class PRDescription:
|
||||
|
||||
{filename}
|
||||
{file_change_description_br}
|
||||
</details>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
</td>
|
||||
<td><a href="{link}">{diff_plus_minus}</a>{delta_nbsp}</td>
|
||||
</tr>
|
||||
@ -423,48 +431,74 @@ class PRDescription:
|
||||
pass
|
||||
return pr_body
|
||||
|
||||
def insert_br_after_x_chars(text, x=70, new_line_char="<br> "):
|
||||
def insert_br_after_x_chars(text, x=70):
|
||||
"""
|
||||
Insert <br> into a string after a word that increases its length above x characters.
|
||||
Use proper HTML tags for code and new lines.
|
||||
"""
|
||||
if len(text) < x:
|
||||
return text
|
||||
|
||||
lines = text.splitlines()
|
||||
# replace odd instances of ` with <code> and even instances of ` with </code>
|
||||
text = replace_code_tags(text)
|
||||
|
||||
# convert list items to <li>
|
||||
if text.startswith("- "):
|
||||
text = "<li>" + text[2:]
|
||||
text = text.replace("\n- ", '<br><li> ').replace("\n - ", '<br><li> ')
|
||||
|
||||
# convert new lines to <br>
|
||||
text = text.replace("\n", '<br>')
|
||||
|
||||
# split text into lines
|
||||
lines = text.split('<br>')
|
||||
words = []
|
||||
for i,line in enumerate(lines):
|
||||
for i, line in enumerate(lines):
|
||||
words += line.split(' ')
|
||||
if i<len(lines)-1:
|
||||
words[-1] += "\n"
|
||||
if i < len(lines) - 1:
|
||||
words[-1] += "<br>"
|
||||
|
||||
def count_chars_without_html(string):
|
||||
if '<' not in string:
|
||||
return len(string)
|
||||
no_html_string = re.sub('<[^>]+>', '', string)
|
||||
return len(no_html_string)
|
||||
|
||||
# words = text.split(' ')
|
||||
|
||||
new_text = ""
|
||||
current_length = 0
|
||||
new_text = []
|
||||
is_inside_code = False
|
||||
current_length = 0
|
||||
for word in words:
|
||||
# Check if adding this word exceeds x characters
|
||||
if current_length + len(word) > x:
|
||||
if not is_inside_code:
|
||||
new_text += f"{new_line_char} " # Insert line break
|
||||
current_length = 0 # Reset counter
|
||||
is_saved_word = False
|
||||
if word == "<code>" or word == "</code>" or word == "<li>" or word == "<br>":
|
||||
is_saved_word = True
|
||||
|
||||
len_word = count_chars_without_html(word)
|
||||
if not is_saved_word and (current_length + len_word > x):
|
||||
if is_inside_code:
|
||||
new_text.append("</code><br><code>")
|
||||
else:
|
||||
new_text += f"`{new_line_char} `"
|
||||
# check if inside <code> tag
|
||||
if word.startswith("`") and not is_inside_code and not word.endswith("`"):
|
||||
is_inside_code = True
|
||||
if word.endswith("`"):
|
||||
is_inside_code = False
|
||||
new_text.append("<br>")
|
||||
current_length = 0 # Reset counter
|
||||
new_text.append(word + " ")
|
||||
|
||||
# Add the word to the new text
|
||||
if word.endswith("\n"):
|
||||
new_text += word
|
||||
else:
|
||||
new_text += word + " "
|
||||
current_length += len(word) + 1 # Add 1 for the space
|
||||
if not is_saved_word:
|
||||
current_length += len_word + 1 # Add 1 for the space
|
||||
|
||||
|
||||
if word.endswith("\n"):
|
||||
if word == "<li>" or word == "<br>":
|
||||
current_length = 0
|
||||
return new_text.strip() # Remove trailing space
|
||||
|
||||
if "<code>" in word:
|
||||
is_inside_code = True
|
||||
if "</code>" in word:
|
||||
is_inside_code = False
|
||||
return ''.join(new_text).strip()
|
||||
|
||||
def replace_code_tags(text):
|
||||
"""
|
||||
Replace odd instances of ` with <code> and even instances of ` with </code>
|
||||
"""
|
||||
parts = text.split('`')
|
||||
for i in range(1, len(parts), 2):
|
||||
parts[i] = '<code>' + parts[i] + '</code>'
|
||||
return ''.join(parts)
|
||||
|
||||
|
@ -36,6 +36,7 @@ class PRReviewer:
|
||||
ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None.
|
||||
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
||||
"""
|
||||
self.args = args
|
||||
self.parse_args(args) # -i command
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url, incremental=self.incremental)
|
||||
@ -102,6 +103,11 @@ class PRReviewer:
|
||||
if self.incremental.is_incremental and not self._can_run_incremental_review():
|
||||
return None
|
||||
|
||||
if isinstance(self.args, list) and self.args and self.args[0] == 'auto_approve':
|
||||
get_logger().info(f'Auto approve flow PR: {self.pr_url} ...')
|
||||
self.auto_approve_logic()
|
||||
return None
|
||||
|
||||
get_logger().info(f'Reviewing PR: {self.pr_url} ...')
|
||||
|
||||
if get_settings().config.publish_output:
|
||||
@ -392,3 +398,30 @@ class PRReviewer:
|
||||
self.git_provider.publish_labels(review_labels + current_labels_filtered)
|
||||
except Exception as e:
|
||||
get_logger().error(f"Failed to set review labels, error: {e}")
|
||||
|
||||
def auto_approve_logic(self):
|
||||
"""
|
||||
Auto-approve a pull request if it meets the conditions for auto-approval.
|
||||
"""
|
||||
if get_settings().pr_reviewer.enable_auto_approval:
|
||||
maximal_review_effort = get_settings().pr_reviewer.maximal_review_effort
|
||||
if maximal_review_effort < 5:
|
||||
current_labels = self.git_provider.get_pr_labels()
|
||||
for label in current_labels:
|
||||
if label.lower().startswith('review effort [1-5]:'):
|
||||
effort = int(label.split(':')[1].strip())
|
||||
if effort > maximal_review_effort:
|
||||
get_logger().info(
|
||||
f"Auto-approve error: PR review effort ({effort}) is higher than the maximal review effort "
|
||||
f"({maximal_review_effort}) allowed")
|
||||
self.git_provider.publish_comment(
|
||||
f"Auto-approve error: PR review effort ({effort}) is higher than the maximal review effort "
|
||||
f"({maximal_review_effort}) allowed")
|
||||
return
|
||||
is_auto_approved = self.git_provider.auto_approve()
|
||||
if is_auto_approved:
|
||||
get_logger().info("Auto-approved PR")
|
||||
self.git_provider.publish_comment("Auto-approved PR")
|
||||
else:
|
||||
get_logger().info("Auto-approval option is disabled")
|
||||
self.git_provider.publish_comment("Auto-approval option for PR-Agent is disabled")
|
Reference in New Issue
Block a user