Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions examples/mcp_with_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
:func:`~adcp.server.auth_context_factory` onto a multi-tenant sales
agent. The SDK owns the security-critical plumbing (constant-time
token compare, discovery bypass, ``ContextVar`` reset-in-finally);
the seller supplies only ``validate_token`` and the handler logic.
the seller supplies only the token → principal map and the handler
logic.

Run::

Expand All @@ -14,12 +15,14 @@

Production note: ``mcp.run()`` is used here for brevity. Real
deployments should mount the Starlette app behind uvicorn + a reverse
proxy that terminates TLS and handles rate limiting.
proxy that terminates TLS and handles rate limiting. Production
agents also typically load tokens from a database — swap
``validator_from_token_map`` for an ``async def validate_token`` that
hits your token store.
"""

from __future__ import annotations

import hashlib
from typing import Any

from adcp.server import (
Expand All @@ -28,37 +31,23 @@
Principal,
ToolContext,
auth_context_factory,
constant_time_token_match,
create_mcp_server,
validator_from_token_map,
)
from adcp.server.responses import capabilities_response, products_response

# Real agents look tokens up in Postgres / Vault / an identity provider.
# Keyed by SHA-256 so the comparison uses ``hmac.compare_digest`` rather
# than raw string equality — never compare raw bearer tokens with ``==``.
_TOKEN_HASHES: dict[str, Principal] = {
hashlib.sha256(raw.encode()).hexdigest(): principal
for raw, principal in {
"token-acme": Principal(
caller_identity="principal-acme-ops",
tenant_id="tenant-acme",
),
# ``validator_from_token_map`` hashes the raw tokens at construction and
# does ``hmac.compare_digest`` lookups — same security properties as a
# hand-rolled validator, one line instead of a dozen.
validate_token = validator_from_token_map(
{
"token-acme": Principal(caller_identity="principal-acme-ops", tenant_id="tenant-acme"),
"token-globex": Principal(
caller_identity="principal-globex-ops",
tenant_id="tenant-globex",
caller_identity="principal-globex-ops", tenant_id="tenant-globex"
),
}.items()
}


def validate_token(token: str) -> Principal | None:
"""Seller-supplied token validator.

``constant_time_token_match`` iterates every stored hash with
:func:`hmac.compare_digest`, avoiding the prefix-match timing leak
that a plain ``dict`` lookup would have.
"""
return constant_time_token_match(token, _TOKEN_HASHES)
}
)


class MultiTenantSalesAgent(ADCPHandler):
Expand Down
100 changes: 99 additions & 1 deletion src/adcp/migrate/v3_to_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,44 @@ def _format_text_report(report: Report, *, apply_changes: bool) -> str:
return "\n".join(lines)


REPORT_SCHEMA_VERSION = 1
"""Version of the JSON report shape. CI scripts / editors parsing the
migrate output key on this so a future shape change (adding a summary
block, renaming fields) doesn't silently break them.

Bump the minor SDK version AND this constant when changing the JSON
shape in a non-additive way. Additive changes (new optional keys)
stay at the same version.

**v1 shape:**

.. code-block:: json

{
"schema_version": 1,
"scanned_files": int,
"rewritten_files": int,
"applied": [
{"kind": "rename", "path": str, "line": int, "column": int,
"before": str, "after": str, "hint": null, "migration_anchor": null}
],
"flagged": [
{"kind": "flag_removed" | "flag_numbered" | "flag_private" | "flag_attribute",
"path": str, "line": int, "column": int, "before": str,
"after": null, "hint": str | null, "migration_anchor": str | null}
]
}
"""


def _format_json_report(report: Report) -> str:
"""JSON report for programmatic consumption (CI, editors)."""
"""JSON report for programmatic consumption (CI, editors).

Versioned via :data:`REPORT_SCHEMA_VERSION` — parsers should check
the top-level ``schema_version`` key before reading the rest.
"""
payload = {
"schema_version": REPORT_SCHEMA_VERSION,
"scanned_files": report.scanned_files,
"rewritten_files": report.rewritten_files,
"applied": [asdict(f) for f in report.applied],
Expand All @@ -407,6 +442,46 @@ def _format_json_report(report: Report) -> str:
return json.dumps(payload, indent=2)


def _is_dirty_tree(path: Path) -> bool:
"""True when ``path`` is inside a git repo with uncommitted changes.

Uses ``git status --porcelain`` for speed and stability. Returns
``False`` when git isn't installed, the path isn't in a repo, or
the repo is clean — any non-clean state returns ``True`` so the
``--apply`` guard fails safe.

