Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 144 additions & 99 deletions litellm/responses/mcp/mcp_streaming_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def __init__(
self.should_auto_execute = self._should_auto_execute_tools()

# Streaming state management
self.phase = "mcp_discovery" # mcp_discovery -> initial_response -> tool_execution -> follow_up_response -> finished
# get_response_created: emit response.created first (OpenAI SDK expects it before other events)
# mcp_discovery -> initial_response -> tool_execution -> follow_up_response -> finished
self.phase = "get_response_created"
self.finished = False

# Event queues and generation flags
Expand All @@ -282,6 +284,7 @@ def __init__(
mcp_events # Store the initial MCP events for backward compatibility
)
self.tool_server_map = tool_server_map
self._sync_response_created_emitted = False

# Iterator references
self.base_iterator: Optional[
Expand Down Expand Up @@ -388,101 +391,123 @@ def __aiter__(self):
async def __anext__(self) -> ResponsesAPIStreamingResponse:
"""
Phase-based streaming:
1. mcp_discovery - Emit MCP discovery events
2. initial_response - Stream the first LLM response
3. tool_execution - Emit tool execution events
4. follow_up_response - Stream the follow-up response
5. finished - End iteration
1. get_response_created - Emit response.created first (OpenAI SDK expects it before other events)
2. mcp_discovery - Emit MCP discovery events
3. initial_response - Stream the first LLM response
4. tool_execution - Emit tool execution events
5. follow_up_response - Stream the follow-up response
6. finished - End iteration
"""

# Phase 1: MCP Discovery Events
if self.phase == "mcp_discovery":
# Generate MCP discovery events if not already done
# MCP discovery events are already generated and available

# Emit MCP discovery events
if self.mcp_discovery_events:
return self.mcp_discovery_events.pop(0)

# All MCP discovery events emitted, move to next phase
verbose_logger.debug(
"MCP discovery phase complete, transitioning to initial_response"
)
self.phase = "initial_response"
await self._create_initial_response_iterator()
# Fall through to process the initial response immediately

# Phase 2: Initial Response Stream
if self.phase == "initial_response":
if self.base_iterator:
# Check if base_iterator is actually iterable
if hasattr(self.base_iterator, "__anext__"):
try:
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]

# If auto-execution is enabled, check for completed responses
if self.should_auto_execute and self._is_response_completed(
chunk
):
# Collect the response for tool execution
response_obj = getattr(chunk, "response", None)
if isinstance(response_obj, ResponsesAPIResponse):
self.collected_response = response_obj
# Move to tool execution phase after emitting this chunk
self.phase = "tool_execution"
await self._generate_tool_execution_events()

return chunk
except StopAsyncIteration:
# Initial response ended, move to next phase
if self.should_auto_execute and self.collected_response:
self.phase = "tool_execution"
await self._generate_tool_execution_events()
else:
self.phase = "finished"
raise
else:
# base_iterator is not async iterable (likely a ResponsesAPIResponse)
# Collect it for tool execution if needed
if self.should_auto_execute and isinstance(
self.base_iterator, ResponsesAPIResponse
):
self.collected_response = self.base_iterator
self.phase = "tool_execution"
await self._generate_tool_execution_events()
else:
self.phase = "finished"
raise StopAsyncIteration

# Phase 3: Tool Execution Events
if self.phase == "tool_execution":
# Emit any queued tool execution events
if self.tool_execution_events:
return self.tool_execution_events.pop(0)

# Move to follow-up response phase
self.phase = "follow_up_response"
await self._create_follow_up_iterator()

# Phase 4: Follow-up Response Stream
if self.phase == "follow_up_response":
if self.follow_up_iterator:
try:
return await cast(Any, self.follow_up_iterator).__anext__() # type: ignore[attr-defined]
except StopAsyncIteration:
self.phase = "finished"
raise
while True:
if self.phase == "get_response_created":
result = await self._handle_get_response_created_phase()
elif self.phase == "mcp_discovery":
result = await self._handle_mcp_discovery_phase()
elif self.phase == "initial_response":
result = await self._handle_initial_response_phase()
elif self.phase == "tool_execution":
result = await self._handle_tool_execution_phase()
elif self.phase == "follow_up_response":
result = await self._handle_follow_up_response_phase()
else:
self.phase = "finished"
raise StopAsyncIteration

# Phase 5: Finished
if result is not None:
return result

async def _handle_get_response_created_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Emit response.created first (OpenAI SDK expects it before other events)."""
await self._create_initial_response_iterator()
if self.phase == "finished":
raise StopAsyncIteration
if self.base_iterator and hasattr(self.base_iterator, "__anext__"):
try:
first_chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
self.phase = "mcp_discovery"
return first_chunk
except StopAsyncIteration:
self.phase = "finished"
raise
self.phase = "mcp_discovery"
return None

async def _handle_mcp_discovery_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Emit MCP discovery events (after response.created)."""
if self.mcp_discovery_events:
return self.mcp_discovery_events.pop(0)
verbose_logger.debug(
"MCP discovery phase complete, transitioning to initial_response"
)
self.phase = "initial_response"
return None

async def _handle_initial_response_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Stream the first LLM response."""
if not self.base_iterator:
self.phase = "finished"
raise StopAsyncIteration
if not hasattr(self.base_iterator, "__anext__"):
return await self._handle_initial_response_non_iterable()
try:
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
if self.should_auto_execute and self._is_response_completed(chunk):
response_obj = getattr(chunk, "response", None)
if isinstance(response_obj, ResponsesAPIResponse):
self.collected_response = response_obj
self.phase = "tool_execution"
await self._generate_tool_execution_events()
return chunk
except StopAsyncIteration:
if self.should_auto_execute and self.collected_response:
self.phase = "tool_execution"
await self._generate_tool_execution_events()
return None
self.phase = "finished"
raise

# Should not reach here
async def _handle_initial_response_non_iterable(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Handle base_iterator that is not async iterable (e.g. ResponsesAPIResponse)."""
if self.should_auto_execute and isinstance(
self.base_iterator, ResponsesAPIResponse
):
self.collected_response = self.base_iterator
self.phase = "tool_execution"
await self._generate_tool_execution_events()
return None
self.phase = "finished"
raise StopAsyncIteration

async def _handle_tool_execution_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Emit tool execution events."""
if self.tool_execution_events:
return self.tool_execution_events.pop(0)
self.phase = "follow_up_response"
await self._create_follow_up_iterator()
return None

async def _handle_follow_up_response_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""Stream the follow-up response."""
if not self.follow_up_iterator:
self.phase = "finished"
raise StopAsyncIteration
try:
return await cast(Any, self.follow_up_iterator).__anext__() # type: ignore[attr-defined]
except StopAsyncIteration:
self.phase = "finished"
raise

def _is_response_completed(self, chunk: ResponsesAPIStreamingResponse) -> bool:
"""Check if this chunk indicates the response is completed"""
from litellm.types.llms.openai import ResponsesAPIStreamEvents
Expand Down Expand Up @@ -701,19 +726,39 @@ def __iter__(self):
return self

def __next__(self) -> ResponsesAPIStreamingResponse:
# First, emit any queued MCP events
if self.is_async:
raise RuntimeError("Cannot use sync iteration on async iterator")

# Emit response.created first (OpenAI SDK expects it before other events)
if not self._sync_response_created_emitted:
self._ensure_sync_base_iterator()
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
try:
first_chunk = next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
self._sync_response_created_emitted = True
return first_chunk
except StopIteration:
self.finished = True
raise
self._sync_response_created_emitted = True

# Then emit MCP discovery events
if self.mcp_events: # type: ignore[attr-defined]
return self.mcp_events.pop(0) # type: ignore[attr-defined]

# Then delegate to the base iterator
if not self.is_async:
try:
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
return next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
else:
raise StopIteration
except StopIteration:
self.finished = True
raise
else:
raise RuntimeError("Cannot use sync iteration on async iterator")
try:
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
return next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
raise StopIteration
except StopIteration:
self.finished = True
raise

def _ensure_sync_base_iterator(self) -> None:
"""Create base iterator synchronously when needed (for sync __next__ path)."""
if self.base_iterator is not None:
return
from litellm.litellm_core_utils.asyncify import run_async_function

run_async_function(self._create_initial_response_iterator)
Loading
Loading