moving the 'improve' command to turbo mode, with auto_extended=true

This commit is contained in:
mrT23
2024-02-01 09:46:04 +02:00
parent 2112defa51
commit d04d8b616a
5 changed files with 31 additions and 13 deletions

View File

@ -9,6 +9,7 @@ MAX_TOKENS = {
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,

View File

@ -11,7 +11,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.algo.utils import get_max_tokens, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
from pr_agent.log import get_logger
@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
return patches, modified_files_list, deleted_files_list, added_files_list
async def retry_with_fallback_models(f: Callable):
all_models = _get_all_models()
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
@ -243,8 +243,11 @@ async def retry_with_fallback_models(f: Callable):
raise # Re-raise the last exception
def _get_all_models() -> List[str]:
model = get_settings().config.model
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]

View File

@ -5,6 +5,7 @@ import json
import re
import textwrap
from datetime import datetime
from enum import Enum
from typing import Any, List
import yaml
@ -15,6 +16,9 @@ from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger
class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"
def get_setting(key: str) -> Any:
try:

View File

@ -1,5 +1,6 @@
[config]
model="gpt-4" # "gpt-4-0125-preview"
model_turbo="gpt-4-0125-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
@ -68,17 +69,18 @@ enable_help_text=true
[pr_code_suggestions] # /improve #
max_context_tokens=8000
num_code_suggestions=4
summarize = true
extra_instructions = ""
rank_suggestions = false
enable_help_text=true
# params for '/improve --extended' mode
auto_extended_mode=false
num_code_suggestions_per_chunk=8
rank_extended_suggestions = true
max_number_of_calls = 5
final_clip_factor = 0.9
auto_extended_mode=true
num_code_suggestions_per_chunk=5
rank_extended_suggestions = false
max_number_of_calls = 3
final_clip_factor = 0.8
[pr_add_docs] # /add_docs #
extra_instructions = ""

View File

@ -8,7 +8,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, replace_code_tags
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
@ -26,6 +26,14 @@ class PRCodeSuggestions:
self.git_provider.get_languages(), self.git_provider.get_files()
)
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
@ -64,10 +72,10 @@ class PRCodeSuggestions:
get_logger().info('Preparing PR code suggestions...')
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
if (not data) or (not 'code_suggestions' in data):