Compare commits

...

11 Commits

3 changed files with 159 additions and 68 deletions

View File

@ -45,6 +45,7 @@ MAX_TOKENS = {
'command-nightly': 4096, 'command-nightly': 4096,
'deepseek/deepseek-chat': 128000, # 128K, but may be limited by config.max_model_tokens 'deepseek/deepseek-chat': 128000, # 128K, but may be limited by config.max_model_tokens
'deepseek/deepseek-reasoner': 64000, # 64K, but may be limited by config.max_model_tokens 'deepseek/deepseek-reasoner': 64000, # 64K, but may be limited by config.max_model_tokens
'openai/qwq-plus': 131072, # 131K context length, but may be limited by config.max_model_tokens
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096, 'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
'meta-llama/Llama-2-7b-chat-hf': 4096, 'meta-llama/Llama-2-7b-chat-hf': 4096,
'vertex_ai/codechat-bison': 6144, 'vertex_ai/codechat-bison': 6144,
@ -193,3 +194,8 @@ CLAUDE_EXTENDED_THINKING_MODELS = [
"anthropic/claude-3-7-sonnet-20250219", "anthropic/claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219" "claude-3-7-sonnet-20250219"
] ]
# Models that require streaming mode
STREAMING_REQUIRED_MODELS = [
"openai/qwq-plus"
]

View File

@ -5,14 +5,16 @@ import requests
from litellm import acompletion from litellm import acompletion
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_helpers import _handle_streaming_response, MockResponse, _get_azure_ad_token, \
_process_litellm_extra_body
from pr_agent.algo.utils import ReasoningEffort, get_version from pr_agent.algo.utils import ReasoningEffort, get_version
from pr_agent.config_loader import get_settings from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger from pr_agent.log import get_logger
import json import json
OPENAI_RETRIES = 5 MODEL_RETRIES = 2
class LiteLLMAIHandler(BaseAiHandler): class LiteLLMAIHandler(BaseAiHandler):
@ -110,7 +112,7 @@ class LiteLLMAIHandler(BaseAiHandler):
if get_settings().get("AZURE_AD.CLIENT_ID", None): if get_settings().get("AZURE_AD.CLIENT_ID", None):
self.azure = True self.azure = True
# Generate access token using Azure AD credentials from settings # Generate access token using Azure AD credentials from settings
access_token = self._get_azure_ad_token() access_token = _get_azure_ad_token()
litellm.api_key = access_token litellm.api_key = access_token
openai.api_key = access_token openai.api_key = access_token
@ -143,25 +145,8 @@ class LiteLLMAIHandler(BaseAiHandler):
# Models that support extended thinking # Models that support extended thinking
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS
def _get_azure_ad_token(self): # Models that require streaming
""" self.streaming_required_models = STREAMING_REQUIRED_MODELS
Generates an access token using Azure AD credentials from settings.
Returns:
str: The access token
"""
from azure.identity import ClientSecretCredential
try:
credential = ClientSecretCredential(
tenant_id=get_settings().azure_ad.tenant_id,
client_id=get_settings().azure_ad.client_id,
client_secret=get_settings().azure_ad.client_secret
)
# Get token for Azure OpenAI service
token = credential.get_token("https://cognitiveservices.azure.com/.default")
return token.token
except Exception as e:
get_logger().error(f"Failed to get Azure AD token: {e}")
raise
def prepare_logs(self, response, system, user, resp, finish_reason): def prepare_logs(self, response, system, user, resp, finish_reason):
response_log = response.dict().copy() response_log = response.dict().copy()
@ -175,37 +160,6 @@ class LiteLLMAIHandler(BaseAiHandler):
response_log['main_pr_language'] = 'unknown' response_log['main_pr_language'] = 'unknown'
return response_log return response_log
def _process_litellm_extra_body(self, kwargs: dict) -> dict:
"""
Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly.
Args:
kwargs: The current kwargs dictionary to update
Returns:
Updated kwargs dictionary
Raises:
ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys
"""
allowed_extra_body_keys = {"processing_mode", "service_tier"}
extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None)
if extra_body:
try:
litellm_extra_body = json.loads(extra_body)
if not isinstance(litellm_extra_body, dict):
raise ValueError("LITELLM.EXTRA_BODY must be a JSON object")
unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys
if unsupported_keys:
raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}")
colliding_keys = kwargs.keys() & litellm_extra_body.keys()
if colliding_keys:
raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}")
kwargs.update(litellm_extra_body)
except json.JSONDecodeError as e:
raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}")
return kwargs
def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict: def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
""" """
Configure Claude extended thinking parameters if applicable. Configure Claude extended thinking parameters if applicable.
@ -306,7 +260,7 @@ class LiteLLMAIHandler(BaseAiHandler):
@retry( @retry(
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
stop=stop_after_attempt(OPENAI_RETRIES), stop=stop_after_attempt(MODEL_RETRIES),
) )
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
try: try:
@ -396,7 +350,7 @@ class LiteLLMAIHandler(BaseAiHandler):
kwargs["extra_headers"] = litellm_extra_headers kwargs["extra_headers"] = litellm_extra_headers
# Support for custom OpenAI body fields (e.g., Flex Processing) # Support for custom OpenAI body fields (e.g., Flex Processing)
kwargs = self._process_litellm_extra_body(kwargs) kwargs = _process_litellm_extra_body(kwargs)
get_logger().debug("Prompts", artifact={"system": system, "user": user}) get_logger().debug("Prompts", artifact={"system": system, "user": user})
@ -404,7 +358,9 @@ class LiteLLMAIHandler(BaseAiHandler):
get_logger().info(f"\nSystem prompt:\n{system}") get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}") get_logger().info(f"\nUser prompt:\n{user}")
response = await acompletion(**kwargs) # Get completion with automatic streaming detection
resp, finish_reason, response_obj = await self._get_completion(**kwargs)
except openai.RateLimitError as e: except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}") get_logger().error(f"Rate limit error during LLM inference: {e}")
raise raise
@ -414,19 +370,36 @@ class LiteLLMAIHandler(BaseAiHandler):
except Exception as e: except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}") get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e raise openai.APIError from e
if response is None or len(response["choices"]) == 0:
raise openai.APIError
else:
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
get_logger().debug(f"\nAI response:\n{resp}")
# log the full response for debugging get_logger().debug(f"\nAI response:\n{resp}")
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging # log the full response for debugging
if get_settings().config.verbosity_level >= 2: response_log = self.prepare_logs(response_obj, system, user, resp, finish_reason)
get_logger().info(f"\nAI response:\n{resp}") get_logger().debug("Full_response", artifact=response_log)
# for CLI debugging
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nAI response:\n{resp}")
return resp, finish_reason return resp, finish_reason
async def _get_completion(self, **kwargs):
"""
Wrapper that automatically handles streaming for required models.
"""
model = kwargs["model"]
if model in self.streaming_required_models:
kwargs["stream"] = True
get_logger().info(f"Using streaming mode for model {model}")
response = await acompletion(**kwargs)
resp, finish_reason = await _handle_streaming_response(response)
# Create MockResponse for streaming since we don't have the full response object
mock_response = MockResponse(resp, finish_reason)
return resp, finish_reason, mock_response
else:
response = await acompletion(**kwargs)
if response is None or len(response["choices"]) == 0:
raise openai.APIError
return (response["choices"][0]['message']['content'],
response["choices"][0]["finish_reason"],
response)

