diff --git a/pr_agent/algo/ai_handlers/litellm_ai_handler.py b/pr_agent/algo/ai_handlers/litellm_ai_handler.py index 596ff7f9..380a7893 100644 --- a/pr_agent/algo/ai_handlers/litellm_ai_handler.py +++ b/pr_agent/algo/ai_handlers/litellm_ai_handler.py @@ -390,20 +390,8 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().info(f"\nSystem prompt:\n{system}") get_logger().info(f"\nUser prompt:\n{user}") - # Check if model requires streaming - if model in self.streaming_required_models: - kwargs["stream"] = True - get_logger().info(f"Using streaming mode for model {model}") - response = await acompletion(**kwargs) - # Handle streaming response - resp, finish_reason = await self._handle_streaming_response(response) - else: - response = await acompletion(**kwargs) - # Handle non-streaming response - if response is None or len(response["choices"]) == 0: - raise openai.APIError - resp = response["choices"][0]['message']['content'] - finish_reason = response["choices"][0]["finish_reason"] + # Get completion with automatic streaming detection + resp, finish_reason, response_obj = await self._get_completion(model, **kwargs) except openai.RateLimitError as e: get_logger().error(f"Rate limit error during LLM inference: {e}") @@ -418,12 +406,7 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().debug(f"\nAI response:\n{resp}") # log the full response for debugging - if model in self.streaming_required_models: - # for streaming, we don't have the full response object, so we create a mock one - mock_response = MockResponse(resp, finish_reason) - response_log = self.prepare_logs(mock_response, system, user, resp, finish_reason) - else: - response_log = self.prepare_logs(response, system, user, resp, finish_reason) + response_log = self.prepare_logs(response_obj, system, user, resp, finish_reason) get_logger().debug("Full_response", artifact=response_log) # for CLI debugging @@ -466,3 +449,23 @@ class LiteLLMAIHandler(BaseAiHandler): get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}") return full_response, finish_reason + + async def _get_completion(self, model, **kwargs): + """ + Wrapper that automatically handles streaming for required models. + """ + if model in self.streaming_required_models: + kwargs["stream"] = True + get_logger().info(f"Using streaming mode for model {model}") + response = await acompletion(**kwargs) + resp, finish_reason = await self._handle_streaming_response(response) + # Create MockResponse for streaming since we don't have the full response object + mock_response = MockResponse(resp, finish_reason) + return resp, finish_reason, mock_response + else: + response = await acompletion(**kwargs) + if response is None or len(response["choices"]) == 0: + raise openai.APIError + return (response["choices"][0]['message']['content'], + response["choices"][0]["finish_reason"], + response)