Rename AiHandler to LiteLLMAiHandler

This commit is contained in:
Brian Pham
2023-12-11 16:56:49 +08:00
parent 7e47baa9db
commit 523a896465
8 changed files with 15 additions and 15 deletions

View File

@ -11,7 +11,7 @@ from pr_agent.algo.base_ai_handler import BaseAiHandler
OPENAI_RETRIES = 5 OPENAI_RETRIES = 5
class AiHandler(BaseAiHandler): class LiteLLMAiHandler(BaseAiHandler):
""" """
This class handles interactions with the OpenAI API for chat completions. This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file, It initializes the API key and other settings from a configuration file,

View File

@ -4,7 +4,7 @@ from typing import Dict
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -4,7 +4,7 @@ from typing import Dict, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler() ): def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler() ):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(

View File

@ -4,7 +4,7 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,7 +15,7 @@ from pr_agent.log import get_logger
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.

View File

@ -2,7 +2,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()

View File

@ -2,7 +2,7 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,7 +12,7 @@ from pr_agent.log import get_logger
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(

View File

@ -6,7 +6,7 @@ import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader from yaml import SafeLoader
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml
@ -21,7 +21,7 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.

View File

@ -5,7 +5,7 @@ from typing import Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler from pr_agent.algo.ai_handler import BaseAiHandler, LiteLLMAiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -17,7 +17,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()): def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAiHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(