Merge pull request #514 from brianpham93/abstract-BaseAiHandler

Abstract AiHandler to BaseAiHandler
This commit is contained in:
mrT23
2023-12-14 07:54:13 -08:00
committed by GitHub
15 changed files with 198 additions and 32 deletions

View File

@ -1,8 +1,11 @@
import shlex import shlex
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.utils import update_settings_from_args from pr_agent.algo.utils import update_settings_from_args
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.git_providers.utils import apply_repo_settings from pr_agent.git_providers.utils import apply_repo_settings
from pr_agent.log import get_logger
from pr_agent.tools.pr_add_docs import PRAddDocs from pr_agent.tools.pr_add_docs import PRAddDocs
from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions from pr_agent.tools.pr_code_suggestions import PRCodeSuggestions
from pr_agent.tools.pr_config import PRConfig from pr_agent.tools.pr_config import PRConfig
@ -38,8 +41,8 @@ command2class = {
commands = list(command2class.keys()) commands = list(command2class.keys())
class PRAgent: class PRAgent:
def __init__(self): def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
pass self.ai_handler = ai_handler
async def handle_request(self, pr_url, request, notify=None) -> bool: async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists # First, apply repo specific settings if exists
@ -61,13 +64,14 @@ class PRAgent:
if action == "answer": if action == "answer":
if notify: if notify:
notify() notify()
await PRReviewer(pr_url, is_answer=True, args=args).run() await PRReviewer(pr_url, is_answer=True, args=args, ai_handler=self.ai_handler).run()
elif action == "auto_review": elif action == "auto_review":
await PRReviewer(pr_url, is_auto=True, args=args).run() await PRReviewer(pr_url, is_auto=True, args=args, ai_handler=self.ai_handler).run()
elif action in command2class: elif action in command2class:
if notify: if notify:
notify() notify()
await command2class[action](pr_url, args=args).run()
await command2class[action](pr_url, ai_handler=self.ai_handler, args=args).run()
else: else:
return False return False
return True return True

View File

@ -0,0 +1,28 @@
from abc import ABC, abstractmethod
class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
"""
@abstractmethod
def __init__(self):
pass
@property
@abstractmethod
def deployment_id(self):
pass
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
"""
This method should be implemented to return a chat completion from the AI model.
Args:
model (str): the name of the model to use for the chat completion
system (str): the system message string to use for the chat completion
user (str): the user message string to use for the chat completion
temperature (float): the temperature to use for the chat completion
"""
pass

View File

@ -0,0 +1,49 @@
try:
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
OPENAI_RETRIES = 5
class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
self._chat = ChatOpenAI(openai_api_key=get_settings().openai.key)
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def chat(self):
return self._chat
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages=[SystemMessage(content=system), HumanMessage(content=user)]
# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed"
return resp.content, finish_reason
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e

View File

@ -6,13 +6,14 @@ import openai
from litellm import acompletion from litellm import acompletion
from openai.error import APIError, RateLimitError, Timeout, TryAgain from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry from retry import retry
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger from pr_agent.log import get_logger
OPENAI_RETRIES = 5 OPENAI_RETRIES = 5
class AiHandler: class LiteLLMAIHandler(BaseAiHandler):
""" """
This class handles interactions with the OpenAI API for chat completions. This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file, It initializes the API key and other settings from a configuration file,
@ -134,4 +135,4 @@ class AiHandler:
usage = response.get("usage") usage = response.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage) model=model, usage=usage)
return resp, finish_reason return resp, finish_reason

View File

@ -0,0 +1,67 @@
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
OPENAI_RETRIES = 5
class OpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
try:
super().__init__()
openai.api_key = get_settings().openai.key
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
if get_settings().get("OPENAI.API_TYPE", None):
if get_settings().openai.api_type == "azure":
self.azure = True
openai.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
except AttributeError as e:
raise ValueError("OpenAI key is required") from e
@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)
@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
deployment_id = self.deployment_id
get_logger().info("System: ", system)
get_logger().info("User: ", user)
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
chat_completion = await openai.ChatCompletion.acreate(
model=model,
deployment_id=deployment_id,
messages=messages,
temperature=temperature,
)
resp = chat_completion["choices"][0]['message']['content']
finish_reason = chat_completion["choices"][0]["finish_reason"]
usage = chat_completion.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
except (RateLimitError) as e:
get_logger().error("Rate limit error during OpenAI inference: ", e)
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e

View File

@ -447,4 +447,4 @@ def clip_tokens(text: str, max_tokens: int, add_three_dots=True) -> str:
return clipped_text return clipped_text
except Exception as e: except Exception as e:
get_logger().warning(f"Failed to clip tokens: {e}") get_logger().warning(f"Failed to clip tokens: {e}")
return text return text

View File

@ -4,7 +4,8 @@ from typing import Dict
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -15,14 +16,15 @@ from pr_agent.log import get_logger
class PRAddDocs: class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None): def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -3,7 +3,8 @@ import textwrap
from typing import Dict, List from typing import Dict, List
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml from pr_agent.algo.utils import load_yaml
@ -14,7 +15,8 @@ from pr_agent.log import get_logger
class PRCodeSuggestions: class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None): def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
@ -31,7 +33,7 @@ class PRCodeSuggestions:
else: else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -4,7 +4,8 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +16,8 @@ from pr_agent.log import get_logger
class PRDescription: class PRDescription:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model. using an AI model.
@ -36,7 +38,7 @@ class PRDescription:
get_settings().pr_description.enable_semantic_files_types = False get_settings().pr_description.enable_semantic_files_types = False
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = ai_handler
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {

View File

@ -4,7 +4,8 @@ from typing import List, Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
@ -15,7 +16,8 @@ from pr_agent.log import get_logger
class PRGenerateLabels: class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model. corresponding to the PR using an AI model.
@ -31,7 +33,7 @@ class PRGenerateLabels:
self.pr_id = self.git_provider.get_pr_id() self.pr_id = self.git_provider.get_pr_id()
# Initialize the AI handler # Initialize the AI handler
self.ai_handler = AiHandler() self.ai_handler = ai_handler
# Initialize the variables dictionary # Initialize the variables dictionary
self.vars = { self.vars = {

View File

@ -2,7 +2,8 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,12 +13,13 @@ from pr_agent.log import get_logger
class PRInformationFromUser: class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None): def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(), "branch": self.git_provider.get_pr_branch(),

View File

@ -2,7 +2,8 @@ import copy
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -12,13 +13,13 @@ from pr_agent.log import get_logger
class PRQuestions: class PRQuestions:
def __init__(self, pr_url: str, args=None): def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
question_str = self.parse_args(args) question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language( self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files() self.git_provider.get_languages(), self.git_provider.get_files()
) )
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.question_str = question_str self.question_str = question_str
self.vars = { self.vars = {
"title": self.git_provider.pr.title, "title": self.git_provider.pr.title,

View File

@ -7,7 +7,8 @@ import yaml
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader from yaml import SafeLoader
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
@ -22,13 +23,16 @@ class PRReviewer:
""" """
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model. The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
""" """
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None): def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
""" """
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request. Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.
Args: Args:
pr_url (str): The URL of the pull request to be reviewed. pr_url (str): The URL of the pull request to be reviewed.
is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False. is_answer (bool, optional): Indicates whether the review is being done in answer mode. Defaults to False.
is_auto (bool, optional): Indicates whether the review is being done in automatic mode. Defaults to False.
ai_handler (BaseAiHandler): The AI handler to be used for the review. Defaults to None.
args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None. args (list, optional): List of arguments passed to the PRReviewer class. Defaults to None.
""" """
self.parse_args(args) # -i command self.parse_args(args) # -i command
@ -43,7 +47,7 @@ class PRReviewer:
if self.is_answer and not self.git_provider.is_supported("get_issue_comments"): if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now") raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None

View File

@ -5,7 +5,8 @@ from typing import Tuple
from jinja2 import Environment, StrictUndefined from jinja2 import Environment, StrictUndefined
from pr_agent.algo.ai_handler import AiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
@ -17,7 +18,7 @@ CHANGELOG_LINES = 50
class PRUpdateChangelog: class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None): def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.git_provider = get_git_provider()(pr_url) self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language( self.main_language = get_main_pr_language(
@ -25,7 +26,7 @@ class PRUpdateChangelog:
) )
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler() self.ai_handler = ai_handler
self.patches_diff = None self.patches_diff = None
self.prediction = None self.prediction = None
self.cli_mode = cli_mode self.cli_mode = cli_mode

View File

@ -23,3 +23,4 @@ starlette-context==0.3.6
tiktoken==0.5.2 tiktoken==0.5.2
ujson==5.8.0 ujson==5.8.0
uvicorn==0.22.0 uvicorn==0.22.0
# langchain==0.0.349 # uncomment this to support language LangChainOpenAIHandler