mirror of
https://github.com/qodo-ai/pr-agent.git
synced 2025-07-15 02:00:39 +08:00
refactor(ai_handler): move streaming response handling and Azure token generation to helpers
This commit is contained in:
@ -7,29 +7,14 @@ from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type
|
|||||||
|
|
||||||
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS
|
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS
|
||||||
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
|
||||||
|
from pr_agent.algo.ai_handlers.litellm_helpers import _handle_streaming_response, MockResponse, _get_azure_ad_token, \
|
||||||
|
_process_litellm_extra_body
|
||||||
from pr_agent.algo.utils import ReasoningEffort, get_version
|
from pr_agent.algo.utils import ReasoningEffort, get_version
|
||||||
from pr_agent.config_loader import get_settings
|
from pr_agent.config_loader import get_settings
|
||||||
from pr_agent.log import get_logger
|
from pr_agent.log import get_logger
|
||||||
import json
|
import json
|
||||||
|
|
||||||
OPENAI_RETRIES = 5
|
MODEL_RETRIES = 2
|
||||||
|
|
||||||
|
|
||||||
class MockResponse:
|
|
||||||
"""Mock response object for streaming models to enable consistent logging."""
|
|
||||||
|
|
||||||
def __init__(self, resp, finish_reason):
|
|
||||||
self._data = {
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"message": {"content": resp},
|
|
||||||
"finish_reason": finish_reason
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
def dict(self):
|
|
||||||
return self._data
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMAIHandler(BaseAiHandler):
|
class LiteLLMAIHandler(BaseAiHandler):
|
||||||
@ -127,7 +112,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
if get_settings().get("AZURE_AD.CLIENT_ID", None):
|
if get_settings().get("AZURE_AD.CLIENT_ID", None):
|
||||||
self.azure = True
|
self.azure = True
|
||||||
# Generate access token using Azure AD credentials from settings
|
# Generate access token using Azure AD credentials from settings
|
||||||
access_token = self._get_azure_ad_token()
|
access_token = _get_azure_ad_token()
|
||||||
litellm.api_key = access_token
|
litellm.api_key = access_token
|
||||||
openai.api_key = access_token
|
openai.api_key = access_token
|
||||||
|
|
||||||
@ -163,26 +148,6 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
# Models that require streaming
|
# Models that require streaming
|
||||||
self.streaming_required_models = STREAMING_REQUIRED_MODELS
|
self.streaming_required_models = STREAMING_REQUIRED_MODELS
|
||||||
|
|
||||||
def _get_azure_ad_token(self):
|
|
||||||
"""
|
|
||||||
Generates an access token using Azure AD credentials from settings.
|
|
||||||
Returns:
|
|
||||||
str: The access token
|
|
||||||
"""
|
|
||||||
from azure.identity import ClientSecretCredential
|
|
||||||
try:
|
|
||||||
credential = ClientSecretCredential(
|
|
||||||
tenant_id=get_settings().azure_ad.tenant_id,
|
|
||||||
client_id=get_settings().azure_ad.client_id,
|
|
||||||
client_secret=get_settings().azure_ad.client_secret
|
|
||||||
)
|
|
||||||
# Get token for Azure OpenAI service
|
|
||||||
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
|
||||||
return token.token
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().error(f"Failed to get Azure AD token: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def prepare_logs(self, response, system, user, resp, finish_reason):
|
def prepare_logs(self, response, system, user, resp, finish_reason):
|
||||||
response_log = response.dict().copy()
|
response_log = response.dict().copy()
|
||||||
response_log['system'] = system
|
response_log['system'] = system
|
||||||
@ -195,37 +160,6 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
response_log['main_pr_language'] = 'unknown'
|
response_log['main_pr_language'] = 'unknown'
|
||||||
return response_log
|
return response_log
|
||||||
|
|
||||||
def _process_litellm_extra_body(self, kwargs: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kwargs: The current kwargs dictionary to update
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated kwargs dictionary
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys
|
|
||||||
"""
|
|
||||||
allowed_extra_body_keys = {"processing_mode", "service_tier"}
|
|
||||||
extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None)
|
|
||||||
if extra_body:
|
|
||||||
try:
|
|
||||||
litellm_extra_body = json.loads(extra_body)
|
|
||||||
if not isinstance(litellm_extra_body, dict):
|
|
||||||
raise ValueError("LITELLM.EXTRA_BODY must be a JSON object")
|
|
||||||
unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys
|
|
||||||
if unsupported_keys:
|
|
||||||
raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}")
|
|
||||||
colliding_keys = kwargs.keys() & litellm_extra_body.keys()
|
|
||||||
if colliding_keys:
|
|
||||||
raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}")
|
|
||||||
kwargs.update(litellm_extra_body)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}")
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
|
def _configure_claude_extended_thinking(self, model: str, kwargs: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Configure Claude extended thinking parameters if applicable.
|
Configure Claude extended thinking parameters if applicable.
|
||||||
@ -326,7 +260,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
|
retry=retry_if_exception_type(openai.APIError) & retry_if_not_exception_type(openai.RateLimitError),
|
||||||
stop=stop_after_attempt(OPENAI_RETRIES),
|
stop=stop_after_attempt(MODEL_RETRIES),
|
||||||
)
|
)
|
||||||
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
|
||||||
try:
|
try:
|
||||||
@ -416,7 +350,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
kwargs["extra_headers"] = litellm_extra_headers
|
kwargs["extra_headers"] = litellm_extra_headers
|
||||||
|
|
||||||
# Support for custom OpenAI body fields (e.g., Flex Processing)
|
# Support for custom OpenAI body fields (e.g., Flex Processing)
|
||||||
kwargs = self._process_litellm_extra_body(kwargs)
|
kwargs = _process_litellm_extra_body(kwargs)
|
||||||
|
|
||||||
get_logger().debug("Prompts", artifact={"system": system, "user": user})
|
get_logger().debug("Prompts", artifact={"system": system, "user": user})
|
||||||
|
|
||||||
@ -449,41 +383,6 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
|
|
||||||
return resp, finish_reason
|
return resp, finish_reason
|
||||||
|
|
||||||
async def _handle_streaming_response(self, response):
|
|
||||||
"""
|
|
||||||
Handle streaming response from acompletion and collect the full response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The streaming response object from acompletion
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (full_response_content, finish_reason)
|
|
||||||
"""
|
|
||||||
full_response = ""
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for chunk in response:
|
|
||||||
if chunk.choices and len(chunk.choices) > 0:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
delta = choice.delta
|
|
||||||
content = getattr(delta, 'content', None)
|
|
||||||
if content:
|
|
||||||
full_response += content
|
|
||||||
if choice.finish_reason:
|
|
||||||
finish_reason = choice.finish_reason
|
|
||||||
except Exception as e:
|
|
||||||
get_logger().error(f"Error handling streaming response: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if not full_response and finish_reason is None:
|
|
||||||
get_logger().warning("Streaming response resulted in empty content with no finish reason")
|
|
||||||
raise openai.APIError("Empty streaming response received without proper completion")
|
|
||||||
elif not full_response and finish_reason:
|
|
||||||
get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}")
|
|
||||||
raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received")
|
|
||||||
return full_response, finish_reason
|
|
||||||
|
|
||||||
async def _get_completion(self, model, **kwargs):
|
async def _get_completion(self, model, **kwargs):
|
||||||
"""
|
"""
|
||||||
Wrapper that automatically handles streaming for required models.
|
Wrapper that automatically handles streaming for required models.
|
||||||
@ -492,7 +391,7 @@ class LiteLLMAIHandler(BaseAiHandler):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
get_logger().info(f"Using streaming mode for model {model}")
|
get_logger().info(f"Using streaming mode for model {model}")
|
||||||
response = await acompletion(**kwargs)
|
response = await acompletion(**kwargs)
|
||||||
resp, finish_reason = await self._handle_streaming_response(response)
|
resp, finish_reason = await _handle_streaming_response(response)
|
||||||
# Create MockResponse for streaming since we don't have the full response object
|
# Create MockResponse for streaming since we don't have the full response object
|
||||||
mock_response = MockResponse(resp, finish_reason)
|
mock_response = MockResponse(resp, finish_reason)
|
||||||
return resp, finish_reason, mock_response
|
return resp, finish_reason, mock_response
|
||||||
|
113
pr_agent/algo/ai_handlers/litellm_helpers.py
Normal file
113
pr_agent/algo/ai_handlers/litellm_helpers.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from azure.identity import ClientSecretCredential
|
||||||
|
|
||||||
|
from pr_agent.config_loader import get_settings
|
||||||
|
from pr_agent.log import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_streaming_response(response):
|
||||||
|
"""
|
||||||
|
Handle streaming response from acompletion and collect the full response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The streaming response object from acompletion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (full_response_content, finish_reason)
|
||||||
|
"""
|
||||||
|
full_response = ""
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in response:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
content = getattr(delta, 'content', None)
|
||||||
|
if content:
|
||||||
|
full_response += content
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Error handling streaming response: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not full_response and finish_reason is None:
|
||||||
|
get_logger().warning("Streaming response resulted in empty content with no finish reason")
|
||||||
|
raise openai.APIError("Empty streaming response received without proper completion")
|
||||||
|
elif not full_response and finish_reason:
|
||||||
|
get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}")
|
||||||
|
raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received")
|
||||||
|
return full_response, finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
"""Mock response object for streaming models to enable consistent logging."""
|
||||||
|
|
||||||
|
def __init__(self, resp, finish_reason):
|
||||||
|
self._data = {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {"content": resp},
|
||||||
|
"finish_reason": finish_reason
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_azure_ad_token():
|
||||||
|
"""
|
||||||
|
Generates an access token using Azure AD credentials from settings.
|
||||||
|
Returns:
|
||||||
|
str: The access token
|
||||||
|
"""
|
||||||
|
from azure.identity import ClientSecretCredential
|
||||||
|
try:
|
||||||
|
credential = ClientSecretCredential(
|
||||||
|
tenant_id=get_settings().azure_ad.tenant_id,
|
||||||
|
client_id=get_settings().azure_ad.client_id,
|
||||||
|
client_secret=get_settings().azure_ad.client_secret
|
||||||
|
)
|
||||||
|
# Get token for Azure OpenAI service
|
||||||
|
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||||
|
return token.token
|
||||||
|
except Exception as e:
|
||||||
|
get_logger().error(f"Failed to get Azure AD token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _process_litellm_extra_body(kwargs: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Process LITELLM.EXTRA_BODY configuration and update kwargs accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: The current kwargs dictionary to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated kwargs dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If extra_body contains invalid JSON, unsupported keys, or colliding keys
|
||||||
|
"""
|
||||||
|
allowed_extra_body_keys = {"processing_mode", "service_tier"}
|
||||||
|
extra_body = getattr(getattr(get_settings(), "litellm", None), "extra_body", None)
|
||||||
|
if extra_body:
|
||||||
|
try:
|
||||||
|
litellm_extra_body = json.loads(extra_body)
|
||||||
|
if not isinstance(litellm_extra_body, dict):
|
||||||
|
raise ValueError("LITELLM.EXTRA_BODY must be a JSON object")
|
||||||
|
unsupported_keys = set(litellm_extra_body.keys()) - allowed_extra_body_keys
|
||||||
|
if unsupported_keys:
|
||||||
|
raise ValueError(f"LITELLM.EXTRA_BODY contains unsupported keys: {', '.join(unsupported_keys)}. Allowed keys: {', '.join(allowed_extra_body_keys)}")
|
||||||
|
colliding_keys = kwargs.keys() & litellm_extra_body.keys()
|
||||||
|
if colliding_keys:
|
||||||
|
raise ValueError(f"LITELLM.EXTRA_BODY cannot override existing parameters: {', '.join(colliding_keys)}")
|
||||||
|
kwargs.update(litellm_extra_body)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"LITELLM.EXTRA_BODY contains invalid JSON: {str(e)}")
|
||||||
|
return kwargs
|
Reference in New Issue
Block a user