Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ async def get_prompt(
async def cleanup(self):
"""Cleanup the server."""
async with self._cleanup_lock:
cleanup_cancelled = False
# Only raise HTTP errors if we're cleaning up after a failed connection.
# During normal teardown (via __aexit__), log but don't raise to avoid
# masking the original exception.
Expand All @@ -646,6 +647,7 @@ async def cleanup(self):
try:
await self.exit_stack.aclose()
except asyncio.CancelledError as e:
cleanup_cancelled = True
logger.debug(f"Cleanup cancelled for MCP server '{self.name}': {e}")
raise
except BaseExceptionGroup as eg:
Expand Down Expand Up @@ -709,7 +711,12 @@ async def cleanup(self):
else:
logger.error(f"Error cleaning up server: {e}")
finally:
self.session = None
if not cleanup_cancelled:
# Reset stack state only after a completed cleanup. If cleanup is cancelled,
# keep the existing stack so a follow-up cleanup can finish unwinding it.
self.exit_stack = AsyncExitStack()
self.session = None
self.server_initialize_result = None


class MCPServerStdioParams(TypedDict):
Expand Down
13 changes: 13 additions & 0 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from dataclasses import InitVar, dataclass, field
from typing import Any, Literal, TypeVar, cast

from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import (
Expand Down Expand Up @@ -124,6 +127,16 @@ class RunResultBase(abc.ABC):
_trace_state: TraceState | None = field(default=None, init=False, repr=False)
"""Serialized trace metadata captured during the run."""

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
# RunResult objects are runtime values; schema generation should treat them as instances
# instead of recursively traversing internal dataclass annotations.
return core_schema.is_instance_schema(cls)

@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
Expand Down
82 changes: 82 additions & 0 deletions tests/mcp/test_connect_disconnect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest.mock import AsyncMock, patch

import pytest
Expand All @@ -8,6 +9,28 @@
from .helpers import DummyStreamsContextManager, tee


class CountingStreamsContextManager:
def __init__(self, counter: dict[str, int]):
self.counter = counter

async def __aenter__(self):
self.counter["enter"] += 1
return (object(), object())

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.counter["exit"] += 1


class CancelThenCloseExitStack:
def __init__(self):
self.close_calls = 0

async def aclose(self):
self.close_calls += 1
if self.close_calls == 1:
raise asyncio.CancelledError("first cleanup interrupted")


@pytest.mark.asyncio
@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager())
@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
Expand Down Expand Up @@ -67,3 +90,62 @@ async def test_manual_connect_disconnect_works(

await server.cleanup()
assert server.session is None, "Server should be disconnected"


@pytest.mark.asyncio
@patch("agents.mcp.server.ClientSession.initialize", new_callable=AsyncMock, return_value=None)
@patch("agents.mcp.server.stdio_client")
async def test_cleanup_resets_exit_stack_and_reconnects(
mock_stdio_client: AsyncMock, mock_initialize: AsyncMock
):
counter = {"enter": 0, "exit": 0}
mock_stdio_client.side_effect = lambda params: CountingStreamsContextManager(counter)

server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)

await server.connect()
original_exit_stack = server.exit_stack

await server.cleanup()
assert server.session is None
assert server.exit_stack is not original_exit_stack
assert server.server_initialize_result is None
assert counter == {"enter": 1, "exit": 1}

await server.connect()
await server.cleanup()
assert counter == {"enter": 2, "exit": 2}


@pytest.mark.asyncio
async def test_cleanup_cancellation_preserves_exit_stack_for_retry():
server = MCPServerStdio(
params={
"command": tee,
},
cache_tools_list=True,
)
cancelled_exit_stack = CancelThenCloseExitStack()

server.exit_stack = cancelled_exit_stack # type: ignore[assignment]
server.session = object() # type: ignore[assignment]
server.server_initialize_result = object() # type: ignore[assignment]

with pytest.raises(asyncio.CancelledError):
await server.cleanup()

assert id(server.exit_stack) == id(cancelled_exit_stack)
assert server.session is not None
assert server.server_initialize_result is not None

await server.cleanup()

assert cancelled_exit_stack.close_calls == 2
assert id(server.exit_stack) != id(cancelled_exit_stack)
assert server.session is None
assert server.server_initialize_result is None
12 changes: 11 additions & 1 deletion tests/test_result_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from agents import (
Agent,
Expand Down Expand Up @@ -45,6 +45,16 @@ class Foo(BaseModel):
bar: int


def test_run_result_streaming_supports_pydantic_model_rebuild() -> None:
class StreamingRunContainer(BaseModel):
query_id: str
run_stream: RunResultStreaming | None

model_config = ConfigDict(arbitrary_types_allowed=True)

StreamingRunContainer.model_rebuild()


def _create_message(text: str) -> ResponseOutputMessage:
return ResponseOutputMessage(
id="msg",
Expand Down