mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-02 11:50:37 +08:00
Refactor AI handler instantiation to use lazy initialization in PR tools
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
import shlex
|
||||
from functools import partial
|
||||
|
||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||
|
||||
@ -41,8 +43,8 @@ command2class = {
|
||||
commands = list(command2class.keys())
|
||||
|
||||
class PRAgent:
|
||||
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
self.ai_handler = ai_handler
|
||||
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
self.ai_handler = ai_handler # will be initialized in run_action
|
||||
|
||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||
# First, apply repo specific settings if exists
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import textwrap
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
@ -17,14 +18,14 @@ from pr_agent.log import get_logger
|
||||
|
||||
class PRAddDocs:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import textwrap
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
@ -16,7 +17,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
class PRCodeSuggestions:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
@ -33,7 +34,7 @@ class PRCodeSuggestions:
|
||||
else:
|
||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
||||
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
class PRDescription:
|
||||
def __init__(self, pr_url: str, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
"""
|
||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||
using an AI model.
|
||||
@ -38,7 +39,7 @@ class PRDescription:
|
||||
get_settings().pr_description.enable_semantic_files_types = False
|
||||
|
||||
# Initialize the AI handler
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
|
||||
# Initialize the variables dictionary
|
||||
self.vars = {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
|
||||
|
||||
class PRGenerateLabels:
|
||||
def __init__(self, pr_url: str, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
"""
|
||||
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
||||
corresponding to the PR using an AI model.
|
||||
@ -33,7 +34,7 @@ class PRGenerateLabels:
|
||||
self.pr_id = self.git_provider.get_pr_id()
|
||||
|
||||
# Initialize the AI handler
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
|
||||
# Initialize the variables dictionary
|
||||
self.vars = {
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
@ -14,12 +15,12 @@ from pr_agent.log import get_logger
|
||||
|
||||
class PRInformationFromUser:
|
||||
def __init__(self, pr_url: str, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
"branch": self.git_provider.get_pr_branch(),
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
from jinja2 import Environment, StrictUndefined
|
||||
|
||||
@ -13,13 +14,13 @@ from pr_agent.log import get_logger
|
||||
|
||||
|
||||
class PRQuestions:
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
question_str = self.parse_args(args)
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_pr_language = get_main_pr_language(
|
||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||
)
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.question_str = question_str
|
||||
self.vars = {
|
||||
"title": self.git_provider.pr.title,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import yaml
|
||||
@ -24,7 +25,7 @@ class PRReviewer:
|
||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||
"""
|
||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
"""
|
||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||
|
||||
@ -47,7 +48,7 @@ class PRReviewer:
|
||||
|
||||
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
||||
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from datetime import date
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Tuple
|
||||
|
||||
@ -18,7 +19,7 @@ CHANGELOG_LINES = 50
|
||||
|
||||
|
||||
class PRUpdateChangelog:
|
||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||
|
||||
self.git_provider = get_git_provider()(pr_url)
|
||||
self.main_language = get_main_pr_language(
|
||||
@ -26,7 +27,7 @@ class PRUpdateChangelog:
|
||||
)
|
||||
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
||||
self._get_changlog_file() # self.changelog_file_str
|
||||
self.ai_handler = ai_handler
|
||||
self.ai_handler = ai_handler()
|
||||
self.patches_diff = None
|
||||
self.prediction = None
|
||||
self.cli_mode = cli_mode
|
||||
|
Reference in New Issue
Block a user