From 317fec05367936ad16a043e2a6c14d8d5665bf89 Mon Sep 17 00:00:00 2001 From: Phill Zarfos Date: Sun, 3 Dec 2023 21:06:55 -0500 Subject: [PATCH] Throw descriptive error message if model is not in MAX_TOKENS array --- Usage.md | 4 +++- pr_agent/algo/utils.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Usage.md b/Usage.md index 548249d3..9cd7b16f 100644 --- a/Usage.md +++ b/Usage.md @@ -262,7 +262,7 @@ MAX_TOKENS = { e.g. MAX_TOKENS={ ..., - "llama2": 4096 + "ollama/llama2": 4096 } @@ -271,6 +271,8 @@ model = "ollama/llama2" [ollama] # in .secrets.toml api_base = ... # the base url for your huggingface inference endpoint +# e.g. if running Ollama locally, you may use: +api_base = "http://localhost:11434/" ``` **Inference Endpoints** diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 1599f056..ded1b52c 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -383,7 +383,11 @@ def get_user_labels(current_labels: List[str] = None): def get_max_tokens(model): settings = get_settings() - max_tokens_model = MAX_TOKENS[model] + if model in MAX_TOKENS: + max_tokens_model = MAX_TOKENS[model] + else: + raise Exception(f"MAX_TOKENS must be set for model {model} in ./pr_agent/algo/__init__.py") + if settings.config.max_model_tokens: max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) # get_logger().debug(f"limiting max tokens to {max_tokens_model}")