mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-05 21:30:40 +08:00
Add support for fallback models
This commit is contained in:
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union, Callable, List
|
||||||
|
|
||||||
|
from pr_agent.algo import MAX_TOKENS
|
||||||
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
|
||||||
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
from pr_agent.algo.language_handler import sort_files_by_main_languages
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
@ -10,7 +11,6 @@ from pr_agent.algo.utils import load_large_diff
|
|||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers.git_provider import GitProvider
|
from pr_agent.git_providers.git_provider import GitProvider
|
||||||
|
|
||||||
|
|
||||||
DELETED_FILES_ = "Deleted files:\n"
|
DELETED_FILES_ = "Deleted files:\n"
|
||||||
|
|
||||||
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
MORE_MODIFIED_FILES_ = "More modified files:\n"
|
||||||
@ -20,7 +20,7 @@ OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD = 600
|
|||||||
PATCH_EXTRA_LINES = 3
|
PATCH_EXTRA_LINES = 3
|
||||||
|
|
||||||
|
|
||||||
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
|
||||||
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
|
||||||
@ -28,6 +28,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
|||||||
Args:
|
Args:
|
||||||
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
git_provider (GitProvider): An object of the GitProvider class representing the Git provider used for the pull request.
|
||||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||||
|
model (str): The name of the model used for tokenization.
|
||||||
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
add_line_numbers_to_hunks (bool, optional): A boolean indicating whether to add line numbers to the hunks in the diff. Defaults to False.
|
||||||
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
disable_extra_lines (bool, optional): A boolean indicating whether to disable the extension of each patch with extra lines of context. Defaults to False.
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler,
|
|||||||
add_line_numbers_to_hunks)
|
add_line_numbers_to_hunks)
|
||||||
|
|
||||||
# if we are under the limit, return the full diff
|
# if we are under the limit, return the full diff
|
||||||
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < token_handler.limit:
|
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < MAX_TOKENS[model]:
|
||||||
return "\n".join(patches_extended)
|
return "\n".join(patches_extended)
|
||||||
|
|
||||||
# if we are over the limit, start pruning
|
# if we are over the limit, start pruning
|
||||||
@ -110,13 +111,14 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
|
|||||||
return patches_extended, total_tokens
|
return patches_extended, total_tokens
|
||||||
|
|
||||||
|
|
||||||
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
|
||||||
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list]:
|
||||||
"""
|
"""
|
||||||
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used.
|
||||||
Args:
|
Args:
|
||||||
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
top_langs (list): A list of dictionaries representing the languages used in the pull request and their corresponding files.
|
||||||
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
token_handler (TokenHandler): An object of the TokenHandler class used for handling tokens in the context of the pull request.
|
||||||
|
model (str): The model used for tokenization.
|
||||||
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
convert_hunks_to_line_numbers (bool): A boolean indicating whether to convert hunks to line numbers in the diff.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[list, list, list]: A tuple containing the following lists:
|
Tuple[list, list, list]: A tuple containing the following lists:
|
||||||
@ -132,7 +134,6 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
|||||||
4. Minimize all remaining files when you reach token limit
|
4. Minimize all remaining files when you reach token limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
patches = []
|
patches = []
|
||||||
modified_files_list = []
|
modified_files_list = []
|
||||||
deleted_files_list = []
|
deleted_files_list = []
|
||||||
@ -166,12 +167,12 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
|||||||
new_patch_tokens = token_handler.count_tokens(patch)
|
new_patch_tokens = token_handler.count_tokens(patch)
|
||||||
|
|
||||||
# Hard Stop, no more tokens
|
# Hard Stop, no more tokens
|
||||||
if total_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
if total_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD:
|
||||||
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
logging.warning(f"File was fully skipped, no more tokens: {file.filename}.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If the patch is too large, just show the file name
|
# If the patch is too large, just show the file name
|
||||||
if total_tokens + new_patch_tokens > token_handler.limit - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
if total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD:
|
||||||
# Current logic is to skip the patch if it's too large
|
# Current logic is to skip the patch if it's too large
|
||||||
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
# TODO: Option for alternative logic to remove hunks from the patch to reduce the number of tokens
|
||||||
# until we meet the requirements
|
# until we meet the requirements
|
||||||
@ -196,3 +197,14 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler,
|
|||||||
return patches, modified_files_list, deleted_files_list
|
return patches, modified_files_list, deleted_files_list
|
||||||
|
|
||||||
|
|
||||||
|
async def retry_with_fallback_models(f: Callable):
|
||||||
|
model = settings.config.model
|
||||||
|
fallback_models = settings.config.fallback_models
|
||||||
|
if not isinstance(fallback_models, list):
|
||||||
|
fallback_models = [fallback_models]
|
||||||
|
all_models = [model] + fallback_models
|
||||||
|
for model in all_models:
|
||||||
|
try:
|
||||||
|
return await f(model)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to generate prediction with {model}: {e}")
|
||||||
|
@ -26,7 +26,6 @@ class TokenHandler:
|
|||||||
- user: The user string.
|
- user: The user string.
|
||||||
"""
|
"""
|
||||||
self.encoder = encoding_for_model(settings.config.model)
|
self.encoder = encoding_for_model(settings.config.model)
|
||||||
self.limit = MAX_TOKENS[settings.config.model]
|
|
||||||
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)
|
||||||
|
|
||||||
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
[config]
|
[config]
|
||||||
model="gpt-4-0613"
|
model="gpt-4"
|
||||||
|
fallback-models=["gpt-3.5-turbo-16k", "gpt-3.5-turbo"]
|
||||||
git_provider="github"
|
git_provider="github"
|
||||||
publish_output=true
|
publish_output=true
|
||||||
publish_output_progress=true
|
publish_output_progress=true
|
||||||
|
@ -6,7 +6,7 @@ import textwrap
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import try_fix_json
|
from pr_agent.algo.utils import try_fix_json
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
@ -44,16 +44,7 @@ class PRCodeSuggestions:
|
|||||||
logging.info('Generating code suggestions for PR...')
|
logging.info('Generating code suggestions for PR...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||||
logging.info('Getting PR diff...')
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
|
|
||||||
# we are using extended hunk with line numbers for code suggestions
|
|
||||||
self.patches_diff = get_pr_diff(self.git_provider,
|
|
||||||
self.token_handler,
|
|
||||||
add_line_numbers_to_hunks=True,
|
|
||||||
disable_extra_lines=True)
|
|
||||||
|
|
||||||
logging.info('Getting AI prediction...')
|
|
||||||
self.prediction = await self._get_prediction()
|
|
||||||
logging.info('Preparing PR review...')
|
logging.info('Preparing PR review...')
|
||||||
data = self._prepare_pr_code_suggestions()
|
data = self._prepare_pr_code_suggestions()
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
@ -62,7 +53,18 @@ class PRCodeSuggestions:
|
|||||||
logging.info('Pushing inline code comments...')
|
logging.info('Pushing inline code comments...')
|
||||||
self.push_inline_code_suggestions(data)
|
self.push_inline_code_suggestions(data)
|
||||||
|
|
||||||
async def _get_prediction(self):
|
async def _prepare_prediction(self, model: str):
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
# we are using extended hunk with line numbers for code suggestions
|
||||||
|
self.patches_diff = get_pr_diff(self.git_provider,
|
||||||
|
self.token_handler,
|
||||||
|
model,
|
||||||
|
add_line_numbers_to_hunks=True,
|
||||||
|
disable_extra_lines=True)
|
||||||
|
logging.info('Getting AI prediction...')
|
||||||
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
|
async def _get_prediction(self, model: str):
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
@ -71,7 +73,6 @@ class PRCodeSuggestions:
|
|||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
model = settings.config.model
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
system=system_prompt, user=user_prompt)
|
system=system_prompt, user=user_prompt)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import logging
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
@ -37,10 +37,7 @@ class PRDescription:
|
|||||||
logging.info('Generating a PR description...')
|
logging.info('Generating a PR description...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing pr description...", is_temporary=True)
|
||||||
logging.info('Getting PR diff...')
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
|
||||||
logging.info('Getting AI prediction...')
|
|
||||||
self.prediction = await self._get_prediction()
|
|
||||||
logging.info('Preparing answer...')
|
logging.info('Preparing answer...')
|
||||||
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
pr_title, pr_body, pr_types, markdown_text = self._prepare_pr_answer()
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
@ -53,7 +50,13 @@ class PRDescription:
|
|||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _get_prediction(self):
|
async def _prepare_prediction(self, model: str):
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
|
logging.info('Getting AI prediction...')
|
||||||
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
|
async def _get_prediction(self, model: str):
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
@ -62,7 +65,6 @@ class PRDescription:
|
|||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
model = settings.config.model
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
system=system_prompt, user=user_prompt)
|
system=system_prompt, user=user_prompt)
|
||||||
return response
|
return response
|
||||||
|
@ -4,13 +4,15 @@ import logging
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PRInformationFromUser:
|
class PRInformationFromUser:
|
||||||
def __init__(self, pr_url: str):
|
def __init__(self, pr_url: str):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
@ -36,10 +38,7 @@ class PRInformationFromUser:
|
|||||||
logging.info('Generating question to the user...')
|
logging.info('Generating question to the user...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing questions...", is_temporary=True)
|
||||||
logging.info('Getting PR diff...')
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
|
||||||
logging.info('Getting AI prediction...')
|
|
||||||
self.prediction = await self._get_prediction()
|
|
||||||
logging.info('Preparing questions...')
|
logging.info('Preparing questions...')
|
||||||
pr_comment = self._prepare_pr_answer()
|
pr_comment = self._prepare_pr_answer()
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
@ -48,7 +47,13 @@ class PRInformationFromUser:
|
|||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _get_prediction(self):
|
async def _prepare_prediction(self, model):
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
|
logging.info('Getting AI prediction...')
|
||||||
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
|
async def _get_prediction(self, model: str):
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
@ -57,7 +62,6 @@ class PRInformationFromUser:
|
|||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
model = settings.config.model
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
system=system_prompt, user=user_prompt)
|
system=system_prompt, user=user_prompt)
|
||||||
return response
|
return response
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
@ -46,10 +46,7 @@ class PRQuestions:
|
|||||||
logging.info('Answering a PR question...')
|
logging.info('Answering a PR question...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing answer...", is_temporary=True)
|
||||||
logging.info('Getting PR diff...')
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
|
||||||
logging.info('Getting AI prediction...')
|
|
||||||
self.prediction = await self._get_prediction()
|
|
||||||
logging.info('Preparing answer...')
|
logging.info('Preparing answer...')
|
||||||
pr_comment = self._prepare_pr_answer()
|
pr_comment = self._prepare_pr_answer()
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
@ -58,7 +55,13 @@ class PRQuestions:
|
|||||||
self.git_provider.remove_initial_comment()
|
self.git_provider.remove_initial_comment()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _get_prediction(self):
|
async def _prepare_prediction(self, model: str):
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
|
logging.info('Getting AI prediction...')
|
||||||
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
|
async def _get_prediction(self, model: str):
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
@ -67,7 +70,6 @@ class PRQuestions:
|
|||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
model = settings.config.model
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
system=system_prompt, user=user_prompt)
|
system=system_prompt, user=user_prompt)
|
||||||
return response
|
return response
|
||||||
|
@ -6,7 +6,7 @@ from collections import OrderedDict
|
|||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handler import AiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
from pr_agent.algo.utils import convert_to_markdown, try_fix_json
|
||||||
from pr_agent.config_loader import settings
|
from pr_agent.config_loader import settings
|
||||||
@ -64,10 +64,7 @@ class PRReviewer:
|
|||||||
logging.info('Reviewing PR...')
|
logging.info('Reviewing PR...')
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
self.git_provider.publish_comment("Preparing review...", is_temporary=True)
|
||||||
logging.info('Getting PR diff...')
|
await retry_with_fallback_models(self._prepare_prediction)
|
||||||
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler)
|
|
||||||
logging.info('Getting AI prediction...')
|
|
||||||
self.prediction = await self._get_prediction()
|
|
||||||
logging.info('Preparing PR review...')
|
logging.info('Preparing PR review...')
|
||||||
pr_comment = self._prepare_pr_review()
|
pr_comment = self._prepare_pr_review()
|
||||||
if settings.config.publish_output:
|
if settings.config.publish_output:
|
||||||
@ -79,7 +76,13 @@ class PRReviewer:
|
|||||||
self._publish_inline_code_comments()
|
self._publish_inline_code_comments()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _get_prediction(self):
|
async def _prepare_prediction(self, model: str):
|
||||||
|
logging.info('Getting PR diff...')
|
||||||
|
self.patches_diff = get_pr_diff(self.git_provider, self.token_handler, model)
|
||||||
|
logging.info('Getting AI prediction...')
|
||||||
|
self.prediction = await self._get_prediction(model)
|
||||||
|
|
||||||
|
async def _get_prediction(self, model: str):
|
||||||
variables = copy.deepcopy(self.vars)
|
variables = copy.deepcopy(self.vars)
|
||||||
variables["diff"] = self.patches_diff # update diff
|
variables["diff"] = self.patches_diff # update diff
|
||||||
environment = Environment(undefined=StrictUndefined)
|
environment = Environment(undefined=StrictUndefined)
|
||||||
@ -88,7 +91,6 @@ class PRReviewer:
|
|||||||
if settings.config.verbosity_level >= 2:
|
if settings.config.verbosity_level >= 2:
|
||||||
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
logging.info(f"\nSystem prompt:\n{system_prompt}")
|
||||||
logging.info(f"\nUser prompt:\n{user_prompt}")
|
logging.info(f"\nUser prompt:\n{user_prompt}")
|
||||||
model = settings.config.model
|
|
||||||
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
response, finish_reason = await self.ai_handler.chat_completion(model=model, temperature=0.2,
|
||||||
system=system_prompt, user=user_prompt)
|
system=system_prompt, user=user_prompt)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user