View File

@ -0,0 +1,112 @@
import json
import openai
from pr_agent.config_loader import get_settings
from pr_agent.log import get_logger
async def _handle_streaming_response(response):
"""
Handle streaming response from acompletion and collect the full response.
Args:
response: The streaming response object from acompletion
Returns:
tuple: (full_response_content, finish_reason)
"""
full_response = ""
finish_reason = None
try:
async for chunk in response:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
delta = choice.delta
content = getattr(delta, 'content', None)
if content:
full_response += content
if choice.finish_reason:
finish_reason = choice.finish_reason
except Exception as e:
get_logger().error(f"Error handling streaming response: {e}")
raise
if not full_response and finish_reason is None:
get_logger().warning("Streaming response resulted in empty content with no finish reason")
raise openai.APIError("Empty streaming response received without proper completion")
elif not full_response and finish_reason:
get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}")
raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received")
return full_response, finish_reason
class MockResponse:
"""Mock response object for streaming models to enable consistent logging."""
def __init__(self, resp, finish_reason):
self._data = {
"choices": [
{
"message": {"content": resp},
"finish_reason": finish_reason
}
]
}
def dict(self):
return self._data
def _get_azure_ad_token():
"""
Generates an access token using Azure AD credentials from settings.
Returns:
str: The access token
"""
from azure.identity import ClientSecretCredential
try:
credential = ClientSecretCredential(
tenant_id=get_settings().azure_ad.tenant_id,
client_id=get_settings().azure_ad.client_id,
client_secret=get_settings().azure_ad.client_secret
)
# Get token for Azure OpenAI service
token = credential.get_token("https://cognitiveservices.azure.com/.default")
return token.token
except Exception as e:
get_logger().error(f"Failed to get Azure AD token: {e}")
raise
def _process_litellm_extra_body(kwargs: dict) -> dict:
"""
Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly.
Args:
kwargs: The current kwargs dictionary to update
Returns:
Updated kwargs dictionary
Raises:
ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys
"""
allowed_extra_body_keys = {"processing_mode", "service_tier"}
extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None)
if extra_body:
try:
litellm_extra_body = json.loads(extra_body)
if not isinstance(litellm_extra_body, dict):
raise ValueError("LITELLM.EXTRA_BODY must be a JSON object")
unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys
if unsupported_keys:
raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}")
colliding_keys = kwargs.keys() & litellm_extra_body.keys()
if colliding_keys:
raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}")
kwargs.update(litellm_extra_body)
except json.JSONDecodeError as e:
raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}")
return kwargs