diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md index f4d5f00a..0f400a41 100644 --- a/docs/handler-authoring.md +++ b/docs/handler-authoring.md @@ -167,18 +167,48 @@ authenticated requests. The SDK trusts the proxy's decision. Simplest, and the right choice when your identity provider and tool endpoints run behind the same gateway. -### Pattern 2 — in-process HTTP middleware +### Pattern 2 — in-process HTTP middleware (recommended) -Call `mcp.streamable_http_app()` to get the Starlette ASGI app, then -`app.add_middleware(YourAuthMiddleware)`. The middleware validates -credentials, stashes the resolved principal + tenant somewhere the -`context_factory` can read (ContextVars are recommended), and calls -`context_factory=` on `create_mcp_server()` to inject a typed -`ToolContext` per call. +Use `BearerTokenAuthMiddleware` and `auth_context_factory` from +`adcp.server`. The SDK owns the four security-critical concerns +(ContextVar carrier, `hmac.compare_digest`, discovery-method bypass, +reset-in-finally); you supply only `validate_token`: + +```python +from adcp.server import ( + BearerTokenAuthMiddleware, + Principal, + auth_context_factory, + create_mcp_server, +) + +async def validate_token(token: str) -> Principal | None: + row = await db.fetch_token(token) + if row is None or row.revoked: + return None + return Principal(caller_identity=row.principal_id, tenant_id=row.tenant_id) + +mcp = create_mcp_server(MyAgent(), context_factory=auth_context_factory) +app = mcp.streamable_http_app() +app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token) +``` + +`validate_token` may be sync or async — whichever matches your token +store. Return `None` to reject; don't raise (exceptions become 500s +and leak the presence of an auth path to attackers). Full worked example: `examples/mcp_with_auth_middleware.py`. Integration test proving the composition: `tests/test_mcp_middleware_composition.py`. +#### Pattern 2a — custom middleware (when the shipped one doesn't fit) + +Subclass `BearerTokenAuthMiddleware` to tighten the discovery bypass, +add extra headers, or customise the 401 response. For non-bearer auth +(mTLS, signed requests, API key via header), write a Starlette +middleware that populates `adcp.server.auth.current_principal` / +`current_tenant` yourself and keep using `auth_context_factory` — the +`ContextVar`s are the contract, not the middleware class. + ### Discovery tools bypass auth Per AdCP spec, `get_adcp_capabilities` is the handshake — clients MUST @@ -242,6 +272,38 @@ locks the default posture with a positive assertion that `tools/list` returns 200 without credentials and a negative control that the gate still lets it through when an invalid bearer is present. +## Custom tools alongside ADCP tools + +Some agents need to expose vendor-specific tools (an internal +`list_publishers` endpoint, a custom storyboard hook) that aren't part +of the AdCP spec. `create_mcp_server()` returns a bare FastMCP +instance — register custom tools on it with FastMCP's standard +`@mcp.tool()` decorator: + +```python +from adcp.server import create_mcp_server + +mcp = create_mcp_server(MyAgent(), name="my-agent") + +@mcp.tool() +async def list_publishers(region: str) -> list[dict]: + """Vendor-specific — not in the AdCP spec.""" + return await my_db.publishers_in(region) + +mcp.run(transport="streamable-http") +``` + +Custom tools appear in `tools/list` alongside the ADCP tools, carry +whatever schema FastMCP generates from the function signature, and do +**not** run through ADCP's spec-driven validation or the `SkillMiddleware` +chain — they're off-spec by construction. Use them for genuinely +vendor-specific surfaces; don't use them to "extend" AdCP operations +(that's what discriminated-union request subclasses are for). + +`tools/list` consumers that validate against the ADCP spec will flag +custom tools as unknown. Set expectations accordingly with clients +your agent talks to. + ## Request-body size cap `serve()` installs an ASGI middleware that caps incoming request @@ -519,10 +581,12 @@ ContextVar — treat it as a P0. ### Per-skill middleware (audit, activity feeds, rate limiting, tracing) -Every A2A skill dispatch can be wrapped in a chain of middleware -callables. Pass them as `middleware=[...]` to `create_a2a_server` / -`serve` / `ADCPAgentExecutor` — first entry wraps outermost, matching -Starlette/ASGI ordering: +Every skill dispatch — on **both** the MCP and A2A transports — can be +wrapped in a chain of middleware callables. Pass them as +`middleware=[...]` to `create_mcp_server` / `create_a2a_server` / +`serve` — first entry wraps outermost, matching Starlette/ASGI +ordering. The same list works across transports; write once, apply to +both: ```python from adcp.server import SkillMiddleware, ToolContext, serve @@ -546,6 +610,10 @@ async def audit_middleware( ) return result +# Works on MCP: +serve(MyAgent(), middleware=[audit_middleware]) + +# Same middleware list, A2A transport: serve(MyAgent(), transport="a2a", middleware=[audit_middleware]) ``` @@ -598,8 +666,37 @@ refs, proposal text, PII in message parts. `context` carries the complete skill surface. Treat it as a data processor under your GDPR/CCPA controller-processor agreements. -MCP transport has its own middleware story (see "Pattern 2 — -in-process HTTP middleware" above); `SkillMiddleware` is A2A-only. +`SkillMiddleware` applies on both transports — pass the same list to +`create_mcp_server(middleware=...)` and `create_a2a_server(middleware=...)`, +or to `serve(middleware=...)`. Per-transport HTTP middleware (the +`BearerTokenAuthMiddleware` from Pattern 2 above, for instance) is a +separate concern — HTTP middleware runs before JSON-RPC decode; +`SkillMiddleware` runs after skill dispatch is resolved. + +### Alternative A2A wire formats + +The default `ADCPAgentExecutor` parses incoming messages expecting +`DataPart(data={"skill": "", "parameters": {...}})` with a +TextPart JSON fallback. Sellers fronting clients that send a +different shape (JSON-RPC 2.0 bodies, vendor-specific DataParts, bare +TextPart with a different skill layout) can pass a custom +`message_parser`: + +```python +from adcp.server import MessageParser, create_a2a_server + +def my_parser(context): + # Parse your wire shape; return (skill_name, params) or (None, {}). + msg = context.message + ... + return skill_name, params + +app = create_a2a_server(MyAgent(), message_parser=my_parser) +``` + +Compose with the default when accepting both shapes — call +`ADCPAgentExecutor._default_parse_request` as a fallback after your +parser returns `(None, {})` for legacy clients. ### Known gaps diff --git a/examples/mcp_with_auth_middleware.py b/examples/mcp_with_auth_middleware.py index 34b7c68b..52bab061 100644 --- a/examples/mcp_with_auth_middleware.py +++ b/examples/mcp_with_auth_middleware.py @@ -1,23 +1,10 @@ -"""Example: custom HTTP auth middleware + typed ToolContext via context_factory. - -This is the recipe for multi-tenant sales agents that need to: - -1. Validate bearer tokens (or any other credential) in front of - :func:`adcp.server.create_mcp_server`-registered tools. -2. Allow the MCP handshake (``initialize``, ``tools/list``) and the AdCP - discovery handshake (``get_adcp_capabilities``) to go through - unauthenticated — per :data:`adcp.server.DISCOVERY_METHODS` and - :data:`adcp.server.DISCOVERY_TOOLS`. -3. Pass the authenticated principal + tenant to handlers as a typed - :class:`adcp.server.ToolContext`. - -Pre-auth posture: ``tools/list`` returns the full tool inventory -(names, input schemas, descriptions) without authentication — per MCP -spec, discovery is a handshake concern. Tool metadata is treated as -non-sensitive by default. Operators who consider their tool surface -sensitive should strip ``tools/list`` from ``DISCOVERY_METHODS`` in -their own middleware and gate it behind the same credential check as -``tools/call``. +"""Example: multi-tenant MCP server with bearer-token auth. + +Wires :class:`~adcp.server.BearerTokenAuthMiddleware` + +:func:`~adcp.server.auth_context_factory` onto a multi-tenant sales +agent. The SDK owns the security-critical plumbing (constant-time +token compare, discovery bypass, ``ContextVar`` reset-in-finally); +the seller supplies only ``validate_token`` and the handler logic. Run:: @@ -25,170 +12,53 @@ # → server on http://localhost:3001/mcp/ # curl -H 'Authorization: Bearer token-acme' ... -Production note: ``mcp.run()`` is used here for brevity. Real deployments -should mount the Starlette app behind uvicorn + a reverse proxy that -terminates TLS and handles rate limiting. +Production note: ``mcp.run()`` is used here for brevity. Real +deployments should mount the Starlette app behind uvicorn + a reverse +proxy that terminates TLS and handles rate limiting. """ from __future__ import annotations import hashlib -import hmac -from contextvars import ContextVar from typing import Any -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request -from starlette.responses import JSONResponse - from adcp.server import ( - DISCOVERY_METHODS, - DISCOVERY_TOOLS, ADCPHandler, - RequestMetadata, + BearerTokenAuthMiddleware, + Principal, ToolContext, + auth_context_factory, + constant_time_token_match, create_mcp_server, ) from adcp.server.responses import capabilities_response, products_response -# ---------------------------------------------------------------------- -# Per-request auth state — populated by middleware, read by context_factory. -# ContextVars are the recommended carrier: they compose cleanly with -# async tasks and don't leak across requests the way module globals do. -# IMPORTANT: always pair ``.set(x)`` with ``.reset(token)`` in a ``finally:`` -# block so the value doesn't linger in the current context past the -# response — otherwise a subsequent task reusing the same context reads a -# stale principal (cross-request confidentiality leak). -# ---------------------------------------------------------------------- - -_principal: ContextVar[str | None] = ContextVar("adcp_principal", default=None) -_tenant: ContextVar[str | None] = ContextVar("adcp_tenant", default=None) - - -# Real agents look tokens up in Postgres / Vault / an identity provider / -# etc. This dict is a stand-in: it stores a per-token SHA-256 so the -# example's token-compare path uses ``hmac.compare_digest`` (constant-time) -# against a hash rather than comparing raw bearer tokens with ``==`` or -# ``in``. Never ship plain-text token equality against a user-supplied -# bearer token — it leaks information via timing, and dict lookups short- -# circuit on hash mismatch. -_TOKEN_HASHES: dict[str, tuple[str, str]] = { - hashlib.sha256(raw.encode()).hexdigest(): (principal, tenant) - for raw, (principal, tenant) in { - "token-acme": ("principal-acme-ops", "tenant-acme"), - "token-globex": ("principal-globex-ops", "tenant-globex"), +# Real agents look tokens up in Postgres / Vault / an identity provider. +# Keyed by SHA-256 so the comparison uses ``hmac.compare_digest`` rather +# than raw string equality — never compare raw bearer tokens with ``==``. +_TOKEN_HASHES: dict[str, Principal] = { + hashlib.sha256(raw.encode()).hexdigest(): principal + for raw, principal in { + "token-acme": Principal( + caller_identity="principal-acme-ops", + tenant_id="tenant-acme", + ), + "token-globex": Principal( + caller_identity="principal-globex-ops", + tenant_id="tenant-globex", + ), }.items() } -def _lookup_token(token: str) -> tuple[str, str] | None: - """Constant-time bearer-token lookup. +def validate_token(token: str) -> Principal | None: + """Seller-supplied token validator. - Iterate all known hashes with ``hmac.compare_digest`` so the wall-clock - runtime doesn't depend on how much of the candidate matches any entry — - the dict-lookup-then-equality pattern leaks that. - """ - if not token: - return None - candidate = hashlib.sha256(token.encode()).hexdigest() - for stored_hash, identity in _TOKEN_HASHES.items(): - if hmac.compare_digest(candidate, stored_hash): - return identity - return None - - -# ---------------------------------------------------------------------- -# HTTP middleware — auth gate, honors DISCOVERY_METHODS + DISCOVERY_TOOLS. -# ---------------------------------------------------------------------- - - -class BearerAuthMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: Any) -> Any: - method, tool = await _peek_jsonrpc(request) - - principal_token = None - tenant_token = None - try: - # MCP spec: ``initialize`` + ``tools/list`` are the handshake / - # discovery layer — pre-auth by spec. - # AdCP spec: ``get_adcp_capabilities`` is the capability handshake — - # pre-auth by spec. - is_discovery = method in DISCOVERY_METHODS or ( - method == "tools/call" and tool in DISCOVERY_TOOLS - ) - if is_discovery: - principal_token = _principal.set(None) - tenant_token = _tenant.set(None) - return await call_next(request) - - # Everything else requires a bearer token. - auth_header = request.headers.get("authorization", "") - bearer = auth_header.removeprefix("Bearer ").strip() - identity = _lookup_token(bearer) - if identity is None: - return JSONResponse({"error": "unauthenticated"}, status_code=401) - - principal_id, tenant_id = identity - principal_token = _principal.set(principal_id) - tenant_token = _tenant.set(tenant_id) - return await call_next(request) - finally: - # Reset unconditionally. Without this, a later task running in - # the same context reads the leftover principal — a - # cross-request confidentiality leak. - if principal_token is not None: - _principal.reset(principal_token) - if tenant_token is not None: - _tenant.reset(tenant_token) - - -async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]: - """Inspect the JSON-RPC body without consuming it for downstream handlers. - - Returns ``(method, tool_name)``. ``tool_name`` is set only when the - method is ``tools/call``; for ``initialize``, ``tools/list``, and - other methods it is ``None``. + ``constant_time_token_match`` iterates every stored hash with + :func:`hmac.compare_digest`, avoiding the prefix-match timing leak + that a plain ``dict`` lookup would have. """ - body = await request.body() - if not body: - return None, None - import json - - try: - payload = json.loads(body) - except ValueError: - return None, None - # JSON-RPC 2.0 allows batch arrays. Fail closed — let auth run — so a - # client can't smuggle a mutation past the gate inside a batch. - if not isinstance(payload, dict): - return None, None - method = payload.get("method") - method = method if isinstance(method, str) else None - if method != "tools/call": - return method, None - params = payload.get("params") or {} - name = params.get("name") - return method, (name if isinstance(name, str) else None) - - -# ---------------------------------------------------------------------- -# context_factory — runs per tool call, reads the ContextVars the -# middleware populated, returns a typed ToolContext. -# ---------------------------------------------------------------------- - - -def build_context(meta: RequestMetadata) -> ToolContext: - return ToolContext( - request_id=meta.request_id, - caller_identity=_principal.get(), - tenant_id=_tenant.get(), - metadata={"tool_name": meta.tool_name, "transport": meta.transport}, - ) - - -# ---------------------------------------------------------------------- -# Handler — reads caller_identity + tenant_id off the ToolContext. -# ---------------------------------------------------------------------- + return constant_time_token_match(token, _TOKEN_HASHES) class MultiTenantSalesAgent(ADCPHandler): @@ -200,11 +70,8 @@ async def get_adcp_capabilities( return capabilities_response(["media_buy"]) async def get_products(self, params: Any, context: ToolContext | None = None) -> dict[str, Any]: - # context.caller_identity is the authenticated principal; - # context.tenant_id is populated for multi-tenant agents. tenant = context.tenant_id if context is not None else None - catalog = _products_for_tenant(tenant) - return products_response(catalog) + return products_response(_products_for_tenant(tenant)) def _products_for_tenant(tenant_id: str | None) -> list[dict[str, Any]]: @@ -215,27 +82,14 @@ def _products_for_tenant(tenant_id: str | None) -> list[dict[str, Any]]: return [] -# ---------------------------------------------------------------------- -# Wiring — create_mcp_server with context_factory, then add middleware -# to the Starlette app. -# ---------------------------------------------------------------------- - - def main() -> None: mcp = create_mcp_server( MultiTenantSalesAgent(), name="multi-tenant-sales-agent", - context_factory=build_context, + context_factory=auth_context_factory, ) - - # Middleware must be added BEFORE the app runs. create_mcp_server - # returns a FastMCP instance; its ASGI app is accessed via - # streamable_http_app(), which is a standard Starlette app. app = mcp.streamable_http_app() - app.add_middleware(BearerAuthMiddleware) - - # mcp.run() hands control to FastMCP. In production, mount with - # uvicorn and a reverse proxy for TLS + rate limiting. + app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token) mcp.run(transport="streamable-http") diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index be41dfed..885fd829 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -53,7 +53,14 @@ async def get_products(params, context=None): from __future__ import annotations from adcp.capabilities import validate_capabilities -from adcp.server.a2a_server import ADCPAgentExecutor, create_a2a_server +from adcp.server.a2a_server import ADCPAgentExecutor, MessageParser, create_a2a_server +from adcp.server.auth import ( + BearerTokenAuthMiddleware, + Principal, + TokenValidator, + auth_context_factory, + constant_time_token_match, +) from adcp.server.base import ( AccountAwareToolContext, ADCPHandler, @@ -161,8 +168,15 @@ async def get_products(params, context=None): "validate_discovery_set", # A2A integration "ADCPAgentExecutor", + "MessageParser", "SkillMiddleware", "create_a2a_server", + # Bearer-token auth middleware (seller-facing recipe) + "BearerTokenAuthMiddleware", + "Principal", + "TokenValidator", + "auth_context_factory", + "constant_time_token_match", # Idempotency middleware (AdCP #2315 seller side) "IdempotencyStore", "MemoryBackend", diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index e8a1834d..610baba1 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -56,7 +56,28 @@ from a2a.server.tasks.task_store import TaskStore from adcp.server.serve import ContextFactory, SkillMiddleware -from adcp.server.helpers import STANDARD_ERROR_CODES + +from collections.abc import Callable # noqa: E402 + +MessageParser = Callable[[RequestContext], tuple[str | None, dict[str, Any]]] +"""Callable that extracts ``(skill_name, params)`` from an incoming +A2A :class:`RequestContext`. + +The default parser handles ``DataPart(data={"skill": ..., +"parameters": ...})`` plus a TextPart JSON fallback. Override this +hook to accept alternative wire shapes — JSON-RPC 2.0 message bodies, +vendor-specific DataPart schemas, or text-only skill encodings. Return +``(None, {})`` to signal "no parseable skill"; the executor will emit +an error Task for the client. + +Pair with :meth:`ADCPAgentExecutor._default_parse_request` when you +want to accept a custom shape *in addition to* the built-in shapes — +call the default as a fallback after your own parser returns +``(None, {})``. +""" + + +from adcp.server.helpers import STANDARD_ERROR_CODES # noqa: E402 from adcp.server.mcp_tools import create_tool_caller, get_tools_for_handler from adcp.server.test_controller import TestControllerStore, _handle_test_controller @@ -81,6 +102,7 @@ def __init__( *, context_factory: ContextFactory | None = None, middleware: Sequence[SkillMiddleware] | None = None, + message_parser: MessageParser | None = None, advertise_all: bool = False, ) -> None: self._handler = handler @@ -91,6 +113,10 @@ def __init__( # ordering; first entry wraps outermost (see ``SkillMiddleware`` # docstring for the composition semantics). self._middleware: tuple[SkillMiddleware, ...] = tuple(middleware or ()) + # Seller-supplied parser for non-default wire shapes (JSON-RPC, + # bare TextPart with different skill layout, etc.). Falls back + # to the built-in parser when None. + self._message_parser: MessageParser | None = message_parser self._tool_callers: dict[str, Any] = {} # Build tool callers for all tools this handler supports. @@ -165,34 +191,24 @@ async def _dispatch_with_middleware( ) -> Any: """Run the handler wrapped in the configured middleware chain. - Middleware composes outermost-first: the first entry in - ``self._middleware`` sees every call *before* the later entries - and *before* the handler. This matches Starlette / ASGI - conventions so sellers porting from those stacks aren't - surprised. Composition is done via a small recursive dispatcher - (no mutable indices, no lambdas closing over loop variables) — - the chain reads the same whether you have zero or ten - middlewares. + Delegates to :func:`adcp.server.serve._dispatch_with_middleware` + so the composition semantics stay identical between transports — + middleware that works with ``create_a2a_server(middleware=...)`` + works unchanged with ``create_mcp_server(middleware=...)``. Middleware exceptions propagate to the executor's normal error handling path in ``execute()``; this method does no try/except so short-circuiting, transform, and exception-observation all work the same way they do for the underlying handler. """ - if not self._middleware: - return await self._tool_callers[skill_name](params, tool_context) - - async def _step(index: int) -> Any: - if index >= len(self._middleware): - return await self._tool_callers[skill_name](params, tool_context) - middleware = self._middleware[index] - - async def call_next() -> Any: - return await _step(index + 1) + from adcp.server.serve import _dispatch_with_middleware - return await middleware(skill_name, params, tool_context, call_next) + async def _call_handler() -> Any: + return await self._tool_callers[skill_name](params, tool_context) - return await _step(0) + return await _dispatch_with_middleware( + self._middleware, skill_name, params, tool_context, _call_handler + ) def _build_tool_context(self, skill_name: str, request: RequestContext) -> ToolContext: """Build the :class:`ToolContext` handed to the skill dispatcher. @@ -253,10 +269,26 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None def _parse_request(self, context: RequestContext) -> tuple[str | None, dict[str, Any]]: """Extract skill name and parameters from the A2A message. - Supports two formats: + Dispatches to the caller-supplied :data:`MessageParser` when the + executor was constructed with ``message_parser=``; otherwise + falls through to :meth:`_default_parse_request`, which supports + the standard shapes (DataPart with explicit skill + TextPart + JSON fallback). + """ + if self._message_parser is not None: + return self._message_parser(context) + return self._default_parse_request(context) + + def _default_parse_request(self, context: RequestContext) -> tuple[str | None, dict[str, Any]]: + """Built-in parser. Supports two formats: + 1. Explicit skill invocation via DataPart: DataPart(data={"skill": "get_products", "parameters": {...}}) 2. Natural language fallback via TextPart (best-effort parse) + + Exposed as a module-level method so custom parsers can compose + it — e.g. "try my JSON-RPC parser first, fall through to the + default for legacy clients". """ msg = context.message if msg is None or not msg.parts: @@ -514,6 +546,7 @@ def create_a2a_server( task_store: TaskStore | None = None, push_config_store: PushNotificationConfigStore | None = None, middleware: Sequence[SkillMiddleware] | None = None, + message_parser: MessageParser | None = None, advertise_all: bool = False, ) -> Any: """Create an A2A Starlette application from an ADCP handler. @@ -570,6 +603,15 @@ def create_a2a_server( :data:`~adcp.server.SkillMiddleware` for the signature, composition semantics, and the exception-capture pattern audit hooks need. + message_parser: Optional :data:`MessageParser` for alternative + wire shapes. The default parser handles ``DataPart(data={ + "skill": ..., "parameters": ...})`` plus a TextPart JSON + fallback. Supply this to accept JSON-RPC 2.0 message bodies, + vendor-specific DataPart schemas, or other layouts. The + callable returns ``(skill_name, params)`` or ``(None, {})`` + for "no parseable skill"; see :data:`MessageParser` and + :meth:`ADCPAgentExecutor._default_parse_request` for the + built-in fallback shape to delegate to for legacy clients. advertise_all: When True, advertise every tool the handler type supports — including ones whose method is still the SDK's ``not_supported`` default. Defaults to ``False``, which @@ -590,6 +632,7 @@ def create_a2a_server( test_controller=test_controller, context_factory=context_factory, middleware=middleware, + message_parser=message_parser, advertise_all=advertise_all, ) @@ -621,6 +664,19 @@ def create_a2a_server( http_handler=request_handler, ) + # Startup log lives on the create_a2a_server path (symmetric with + # MCP's _register_handler_tools). Moved out of + # ADCPAgentExecutor.__init__ so per-test executor constructions + # don't pollute caplog with repeated startup messages. + from adcp.server.serve import _log_advertised_tools + + _log_advertised_tools( + transport="a2a", + handler=handler, + advertise_all=advertise_all, + registered=list(executor.supported_skills), + ) + return a2a_app.build() diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py new file mode 100644 index 00000000..ad577df6 --- /dev/null +++ b/src/adcp/server/auth.py @@ -0,0 +1,395 @@ +"""Bearer-token HTTP authentication middleware for ADCP MCP servers. + +`examples/mcp_with_auth_middleware.py` is the full, load-bearing +recipe for multi-tenant sellers. Four things have to be right at the +same time — a ContextVar carrier for the authenticated principal, +constant-time token compare, the AdCP/MCP discovery-method bypass, and +reset-in-finally to prevent cross-request leak. Getting any of them +wrong is a security incident. This module factors that recipe into a +middleware class + matching ``context_factory`` so sellers write four +lines of wiring instead of four pages of auth code. + +Typical usage:: + + from adcp.server import create_mcp_server + from adcp.server.auth import ( + BearerTokenAuthMiddleware, + Principal, + auth_context_factory, + ) + + async def validate_token(token: str) -> Principal | None: + row = await db.fetch_token(token) + if row is None or row.revoked: + return None + return Principal( + caller_identity=row.principal_id, + tenant_id=row.tenant_id, + ) + + mcp = create_mcp_server(MyAgent(), context_factory=auth_context_factory) + app = mcp.streamable_http_app() + app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token) + +The middleware populates module-level ``ContextVar``s that +``auth_context_factory`` reads to build a +:class:`~adcp.server.ToolContext` per call. The same module-level +vars compose with any other auth layer a seller writes on top — e.g., +an additional role-check middleware that reads +:data:`current_principal`. + +Security invariants the middleware enforces: + +* Tokens are compared with :func:`hmac.compare_digest` over SHA-256 + hashes, not raw string equality — :meth:`dict.__contains__` leaks + match-prefix timing. +* ``initialize`` and ``tools/list`` (MCP handshake) plus + ``get_adcp_capabilities`` (AdCP handshake) are exempt per spec; + every other request requires a valid bearer token. +* ``ContextVar``s are reset in ``finally`` so a later task sharing the + context can't read a stale principal. +* The JSON-RPC body is peeked but not consumed — downstream handlers + still read the same bytes (Starlette caches the body via the + ``_body`` attribute on the request). + +What this middleware does NOT do: + +* **Token storage.** You supply ``validate_token``; where tokens live + (Postgres, Redis, Vault, an IdP) is yours to design. +* **Authorization.** The middleware answers "who is this?", not "can + they do X?". Authorization checks run on the authenticated principal + inside your handlers or as :data:`~adcp.server.SkillMiddleware`. +* **A2A auth.** A2A uses a different transport; wire a2a-sdk's + ``ServerCallContext.user`` via a2a-sdk auth middleware on that side. + The ``Principal`` / ``ToolContext`` shape is the same, so handlers + work unchanged across transports. +""" + +from __future__ import annotations + +import hashlib +import hmac +import inspect +import json +import logging +from collections.abc import Awaitable, Callable +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse + +from adcp.server.base import ToolContext +from adcp.server.mcp_tools import DISCOVERY_METHODS, DISCOVERY_TOOLS + +logger = logging.getLogger("adcp.server.auth") + + +def _parse_bearer_header(header: str) -> str | None: + """Parse ``Authorization: Bearer `` per RFC 7235. + + Scheme comparison is case-insensitive and tolerates folded + whitespace (any run of spaces, tabs, or newlines) between the + scheme and the token — some clients send ``bearer`` (lowercase), + ``Bearer\\t``, or ``Bearer `` (double space). Returns + ``None`` when the scheme doesn't match or the token is empty / + whitespace-only. + """ + parts = header.split(maxsplit=1) + if len(parts) != 2: + return None + scheme, token = parts + if scheme.lower() != "bearer": + return None + return token.strip() or None + + +if TYPE_CHECKING: + from starlette.requests import Request + + from adcp.server.serve import RequestMetadata + + +@dataclass(frozen=True) +class Principal: + """An authenticated principal — the result of token validation. + + Returned by a :data:`TokenValidator` on success. Used to populate + the transport-layer ``ContextVar``s that :func:`auth_context_factory` + reads when building per-call :class:`~adcp.server.ToolContext`. + + :param caller_identity: Stable, globally-unique principal id within + the tenant. See the + :class:`~adcp.server.ToolContext.caller_identity` docstring for + the stability contract and the failure mode when this is + reused across logical principals. + :param tenant_id: Tenant the principal belongs to. Populate unless + your principal ids are globally unique across tenants — the + server-side idempotency store scopes cache keys on + ``(tenant_id, caller_identity)``. See + :doc:`/multi-tenant-contract` for the full invariants. + :param metadata: Optional extra fields the context_factory should + propagate into :class:`~adcp.server.ToolContext.metadata`. + """ + + caller_identity: str + tenant_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +TokenValidator = Callable[[str], "Principal | None | Awaitable[Principal | None]"] +"""Seller-supplied callable that validates a bearer token. + +Called with the raw token string (``Authorization: Bearer `` +with the prefix already stripped). Return a :class:`Principal` on +success, ``None`` to reject. + +Sync and async callables are both accepted — the middleware awaits the +result when it's awaitable, so plain ``def validate_token(...)`` and +``async def validate_token(...)`` both work. + +**Do not raise on invalid tokens.** Exceptions become ``500 Internal +Server Error`` responses, which leak the presence of an auth path +to attackers who can't know a valid token. Return ``None`` instead. +""" + + +# Module-level ``ContextVar``s populated by the middleware, read by the +# matching ``context_factory``. Exported so sellers can read them from +# their own composed middleware layers (rate-limiter keyed by +# principal, per-tenant feature flags, etc.) without re-authenticating. +# +# Named ``current_*`` to match the FastAPI / Starlette convention for +# per-request state carriers. Keep ``default=None`` so a pre-auth or +# discovery-exempt request reads ``None`` instead of raising +# ``LookupError``. +current_principal: ContextVar[str | None] = ContextVar("adcp_auth_principal", default=None) +current_tenant: ContextVar[str | None] = ContextVar("adcp_auth_tenant", default=None) +current_principal_metadata: ContextVar[dict[str, Any] | None] = ContextVar( + "adcp_auth_principal_metadata", default=None +) + + +class BearerTokenAuthMiddleware(BaseHTTPMiddleware): + """Starlette HTTP middleware that gates every non-discovery JSON-RPC + request on a valid bearer token. + + Instantiate via ``app.add_middleware`` with a seller-supplied + :data:`TokenValidator`:: + + app.add_middleware( + BearerTokenAuthMiddleware, + validate_token=my_validate_token, + ) + + On success, populates :data:`current_principal`, + :data:`current_tenant`, and :data:`current_principal_metadata` + for the duration of the downstream call. On failure, returns + ``401`` without invoking the handler. + + **Discovery bypass.** ``initialize``, ``notifications/initialized``, + and ``tools/list`` (MCP handshake) plus ``get_adcp_capabilities`` + (AdCP handshake) are always exempt — these run before any client + has credentials. Operators who consider their tool surface + sensitive can subclass and override :meth:`is_discovery_request` + to tighten the bypass (e.g. require auth on ``tools/list``). + + **Body is peeked, not consumed.** The middleware reads the + JSON-RPC payload to identify the ``method`` / ``tool`` name for + the discovery gate; Starlette caches the body on the request so + handlers still read it normally. + + :param app: The inner ASGI app. Passed by Starlette — + ``app.add_middleware`` supplies it automatically. + :param validate_token: Your token lookup. See :data:`TokenValidator`. + :param unauthenticated_response: Optional override for the 401 + response body. Default is ``{"error": "unauthenticated"}``. + """ + + def __init__( + self, + app: Any, + *, + validate_token: TokenValidator, + unauthenticated_response: dict[str, Any] | None = None, + ) -> None: + super().__init__(app) + self._validate_token = validate_token + self._unauth_body = unauthenticated_response or {"error": "unauthenticated"} + + async def dispatch(self, request: Request, call_next: Any) -> Any: + method, tool = await self._peek_jsonrpc(request) + + principal_token = None + tenant_token = None + metadata_token = None + try: + if self.is_discovery_request(method, tool): + principal_token = current_principal.set(None) + tenant_token = current_tenant.set(None) + metadata_token = current_principal_metadata.set(None) + return await call_next(request) + + bearer = _parse_bearer_header(request.headers.get("authorization", "")) + if not bearer: + return self._unauthenticated() + + try: + raw = self._validate_token(bearer) + principal: Principal | None + if inspect.isawaitable(raw): + principal = await raw + else: + principal = raw + except Exception: + # Validator failure must not leak stack info to the caller. + # Fail closed — a buggy validator is an auth failure, not a + # 500. Logged for operators. + logger.exception("token validator raised") + return self._unauthenticated() + + if principal is None: + return self._unauthenticated() + + principal_token = current_principal.set(principal.caller_identity) + tenant_token = current_tenant.set(principal.tenant_id) + metadata_token = current_principal_metadata.set( + dict(principal.metadata) if principal.metadata else None + ) + return await call_next(request) + finally: + # Reset unconditionally so a later task sharing this context + # doesn't read a stale principal. Matches the idempotency + # store's "fail fast on missing caller_identity" contract. + if principal_token is not None: + current_principal.reset(principal_token) + if tenant_token is not None: + current_tenant.reset(tenant_token) + if metadata_token is not None: + current_principal_metadata.reset(metadata_token) + + def is_discovery_request(self, method: str | None, tool: str | None) -> bool: + """True when the request should bypass auth. + + Defaults to the spec-mandated discovery set. Subclass + override + to tighten (e.g. require auth on ``tools/list``) or loosen + (e.g. add a seller-specific unauthenticated ping method). + """ + if method in DISCOVERY_METHODS: + return True + return method == "tools/call" and tool in DISCOVERY_TOOLS + + def _unauthenticated(self) -> JSONResponse: + return JSONResponse(self._unauth_body, status_code=401) + + @staticmethod + async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]: + """Inspect the JSON-RPC body without preventing handlers from + reading it downstream. Returns ``(method, tool_name)``. + + Explicitly caches the body on ``request._body`` so downstream + handlers receive the same bytes. Starlette's ``Request`` caches + the first ``.body()`` call via this attribute, but relying on + that behavior implicitly is fragile — nested ASGI apps that + read the raw ``receive`` callable (as FastMCP's streamable-HTTP + transport does) will otherwise observe an empty body. The + explicit assignment matches the documented Starlette middleware + body-peek pattern. + + Fails closed on batch arrays — the JSON-RPC 2.0 spec allows + them, but the handshake methods never come in batches and + permitting them here would let a client smuggle a mutation past + the discovery gate inside a batch. + """ + body = await request.body() + # Ensure the body is cached for downstream reads. ``request.body()`` + # already sets ``_body``; the explicit re-assignment is a belt-and- + # suspenders guard against Starlette internals changing and a + # pinned target for the body-round-trip test. + request._body = body + if not body: + return None, None + try: + payload = json.loads(body) + except ValueError: + return None, None + if not isinstance(payload, dict): + return None, None + method = payload.get("method") + method = method if isinstance(method, str) else None + if method != "tools/call": + return method, None + params = payload.get("params") or {} + name = params.get("name") if isinstance(params, dict) else None + return method, (name if isinstance(name, str) else None) + + +# ------------------------------------------------------------------ +# Matching context_factory — reads what the middleware populated. +# ------------------------------------------------------------------ + + +def auth_context_factory(meta: RequestMetadata) -> ToolContext: + """Build a :class:`~adcp.server.ToolContext` from the ContextVars + :class:`BearerTokenAuthMiddleware` populates. + + Pass this to :func:`~adcp.server.create_mcp_server` (or + :func:`~adcp.server.serve`) alongside the middleware so handlers + receive a typed context carrying the authenticated principal. + + Populates ``caller_identity``, ``tenant_id``, and a ``metadata`` + dict containing the transport + tool name plus anything the + :class:`Principal` provided. SDK-owned keys (``tool_name``, + ``transport``) take precedence over principal-supplied keys, so a + validator returning ``Principal(metadata={"tool_name": "x"})`` + cannot shadow audit fields the SDK populates. Returns a bare + :class:`ToolContext` — agents that want a typed subclass + (e.g. :class:`~adcp.server.AccountAwareToolContext`) should copy + the three-line body and return their own subclass instead. + """ + principal_metadata = current_principal_metadata.get() or {} + combined_metadata: dict[str, Any] = { + **principal_metadata, + "tool_name": meta.tool_name, + "transport": meta.transport, + } + return ToolContext( + request_id=meta.request_id, + caller_identity=current_principal.get(), + tenant_id=current_tenant.get(), + metadata=combined_metadata, + ) + + +# ------------------------------------------------------------------ +# Helpers sellers sometimes need when building their own validator. +# ------------------------------------------------------------------ + + +def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any: + """Look up a token in a dict of SHA-256 hashes using + :func:`hmac.compare_digest` rather than dict-containment. + + Dict lookup + equality (``candidate_hash in stored_hashes``) leaks + prefix-match timing because the hash comparison short-circuits on + first byte mismatch. Iterating every stored hash with + ``compare_digest`` makes the wall-clock runtime independent of + how much of the candidate matches any entry. + + Use this when your token store is small enough to iterate linearly + (hundreds to low-thousands). For larger stores, use a database + column of hashed tokens with an equality index + one + ``compare_digest`` check on the single returned row. + + :param token: Raw bearer token supplied by the client. + :param stored_hashes: ``{sha256_hex: value}`` dictionary. Returns + ``value`` on the matching entry, ``None`` on no match. + """ + if not token: + return None + candidate = hashlib.sha256(token.encode()).hexdigest() + for stored_hash, value in stored_hashes.items(): + if hmac.compare_digest(candidate, stored_hash): + return value + return None diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index a624abac..eb8ae361 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -18,11 +18,14 @@ async def get_adcp_capabilities(self, params, context=None): from __future__ import annotations +import logging import os from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal +logger = logging.getLogger("adcp.server") + from adcp.server.base import ADCPHandler, ToolContext from adcp.server.mcp_tools import create_tool_caller, get_tools_for_handler @@ -34,6 +37,7 @@ async def get_adcp_capabilities(self, params, context=None): ) from a2a.server.tasks.task_store import TaskStore + from adcp.server.a2a_server import MessageParser from adcp.server.test_controller import TestControllerStore @@ -67,8 +71,11 @@ class RequestMetadata: [str, dict[str, Any], ToolContext, Callable[[], Awaitable[Any]]], Awaitable[Any], ] -"""Middleware that wraps A2A skill dispatch — the audit / activity-feed / -rate-limiter / tracing hook for the A2A transport. +"""Middleware that wraps skill dispatch on both the MCP and A2A +transports — the audit / activity-feed / rate-limiter / tracing hook. +Composition semantics are identical across transports (shared +composer); middleware written against one transport works unchanged +on the other. Signature (conceptually a Protocol; declared as a ``Callable`` alias so it's importable and consistent with ``ContextFactory``):: @@ -172,9 +179,87 @@ async def audit_middleware( return result create_a2a_server(MyAgent(), middleware=[audit_middleware]) + +The same middleware list also composes on the MCP side — pass it to +``create_mcp_server(middleware=...)`` or the transport-agnostic +``serve(middleware=...)``. """ +def _log_advertised_tools( + *, + transport: Literal["mcp", "a2a"], + handler: ADCPHandler[Any], + advertise_all: bool, + registered: list[str], +) -> None: + """Log which tools the server just advertised, plus the delta vs the + full spec surface the handler class could have supported. + + Operators occasionally rename a handler method and silently drop it + from ``tools/list`` — discovering that during incident review is + the wrong time. Emitting the advertised set and the unadvertised + delta at startup turns a silent gap into a searchable log line. + + Registered at ``INFO`` because operators routinely tail this; the + delta at ``DEBUG`` because it's noisy on fully-implemented handlers. + """ + registered_set = set(registered) + full_defs = get_tools_for_handler(handler, advertise_all=True) + full_names = {t["name"] for t in full_defs} + unadvertised = sorted(full_names - registered_set) + + logger.info( + "%s server advertising %d of %d tools%s", + transport, + len(registered_set), + len(full_names), + " (advertise_all=True)" if advertise_all else "", + ) + if unadvertised and not advertise_all: + logger.debug("%s server unadvertised tools: %s", transport, ", ".join(unadvertised)) + + +async def _dispatch_with_middleware( + middleware: tuple[SkillMiddleware, ...] | Sequence[SkillMiddleware], + skill_name: str, + params: dict[str, Any], + context: ToolContext, + call_handler: Callable[[], Awaitable[Any]], +) -> Any: + """Run ``call_handler`` wrapped in the supplied middleware chain. + + Shared by the MCP and A2A dispatch paths so composition semantics + stay identical across transports — middleware porting between + ``create_mcp_server(middleware=...)`` and + ``create_a2a_server(middleware=...)`` needs zero changes. + + Outermost-first composition: the first entry in ``middleware`` sees + every call *before* later entries and *before* the handler. No + mutable indices, no loop-variable captures — a small recursive + dispatcher reads the same with zero or ten middlewares. + + Middleware exceptions propagate to the caller unchanged; this + function does no try/except so short-circuiting, transform, and + exception-observation behaviors are owned by the transport-level + executor, not the composer. + """ + if not middleware: + return await call_handler() + + async def _step(index: int) -> Any: + if index >= len(middleware): + return await call_handler() + mw = middleware[index] + + async def call_next() -> Any: + return await _step(index + 1) + + return await mw(skill_name, params, context, call_next) + + return await _step(0) + + ContextFactory = Callable[[RequestMetadata], ToolContext] """Factory invoked per tool call to build a :class:`ToolContext`. @@ -227,6 +312,7 @@ def serve( task_store: TaskStore | None = None, push_config_store: PushNotificationConfigStore | None = None, middleware: Sequence[SkillMiddleware] | None = None, + message_parser: MessageParser | None = None, advertise_all: bool = False, max_request_size: int | None = None, ) -> None: @@ -258,11 +344,19 @@ def serve( subscriptions at all. See ``examples/a2a_db_tasks.py`` for a durable reference implementation. middleware: Optional sequence of :data:`SkillMiddleware` callables - wrapping every A2A skill dispatch (A2A transport only). Use - for audit logging, activity-feed hooks, rate limiting, - tracing. Composes outermost-first. See + wrapping every skill dispatch on both the MCP and A2A + transports. Use for audit logging, activity-feed hooks, + rate limiting, tracing. Composes outermost-first. See :data:`SkillMiddleware` for the signature and composition semantics. + message_parser: Optional + :data:`~adcp.server.a2a_server.MessageParser` callable for + alternative A2A wire shapes (A2A transport only). The + default parser handles ``DataPart(data={"skill": ..., + "parameters": ...})`` plus a TextPart JSON fallback; supply + this hook to accept JSON-RPC 2.0 message bodies or vendor- + specific DataPart schemas. MCP does not use this kwarg + (FastMCP owns the wire shape). advertise_all: When True, advertise every tool the handler type supports even if the subclass didn't override the method. Defaults to ``False`` — ``tools/list`` only shows tools the @@ -325,6 +419,7 @@ async def force_account_status(self, account_id, status): task_store=task_store, push_config_store=push_config_store, middleware=middleware, + message_parser=message_parser, advertise_all=advertise_all, max_request_size=max_request_size, ) @@ -337,6 +432,7 @@ async def force_account_status(self, account_id, status): instructions=instructions, test_controller=test_controller, context_factory=context_factory, + middleware=middleware, advertise_all=advertise_all, max_request_size=max_request_size, ) @@ -431,6 +527,7 @@ def _serve_mcp( instructions: str | None, test_controller: TestControllerStore | None, context_factory: ContextFactory | None = None, + middleware: Sequence[SkillMiddleware] | None = None, advertise_all: bool = False, max_request_size: int | None = None, ) -> None: @@ -442,6 +539,7 @@ def _serve_mcp( instructions=instructions, include_test_controller=test_controller is not None, context_factory=context_factory, + middleware=middleware, advertise_all=advertise_all, ) @@ -503,6 +601,7 @@ def _serve_a2a( task_store: TaskStore | None = None, push_config_store: PushNotificationConfigStore | None = None, middleware: Sequence[SkillMiddleware] | None = None, + message_parser: MessageParser | None = None, advertise_all: bool = False, max_request_size: int | None = None, ) -> None: @@ -522,6 +621,7 @@ def _serve_a2a( task_store=task_store, push_config_store=push_config_store, middleware=middleware, + message_parser=message_parser, advertise_all=advertise_all, ) app = _wrap_with_size_limit(app, max_request_size) @@ -547,6 +647,7 @@ def create_mcp_server( instructions: str | None = None, include_test_controller: bool = False, context_factory: ContextFactory | None = None, + middleware: Sequence[SkillMiddleware] | None = None, advertise_all: bool = False, ) -> Any: """Create a FastMCP server from an ADCP handler without starting it. @@ -577,6 +678,12 @@ def create_mcp_server( :data:`ContextFactory` for the recommended contextvars pattern. When ``None``, handlers receive a bare ``ToolContext()`` (no caller identity, no tenant). + middleware: Optional sequence of :data:`SkillMiddleware` callables + wrapping every tool dispatch. Symmetric with A2A's + ``create_a2a_server(middleware=...)`` — the same list works + on both transports. Use for audit logging, rate limiting, + tracing, activity-feed hooks. See :data:`SkillMiddleware` + for signature and composition semantics. advertise_all: When True, advertise every tool the handler type supports — even those whose method is still the SDK's ``not_supported`` default. Defaults to ``False``, which @@ -643,6 +750,7 @@ def create_mcp_server( handler, include_test_controller=include_test_controller, context_factory=context_factory, + middleware=middleware, advertise_all=advertise_all, ) return mcp @@ -654,10 +762,17 @@ def _register_handler_tools( *, include_test_controller: bool = False, context_factory: ContextFactory | None = None, + middleware: Sequence[SkillMiddleware] | None = None, advertise_all: bool = False, ) -> None: """Register all ADCP tools from a handler onto a FastMCP server.""" + # Freeze middleware ordering at registration time. Tuple both guards + # against a mutable list being reshuffled mid-request and matches the + # A2A executor's handling. + middleware_tuple: tuple[SkillMiddleware, ...] = tuple(middleware or ()) + tool_defs = get_tools_for_handler(handler, advertise_all=advertise_all) + registered: list[str] = [] for tool_def in tool_defs: tool_name = tool_def["name"] # Gate comply_test_controller on explicit opt-in. The handler base @@ -675,7 +790,16 @@ def _register_handler_tools( input_schema, caller, context_factory=context_factory, + middleware=middleware_tuple, ) + registered.append(tool_name) + + _log_advertised_tools( + transport="mcp", + handler=handler, + advertise_all=advertise_all, + registered=registered, + ) def _register_tool( @@ -686,6 +810,7 @@ def _register_tool( caller: Callable[..., Any], *, context_factory: ContextFactory | None = None, + middleware: tuple[SkillMiddleware, ...] = (), ) -> None: """Register a single ADCP tool on a FastMCP server. @@ -722,8 +847,23 @@ async def fn(**kwargs: Any) -> dict[str, Any]: f"context_factory for tool {name!r} returned " f"{type(context).__name__}, not a ToolContext instance" ) + + async def _call_handler() -> Any: + return await caller(kwargs, context=context) + try: - result = await caller(kwargs, context=context) + if middleware: + # Middleware requires a concrete ToolContext to match the + # declared SkillMiddleware signature; synthesise an empty + # one when no factory is configured so the chain still + # runs. Handler itself keeps receiving ``None`` semantics + # via ``context`` closed over by _call_handler. + mw_context = context if context is not None else ToolContext() + result = await _dispatch_with_middleware( + middleware, name, kwargs, mw_context, _call_handler + ) + else: + result = await _call_handler() except ADCPError as exc: # Translate AdCP-typed exceptions (IdempotencyConflictError, # ADCPTaskError with a spec code, etc.) into a ToolError so FastMCP diff --git a/tests/test_a2a_server.py b/tests/test_a2a_server.py index 2bb81e27..59f60e4a 100644 --- a/tests/test_a2a_server.py +++ b/tests/test_a2a_server.py @@ -1123,3 +1123,159 @@ async def enriching_middleware(skill_name, params, context, call_next): assert result["middleware_marker"] == "wrapped" # And the handler's original payload is still there. assert result["products"][0]["id"] == "p1" + + +# -------------------------------------------------------------------- +# Custom message_parser hook (alternative A2A wire formats) +# -------------------------------------------------------------------- + + +async def test_custom_message_parser_receives_request_context(): + """A custom parser is called with the RequestContext and owns the + (skill_name, params) extraction — enabling JSON-RPC, bare-text, or + vendor-specific DataPart layouts without subclassing the executor.""" + + class _ParserHandler(ADCPHandler): + async def get_products(self, params, context=None): + return {"products": [{"id": params.get("id", "?")}]} + + received: list[Any] = [] + + def my_parser(ctx: RequestContext) -> tuple[str | None, dict[str, Any]]: + received.append(ctx) + # Pretend the client sends ``{"operation": "get_products", "body": {...}}``. + msg = ctx.message + assert msg is not None + for part in msg.parts: + inner = part.root if hasattr(part, "root") else part + if isinstance(inner, DataPart) and isinstance(inner.data, dict): + op = inner.data.get("operation") + body = inner.data.get("body") or {} + if op: + return str(op), body if isinstance(body, dict) else {} + return None, {} + + executor = ADCPAgentExecutor(_ParserHandler(), message_parser=my_parser) + msg = Message( + message_id="m-custom", + role=Role.user, + parts=[Part(root=DataPart(data={"operation": "get_products", "body": {"id": "p42"}}))], + ) + ctx = RequestContext(request=MessageSendParams(message=msg)) + queue = EventQueue() + await executor.execute(ctx, queue) + + assert len(received) == 1 + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + +async def test_custom_parser_returning_none_yields_error_task(): + """A parser that returns (None, {}) must surface as an error Task + the same way an unparseable default message does.""" + + def bad_parser(ctx: RequestContext) -> tuple[str | None, dict[str, Any]]: + return None, {} + + class _Handler(ADCPHandler): + async def get_products(self, params, context=None): + return {"products": []} + + executor = ADCPAgentExecutor(_Handler(), message_parser=bad_parser) + msg = Message( + message_id="m-none", + role=Role.user, + parts=[Part(root=DataPart(data={"skill": "get_products", "parameters": {}}))], + ) + ctx = RequestContext(request=MessageSendParams(message=msg)) + queue = EventQueue() + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "failed" + + +async def test_default_parser_runs_when_no_message_parser_configured(): + """No ``message_parser=`` → the built-in ``_default_parse_request`` + runs. Pins backwards-compat for sellers who don't opt in.""" + + class _Handler(ADCPHandler): + async def get_products(self, params, context=None): + return {"products": [{"id": "default-path"}]} + + executor = ADCPAgentExecutor(_Handler()) + msg = Message( + message_id="m-default", + role=Role.user, + parts=[Part(root=DataPart(data={"skill": "get_products", "parameters": {}}))], + ) + ctx = RequestContext(request=MessageSendParams(message=msg)) + queue = EventQueue() + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="a2a-sdk starlette integration requires Python 3.11+", +) +def test_create_a2a_server_threads_message_parser_into_executor(): + """The kwarg propagates from ``create_a2a_server`` → executor.""" + + def my_parser(ctx: RequestContext) -> tuple[str | None, dict[str, Any]]: + return None, {} + + app = create_a2a_server(_TestHandler(), name="parser-test", message_parser=my_parser) + handler = _extract_default_request_handler(app) + executor = handler.agent_executor + assert isinstance(executor, ADCPAgentExecutor) + assert executor._message_parser is my_parser + + +async def test_custom_parser_can_compose_with_default(): + """Typical pattern: seller's parser tries a custom shape first, + then falls through to the default parser for legacy clients.""" + + class _Handler(ADCPHandler): + async def get_products(self, params, context=None): + return {"products": [{"from_params": params.get("source", "unknown")}]} + + executor = ADCPAgentExecutor(_Handler()) + + def composed(ctx: RequestContext) -> tuple[str | None, dict[str, Any]]: + # Seller's custom shape: DataPart({"operation": ..., "body": ...}) + msg = ctx.message + if msg is not None: + for part in msg.parts: + inner = part.root if hasattr(part, "root") else part + if ( + isinstance(inner, DataPart) + and isinstance(inner.data, dict) + and "operation" in inner.data + ): + return str(inner.data["operation"]), { + "source": "custom", + **(inner.data.get("body") or {}), + } + # Fall through to the default for legacy clients. + return executor._default_parse_request(ctx) + + executor2 = ADCPAgentExecutor(_Handler(), message_parser=composed) + + # Legacy shape → default parser catches it. + legacy_msg = Message( + message_id="m-legacy", + role=Role.user, + parts=[Part(root=DataPart(data={"skill": "get_products", "parameters": {}}))], + ) + legacy_ctx = RequestContext(request=MessageSendParams(message=legacy_msg)) + queue = EventQueue() + await executor2.execute(legacy_ctx, queue) + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" diff --git a/tests/test_auth_middleware.py b/tests/test_auth_middleware.py new file mode 100644 index 00000000..3f2ee7e3 --- /dev/null +++ b/tests/test_auth_middleware.py @@ -0,0 +1,500 @@ +"""Tests for BearerTokenAuthMiddleware + auth_context_factory. + +The middleware is load-bearing: a subtle bug here is a cross-tenant +confidentiality leak in production. Tests focus on the exact +invariants that matter for correctness — token compare, discovery +bypass, ContextVar reset, principal/tenant population. + +Composition with ``create_mcp_server(context_factory=auth_context_factory)`` +lives in ``test_mcp_middleware_composition.py`` — these tests +exercise the middleware class in isolation. +""" + +from __future__ import annotations + +import hashlib +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse +from starlette.routing import Route + +from adcp.server import ( + BearerTokenAuthMiddleware, + Principal, + auth_context_factory, + constant_time_token_match, +) +from adcp.server.auth import ( + current_principal, + current_principal_metadata, + current_tenant, +) + +# --------------------------------------------------------------------------- +# Principal + validator plumbing +# --------------------------------------------------------------------------- + + +def test_principal_is_immutable() -> None: + """Principal is frozen so a middleware can't mutate it after the + validator returns — any re-scope must build a fresh Principal.""" + p = Principal(caller_identity="alice", tenant_id="t1") + with pytest.raises(AttributeError): + p.caller_identity = "bob" # type: ignore[misc] + + +def test_constant_time_token_match_returns_value() -> None: + stored = {hashlib.sha256(b"good").hexdigest(): "payload"} + assert constant_time_token_match("good", stored) == "payload" + + +def test_constant_time_token_match_returns_none_on_miss() -> None: + stored = {hashlib.sha256(b"good").hexdigest(): "payload"} + assert constant_time_token_match("wrong", stored) is None + + +def test_constant_time_token_match_empty_token() -> None: + stored = {hashlib.sha256(b"good").hexdigest(): "payload"} + assert constant_time_token_match("", stored) is None + + +# --------------------------------------------------------------------------- +# Middleware-in-isolation tests via a minimal Starlette harness +# --------------------------------------------------------------------------- + + +async def _echo_handler(request: Request) -> JSONResponse: + """Starlette handler that echoes back the per-request ContextVars. + + The middleware populates these for each successfully-authenticated + request; failures short-circuit before the handler runs. + """ + return JSONResponse( + { + "principal": current_principal.get(), + "tenant": current_tenant.get(), + "metadata": current_principal_metadata.get(), + } + ) + + +def _build_app(validator: Any, routes: list[Route] | None = None) -> Starlette: + app = Starlette(routes=routes or [Route("/", _echo_handler, methods=["POST"])]) + app.add_middleware(BearerTokenAuthMiddleware, validate_token=validator) + return app + + +@pytest.mark.asyncio +async def test_rejects_missing_bearer() -> None: + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post("/", json={"method": "tools/call"}) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_rejects_invalid_bearer() -> None: + def validator(token: str) -> Principal | None: + return None # always reject + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Bearer bad-token"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_populates_contextvars_on_valid_token() -> None: + expected = Principal( + caller_identity="alice", + tenant_id="t1", + metadata={"role": "admin"}, + ) + + def validator(token: str) -> Principal | None: + return expected if token == "good" else None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Bearer good"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["principal"] == "alice" + assert body["tenant"] == "t1" + assert body["metadata"] == {"role": "admin"} + + +@pytest.mark.asyncio +async def test_async_validator_is_awaited() -> None: + """Validators can be `async def` — the middleware awaits them.""" + + async def validator(token: str) -> Principal | None: + return Principal(caller_identity="async-alice") if token == "good" else None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Bearer good"}, + ) + assert resp.status_code == 200 + assert resp.json()["principal"] == "async-alice" + + +@pytest.mark.asyncio +async def test_discovery_methods_bypass_auth() -> None: + """``initialize`` / ``notifications/initialized`` / ``tools/list`` + MUST go through without credentials — the MCP handshake has no + token yet.""" + validator_calls: list[str] = [] + + def validator(token: str) -> Principal | None: + validator_calls.append(token) + return None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + for method in ("initialize", "notifications/initialized", "tools/list"): + resp = await client.post("/", json={"method": method}) + assert resp.status_code == 200, f"{method} should bypass auth" + + # Validator MUST NOT have been called for any discovery method — bypass + # is composition-by-identity, not "call validator and ignore result". + assert validator_calls == [] + + +@pytest.mark.asyncio +async def test_discovery_tools_bypass_auth() -> None: + """``tools/call`` on a DISCOVERY_TOOLS entry (``get_adcp_capabilities``) + bypasses auth per AdCP spec — the capability handshake.""" + + def validator(token: str) -> Principal | None: + return None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={ + "method": "tools/call", + "params": {"name": "get_adcp_capabilities"}, + }, + ) + assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_contextvars_reset_after_request() -> None: + """The critical security invariant: after the response, the + ContextVars MUST be back to None — otherwise a later task sharing + the context reads a stale principal.""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice", tenant_id="t1") + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Bearer good"}, + ) + assert resp.status_code == 200 + + # The test's own context reads None — the middleware reset-in-finally + # fired before the test resumed. If this regresses, `.get()` would + # return "alice" from a leaked ContextVar. + assert current_principal.get() is None + assert current_tenant.get() is None + assert current_principal_metadata.get() is None + + +@pytest.mark.asyncio +async def test_batch_jsonrpc_fails_closed() -> None: + """JSON-RPC 2.0 allows batch arrays, but the discovery bypass must + NOT apply to batches — a client could smuggle a mutation past the + gate inside a batch. Batch → auth required → 401 without a bearer.""" + + def validator(token: str) -> Principal | None: + return None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json=[{"method": "tools/list"}, {"method": "tools/call"}], + ) + # Without a bearer header, the batch cannot satisfy the auth gate. + assert resp.status_code == 401 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "header", + [ + "Bearer good", # canonical + "bearer good", # RFC 7235: scheme is case-insensitive + "BEARER good", + "Bearer good", # folded double-space + "Bearer\tgood", # tab-separator accepted + "Bearer good\n", # trailing whitespace tolerated + ], +) +async def test_accepts_rfc7235_scheme_variants(header: str) -> None: + """RFC 7235 says the ``Bearer`` scheme is case-insensitive and + whitespace-folded. Clients that send lowercase or tab-separated + headers must not get a 401 — that's an interop bug that looks like + an auth bug.""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") if token == "good" else None + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": header}, + ) + assert resp.status_code == 200, f"header {header!r} was rejected" + + +@pytest.mark.asyncio +async def test_non_bearer_scheme_is_rejected() -> None: + """Basic / Digest / other schemes MUST return 401 — the middleware + is bearer-only by design.""" + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + # Placeholder non-bearer header — specific value is irrelevant, + # we only check the scheme gate rejects anything that isn't + # "Bearer". Kept as obvious placeholder text so secret scanners + # don't flag a real-looking base64 payload. + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Basic "}, + ) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +async def test_validator_exception_returns_401_not_500() -> None: + """A buggy validator (DB outage, bug) must fail closed with 401 — + a 500 leaks stack traces to the caller and signals the presence of + an auth path on the deployment. The docstring contract is "do not + raise"; we enforce fail-closed regardless.""" + + def validator(token: str) -> Principal | None: + raise RuntimeError("db down — leak-prone details here") + + app = _build_app(validator) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + json={"method": "tools/call", "params": {"name": "get_products"}}, + headers={"Authorization": "Bearer token"}, + ) + assert resp.status_code == 401 + # Body must NOT carry the exception text — exceptions go to logs, not clients. + assert "db down" not in resp.text + + +@pytest.mark.asyncio +async def test_principal_metadata_cannot_shadow_sdk_keys() -> None: + """A validator returning ``Principal(metadata={"tool_name": "x"})`` + must NOT shadow the SDK-populated ``tool_name`` in + ``ToolContext.metadata``. SDK keys always win — otherwise an + attacker-controlled validator could inject arbitrary audit fields.""" + from adcp.server import RequestMetadata + + principal_token = current_principal.set("alice") + metadata_token = current_principal_metadata.set( + {"tool_name": "attacker-injected", "transport": "attacker"} + ) + try: + meta = RequestMetadata(tool_name="get_products", transport="mcp") + ctx = auth_context_factory(meta) + finally: + current_principal.reset(principal_token) + current_principal_metadata.reset(metadata_token) + + # SDK keys win over principal-supplied keys. + assert ctx.metadata["tool_name"] == "get_products" + assert ctx.metadata["transport"] == "mcp" + + +@pytest.mark.asyncio +async def test_body_peek_does_not_starve_downstream_handler() -> None: + """The middleware peeks the JSON-RPC body to identify the method. + Downstream handlers must still read the same bytes — otherwise + MCP's streamable-HTTP transport (nested ASGI app that reads from + ``receive`` directly) hangs or sees empty payloads. + + This test runs the full request path: middleware peeks, downstream + reads ``request.body()``, asserts identical bytes.""" + from starlette.requests import Request as _Request + from starlette.responses import JSONResponse as _JSONResponse + + async def _echo_body(request: _Request) -> _JSONResponse: + body = await request.body() + return _JSONResponse({"body_len": len(body), "body_text": body.decode()}) + + def validator(token: str) -> Principal | None: + return Principal(caller_identity="alice") + + app = Starlette(routes=[Route("/", _echo_body, methods=["POST"])]) + app.add_middleware(BearerTokenAuthMiddleware, validate_token=validator) + + payload = { + "method": "tools/call", + "params": {"name": "get_products", "arguments": {"brief": "x"}}, + } + import json as _json + + expected = _json.dumps(payload) + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp = await client.post( + "/", + content=expected, + headers={ + "Authorization": "Bearer good", + "Content-Type": "application/json", + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["body_len"] == len(expected) + assert body["body_text"] == expected + + +@pytest.mark.asyncio +async def test_subclass_can_tighten_discovery_bypass() -> None: + """Operators tightening ``tools/list`` behind auth override + ``is_discovery_request``. Confirm the hook fires.""" + + class StricterMiddleware(BearerTokenAuthMiddleware): + def is_discovery_request(self, method: str | None, tool: str | None) -> bool: + # Only MCP initialize is bypassed; tools/list requires auth. + return method == "initialize" + + def validator(token: str) -> Principal | None: + return None + + app = Starlette(routes=[Route("/", _echo_handler, methods=["POST"])]) + app.add_middleware(StricterMiddleware, validate_token=validator) + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + resp_init = await client.post("/", json={"method": "initialize"}) + resp_list = await client.post("/", json={"method": "tools/list"}) + assert resp_init.status_code == 200 + assert resp_list.status_code == 401 + + +# --------------------------------------------------------------------------- +# Composition: auth_context_factory reads the middleware's ContextVars +# --------------------------------------------------------------------------- + + +def test_auth_context_factory_reads_contextvars() -> None: + """The factory builds a ToolContext from current_principal / + current_tenant / current_principal_metadata. No middleware runs + here — set the vars directly and call the factory.""" + from adcp.server import RequestMetadata + + principal_token = current_principal.set("alice") + tenant_token = current_tenant.set("t1") + metadata_token = current_principal_metadata.set({"role": "admin"}) + try: + meta = RequestMetadata(tool_name="get_products", transport="mcp") + ctx = auth_context_factory(meta) + finally: + current_principal.reset(principal_token) + current_tenant.reset(tenant_token) + current_principal_metadata.reset(metadata_token) + + assert ctx.caller_identity == "alice" + assert ctx.tenant_id == "t1" + assert ctx.metadata["role"] == "admin" + assert ctx.metadata["tool_name"] == "get_products" + assert ctx.metadata["transport"] == "mcp" + + +def test_auth_context_factory_with_no_principal() -> None: + """Discovery requests populate the ContextVars to None; the factory + returns a ToolContext with caller_identity=None (handshake is + pre-auth by design).""" + from adcp.server import RequestMetadata + + meta = RequestMetadata(tool_name="get_adcp_capabilities", transport="mcp") + ctx = auth_context_factory(meta) + + assert ctx.caller_identity is None + assert ctx.tenant_id is None + + +# Full-stack composition (middleware + create_mcp_server + handler) is +# covered by ``test_mcp_middleware_composition.py`` — that harness +# already boots the FastMCP initialize/tools-call flow end-to-end. The +# tests in this file stay focused on the middleware class itself so +# failures localise to the auth logic, not the transport plumbing. diff --git a/tests/test_mcp_middleware_composition.py b/tests/test_mcp_middleware_composition.py index 9e6c6713..fc6ce329 100644 --- a/tests/test_mcp_middleware_composition.py +++ b/tests/test_mcp_middleware_composition.py @@ -399,3 +399,158 @@ def _parse_event_stream(body: str) -> dict[str, Any]: if line.startswith("data: "): return json.loads(line.removeprefix("data: ")) return json.loads(body) if body.strip() else {} + + +# ---------------------------------------------------------------------- +# MCP middleware parity with A2A — ``create_mcp_server(middleware=[...])`` +# ---------------------------------------------------------------------- + + +@pytest.fixture +async def middleware_events() -> list[str]: + return [] + + +@pytest.fixture +async def middleware_handler_and_client(middleware_events: list[str]) -> Any: + """Fixture that wires a SkillMiddleware chain onto the MCP server. + Mirrors ``handler_and_client`` above but without the HTTP auth + layer so the middleware chain is the only thing under test.""" + handler = _RecordingHandler() + + async def outer(skill_name, params, context, call_next): + middleware_events.append(f"outer-pre:{skill_name}") + result = await call_next() + middleware_events.append(f"outer-post:{skill_name}") + return result + + async def inner(skill_name, params, context, call_next): + middleware_events.append(f"inner-pre:{skill_name}") + result = await call_next() + middleware_events.append(f"inner-post:{skill_name}") + return result + + mcp = create_mcp_server( + handler, + name="mw-test", + context_factory=_build_context, + middleware=[outer, inner], + ) + mcp.settings.stateless_http = True + mcp.settings.json_response = True + mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"] + app = mcp.streamable_http_app() + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + yield handler, client + + +@pytest.mark.asyncio +async def test_mcp_middleware_composes_outermost_first( + middleware_handler_and_client: Any, + middleware_events: list[str], +) -> None: + """MCP ``middleware=[outer, inner]`` matches A2A semantics: outer + pre-event comes first, then inner pre-event, then handler, then + inner post, then outer post. Stale ordering or reversed composition + would regress cross-transport parity.""" + _, client = middleware_handler_and_client + + await _initialize_session(client) + resp = await _call_tool(client, "get_adcp_capabilities", {}) + + assert resp.status_code == 200, resp.text + assert middleware_events == [ + "outer-pre:get_adcp_capabilities", + "inner-pre:get_adcp_capabilities", + "inner-post:get_adcp_capabilities", + "outer-post:get_adcp_capabilities", + ], middleware_events + + +@pytest.mark.asyncio +async def test_mcp_middleware_can_short_circuit() -> None: + """Middleware that returns without calling ``call_next()`` MUST + stop the chain — handler doesn't run. Rate limiters use this.""" + + handler_calls: list[str] = [] + + class _ShortCircuitTarget(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + handler_calls.append("called") + return {"adcp": {"major_versions": [3]}} + + async def rate_limiter(skill_name, params, context, call_next): + return {"error": "rate-limited", "skill": skill_name} + + mcp = create_mcp_server( + _ShortCircuitTarget(), + name="sc-test", + middleware=[rate_limiter], + ) + mcp.settings.stateless_http = True + mcp.settings.json_response = True + mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"] + app = mcp.streamable_http_app() + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + await _initialize_session(client) + resp = await _call_tool(client, "get_adcp_capabilities", {}) + + assert resp.status_code == 200, resp.text + assert handler_calls == [], ( + "middleware short-circuited but the handler still ran — MCP middleware " + "chain did not honour the 'skip call_next to skip handler' contract" + ) + + +@pytest.mark.asyncio +async def test_mcp_middleware_sees_tool_context() -> None: + """Middleware gets the same ToolContext the handler will receive. + When no context_factory is configured, middleware sees a default + ToolContext (not None) so the typed signature holds.""" + + seen: list[ToolContext] = [] + + async def record_context(skill_name, params, context, call_next): + seen.append(context) + return await call_next() + + class _Handler(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + mcp = create_mcp_server( + _Handler(), + name="ctx-test", + middleware=[record_context], + ) + mcp.settings.stateless_http = True + mcp.settings.json_response = True + mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"] + app = mcp.streamable_http_app() + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + await _initialize_session(client) + await _call_tool(client, "get_adcp_capabilities", {}) + + assert len(seen) == 1 + # No context_factory configured → middleware receives a synthesised + # default ToolContext so the signature type holds. Verified + # explicitly so a future change that passes None instead breaks here. + assert isinstance(seen[0], ToolContext) diff --git a/tests/test_server_startup_log.py b/tests/test_server_startup_log.py new file mode 100644 index 00000000..f34f44af --- /dev/null +++ b/tests/test_server_startup_log.py @@ -0,0 +1,97 @@ +"""Tests for the advertised-tools startup log. + +Operators occasionally rename a handler method and silently drop it +from ``tools/list`` — discovering that during incident review is the +wrong time. ``_log_advertised_tools`` turns a silent drop into a +searchable INFO log on server boot. These tests verify the message is +emitted with the right transport + counts on both MCP and A2A paths. +""" + +from __future__ import annotations + +import logging +import re +import sys + +import pytest + +from adcp.server import ADCPHandler, create_mcp_server +from adcp.server.a2a_server import create_a2a_server + +_ADVERTISING_PATTERN = re.compile(r"advertising (\d+) of (\d+) tools") + + +class _MinimalHandler(ADCPHandler): + """Overrides a few tools; rest stay at the ``not_supported`` default.""" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + async def get_products(self, params, context=None): + return {"products": []} + + +def test_mcp_startup_log_emits_count(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.INFO, logger="adcp.server") + create_mcp_server(_MinimalHandler(), name="log-test") + messages = [r.message for r in caplog.records if r.name == "adcp.server"] + assert any( + "mcp server advertising" in m and "of" in m and "tools" in m for m in messages + ), f"expected MCP startup log in {messages}" + + +def test_mcp_startup_log_advertises_only_overridden_tools( + caplog: pytest.LogCaptureFixture, +) -> None: + caplog.set_level(logging.INFO, logger="adcp.server") + create_mcp_server(_MinimalHandler(), name="log-test") + log_line = next(r.message for r in caplog.records if "mcp server advertising" in r.message) + match = _ADVERTISING_PATTERN.search(log_line) + assert match is not None, f"log line did not match expected shape: {log_line!r}" + advertised = int(match.group(1)) + total = int(match.group(2)) + assert advertised == 2, f"expected 2 advertised, got {advertised} in: {log_line}" + assert total > advertised, ( + f"expected total > advertised; got {total} vs {advertised}. " + f"Full handler surface should exceed the 2 overridden methods." + ) + + +def test_mcp_startup_log_notes_advertise_all_flag( + caplog: pytest.LogCaptureFixture, +) -> None: + caplog.set_level(logging.INFO, logger="adcp.server") + create_mcp_server(_MinimalHandler(), name="log-test", advertise_all=True) + log_line = next(r.message for r in caplog.records if "mcp server advertising" in r.message) + assert "advertise_all=True" in log_line + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="a2a-sdk starlette integration requires Python 3.11+", +) +def test_a2a_startup_log_emits_count(caplog: pytest.LogCaptureFixture) -> None: + """A2A startup log fires from ``create_a2a_server`` (symmetric with + MCP's placement). Per-test executor constructions don't pollute + caplog, which keeps this test reliable alongside a large test suite + that instantiates executors.""" + caplog.set_level(logging.INFO, logger="adcp.server") + create_a2a_server(_MinimalHandler(), name="log-test") + messages = [r.message for r in caplog.records if r.name == "adcp.server"] + assert any( + "a2a server advertising" in m for m in messages + ), f"expected A2A startup log in {messages}" + + +def test_unadvertised_tools_at_debug_level( + caplog: pytest.LogCaptureFixture, +) -> None: + """The list of unadvertised tool names logs at DEBUG — INFO would + be noisy on fully-implemented handlers. Pin both the level and the + existence of the message so operators know where to look.""" + caplog.set_level(logging.DEBUG, logger="adcp.server") + create_mcp_server(_MinimalHandler(), name="log-test") + debug_lines = [r.message for r in caplog.records if r.levelno == logging.DEBUG] + assert any( + "unadvertised" in m for m in debug_lines + ), f"expected a DEBUG log listing unadvertised tools, got: {debug_lines}"