Refactor AI handler instantiation to use lazy initialization in PR tools

This commit is contained in:
mrT23
2023-12-17 16:52:03 +02:00
parent 54891ad1d2
commit 5fb373b212
9 changed files with 28 additions and 18 deletions

View File

@ -1,4 +1,6 @@
import shlex import shlex
from functools import partial
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
@ -41,8 +43,8 @@ command2class = {
commands = list(command2class.keys()) commands = list(command2class.keys())
class PRAgent: class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()): def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler self.ai_handler = ai_handler # will be initialized in run_action
async def handle_request(self, pr_url, request, notify=None) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists

View File

@ -1,5 +1,6 @@
import copy import copy
import textwrap import textwrap
from functools import partial
from typing import Dict from typing import Dict
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -17,14 +18,14 @@ from pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -1,5 +1,6 @@
import copy import copy
import textwrap import textwrap
from functools import partial
from typing import Dict, List from typing import Dict, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -16,7 +17,7 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
@ -33,7 +34,7 @@ class PRCodeSuggestions:
else: else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -1,5 +1,6 @@
import copy import copy
import re import re
from functools import partial
from typing import List, Tuple from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None, def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.
@ -38,7 +39,7 @@ class PRDescription:
get_settings().pr_description.enable_semantic_files_types = False get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = ai_handler self.ai_handler = ai_handler()
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {

View File

@ -1,5 +1,6 @@
import copy import copy
import re import re
from functools import partial
from typing import List, Tuple from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -17,7 +18,7 @@ from pr_agent.log import get_logger
class PRGenerateLabels: class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None, def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
""" """
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model. corresponding to the PR using an AI model.
@ -33,7 +34,7 @@ class PRGenerateLabels:
self.pr_id = self.git_provider.get_pr_id() self.pr_id = self.git_provider.get_pr_id()
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = ai_handler self.ai_handler = ai_handler()
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {

View File

@ -1,4 +1,5 @@
import copy import copy
from functools import partial
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -14,12 +15,12 @@ from pr_agent.log import get_logger
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None, def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),

View File

@ -1,4 +1,5 @@
import copy import copy
from functools import partial
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
@ -13,13 +14,13 @@ from pr_agent.log import get_logger
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.question_str = question_str self.question_str = question_str
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,

View File

@ -1,6 +1,7 @@
import copy import copy
import datetime import datetime
from collections import OrderedDict from collections import OrderedDict
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import yaml import yaml
@ -24,7 +25,7 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()): ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
@ -47,7 +48,7 @@ class PRReviewer:
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None

View File

@ -1,5 +1,6 @@
import copy import copy
from datetime import date from datetime import date
from functools import partial
from time import sleep from time import sleep
from typing import Tuple from typing import Tuple
@ -18,7 +19,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()): def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
@ -26,7 +27,7 @@ class PRUpdateChangelog:
) )
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str self._get_changlog_file() # self.changelog_file_str
self.ai_handler = ai_handler self.ai_handler = ai_handler()
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode