Merge remote-tracking branch 'origin/main'

This commit is contained in:
mrT23
2024-07-31 13:32:51 +03:00
10 changed files with 232 additions and 169 deletions

View File

@ -52,4 +52,10 @@ MAX_TOKENS = {
'groq/llama-3.1-70b-versatile': 131072,
'groq/llama-3.1-405b-reasoning': 131072,
'ollama/llama3': 4096,
'watsonx/meta-llama/llama-3-8b-instruct': 4096,
"watsonx/meta-llama/llama-3-70b-instruct": 4096,
"watsonx/meta-llama/llama-3-405b-instruct": 16384,
"watsonx/ibm/granite-13b-chat-v2": 8191,
"watsonx/ibm/granite-34b-code-instruct": 8191,
"watsonx/mistralai/mistral-large": 32768,
}

View File

@ -557,7 +557,7 @@ def _fix_key_value(key: str, value: str):
def load_yaml(response_text: str, keys_fix_yaml: List[str] = [], first_key="", last_key="") -> dict:
response_text = response_text.removeprefix('```yaml').rstrip('`')
response_text = response_text.strip('\n').removeprefix('```yaml').rstrip('`')
try:
data = yaml.safe_load(response_text)
except Exception as e:
@ -693,15 +693,25 @@ def get_user_labels(current_labels: List[str] = None):
def get_max_tokens(model):
"""
Get the maximum number of tokens allowed for a model.
logic:
(1) If the model is in './pr_agent/algo/__init__.py', use the value from there.
(2) else, the user needs to define explicitly 'config.custom_model_max_tokens'
For both cases, we further limit the number of tokens to 'config.max_model_tokens' if it is set.
This aims to improve the algorithmic quality, as the AI model degrades in performance when the input is too long.
"""
settings = get_settings()
if model in MAX_TOKENS:
max_tokens_model = MAX_TOKENS[model]
elif settings.config.custom_model_max_tokens > 0:
max_tokens_model = settings.config.custom_model_max_tokens
else:
raise Exception(f"MAX_TOKENS must be set for model {model} in ./pr_agent/algo/__init__.py")
raise Exception(f"Ensure {model} is defined in MAX_TOKENS in ./pr_agent/algo/__init__.py or set a positive value for it in config.custom_model_max_tokens")
if settings.config.max_model_tokens:
if settings.config.max_model_tokens and settings.config.max_model_tokens > 0:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model