mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 12:20:38 +08:00
Add support for fallback models
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, Callable, List
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
@ -10,7 +11,6 @@ from pr_agent.algo.utils import load_large_diff
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
|
||||
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
||||
@ -20,14 +20,15 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
||||
PATCH_EXTRA_LINES = 3
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
||||
"""
|
||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
model (str): The name of the model used for tokenization.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
||||
|
||||
@ -49,7 +50,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
|
||||
return "\n".join(patches_extended)
|
||||
|
||||
# if we are over the limit, start pruning
|
||||
@ -110,13 +111,14 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
return patches_extended, total_tokens
|
||||
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||
"""
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
||||
Args:
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
model (str): The model used for tokenization.
|
||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||
Returns:
|
||||
Tuple[list, list, list]: A tuple containing the following lists:
|
||||
@ -131,7 +133,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
3. Minimize deleted hunks
|
||||
4. Minimize all remaining files when you reach token limit
|
||||
"""
|
||||
|
||||
|
||||
patches = []
|
||||
modified_files_list = []
|
||||
@ -166,12 +167,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
|
||||
# Hard Stop, no more tokens
|
||||
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
||||
continue
|
||||
|
||||
# If the patch is too large, just show the file name
|
||||
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
@ -196,3 +197,14 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
return patches, modified_files_list, deleted_files_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable):
|
||||
model = settings.config.model
|
||||
fallback_models = settings.config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [fallback_models]
|
||||
all_models = [model] + fallback_models
|
||||
for model in all_models:
|
||||
try:
|
||||
return await f(model)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate prediction with {model}: {e}")
|
||||
|
@ -26,7 +26,6 @@ class TokenHandler:
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(settings.config.model)
|
||||
self.limit = MAX_TOKENS[settings.config.model]
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
|
Reference in New Issue
Block a user