diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 8bfaac50..fe629993 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging -from typing import Tuple, Union +from typing import Tuple, Union, Callable, List +from pr_agent.algo import MAX_TOKENS from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.token_handler import TokenHandler @@ -10,7 +11,6 @@ from pr_agent.algo.utils import load_large_diff from pr_agent.config_loader import settings from pr_agent.git_providers.git_provider import GitProvider - DELETED_FILES_ = "Deleted files:\n" MORE_MODIFIED_FILES_ = "More modified files:\n" @@ -20,14 +20,15 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600 PATCH_EXTRA_LINES = 3 -def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, - add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool =False) -> str: +def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str, + add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str: """ Returns a string with the diff of the pull request, applying diff minimization techniques if needed. Args: git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request. token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + model (str): The name of the model used for tokenization. add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False. disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False. @@ -49,7 +50,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, add_line_numbers_to_hunks) # if we are under the limit, return the full diff - if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit: + if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]: return "\n".join(patches_extended) # if we are over the limit, start pruning @@ -110,13 +111,14 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, return patches_extended, total_tokens -def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, +def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]: """ Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. Args: top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files. token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request. + model (str): The model used for tokenization. convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff. Returns: Tuple[list, list, list]: A tuple containing the following lists: @@ -131,7 +133,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, 3. Minimize deleted hunks 4. Minimize all remaining files when you reach token limit """ - patches = [] modified_files_list = [] @@ -166,12 +167,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, new_patch_tokens = token_handler.count_tokens(patch) # Hard Stop, no more tokens - if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: + if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD: logging.warning(f"File was fully skipped, no more tokens: {file.filename}.") continue # If the patch is too large, just show the file name - if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: + if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD: # Current logic is to skip the patch if it's too large # TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens # until we meet the requirements @@ -196,3 +197,14 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, return patches, modified_files_list, deleted_files_list +async def retry_with_fallback_models(f: Callable): + model = settings.config.model + fallback_models = settings.config.fallback_models + if not isinstance(fallback_models, list): + fallback_models = [fallback_models] + all_models = [model] + fallback_models + for model in all_models: + try: + return await f(model) + except Exception as e: + logging.warning(f"Failed to generate prediction with {model}: {e}") diff --git a/pr_agent/algo/token_handler.py b/pr_agent/algo/token_handler.py index 19d03df3..66659824 100644 --- a/pr_agent/algo/token_handler.py +++ b/pr_agent/algo/token_handler.py @@ -26,7 +26,6 @@ class TokenHandler: - user: The user string. """ self.encoder = encoding_for_model(settings.config.model) - self.limit = MAX_TOKENS[settings.config.model] self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user) def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user): diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 5f9de595..3861df13 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -1,5 +1,6 @@ [config] -model="gpt-4-0613" +model="gpt-4" +fallback-models=["gpt-3.5-turbo-16k", "gpt-3.5-turbo"] git_provider="github" publish_output=true publish_output_progress=true diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index bfc06b5c..f9d0efe4 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -6,7 +6,7 @@ import textwrap from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import try_fix_json from pr_agent.config_loader import settings @@ -44,16 +44,7 @@ class PRCodeSuggestions: logging.info('Generating code suggestions for PR...') if settings.config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) - logging.info('Getting PR diff...') - - # we are using extended hunk with line numbers for code suggestions - self.patches_diff = get_pr_diff(self.git_provider, - self.token_handler, - add_line_numbers_to_hunks=True, - disable_extra_lines=True) - - logging.info('Getting AI prediction...') - self.prediction = await self._get_prediction() + await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing PR review...') data = self._prepare_pr_code_suggestions() if settings.config.publish_output: @@ -62,7 +53,18 @@ class PRCodeSuggestions: logging.info('Pushing inline code comments...') self.push_inline_code_suggestions(data) - async def _get_prediction(self): + async def _prepare_prediction(self, model: str): + logging.info('Getting PR diff...') + # we are using extended hunk with line numbers for code suggestions + self.patches_diff = get_pr_diff(self.git_provider, + self.token_handler, + model, + add_line_numbers_to_hunks=True, + disable_extra_lines=True) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) @@ -71,7 +73,6 @@ class PRCodeSuggestions: if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - model = settings.config.model response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt) diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index a8647a83..bf5fde17 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -5,7 +5,7 @@ import logging from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider @@ -37,10 +37,7 @@ class PRDescription: logging.info('Generating a PR description...') if settings.config.publish_output: self.git_provider.publish_comment("Preparing pr description...", is_temporary=True) - logging.info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, self.token_handler) - logging.info('Getting AI prediction...') - self.prediction = await self._get_prediction() + await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing answer...') pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer() if settings.config.publish_output: @@ -53,7 +50,13 @@ class PRDescription: self.git_provider.remove_initial_comment() return "" - async def _get_prediction(self): + async def _prepare_prediction(self, model: str): + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) @@ -62,7 +65,6 @@ class PRDescription: if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - model = settings.config.model response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt) return response diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index b1cdc3b8..463b434e 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -4,13 +4,15 @@ import logging from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language + + class PRInformationFromUser: def __init__(self, pr_url: str): self.git_provider = get_git_provider()(pr_url) @@ -36,10 +38,7 @@ class PRInformationFromUser: logging.info('Generating question to the user...') if settings.config.publish_output: self.git_provider.publish_comment("Preparing questions...", is_temporary=True) - logging.info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, self.token_handler) - logging.info('Getting AI prediction...') - self.prediction = await self._get_prediction() + await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing questions...') pr_comment = self._prepare_pr_answer() if settings.config.publish_output: @@ -48,7 +47,13 @@ class PRInformationFromUser: self.git_provider.remove_initial_comment() return "" - async def _get_prediction(self): + async def _prepare_prediction(self, model): + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) @@ -57,7 +62,6 @@ class PRInformationFromUser: if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - model = settings.config.model response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt) return response diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 8d24c04c..589cf3e9 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -4,7 +4,7 @@ import logging from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import settings from pr_agent.git_providers import get_git_provider @@ -46,10 +46,7 @@ class PRQuestions: logging.info('Answering a PR question...') if settings.config.publish_output: self.git_provider.publish_comment("Preparing answer...", is_temporary=True) - logging.info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, self.token_handler) - logging.info('Getting AI prediction...') - self.prediction = await self._get_prediction() + await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing answer...') pr_comment = self._prepare_pr_answer() if settings.config.publish_output: @@ -58,7 +55,13 @@ class PRQuestions: self.git_provider.remove_initial_comment() return "" - async def _get_prediction(self): + async def _prepare_prediction(self, model: str): + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) @@ -67,7 +70,6 @@ class PRQuestions: if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - model = settings.config.model response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt) return response diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index c9446029..0bf952dd 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -6,7 +6,7 @@ from collections import OrderedDict from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handler import AiHandler -from pr_agent.algo.pr_processing import get_pr_diff +from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, try_fix_json from pr_agent.config_loader import settings @@ -64,10 +64,7 @@ class PRReviewer: logging.info('Reviewing PR...') if settings.config.publish_output: self.git_provider.publish_comment("Preparing review...", is_temporary=True) - logging.info('Getting PR diff...') - self.patches_diff = get_pr_diff(self.git_provider, self.token_handler) - logging.info('Getting AI prediction...') - self.prediction = await self._get_prediction() + await retry_with_fallback_models(self._prepare_prediction) logging.info('Preparing PR review...') pr_comment = self._prepare_pr_review() if settings.config.publish_output: @@ -79,7 +76,13 @@ class PRReviewer: self._publish_inline_code_comments() return "" - async def _get_prediction(self): + async def _prepare_prediction(self, model: str): + logging.info('Getting PR diff...') + self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model) + logging.info('Getting AI prediction...') + self.prediction = await self._get_prediction(model) + + async def _get_prediction(self, model: str): variables = copy.deepcopy(self.vars) variables["diff"] = self.patches_diff # update diff environment = Environment(undefined=StrictUndefined) @@ -88,7 +91,6 @@ class PRReviewer: if settings.config.verbosity_level >= 2: logging.info(f"\nSystem prompt:\n{system_prompt}") logging.info(f"\nUser prompt:\n{user_prompt}") - model = settings.config.model response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2, system=system_prompt, user=user_prompt)