Refactor AI handler instantiation to use lazy initialization in PR tools

This commit is contained in:
mrT23
2023-12-17 16:52:03 +02:00
parent 54891ad1d2
commit 5fb373b212
9 changed files with 28 additions and 18 deletions

View File

@ -1,4 +1,6 @@
import shlex
from functools import partial
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
@ -41,8 +43,8 @@ command2class = {
commands = list(command2class.keys())
class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action
async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists

View File

@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict
from jinja2 import Environment, StrictUndefined
@ -17,14 +18,14 @@ from pr_agent.log import get_logger
class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode

View File

@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict, List
from jinja2 import Environment, StrictUndefined
@ -16,7 +17,7 @@ from pr_agent.log import get_logger
class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
@ -33,7 +34,7 @@ class PRCodeSuggestions:
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode

View File

@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple
from jinja2 import Environment, StrictUndefined
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
class PRDescription:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
@ -38,7 +39,7 @@ class PRDescription:
get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
# Initialize the variables dictionary
self.vars = {

View File

@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple
from jinja2 import Environment, StrictUndefined
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
@ -33,7 +34,7 @@ class PRGenerateLabels:
self.pr_id = self.git_provider.get_pr_id()
# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
# Initialize the variables dictionary
self.vars = {

View File

@ -1,4 +1,5 @@
import copy
from functools import partial
from jinja2 import Environment, StrictUndefined
@ -14,12 +15,12 @@ from pr_agent.log import get_logger
class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),

View File

@ -1,4 +1,5 @@
import copy
from functools import partial
from jinja2 import Environment, StrictUndefined
@ -13,13 +14,13 @@ from pr_agent.log import get_logger
class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,

View File

@ -1,6 +1,7 @@
import copy
import datetime
from collections import OrderedDict
from functools import partial
from typing import List, Tuple
import yaml
@ -24,7 +25,7 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@ -47,7 +48,7 @@ class PRReviewer:
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None

View File

@ -1,5 +1,6 @@
import copy
from datetime import date
from functools import partial
from time import sleep
from typing import Tuple
@ -18,7 +19,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
@ -26,7 +27,7 @@ class PRUpdateChangelog:
)
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode