Update AI handler instantiation in server files

This commit is contained in:
Brian Pham
2023-12-13 08:16:02 +08:00
parent ca1ccd7b91
commit 8fb4a42ef1
16 changed files with 63 additions and 37 deletions

View File

@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter()
secret_provider = get_secret_provider()
@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
context['bitbucket_bearer_token'] = bearer_token
context["settings"] = copy.deepcopy(global_settings)
event = data["event"]
agent = PRAgent()
agent = PRAgent(ai_handler=litellm_ai_handler)
if event == "pullrequest:created":
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
log_context["api_url"] = pr_url

View File

@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware
from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger, setup_logger
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger()
router = APIRouter()
@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item):
status_code=400,
detail="msg is required for ask command"
)
await PRAgent().handle_request(
await PRAgent(ai_handler=litellm_ai_handler).handle_request(
f"{item.project}:{item.refspec}",
f"/{item.msg.strip()}"
)

View File

@ -8,7 +8,8 @@ from pr_agent.git_providers import get_git_provider
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_description import PRDescription
from pr_agent.tools.pr_reviewer import PRReviewer
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
async def run_action():
# Get environment variables
@ -83,9 +84,9 @@ async def run_action():
comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url)
if is_pr:
await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
else:
await PRAgent().handle_request(url, body)
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body)
if __name__ == '__main__':

View File

@ -16,7 +16,9 @@ from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.utils import verify_signature
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter()
@ -75,7 +77,7 @@ async def handle_request(body: Dict[str, Any], event: str):
action = body.get("action")
if not action:
return {}
agent = PRAgent()
agent = PRAgent(ai_handler=litellm_ai_handler)
bot_user = get_settings().github_app.bot_user
sender = body.get("sender", {}).get("login")
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}

View File

@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.servers.help import bot_help_text
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON)
NOTIFICATION_URL = "https://api.github.com/notifications"
@ -34,7 +36,7 @@ async def polling_loop():
last_modified = [None]
git_provider = get_git_provider()()
user_id = git_provider.get_user_id()
agent = PRAgent()
agent = PRAgent(ai_handler=litellm_ai_handler)
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
try:

View File

@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import LoggingFormat, get_logger, setup_logger
from pr_agent.secret_providers import get_secret_provider
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
litellm_ai_handler = LiteLLMAiHandler()
setup_logger(fmt=LoggingFormat.JSON)
router = APIRouter()
@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
log_context["event"] = "pull_request" if body == "/review" else "comment"
log_context["api_url"] = url
with get_logger().contextualize(**log_context):
background_tasks.add_task(PRAgent().handle_request, url, body)
background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body)
@router.post("/webhook")