diff --git a/docs/docs/usage-guide/changing_a_model.md b/docs/docs/usage-guide/changing_a_model.md index 57e0787e..c86af096 100644 --- a/docs/docs/usage-guide/changing_a_model.md +++ b/docs/docs/usage-guide/changing_a_model.md @@ -5,7 +5,6 @@ To use a different model than the default (GPT-4), you need to edit in the [conf ``` [config] model = "..." -model_weak = "..." fallback_models = ["..."] ``` @@ -27,9 +26,8 @@ deployment_id = "" # The deployment name you chose when you deployed the engine and set in your configuration file: ``` [config] -model="" # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo) -model_weak="" # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo) -fallback_models=["..."] # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo) +model="" # the OpenAI model you've deployed on Azure (e.g. gpt-4o) +fallback_models=["..."] ``` ### Hugging Face @@ -52,7 +50,6 @@ MAX_TOKENS={ [config] # in configuration.toml model = "ollama/llama2" -model_weak = "ollama/llama2" fallback_models=["ollama/llama2"] [ollama] # in .secrets.toml @@ -76,7 +73,6 @@ MAX_TOKENS={ } [config] # in configuration.toml model = "huggingface/meta-llama/Llama-2-7b-chat-hf" -model_weak = "huggingface/meta-llama/Llama-2-7b-chat-hf" fallback_models=["huggingface/meta-llama/Llama-2-7b-chat-hf"] [huggingface] # in .secrets.toml @@ -91,7 +87,6 @@ To use Llama2 model with Replicate, for example, set: ``` [config] # in configuration.toml model = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" -model_weak = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" fallback_models=["replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"] [replicate] # in .secrets.toml key = ... @@ -107,7 +102,6 @@ To use Llama3 model with Groq, for example, set: ``` [config] # in configuration.toml model = "llama3-70b-8192" -model_weak = "llama3-70b-8192" fallback_models = ["groq/llama3-70b-8192"] [groq] # in .secrets.toml key = ... # your Groq api key @@ -121,7 +115,6 @@ To use Google's Vertex AI platform and its associated models (chat-bison/codecha ``` [config] # in configuration.toml model = "vertex_ai/codechat-bison" -model_weak = "vertex_ai/codechat-bison" fallback_models="vertex_ai/codechat-bison" [vertexai] # in .secrets.toml @@ -140,7 +133,6 @@ To use [Google AI Studio](https://aistudio.google.com/) models, set the relevant ```toml [config] # in configuration.toml model="google_ai_studio/gemini-1.5-flash" -model_weak="google_ai_studio/gemini-1.5-flash" fallback_models=["google_ai_studio/gemini-1.5-flash"] [google_ai_studio] # in .secrets.toml @@ -156,7 +148,6 @@ To use Anthropic models, set the relevant models in the configuration section of ``` [config] model="anthropic/claude-3-opus-20240229" -model_weak="anthropic/claude-3-opus-20240229" fallback_models=["anthropic/claude-3-opus-20240229"] ``` @@ -173,7 +164,6 @@ To use Amazon Bedrock and its foundational models, add the below configuration: ``` [config] # in configuration.toml model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0" -model_weak="bedrock/anthropic.claude-3-sonnet-20240229-v1:0" fallback_models=["bedrock/anthropic.claude-v2:1"] ``` @@ -195,7 +185,6 @@ If the relevant model doesn't appear [here](https://github.com/Codium-ai/pr-agen ``` [config] model="custom_model_name" -model_weak="custom_model_name" fallback_models=["custom_model_name"] ``` (2) Set the maximal tokens for the model: diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index b37c3001..d45a4fab 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -11,7 +11,7 @@ from pr_agent.algo.git_patch_processing import ( from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo -from pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens +from pr_agent.algo.utils import ModelType, clip_tokens, get_max_tokens, get_weak_model from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import GitProvider from pr_agent.log import get_logger @@ -333,7 +333,7 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod return total_tokens, patches, remaining_files_list_new, files_in_patch_list -async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.WEAK): +async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR): all_models = _get_all_models(model_type) all_deployments = _get_all_deployments(all_models) # try each (model, deployment_id) pair until one is successful, otherwise raise exception @@ -354,8 +354,8 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]: - if get_settings().config.get('model_weak') and model_type == ModelType.WEAK: - model = get_settings().config.model_weak + if model_type == ModelType.WEAK: + model = get_weak_model() else: model = get_settings().config.model fallback_models = get_settings().config.fallback_models diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 371a5df9..2310838b 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -27,6 +27,12 @@ from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger +def get_weak_model() -> str: + if get_settings().get("config.model_weak"): + return get_settings().config.model_weak + return get_settings().config.model + + class Range(BaseModel): line_start: int # should be 0-indexed line_end: int diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 4e282953..634c8acc 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,8 +1,8 @@ [config] # models -model_weak="gpt-4o-mini-2024-07-18" model="gpt-4o-2024-11-20" fallback_models=["gpt-4o-2024-08-06"] +#model_weak="gpt-4o-mini-2024-07-18" # optional, a weaker model to use for some easier tasks # CLI git_provider="github" publish_output=true