From 8fb4a42ef1e45aa5ac67a5eca42eeabe5c3c999a Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Wed, 13 Dec 2023 08:16:02 +0800 Subject: [PATCH] Update AI handler instantiation in server files --- pr_agent/agent/pr_agent.py | 22 ++++++++++++++++---- pr_agent/algo/ai_handlers/base_ai_handler.py | 18 ++++++++-------- pr_agent/cli.py | 6 ++++-- pr_agent/servers/bitbucket_app.py | 4 +++- pr_agent/servers/gerrit_server.py | 4 +++- pr_agent/servers/github_action_runner.py | 7 ++++--- pr_agent/servers/github_app.py | 4 +++- pr_agent/servers/github_polling.py | 4 +++- pr_agent/servers/gitlab_webhook.py | 4 +++- pr_agent/tools/pr_add_docs.py | 4 ++-- pr_agent/tools/pr_code_suggestions.py | 4 ++-- pr_agent/tools/pr_description.py | 4 ++-- pr_agent/tools/pr_information_from_user.py | 3 +-- pr_agent/tools/pr_questions.py | 3 +-- pr_agent/tools/pr_reviewer.py | 6 ++++-- pr_agent/tools/pr_update_changelog.py | 3 +-- 16 files changed, 63 insertions(+), 37 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index cd2bf2cc..a94984ac 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -1,4 +1,5 @@ import shlex +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.utils import update_settings_from_args from pr_agent.config_loader import get_settings @@ -12,6 +13,7 @@ from pr_agent.tools.pr_questions import PRQuestions from pr_agent.tools.pr_reviewer import PRReviewer from pr_agent.tools.pr_similar_issue import PRSimilarIssue from pr_agent.tools.pr_update_changelog import PRUpdateChangelog +import inspect command2class = { "auto_review": PRReviewer, @@ -36,8 +38,16 @@ command2class = { commands = list(command2class.keys()) class PRAgent: - def __init__(self): + def __init__(self, ai_handler: BaseAiHandler = None): + self.ai_handler = ai_handler pass + + def has_ai_handler_param(cls): + constructor = getattr(cls, "__init__", None) + if constructor is not None: + parameters = inspect.signature(constructor).parameters + return "ai_handler" in parameters + return False async def handle_request(self, pr_url, request, notify=None) -> bool: # First, apply repo specific settings if exists @@ -56,13 +66,17 @@ class PRAgent: if action == "answer": if notify: notify() - await PRReviewer(pr_url, is_answer=True, args=args).run() + await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run() elif action == "auto_review": - await PRReviewer(pr_url, is_auto=True, args=args).run() + await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run() elif action in command2class: if notify: notify() - await command2class[action](pr_url, args=args).run() + + if(not self.has_ai_handler_param(command2class[action])): + await command2class[action](pr_url, args=args).run() + else + await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run() else: return False return True diff --git a/pr_agent/algo/ai_handlers/base_ai_handler.py b/pr_agent/algo/ai_handlers/base_ai_handler.py index 7c6c3ddf..c8473fb3 100644 --- a/pr_agent/algo/ai_handlers/base_ai_handler.py +++ b/pr_agent/algo/ai_handlers/base_ai_handler.py @@ -14,15 +14,15 @@ class BaseAiHandler(ABC): def deployment_id(self): pass - @abstractmethod - """ - This method should be implemented to return a chat completion from the AI model. - params: - model: the name of the model to use for the chat completion - system: the system message string to use for the chat completion - user: the user message string to use for the chat completion - temperature: the temperature to use for the chat completion - """ + @abstractmethod async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + """ + This method should be implemented to return a chat completion from the AI model. + Args: + model (str): the name of the model to use for the chat completion + system (str): the system message string to use for the chat completion + user (str): the user message string to use for the chat completion + temperature (float): the temperature to use for the chat completion + """ pass diff --git a/pr_agent/cli.py b/pr_agent/cli.py index 6728db9f..422185ab 100644 --- a/pr_agent/cli.py +++ b/pr_agent/cli.py @@ -5,7 +5,9 @@ import os from pr_agent.agent.pr_agent import PRAgent, commands from pr_agent.config_loader import get_settings from pr_agent.log import setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() def run(inargs=None): @@ -51,9 +53,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions command = args.command.lower() get_settings().set("CONFIG.CLI_MODE", True) if args.issue_url: - result = asyncio.run(PRAgent().handle_request(args.issue_url, command + " " + " ".join(args.rest))) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, command + " " + " ".join(args.rest))) else: - result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest))) + result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, command + " " + " ".join(args.rest))) if not result: parser.print_help() diff --git a/pr_agent/servers/bitbucket_app.py b/pr_agent/servers/bitbucket_app.py index e147fbdd..739ba162 100644 --- a/pr_agent/servers/bitbucket_app.py +++ b/pr_agent/servers/bitbucket_app.py @@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() secret_provider = get_secret_provider() @@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req context['bitbucket_bearer_token'] = bearer_token context["settings"] = copy.deepcopy(global_settings) event = data["event"] - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) if event == "pullrequest:created": pr_url = data["data"]["pullrequest"]["links"]["html"]["href"] log_context["api_url"] = pr_url diff --git a/pr_agent/servers/gerrit_server.py b/pr_agent/servers/gerrit_server.py index 1783f6b9..b8b90670 100644 --- a/pr_agent/servers/gerrit_server.py +++ b/pr_agent/servers/gerrit_server.py @@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger, setup_logger +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger() router = APIRouter() @@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item): status_code=400, detail="msg is required for ask command" ) - await PRAgent().handle_request( + await PRAgent(ai_handler=litellm_ai_handler).handle_request( f"{item.project}:{item.refspec}", f"/{item.msg.strip()}" ) diff --git a/pr_agent/servers/github_action_runner.py b/pr_agent/servers/github_action_runner.py index 714e7297..030e3f2c 100644 --- a/pr_agent/servers/github_action_runner.py +++ b/pr_agent/servers/github_action_runner.py @@ -8,7 +8,8 @@ from pr_agent.git_providers import get_git_provider from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_description import PRDescription from pr_agent.tools.pr_reviewer import PRReviewer - +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() async def run_action(): # Get environment variables @@ -83,9 +84,9 @@ async def run_action(): comment_id = event_payload.get("comment", {}).get("id") provider = get_git_provider()(pr_url=url) if is_pr: - await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id)) else: - await PRAgent().handle_request(url, body) + await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body) if __name__ == '__main__': diff --git a/pr_agent/servers/github_app.py b/pr_agent/servers/github_app.py index 37f96e2d..bb595a89 100644 --- a/pr_agent/servers/github_app.py +++ b/pr_agent/servers/github_app.py @@ -16,7 +16,9 @@ from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.utils import verify_signature +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -75,7 +77,7 @@ async def handle_request(body: Dict[str, Any], event: str): action = body.get("action") if not action: return {} - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) bot_user = get_settings().github_app.bot_user sender = body.get("sender", {}).get("login") log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"} diff --git a/pr_agent/servers/github_polling.py b/pr_agent/servers/github_polling.py index 1363b941..b473b8fa 100644 --- a/pr_agent/servers/github_polling.py +++ b/pr_agent/servers/github_polling.py @@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.servers.help import bot_help_text +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) NOTIFICATION_URL = "https://api.github.com/notifications" @@ -34,7 +36,7 @@ async def polling_loop(): last_modified = [None] git_provider = get_git_provider()() user_id = git_provider.get_user_id() - agent = PRAgent() + agent = PRAgent(ai_handler=litellm_ai_handler) get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False) try: diff --git a/pr_agent/servers/gitlab_webhook.py b/pr_agent/servers/gitlab_webhook.py index 63bf99ce..b6921684 100644 --- a/pr_agent/servers/gitlab_webhook.py +++ b/pr_agent/servers/gitlab_webhook.py @@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import LoggingFormat, get_logger, setup_logger from pr_agent.secret_providers import get_secret_provider +from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler +litellm_ai_handler = LiteLLMAiHandler() setup_logger(fmt=LoggingFormat.JSON) router = APIRouter() @@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c log_context["event"] = "pull_request" if body == "/review" else "comment" log_context["api_url"] = url with get_logger().contextualize(**log_context): - background_tasks.add_task(PRAgent().handle_request, url, body) + background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body) @router.post("/webhook") diff --git a/pr_agent/tools/pr_add_docs.py b/pr_agent/tools/pr_add_docs.py index 916f479f..70dd66c2 100644 --- a/pr_agent/tools/pr_add_docs.py +++ b/pr_agent/tools/pr_add_docs.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRAddDocs: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index 61a382a5..a85bab5f 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRCodeSuggestions: - def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler() ): + def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_description.py b/pr_agent/tools/pr_description.py index c3db0cef..d377d75c 100644 --- a/pr_agent/tools/pr_description.py +++ b/pr_agent/tools/pr_description.py @@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import load_yaml, get_ai_handler +from pr_agent.algo.utils import load_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -15,7 +15,7 @@ from pr_agent.log import get_logger class PRDescription: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRDescription object with the necessary attributes and objects for generating a PR description using an AI model. diff --git a/pr_agent/tools/pr_information_from_user.py b/pr_agent/tools/pr_information_from_user.py index c4240723..e52765f7 100644 --- a/pr_agent/tools/pr_information_from_user.py +++ b/pr_agent/tools/pr_information_from_user.py @@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -13,7 +12,7 @@ from pr_agent.log import get_logger class PRInformationFromUser: - def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( self.git_provider.get_languages(), self.git_provider.get_files() diff --git a/pr_agent/tools/pr_questions.py b/pr_agent/tools/pr_questions.py index ecaf4d8d..79edfd6a 100644 --- a/pr_agent/tools/pr_questions.py +++ b/pr_agent/tools/pr_questions.py @@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -13,7 +12,7 @@ from pr_agent.log import get_logger class PRQuestions: - def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None): question_str = self.parse_args(args) self.git_provider = get_git_provider()(pr_url) self.main_pr_language = get_main_pr_language( diff --git a/pr_agent/tools/pr_reviewer.py b/pr_agent/tools/pr_reviewer.py index 138ad5ad..becd2191 100644 --- a/pr_agent/tools/pr_reviewer.py +++ b/pr_agent/tools/pr_reviewer.py @@ -9,7 +9,7 @@ from yaml import SafeLoader from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import convert_to_markdown, get_ai_handler, load_yaml, try_fix_yaml +from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language @@ -21,13 +21,15 @@ class PRReviewer: """ The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. """ - def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None): """ Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Args: pr_url (str): The URL of the pull request to be reviewed. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. + is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False. + ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. """ self.parse_args(args) # -i command diff --git a/pr_agent/tools/pr_update_changelog.py b/pr_agent/tools/pr_update_changelog.py index 33ba941d..07130749 100644 --- a/pr_agent/tools/pr_update_changelog.py +++ b/pr_agent/tools/pr_update_changelog.py @@ -8,7 +8,6 @@ from jinja2 import Environment, StrictUndefined from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_ai_handler from pr_agent.config_loader import get_settings from pr_agent.git_providers import get_git_provider from pr_agent.git_providers.git_provider import get_main_pr_language @@ -18,7 +17,7 @@ CHANGELOG_LINES = 50 class PRUpdateChangelog: - def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = get_ai_handler()): + def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None): self.git_provider = get_git_provider()(pr_url) self.main_language = get_main_pr_language(