diff --git a/examples/mcp_with_auth_middleware.py b/examples/mcp_with_auth_middleware.py index 52bab061..f4bd862b 100644 --- a/examples/mcp_with_auth_middleware.py +++ b/examples/mcp_with_auth_middleware.py @@ -4,7 +4,8 @@ :func:`~adcp.server.auth_context_factory` onto a multi-tenant sales agent. The SDK owns the security-critical plumbing (constant-time token compare, discovery bypass, ``ContextVar`` reset-in-finally); -the seller supplies only ``validate_token`` and the handler logic. +the seller supplies only the token → principal map and the handler +logic. Run:: @@ -14,12 +15,14 @@ Production note: ``mcp.run()`` is used here for brevity. Real deployments should mount the Starlette app behind uvicorn + a reverse -proxy that terminates TLS and handles rate limiting. +proxy that terminates TLS and handles rate limiting. Production +agents also typically load tokens from a database — swap +``validator_from_token_map`` for an ``async def validate_token`` that +hits your token store. """ from __future__ import annotations -import hashlib from typing import Any from adcp.server import ( @@ -28,37 +31,23 @@ Principal, ToolContext, auth_context_factory, - constant_time_token_match, create_mcp_server, + validator_from_token_map, ) from adcp.server.responses import capabilities_response, products_response # Real agents look tokens up in Postgres / Vault / an identity provider. -# Keyed by SHA-256 so the comparison uses ``hmac.compare_digest`` rather -# than raw string equality — never compare raw bearer tokens with ``==``. -_TOKEN_HASHES: dict[str, Principal] = { - hashlib.sha256(raw.encode()).hexdigest(): principal - for raw, principal in { - "token-acme": Principal( - caller_identity="principal-acme-ops", - tenant_id="tenant-acme", - ), +# ``validator_from_token_map`` hashes the raw tokens at construction and +# does ``hmac.compare_digest`` lookups — same security properties as a +# hand-rolled validator, one line instead of a dozen. +validate_token = validator_from_token_map( + { + "token-acme": Principal(caller_identity="principal-acme-ops", tenant_id="tenant-acme"), "token-globex": Principal( - caller_identity="principal-globex-ops", - tenant_id="tenant-globex", + caller_identity="principal-globex-ops", tenant_id="tenant-globex" ), - }.items() -} - - -def validate_token(token: str) -> Principal | None: - """Seller-supplied token validator. - - ``constant_time_token_match`` iterates every stored hash with - :func:`hmac.compare_digest`, avoiding the prefix-match timing leak - that a plain ``dict`` lookup would have. - """ - return constant_time_token_match(token, _TOKEN_HASHES) + } +) class MultiTenantSalesAgent(ADCPHandler): diff --git a/src/adcp/migrate/v3_to_v4.py b/src/adcp/migrate/v3_to_v4.py index b1bcdb62..9fc411c3 100644 --- a/src/adcp/migrate/v3_to_v4.py +++ b/src/adcp/migrate/v3_to_v4.py @@ -396,9 +396,44 @@ def _format_text_report(report: Report, *, apply_changes: bool) -> str: return "\n".join(lines) +REPORT_SCHEMA_VERSION = 1 +"""Version of the JSON report shape. CI scripts / editors parsing the +migrate output key on this so a future shape change (adding a summary +block, renaming fields) doesn't silently break them. + +Bump the minor SDK version AND this constant when changing the JSON +shape in a non-additive way. Additive changes (new optional keys) +stay at the same version. + +**v1 shape:** + +.. code-block:: json + + { + "schema_version": 1, + "scanned_files": int, + "rewritten_files": int, + "applied": [ + {"kind": "rename", "path": str, "line": int, "column": int, + "before": str, "after": str, "hint": null, "migration_anchor": null} + ], + "flagged": [ + {"kind": "flag_removed" | "flag_numbered" | "flag_private" | "flag_attribute", + "path": str, "line": int, "column": int, "before": str, + "after": null, "hint": str | null, "migration_anchor": str | null} + ] + } +""" + + def _format_json_report(report: Report) -> str: - """JSON report for programmatic consumption (CI, editors).""" + """JSON report for programmatic consumption (CI, editors). + + Versioned via :data:`REPORT_SCHEMA_VERSION` — parsers should check + the top-level ``schema_version`` key before reading the rest. + """ payload = { + "schema_version": REPORT_SCHEMA_VERSION, "scanned_files": report.scanned_files, "rewritten_files": report.rewritten_files, "applied": [asdict(f) for f in report.applied], @@ -407,6 +442,46 @@ def _format_json_report(report: Report) -> str: return json.dumps(payload, indent=2) +def _is_dirty_tree(path: Path) -> bool: + """True when ``path`` is inside a git repo with uncommitted changes. + + Uses ``git status --porcelain`` for speed and stability. Returns + ``False`` when git isn't installed, the path isn't in a repo, or + the repo is clean — any non-clean state returns ``True`` so the + ``--apply`` guard fails safe. + + The check is best-effort: absence of git isn't a reason to block + the rewrite (sellers may run in sandboxed or read-only environments + where git isn't available). A ``True`` result means we saw + definite uncommitted state. + """ + import shutil + import subprocess + + if shutil.which("git") is None: + return False + + target = path.resolve() + cwd = target if target.is_dir() else target.parent + try: + result = subprocess.run( + ["git", "status", "--porcelain"], + cwd=cwd, + check=False, + capture_output=True, + text=True, + timeout=5, + ) + except (OSError, subprocess.SubprocessError): + return False + # Exit 128 = not a git repo; anything non-zero → treat as clean + # (not blocking — we don't want `--apply` in a sandboxed env to + # break because git can't run). + if result.returncode != 0: + return False + return bool(result.stdout.strip()) + + def main(argv: list[str] | None = None) -> int: """CLI entry point for ``python -m adcp.migrate v3-to-v4``.""" parser = argparse.ArgumentParser( @@ -429,6 +504,18 @@ def main(argv: list[str] | None = None) -> int: "Commit your tree first so `git diff` is your review view." ), ) + parser.add_argument( + "--allow-dirty", + action="store_true", + help=( + "Allow --apply even when the git working tree has " + "uncommitted changes. Default is to refuse so `git diff` " + "after the migration shows only the codemod's rewrites, " + "not a mix of the seller's in-progress work and the " + "codemod. Pass --allow-dirty when you know what you're " + "doing (e.g. applying to a staged change deliberately)." + ), + ) parser.add_argument( "--json", action="store_true", @@ -440,6 +527,17 @@ def main(argv: list[str] | None = None) -> int: print(f"error: path does not exist: {args.path}", file=sys.stderr) return 2 + if args.apply and not args.allow_dirty and _is_dirty_tree(args.path): + print( + "error: --apply refused on a dirty git working tree.\n" + " Commit your changes first so `git diff` after the\n" + " migration shows only the codemod's rewrites. Pass\n" + " --allow-dirty to override (e.g. you're deliberately\n" + " applying on top of staged changes).", + file=sys.stderr, + ) + return 2 + report = run(args.path, apply_changes=args.apply) if args.json: diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index 885fd829..eaa7f001 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -55,11 +55,14 @@ async def get_products(params, context=None): from adcp.capabilities import validate_capabilities from adcp.server.a2a_server import ADCPAgentExecutor, MessageParser, create_a2a_server from adcp.server.auth import ( + AsyncTokenValidator, BearerTokenAuthMiddleware, Principal, + SyncTokenValidator, TokenValidator, auth_context_factory, constant_time_token_match, + validator_from_token_map, ) from adcp.server.base import ( AccountAwareToolContext, @@ -172,11 +175,14 @@ async def get_products(params, context=None): "SkillMiddleware", "create_a2a_server", # Bearer-token auth middleware (seller-facing recipe) + "AsyncTokenValidator", "BearerTokenAuthMiddleware", "Principal", + "SyncTokenValidator", "TokenValidator", "auth_context_factory", "constant_time_token_match", + "validator_from_token_map", # Idempotency middleware (AdCP #2315 seller side) "IdempotencyStore", "MemoryBackend", diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index ad577df6..b745d0b7 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -72,10 +72,12 @@ async def validate_token(token: str) -> Principal | None: import inspect import json import logging -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Mapping from contextvars import ContextVar from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +_V = TypeVar("_V") from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse @@ -138,16 +140,32 @@ class Principal: metadata: dict[str, Any] = field(default_factory=dict) -TokenValidator = Callable[[str], "Principal | None | Awaitable[Principal | None]"] +class SyncTokenValidator(Protocol): + """Synchronous token validator — ``def validate_token(token) -> Principal | None``.""" + + def __call__(self, token: str) -> Principal | None: ... + + +class AsyncTokenValidator(Protocol): + """Asynchronous token validator — + ``async def validate_token(token) -> Principal | None``.""" + + def __call__(self, token: str) -> Awaitable[Principal | None]: ... + + +TokenValidator = SyncTokenValidator | AsyncTokenValidator """Seller-supplied callable that validates a bearer token. Called with the raw token string (``Authorization: Bearer `` with the prefix already stripped). Return a :class:`Principal` on -success, ``None`` to reject. +success, ``None`` to reject. Sync and async callables are both +accepted — the middleware awaits the result when it's awaitable. -Sync and async callables are both accepted — the middleware awaits the -result when it's awaitable, so plain ``def validate_token(...)`` and -``async def validate_token(...)`` both work. +Declared as a union of two Protocols (rather than a +``Callable[[str], Principal | None | Awaitable[...]]`` alias) +because mypy narrows Protocol unions per-call-site: downstream code +using ``async def validate_token`` gets the async branch without +``type: ignore`` noise. Either protocol is a valid ``TokenValidator``. **Do not raise on invalid tokens.** Exceptions become ``500 Internal Server Error`` responses, which leak the presence of an auth path @@ -367,7 +385,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext: # ------------------------------------------------------------------ -def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any: +def constant_time_token_match(token: str, stored_hashes: Mapping[str, _V]) -> _V | None: """Look up a token in a dict of SHA-256 hashes using :func:`hmac.compare_digest` rather than dict-containment. @@ -393,3 +411,42 @@ def constant_time_token_match(token: str, stored_hashes: dict[str, Any]) -> Any: if hmac.compare_digest(candidate, stored_hash): return value return None + + +def validator_from_token_map( + token_map: Mapping[str, Principal], +) -> SyncTokenValidator: + """Build a :data:`TokenValidator` from a ``{raw_token: Principal}`` map. + + The shape most demo/test agents actually need — a fixed set of + tokens mapped to principals — without having to write the + constant-time plumbing. The returned validator hashes each raw + token at construction time and does constant-time lookups via + :func:`hmac.compare_digest` on every call, matching the security + properties of a hand-rolled validator:: + + validate_token = validator_from_token_map({ + "token-acme": Principal(caller_identity="p-acme", tenant_id="acme"), + "token-globex": Principal(caller_identity="p-globex", tenant_id="globex"), + }) + app.add_middleware(BearerTokenAuthMiddleware, validate_token=validate_token) + + Production agents looking tokens up in Postgres / Redis / Vault + should write their own async validator instead — this helper is + for the small-fixed-set case (demo, test, CI fixtures). + + :param token_map: Mapping of raw bearer tokens to their resolved + :class:`Principal`. Tokens are hashed at construction; the + plaintext is not retained. + :returns: A :data:`SyncTokenValidator` (which satisfies + :data:`TokenValidator`). + """ + stored_hashes: dict[str, Principal] = { + hashlib.sha256(token.encode()).hexdigest(): principal + for token, principal in token_map.items() + } + + def _validate(token: str) -> Principal | None: + return constant_time_token_match(token, stored_hashes) + + return _validate diff --git a/tests/test_auth_middleware.py b/tests/test_auth_middleware.py index 3f2ee7e3..af83be6b 100644 --- a/tests/test_auth_middleware.py +++ b/tests/test_auth_middleware.py @@ -28,6 +28,7 @@ Principal, auth_context_factory, constant_time_token_match, + validator_from_token_map, ) from adcp.server.auth import ( current_principal, @@ -63,6 +64,76 @@ def test_constant_time_token_match_empty_token() -> None: assert constant_time_token_match("", stored) is None +# --------------------------------------------------------------------------- +# validator_from_token_map +# --------------------------------------------------------------------------- + + +def test_validator_from_token_map_returns_principal_on_match() -> None: + """Happy path: the map's raw token resolves to its Principal.""" + alice = Principal(caller_identity="alice", tenant_id="t1") + validate = validator_from_token_map({"s3cret-token": alice}) + assert validate("s3cret-token") == alice + + +def test_validator_from_token_map_returns_none_on_miss() -> None: + """Unknown token → ``None``, not exception.""" + alice = Principal(caller_identity="alice") + validate = validator_from_token_map({"known": alice}) + assert validate("unknown") is None + + +def test_validator_from_token_map_constant_time_compare() -> None: + """The helper MUST use ``constant_time_token_match`` under the + hood — not raw dict lookup — so timing doesn't leak prefix match. + Test by confirming both a known-prefix miss and a full miss + return the same (None) result without blowing up.""" + validate = validator_from_token_map( + { + "alpha-beta-gamma": Principal(caller_identity="alice"), + "zulu-yankee-xray": Principal(caller_identity="bob"), + } + ) + # Same-length miss with partial prefix overlap + assert validate("alpha-beta-nope-") is None + # Completely different token + assert validate("mno-pqr-stu-vwx") is None + # Actual match still works + assert validate("alpha-beta-gamma").caller_identity == "alice" + + +def test_validator_from_token_map_empty_map_always_returns_none() -> None: + """Degenerate case: empty map → every token rejects. No crashes, + no AttributeErrors.""" + validate = validator_from_token_map({}) + assert validate("anything") is None + assert validate("") is None + + +def test_validator_from_token_map_does_not_retain_plaintext() -> None: + """Security invariant: the plaintext tokens MUST NOT be + retrievable from the returned validator's closure. They're hashed + at construction; only hashes live in the closure.""" + import gc + + raw_token = "plaintext-should-not-persist-here" + validate = validator_from_token_map({raw_token: Principal(caller_identity="alice")}) + + # Walk the closure's referents, flatten one level. The raw token + # SHOULD NOT appear — only its SHA-256 hex digest. + referents = gc.get_referents(validate.__closure__[0].cell_contents) + flat_strings: list[str] = [] + for ref in referents: + if isinstance(ref, str): + flat_strings.append(ref) + elif isinstance(ref, dict): + flat_strings.extend(k for k in ref.keys() if isinstance(k, str)) + + assert ( + raw_token not in flat_strings + ), f"raw token leaked into validator closure: {flat_strings!r}" + + # --------------------------------------------------------------------------- # Middleware-in-isolation tests via a minimal Starlette harness # --------------------------------------------------------------------------- diff --git a/tests/test_migrate_v3_to_v4.py b/tests/test_migrate_v3_to_v4.py index 87ad85d7..34e3ce6d 100644 --- a/tests/test_migrate_v3_to_v4.py +++ b/tests/test_migrate_v3_to_v4.py @@ -409,6 +409,7 @@ def test_cli_json_output(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> out = capsys.readouterr().out payload = json.loads(out) + assert payload["schema_version"] == v3_to_v4.REPORT_SCHEMA_VERSION assert payload["scanned_files"] == 1 assert payload["rewritten_files"] == 0 assert len(payload["applied"]) == 1 @@ -418,6 +419,15 @@ def test_cli_json_output(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> assert any(f["before"] == "BrandManifest" for f in removed) +def test_json_report_schema_version_is_declared() -> None: + """The v1 JSON shape is a wire contract with CI scripts and + editors. A non-additive change (renaming a field, removing one) + MUST bump ``REPORT_SCHEMA_VERSION`` AND the SDK minor version — + this test pins the current version so a change is a deliberate + choice, not an accident.""" + assert v3_to_v4.REPORT_SCHEMA_VERSION == 1 + + def test_cli_apply_rewrites_and_reports(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: """The happy end-to-end path: scan + rewrite + human-readable summary.""" path = _write(tmp_path, "code.py", "from adcp.types import AudioAsset\n") @@ -439,3 +449,80 @@ def test_unreadable_file_does_not_crash(tmp_path: Path) -> None: report = v3_to_v4.run(tmp_path, apply_changes=False) assert report.scanned_files == 1 assert report.applied == [] + + +# --------------------------------------------------------------------------- +# --apply safety: refuse on dirty git tree, allow with --allow-dirty +# --------------------------------------------------------------------------- + + +def _init_git_repo(path: Path) -> None: + """Initialize a git repo at ``path`` and commit one file so the + default branch exists.""" + import subprocess + + subprocess.run(["git", "init", "-q", "-b", "main"], cwd=path, check=True) + subprocess.run(["git", "config", "user.email", "test@example.com"], cwd=path, check=True) + subprocess.run(["git", "config", "user.name", "Test"], cwd=path, check=True) + subprocess.run(["git", "config", "commit.gpgsign", "false"], cwd=path, check=True) + (path / ".gitkeep").write_text("") + subprocess.run(["git", "add", ".gitkeep"], cwd=path, check=True) + subprocess.run(["git", "commit", "-q", "-m", "initial"], cwd=path, check=True) + + +def test_apply_refuses_on_dirty_git_tree( + tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + """``--apply`` MUST refuse when the working tree is dirty. Otherwise + the seller's in-progress work gets mixed into the codemod's rewrite + diff and ``git diff`` review stops being useful.""" + import shutil + + if shutil.which("git") is None: + pytest.skip("git not available") + + _init_git_repo(tmp_path) + # Create an uncommitted file — this makes the tree dirty. + _write(tmp_path, "code.py", "from adcp.types import AudioAsset\n") + + rc = v3_to_v4.main([str(tmp_path), "--apply"]) + err = capsys.readouterr().err + + assert rc == 2 + assert "dirty git working tree" in err + assert "--allow-dirty" in err + # File NOT rewritten (the guard short-circuits before `run`). + assert (tmp_path / "code.py").read_text() == "from adcp.types import AudioAsset\n" + + +def test_apply_allow_dirty_overrides_guard(tmp_path: Path) -> None: + """``--allow-dirty`` lets sellers deliberately run the codemod on + top of staged changes (e.g. batched with a related refactor).""" + import shutil + + if shutil.which("git") is None: + pytest.skip("git not available") + + _init_git_repo(tmp_path) + path = _write(tmp_path, "code.py", "from adcp.types import AudioAsset\n") + + rc = v3_to_v4.main([str(tmp_path), "--apply", "--allow-dirty"]) + + # Renames applied; exit code reflects no flagged findings. + assert rc == 0 + assert "AudioContent" in path.read_text() + + +def test_apply_proceeds_when_not_in_git_repo(tmp_path: Path) -> None: + """Running in a non-git directory (CI sandbox, scratch env) must + not block --apply. The guard fails-safe: if git can't verify + dirty state, we proceed. This is already implicitly tested by + other --apply tests (they use tmp_path which isn't a repo), + but pin it explicitly so a future tightening of the guard breaks + here, not silently in seller CI environments.""" + path = _write(tmp_path, "code.py", "from adcp.types import AudioAsset\n") + + rc = v3_to_v4.main([str(tmp_path), "--apply"]) + + assert rc == 0 + assert "AudioContent" in path.read_text()