mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 03:40:38 +08:00
moving the 'improve' command to turbo mode, with auto_extended=true
This commit is contained in:
@ -9,6 +9,7 @@ MAX_TOKENS = {
|
||||
'gpt-4-0613': 8000,
|
||||
'gpt-4-32k': 32000,
|
||||
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||
'claude-instant-1': 100000,
|
||||
'claude-2': 100000,
|
||||
'command-nightly': 4096,
|
||||
|
@ -11,7 +11,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.file_filter import filter_ignored
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import get_max_tokens
|
||||
from pr_agent.algo.utils import get_max_tokens, ModelType
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||
from pr_agent.log import get_logger
|
||||
@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
||||
return patches, modified_files_list, deleted_files_list, added_files_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable):
|
||||
all_models = _get_all_models()
|
||||
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||
all_models = _get_all_models(model_type)
|
||||
all_deployments = _get_all_deployments(all_models)
|
||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
||||
@ -243,8 +243,11 @@ async def retry_with_fallback_models(f: Callable):
|
||||
raise # Re-raise the last exception
|
||||
|
||||
|
||||
def _get_all_models() -> List[str]:
|
||||
model = get_settings().config.model
|
||||
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||
if model_type == ModelType.TURBO:
|
||||
model = get_settings().config.model_turbo
|
||||
else:
|
||||
model = get_settings().config.model
|
||||
fallback_models = get_settings().config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [m.strip() for m in fallback_models.split(",")]
|
||||
|
@ -5,6 +5,7 @@ import json
|
||||
import re
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
@ -15,6 +16,9 @@ from pr_agent.algo.token_handler import get_token_encoder
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.log import get_logger
|
||||
|
||||
class ModelType(str, Enum):
|
||||
REGULAR = "regular"
|
||||
TURBO = "turbo"
|
||||
|
||||
def get_setting(key: str) -> Any:
|
||||
try:
|
||||
|
@ -1,5 +1,6 @@
|
||||
[config]
|
||||
model="gpt-4" # "gpt-4-0125-preview"
|
||||
model_turbo="gpt-4-0125-preview"
|
||||
fallback_models=["gpt-3.5-turbo-16k"]
|
||||
git_provider="github"
|
||||
publish_output=true
|
||||
@ -68,17 +69,18 @@ enable_help_text=true
|
||||
|
||||
|
||||
[pr_code_suggestions] # /improve #
|
||||
max_context_tokens=8000
|
||||
num_code_suggestions=4
|
||||
summarize = true
|
||||
extra_instructions = ""
|
||||
rank_suggestions = false
|
||||
enable_help_text=true
|
||||
# params for '/improve --extended' mode
|
||||
auto_extended_mode=false
|
||||
num_code_suggestions_per_chunk=8
|
||||
rank_extended_suggestions = true
|
||||
max_number_of_calls = 5
|
||||
final_clip_factor = 0.9
|
||||
auto_extended_mode=true
|
||||
num_code_suggestions_per_chunk=5
|
||||
rank_extended_suggestions = false
|
||||
max_number_of_calls = 3
|
||||
final_clip_factor = 0.8
|
||||
|
||||
[pr_add_docs] # /add_docs #
|
||||
extra_instructions = ""
|
||||
|
@ -8,7 +8,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, replace_code_tags
|
||||
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -26,6 +26,14 @@ class PRCodeSuggestions:
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
# limit context specifically for the improve command, which has hard input to parse:
|
||||
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||
|
||||
|
||||
# extended mode
|
||||
try:
|
||||
self.is_extended = self._get_is_extended(args or [])
|
||||
@ -64,10 +72,10 @@ class PRCodeSuggestions:
|
||||
|
||||
get_logger().info('Preparing PR code suggestions...')
|
||||
if not self.is_extended:
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
|
||||
data = self._prepare_pr_code_suggestions()
|
||||
else:
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended)
|
||||
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
|
||||
|
||||
|
||||
if (not data) or (not 'code_suggestions' in data):
|
||||
|
Reference in New Issue
Block a user