diff --git a/docs/testing-guide.md b/docs/testing-guide.md index b19a0c90..2ed0d1ca 100644 --- a/docs/testing-guide.md +++ b/docs/testing-guide.md @@ -430,6 +430,53 @@ def sample_product_json(): --- +## In-Process MCP Fixtures + +For compliance fleets and integration tests that need a full `ADCPClient` +exercising the real protocol path against an in-process server (no +loopback HTTP), wire an `InMemoryTransport` pair and pass the connected +session to `ADCPClient.from_mcp_client()`: + +```python +import contextlib + +import pytest +from mcp import ClientSession +from mcp.shared.memory import create_client_server_memory_streams + +from adcp import ADCPClient + + +@pytest.fixture +async def in_process_client(my_mcp_server): + """ADCPClient backed by an in-process MCP transport. + + Caller owns the session lifecycle — `close()` and `async with` exit + on the returned client are no-ops. + """ + async with contextlib.AsyncExitStack() as stack: + (c_read, c_write), (s_read, s_write) = await stack.enter_async_context( + create_client_server_memory_streams() + ) + # wire your in-process server to (s_read, s_write) here + my_mcp_server.connect(s_read, s_write) + + session = await stack.enter_async_context(ClientSession(c_read, c_write)) + await session.initialize() + + yield ADCPClient.from_mcp_client(session, agent_id="fixture") +``` + +Why this matters: a loopback HTTP server adds a port-allocation race per +test, dies under high parallelism, and obscures bugs that only surface +when a real protocol path is exercised. The factory bypasses that without +giving up any of the client surface (validation hooks, idempotency, the +capability cache). + +For parity, see JS `AgentClient.fromMCPClient()` (v5.19.0). + +--- + ## Anti-Patterns to Avoid ### ❌ Don't Import from Internal Modules diff --git a/src/adcp/client.py b/src/adcp/client.py index 478e9c3b..5428fc37 100644 --- a/src/adcp/client.py +++ b/src/adcp/client.py @@ -12,12 +12,14 @@ from collections.abc import Callable, Iterator from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypedDict, cast +from uuid import uuid4 from a2a.types import Task, TaskStatusUpdateEvent from pydantic import BaseModel if TYPE_CHECKING: import httpx + from mcp import ClientSession from adcp.capabilities import TASK_FEATURE_MAP, FeatureResolver from adcp.exceptions import ADCPError, ADCPWebhookSignatureError @@ -641,6 +643,108 @@ def from_checkpoint( client.adapter._restore_active_task_id(active_task_id) return client + @classmethod + def from_mcp_client( + cls, + client: ClientSession, + *, + agent_id: str | None = None, + validation: ValidationHookConfig | None = None, + capabilities_ttl: float = 3600.0, + validate_features: bool = False, + strict_idempotency: bool = False, + ) -> ADCPClient: + """Create an ADCPClient wrapping a pre-connected MCP ClientSession. + + Parity with JS ``AgentClient.fromMCPClient()`` (v5.19.0). The primary + use case is compliance test fleets that wire a full ``ADCPClient`` + against an in-process MCP server without standing up a loopback HTTP + server. + + Warning: + The returned client's ``close()`` and ``async with`` ``__aexit__`` + are **no-ops** — the caller owns the injected session and is + responsible for closing it. Code that relies on ``async with + ADCPClient.from_mcp_client(...) as c:`` to clean up the session + will leak the session. + + Webhook delivery and ``on_activity`` callbacks are **not wired** + on the in-process path — there is no HTTP transport for the + seller to call back through. Don't pass these to the factory + (they're absent from the signature on purpose). + + If the injected session has not been initialized + (``await session.initialize()``), the first tool call surfaces + as an opaque MCP protocol error in ``TaskResult.error``. The + factory does not initialize for you — verify before calling. + + **Session lifecycle:** the caller owns the session — ``close()`` and + ``async with`` exit on the returned client are no-ops. Use your own + ``AsyncExitStack`` to scope both the transport and the client:: + + import contextlib + from mcp import ClientSession + from mcp.shared.memory import create_client_server_memory_streams + + async with contextlib.AsyncExitStack() as stack: + (c_read, c_write), (s_read, s_write) = await stack.enter_async_context( + create_client_server_memory_streams() + ) + # wire your in-process server to (s_read, s_write) here + session = await stack.enter_async_context( + ClientSession(c_read, c_write) + ) + await session.initialize() + # close() is a no-op on injected sessions; no stack.enter_async_context needed. + adcp_client = ADCPClient.from_mcp_client(session, agent_id="test-seller") + result = await adcp_client.get_products(GetProductsRequest(...)) + + Note: + Request signing is not supported on the injected-session path — + the signing hook is wired into the HTTP transport layer that is + bypassed here. ``signing=`` is intentionally absent from this + factory's parameters. + + Args: + client: A pre-connected ``mcp.ClientSession`` whose + ``initialize()`` has already been awaited. + agent_id: Identifier for the wrapped agent used in log messages + and error objects. Defaults to a unique ``in-process-XXXXXXXX`` + token; set this explicitly when running multiple in-process + agents concurrently so log lines are distinguishable. + validation: Schema-validation modes (same as ``__init__``). + strict_idempotency: Verify seller declared idempotency support + before each mutating call (same as ``__init__``). + validate_features: Gate tool calls on fetched capability + declarations (same as ``__init__``). + capabilities_ttl: TTL for the capability cache in seconds + (same as ``__init__``). + + Returns: + A fully configured ``ADCPClient`` backed by the injected session. + """ + effective_id = agent_id if agent_id is not None else f"in-process-{uuid4().hex[:8]}" + config = AgentConfig( + id=effective_id, + # RFC 2606 .invalid TLD — passes the http:// validator, guaranteed + # not to route to a real host. Self-documenting in error messages. + agent_uri="http://in-process.invalid", + protocol=Protocol.MCP, + ) + instance = cls( + config, + validation=validation, + strict_idempotency=strict_idempotency, + validate_features=validate_features, + capabilities_ttl=capabilities_ttl, + ) + if not isinstance(instance.adapter, MCPAdapter): + raise RuntimeError( # pragma: no cover + "from_mcp_client: expected MCPAdapter but got " f"{type(instance.adapter).__name__}" + ) + instance.adapter._inject_session(client) + return instance + async def _ensure_idempotency_capability(self) -> None: """Verify the seller positively declares idempotency support in capabilities. diff --git a/src/adcp/protocols/mcp.py b/src/adcp/protocols/mcp.py index 2c6bae12..46f21e16 100644 --- a/src/adcp/protocols/mcp.py +++ b/src/adcp/protocols/mcp.py @@ -217,6 +217,18 @@ def __init__(self, *args: Any, **kwargs: Any): ) self._session: Any = None self._exit_stack: Any = None + # True when the session was injected by ADCPClient.from_mcp_client(). + # Caller owns the lifecycle — close() is a no-op on injected adapters. + self._session_is_injected: bool = False + + def _inject_session(self, session: ClientSession) -> None: + """Pre-wire a caller-owned session, bypassing URL-based connection. + + Used by ADCPClient.from_mcp_client(). Once injected, _get_session() + returns it immediately and close() is a no-op (caller owns lifecycle). + """ + self._session = session + self._session_is_injected = True async def _cleanup_failed_connection(self, context: str) -> None: """ @@ -821,6 +833,8 @@ async def get_agent_info(self) -> dict[str, Any]: async def close(self) -> None: """Close the MCP session and clean up resources.""" + if self._session_is_injected: + return # caller owns lifecycle; never close an injected session await self._cleanup_failed_connection("during close") # ======================================================================== diff --git a/tests/test_protocols.py b/tests/test_protocols.py index e7c016ec..3a7d946c 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -1688,3 +1688,112 @@ async def test_cleanup_handles_exception_group_with_cancelled_error(self, mcp_co mock_exit_stack.aclose.assert_called_once() assert adapter._exit_stack is None assert adapter._session is None + + +class TestFromMcpClientFactory: + """Tests for ADCPClient.from_mcp_client() factory method.""" + + def _make_mock_session(self) -> AsyncMock: + """Return a minimal mock ClientSession.""" + session = AsyncMock() + mock_result = MagicMock() + mock_result.isError = False + mock_result.structuredContent = {"products": []} + mock_result.content = [] + session.call_tool.return_value = mock_result + return session + + def test_factory_injects_session(self): + """Injected session is wired directly — no URL connection.""" + from adcp import ADCPClient + + session = self._make_mock_session() + client = ADCPClient.from_mcp_client(session, agent_id="test-seller") + adapter = client.adapter + assert isinstance(adapter, MCPAdapter) + assert adapter._session is session + assert adapter._session_is_injected is True + + def test_factory_default_agent_id_is_unique(self): + """Default agent_id gets a unique in-process token each call.""" + from adcp import ADCPClient + + session = self._make_mock_session() + c1 = ADCPClient.from_mcp_client(session) + c2 = ADCPClient.from_mcp_client(session) + assert c1.agent_config.id.startswith("in-process-") + assert c2.agent_config.id.startswith("in-process-") + assert c1.agent_config.id != c2.agent_config.id + + def test_factory_explicit_agent_id(self): + """Explicit agent_id is preserved on the AgentConfig.""" + from adcp import ADCPClient + + session = self._make_mock_session() + client = ADCPClient.from_mcp_client(session, agent_id="seller-abc") + assert client.agent_config.id == "seller-abc" + + @pytest.mark.asyncio + async def test_tool_call_routes_through_injected_session(self): + """Tool calls use the injected session, not a URL-based connection.""" + from adcp import ADCPClient + + session = self._make_mock_session() + client = ADCPClient.from_mcp_client(session, agent_id="test-seller") + + result = await client.adapter._call_mcp_tool("get_products", {"brief": "test"}) + + session.call_tool.assert_called_once() + assert result.success is True + assert result.data == {"products": []} + + @pytest.mark.asyncio + async def test_close_is_noop_for_injected_session(self): + """close() does not call _cleanup_failed_connection for injected sessions.""" + from adcp import ADCPClient + + session = self._make_mock_session() + client = ADCPClient.from_mcp_client(session, agent_id="test-seller") + + with patch.object( + client.adapter, "_cleanup_failed_connection", new_callable=AsyncMock + ) as mock_cleanup: + await client.close() + mock_cleanup.assert_not_called() + + @pytest.mark.asyncio + async def test_async_context_manager_exit_is_noop(self): + """async with exit does not close the injected session.""" + from adcp import ADCPClient + + session = self._make_mock_session() + client = ADCPClient.from_mcp_client(session, agent_id="test-seller") + + with patch.object( + client.adapter, "_cleanup_failed_connection", new_callable=AsyncMock + ) as mock_cleanup: + async with client: + pass + mock_cleanup.assert_not_called() + + def test_validation_config_is_wired(self): + """Validation modes from the factory are applied to the adapter.""" + from adcp import ADCPClient + from adcp.validation.client_hooks import ValidationHookConfig + + session = self._make_mock_session() + config = ValidationHookConfig(requests="strict", responses="strict") + client = ADCPClient.from_mcp_client(session, validation=config) + assert client.adapter.request_validation_mode == "strict" + assert client.adapter.response_validation_mode == "strict" + + def test_idempotency_token_is_set(self): + """Factory-created client has a unique idempotency token on the adapter.""" + from adcp import ADCPClient + + session = self._make_mock_session() + c1 = ADCPClient.from_mcp_client(session) + c2 = ADCPClient.from_mcp_client(session) + assert c1.adapter.idempotency_client_token is not None + assert c2.adapter.idempotency_client_token is not None + assert c1.adapter.idempotency_client_token != c2.adapter.idempotency_client_token