Refactor AI handler instantiation in PRAgent and related classes

This commit is contained in:
mrT23
2023-12-14 08:53:22 +02:00
parent e37598fdca
commit 3531016a2c
11 changed files with 15 additions and 51 deletions

View File

@ -1,5 +1,6 @@
import shlex import shlex
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.utils import update_settings_from_args from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -48,10 +49,8 @@ def has_ai_handler_param(cls: object):
return False return False
class PRAgent: class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = None): def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler self.ai_handler = ai_handler
pass
async def handle_request(self, pr_url, request, notify=None) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists

View File

@ -1,20 +0,0 @@
from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler.
"""
@abstractmethod
def __init__(self):
pass
@property
@abstractmethod
def deployment_id(self):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
pass

View File

@ -5,9 +5,6 @@ import os
from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.agent.pr_agent import PRAgent, commands
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.log import setup_logger from pr_agent.log import setup_logger
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger() setup_logger()
@ -59,9 +56,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions
command = args.command.lower() command = args.command.lower()
get_settings().set("CONFIG.CLI_MODE", True) get_settings().set("CONFIG.CLI_MODE", True)
if args.issue_url: if args.issue_url:
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest)) result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest))
else: else:
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest)) result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest))
if not result: if not result:
parser.print_help() parser.print_help()

View File

@ -23,9 +23,7 @@ from pr_agent.servers.github_action_runner import get_setting_or_env, is_true
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
secret_provider = get_secret_provider() secret_provider = get_secret_provider()
@ -91,7 +89,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
context['bitbucket_bearer_token'] = bearer_token context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings) context["settings"] = copy.deepcopy(global_settings)
event = data["event"] event = data["event"]
agent = PRAgent(ai_handler=litellm_ai_handler) agent = PRAgent()
if event == "pullrequest:created": if event == "pullrequest:created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url log_context["api_url"] = pr_url

View File

@ -12,9 +12,7 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger, setup_logger from pr_agent.log import get_logger, setup_logger
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger() setup_logger()
router = APIRouter() router = APIRouter()
@ -45,7 +43,7 @@ async def handle_gerrit_request(action: Action, item: Item):
status_code=400, status_code=400,
detail="msg is required for ask command" detail="msg is required for ask command"
) )
await PRAgent(ai_handler=litellm_ai_handler).handle_request( await PRAgent().handle_request(
f"{item.project}:{item.refspec}", f"{item.project}:{item.refspec}",
f"/{item.msg.strip()}" f"/{item.msg.strip()}"
) )

View File

@ -11,8 +11,6 @@ from pr_agent.log import get_logger
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
def is_true(value: Union[str, bool]) -> bool: def is_true(value: Union[str, bool]) -> bool:
if isinstance(value, bool): if isinstance(value, bool):
@ -111,9 +109,9 @@ async def run_action():
comment_id = event_payload.get("comment", {}).get("id") comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url) provider = get_git_provider()(pr_url=url)
if is_pr: if is_pr:
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
else: else:
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body) await PRAgent().handle_request(url, body)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -17,9 +17,7 @@ from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.git_providers.git_provider import IncrementalPR from pr_agent.git_providers.git_provider import IncrementalPR
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
@ -81,7 +79,7 @@ async def handle_request(body: Dict[str, Any], event: str):
action = body.get("action") action = body.get("action")
if not action: if not action:
return {} return {}
agent = PRAgent(ai_handler=litellm_ai_handler) agent = PRAgent()
bot_user = get_settings().github_app.bot_user bot_user = get_settings().github_app.bot_user
sender = body.get("sender", {}).get("login") sender = body.get("sender", {}).get("login")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}

View File

@ -8,9 +8,7 @@ from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider from pr_agent.git_providers import get_git_provider
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.help import bot_help_text from pr_agent.servers.help import bot_help_text
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
NOTIFICATION_URL = "https://api.github.com/notifications" NOTIFICATION_URL = "https://api.github.com/notifications"
@ -36,7 +34,7 @@ async def polling_loop():
last_modified = [None] last_modified = [None]
git_provider = get_git_provider()() git_provider = get_git_provider()()
user_id = git_provider.get_user_id() user_id = git_provider.get_user_id()
agent = PRAgent(ai_handler=litellm_ai_handler) agent = PRAgent()
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
try: try:

View File

@ -14,9 +14,7 @@ from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider from pr_agent.secret_providers import get_secret_provider
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON) setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter() router = APIRouter()
@ -28,7 +26,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["event"] = "pull_request" if body == "/review" else "comment"
log_context["api_url"] = url log_context["api_url"] = url
with get_logger().contextualize(**log_context): with get_logger().contextualize(**log_context):
background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body) background_tasks.add_task(PRAgent().handle_request, url, body)
@router.post("/webhook") @router.post("/webhook")

View File

@ -14,7 +14,7 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ): def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -4,7 +4,7 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
class PRGenerateLabels: class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
""" """
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model. corresponding to the PR using an AI model.
@ -31,7 +31,7 @@ class PRGenerateLabels:
self.pr_id = self.git_provider.get_pr_id() self.pr_id = self.git_provider.get_pr_id()
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = ai_handler
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {