Throw descriptive error message if model is not in MAX_TOKENS array

This commit is contained in:
Phill Zarfos
2023-12-03 21:06:55 -05:00
parent 526ad00812
commit 317fec0536
2 changed files with 8 additions and 2 deletions

View File

@ -262,7 +262,7 @@ MAX_TOKENS = {
e.g.
MAX_TOKENS={
...,
"llama2": 4096
"ollama/llama2": 4096
}
@ -271,6 +271,8 @@ model = "ollama/llama2"
[ollama] # in .secrets.toml
api_base = ... # the base url for your huggingface inference endpoint
# e.g. if running Ollama locally, you may use:
api_base = "http://localhost:11434/"
```
**Inference Endpoints**

View File

@ -383,7 +383,11 @@ def get_user_labels(current_labels: List[str] = None):
def get_max_tokens(model):
settings = get_settings()
if model in MAX_TOKENS:
max_tokens_model = MAX_TOKENS[model]
else:
raise Exception(f"MAX_TOKENS must be set for model {model} in ./pr_agent/algo/__init__.py")
if settings.config.max_model_tokens:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")