diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 26a05855..b0e3a4c2 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -16,7 +16,6 @@ from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_update_changelog import PRUpdateChangelog -import inspect command2class = { "auto_review": PRReviewer, @@ -41,13 +40,6 @@ command2class = { commands = list(command2class.keys()) -def has_ai_handler_param(cls: object): - constructor = getattr(cls, "__init__", None) - if constructor is not None: - parameters = inspect.signature(constructor).parameters - return "ai_handler" in parameters - return False - class PRAgent: def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.ai_handler = ai_handler @@ -80,10 +72,7 @@ class PRAgent: notify() get_logger().info(f"Class: {command2class[action]}") - if(not has_ai_handler_param(cls=command2class[action])): - await command2class[action](pr_url, args=args).run() - else: - await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() + await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False return True diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 70dd66c2..a729233d 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -5,6 +5,7 @@ from typing import Dict from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -15,7 +16,8 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 77e0c29f..81e1ceab 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -4,6 +4,7 @@ from typing import Dict, List from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml @@ -14,7 +15,8 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index 8d9623cd..4915c5b6 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -5,6 +5,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +16,8 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index d59e6ccb..25e80a55 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -5,6 +5,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +16,8 @@ from pr_agent.log import get_logger class PRGenerateLabels: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels corresponding to the PR using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index e52765f7..a47d511b 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -3,6 +3,7 @@ import copy from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +13,8 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index 79edfd6a..5de3d776 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -3,6 +3,7 @@ import copy from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -12,7 +13,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index ca345cba..24a40af3 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -8,6 +8,7 @@ from jinja2 import Environment, StrictUndefined from yaml import SafeLoader from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels @@ -22,7 +23,8 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, + ai_handler: BaseAiHandler = LiteLLMAIHandler()): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index 07130749..b8c6187f 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -6,6 +6,7 @@ from typing import Tuple from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.config_loader import get_settings @@ -17,7 +18,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language(