Add support for fallback models

This commit is contained in:
Ori Kotek
2023-07-23 16:16:36 +03:00
parent e34f9d8d1c
commit 02a1d8dbfc
8 changed files with 75 additions and 52 deletions

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import logging
from typing import Tuple, Union
from typing import Tuple, Union, Callable, List
from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
@ -10,7 +11,6 @@ from pr_agent.algo.utils import load_large_diff
from pr_agent.config_loader import settings
from pr_agent.git_providers.git_provider import GitProvider
DELETED_FILES_ = "Deleted files:\n"
MORE_MODIFIED_FILES_ = "More modified files:\n"
@ -20,14 +20,15 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
PATCH_EXTRA_LINES = 3
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
"""
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Args:
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
model (str): The name of the model used for tokenization.
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
@ -49,7 +50,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
add_line_numbers_to_hunks)
# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
return "\n".join(patches_extended)
# if we are over the limit, start pruning
@ -110,13 +111,14 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
return patches_extended, total_tokens
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
"""
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
Args:
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
model (str): The model used for tokenization.
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
Returns:
Tuple[list, list, list]: A tuple containing the following lists:
@ -131,7 +133,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
3. Minimize deleted hunks
4. Minimize all remaining files when you reach token limit
"""
patches = []
modified_files_list = []
@ -166,12 +167,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
new_patch_tokens = token_handler.count_tokens(patch)
# Hard Stop, no more tokens
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
continue
# If the patch is too large, just show the file name
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
# Current logic is to skip the patch if it's too large
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
# until we meet the requirements
@ -196,3 +197,14 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
return patches, modified_files_list, deleted_files_list
async def retry_with_fallback_models(f: Callable):
model = settings.config.model
fallback_models = settings.config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [fallback_models]
all_models = [model] + fallback_models
for model in all_models:
try:
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}")

View File

@ -26,7 +26,6 @@ class TokenHandler:
- user: The user string.
"""
self.encoder = encoding_for_model(settings.config.model)
self.limit = MAX_TOKENS[settings.config.model]
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):