mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Update AI handler instantiation in server files
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import shlex
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
|
||||
from pr_agent.algo.utils import update_settings_from_args
|
||||
from pr_agent.config_loader import get_settings
|
||||
@ -12,6 +13,7 @@ from pr_agent.tools.pr_questions import PRQuestions
|
||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||
from pr_agent.tools.pr_similar_issue import PRSimilarIssue
|
||||
from pr_agent.tools.pr_update_changelog import PRUpdateChangelog
|
||||
import inspect
|
||||
|
||||
command2class = {
|
||||
"auto_review": PRReviewer,
|
||||
@ -36,8 +38,16 @@ command2class = {
|
||||
commands = list(command2class.keys())
|
||||
|
||||
class PRAgent:
|
||||
def __init__(self):
|
||||
def __init__(self, ai_handler: BaseAiHandler = None):
|
||||
self.ai_handler = ai_handler
|
||||
pass
|
||||
|
||||
def has_ai_handler_param(cls):
|
||||
constructor = getattr(cls, "__init__", None)
|
||||
if constructor is not None:
|
||||
parameters = inspect.signature(constructor).parameters
|
||||
return "ai_handler" in parameters
|
||||
return False
|
||||
|
||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||
# First, apply repo specific settings if exists
|
||||
@ -56,13 +66,17 @@ class PRAgent:
|
||||
if action == "answer":
|
||||
if notify:
|
||||
notify()
|
||||
await PRReviewer(pr_url, is_answer=True, args=args).run()
|
||||
await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
|
||||
elif action == "auto_review":
|
||||
await PRReviewer(pr_url, is_auto=True, args=args).run()
|
||||
await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
|
||||
elif action in command2class:
|
||||
if notify:
|
||||
notify()
|
||||
await command2class[action](pr_url, args=args).run()
|
||||
|
||||
if(not self.has_ai_handler_param(command2class[action])):
|
||||
await command2class[action](pr_url, args=args).run()
|
||||
else
|
||||
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
@ -14,15 +14,15 @@ class BaseAiHandler(ABC):
|
||||
def deployment_id(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
"""
|
||||
This method should be implemented to return a chat completion from the AI model.
|
||||
params:
|
||||
model: the name of the model to use for the chat completion
|
||||
system: the system message string to use for the chat completion
|
||||
user: the user message string to use for the chat completion
|
||||
temperature: the temperature to use for the chat completion
|
||||
"""
|
||||
@abstractmethod
|
||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
|
||||
"""
|
||||
This method should be implemented to return a chat completion from the AI model.
|
||||
Args:
|
||||
model (str): the name of the model to use for the chat completion
|
||||
system (str): the system message string to use for the chat completion
|
||||
user (str): the user message string to use for the chat completion
|
||||
temperature (float): the temperature to use for the chat completion
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -5,7 +5,9 @@ import os
|
||||
from pr_agent.agent.pr_agent import PRAgent, commands
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.log import setup_logger
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger()
|
||||
|
||||
def run(inargs=None):
|
||||
@ -51,9 +53,9 @@ For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions
|
||||
command = args.command.lower()
|
||||
get_settings().set("CONFIG.CLI_MODE", True)
|
||||
if args.issue_url:
|
||||
result = asyncio.run(PRAgent().handle_request(args.issue_url, command + " " + " ".join(args.rest)))
|
||||
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.issue_url, command + " " + " ".join(args.rest)))
|
||||
else:
|
||||
result = asyncio.run(PRAgent().handle_request(args.pr_url, command + " " + " ".join(args.rest)))
|
||||
result = asyncio.run(PRAgent(ai_handler=litellm_ai_handler).handle_request(args.pr_url, command + " " + " ".join(args.rest)))
|
||||
if not result:
|
||||
parser.print_help()
|
||||
|
||||
|
@ -18,7 +18,9 @@ from pr_agent.agent.pr_agent import PRAgent
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from pr_agent.secret_providers import get_secret_provider
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger(fmt=LoggingFormat.JSON)
|
||||
router = APIRouter()
|
||||
secret_provider = get_secret_provider()
|
||||
@ -84,7 +86,7 @@ async def handle_github_webhooks(background_tasks: BackgroundTasks, request: Req
|
||||
context['bitbucket_bearer_token'] = bearer_token
|
||||
context["settings"] = copy.deepcopy(global_settings)
|
||||
event = data["event"]
|
||||
agent = PRAgent()
|
||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||
if event == "pullrequest:created":
|
||||
pr_url = data["data"]["pullrequest"]["links"]["html"]["href"]
|
||||
log_context["api_url"] = pr_url
|
||||
|
@ -12,7 +12,9 @@ from starlette_context.middleware import RawContextMiddleware
|
||||
from pr_agent.agent.pr_agent import PRAgent
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.log import get_logger, setup_logger
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
@ -43,7 +45,7 @@ async def handle_gerrit_request(action: Action, item: Item):
|
||||
status_code=400,
|
||||
detail="msg is required for ask command"
|
||||
)
|
||||
await PRAgent().handle_request(
|
||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(
|
||||
f"{item.project}:{item.refspec}",
|
||||
f"/{item.msg.strip()}"
|
||||
)
|
||||
|
@ -8,7 +8,8 @@ from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
|
||||
from pr_agent.tools.pr_description import PRDescription
|
||||
from pr_agent.tools.pr_reviewer import PRReviewer
|
||||
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
|
||||
async def run_action():
|
||||
# Get environment variables
|
||||
@ -83,9 +84,9 @@ async def run_action():
|
||||
comment_id = event_payload.get("comment", {}).get("id")
|
||||
provider = get_git_provider()(pr_url=url)
|
||||
if is_pr:
|
||||
await PRAgent().handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body, notify=lambda: provider.add_eyes_reaction(comment_id))
|
||||
else:
|
||||
await PRAgent().handle_request(url, body)
|
||||
await PRAgent(ai_handler=litellm_ai_handler).handle_request(url, body)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -16,7 +16,9 @@ from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.utils import apply_repo_settings
|
||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from pr_agent.servers.utils import verify_signature
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger(fmt=LoggingFormat.JSON)
|
||||
|
||||
router = APIRouter()
|
||||
@ -75,7 +77,7 @@ async def handle_request(body: Dict[str, Any], event: str):
|
||||
action = body.get("action")
|
||||
if not action:
|
||||
return {}
|
||||
agent = PRAgent()
|
||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||
bot_user = get_settings().github_app.bot_user
|
||||
sender = body.get("sender", {}).get("login")
|
||||
log_context = {"action": action, "event": event, "sender": sender, "server_type": "github_app"}
|
||||
|
@ -8,7 +8,9 @@ from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from pr_agent.servers.help import bot_help_text
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger(fmt=LoggingFormat.JSON)
|
||||
NOTIFICATION_URL = "https://api.github.com/notifications"
|
||||
|
||||
@ -34,7 +36,7 @@ async def polling_loop():
|
||||
last_modified = [None]
|
||||
git_provider = get_git_provider()()
|
||||
user_id = git_provider.get_user_id()
|
||||
agent = PRAgent()
|
||||
agent = PRAgent(ai_handler=litellm_ai_handler)
|
||||
get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)
|
||||
|
||||
try:
|
||||
|
@ -14,7 +14,9 @@ from pr_agent.agent.pr_agent import PRAgent
|
||||
from pr_agent.config_loader import get_settings, global_settings
|
||||
from pr_agent.log import LoggingFormat, get_logger, setup_logger
|
||||
from pr_agent.secret_providers import get_secret_provider
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler
|
||||
|
||||
litellm_ai_handler = LiteLLMAiHandler()
|
||||
setup_logger(fmt=LoggingFormat.JSON)
|
||||
router = APIRouter()
|
||||
|
||||
@ -26,7 +28,7 @@ def handle_request(background_tasks: BackgroundTasks, url: str, body: str, log_c
|
||||
log_context["event"] = "pull_request" if body == "/review" else "comment"
|
||||
log_context["api_url"] = url
|
||||
with get_logger().contextualize(**log_context):
|
||||
background_tasks.add_task(PRAgent().handle_request, url, body)
|
||||
background_tasks.add_task(PRAgent(ai_handler=litellm_ai_handler).handle_request, url, body)
|
||||
|
||||
|
||||
@router.post("/webhook")
|
||||
|
@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, get_ai_handler
|
||||
from pr_agent.algo.utils import load_yaml
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRAddDocs:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
|
@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, get_ai_handler
|
||||
from pr_agent.algo.utils import load_yaml
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRCodeSuggestions:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler() ):
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = None ):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
|
@ -7,7 +7,7 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import load_yaml, get_ai_handler
|
||||
from pr_agent.algo.utils import load_yaml
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRDescription:
|
||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
|
||||
"""
|
||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||
using an AI model.
|
||||
|
@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import get_ai_handler
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -13,7 +12,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRInformationFromUser:
|
||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = None):
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
|
@ -5,7 +5,6 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import get_ai_handler
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -13,7 +12,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRQuestions:
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = None):
|
||||
question_str = self.parse_args(args)
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
|
@ -9,7 +9,7 @@ from yaml import SafeLoader
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import convert_to_markdown, get_ai_handler, load_yaml, try_fix_yaml
|
||||
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import IncrementalPR, get_main_pr_language
|
||||
@ -21,13 +21,15 @@ class PRReviewer:
|
||||
"""
|
||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||
"""
|
||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = None):
|
||||
"""
|
||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||
|
||||
Args:
|
||||
pr_url (str): The URL of the pull request to be reviewed.
|
||||
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
|
||||
is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False.
|
||||
ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None.
|
||||
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
|
||||
"""
|
||||
self.parse_args(args) # -i command
|
||||
|
@ -8,7 +8,6 @@ from jinja2 import Environment, StrictUndefined
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
|
||||
from pr_agent.algo.token_handler import TokenHandler
|
||||
from pr_agent.algo.utils import get_ai_handler
|
||||
from pr_agent.config_loader import get_settings
|
||||
from pr_agent.git_providers import get_git_provider
|
||||
from pr_agent.git_providers.git_provider import get_main_pr_language
|
||||
@ -18,7 +17,7 @@ CHANGELOG_LINES = 50
|
||||
|
||||
|
||||
class PRUpdateChangelog:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = get_ai_handler()):
|
||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = None):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
|
Reference in New Issue
Block a user