From 727b08fde3fe7b1a8fdaefa006195f70132b3501 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Wed, 9 Oct 2024 08:53:34 +0300 Subject: [PATCH] feat: add support for O1 model by combining system and user prompts in litellm_ai_handler --- .../algo/ai_handlers/litellm_ai_handler.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 43812386..e1587f1d 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -171,6 +171,7 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().warning( "Empty system prompt for claude model. Adding a newline character to prevent OpenAI API error.") messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + if img_path: try: # check if the image link is alive @@ -185,14 +186,28 @@ class LiteLLMAIHandler(BaseAiHandler): messages[1]["content"] = [{"type": "text", "text": messages[1]["content"]}, {"type": "image_url", "image_url": {"url": img_path}}] - kwargs = { - "model": model, - "deployment_id": deployment_id, - "messages": messages, - "temperature": temperature, - "timeout": get_settings().config.ai_timeout, - "api_base": self.api_base, - } + # Currently O1 does not support separate system and user prompts + if model.startswith('o1-'): + user = f"{system}\n\n\n{user}" + system = "" + get_logger().info(f"Using O1 model, combining system and user prompts") + messages = [{"role": "user", "content": user}] + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "timeout": get_settings().config.ai_timeout, + "api_base": self.api_base, + } + else: + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "temperature": temperature, + "timeout": get_settings().config.ai_timeout, + "api_base": self.api_base, + } if get_settings().litellm.get("enable_callbacks", False): kwargs = self.add_litellm_callbacks(kwargs)