diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index dfffcb6b..26a05855 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,5 +1,6 @@ import shlex from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings @@ -48,10 +49,8 @@ def has_ai_handler_param(cls: object): return False class PRAgent: - def __init__(self, ai_handler: BaseAiHandler = None): + def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): self.ai_handler = ai_handler - pass - async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists diff --git a/pr_agent/algo/base_ai_handler.py b/pr_agent/algo/base_ai_handler.py deleted file mode 100644 index 21050db9..00000000 --- a/pr_agent/algo/base_ai_handler.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABC, abstractmethod - -class BaseAiHandler(ABC): - """ - This class defines the interface for an AI handler. - """ - - @abstractmethod - def __init__(self): - pass - - @property - @abstractmethod - def deployment_id(self): - pass - - @abstractmethod - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): - pass - diff --git a/pr_agent/cli.py b/pr_agent/cli.py index caa7feed..c955e3c5 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -5,9 +5,6 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings from pr_agent.log import setup_logger -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler - -litellm_ai_handler = LiteLLMAiHandler() setup_logger() @@ -59,9 +56,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) if args.issue_url: - result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest)) + result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest)) else: - result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest)) + result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest)) if not result: parser.print_help() diff --git a/pr_agent/servers/bitbucket_app.py b/pr_agent/servers/bitbucket_app.py index 9fba4030..b5e5e7d7 100644 --- a/pr_agent/servers/bitbucket_app.py +++ b/pr_agent/servers/bitbucket_app.py @@ -23,9 +23,7 @@ from pr_agent.servers.github_action_runner import get_setting_or_env, is_true from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_reviewer import PRReviewer -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() secret_provider = get_secret_provider() @@ -91,7 +89,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req context['bitbucket_bearer_token'] = bearer_token context["settings"] = copy.deepcopy(global_settings) event = data["event"] - agent = PRAgent(ai_handler=litellm_ai_handler) + agent = PRAgent() if event == "pullrequest:created": pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] log_context["api_url"] = pr_url diff --git a/pr_agent/servers/gerrit_server.py b/pr_agent/servers/gerrit_server.py index b8b90670..1783f6b9 100644 --- a/pr_agent/servers/gerrit_server.py +++ b/pr_agent/servers/gerrit_server.py @@ -12,9 +12,7 @@ from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger, setup_logger -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() setup_logger() router = APIRouter() @@ -45,7 +43,7 @@ async def handle_gerrit_request(action: Action, item: Item): status_code=400, detail="msg is required for ask command" ) - await PRAgent(ai_handler=litellm_ai_handler).handle_request( + await PRAgent().handle_request( f"{item.project}:{item.refspec}", f"/{item.msg.strip()}" ) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index e420a61e..1d1f8a56 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -11,8 +11,6 @@ from pr_agent.log import get_logger from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_reviewer import PRReviewer -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() def is_true(value: Union[str, bool]) -> bool: if isinstance(value, bool): @@ -111,9 +109,9 @@ async def run_action(): comment_id = event_payload.get("comment", {}).get("id") provider = get_git_provider()(pr_url=url) if is_pr: - await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) + await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) else: - await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body) + await PRAgent().handle_request(url, body) if __name__ == '__main__': diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index bdcd78f5..32b3305d 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -17,9 +17,7 @@ from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.git_providers.git_provider import IncrementalPR from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -81,7 +79,7 @@ async def handle_request(body: Dict[str, Any], event: str): action = body.get("action") if not action: return {} - agent = PRAgent(ai_handler=litellm_ai_handler) + agent = PRAgent() bot_user = get_settings().github_app.bot_user sender = body.get("sender", {}).get("login") log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index b473b8fa..1363b941 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -8,9 +8,7 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.help import bot_help_text -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) NOTIFICATION_URL = "https://api.github.com/notifications" @@ -36,7 +34,7 @@ async def polling_loop(): last_modified = [None] git_provider = get_git_provider()() user_id = git_provider.get_user_id() - agent = PRAgent(ai_handler=litellm_ai_handler) + agent = PRAgent() get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) try: diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index 91956262..a5d5a115 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -14,9 +14,7 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider -from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler -litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -28,7 +26,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["api_url"] = url with get_logger().contextualize(**log_context): - background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body) + background_tasks.add_task(PRAgent().handle_request, url, body) @router.post("/webhook") diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 8bded7de..77e0c29f 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -14,7 +14,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_generate_labels.py b/pr_agent/tools/pr_generate_labels.py index fc90ed44..d59e6ccb 100644 --- a/pr_agent/tools/pr_generate_labels.py +++ b/pr_agent/tools/pr_generate_labels.py @@ -4,7 +4,7 @@ from typing import List, Tuple from jinja2 import Environment, StrictUndefined -from pr_agent.algo.ai_handler import AiHandler +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRGenerateLabels: - def __init__(self, pr_url: str, args: list = None): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels corresponding to the PR using an AI model. @@ -31,7 +31,7 @@ class PRGenerateLabels: self.pr_id = self.git_provider.get_pr_id() # Initialize the AI handler - self.ai_handler = AiHandler() + self.ai_handler = ai_handler # Initialize the variables dictionary self.vars = {