Merge pull request #437 from Codium-ai/tr/new_gpt4

Introduce support for 'gpt-4-1106-preview' model and dynamic token limit calculation
This commit is contained in:
mrT23
2023-11-07 04:49:50 -08:00
committed by GitHub
5 changed files with 21 additions and 8 deletions

View File

@ -8,6 +8,7 @@ MAX_TOKENS = {
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,

View File

@ -7,11 +7,11 @@ from typing import Any, Callable, List, Tuple
from github import RateLimitExceededException
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler, get_token_encoder
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider
from pr_agent.log import get_logger
@ -64,7 +64,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return "\n".join(patches_extended)
# if we are over the limit, start pruning
@ -179,12 +179,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
new_patch_tokens = token_handler.count_tokens(patch)
# Hard Stop, no more tokens
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
if total_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
get_logger().warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
@ -403,7 +403,7 @@ def get_pr_multi_diffs(git_provider: GitProvider,
patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
if patch and (total_tokens + new_patch_tokens > get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []

View File

@ -9,6 +9,8 @@ from typing import Any, List
import yaml
from starlette_context import context
from pr_agent.algo import MAX_TOKENS
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger
@ -341,3 +343,12 @@ def get_user_labels(current_labels):
if user_labels:
get_logger().info(f"Keeping user labels: {user_labels}")
return user_labels
def get_max_tokens(model):
settings = get_settings()
max_tokens_model = MAX_TOKENS[model]
if settings.config.max_model_tokens:
max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model)
# get_logger().debug(f"limiting max tokens to {max_tokens_model}")
return max_tokens_model

View File

@ -1,5 +1,5 @@
[config]
model="gpt-4"
model="gpt-4" # "gpt-4-1106-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
@ -10,6 +10,7 @@ use_repo_settings_file=true
ai_timeout=180
max_description_tokens = 500
max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
patch_extra_lines = 3
secret_provider="google_cloud_storage"
cli_mode=false

View File

@ -8,8 +8,8 @@ import pinecone
from pinecone_datasets import Dataset, DatasetMetadata
from pydantic import BaseModel, Field
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.log import get_logger
@ -197,7 +197,7 @@ class PRSimilarIssue:
username = issue.user.login
created_at = str(issue.created_at)
if len(issue_str) < 8000 or \
self.token_handler.count_tokens(issue_str) < MAX_TOKENS[MODEL]: # fast reject first
self.token_handler.count_tokens(issue_str) < get_max_tokens(MODEL): # fast reject first
issue_record = Record(
id=issue_key + "." + "issue",
text=issue_str,