mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-03 12:20:38 +08:00
Refactor AI handler instantiation to use lazy initialization in PR tools
This commit is contained in:
@ -1,4 +1,6 @@
|
|||||||
import shlex
|
import shlex
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
|
||||||
|
|
||||||
@ -41,8 +43,8 @@ command2class = {
|
|||||||
commands = list(command2class.keys())
|
commands = list(command2class.keys())
|
||||||
|
|
||||||
class PRAgent:
|
class PRAgent:
|
||||||
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler # will be initialized in run_action
|
||||||
|
|
||||||
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
async def handle_request(self, pr_url, request, notify=None) -> bool:
|
||||||
# First, apply repo specific settings if exists
|
# First, apply repo specific settings if exists
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from functools import partial
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
@ -17,14 +18,14 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
class PRAddDocs:
|
class PRAddDocs:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
self.cli_mode = cli_mode
|
self.cli_mode = cli_mode
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from functools import partial
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
class PRCodeSuggestions:
|
class PRCodeSuggestions:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
@ -33,7 +34,7 @@ class PRCodeSuggestions:
|
|||||||
else:
|
else:
|
||||||
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
|
||||||
|
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
self.cli_mode = cli_mode
|
self.cli_mode = cli_mode
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
class PRDescription:
|
class PRDescription:
|
||||||
def __init__(self, pr_url: str, args: list = None,
|
def __init__(self, pr_url: str, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
"""
|
"""
|
||||||
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
|
||||||
using an AI model.
|
using an AI model.
|
||||||
@ -38,7 +39,7 @@ class PRDescription:
|
|||||||
get_settings().pr_description.enable_semantic_files_types = False
|
get_settings().pr_description.enable_semantic_files_types = False
|
||||||
|
|
||||||
# Initialize the AI handler
|
# Initialize the AI handler
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
|
|
||||||
# Initialize the variables dictionary
|
# Initialize the variables dictionary
|
||||||
self.vars = {
|
self.vars = {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
class PRGenerateLabels:
|
class PRGenerateLabels:
|
||||||
def __init__(self, pr_url: str, args: list = None,
|
def __init__(self, pr_url: str, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
"""
|
"""
|
||||||
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
|
||||||
corresponding to the PR using an AI model.
|
corresponding to the PR using an AI model.
|
||||||
@ -33,7 +34,7 @@ class PRGenerateLabels:
|
|||||||
self.pr_id = self.git_provider.get_pr_id()
|
self.pr_id = self.git_provider.get_pr_id()
|
||||||
|
|
||||||
# Initialize the AI handler
|
# Initialize the AI handler
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
|
|
||||||
# Initialize the variables dictionary
|
# Initialize the variables dictionary
|
||||||
self.vars = {
|
self.vars = {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
@ -14,12 +15,12 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
class PRInformationFromUser:
|
class PRInformationFromUser:
|
||||||
def __init__(self, pr_url: str, args: list = None,
|
def __init__(self, pr_url: str, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
"branch": self.git_provider.get_pr_branch(),
|
"branch": self.git_provider.get_pr_branch(),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from jinja2 import Environment, StrictUndefined
|
from jinja2 import Environment, StrictUndefined
|
||||||
|
|
||||||
@ -13,13 +14,13 @@ from pr_agent.log import get_logger
|
|||||||
|
|
||||||
|
|
||||||
class PRQuestions:
|
class PRQuestions:
|
||||||
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
question_str = self.parse_args(args)
|
question_str = self.parse_args(args)
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_pr_language = get_main_pr_language(
|
self.main_pr_language = get_main_pr_language(
|
||||||
self.git_provider.get_languages(), self.git_provider.get_files()
|
self.git_provider.get_languages(), self.git_provider.get_files()
|
||||||
)
|
)
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.question_str = question_str
|
self.question_str = question_str
|
||||||
self.vars = {
|
self.vars = {
|
||||||
"title": self.git_provider.pr.title,
|
"title": self.git_provider.pr.title,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -24,7 +25,7 @@ class PRReviewer:
|
|||||||
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
|
||||||
"""
|
"""
|
||||||
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
|
||||||
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
"""
|
"""
|
||||||
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
|
||||||
|
|
||||||
@ -47,7 +48,7 @@ class PRReviewer:
|
|||||||
|
|
||||||
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
|
||||||
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
from functools import partial
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ CHANGELOG_LINES = 50
|
|||||||
|
|
||||||
|
|
||||||
class PRUpdateChangelog:
|
class PRUpdateChangelog:
|
||||||
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
|
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
|
||||||
|
|
||||||
self.git_provider = get_git_provider()(pr_url)
|
self.git_provider = get_git_provider()(pr_url)
|
||||||
self.main_language = get_main_pr_language(
|
self.main_language = get_main_pr_language(
|
||||||
@ -26,7 +27,7 @@ class PRUpdateChangelog:
|
|||||||
)
|
)
|
||||||
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
|
||||||
self._get_changlog_file() # self.changelog_file_str
|
self._get_changlog_file() # self.changelog_file_str
|
||||||
self.ai_handler = ai_handler
|
self.ai_handler = ai_handler()
|
||||||
self.patches_diff = None
|
self.patches_diff = None
|
||||||
self.prediction = None
|
self.prediction = None
|
||||||
self.cli_mode = cli_mode
|
self.cli_mode = cli_mode
|
||||||
|
Reference in New Issue
Block a user