mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 04:10:49 +08:00
moving the 'improve' command to turbo mode, with auto_extended=true
This commit is contained in:
@ -9,6 +9,7 @@ MAX_TOKENS = {
|
|||||||
'gpt-4-0613': 8000,
|
'gpt-4-0613': 8000,
|
||||||
'gpt-4-32k': 32000,
|
'gpt-4-32k': 32000,
|
||||||
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
|
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
|
||||||
'claude-instant-1': 100000,
|
'claude-instant-1': 100000,
|
||||||
'claude-2': 100000,
|
'claude-2': 100000,
|
||||||
'command-nightly': 4096,
|
'command-nightly': 4096,
|
||||||
|
@ -11,7 +11,7 @@ from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbe
|
|||||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
from pr_agent.algo.file_filter import filter_ignored
|
from pr_agent.algo.file_filter import filter_ignored
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import get_max_tokens
|
from pr_agent.algo.utils import get_max_tokens, ModelType
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
|
||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
|
|||||||
return patches, modified_files_list, deleted_files_list, added_files_list
|
return patches, modified_files_list, deleted_files_list, added_files_list
|
||||||
|
|
||||||
|
|
||||||
async def retry_with_fallback_models(f: Callable):
|
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
|
||||||
all_models = _get_all_models()
|
all_models = _get_all_models(model_type)
|
||||||
all_deployments = _get_all_deployments(all_models)
|
all_deployments = _get_all_deployments(all_models)
|
||||||
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
|
||||||
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
|
||||||
@ -243,7 +243,10 @@ async def retry_with_fallback_models(f: Callable):
|
|||||||
raise # Re-raise the last exception
|
raise # Re-raise the last exception
|
||||||
|
|
||||||
|
|
||||||
def _get_all_models() -> List[str]:
|
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
|
||||||
|
if model_type == ModelType.TURBO:
|
||||||
|
model = get_settings().config.model_turbo
|
||||||
|
else:
|
||||||
model = get_settings().config.model
|
model = get_settings().config.model
|
||||||
fallback_models = get_settings().config.fallback_models
|
fallback_models = get_settings().config.fallback_models
|
||||||
if not isinstance(fallback_models, list):
|
if not isinstance(fallback_models, list):
|
||||||
|
@ -5,6 +5,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -15,6 +16,9 @@ from pr_agent.algo.token_handler import get_token_encoder
|
|||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
REGULAR = "regular"
|
||||||
|
TURBO = "turbo"
|
||||||
|
|
||||||
def get_setting(key: str) -> Any:
|
def get_setting(key: str) -> Any:
|
||||||
try:
|
try:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
[config]
|
[config]
|
||||||
model="gpt-4" # "gpt-4-0125-preview"
|
model="gpt-4" # "gpt-4-0125-preview"
|
||||||
|
model_turbo="gpt-4-0125-preview"
|
||||||
fallback_models=["gpt-3.5-turbo-16k"]
|
fallback_models=["gpt-3.5-turbo-16k"]
|
||||||
git_provider="github"
|
git_provider="github"
|
||||||
publish_output=true
|
publish_output=true
|
||||||
@ -68,17 +69,18 @@ enable_help_text=true
|
|||||||
|
|
||||||
|
|
||||||
[pr_code_suggestions] # /improve #
|
[pr_code_suggestions] # /improve #
|
||||||
|
max_context_tokens=8000
|
||||||
num_code_suggestions=4
|
num_code_suggestions=4
|
||||||
summarize = true
|
summarize = true
|
||||||
extra_instructions = ""
|
extra_instructions = ""
|
||||||
rank_suggestions = false
|
rank_suggestions = false
|
||||||
enable_help_text=true
|
enable_help_text=true
|
||||||
# params for '/improve --extended' mode
|
# params for '/improve --extended' mode
|
||||||
auto_extended_mode=false
|
auto_extended_mode=true
|
||||||
num_code_suggestions_per_chunk=8
|
num_code_suggestions_per_chunk=5
|
||||||
rank_extended_suggestions = true
|
rank_extended_suggestions = false
|
||||||
max_number_of_calls = 5
|
max_number_of_calls = 3
|
||||||
final_clip_factor = 0.9
|
final_clip_factor = 0.8
|
||||||
|
|
||||||
[pr_add_docs] # /add_docs #
|
[pr_add_docs] # /add_docs #
|
||||||
extra_instructions = ""
|
extra_instructions = ""
|
||||||
|
@ -8,7 +8,7 @@ from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
|||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_yaml, replace_code_tags
|
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
@ -26,6 +26,14 @@ class PRCodeSuggestions:
|
|||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# limit context specifically for the improve command, which has hard input to parse:
|
||||||
|
if get_settings().pr_code_suggestions.max_context_tokens:
|
||||||
|
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
|
||||||
|
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
|
||||||
|
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
|
||||||
|
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE
|
||||||
|
|
||||||
|
|
||||||
# extended mode
|
# extended mode
|
||||||
try:
|
try:
|
||||||
self.is_extended = self._get_is_extended(args or [])
|
self.is_extended = self._get_is_extended(args or [])
|
||||||
@ -64,10 +72,10 @@ class PRCodeSuggestions:
|
|||||||
|
|
||||||
get_logger().info('Preparing PR code suggestions...')
|
get_logger().info('Preparing PR code suggestions...')
|
||||||
if not self.is_extended:
|
if not self.is_extended:
|
||||||
await retry_with_fallback_models(self._prepare_prediction)
|
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
|
||||||
data = self._prepare_pr_code_suggestions()
|
data = self._prepare_pr_code_suggestions()
|
||||||
else:
|
else:
|
||||||
data = await retry_with_fallback_models(self._prepare_prediction_extended)
|
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)
|
||||||
|
|
||||||
|
|
||||||
if (not data) or (not 'code_suggestions' in data):
|
if (not data) or (not 'code_suggestions' in data):
|
||||||
|
Reference in New Issue
Block a user