moving the 'improve' command to turbo mode, with auto_extended=true

This commit is contained in:
mrT23
2024-02-01 09:46:04 +02:00
parent 2112defa51
commit d04d8b616a
5 changed files with 31 additions and 13 deletions

View File

@ -9,6 +9,7 @@ MAX_TOKENS = {
'gpt-4-0613': 8000, 'gpt-4-0613': 8000,
'gpt-4-32k': 32000, 'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens 'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000, 'claude-instant-1': 100000,
'claude-2': 100000, 'claude-2': 100000,
'command-nightly': 4096, 'command-nightly': 4096,

View File

@ -11,7 +11,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens from pr_agent.algo.utils import get_max_tokens, ModelType
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
from pr_agent.log import get_logger from pr_agent.log import get_logger
@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
return patches, modified_files_list, deleted_files_list, added_files_list return patches, modified_files_list, deleted_files_list, added_files_list
async def retry_with_fallback_models(f: Callable): async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models() all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models) all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception # try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)): for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
@ -243,7 +243,10 @@ async def retry_with_fallback_models(f: Callable):
raise # Re-raise the last exception raise # Re-raise the last exception
def _get_all_models() -> List[str]: def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
else:
model = get_settings().config.model model = get_settings().config.model
fallback_models = get_settings().config.fallback_models fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list): if not isinstance(fallback_models, list):

View File

@ -5,6 +5,7 @@ import json
import re import re
import textwrap import textwrap
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Any, List from typing import Any, List
import yaml import yaml
@ -15,6 +16,9 @@ from pr_agent.algo.token_handler import get_token_encoder
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger from pr_agent.log import get_logger
class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"
def get_setting(key: str) -> Any: def get_setting(key: str) -> Any:
try: try:

View File

@ -1,5 +1,6 @@
[config] [config]
model="gpt-4" # "gpt-4-0125-preview" model="gpt-4" # "gpt-4-0125-preview"
model_turbo="gpt-4-0125-preview"
fallback_models=["gpt-3.5-turbo-16k"] fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github" git_provider="github"
publish_output=true publish_output=true
@ -68,17 +69,18 @@ enable_help_text=true
[pr_code_suggestions] # /improve # [pr_code_suggestions] # /improve #
max_context_tokens=8000
num_code_suggestions=4 num_code_suggestions=4
summarize = true summarize = true
extra_instructions = "" extra_instructions = ""
rank_suggestions = false rank_suggestions = false
enable_help_text=true enable_help_text=true
# params for '/improve --extended' mode # params for '/improve --extended' mode
auto_extended_mode=false auto_extended_mode=true
num_code_suggestions_per_chunk=8 num_code_suggestions_per_chunk=5
rank_extended_suggestions = true rank_extended_suggestions = false
max_number_of_calls = 5 max_number_of_calls = 3
final_clip_factor = 0.9 final_clip_factor = 0.8
[pr_add_docs] # /add_docs # [pr_add_docs] # /add_docs #
extra_instructions = "" extra_instructions = ""

View File

@ -8,7 +8,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, replace_code_tags from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language from pr_agent.git_providers.git_provider import get_main_pr_language
@ -26,6 +26,14 @@ class PRCodeSuggestions:
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
# extended mode # extended mode
try: try:
self.is_extended = self._get_is_extended(args or []) self.is_extended = self._get_is_extended(args or [])
@ -64,10 +72,10 @@ class PRCodeSuggestions:
get_logger().info('Preparing PR code suggestions...') get_logger().info('Preparing PR code suggestions...')
if not self.is_extended: if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction) await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions() data = self._prepare_pr_code_suggestions()
else: else:
data = await retry_with_fallback_models(self._prepare_prediction_extended) data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
if (not data) or (not 'code_suggestions' in data): if (not data) or (not 'code_suggestions' in data):