diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 7ab78a0e..98b493de 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -4,7 +4,7 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings -from pr_agent.log import setup_logger +from pr_agent.log import setup_logger, get_logger log_level = os.environ.get("LOG_LEVEL", "INFO") setup_logger(log_level) @@ -71,10 +71,21 @@ def run(inargs=None, args=None): command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) - if args.issue_url: - result = asyncio.run(PRAgent().handle_request(args.issue_url, [command] + args.rest)) - else: - result = asyncio.run(PRAgent().handle_request(args.pr_url, [command] + args.rest)) + + async def inner(): + if args.issue_url: + result = await asyncio.create_task(PRAgent().handle_request(args.issue_url, [command] + args.rest)) + else: + result = await asyncio.create_task(PRAgent().handle_request(args.pr_url, [command] + args.rest)) + + if get_settings().litellm.get("enable_callbacks", False): + # There may be additional events on the event queue from the run above. If there are give them time to complete. + get_logger().debug("Waiting for event queue to complete") + await asyncio.wait([task for task in asyncio.all_tasks() if task is not asyncio.current_task()]) + + return result + + result = asyncio.run(inner()) if not result: parser.print_help()