From 6c7beccb4f11dcd495522c3d6a9cd65ef5c87c4b Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:03:49 +0800 Subject: [PATCH] add LangChain AI Handler --- .../algo/ai_handlers/langchain_ai_handler.py | 45 +++++++++++++++++++ pr_agent/algo/utils.py | 3 +- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 pr_agent/algo/ai_handlers/langchain_ai_handler.py diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py new file mode 100644 index 00000000..406c6c40 --- /dev/null +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -0,0 +1,45 @@ +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import SystemMessage, HumanMessage + + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + +OPENAI_RETRIES = 5 +chat = ChatOpenAI(openai_api_key = get_settings().openai.key, model="gpt-4") + +class LangChainAIHandler(BaseAiHandler): + def __init__(self): + # Initialize OpenAIHandler specific attributes here + try: + super().__init__() + + except AttributeError as e: + raise ValueError("OpenAI key is required") from e + @property + def deployment_id(self): + """ + Returns the deployment ID for the OpenAI API. + """ + return get_settings().get("OPENAI.DEPLOYMENT_ID", None) + + async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + try: + + messages=[SystemMessage(content=system), HumanMessage(content=user)] + + # get a chat completion from the formatted messages + resp = chat(messages) + get_logger().info("AI response: ", resp.content) + finish_reason="completed" + return resp.content, finish_reason + + except (Exception) as e: + get_logger().error("Unknown error during OpenAI inference: ", e) + raise e \ No newline at end of file diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 824e4b70..d0b86b63 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -11,6 +11,7 @@ import yaml from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.openai_ai_handler import OpenAIHandler +from pr_agent.algo.ai_handlers.langchain_ai_handler import LangChainAIHandler from starlette_context import context from pr_agent.config_loader import get_settings, global_settings from pr_agent.log import get_logger @@ -309,4 +310,4 @@ def try_fix_yaml(review_text: str) -> dict: return data def get_ai_handler() -> BaseAiHandler: - return OpenAIHandler() \ No newline at end of file + return LangChainAIHandler() \ No newline at end of file