mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Add support for fallback models
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple, Union, Callable, List
|
||||
|
||||
from pr_agent.algo import MAX_TOKENS
|
||||
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
@ -10,7 +11,6 @@ from pr_agent.algo.utils import load_large_diff
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers.git_provider import GitProvider
|
||||
|
||||
|
||||
DELETED_FILES_ = "Deleted files:\n"
|
||||
|
||||
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
||||
@ -20,14 +20,15 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
||||
PATCH_EXTRA_LINES = 3
|
||||
|
||||
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str:
|
||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
||||
"""
|
||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||
|
||||
Args:
|
||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
model (str): The name of the model used for tokenization.
|
||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
||||
|
||||
@ -49,7 +50,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
||||
add_line_numbers_to_hunks)
|
||||
|
||||
# if we are under the limit, return the full diff
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
|
||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
|
||||
return "\n".join(patches_extended)
|
||||
|
||||
# if we are over the limit, start pruning
|
||||
@ -110,13 +111,14 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
||||
return patches_extended, total_tokens
|
||||
|
||||
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||
"""
|
||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
||||
Args:
|
||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||
model (str): The model used for tokenization.
|
||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||
Returns:
|
||||
Tuple[list, list, list]: A tuple containing the following lists:
|
||||
@ -131,7 +133,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
3. Minimize deleted hunks
|
||||
4. Minimize all remaining files when you reach token limit
|
||||
"""
|
||||
|
||||
|
||||
patches = []
|
||||
modified_files_list = []
|
||||
@ -166,12 +167,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
new_patch_tokens = token_handler.count_tokens(patch)
|
||||
|
||||
# Hard Stop, no more tokens
|
||||
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
||||
continue
|
||||
|
||||
# If the patch is too large, just show the file name
|
||||
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||
# Current logic is to skip the patch if it's too large
|
||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||
# until we meet the requirements
|
||||
@ -196,3 +197,14 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
||||
return patches, modified_files_list, deleted_files_list
|
||||
|
||||
|
||||
async def retry_with_fallback_models(f: Callable):
|
||||
model = settings.config.model
|
||||
fallback_models = settings.config.fallback_models
|
||||
if not isinstance(fallback_models, list):
|
||||
fallback_models = [fallback_models]
|
||||
all_models = [model] + fallback_models
|
||||
for model in all_models:
|
||||
try:
|
||||
return await f(model)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate prediction with {model}: {e}")
|
||||
|
@ -26,7 +26,6 @@ class TokenHandler:
|
||||
- user: The user string.
|
||||
"""
|
||||
self.encoder = encoding_for_model(settings.config.model)
|
||||
self.limit = MAX_TOKENS[settings.config.model]
|
||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||
|
||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||
|
@ -1,5 +1,6 @@
|
||||
[config]
|
||||
model="gpt-4-0613"
|
||||
model="gpt-4"
|
||||
fallback-models=["gpt-3.5-turbo-16k", "gpt-3.5-turbo"]
|
||||
git_provider="github"
|
||||
publish_output=true
|
||||
publish_output_progress=true
|
||||
|
@ -6,7 +6,7 @@ import textwrap
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import try_fix_json
|
||||
from pr_agent.config_loader import settings
|
||||
@ -44,16 +44,7 @@ class PRCodeSuggestions:
|
||||
logging.info('Generating code suggestions for PR...')
|
||||
if settings.config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||
logging.info('Getting PR diff...')
|
||||
|
||||
# we are using extended hunk with line numbers for code suggestions
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=True)
|
||||
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction()
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
logging.info('Preparing PR review...')
|
||||
data = self._prepare_pr_code_suggestions()
|
||||
if settings.config.publish_output:
|
||||
@ -62,7 +53,18 @@ class PRCodeSuggestions:
|
||||
logging.info('Pushing inline code comments...')
|
||||
self.push_inline_code_suggestions(data)
|
||||
|
||||
async def _get_prediction(self):
|
||||
async def _prepare_prediction(self, model: str):
|
||||
logging.info('Getting PR diff...')
|
||||
# we are using extended hunk with line numbers for code suggestions
|
||||
self.patches_diff = get_pr_diff(self.git_provider,
|
||||
self.token_handler,
|
||||
model,
|
||||
add_line_numbers_to_hunks=True,
|
||||
disable_extra_lines=True)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -71,7 +73,6 @@ class PRCodeSuggestions:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
model = settings.config.model
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
|
||||
|
@ -5,7 +5,7 @@ import logging
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
@ -37,10 +37,7 @@ class PRDescription:
|
||||
logging.info('Generating a PR description...')
|
||||
if settings.config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction()
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
logging.info('Preparing answer...')
|
||||
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
||||
if settings.config.publish_output:
|
||||
@ -53,7 +50,13 @@ class PRDescription:
|
||||
self.git_provider.remove_initial_comment()
|
||||
return ""
|
||||
|
||||
async def _get_prediction(self):
|
||||
async def _prepare_prediction(self, model: str):
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -62,7 +65,6 @@ class PRDescription:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
model = settings.config.model
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
return response
|
||||
|
@ -4,13 +4,15 @@ import logging
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
|
||||
|
||||
|
||||
|
||||
class PRInformationFromUser:
|
||||
def __init__(self, pr_url: str):
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
@ -36,10 +38,7 @@ class PRInformationFromUser:
|
||||
logging.info('Generating question to the user...')
|
||||
if settings.config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction()
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
logging.info('Preparing questions...')
|
||||
pr_comment = self._prepare_pr_answer()
|
||||
if settings.config.publish_output:
|
||||
@ -48,7 +47,13 @@ class PRInformationFromUser:
|
||||
self.git_provider.remove_initial_comment()
|
||||
return ""
|
||||
|
||||
async def _get_prediction(self):
|
||||
async def _prepare_prediction(self, model):
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -57,7 +62,6 @@ class PRInformationFromUser:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
model = settings.config.model
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
return response
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.config_loader import settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
@ -46,10 +46,7 @@ class PRQuestions:
|
||||
logging.info('Answering a PR question...')
|
||||
if settings.config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction()
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
logging.info('Preparing answer...')
|
||||
pr_comment = self._prepare_pr_answer()
|
||||
if settings.config.publish_output:
|
||||
@ -58,7 +55,13 @@ class PRQuestions:
|
||||
self.git_provider.remove_initial_comment()
|
||||
return ""
|
||||
|
||||
async def _get_prediction(self):
|
||||
async def _prepare_prediction(self, model: str):
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -67,7 +70,6 @@ class PRQuestions:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
model = settings.config.model
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
return response
|
||||
|
@ -6,7 +6,7 @@ from collections import OrderedDict
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
from pr_agent.algo.ai_handler import AiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
||||
from pr_agent.config_loader import settings
|
||||
@ -64,10 +64,7 @@ class PRReviewer:
|
||||
logging.info('Reviewing PR...')
|
||||
if settings.config.publish_output:
|
||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction()
|
||||
await retry_with_fallback_models(self._prepare_prediction)
|
||||
logging.info('Preparing PR review...')
|
||||
pr_comment = self._prepare_pr_review()
|
||||
if settings.config.publish_output:
|
||||
@ -79,7 +76,13 @@ class PRReviewer:
|
||||
self._publish_inline_code_comments()
|
||||
return ""
|
||||
|
||||
async def _get_prediction(self):
|
||||
async def _prepare_prediction(self, model: str):
|
||||
logging.info('Getting PR diff...')
|
||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||
logging.info('Getting AI prediction...')
|
||||
self.prediction = await self._get_prediction(model)
|
||||
|
||||
async def _get_prediction(self, model: str):
|
||||
variables = copy.deepcopy(self.vars)
|
||||
variables["diff"] = self.patches_diff # update diff
|
||||
environment = Environment(undefined=StrictUndefined)
|
||||
@ -88,7 +91,6 @@ class PRReviewer:
|
||||
if settings.config.verbosity_level >= 2:
|
||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||
model = settings.config.model
|
||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||
system=system_prompt, user=user_prompt)
|
||||
|
||||
|
Reference in New Issue
Block a user