-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_mcp_middleware_composition.py
More file actions
556 lines (446 loc) · 19.9 KB
/
test_mcp_middleware_composition.py
File metadata and controls
556 lines (446 loc) · 19.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
"""Integration test: custom HTTP middleware composes with SDK-registered tools.
Downstream agents (salesagent, creative agents) need to wire their own
auth middleware around tools registered by ``create_mcp_server()``. This
test proves the composition path works end-to-end:
1. ``mcp.streamable_http_app()`` returns a Starlette app that accepts
``.add_middleware()``.
2. The middleware fires before tool dispatch and can reject requests
(401 Unauthorized) or let them through.
3. When the middleware lets the request through, a ``context_factory``
passed to ``create_mcp_server()`` builds a :class:`ToolContext` the
handler receives — populated from the middleware's side-channel
(``contextvars.ContextVar``).
4. Tools in :data:`adcp.server.DISCOVERY_TOOLS` are callable without
auth (the spec-mandated handshake path).
5. JSON-RPC methods in :data:`adcp.server.DISCOVERY_METHODS`
(``initialize``, ``notifications/initialized``, ``tools/list``) are
callable pre-auth — MCP treats handshake + inventory as discovery.
If any of this regresses, salesagent and every other downstream has to
keep their wrapper layer (``mcp_context_wrapper.py``, custom
``@mcp.tool()`` scaffolding) forever. Failing here is the signal to fix
the integration, not the test.
"""
from __future__ import annotations
from contextvars import ContextVar
from typing import Any
import httpx
import pytest
from asgi_lifespan import LifespanManager
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from adcp.server import (
DISCOVERY_METHODS,
DISCOVERY_TOOLS,
ADCPHandler,
RequestMetadata,
ToolContext,
create_mcp_server,
)
_current_principal: ContextVar[str | None] = ContextVar("test_current_principal", default=None)
_current_tenant: ContextVar[str | None] = ContextVar("test_current_tenant", default=None)
class _RecordingHandler(ADCPHandler):
"""Handler that records the ToolContext each call received."""
def __init__(self) -> None:
self.calls: list[ToolContext | None] = []
async def get_adcp_capabilities(
self, params: Any, context: ToolContext | None = None
) -> dict[str, Any]:
self.calls.append(context)
return {"adcp": {"major_versions": [3]}}
async def get_products(self, params: Any, context: ToolContext | None = None) -> dict[str, Any]:
self.calls.append(context)
return {"products": []}
class _AuthMiddleware(BaseHTTPMiddleware):
"""Middleware that validates Authorization headers.
Lets the MCP discovery layer through (``DISCOVERY_METHODS`` +
``tools/call`` → ``DISCOVERY_TOOLS``) without a token; rejects
anything else lacking a valid token. On a valid token, stashes
principal + tenant in ContextVars so the handler-side
``context_factory`` can read them.
"""
VALID_TOKENS: dict[str, tuple[str, str]] = {
"token-acme": ("principal-acme-1", "tenant-acme"),
"token-beta": ("principal-beta-9", "tenant-beta"),
}
async def dispatch(self, request: Request, call_next: Any) -> Any:
method, tool_name = await _peek_jsonrpc(request)
is_discovery = method in DISCOVERY_METHODS or (
method == "tools/call" and tool_name in DISCOVERY_TOOLS
)
if not is_discovery:
auth = request.headers.get("authorization", "")
token = auth.removeprefix("Bearer ").strip()
if token not in self.VALID_TOKENS:
return JSONResponse({"error": "unauthenticated"}, status_code=401)
principal, tenant = self.VALID_TOKENS[token]
_principal_token = _current_principal.set(principal)
_tenant_token = _current_tenant.set(tenant)
else:
_principal_token = _current_principal.set(None)
_tenant_token = _current_tenant.set(None)
try:
return await call_next(request)
finally:
_current_principal.reset(_principal_token)
_current_tenant.reset(_tenant_token)
async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]:
"""Extract ``(method, tool_name)`` from the incoming JSON-RPC body
without consuming it for downstream handlers. ``tool_name`` is set
only for ``tools/call``."""
# Starlette caches ``request._body`` on first read, so subsequent
# reads inside the app still see the bytes.
body = await request.body()
if not body:
return None, None
try:
import json
payload = json.loads(body)
except ValueError:
return None, None
# JSON-RPC 2.0 batch arrays fall through to auth (fail closed).
if not isinstance(payload, dict):
return None, None
method = payload.get("method")
method = method if isinstance(method, str) else None
if method != "tools/call":
return method, None
params = payload.get("params") or {}
name = params.get("name")
return method, (name if isinstance(name, str) else None)
def _build_context(meta: RequestMetadata) -> ToolContext:
return ToolContext(
request_id=meta.request_id,
caller_identity=_current_principal.get(),
tenant_id=_current_tenant.get(),
metadata={"tool_name": meta.tool_name, "transport": meta.transport},
)
@pytest.fixture
async def handler_and_client() -> Any:
handler = _RecordingHandler()
mcp = create_mcp_server(
handler,
name="test-agent",
context_factory=_build_context,
)
# Force stateless JSON responses. Production deployments mount the
# MCP app behind a reverse proxy; this test covers that shape.
mcp.settings.stateless_http = True
mcp.settings.json_response = True
# Allow in-process test host — MCP's DNS-rebinding protection
# rejects unknown Host headers by default when enabled.
mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"]
app = mcp.streamable_http_app()
app.add_middleware(_AuthMiddleware)
# FastMCP's streamable HTTP session manager initializes a TaskGroup
# via the Starlette app lifespan. httpx.ASGITransport does not run
# lifespan by default — asgi-lifespan handles startup/shutdown and
# surfaces exceptions raised during startup so test failures report
# the real error instead of hanging.
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://localhost",
follow_redirects=True,
) as client:
yield handler, client
@pytest.mark.asyncio
async def test_discovery_tool_is_callable_without_auth(handler_and_client: Any) -> None:
handler, client = handler_and_client
await _initialize_session(client)
response = await _call_tool(client, "get_adcp_capabilities", {})
assert response.status_code == 200, response.text
payload = _parse_event_stream(response.text)
assert "result" in payload, payload
assert handler.calls, "handler was not invoked"
call_context = handler.calls[-1]
# Discovery calls have no authenticated principal — that's the whole point.
assert call_context is not None
assert call_context.caller_identity is None
assert call_context.tenant_id is None
@pytest.mark.asyncio
async def test_authenticated_tool_call_populates_caller_identity(
handler_and_client: Any,
) -> None:
handler, client = handler_and_client
await _initialize_session(client, headers={"Authorization": "Bearer token-acme"})
response = await _call_tool(
client,
"get_products",
{"brief": "coffee"},
headers={"Authorization": "Bearer token-acme"},
)
assert response.status_code == 200, response.text
call_context = handler.calls[-1]
assert call_context is not None
assert call_context.caller_identity == "principal-acme-1"
assert call_context.tenant_id == "tenant-acme"
@pytest.mark.asyncio
async def test_missing_token_blocks_non_discovery_tool(handler_and_client: Any) -> None:
handler, client = handler_and_client
response = await _call_tool(client, "get_products", {"brief": "coffee"})
assert response.status_code == 401
assert not handler.calls, (
"handler was invoked despite missing auth — middleware did NOT "
"compose with the tool dispatch"
)
@pytest.mark.asyncio
async def test_initialize_is_callable_without_auth(handler_and_client: Any) -> None:
"""``initialize`` is pre-auth per MCP spec. Pins the contract so a
future tightening of the gate breaks here, not in every fixture."""
_, client = handler_and_client
response = await _initialize_session(client)
assert response.status_code == 200, response.text
@pytest.mark.asyncio
async def test_tools_list_is_callable_without_auth(handler_and_client: Any) -> None:
"""``tools/list`` is pre-auth per MCP spec (discovery handshake).
An unauthenticated client gets the tool inventory. Operators who
consider the inventory sensitive can strip ``tools/list`` from
``DISCOVERY_METHODS`` in their own middleware — this test locks in
the default posture.
"""
_, client = handler_and_client
await _initialize_session(client)
response = await _list_tools(client)
assert response.status_code == 200, response.text
payload = _parse_event_stream(response.text)
assert "result" in payload, payload
tools = payload["result"].get("tools", [])
# Handler only overrides two tools but the base class advertises
# the full AdCP surface — just assert the list was returned.
assert isinstance(tools, list)
assert tools, "tools/list returned an empty inventory"
@pytest.mark.asyncio
async def test_tools_list_bypasses_gate_even_with_invalid_token(
handler_and_client: Any,
) -> None:
"""Negative control: an invalid ``Authorization`` header must NOT
cause the gate to reject ``tools/list``. Proves the gate is
consulting :data:`DISCOVERY_METHODS` rather than missing-header
being coincidentally treated as 'no auth attempt'."""
_, client = handler_and_client
await _initialize_session(client)
response = await _list_tools(client, headers={"Authorization": "Bearer not-valid"})
assert response.status_code == 200, response.text
def test_discovery_tools_frozenset_contract() -> None:
# Protects against accidental widening/narrowing of the spec-mandated
# auth-optional set. Callers extend via ``DISCOVERY_TOOLS | {...}``.
assert DISCOVERY_TOOLS == frozenset({"get_adcp_capabilities"})
def test_discovery_methods_frozenset_contract() -> None:
# The MCP discovery layer is ``initialize`` (session handshake),
# ``notifications/initialized`` (handshake-completion notification),
# and ``tools/list`` (inventory). Widening this set silently lets
# mutations through the auth gate; narrowing breaks clients that
# expect pre-auth discovery.
assert DISCOVERY_METHODS == frozenset({"initialize", "notifications/initialized", "tools/list"})
def test_validate_discovery_set_accepts_base_set() -> None:
from adcp.server import validate_discovery_set
# The base DISCOVERY_TOOLS set must always validate — any regression
# here means we added a mutation tool to the spec-mandated handshake.
validate_discovery_set(DISCOVERY_TOOLS)
def test_validate_discovery_set_accepts_read_only_extension() -> None:
from adcp.server import validate_discovery_set
# list_creative_formats is annotated read-only — downstream that
# wants to make format listing public should be allowed to.
validate_discovery_set(DISCOVERY_TOOLS | {"list_creative_formats"})
def test_validate_discovery_set_rejects_mutation_tool() -> None:
from adcp.server import validate_discovery_set
with pytest.raises(ValueError, match="non-read-only"):
validate_discovery_set(DISCOVERY_TOOLS | {"create_media_buy"})
def test_validate_discovery_set_rejects_unknown_tool() -> None:
from adcp.server import validate_discovery_set
with pytest.raises(ValueError, match="unknown tool"):
validate_discovery_set(DISCOVERY_TOOLS | {"not_a_real_tool"})
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
async def _initialize_session(
client: httpx.AsyncClient, *, headers: dict[str, str] | None = None
) -> httpx.Response:
"""Send an MCP ``initialize`` JSON-RPC call — FastMCP requires this
before ``tools/call`` even in stateless mode."""
request_headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
if headers:
request_headers.update(headers)
body = {
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0"},
},
}
return await client.post("/mcp/", json=body, headers=request_headers)
async def _call_tool(
client: httpx.AsyncClient,
tool_name: str,
arguments: dict[str, Any],
*,
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""POST a JSON-RPC ``tools/call`` to the MCP endpoint."""
request_headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
if headers:
request_headers.update(headers)
body = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": tool_name, "arguments": arguments},
}
return await client.post("/mcp/", json=body, headers=request_headers)
async def _list_tools(
client: httpx.AsyncClient,
*,
headers: dict[str, str] | None = None,
) -> httpx.Response:
"""POST a JSON-RPC ``tools/list`` to the MCP endpoint."""
request_headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
if headers:
request_headers.update(headers)
body = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
return await client.post("/mcp/", json=body, headers=request_headers)
def _parse_event_stream(body: str) -> dict[str, Any]:
"""Parse SSE event-stream body from FastMCP into a dict."""
import json
for line in body.splitlines():
line = line.strip()
if line.startswith("data: "):
return json.loads(line.removeprefix("data: "))
return json.loads(body) if body.strip() else {}
# ----------------------------------------------------------------------
# MCP middleware parity with A2A — ``create_mcp_server(middleware=[...])``
# ----------------------------------------------------------------------
@pytest.fixture
async def middleware_events() -> list[str]:
return []
@pytest.fixture
async def middleware_handler_and_client(middleware_events: list[str]) -> Any:
"""Fixture that wires a SkillMiddleware chain onto the MCP server.
Mirrors ``handler_and_client`` above but without the HTTP auth
layer so the middleware chain is the only thing under test."""
handler = _RecordingHandler()
async def outer(skill_name, params, context, call_next):
middleware_events.append(f"outer-pre:{skill_name}")
result = await call_next()
middleware_events.append(f"outer-post:{skill_name}")
return result
async def inner(skill_name, params, context, call_next):
middleware_events.append(f"inner-pre:{skill_name}")
result = await call_next()
middleware_events.append(f"inner-post:{skill_name}")
return result
mcp = create_mcp_server(
handler,
name="mw-test",
context_factory=_build_context,
middleware=[outer, inner],
)
mcp.settings.stateless_http = True
mcp.settings.json_response = True
mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"]
app = mcp.streamable_http_app()
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://localhost",
follow_redirects=True,
) as client:
yield handler, client
@pytest.mark.asyncio
async def test_mcp_middleware_composes_outermost_first(
middleware_handler_and_client: Any,
middleware_events: list[str],
) -> None:
"""MCP ``middleware=[outer, inner]`` matches A2A semantics: outer
pre-event comes first, then inner pre-event, then handler, then
inner post, then outer post. Stale ordering or reversed composition
would regress cross-transport parity."""
_, client = middleware_handler_and_client
await _initialize_session(client)
resp = await _call_tool(client, "get_adcp_capabilities", {})
assert resp.status_code == 200, resp.text
assert middleware_events == [
"outer-pre:get_adcp_capabilities",
"inner-pre:get_adcp_capabilities",
"inner-post:get_adcp_capabilities",
"outer-post:get_adcp_capabilities",
], middleware_events
@pytest.mark.asyncio
async def test_mcp_middleware_can_short_circuit() -> None:
"""Middleware that returns without calling ``call_next()`` MUST
stop the chain — handler doesn't run. Rate limiters use this."""
handler_calls: list[str] = []
class _ShortCircuitTarget(ADCPHandler):
async def get_adcp_capabilities(self, params, context=None):
handler_calls.append("called")
return {"adcp": {"major_versions": [3]}}
async def rate_limiter(skill_name, params, context, call_next):
return {"error": "rate-limited", "skill": skill_name}
mcp = create_mcp_server(
_ShortCircuitTarget(),
name="sc-test",
middleware=[rate_limiter],
)
mcp.settings.stateless_http = True
mcp.settings.json_response = True
mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"]
app = mcp.streamable_http_app()
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://localhost",
follow_redirects=True,
) as client:
await _initialize_session(client)
resp = await _call_tool(client, "get_adcp_capabilities", {})
assert resp.status_code == 200, resp.text
assert handler_calls == [], (
"middleware short-circuited but the handler still ran — MCP middleware "
"chain did not honour the 'skip call_next to skip handler' contract"
)
@pytest.mark.asyncio
async def test_mcp_middleware_sees_tool_context() -> None:
"""Middleware gets the same ToolContext the handler will receive.
When no context_factory is configured, middleware sees a default
ToolContext (not None) so the typed signature holds."""
seen: list[ToolContext] = []
async def record_context(skill_name, params, context, call_next):
seen.append(context)
return await call_next()
class _Handler(ADCPHandler):
async def get_adcp_capabilities(self, params, context=None):
return {"adcp": {"major_versions": [3]}}
mcp = create_mcp_server(
_Handler(),
name="ctx-test",
middleware=[record_context],
)
mcp.settings.stateless_http = True
mcp.settings.json_response = True
mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"]
app = mcp.streamable_http_app()
async with LifespanManager(app):
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://localhost",
follow_redirects=True,
) as client:
await _initialize_session(client)
await _call_tool(client, "get_adcp_capabilities", {})
assert len(seen) == 1
# No context_factory configured → middleware receives a synthesised
# default ToolContext so the signature type holds. Verified
# explicitly so a future change that passes None instead breaks here.
assert isinstance(seen[0], ToolContext)