diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 8e2e1617..a21e3a71 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -7,29 +7,14 @@ from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler +from pr_agent.algo.ai_handlers.litellm_helpers import _handle_streaming_response, MockResponse, _get_azure_ad_token, \ + _process_litellm_extra_body from pr_agent.algo.utils import ReasoningEffort, get_version from pr_agent.config_loader import get_settings from pr_agent.log import get_logger import json -OPENAI_RETRIES = 5 - - -class MockResponse: - """Mock response object for streaming models to enable consistent logging.""" - - def __init__(self, resp, finish_reason): - self._data = { - "choices": [ - { - "message": {"content": resp}, - "finish_reason": finish_reason - } - ] - } - - def dict(self): - return self._data +MODEL_RETRIES = 2 class LiteLLMAIHandler(BaseAiHandler): @@ -127,7 +112,7 @@ class LiteLLMAIHandler(BaseAiHandler): if get_settings().get("AZURE_AD.CLIENT_ID", None): self.azure = True # Generate access token using Azure AD credentials from settings - access_token = self._get_azure_ad_token() + access_token = _get_azure_ad_token() litellm.api_key = access_token openai.api_key = access_token @@ -163,26 +148,6 @@ class LiteLLMAIHandler(BaseAiHandler): # Models that require streaming self.streaming_required_models = STREAMING_REQUIRED_MODELS - def _get_azure_ad_token(self): - """ - Generates an access token using Azure AD credentials from settings. - Returns: - str: The access token - """ - from azure.identity import ClientSecretCredential - try: - credential = ClientSecretCredential( - tenant_id=get_settings().azure_ad.tenant_id, - client_id=get_settings().azure_ad.client_id, - client_secret=get_settings().azure_ad.client_secret - ) - # Get token for Azure OpenAI service - token = credential.get_token("https://cognitiveservices.azure.com/.default") - return token.token - except Exception as e: - get_logger().error(f"Failed to get Azure AD token: {e}") - raise - def prepare_logs(self, response, system, user, resp, finish_reason): response_log = response.dict().copy() response_log['system'] = system @@ -195,37 +160,6 @@ class LiteLLMAIHandler(BaseAiHandler): response_log['main_pr_language'] = 'unknown' return response_log - def _process_litellm_extra_body(self, kwargs: dict) -> dict: - """ - Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly. - - Args: - kwargs: The current kwargs dictionary to update - - Returns: - Updated kwargs dictionary - - Raises: - ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys - """ - allowed_extra_body_keys = {"processing_mode", "service_tier"} - extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None) - if extra_body: - try: - litellm_extra_body = json.loads(extra_body) - if not isinstance(litellm_extra_body, dict): - raise ValueError("LITELLM.EXTRA_BODY must be a JSON object") - unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys - if unsupported_keys: - raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}") - colliding_keys = kwargs.keys() & litellm_extra_body.keys() - if colliding_keys: - raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}") - kwargs.update(litellm_extra_body) - except json.JSONDecodeError as e: - raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}") - return kwargs - def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict: """ Configure Claude extended thinking parameters if applicable. @@ -326,7 +260,7 @@ class LiteLLMAIHandler(BaseAiHandler): @retry( retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError), - stop=stop_after_attempt(OPENAI_RETRIES), + stop=stop_after_attempt(MODEL_RETRIES), ) async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None): try: @@ -416,7 +350,7 @@ class LiteLLMAIHandler(BaseAiHandler): kwargs["extra_headers"] = litellm_extra_headers # Support for custom OpenAI body fields (e.g., Flex Processing) - kwargs = self._process_litellm_extra_body(kwargs) + kwargs = _process_litellm_extra_body(kwargs) get_logger().debug("Prompts", artifact={"system": system, "user": user}) @@ -449,41 +383,6 @@ class LiteLLMAIHandler(BaseAiHandler): return resp, finish_reason - async def _handle_streaming_response(self, response): - """ - Handle streaming response from acompletion and collect the full response. - - Args: - response: The streaming response object from acompletion - - Returns: - tuple: (full_response_content, finish_reason) - """ - full_response = "" - finish_reason = None - - try: - async for chunk in response: - if chunk.choices and len(chunk.choices) > 0: - choice = chunk.choices[0] - delta = choice.delta - content = getattr(delta, 'content', None) - if content: - full_response += content - if choice.finish_reason: - finish_reason = choice.finish_reason - except Exception as e: - get_logger().error(f"Error handling streaming response: {e}") - raise - - if not full_response and finish_reason is None: - get_logger().warning("Streaming response resulted in empty content with no finish reason") - raise openai.APIError("Empty streaming response received without proper completion") - elif not full_response and finish_reason: - get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}") - raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received") - return full_response, finish_reason - async def _get_completion(self, model, **kwargs): """ Wrapper that automatically handles streaming for required models. @@ -492,7 +391,7 @@ class LiteLLMAIHandler(BaseAiHandler): kwargs["stream"] = True get_logger().info(f"Using streaming mode for model {model}") response = await acompletion(**kwargs) - resp, finish_reason = await self._handle_streaming_response(response) + resp, finish_reason = await _handle_streaming_response(response) # Create MockResponse for streaming since we don't have the full response object mock_response = MockResponse(resp, finish_reason) return resp, finish_reason, mock_response diff --git a/pr_agent/algo/ai_handlers/litellm_helpers.py b/pr_agent/algo/ai_handlers/litellm_helpers.py new file mode 100644 index 00000000..5f30655d --- /dev/null +++ b/pr_agent/algo/ai_handlers/litellm_helpers.py @@ -0,0 +1,113 @@ +import json + +import openai +from azure.identity import ClientSecretCredential + +from pr_agent.config_loader import get_settings +from pr_agent.log import get_logger + + +async def _handle_streaming_response(response): + """ + Handle streaming response from acompletion and collect the full response. + + Args: + response: The streaming response object from acompletion + + Returns: + tuple: (full_response_content, finish_reason) + """ + full_response = "" + finish_reason = None + + try: + async for chunk in response: + if chunk.choices and len(chunk.choices) > 0: + choice = chunk.choices[0] + delta = choice.delta + content = getattr(delta, 'content', None) + if content: + full_response += content + if choice.finish_reason: + finish_reason = choice.finish_reason + except Exception as e: + get_logger().error(f"Error handling streaming response: {e}") + raise + + if not full_response and finish_reason is None: + get_logger().warning("Streaming response resulted in empty content with no finish reason") + raise openai.APIError("Empty streaming response received without proper completion") + elif not full_response and finish_reason: + get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}") + raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received") + return full_response, finish_reason + + +class MockResponse: + """Mock response object for streaming models to enable consistent logging.""" + + def __init__(self, resp, finish_reason): + self._data = { + "choices": [ + { + "message": {"content": resp}, + "finish_reason": finish_reason + } + ] + } + + def dict(self): + return self._data + + +def _get_azure_ad_token(): + """ + Generates an access token using Azure AD credentials from settings. + Returns: + str: The access token + """ + from azure.identity import ClientSecretCredential + try: + credential = ClientSecretCredential( + tenant_id=get_settings().azure_ad.tenant_id, + client_id=get_settings().azure_ad.client_id, + client_secret=get_settings().azure_ad.client_secret + ) + # Get token for Azure OpenAI service + token = credential.get_token("https://cognitiveservices.azure.com/.default") + return token.token + except Exception as e: + get_logger().error(f"Failed to get Azure AD token: {e}") + raise + + +def _process_litellm_extra_body(kwargs: dict) -> dict: + """ + Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly. + + Args: + kwargs: The current kwargs dictionary to update + + Returns: + Updated kwargs dictionary + + Raises: + ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys + """ + allowed_extra_body_keys = {"processing_mode", "service_tier"} + extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None) + if extra_body: + try: + litellm_extra_body = json.loads(extra_body) + if not isinstance(litellm_extra_body, dict): + raise ValueError("LITELLM.EXTRA_BODY must be a JSON object") + unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys + if unsupported_keys: + raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}") + colliding_keys = kwargs.keys() & litellm_extra_body.keys() + if colliding_keys: + raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}") + kwargs.update(litellm_extra_body) + except json.JSONDecodeError as e: + raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}") + return kwargs \ No newline at end of file