The check is best-effort: absence of git isn't a reason to block
the rewrite (sellers may run in sandboxed or read-only environments
where git isn't available). A ``True`` result means we saw
definite uncommitted state.
"""
import shutil
import subprocess

if shutil.which("git") is None:
return False

target = path.resolve()
cwd = target if target.is_dir() else target.parent
try:
result = subprocess.run(
["git", "status", "--porcelain"],
cwd=cwd,
check=False,
capture_output=True,
text=True,
timeout=5,
)
except (OSError, subprocess.SubprocessError):
return False
# Exit 128 = not a git repo; anything non-zero → treat as clean
# (not blocking — we don't want `--apply` in a sandboxed env to
# break because git can't run).
if result.returncode != 0:
return False
return bool(result.stdout.strip())


def main(argv: list[str] | None = None) -> int:
"""CLI entry point for ``python -m adcp.migrate v3-to-v4``."""
parser = argparse.ArgumentParser(
Expand All @@ -429,6 +504,18 @@ def main(argv: list[str] | None = None) -> int:
"Commit your tree first so `git diff` is your review view."
),
)
parser.add_argument(
"--allow-dirty",
action="store_true",
help=(
"Allow --apply even when the git working tree has "
"uncommitted changes. Default is to refuse so `git diff` "
"after the migration shows only the codemod's rewrites, "
"not a mix of the seller's in-progress work and the "
"codemod. Pass --allow-dirty when you know what you're "
"doing (e.g. applying to a staged change deliberately)."
),
)
parser.add_argument(
"--json",
action="store_true",
Expand All @@ -440,6 +527,17 @@ def main(argv: list[str] | None = None) -> int:
print(f"error: path does not exist: {args.path}", file=sys.stderr)
return 2

if args.apply and not args.allow_dirty and _is_dirty_tree(args.path):
print(
"error: --apply refused on a dirty git working tree.\n"
" Commit your changes first so `git diff` after the\n"
" migration shows only the codemod's rewrites. Pass\n"
" --allow-dirty to override (e.g. you're deliberately\n"
" applying on top of staged changes).",
file=sys.stderr,
)
return 2

report = run(args.path, apply_changes=args.apply)

if args.json:
Expand Down
6 changes: 6 additions & 0 deletions src/adcp/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ async def get_products(params, context=None):
from adcp.capabilities import validate_capabilities
from adcp.server.a2a_server import ADCPAgentExecutor, MessageParser, create_a2a_server
from adcp.server.auth import (
AsyncTokenValidator,
BearerTokenAuthMiddleware,
Principal,
SyncTokenValidator,
TokenValidator,
auth_context_factory,
constant_time_token_match,
validator_from_token_map,
)
from adcp.server.base import (
AccountAwareToolContext,
Expand Down Expand Up @@ -172,11 +175,14 @@ async def get_products(params, context=None):
"SkillMiddleware",
"create_a2a_server",
# Bearer-token auth middleware (seller-facing recipe)
"AsyncTokenValidator",
"BearerTokenAuthMiddleware",
"Principal",
"SyncTokenValidator",
"TokenValidator",
"auth_context_factory",
"constant_time_token_match",
"validator_from_token_map",
# Idempotency middleware (AdCP #2315 seller side)
"IdempotencyStore",
"MemoryBackend",
Expand Down
73 changes: 65 additions & 8 deletions src/adcp/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ async def validate_token(token: str) -> Principal | None:
import inspect
import json
import logging
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Mapping
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Protocol, TypeVar

_V = TypeVar("_V")

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
Expand Down Expand Up @@ -138,16 +140,32 @@ class Principal:
metadata: dict[str, Any] = field(default_factory=dict)


TokenValidator = Callable[[str], "Principal | None | Awaitable[Principal | None]"]
class SyncTokenValidator(Protocol):
"""Synchronous token validator — ``def validate_token(token) -> Principal | None``."""

def __call__(self, token: str) -> Principal | None: ...


class AsyncTokenValidator(Protocol):
"""Asynchronous token validator —
``async def validate_token(token) -> Principal | None``."""

def __call__(self, token: str) -> Awaitable[Principal | None]: ...


TokenValidator = SyncTokenValidator | AsyncTokenValidator
"""Seller-supplied callable that validates a bearer token.

Called with the raw token string (``Authorization: Bearer <token>``
with the prefix already stripped). Return a :class:`Principal` on
success, ``None`` to reject.
success, ``None`` to reject. Sync and async callables are both
accepted — the middleware awaits the result when it's awaitable.

Sync and async callables are both accepted — the middleware awaits the
result when it's awaitable, so plain ``def validate_token(...)`` and
``async def validate_token(...)`` both work.
Declared as a union of two Protocols (rather than a
``Callable[[str], Principal | None | Awaitable[...]]`` alias)
because mypy narrows Protocol unions per-call-site: downstream code
using ``async def validate_token`` gets the async branch without
``type: ignore`` noise. Either protocol is a valid ``TokenValidator``.

**Do not raise on invalid tokens.** Exceptions become ``500 Internal
Server Error`` responses, which leak the presence of an auth path
Expand Down Expand Up @@ -367,7 +385,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext:
# ------------------------------------------------------------------


def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any:
def constant_time_token_match(token: str, stored_hashes: Mapping[str, _V]) -> _V | None:
"""Look up a token in a dict of SHA-256 hashes using
:func:`hmac.compare_digest` rather than dict-containment.

Expand All @@ -393,3 +411,42 @@ def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any:
if hmac.compare_digest(candidate, stored_hash):
return value
return None


def validator_from_token_map(
token_map: Mapping[str, Principal],
) -> SyncTokenValidator:
"""Build a :data:`TokenValidator` from a ``{raw_token: Principal}`` map.

The shape most demo/test agents actually need — a fixed set of
tokens mapped to principals — without having to write the
constant-time plumbing. The returned validator hashes each raw
token at construction time and does constant-time lookups via
:func:`hmac.compare_digest` on every call, matching the security
properties of a hand-rolled validator::

validate_token = validator_from_token_map({
"token-acme": Principal(caller_identity="p-acme", tenant_id="acme"),
"token-globex": Principal(caller_identity="p-globex", tenant_id="globex"),
})
app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token)

Production agents looking tokens up in Postgres / Redis / Vault
should write their own async validator instead — this helper is
for the small-fixed-set case (demo, test, CI fixtures).

:param token_map: Mapping of raw bearer tokens to their resolved
:class:`Principal`. Tokens are hashed at construction; the
plaintext is not retained.
:returns: A :data:`SyncTokenValidator` (which satisfies
:data:`TokenValidator`).
"""
stored_hashes: dict[str, Principal] = {
hashlib.sha256(token.encode()).hexdigest(): principal
for token, principal in token_map.items()
}

def _validate(token: str) -> Principal | None:
return constant_time_token_match(token, stored_hashes)

return _validate
Loading
Loading