mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-06 05:40:38 +08:00
Refactor AI handler instantiation in PRAgent and related classes
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
import shlex
|
import shlex
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
|
|
||||||
from pr_agent.algo.utils import update_settings_from_args
|
from pr_agent.algo.utils import update_settings_from_args
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
@ -48,10 +49,8 @@ def has_ai_handler_param(cls: object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
class PRAgent:
|
class PRAgent:
|
||||||
def __init__(self, ai_handler: BaseAiHandler = None):
|
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||||
# First, apply repo specific settings if exists
|
# First, apply repo specific settings if exists
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
class BaseAiHandler(ABC):
|
|
||||||
"""
|
|
||||||
This class defines the interface for an AI handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def deployment_id(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
|
||||||
pass
|
|
||||||
|
|
@ -5,9 +5,6 @@ import os
|
|||||||
from pr_agent.agent.pr_agent import PRAgent, commands
|
from pr_agent.agent.pr_agent import PRAgent, commands
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.log import setup_logger
|
from pr_agent.log import setup_logger
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger()
|
setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@ -59,9 +56,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions
|
|||||||
command = args.command.lower()
|
command = args.command.lower()
|
||||||
get_settings().set("CONFIG.CLI_MODE", True)
|
get_settings().set("CONFIG.CLI_MODE", True)
|
||||||
if args.issue_url:
|
if args.issue_url:
|
||||||
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, [command] + args.rest))
|
result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest))
|
||||||
else:
|
else:
|
||||||
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, [command] + args.rest))
|
result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest))
|
||||||
if not result:
|
if not result:
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
|
|
||||||
|
@ -23,9 +23,7 @@ from pr_agent.servers.github_action_runner import get_setting_or_env, is_true
|
|||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||||
from pr_agent.tools.pr_description import PRDescription
|
from pr_agent.tools.pr_description import PRDescription
|
||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
secret_provider = get_secret_provider()
|
secret_provider = get_secret_provider()
|
||||||
@ -91,7 +89,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
|||||||
context['bitbucket_bearer_token'] = bearer_token
|
context['bitbucket_bearer_token'] = bearer_token
|
||||||
context["settings"] = copy.deepcopy(global_settings)
|
context["settings"] = copy.deepcopy(global_settings)
|
||||||
event = data["event"]
|
event = data["event"]
|
||||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
agent = PRAgent()
|
||||||
if event == "pullrequest:created":
|
if event == "pullrequest:created":
|
||||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||||
log_context["api_url"] = pr_url
|
log_context["api_url"] = pr_url
|
||||||
|
@ -12,9 +12,7 @@ from starlette_context.middleware import RawContextMiddleware
|
|||||||
from pr_agent.agent.pr_agent import PRAgent
|
from pr_agent.agent.pr_agent import PRAgent
|
||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import get_logger, setup_logger
|
from pr_agent.log import get_logger, setup_logger
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger()
|
setup_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -45,7 +43,7 @@ async def handle_gerrit_request(action: Action, item: Item):
|
|||||||
status_code=400,
|
status_code=400,
|
||||||
detail="msg is required for ask command"
|
detail="msg is required for ask command"
|
||||||
)
|
)
|
||||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(
|
await PRAgent().handle_request(
|
||||||
f"{item.project}:{item.refspec}",
|
f"{item.project}:{item.refspec}",
|
||||||
f"/{item.msg.strip()}"
|
f"/{item.msg.strip()}"
|
||||||
)
|
)
|
||||||
|
@ -11,8 +11,6 @@ from pr_agent.log import get_logger
|
|||||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||||
from pr_agent.tools.pr_description import PRDescription
|
from pr_agent.tools.pr_description import PRDescription
|
||||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
|
|
||||||
def is_true(value: Union[str, bool]) -> bool:
|
def is_true(value: Union[str, bool]) -> bool:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
@ -111,9 +109,9 @@ async def run_action():
|
|||||||
comment_id = event_payload.get("comment", {}).get("id")
|
comment_id = event_payload.get("comment", {}).get("id")
|
||||||
provider = get_git_provider()(pr_url=url)
|
provider = get_git_provider()(pr_url=url)
|
||||||
if is_pr:
|
if is_pr:
|
||||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
||||||
else:
|
else:
|
||||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body)
|
await PRAgent().handle_request(url, body)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -17,9 +17,7 @@ from pr_agent.git_providers.utils import apply_repo_settings
|
|||||||
from pr_agent.git_providers.git_provider import IncrementalPR
|
from pr_agent.git_providers.git_provider import IncrementalPR
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
|
from pr_agent.servers.utils import verify_signature, DefaultDictWithTimeout
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -81,7 +79,7 @@ async def handle_request(body: Dict[str, Any], event: str):
|
|||||||
action = body.get("action")
|
action = body.get("action")
|
||||||
if not action:
|
if not action:
|
||||||
return {}
|
return {}
|
||||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
agent = PRAgent()
|
||||||
bot_user = get_settings().github_app.bot_user
|
bot_user = get_settings().github_app.bot_user
|
||||||
sender = body.get("sender", {}).get("login")
|
sender = body.get("sender", {}).get("login")
|
||||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}
|
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}
|
||||||
|
@ -8,9 +8,7 @@ from pr_agent.config_loader import get_settings
|
|||||||
from pr_agent.git_providers import get_git_provider
|
from pr_agent.git_providers import get_git_provider
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.servers.help import bot_help_text
|
from pr_agent.servers.help import bot_help_text
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
NOTIFICATION_URL = "https://api.github.com/notifications"
|
NOTIFICATION_URL = "https://api.github.com/notifications"
|
||||||
|
|
||||||
@ -36,7 +34,7 @@ async def polling_loop():
|
|||||||
last_modified = [None]
|
last_modified = [None]
|
||||||
git_provider = get_git_provider()()
|
git_provider = get_git_provider()()
|
||||||
user_id = git_provider.get_user_id()
|
user_id = git_provider.get_user_id()
|
||||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
agent = PRAgent()
|
||||||
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -14,9 +14,7 @@ from pr_agent.agent.pr_agent import PRAgent
|
|||||||
from pr_agent.config_loader import get_settings, global_settings
|
from pr_agent.config_loader import get_settings, global_settings
|
||||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||||
from pr_agent.secret_providers import get_secret_provider
|
from pr_agent.secret_providers import get_secret_provider
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
|
||||||
|
|
||||||
litellm_ai_handler = LiteLLMAiHandler()
|
|
||||||
setup_logger(fmt=LoggingFormat.JSON)
|
setup_logger(fmt=LoggingFormat.JSON)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
|
|||||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||||
log_context["api_url"] = url
|
log_context["api_url"] = url
|
||||||
with get_logger().contextualize(**log_context):
|
with get_logger().contextualize(**log_context):
|
||||||
background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body)
|
background_tasks.add_task(PRAgent().handle_request, url, body)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook")
|
@router.post("/webhook")
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
from pr_agent.algo.ai_handler import AiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||||
from pr_agent.algo.token_handler import TokenHandler
|
from pr_agent.algo.token_handler import TokenHandler
|
||||||
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
|
||||||
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRGenerateLabels:
|
class PRGenerateLabels:
|
||||||
def __init__(self, pr_url: str, args: list = None):
|
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
|
||||||
"""
|
"""
|
||||||
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
||||||
corresponding to the PR using an AI model.
|
corresponding to the PR using an AI model.
|
||||||
@ -31,7 +31,7 @@ class PRGenerateLabels:
|
|||||||
self.pr_id = self.git_provider.get_pr_id()
|
self.pr_id = self.git_provider.get_pr_id()
|
||||||
|
|
||||||
# Initialize the AI handler
|
# Initialize the AI handler
|
||||||
self.ai_handler = AiHandler()
|
self.ai_handler = ai_handler
|
||||||
|
|
||||||
# Initialize the variables dictionary
|
# Initialize the variables dictionary
|
||||||
self.vars = {
|
self.vars = {
|
||||||
|
Reference in New Issue
Block a user