Merge pull request #1605 from KennyDizi/main

Support extended thinking for model `claude-3-7-sonnet-20250219`
This commit is contained in:
Tal
2025-03-09 17:03:37 +02:00
committed by GitHub
4 changed files with 65 additions and 4 deletions

View File

@ -6,7 +6,7 @@ import requests
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from pr_agent.algo import NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.utils import ReasoningEffort, get_version
from pr_agent.config_loader import get_settings
@ -109,6 +109,9 @@ class LiteLLMAIHandler(BaseAiHandler):
# Models that support reasoning effort
self.support_reasoning_models = SUPPORT_REASONING_EFFORT_MODELS
# Models that support extended thinking
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy()
response_log['system'] = system
@ -121,6 +124,43 @@ class LiteLLMAIHandler(BaseAiHandler):
response_log['main_pr_language'] = 'unknown'
return response_log
def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
"""
Configure Claude extended thinking parameters if applicable.
Args:
model (str): The AI model being used
kwargs (dict): The keyword arguments for the model call
Returns:
dict: Updated kwargs with extended thinking configuration
"""
extended_thinking_budget_tokens = get_settings().config.get("extended_thinking_budget_tokens", 2048)
extended_thinking_max_output_tokens = get_settings().config.get("extended_thinking_max_output_tokens", 2048)
# Validate extended thinking parameters
if not isinstance(extended_thinking_budget_tokens, int) or extended_thinking_budget_tokens <= 0:
raise ValueError(f"extended_thinking_budget_tokens must be a positive integer, got {extended_thinking_budget_tokens}")
if not isinstance(extended_thinking_max_output_tokens, int) or extended_thinking_max_output_tokens <= 0:
raise ValueError(f"extended_thinking_max_output_tokens must be a positive integer, got {extended_thinking_max_output_tokens}")
if extended_thinking_max_output_tokens < extended_thinking_budget_tokens:
raise ValueError(f"extended_thinking_max_output_tokens ({extended_thinking_max_output_tokens}) must be greater than or equal to extended_thinking_budget_tokens ({extended_thinking_budget_tokens})")
kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": extended_thinking_budget_tokens
}
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Adding max output tokens {extended_thinking_max_output_tokens} to model {model}, extended thinking budget tokens: {extended_thinking_budget_tokens}")
kwargs["max_tokens"] = extended_thinking_max_output_tokens
# temperature may only be set to 1 when thinking is enabled
if get_settings().config.verbosity_level >= 2:
get_logger().info("Temperature may only be set to 1 when thinking is enabled with claude models.")
kwargs["temperature"] = 1
return kwargs
def add_litellm_callbacks(selfs, kwargs) -> dict:
captured_extra = []
@ -246,6 +286,10 @@ class LiteLLMAIHandler(BaseAiHandler):
get_logger().info(f"Adding reasoning_effort with value {reasoning_effort} to model {model}.")
kwargs["reasoning_effort"] = reasoning_effort
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
if (model in self.claude_extended_thinking_models) and get_settings().config.get("enable_claude_extended_thinking", False):
kwargs = self._configure_claude_extended_thinking(model, kwargs)
if get_settings().litellm.get("enable_callbacks", False):
kwargs = self.add_litellm_callbacks(kwargs)
@ -268,13 +312,13 @@ class LiteLLMAIHandler(BaseAiHandler):
except json.JSONDecodeError as e:
raise ValueError(f"LITELLM.EXTRA_HEADERS contains invalid JSON: {str(e)}")
kwargs["extra_headers"] = litellm_extra_headers
get_logger().debug("Prompts", artifact={"system": system, "user": user})
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}")
response = await acompletion(**kwargs)
except (openai.APIError, openai.APITimeoutError) as e:
get_logger().warning(f"Error during LLM inference: {e}")