From b7225cc674d28e6c337768630223090ad78cc8ee Mon Sep 17 00:00:00 2001 From: Brian Pham Date: Tue, 12 Dec 2023 23:52:50 +0800 Subject: [PATCH] update langchain --- pr_agent/algo/ai_handlers/langchain_ai_handler.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pr_agent/algo/ai_handlers/langchain_ai_handler.py b/pr_agent/algo/ai_handlers/langchain_ai_handler.py index bc26e624..5c793f2b 100644 --- a/pr_agent/algo/ai_handlers/langchain_ai_handler.py +++ b/pr_agent/algo/ai_handlers/langchain_ai_handler.py @@ -1,9 +1,15 @@ -from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from langchain.chat_models import ChatOpenAI from langchain.schema import SystemMessage, HumanMessage + +from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.config_loader import get_settings from pr_agent.log import get_logger +from openai.error import APIError, RateLimitError, Timeout, TryAgain +from retry import retry + +OPENAI_RETRIES = 5 + class LangChainOpenAIHandler(BaseAiHandler): def __init__(self): # Initialize OpenAIHandler specific attributes here @@ -24,15 +30,14 @@ class LangChainOpenAIHandler(BaseAiHandler): Returns the deployment ID for the OpenAI API. """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - + @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): try: - get_logger().info("model: ", model) messages=[SystemMessage(content=system), HumanMessage(content=user)] # get a chat completion from the formatted messages resp = self.chat(messages, model=model, temperature=temperature) - get_logger().info("AI response: ", resp.content) finish_reason="completed" return resp.content, finish_reason