Set LiteLLMAIHandler as default AI handler in all PR tools and simplify AI handler injection in PRAgent

This commit is contained in:
mrT23
2023-12-14 09:00:14 +02:00
parent 3531016a2c
commit 246be6147f
9 changed files with 23 additions and 20 deletions

View File

@ -16,7 +16,6 @@ from pr_agent.tools.pr_questions import PRQuestions
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_similar_issue import PRSimilarIssue
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
import inspect
command2class = { command2class = {
"auto_review": PRReviewer, "auto_review": PRReviewer,
@ -41,13 +40,6 @@ command2class = {
commands = list(command2class.keys()) commands = list(command2class.keys())
def has_ai_handler_param(cls: object):
constructor = getattr(cls, "__init__", None)
if constructor is not None:
parameters = inspect.signature(constructor).parameters
return "ai_handler" in parameters
return False
class PRAgent: class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler self.ai_handler = ai_handler
@ -80,10 +72,7 @@ class PRAgent:
notify() notify()
get_logger().info(f"Class: {command2class[action]}") get_logger().info(f"Class: {command2class[action]}")
if(not has_ai_handler_param(cls=command2class[action])): await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
await command2class[action](pr_url, args=args).run()
else:
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
else: else:
return False return False
return True return True

View File

@ -5,6 +5,7 @@ from typing import Dict
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,7 +16,8 @@ from pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -4,6 +4,7 @@ from typing import Dict, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -14,7 +15,8 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -5,6 +5,7 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +16,8 @@ from pr_agent.log import get_logger
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.

View File

@ -5,6 +5,7 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +16,8 @@ from pr_agent.log import get_logger
class PRGenerateLabels: class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model. corresponding to the PR using an AI model.

View File

@ -3,6 +3,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +13,8 @@ from pr_agent.log import get_logger
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()

View File

@ -3,6 +3,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +13,7 @@ from pr_agent.log import get_logger
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(

View File

@ -8,6 +8,7 @@ from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader from yaml import SafeLoader
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
@ -22,7 +23,8 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.

View File

@ -6,6 +6,7 @@ from typing import Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -17,7 +18,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None): def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(