diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7c6e523..638cb22 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -44,10 +44,11 @@ uv run opencode-a2a serve Run the default validation baseline before opening a PR: ```bash -uv run pre-commit run --all-files -uv run pytest +bash ./scripts/doctor.sh ``` +`doctor.sh` is the primary repository validation entrypoint. It currently runs locked-environment sync, dependency compatibility checks, `pre-commit`, `mypy`, `pytest`, the repository coverage gate, package build, and a built-wheel smoke test. + If you change shell scripts, also run `bash -n` on each modified script, for example: ```bash @@ -63,7 +64,7 @@ bash ./scripts/conformance.sh Treat that output as investigation input. Do not fold it into `doctor.sh` or the default CI quality gate unless the repository explicitly decides to promote a specific experiment into a maintained policy. -If you change extension methods, extension metadata, or Agent Card/OpenAPI contract surfaces, also run: +If you change extension methods, extension metadata, or Agent Card/OpenAPI contract surfaces, also make sure the targeted contract checks stay green: ```bash uv run pytest tests/contracts/test_extension_contract_consistency.py @@ -107,7 +108,7 @@ Update docs together with code whenever you change: - user-facing request or response shapes - operational scripts -Keep compatibility guidance centralized in [docs/guide.md](docs/guide.md) unless a new standalone document is clearly necessary. +Keep usage details in [docs/guide.md](docs/guide.md) and compatibility-sensitive stability guidance in [docs/compatibility.md](docs/compatibility.md). When changing extension contracts, update [`src/opencode_a2a/contracts/extensions.py`](src/opencode_a2a/contracts/extensions.py) first and keep these generated/documented surfaces aligned: diff --git a/README.md b/README.md index be3f3af..9d94cc5 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ curl http://127.0.0.1:8000/.well-known/agent-card.json - Declared supported protocol lines: `0.3`, `1.0` - `0.3` is the stable interoperability baseline for the current runtime surface. - `1.0` currently covers version negotiation plus protocol-aware JSON-RPC and REST error shaping, while transport payloads, enums, pagination, signatures, and interface-level protocol declarations still follow the shipped SDK baseline. -- The detailed compatibility matrix and machine-readable support boundary are documented in [`docs/guide.md`](docs/guide.md). +- The detailed compatibility matrix and machine-readable support boundary are documented in [`docs/guide.md`](docs/guide.md) and [`docs/compatibility.md`](docs/compatibility.md). ## Peering Node / Outbound Access @@ -162,7 +162,11 @@ Read before deployment: ## Further Reading +- [docs/architecture.md](docs/architecture.md) Service responsibility boundaries and request flow. +- [docs/maintainer-architecture.md](docs/maintainer-architecture.md) Internal module boundaries and maintainer call chains. +- [docs/compatibility.md](docs/compatibility.md) Compatibility-sensitive surface and contract-honesty guidance. - [docs/guide.md](docs/guide.md) Usage guide, transport details, streaming behavior, extensions, and examples. +- [docs/conformance.md](docs/conformance.md) External TCK experiment workflow and artifact handling. - [SECURITY.md](SECURITY.md) Threat model, deployment caveats, and vulnerability disclosure guidance. ## Development diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..711183f --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,101 @@ +# Architecture Guide + +This document explains what `opencode-a2a` is responsible for, what remains inside OpenCode, and how requests move through the adapter boundary. + +## System Role + +`opencode-a2a` is an adapter layer between A2A clients and an OpenCode runtime. + +It is responsible for: + +- exposing A2A-facing HTTP+JSON and JSON-RPC endpoints +- normalizing transport, streaming, session, and interrupt contracts +- applying authentication, logging, persistence, and deployment-side guardrails +- hosting an outbound A2A client for peer calls triggered by CLI usage or `a2a_call` + +It is not responsible for: + +- replacing OpenCode's own provider or model selection internals +- hard multi-tenant isolation inside one shared deployment by default +- acting as a general OpenCode process supervisor + +## Adapter Layers + +```mermaid +flowchart LR + Client["A2A Client / Hub / Gateway"] --> Gateway["opencode-a2a"] + Gateway --> Contracts["A2A contracts\ntransport/session/interrupt"] + Gateway --> Ops["Auth / logging / persistence / deployment boundary"] + Contracts --> OpenCode["OpenCode runtime"] + Ops --> OpenCode + Gateway --> Peer["Embedded outbound A2A client"] +``` + +This view emphasizes service responsibility boundaries rather than internal module structure. The root [README](../README.md) keeps the overview path for first-time readers, while [maintainer-architecture.md](./maintainer-architecture.md) covers module boundaries and request call chains for contributors. + +## Request Flow + +### Standard send/stream flow + +1. A client calls the REST or JSON-RPC endpoint. +2. FastAPI middleware validates auth, request size, protocol version, and logging policy. +3. The adapter maps transport payloads into the OpenCode-facing execution path. +4. The execution layer calls the upstream OpenCode runtime and consumes its stream or unary response. +5. The service maps the result back into shared A2A-facing task, message, and stream event contracts. + +### Streaming flow + +For streaming requests, the adapter does more than simple passthrough: + +- classifies stream blocks into shared types such as `text`, `reasoning`, and `tool_call` +- preserves stable event and message identity where possible +- emits shared interrupt lifecycle state +- avoids duplicate final snapshots when streaming already produced the final text + +Detailed streaming contract: [Usage Guide](guide.md) + +### Session flow + +The service keeps a shared session continuation contract around `metadata.shared.session.id` and adapter-derived `contextId`, so clients can continue work without binding directly to raw OpenCode session IDs. + +### Outbound peer flow + +The same process can also act as an embedded A2A client: + +- CLI calls route through the local client facade +- server-side `a2a_call` tool execution uses the same outbound client settings +- outbound auth, timeouts, and transport preferences are configured independently from inbound auth + +## Boundary Model + +The adapter improves the runtime boundary, but it is not a full trust-boundary replacement. + +### What the adapter boundary helps with + +- stable A2A-facing contract shape +- auth enforcement on A2A entrypoints +- payload logging controls +- lightweight persistence for SDK task rows plus adapter-managed runtime state +- compatibility metadata for clients and operators + +### What still belongs to the OpenCode runtime boundary + +- provider credential consumption +- workspace side effects +- upstream model/provider behavior +- host-level process supervision and isolation + +That is why deployments should still be treated as trusted or controlled unless stronger isolation is added outside this repository. + +## Documentation Split + +Use the docs by responsibility: + +- [README](../README.md): project overview, install/start path, and entry navigation +- [Compatibility Guide](compatibility.md): compatibility-sensitive surface and stability expectations +- [Usage Guide](guide.md): runtime configuration, transport contracts, extensions, and examples +- [Maintainer Architecture Guide](maintainer-architecture.md): internal module boundaries, request call chains, and persistence touchpoints +- [Extension Specifications](extension-specifications.md): stable extension URI/spec index and disclosure policy +- [Conformance Notes](conformance.md): external TCK experiment workflow +- [Contributing Guide](../CONTRIBUTING.md): contributor workflow and validation +- [Security Policy](../SECURITY.md): threat model and disclosure guidance diff --git a/docs/compatibility.md b/docs/compatibility.md new file mode 100644 index 0000000..38ef7f0 --- /dev/null +++ b/docs/compatibility.md @@ -0,0 +1,147 @@ +# Compatibility Guide + +This document explains the compatibility promises `opencode-a2a` currently tries to uphold for A2A consumers, operators, and maintainers. + +## Runtime Support + +- Python versions: 3.11, 3.12, 3.13 +- A2A SDK line: `0.3.x` +- Default advertised protocol line: `0.3` +- Declared supported protocol lines: `0.3`, `1.0` + +The repository pins the SDK version in `pyproject.toml`. Upgrade the SDK deliberately rather than relying on floating dependency resolution. + +## Contract Honesty + +Machine-readable discovery surfaces must reflect actual runtime behavior: + +- public Agent Card +- authenticated extended card +- OpenAPI metadata +- JSON-RPC wire contract +- compatibility profile + +If runtime support is not actually implemented, do not publish it as a supported machine-readable capability. + +Consumer guidance: + +- Treat the core A2A send / stream / task methods as the portable baseline. +- Treat `urn:a2a:*` entries in this repository as shared repo-family conventions, not as a claim that they are part of the A2A core baseline. +- Treat `opencode.*` methods and `metadata.opencode.*` fields as provider-private OpenCode control and discovery surfaces layered on top of the portable A2A baseline. +- Treat [extension-specifications.md](./extension-specifications.md) as the stable URI/spec index, not as the main usage guide. + +## Normative Sources + +When docs or reference material disagree, treat these as normative in this order: + +- runtime behavior validated by tests +- machine-readable discovery output such as Agent Card, authenticated extended card, and OpenAPI metadata +- repository-owned docs in `README.md`, `docs/`, and `CONTRIBUTING.md` + +External TCK runs and local conformance experiments are investigation inputs. They do not override the repository's declared contract by themselves. + +## Compatibility-Sensitive Surface + +This repository still ships as an alpha project. Within that alpha line, these declared surfaces should not drift silently: + +- core A2A send / stream / task methods +- version negotiation and protocol-aware error shaping +- shared session-binding metadata +- shared model-selection metadata +- shared streaming metadata +- declared custom JSON-RPC extension methods +- authenticated extended card and OpenAPI wire-contract metadata + +Changes to those surfaces should be treated as compatibility-sensitive and should include corresponding test updates. + +Service-level behavior layered on top of those core methods should also be declared explicitly when interoperability depends on it. Current examples: + +- `tasks/resubscribe` replay-once behavior for terminal updates +- first-terminal-state-wins task persistence policy +- task-scoped `acceptedOutputModes` negotiation persistence across send / stream / get / resubscribe +- request-body rejection behavior for oversized transport payloads + +## Deployment Profile + +The current service profile is intentionally: + +- single-tenant +- shared-workspace +- adapter boundary around one OpenCode deployment + +One deployed instance should be treated as a single-tenant trust boundary, not as a secure multi-tenant runtime boundary. + +Execution-environment boundary fields published through the runtime profile are declarative deployment metadata. They are not promises that every host-side approval, sandbox escalation, or filesystem change will be reflected live per request. + +## Persistence Compatibility + +Task durability is deployment-dependent: + +- `A2A_TASK_STORE_BACKEND=database` preserves SDK task rows plus adapter-managed session and interrupt state across restarts +- `A2A_TASK_STORE_BACKEND=memory` keeps the service in an ephemeral development profile + +Task-store behavior that should remain stable for clients: + +- once a task reaches a terminal state, later conflicting writes are dropped on a first-terminal-state-wins basis +- task-store I/O failures are surfaced as stable service errors instead of leaking backend-specific exceptions +- accepted output-mode negotiation for a task is persisted with the task so later reads keep the same filtered output contract +- adapter-managed migrations only own adapter state tables; SDK-managed task schema remains SDK-owned + +The default SQLite-first profile is intended for local or controlled single-instance deployments. Wider SQLAlchemy dialect compatibility should be treated as implementation latitude rather than a strong public promise unless explicitly documented later. + +## Extension Stability + +- Shared metadata and extension contracts should stay synchronized across Agent Card, OpenAPI, and runtime behavior. +- Public Agent Card should stay intentionally minimal. Detailed extension params belong in the authenticated extended card and OpenAPI, not back in the anonymous discovery surface. +- Deployment-conditional methods must be declared as conditional rather than silently disappearing. +- `opencode.sessions.prompt_async` input-part passthrough is compatibility-sensitive. Changes to supported part types, passthrough field semantics, or rejection behavior should be treated as wire-level changes. +- `opencode.sessions.shell` is compatibility-sensitive as a deployment-conditional shell snapshot surface. It should not silently widen into a general interactive shell API. +- `opencode.workspaces.*` and `opencode.worktrees.*` are boundary-sensitive and should remain explicitly provider-private, operator-scoped, and deployment-conditional where applicable. +- Interrupt callback and recovery methods are compatibility-sensitive because clients may depend on request ID lifecycle, expiry semantics, and identity scoping. +- Agent Card media modes and `acceptedOutputModes` handling are compatibility-sensitive. Changes to declared chat modes, to task-scoped negotiation persistence, or to `DataPart` -> `TextPart` downgrade behavior should be treated as wire-level changes. +- Agent Card and OpenAPI publication of `protocol_compatibility`, `service_behaviors`, and runtime feature toggles is compatibility-sensitive discoverability surface. + +## Extension Boundary Governance + +When evaluating or evolving `opencode.*` methods, this repository uses the following rules: + +- The adapter may document, validate, route, and normalize stable upstream-facing behavior, but it should not grow into a general replacement for upstream private runtime internals or host-level control planes. +- New `opencode.*` methods default to provider-private extension status. +- Read-only discovery, compatibility-preserving projections, and low-risk control methods are preferred over stronger mutating or destructive provider controls. +- A2A core object mappings should be used only for stable, low-ambiguity read projections. +- Subtask/subagent fan-out, task-tool internals, and similar upstream execution mechanisms should stay framed as upstream runtime behavior even when passthrough compatibility exists. + +Each new extension proposal should answer: + +- what client value exists beyond the current chat/session flow? +- is the upstream behavior stable enough to carry as a maintained contract? +- should the surface be provider-private, deployment-conditional, or excluded? +- are authorization and destructive-side-effect boundaries enforceable? +- can the result shape avoid overfitting OpenCode internals into fake A2A core semantics? + +## Extension Taxonomy + +This repository distinguishes between three layers: + +- core A2A surface + - standard send / stream / task methods +- shared extensions + - repo-family conventions such as session binding, model selection, stream hints, and interrupt callbacks +- OpenCode-specific extensions + - `opencode.*` JSON-RPC methods plus `metadata.opencode.*` + +Important note: + +- `urn:a2a:*` extension URIs used here should be read as shared conventions in this repository family. +- They are not a claim that those extensions are part of the A2A core baseline. +- `opencode.*` methods are intentionally product-specific. They improve OpenCode-aware workflows but should not be assumed to transfer unchanged to unrelated A2A agents. + +## Non-Goals + +This repository does not currently promise: + +- hard multi-tenant isolation inside one instance +- generic provider-auth orchestration on behalf of OpenCode +- a claim that all declared `1.0` protocol surfaces are fully implemented beyond the documented compatibility matrix + +Those areas may evolve later, but they should not be implied by current machine-readable discovery output. diff --git a/docs/extension-specifications.md b/docs/extension-specifications.md index 17a4555..dc5e102 100644 --- a/docs/extension-specifications.md +++ b/docs/extension-specifications.md @@ -1,6 +1,6 @@ # Extension Specifications -This document is the stable specification surface referenced by the extension URIs published in the Agent Card. It is intentionally a compact URI/spec index, not the main consumer guide. For runtime behavior, request/response examples, and client integration guidance, see [`guide.md`](./guide.md). +This document is the stable specification surface referenced by the extension URIs published in the Agent Card. It is intentionally a compact URI/spec index, not the main consumer guide. For runtime behavior, request/response examples, and client integration guidance, see [`guide.md`](./guide.md). For compatibility-sensitive surface and contract-honesty guidance, see [`compatibility.md`](./compatibility.md). ## SDK Compatibility Note diff --git a/docs/guide.md b/docs/guide.md index d8bbb70..e50476e 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1,6 +1,6 @@ # Usage Guide -This guide covers configuration, authentication, API behavior, streaming re-subscription, and A2A client examples. It is the canonical document for implementation-level protocol contracts and JSON-RPC extension details; README stays at overview level. +This guide covers configuration, authentication, API behavior, streaming re-subscription, and A2A client examples. It is the canonical document for implementation-level protocol contracts and JSON-RPC extension details; [README](../README.md) stays at overview level, [architecture.md](./architecture.md) explains the service boundary, [maintainer-architecture.md](./maintainer-architecture.md) covers the internal module view for contributors, and [compatibility.md](./compatibility.md) defines the compatibility-sensitive surface. ## Transport Contracts @@ -199,6 +199,8 @@ If one deployment works while another fails against the same upstream provider, - Main chat requests that explicitly send `configuration.acceptedOutputModes` must stay compatible with the declared chat output modes. - Current main chat requests must continue accepting `text/plain`; requests that only accept `application/json` or other incompatible modes are rejected before execution starts. - `application/json` is additive structured-output support for incremental `tool_call` payloads. It does not guarantee that ordinary assistant prose can always be losslessly represented as JSON, so consumers that expect normal chat text should keep accepting `text/plain`. +- When a client accepts `text/plain` but not `application/json`, structured `tool_call` payloads are downgraded to compact JSON text instead of being silently dropped. +- Accepted output-mode negotiation is persisted as task-scoped metadata so later `tasks/get` and `tasks/resubscribe` reads keep the same filtered response contract as the original `message/send` or `message:stream` request. - Main chat input supports structured A2A `parts` passthrough: - `TextPart` is forwarded as an OpenCode text part. - `FilePart(FileWithBytes)` is forwarded as a `file` part with a `data:` URL. @@ -220,6 +222,8 @@ If one deployment works while another fails against the same upstream provider, - `message.part.delta` and `message.part.updated` are merged per `part_id`; out-of-order deltas are buffered and replayed when the corresponding `part.updated` arrives. - Structured `tool` parts are emitted as `tool_call` blocks backed by `DataPart(data={...})`, while `text` and `reasoning` continue to use `TextPart`. - `tool_call` block payloads are normalized structured objects that may expose fields such as `call_id`, `tool`, `status`, `title`, `subtitle`, `input`, `output`, and `error`. +- If `application/json` is not accepted but `text/plain` is still accepted, those `tool_call` blocks are downgraded to stable compact JSON text so text-only clients retain the same observable state transitions. +- When a request restricts `acceptedOutputModes`, the stream applies the same output filtering before persistence so later task snapshots do not re-expose filtered structured blocks. - Final status event metadata may include normalized token usage at `metadata.shared.usage` with fields such as `input_tokens`, `output_tokens`, `total_tokens`, optional `reasoning_tokens`, optional `cache_tokens.read_tokens` / `cache_tokens.write_tokens`, and optional `cost`. - Usage is extracted from documented info payloads and supported usage parts such as `step-finish`; non-usage parts with similar fields are ignored. - Interrupt events (`permission.asked` / `question.asked`) are mapped to `TaskStatusUpdateEvent(final=false, state=input-required)` with details at `metadata.shared.interrupt`, including `request_id`, interrupt `type`, `phase=asked`, and a normalized minimal callback payload. diff --git a/docs/maintainer-architecture.md b/docs/maintainer-architecture.md new file mode 100644 index 0000000..835fc11 --- /dev/null +++ b/docs/maintainer-architecture.md @@ -0,0 +1,132 @@ +# Maintainer Architecture Guide + +This document describes the internal structure, module boundaries, and main request call chains of `opencode-a2a`. It is intended for maintainers and contributors. Use [architecture.md](./architecture.md) for the higher-level service boundary view and [guide.md](./guide.md) for deployment-facing runtime configuration. + +## Core Component Map + +```mermaid +flowchart TD + subgraph Inbound["Server Layer (src/opencode_a2a/server/)"] + App["application.py (FastAPI assembly)"] + Middleware["middleware.py"] + AgentCard["agent_card.py / openapi.py"] + end + + subgraph Execution["Execution Layer (src/opencode_a2a/execution/)"] + Executor["executor.py"] + Coordinator["coordinator.py"] + StreamRuntime["stream_runtime.py"] + ToolOrch["tool_orchestration.py"] + end + + subgraph Upstream["OpenCode Upstream Layer"] + UpstreamClient["opencode_upstream_client.py"] + Invocation["invocation.py"] + end + + subgraph Extensions["Extension / Contract Layer"] + Jsonrpc["jsonrpc/"] + Contracts["contracts/extensions.py"] + Profile["profile/runtime.py"] + end + + subgraph Persistence["Persistence Layer"] + TaskStore["server/task_store.py"] + StateStore["server/state_store.py"] + Migrations["server/migrations.py"] + end + + Inbound --> Execution + Inbound --> Extensions + Execution --> Upstream + Inbound --> Persistence + Execution --> Persistence +``` + +## Request Call Chain + +### Inbound Message Send / Stream + +1. `server/application.py` assembles the FastAPI app, SDK adapters, and middleware stack. +2. `server/middleware.py` handles auth, request sizing, protocol negotiation, logging, and response headers. +3. The SDK-backed handler delegates execution to `execution/executor.py`. +4. The execution layer coordinates session continuity, stream handling, tool calls, and error translation through: + - `execution/coordinator.py` + - `execution/stream_runtime.py` + - `execution/tool_orchestration.py` +5. `opencode_upstream_client.py` sends requests to the upstream OpenCode runtime. +6. The adapter maps upstream responses back into A2A tasks, messages, artifacts, and stream events. + +### JSON-RPC Extension Path + +1. `jsonrpc/application.py` owns the adapter-specific JSON-RPC application boundary. +2. `jsonrpc/dispatch.py` and handler modules under `jsonrpc/handlers/` route provider-private methods. +3. `contracts/extensions.py` remains the SSOT for extension metadata exposed through Agent Card and OpenAPI. +4. Tests under `tests/contracts/` and `tests/jsonrpc/` guard contract drift. + +### Outbound Peer Call Path + +1. CLI or server-side tool execution asks `server/client_manager.py` for an outbound client. +2. `client/` builds and configures the embedded A2A client facade. +3. Outbound peer responses are normalized before being reintroduced into the local runtime surface. + +## Module Responsibilities + +### Server Layer + +- `server/application.py`: app assembly, route wiring, request handler customization, and top-level lifecycle integration +- `server/middleware.py`: auth, protocol negotiation, payload/body guards, logging, and response decoration +- `server/agent_card.py` / `server/openapi.py`: machine-readable contract publication +- `server/rest_tasks.py`: SDK-owned REST task routes plus adapter-specific list behavior + +### Execution Layer + +- `execution/executor.py`: main orchestration entrypoint +- `execution/coordinator.py`: OpenCode session coordination and request shaping +- `execution/stream_runtime.py` / `execution/stream_events.py`: stream normalization and event conversion +- `execution/tool_orchestration.py`: embedded peer-call tool handling +- `execution/upstream_error_translator.py` / `execution/tool_error_mapping.py`: upstream-facing error normalization + +### Extension and Contract Layer + +- `contracts/extensions.py`: SSOT for extension metadata, compatibility profile, and wire-contract payloads +- `jsonrpc/`: provider-private JSON-RPC extension surface +- `profile/runtime.py`: runtime profile that feeds Agent Card, OpenAPI, and compatibility metadata +- `protocol_versions.py`: protocol normalization and negotiation helpers + +### Persistence Layer + +- `server/task_store.py`: SDK task store construction plus adapter policy wrappers +- `server/state_store.py`: session binding and interrupt repositories +- `server/migrations.py`: adapter-managed state schema migrations + +### Client Layer + +- `client/`: outbound peer card discovery, request context, auth handling, polling fallback, and error mapping + +## Key Persistence Points + +- SDK task rows stored through the configured task store backend +- adapter-managed session binding / ownership state +- interrupt request bindings and tombstones +- pending preferred-session claims + +The custom migration runner owns only adapter-managed state tables; SDK-managed task schema still follows the SDK path. + +## Configuration Layering + +Configuration is handled in [config.py](../src/opencode_a2a/config.py) with `pydantic-settings`. + +- `A2A_*`: inbound runtime, outbound peer client, protocol, persistence, and deployment metadata +- `OPENCODE_*`: upstream OpenCode connection and request behavior + +## Practical Reading Order + +For maintainers new to the codebase, this order usually gives the fastest payoff: + +1. `README.md` +2. `docs/architecture.md` +3. `src/opencode_a2a/server/application.py` +4. `src/opencode_a2a/execution/executor.py` +5. `src/opencode_a2a/contracts/extensions.py` +6. `docs/guide.md` diff --git a/scripts/README.md b/scripts/README.md index c8c98b9..14ef8e0 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -11,7 +11,7 @@ Executable scripts live in this directory. This file is the entry index for the ## Other Scripts -- [`doctor.sh`](./doctor.sh): primary local development regression entrypoint (uv sync + lint + tests + coverage) +- [`doctor.sh`](./doctor.sh): primary local development regression entrypoint (uv sync + dependency compatibility + lint + mypy + tests + coverage + built-wheel smoke test) - [`conformance.sh`](./conformance.sh): local/manual external A2A conformance experiment entrypoint; caches the official TCK, can launch a dummy-backed local SUT, and preserves raw artifacts under `run/conformance/` - [`dependency_health.sh`](./dependency_health.sh): development dependency review entrypoint (`sync`/`pip check` + outdated + dev audit), while blocking CI/publish audits focus on runtime dependencies - [`check_coverage.py`](./check_coverage.py): enforces the overall coverage floor and per-file minimums for critical modules @@ -21,5 +21,6 @@ Executable scripts live in this directory. This file is the entry index for the ## Notes - `doctor.sh` and `dependency_health.sh` intentionally remain separate entrypoints and share common prerequisites through [`health_common.sh`](./health_common.sh). +- `doctor.sh` covers the default local validation baseline, while `dependency_health.sh` remains focused on standalone dependency review and audit flow. - [`.github/dependabot.yml`](../.github/dependabot.yml) prefers a single weekly grouped Dependabot PR for `uv`, while `dependency_health.sh` remains the explicit review/audit entrypoint. - External conformance experiments remain intentionally separate from the default regression path. See [`../docs/conformance.md`](../docs/conformance.md). diff --git a/scripts/doctor.sh b/scripts/doctor.sh index 3968943..f1e254d 100755 --- a/scripts/doctor.sh +++ b/scripts/doctor.sh @@ -9,8 +9,18 @@ run_shared_repo_health_prerequisites "doctor" echo "[doctor] run lint" uv run pre-commit run --all-files +echo "[doctor] run type checks" +uv run mypy src/opencode_a2a + echo "[doctor] run tests" uv run pytest echo "[doctor] enforce coverage policy" uv run python ./scripts/check_coverage.py + +echo "[doctor] build release artifacts" +rm -f dist/opencode_a2a-*.whl dist/opencode_a2a-*.tar.gz +uv build --no-sources + +echo "[doctor] smoke test built wheel" +bash ./scripts/smoke_test_built_cli.sh dist/opencode_a2a-*.whl diff --git a/src/opencode_a2a/client/polling.py b/src/opencode_a2a/client/polling.py index 1f38b0c..7a1fb5a 100644 --- a/src/opencode_a2a/client/polling.py +++ b/src/opencode_a2a/client/polling.py @@ -6,14 +6,8 @@ from a2a.types import TaskState -_TERMINAL_TASK_STATES = frozenset( - { - TaskState.completed, - TaskState.canceled, - TaskState.failed, - TaskState.rejected, - } -) +from ..task_states import TERMINAL_TASK_STATES + _AUTO_POLLING_TASK_STATES = frozenset( { TaskState.submitted, @@ -37,7 +31,7 @@ def should_poll_state(self, state: TaskState) -> bool: return state in _AUTO_POLLING_TASK_STATES def is_terminal_state(self, state: TaskState) -> bool: - return state in _TERMINAL_TASK_STATES + return state in TERMINAL_TASK_STATES def next_interval_seconds(self, current_interval_seconds: float) -> float: return min( diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index e218375..045503a 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -73,7 +73,7 @@ _StreamOutputState, _TTLCache, ) -from .tool_orchestration import handle_a2a_call_tool, maybe_handle_tools, merge_streamed_tool_output +from .tool_orchestration import maybe_handle_tools, merge_streamed_tool_output from .upstream_error_translator import ( _await_stream_terminal_signal, _extract_upstream_error_detail, @@ -186,12 +186,6 @@ async def _maybe_handle_tools( a2a_client_manager=self._a2a_client_manager, ) - async def _handle_a2a_call_tool(self, part: dict[str, Any]) -> dict[str, Any]: - return await handle_a2a_call_tool( - part, - a2a_client_manager=self._a2a_client_manager, - ) - @staticmethod def _merge_streamed_tool_output(current: str, incoming: str) -> str: return merge_streamed_tool_output(current, incoming) diff --git a/src/opencode_a2a/execution/request_context.py b/src/opencode_a2a/execution/request_context.py index f70dcac..815b81e 100644 --- a/src/opencode_a2a/execution/request_context.py +++ b/src/opencode_a2a/execution/request_context.py @@ -6,6 +6,8 @@ from a2a.server.agent_execution import RequestContext from a2a.types import Message +from ..metadata_access import extract_first_namespaced_string + def _build_history(context: RequestContext) -> list[Message]: if context.current_task and context.current_task.history: @@ -17,23 +19,20 @@ def _build_history(context: RequestContext) -> list[Message]: return history -def _iter_metadata_maps(context: RequestContext, namespace: str): +def _metadata_sources(context: RequestContext) -> tuple[Mapping[str, Any] | None, ...]: try: meta = context.metadata except Exception: meta = None + sources: list[Mapping[str, Any] | None] = [] if isinstance(meta, Mapping): - namespaced_meta = meta.get(namespace) - if isinstance(namespaced_meta, Mapping): - yield namespaced_meta - + sources.append(meta) if context.message is not None: msg_meta = getattr(context.message, "metadata", None) or {} if isinstance(msg_meta, Mapping): - namespaced_meta = msg_meta.get(namespace) - if isinstance(namespaced_meta, Mapping): - yield namespaced_meta + sources.append(msg_meta) + return tuple(sources) def _extract_namespaced_string_metadata( @@ -42,21 +41,11 @@ def _extract_namespaced_string_metadata( namespace: str, path: tuple[str, ...], ) -> str | None: - for namespaced_meta in _iter_metadata_maps(context, namespace): - current: Any = namespaced_meta - for part in path[:-1]: - if not isinstance(current, Mapping): - current = None - break - current = current.get(part) - if not isinstance(current, Mapping): - continue - candidate = current.get(path[-1]) - if isinstance(candidate, str): - value = candidate.strip() - if value: - return value - return None + return extract_first_namespaced_string( + _metadata_sources(context), + namespace=namespace, + path=path, + ) def _extract_shared_session_id(context: RequestContext) -> str | None: diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 721b2cf..94b7a08 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -18,6 +18,7 @@ ) from ..invocation import call_with_supported_kwargs +from ..output_modes import part_text_fallback from .event_helpers import _enqueue_artifact_update from .stream_events import ( BlockType, @@ -87,7 +88,20 @@ async def consume( async def _emit_chunks(chunks: list[_NormalizedStreamChunk]) -> None: for chunk in chunks: if not allow_structured_output and getattr(chunk.part.root, "kind", None) == "data": - continue + fallback_text = part_text_fallback(chunk.part.root) + if fallback_text is None: + continue + chunk = _NormalizedStreamChunk( + part=Part(root=TextPart(text=fallback_text)), + content_key=fallback_text, + accumulate_content=False, + append=chunk.append, + block_type=chunk.block_type, + internal_source=chunk.internal_source, + shared_source=chunk.shared_source, + message_id=chunk.message_id, + role=chunk.role, + ) resolved_message_id = stream_state.resolve_message_id(chunk.message_id) chunk_text = getattr(chunk.part.root, "text", "") if stream_state.should_drop_initial_user_echo( diff --git a/src/opencode_a2a/jsonrpc/handlers/common.py b/src/opencode_a2a/jsonrpc/handlers/common.py index 0b23d51..6854144 100644 --- a/src/opencode_a2a/jsonrpc/handlers/common.py +++ b/src/opencode_a2a/jsonrpc/handlers/common.py @@ -1,13 +1,16 @@ from __future__ import annotations import logging +from collections.abc import Awaitable, Callable from typing import Any +import httpx from a2a.types import A2AError, InternalError from starlette.responses import Response from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES -from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ...metadata_access import extract_namespaced_value +from ...opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError from ..dispatch import ExtensionHandlerContext from ..error_responses import ( authorization_forbidden_error, @@ -23,6 +26,69 @@ logger = logging.getLogger(__name__) +class SessionClaimGuard: + def __init__( + self, + context: ExtensionHandlerContext, + *, + identity: str | None, + session_id: str | None, + logger: logging.Logger, + ) -> None: + self._context = context + self._identity = identity + self._session_id = session_id + self._logger = logger + self._pending = False + self._finalized = False + + async def __aenter__(self) -> SessionClaimGuard: + if self._identity and self._session_id: + self._pending = await self._context.session_claim( + identity=self._identity, + session_id=self._session_id, + ) + return self + + async def finalize(self) -> None: + if self._pending and not self._finalized and self._identity and self._session_id: + await self._context.session_claim_finalize( + identity=self._identity, + session_id=self._session_id, + ) + self._finalized = True + + async def __aexit__(self, exc_type, exc, tb) -> bool: # noqa: ANN001 + del exc_type, exc, tb + if self._pending and not self._finalized and self._identity and self._session_id: + try: + await self._context.session_claim_release( + identity=self._identity, + session_id=self._session_id, + ) + except Exception: + self._logger.exception( + "Failed to release pending session claim for session_id=%s", + self._session_id, + ) + return False + + +def claim_session( + context: ExtensionHandlerContext, + *, + identity: str | None, + session_id: str | None, + logger: logging.Logger, +) -> SessionClaimGuard: + return SessionClaimGuard( + context, + identity=identity, + session_id=session_id, + logger=logger, + ) + + def build_success_response( context: ExtensionHandlerContext, request_id: str | int | None, @@ -33,6 +99,30 @@ def build_success_response( return context.success_response(request_id, result) +def reject_unknown_fields( + context: ExtensionHandlerContext, + request_id: str | int | None, + payload: dict[str, Any], + *, + allowed_fields: set[str] | frozenset[str], + field_prefix: str = "", + message_prefix: str = "Unsupported fields", +) -> Response | None: + unknown_fields = sorted(set(payload) - set(allowed_fields)) + if not unknown_fields: + return None + reported_fields = ( + [f"{field_prefix}{field}" for field in unknown_fields] if field_prefix else unknown_fields + ) + return context.error_response( + request_id, + invalid_params_error( + f"{message_prefix}: {', '.join(reported_fields)}", + data={"type": "INVALID_FIELD", "fields": reported_fields}, + ), + ) + + def build_session_forbidden_response( context: ExtensionHandlerContext, request_id: str | int | None, @@ -65,14 +155,18 @@ def build_authorization_forbidden_response( ) -def extract_directory_from_metadata( +def _parse_metadata_objects( context: ExtensionHandlerContext, *, request_id: str | int | None, params: dict[str, Any], -) -> tuple[str | None, Response | None]: + strict_top_level: bool = False, + validate_shared_object: bool = False, +) -> tuple[dict[str, Any] | None, Response | None]: metadata = params.get("metadata") - if metadata is not None and not isinstance(metadata, dict): + if metadata is None: + return None, None + if not isinstance(metadata, dict): return None, context.error_response( request_id, invalid_params_error( @@ -81,8 +175,7 @@ def extract_directory_from_metadata( ), ) - opencode_metadata: dict[str, Any] | None = None - if isinstance(metadata, dict): + if strict_top_level: unknown_metadata_fields = sorted(set(metadata) - {"opencode", "shared"}) if unknown_metadata_fields: prefixed_fields = [f"metadata.{field}" for field in unknown_metadata_fields] @@ -93,17 +186,18 @@ def extract_directory_from_metadata( data={"type": "INVALID_FIELD", "fields": prefixed_fields}, ), ) - raw_opencode_metadata = metadata.get("opencode") - if raw_opencode_metadata is not None and not isinstance(raw_opencode_metadata, dict): - return None, context.error_response( - request_id, - invalid_params_error( - "metadata.opencode must be an object", - data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, - ), - ) - if isinstance(raw_opencode_metadata, dict): - opencode_metadata = raw_opencode_metadata + + raw_opencode_metadata = metadata.get("opencode") + if raw_opencode_metadata is not None and not isinstance(raw_opencode_metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode must be an object", + data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, + ), + ) + + if validate_shared_object: raw_shared_metadata = metadata.get("shared") if raw_shared_metadata is not None and not isinstance(raw_shared_metadata, dict): return None, context.error_response( @@ -114,9 +208,35 @@ def extract_directory_from_metadata( ), ) + return ( + raw_opencode_metadata if isinstance(raw_opencode_metadata, dict) else None, + None, + ) + + +def extract_directory_from_metadata( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: + opencode_metadata, metadata_error = _parse_metadata_objects( + context, + request_id=request_id, + params=params, + strict_top_level=True, + validate_shared_object=True, + ) + if metadata_error is not None: + return None, metadata_error + directory = None if opencode_metadata is not None: - directory = opencode_metadata.get("directory") + directory = extract_namespaced_value( + {"opencode": opencode_metadata}, + namespace="opencode", + path=("directory",), + ) if directory is not None and not isinstance(directory, str): return None, context.error_response( request_id, @@ -135,31 +255,21 @@ def extract_workspace_id_from_metadata( request_id: str | int | None, params: dict[str, Any], ) -> tuple[str | None, Response | None]: - metadata = params.get("metadata") - if metadata is None: - return None, None - if not isinstance(metadata, dict): - return None, context.error_response( - request_id, - invalid_params_error( - "metadata must be an object", - data={"type": "INVALID_FIELD", "field": "metadata"}, - ), - ) - - raw_opencode_metadata = metadata.get("opencode") + raw_opencode_metadata, metadata_error = _parse_metadata_objects( + context, + request_id=request_id, + params=params, + ) + if metadata_error is not None: + return None, metadata_error if raw_opencode_metadata is None: return None, None - if not isinstance(raw_opencode_metadata, dict): - return None, context.error_response( - request_id, - invalid_params_error( - "metadata.opencode must be an object", - data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, - ), - ) - raw_workspace = raw_opencode_metadata.get("workspace") + raw_workspace = extract_namespaced_value( + {"opencode": raw_opencode_metadata}, + namespace="opencode", + path=("workspace",), + ) if raw_workspace is None: return None, None if not isinstance(raw_workspace, dict): @@ -171,7 +281,11 @@ def extract_workspace_id_from_metadata( ), ) - raw_workspace_id = raw_workspace.get("id") + raw_workspace_id = extract_namespaced_value( + {"workspace": raw_workspace}, + namespace="workspace", + path=("id",), + ) if raw_workspace_id is None: return None, None if not isinstance(raw_workspace_id, str): @@ -270,6 +384,112 @@ def extract_interrupt_callback_directory_hint( ) +async def invoke_upstream_or_error( + context: ExtensionHandlerContext, + request_id: str | int | None, + *, + invoke: Callable[[], Awaitable[Any]], + upstream_http_error_code: int, + upstream_unreachable_error_code: int, + internal_log_message: str, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, + on_not_found: Callable[[], Response] | None = None, +) -> tuple[Any | None, Response | None]: + try: + return await invoke(), None + except Exception as exc: + return None, build_upstream_exception_response( + context, + request_id, + exc=exc, + upstream_http_error_code=upstream_http_error_code, + upstream_unreachable_error_code=upstream_unreachable_error_code, + internal_log_message=internal_log_message, + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + on_not_found=on_not_found, + ) + + +def build_upstream_exception_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + *, + exc: Exception, + upstream_http_error_code: int, + upstream_unreachable_error_code: int, + internal_log_message: str, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, + upstream_payload_error_code: int | None = None, + on_not_found: Callable[[], Response] | None = None, + on_permission_error: Callable[[], Response] | None = None, + payload_warning_message: str | None = None, +) -> Response: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code == 404 and on_not_found is not None: + return on_not_found() + return build_upstream_http_error_response( + context, + request_id, + upstream_http_error_code, + upstream_status=exc.response.status_code, + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + ) + + if isinstance(exc, httpx.HTTPError): + return build_upstream_unreachable_error_response( + context, + request_id, + upstream_unreachable_error_code, + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + ) + + if isinstance(exc, UpstreamConcurrencyLimitError): + return build_upstream_concurrency_error_response( + context, + request_id, + upstream_unreachable_error_code, + exc=exc, + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + ) + + if upstream_payload_error_code is not None and isinstance( + exc, (UpstreamContractError, ValueError) + ): + if isinstance(exc, ValueError) and payload_warning_message is not None: + logger.warning("%s: %s", payload_warning_message, exc) + return build_upstream_payload_error_response( + context, + request_id, + upstream_payload_error_code, + detail=str(exc), + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + ) + + if isinstance(exc, PermissionError) and on_permission_error is not None: + return on_permission_error() + + return build_internal_error_response( + context, + request_id, + log_message=internal_log_message, + exc=exc, + ) + + def build_upstream_http_error_response( context: ExtensionHandlerContext, request_id: str | int | None, diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py index 18de2d3..8c4ab45 100644 --- a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py @@ -26,6 +26,7 @@ build_upstream_unreachable_error_response, extract_interrupt_callback_directory_hint, extract_workspace_id_from_metadata, + reject_unknown_fields, ) logger = logging.getLogger(__name__) @@ -139,15 +140,14 @@ async def handle_interrupt_callback_request( allowed_fields = {"request_id", "answers", "metadata"} else: allowed_fields = {"request_id", "metadata"} - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - return context.error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + base_request.id, + params, + allowed_fields=allowed_fields, + ) + if unknown_fields_error is not None: + return unknown_fields_error try: result: dict[str, Any] = { diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py index cfd0d61..bd20f76 100644 --- a/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py @@ -7,8 +7,7 @@ from starlette.responses import Response from ..dispatch import ExtensionHandlerContext -from ..error_responses import invalid_params_error -from .common import build_internal_error_response, build_success_response +from .common import build_internal_error_response, build_success_response, reject_unknown_fields def _binding_to_result_item(binding: Any) -> dict[str, Any]: @@ -29,15 +28,14 @@ async def handle_interrupt_query_request( params: dict[str, Any], request: Request, ) -> Response: - unknown_fields = sorted(params) - if unknown_fields: - return context.error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + base_request.id, + params, + allowed_fields=set(), + ) + if unknown_fields_error is not None: + return unknown_fields_error request_identity = getattr(request.state, "user_identity", None) identity = request_identity.strip() if isinstance(request_identity, str) else "" diff --git a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py index ce4f508..fecb474 100644 --- a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py +++ b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py @@ -3,14 +3,12 @@ import logging from typing import Any -import httpx from a2a.types import JSONRPCRequest from starlette.requests import Request from starlette.responses import Response from ...contracts.extensions import PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES from ...invocation import call_with_supported_kwargs -from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error from ..methods import ( @@ -19,12 +17,10 @@ _normalize_provider_summaries, ) from .common import ( - build_internal_error_response, build_success_response, - build_upstream_concurrency_error_response, - build_upstream_http_error_response, build_upstream_payload_error_response, - build_upstream_unreachable_error_response, + invoke_upstream_or_error, + reject_unknown_fields, resolve_routing_context, ) @@ -47,16 +43,16 @@ async def handle_provider_discovery_request( allowed_fields = {"metadata"} if base_request.method == context.method_list_models: allowed_fields.add("provider_id") - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - prefixed_fields = [f"params.{field}" for field in unknown_fields] - return context.error_response( - base_request.id, - invalid_params_error( - f"Unsupported params fields: {', '.join(prefixed_fields)}", - data={"type": "INVALID_FIELD", "fields": prefixed_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + base_request.id, + params, + allowed_fields=allowed_fields, + field_prefix="params.", + message_prefix="Unsupported params fields", + ) + if unknown_fields_error is not None: + return unknown_fields_error provider_id: str | None = None if base_request.method == context.method_list_models: @@ -80,43 +76,22 @@ async def handle_provider_discovery_request( if routing_error is not None: return routing_error - try: - raw_result = await call_with_supported_kwargs( + raw_result, upstream_error = await invoke_upstream_or_error( + context, + base_request.id, + invoke=lambda: call_with_supported_kwargs( context.upstream_client.list_provider_catalog, directory=directory, workspace_id=workspace_id, - ) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - return build_upstream_http_error_response( - context, - base_request.id, - ERR_DISCOVERY_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - method=base_request.method, - ) - except httpx.HTTPError: - return build_upstream_unreachable_error_response( - context, - base_request.id, - ERR_DISCOVERY_UPSTREAM_UNREACHABLE, - method=base_request.method, - ) - except UpstreamConcurrencyLimitError as exc: - return build_upstream_concurrency_error_response( - context, - base_request.id, - ERR_DISCOVERY_UPSTREAM_UNREACHABLE, - exc=exc, - method=base_request.method, - ) - except Exception as exc: - return build_internal_error_response( - context, - base_request.id, - log_message="OpenCode provider discovery JSON-RPC method failed", - exc=exc, - ) + ), + upstream_http_error_code=ERR_DISCOVERY_UPSTREAM_HTTP_ERROR, + upstream_unreachable_error_code=ERR_DISCOVERY_UPSTREAM_UNREACHABLE, + internal_log_message="OpenCode provider discovery JSON-RPC method failed", + method=base_request.method, + ) + if upstream_error is not None: + return upstream_error + assert raw_result is not None try: raw_providers, default_by_provider, connected = _extract_provider_catalog(raw_result) diff --git a/src/opencode_a2a/jsonrpc/handlers/session_control.py b/src/opencode_a2a/jsonrpc/handlers/session_control.py index e3cf5b1..f5fafcc 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_control.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_control.py @@ -26,13 +26,11 @@ ) from .common import ( build_authorization_forbidden_response, - build_internal_error_response, build_session_forbidden_response, build_success_response, - build_upstream_concurrency_error_response, - build_upstream_http_error_response, - build_upstream_payload_error_response, - build_upstream_unreachable_error_response, + build_upstream_exception_response, + claim_session, + reject_unknown_fields, resolve_routing_context, ) @@ -44,22 +42,36 @@ ERR_UPSTREAM_PAYLOAD_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] +def _shell_audit_outcome(exc: Exception) -> str: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code == 404: + return "upstream_404" + return "upstream_http_error" + if isinstance(exc, httpx.HTTPError): + return "upstream_unreachable" + if isinstance(exc, UpstreamConcurrencyLimitError): + return "upstream_backpressure" + if isinstance(exc, (UpstreamContractError, ValueError)): + return "upstream_payload_error" + if isinstance(exc, PermissionError): + return "forbidden" + return "internal_error" + + async def handle_session_control_request( context: ExtensionHandlerContext, base_request: JSONRPCRequest, params: dict[str, Any], request: Request, ) -> Response: - allowed_fields = {"session_id", "request", "metadata"} - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - return context.error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + base_request.id, + params, + allowed_fields={"session_id", "request", "metadata"}, + ) + if unknown_fields_error is not None: + return unknown_fields_error session_id = params.get("session_id") if not isinstance(session_id, str) or not session_id.strip(): @@ -150,143 +162,77 @@ def _log_shell_audit(outcome: str) -> None: if routing_error is not None: return routing_error - pending_claim = False - claim_finalized = False - if identity: - try: - pending_claim = await context.session_claim( - identity=identity, - session_id=session_id, - ) - except PermissionError: - _log_shell_audit("forbidden") - return build_session_forbidden_response( - context, - base_request.id, - session_id=session_id, - ) - try: - result: dict[str, Any] - if base_request.method == context.method_prompt_async: - await call_with_supported_kwargs( - context.upstream_client.session_prompt_async, - session_id, - request=dict(raw_request), - directory=directory, - workspace_id=workspace_id, - ) - result = {"ok": True, "session_id": session_id} - elif base_request.method == context.method_command: - raw_result = await call_with_supported_kwargs( - context.upstream_client.session_command, - session_id, - request=dict(raw_request), - directory=directory, - workspace_id=workspace_id, - ) - item = _as_a2a_message(session_id, raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/command response could not be mapped " - "to A2A Message" - ) - result = {"item": item} - else: - raw_result = await call_with_supported_kwargs( - context.upstream_client.session_shell, - session_id, - request=dict(raw_request), - directory=directory, - workspace_id=workspace_id, - ) - item = _as_a2a_message(session_id, raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/shell response could not be mapped " - "to A2A Message" - ) - result = {"item": item} - - if pending_claim and identity: - await context.session_claim_finalize( - identity=identity, - session_id=session_id, - ) - claim_finalized = True - _log_shell_audit("success") - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404: - _log_shell_audit("upstream_404") - return context.error_response( - base_request.id, - session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), - ) - _log_shell_audit("upstream_http_error") - return build_upstream_http_error_response( + async with claim_session( context, - base_request.id, - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - method=base_request.method, + identity=identity, session_id=session_id, - ) - except httpx.HTTPError: - _log_shell_audit("upstream_unreachable") - return build_upstream_unreachable_error_response( - context, - base_request.id, - ERR_UPSTREAM_UNREACHABLE, - method=base_request.method, - session_id=session_id, - ) - except UpstreamConcurrencyLimitError as exc: - _log_shell_audit("upstream_backpressure") - return build_upstream_concurrency_error_response( + logger=logger, + ) as session_claim: + result: dict[str, Any] + if base_request.method == context.method_prompt_async: + await call_with_supported_kwargs( + context.upstream_client.session_prompt_async, + session_id, + request=dict(raw_request), + directory=directory, + workspace_id=workspace_id, + ) + result = {"ok": True, "session_id": session_id} + elif base_request.method == context.method_command: + raw_result = await call_with_supported_kwargs( + context.upstream_client.session_command, + session_id, + request=dict(raw_request), + directory=directory, + workspace_id=workspace_id, + ) + item = _as_a2a_message(session_id, raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID}/command response could not be mapped " + "to A2A Message" + ) + result = {"item": item} + else: + raw_result = await call_with_supported_kwargs( + context.upstream_client.session_shell, + session_id, + request=dict(raw_request), + directory=directory, + workspace_id=workspace_id, + ) + item = _as_a2a_message(session_id, raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID}/shell response could not be mapped " + "to A2A Message" + ) + result = {"item": item} + + await session_claim.finalize() + _log_shell_audit("success") + except Exception as exc: + _log_shell_audit(_shell_audit_outcome(exc)) + return build_upstream_exception_response( context, base_request.id, - ERR_UPSTREAM_UNREACHABLE, exc=exc, + upstream_http_error_code=ERR_UPSTREAM_HTTP_ERROR, + upstream_unreachable_error_code=ERR_UPSTREAM_UNREACHABLE, + upstream_payload_error_code=ERR_UPSTREAM_PAYLOAD_ERROR, + internal_log_message="OpenCode session control JSON-RPC method failed", method=base_request.method, session_id=session_id, + on_not_found=lambda: context.error_response( + base_request.id, + session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), + ), + on_permission_error=lambda: build_session_forbidden_response( + context, + base_request.id, + session_id=session_id, + ), ) - except UpstreamContractError as exc: - _log_shell_audit("upstream_payload_error") - return build_upstream_payload_error_response( - context, - base_request.id, - ERR_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - method=base_request.method, - session_id=session_id, - ) - except PermissionError: - _log_shell_audit("forbidden") - return build_session_forbidden_response( - context, - base_request.id, - session_id=session_id, - ) - except Exception as exc: - _log_shell_audit("internal_error") - return build_internal_error_response( - context, - base_request.id, - log_message="OpenCode session control JSON-RPC method failed", - exc=exc, - ) - finally: - if pending_claim and not claim_finalized and identity: - try: - await context.session_claim_release( - identity=identity, - session_id=session_id, - ) - except Exception: - logger.exception( - "Failed to release pending session claim for session_id=%s", - session_id, - ) return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/jsonrpc/handlers/session_lifecycle.py b/src/opencode_a2a/jsonrpc/handlers/session_lifecycle.py index e64a8f9..b558a7c 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_lifecycle.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_lifecycle.py @@ -3,14 +3,13 @@ import logging from typing import Any -import httpx from a2a.types import JSONRPCRequest from starlette.requests import Request from starlette.responses import Response from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES from ...invocation import call_with_supported_kwargs -from ...opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError +from ...opencode_upstream_client import UpstreamContractError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error, session_not_found_error from ..methods import ( @@ -23,13 +22,11 @@ _normalize_todo_items, ) from .common import ( - build_internal_error_response, build_session_forbidden_response, build_success_response, - build_upstream_concurrency_error_response, - build_upstream_http_error_response, - build_upstream_payload_error_response, - build_upstream_unreachable_error_response, + build_upstream_exception_response, + claim_session, + reject_unknown_fields, resolve_routing_context, ) @@ -113,15 +110,15 @@ def _parse_fork_request( request_id, _invalid_field_error("request", "params.request must be an object"), ) - unknown_fields = sorted(set(raw_request) - {"messageID"}) - if unknown_fields: - return {}, context.error_response( - request_id, - invalid_params_error( - f"Unsupported fields: {', '.join(f'request.{field}' for field in unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + request_id, + raw_request, + allowed_fields={"messageID"}, + field_prefix="request.", + ) + if unknown_fields_error is not None: + return {}, unknown_fields_error message_id = raw_request.get("messageID") if message_id is None: return {}, None @@ -146,15 +143,15 @@ def _parse_summarize_request( request_id, _invalid_field_error("request", "params.request must be an object"), ) - unknown_fields = sorted(set(raw_request) - {"providerID", "modelID", "auto"}) - if unknown_fields: - return None, context.error_response( - request_id, - invalid_params_error( - f"Unsupported fields: {', '.join(f'request.{field}' for field in unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + request_id, + raw_request, + allowed_fields={"providerID", "modelID", "auto"}, + field_prefix="request.", + ) + if unknown_fields_error is not None: + return None, unknown_fields_error provider_id = raw_request.get("providerID") model_id = raw_request.get("modelID") auto = raw_request.get("auto") @@ -195,15 +192,15 @@ def _parse_revert_request( request_id, _invalid_field_error("request", "params.request must be an object"), ) - unknown_fields = sorted(set(raw_request) - {"messageID", "partID"}) - if unknown_fields: - return {}, context.error_response( - request_id, - invalid_params_error( - f"Unsupported fields: {', '.join(f'request.{field}' for field in unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + request_id, + raw_request, + allowed_fields={"messageID", "partID"}, + field_prefix="request.", + ) + if unknown_fields_error is not None: + return {}, unknown_fields_error message_id = raw_request.get("messageID") if not isinstance(message_id, str) or not message_id.strip(): return {}, context.error_response( @@ -261,15 +258,14 @@ async def handle_session_lifecycle_request( }: allowed_fields.add("request") - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - return context.error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) + unknown_fields_error = reject_unknown_fields( + context, + base_request.id, + params, + allowed_fields=allowed_fields, + ) + if unknown_fields_error is not None: + return unknown_fields_error directory, directory_error = _parse_directory_hint(context, base_request.id, params) if directory_error is not None: @@ -338,235 +334,189 @@ async def handle_session_lifecycle_request( context.method_unrevert_session, } - pending_claim = False - claim_finalized = False - if method in mutating_methods and session_id is not None and identity: - try: - pending_claim = await context.session_claim(identity=identity, session_id=session_id) - except PermissionError: - return build_session_forbidden_response( - context, - base_request.id, - session_id=session_id, - ) - + claim_identity = identity if method in mutating_methods else None try: - result: dict[str, Any] - forked_session_id: str | None = None - - if method == context.method_session_status: - raw_result = await call_with_supported_kwargs( - context.upstream_client.session_status, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"items": _normalize_session_status_items(raw_result)} - elif method == context.method_get_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.get_session, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - item = _as_a2a_session_task(raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID} response could not be mapped to A2A Task" + async with claim_session( + context, + identity=claim_identity, + session_id=session_id, + logger=logger, + ) as session_claim: + result: dict[str, Any] + forked_session_id: str | None = None + + if method == context.method_session_status: + raw_result = await call_with_supported_kwargs( + context.upstream_client.session_status, + directory=resolved_directory, + workspace_id=workspace_id, ) - result = {"item": item} - elif method == context.method_get_session_children: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.list_child_sessions, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - raw_items = _extract_raw_items(raw_result, kind="child sessions") - result = { - "items": [ - task for item in raw_items if (task := _as_a2a_session_task(item)) is not None - ] - } - elif method == context.method_get_session_todo: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.get_session_todo, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"items": _normalize_todo_items(raw_result)} - elif method == context.method_get_session_diff: - assert session_id is not None - query = {"messageID": message_id} if message_id else None - raw_result = await call_with_supported_kwargs( - context.upstream_client.get_session_diff, - session_id, - params=query, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"items": _normalize_diff_items(raw_result)} - elif method == context.method_get_session_message: - assert session_id is not None - assert message_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.get_message, - session_id, - message_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - item = _as_a2a_message(session_id, raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/message/{messageID} response could not be " - "mapped to A2A Message" + result = {"items": _normalize_session_status_items(raw_result)} + elif method == context.method_get_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.get_session, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, ) - result = {"item": item} - elif method == context.method_fork_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.fork_session, - session_id, - request=fork_request, - directory=resolved_directory, - workspace_id=workspace_id, - ) - item = _normalize_session_summary(raw_result) - forked_session_id = item["id"] - result = {"item": item} - elif method == context.method_share_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.share_session, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"item": _normalize_session_summary(raw_result)} - elif method == context.method_summarize_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.summarize_session, - session_id, - request=summarize_request, - directory=resolved_directory, - workspace_id=workspace_id, - ) - if not isinstance(raw_result, bool): - raise ValueError("Upstream summarize response must be a boolean") - result = {"ok": raw_result, "session_id": session_id} - elif method == context.method_revert_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.revert_session, - session_id, - request=revert_request, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"item": _normalize_session_summary(raw_result)} - elif method == context.method_unrevert_session: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.unrevert_session, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"item": _normalize_session_summary(raw_result)} - else: - assert method == context.method_unshare_session + item = _as_a2a_session_task(raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID} response could not be mapped to A2A Task" + ) + result = {"item": item} + elif method == context.method_get_session_children: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.list_child_sessions, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + raw_items = _extract_raw_items(raw_result, kind="child sessions") + result = { + "items": [ + task + for item in raw_items + if (task := _as_a2a_session_task(item)) is not None + ] + } + elif method == context.method_get_session_todo: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.get_session_todo, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"items": _normalize_todo_items(raw_result)} + elif method == context.method_get_session_diff: + assert session_id is not None + query = {"messageID": message_id} if message_id else None + raw_result = await call_with_supported_kwargs( + context.upstream_client.get_session_diff, + session_id, + params=query, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"items": _normalize_diff_items(raw_result)} + elif method == context.method_get_session_message: + assert session_id is not None + assert message_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.get_message, + session_id, + message_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + item = _as_a2a_message(session_id, raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID}/message/{messageID} response could not " + "be mapped to A2A Message" + ) + result = {"item": item} + elif method == context.method_fork_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.fork_session, + session_id, + request=fork_request, + directory=resolved_directory, + workspace_id=workspace_id, + ) + item = _normalize_session_summary(raw_result) + forked_session_id = item["id"] + result = {"item": item} + elif method == context.method_share_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.share_session, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"item": _normalize_session_summary(raw_result)} + elif method == context.method_summarize_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.summarize_session, + session_id, + request=summarize_request, + directory=resolved_directory, + workspace_id=workspace_id, + ) + if not isinstance(raw_result, bool): + raise ValueError("Upstream summarize response must be a boolean") + result = {"ok": raw_result, "session_id": session_id} + elif method == context.method_revert_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.revert_session, + session_id, + request=revert_request, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"item": _normalize_session_summary(raw_result)} + elif method == context.method_unrevert_session: + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.unrevert_session, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"item": _normalize_session_summary(raw_result)} + else: + assert method == context.method_unshare_session + assert session_id is not None + raw_result = await call_with_supported_kwargs( + context.upstream_client.unshare_session, + session_id, + directory=resolved_directory, + workspace_id=workspace_id, + ) + result = {"item": _normalize_session_summary(raw_result)} + + await session_claim.finalize() + if forked_session_id is not None and identity: + await context.session_claim_finalize( + identity=identity, session_id=forked_session_id + ) + except Exception as exc: + + def _session_not_found_response() -> Response: assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.unshare_session, - session_id, - directory=resolved_directory, - workspace_id=workspace_id, - ) - result = {"item": _normalize_session_summary(raw_result)} - - if pending_claim and identity and session_id is not None: - await context.session_claim_finalize(identity=identity, session_id=session_id) - claim_finalized = True - if forked_session_id is not None and identity: - await context.session_claim_finalize(identity=identity, session_id=forked_session_id) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404 and session_id is not None: return context.error_response( base_request.id, session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), ) - return build_upstream_http_error_response( - context, - base_request.id, - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - method=method, - session_id=session_id, - ) - except httpx.HTTPError: - return build_upstream_unreachable_error_response( - context, - base_request.id, - ERR_UPSTREAM_UNREACHABLE, - method=method, - session_id=session_id, - ) - except UpstreamConcurrencyLimitError as exc: - return build_upstream_concurrency_error_response( + + return build_upstream_exception_response( context, base_request.id, - ERR_UPSTREAM_UNREACHABLE, exc=exc, + upstream_http_error_code=ERR_UPSTREAM_HTTP_ERROR, + upstream_unreachable_error_code=ERR_UPSTREAM_UNREACHABLE, + upstream_payload_error_code=ERR_UPSTREAM_PAYLOAD_ERROR, + internal_log_message="OpenCode session lifecycle JSON-RPC method failed", method=method, session_id=session_id, - ) - except UpstreamContractError as exc: - return build_upstream_payload_error_response( - context, - base_request.id, - ERR_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - method=method, - session_id=session_id, - ) - except ValueError as exc: - logger.warning("Upstream OpenCode payload mismatch: %s", exc) - return build_upstream_payload_error_response( - context, - base_request.id, - ERR_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - method=method, - session_id=session_id, - ) - except PermissionError: - assert session_id is not None - return build_session_forbidden_response( - context, - base_request.id, - session_id=session_id, - ) - except Exception as exc: - return build_internal_error_response( - context, - base_request.id, - log_message="OpenCode session lifecycle JSON-RPC method failed", - exc=exc, - ) - finally: - if pending_claim and not claim_finalized and identity and session_id is not None: - try: - await context.session_claim_release(identity=identity, session_id=session_id) - except Exception: - logger.exception( - "Failed to release pending session claim for session_id=%s", - session_id, + on_not_found=_session_not_found_response if session_id is not None else None, + on_permission_error=( + lambda: build_session_forbidden_response( + context, + base_request.id, + session_id=session_id, ) + ) + if session_id is not None + else None, + payload_warning_message="Upstream OpenCode payload mismatch", + ) return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/jsonrpc/handlers/session_queries.py b/src/opencode_a2a/jsonrpc/handlers/session_queries.py index 35247d4..dc73847 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_queries.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_queries.py @@ -3,14 +3,12 @@ import logging from typing import Any -import httpx from a2a.types import JSONRPCRequest from starlette.requests import Request from starlette.responses import Response from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES from ...invocation import call_with_supported_kwargs -from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error, session_not_found_error from ..methods import ( @@ -25,12 +23,9 @@ parse_list_sessions_params, ) from .common import ( - build_internal_error_response, build_success_response, - build_upstream_concurrency_error_response, - build_upstream_http_error_response, build_upstream_payload_error_response, - build_upstream_unreachable_error_response, + invoke_upstream_or_error, resolve_routing_context, ) @@ -89,57 +84,47 @@ async def handle_session_query_request( ) if routing_error is not None: return routing_error - try: + + def _session_not_found_response() -> Response: + assert session_id is not None + return context.error_response( + base_request.id, + session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), + ) + + async def _invoke_session_query() -> Any: if base_request.method == context.method_list_sessions: - raw_result = await call_with_supported_kwargs( + return await call_with_supported_kwargs( context.upstream_client.list_sessions, params=query, directory=directory, workspace_id=workspace_id, ) - else: - assert session_id is not None - raw_result = await call_with_supported_kwargs( - context.upstream_client.list_messages, - session_id, - params=query, - workspace_id=workspace_id, - ) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404 and base_request.method == context.method_get_session_messages: - assert session_id is not None - return context.error_response( - base_request.id, - session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), - ) - return build_upstream_http_error_response( - context, - base_request.id, - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - ) - except httpx.HTTPError: - return build_upstream_unreachable_error_response( - context, - base_request.id, - ERR_UPSTREAM_UNREACHABLE, - ) - except UpstreamConcurrencyLimitError as exc: - return build_upstream_concurrency_error_response( - context, - base_request.id, - ERR_UPSTREAM_UNREACHABLE, - exc=exc, - ) - except Exception as exc: - return build_internal_error_response( - context, - base_request.id, - log_message="OpenCode session query JSON-RPC method failed", - exc=exc, + assert session_id is not None + return await call_with_supported_kwargs( + context.upstream_client.list_messages, + session_id, + params=query, + workspace_id=workspace_id, ) + raw_result, upstream_error = await invoke_upstream_or_error( + context, + base_request.id, + invoke=_invoke_session_query, + upstream_http_error_code=ERR_UPSTREAM_HTTP_ERROR, + upstream_unreachable_error_code=ERR_UPSTREAM_UNREACHABLE, + internal_log_message="OpenCode session query JSON-RPC method failed", + on_not_found=( + _session_not_found_response + if base_request.method == context.method_get_session_messages + else None + ), + ) + if upstream_error is not None: + return upstream_error + assert raw_result is not None + try: if base_request.method == context.method_list_sessions: raw_items = _extract_raw_items(raw_result, kind="sessions") diff --git a/src/opencode_a2a/jsonrpc/handlers/workspace_control.py b/src/opencode_a2a/jsonrpc/handlers/workspace_control.py index 5111a26..ccae7d2 100644 --- a/src/opencode_a2a/jsonrpc/handlers/workspace_control.py +++ b/src/opencode_a2a/jsonrpc/handlers/workspace_control.py @@ -3,7 +3,6 @@ import logging from typing import Any -import httpx from a2a.types import JSONRPCRequest from starlette.requests import Request from starlette.responses import Response @@ -13,17 +12,14 @@ request_has_capability, ) from ...contracts.extensions import WORKSPACE_CONTROL_ERROR_BUSINESS_CODES -from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error from .common import ( build_authorization_forbidden_response, - build_internal_error_response, build_success_response, - build_upstream_concurrency_error_response, - build_upstream_http_error_response, build_upstream_payload_error_response, - build_upstream_unreachable_error_response, + invoke_upstream_or_error, + reject_unknown_fields, ) logger = logging.getLogger(__name__) @@ -33,74 +29,145 @@ ERR_UPSTREAM_PAYLOAD_ERROR = WORKSPACE_CONTROL_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] -def _parse_optional_request_object( - params: dict[str, Any], +def _invalid_field_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + *, + field: str, + message: str, +) -> Response: + return context.error_response( + request_id, + invalid_params_error(message, data={"type": "INVALID_FIELD", "field": field}), + ) + + +def _missing_field_response( + context: ExtensionHandlerContext, + request_id: str | int | None, *, - required: bool, -) -> dict[str, Any] | None: + field: str, + error_field: str | None = None, +) -> Response: + return context.error_response( + request_id, + invalid_params_error( + f"Missing required params.{field}", + data={"type": "MISSING_FIELD", "field": error_field or field}, + ), + ) + + +def _parse_required_request_object( + context: ExtensionHandlerContext, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[dict[str, Any] | None, Response | None]: value = params.get("request") if value is None: - if required: - raise ValueError("Missing required params.request") - return None + return None, _missing_field_response(context, request_id, field="request") if not isinstance(value, dict): - raise TypeError("params.request must be an object") - return dict(value) + return None, _invalid_field_response( + context, + request_id, + field="request", + message="params.request must be an object", + ) + return dict(value), None -def _parse_workspace_id(params: dict[str, Any]) -> str: +def _parse_workspace_id( + context: ExtensionHandlerContext, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: raw_workspace_id = params.get("workspace_id") if not isinstance(raw_workspace_id, str) or not raw_workspace_id.strip(): - raise ValueError("Missing required params.workspace_id") - return raw_workspace_id.strip() + return None, _missing_field_response(context, request_id, field="workspace_id") + return raw_workspace_id.strip(), None -def _validate_workspace_request(method: str, request: dict[str, Any]) -> None: +def _validate_workspace_request( + context: ExtensionHandlerContext, + request_id: str | int | None, + method: str, + request: dict[str, Any], +) -> Response | None: if method == "create_workspace": allowed_fields = {"id", "type", "branch", "extra"} if "type" not in request: - raise ValueError("Missing required params.request.type") + return _missing_field_response( + context, + request_id, + field="request.type", + error_field="request", + ) request_type = request.get("type") if not isinstance(request_type, str) or not request_type.strip(): - raise TypeError("params.request.type must be a non-empty string") + return _invalid_field_response( + context, + request_id, + field="request.type", + message="params.request.type must be a non-empty string", + ) elif method == "create_worktree": allowed_fields = {"name", "startCommand"} elif method in {"remove_worktree", "reset_worktree"}: allowed_fields = {"directory"} directory = request.get("directory") if not isinstance(directory, str) or not directory.strip(): - raise TypeError("params.request.directory must be a non-empty string") + return _invalid_field_response( + context, + request_id, + field="request.directory", + message="params.request.directory must be a non-empty string", + ) else: allowed_fields = set() - unknown_fields = sorted(set(request) - allowed_fields) - if unknown_fields: - raise ValueError( - "Unsupported request fields: " - + ", ".join(f"request.{field}" for field in unknown_fields) - ) + unknown_fields_error = reject_unknown_fields( + context, + request_id, + request, + allowed_fields=allowed_fields, + field_prefix="request.", + message_prefix="Unsupported request fields", + ) + if unknown_fields_error is not None: + return unknown_fields_error for field in ("id", "type", "branch", "name", "startCommand", "directory"): if field not in request: continue value = request[field] if value is not None and not isinstance(value, str): - raise TypeError(f"params.request.{field} must be a string") + return _invalid_field_response( + context, + request_id, + field=f"request.{field}", + message=f"params.request.{field} must be a string", + ) + return None def _validate_allowed_fields( + context: ExtensionHandlerContext, + request_id: str | int | None, method: str, params: dict[str, Any], -) -> None: +) -> Response | None: allowed_fields = {"metadata"} if method in {"create_workspace", "create_worktree", "remove_worktree", "reset_worktree"}: allowed_fields.add("request") if method == "remove_workspace": allowed_fields.add("workspace_id") - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - raise ValueError("Unsupported fields: " + ", ".join(unknown_fields)) + return reject_unknown_fields( + context, + request_id, + params, + allowed_fields=allowed_fields, + ) def _validate_response_payload(method: str, payload: Any) -> dict[str, Any]: @@ -168,86 +235,71 @@ async def handle_workspace_control_request( error_code=WORKSPACE_CONTROL_ERROR_BUSINESS_CODES["AUTHORIZATION_FORBIDDEN"], ) - try: - _validate_allowed_fields(method_key, params) - request_body: dict[str, Any] | None = None - workspace_id: str | None = None - if method_key == "remove_workspace": - workspace_id = _parse_workspace_id(params) - elif method_key in { - "create_workspace", - "create_worktree", - "remove_worktree", - "reset_worktree", - }: - request_body = _parse_optional_request_object( - params, - required=True, - ) - assert request_body is not None - _validate_workspace_request(method_key, request_body) - except ValueError as exc: - field = "workspace_id" if "workspace_id" in str(exc) else "request" - return context.error_response( - base_request.id, - invalid_params_error(str(exc), data={"type": "INVALID_FIELD", "field": field}), - ) - except TypeError as exc: - return context.error_response( - base_request.id, - invalid_params_error(str(exc), data={"type": "INVALID_FIELD"}), - ) + allowed_fields_error = _validate_allowed_fields(context, base_request.id, method_key, params) + if allowed_fields_error is not None: + return allowed_fields_error - try: - if method_key == "list_projects": - raw_result = await context.upstream_client.list_projects() - elif method_key == "get_current_project": - raw_result = await context.upstream_client.get_current_project() - elif method_key == "list_workspaces": - raw_result = await context.upstream_client.list_workspaces() - elif method_key == "create_workspace": - raw_result = await context.upstream_client.create_workspace(request_body or {}) - elif method_key == "remove_workspace": - assert workspace_id is not None - raw_result = await context.upstream_client.remove_workspace(workspace_id) - elif method_key == "list_worktrees": - raw_result = await context.upstream_client.list_worktrees() - elif method_key == "create_worktree": - raw_result = await context.upstream_client.create_worktree(request_body or {}) - elif method_key == "remove_worktree": - raw_result = await context.upstream_client.remove_worktree(request_body or {}) - else: - raw_result = await context.upstream_client.reset_worktree(request_body or {}) - except httpx.HTTPStatusError as exc: - return build_upstream_http_error_response( - context, - base_request.id, - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=exc.response.status_code, - method=base_request.method, - ) - except httpx.HTTPError: - return build_upstream_unreachable_error_response( + request_body: dict[str, Any] | None = None + workspace_id: str | None = None + if method_key == "remove_workspace": + workspace_id, workspace_error = _parse_workspace_id(context, base_request.id, params) + if workspace_error is not None: + return workspace_error + elif method_key in { + "create_workspace", + "create_worktree", + "remove_worktree", + "reset_worktree", + }: + request_body, request_error = _parse_required_request_object( context, base_request.id, - ERR_UPSTREAM_UNREACHABLE, - method=base_request.method, + params, ) - except UpstreamConcurrencyLimitError as exc: - return build_upstream_concurrency_error_response( + if request_error is not None: + return request_error + assert request_body is not None + request_validation_error = _validate_workspace_request( context, base_request.id, - ERR_UPSTREAM_UNREACHABLE, - exc=exc, - method=base_request.method, - ) - except Exception as exc: - return build_internal_error_response( - context, - base_request.id, - log_message="OpenCode workspace control JSON-RPC method failed", - exc=exc, + method_key, + request_body, ) + if request_validation_error is not None: + return request_validation_error + + async def _invoke_workspace_method() -> Any: + if method_key == "list_projects": + return await context.upstream_client.list_projects() + if method_key == "get_current_project": + return await context.upstream_client.get_current_project() + if method_key == "list_workspaces": + return await context.upstream_client.list_workspaces() + if method_key == "create_workspace": + return await context.upstream_client.create_workspace(request_body or {}) + if method_key == "remove_workspace": + assert workspace_id is not None + return await context.upstream_client.remove_workspace(workspace_id) + if method_key == "list_worktrees": + return await context.upstream_client.list_worktrees() + if method_key == "create_worktree": + return await context.upstream_client.create_worktree(request_body or {}) + if method_key == "remove_worktree": + return await context.upstream_client.remove_worktree(request_body or {}) + return await context.upstream_client.reset_worktree(request_body or {}) + + raw_result, upstream_error = await invoke_upstream_or_error( + context, + base_request.id, + invoke=_invoke_workspace_method, + upstream_http_error_code=ERR_UPSTREAM_HTTP_ERROR, + upstream_unreachable_error_code=ERR_UPSTREAM_UNREACHABLE, + internal_log_message="OpenCode workspace control JSON-RPC method failed", + method=base_request.method, + ) + if upstream_error is not None: + return upstream_error + assert raw_result is not None try: result = _validate_response_payload(method_key, raw_result) diff --git a/src/opencode_a2a/jsonrpc/methods.py b/src/opencode_a2a/jsonrpc/methods.py index ba67043..6f29d14 100644 --- a/src/opencode_a2a/jsonrpc/methods.py +++ b/src/opencode_a2a/jsonrpc/methods.py @@ -57,6 +57,72 @@ def _raise_prompt_async_validation_error(*, field: str, message: str) -> None: raise _PromptAsyncValidationError(field=field, message=message) +def _validate_allowed_request_fields( + value: dict[str, Any], + *, + allowed_fields: tuple[str, ...] | frozenset[str], +) -> None: + unknown_fields = sorted(set(value) - set(allowed_fields)) + if not unknown_fields: + return + joined = ", ".join(f"request.{field}" for field in unknown_fields) + _raise_prompt_async_validation_error( + field="request", + message=f"Unsupported fields: {joined}", + ) + + +def _validate_optional_message_id(value: Any, *, field: str) -> None: + if value is not None and (not isinstance(value, str) or not value.startswith("msg")): + _raise_prompt_async_validation_error( + field=field, + message=f"{field} must be a string starting with 'msg'", + ) + + +def _validate_optional_string_fields( + value: dict[str, Any], + *, + field_names: tuple[str, ...], +) -> None: + for field_name in field_names: + field_value = value.get(field_name) + if field_value is not None and not isinstance(field_value, str): + _raise_prompt_async_validation_error( + field=f"request.{field_name}", + message=f"request.{field_name} must be a string", + ) + + +def _validate_required_non_empty_string_fields( + value: dict[str, Any], + *, + field_names: tuple[str, ...], +) -> None: + for field_name in field_names: + field_value = value.get(field_name) + if not isinstance(field_value, str) or not field_value.strip(): + _raise_prompt_async_validation_error( + field=f"request.{field_name}", + message=f"request.{field_name} must be a non-empty string", + ) + + +def _validate_parts_array( + value: Any, + *, + field: str, + part_validator: Any, +) -> None: + if not isinstance(value, list): + _raise_prompt_async_validation_error( + field=field, + message=f"{field} must be an array", + ) + for index, part in enumerate(cast(list[Any], value)): + part_validator(part, field=f"{field}[{index}]") + + def _validate_model_ref(value: Any, *, field: str) -> None: if not isinstance(value, dict): _raise_prompt_async_validation_error(field=field, message=f"{field} must be an object") @@ -159,34 +225,13 @@ def _validate_prompt_async_part(value: Any, *, field: str) -> None: def _validate_prompt_async_request_payload(value: dict[str, Any]) -> None: - allowed_fields = set(PROMPT_ASYNC_REQUEST_ALLOWED_FIELDS) - unknown_fields = sorted(set(value) - allowed_fields) - if unknown_fields: - joined = ", ".join(f"request.{field}" for field in unknown_fields) - _raise_prompt_async_validation_error( - field="request", - message=f"Unsupported fields: {joined}", - ) + _validate_allowed_request_fields(value, allowed_fields=PROMPT_ASYNC_REQUEST_ALLOWED_FIELDS) + _validate_optional_message_id(value.get("messageID"), field="request.messageID") - message_id = value.get("messageID") - if message_id is not None: - if not isinstance(message_id, str) or not message_id.startswith("msg"): - _raise_prompt_async_validation_error( - field="request.messageID", - message="request.messageID must be a string starting with 'msg'", - ) - - model = value.get("model") - if model is not None: + if (model := value.get("model")) is not None: _validate_model_ref(model, field="request.model") - for key in ("agent", "system", "variant"): - data = value.get(key) - if data is not None and not isinstance(data, str): - _raise_prompt_async_validation_error( - field=f"request.{key}", - message=f"request.{key} must be a string", - ) + _validate_optional_string_fields(value, field_names=("agent", "system", "variant")) no_reply = value.get("noReply") if no_reply is not None and not isinstance(no_reply, bool): @@ -218,15 +263,11 @@ def _validate_prompt_async_request_payload(value: dict[str, Any]) -> None: if fmt is not None: _validate_prompt_async_format(fmt, field="request.format") - parts = value.get("parts") - if not isinstance(parts, list): - _raise_prompt_async_validation_error( - field="request.parts", - message="request.parts must be an array", - ) - parts_list = cast(list[Any], parts) - for index, part in enumerate(parts_list): - _validate_prompt_async_part(part, field=f"request.parts[{index}]") + _validate_parts_array( + value.get("parts"), + field="request.parts", + part_validator=_validate_prompt_async_part, + ) def _validate_command_part(value: Any, *, field: str) -> None: @@ -248,75 +289,24 @@ def _validate_command_part(value: Any, *, field: str) -> None: def _validate_command_request_payload(value: dict[str, Any]) -> None: - allowed_fields = set(COMMAND_REQUEST_ALLOWED_FIELDS) - unknown_fields = sorted(set(value) - allowed_fields) - if unknown_fields: - joined = ", ".join(f"request.{field}" for field in unknown_fields) - _raise_prompt_async_validation_error( - field="request", - message=f"Unsupported fields: {joined}", - ) + _validate_allowed_request_fields(value, allowed_fields=COMMAND_REQUEST_ALLOWED_FIELDS) + _validate_required_non_empty_string_fields(value, field_names=("command", "arguments")) + _validate_optional_message_id(value.get("messageID"), field="request.messageID") - for key in ("command", "arguments"): - item = value.get(key) - if not isinstance(item, str) or not item.strip(): - _raise_prompt_async_validation_error( - field=f"request.{key}", - message=f"request.{key} must be a non-empty string", - ) - - message_id = value.get("messageID") - if message_id is not None: - if not isinstance(message_id, str) or not message_id.startswith("msg"): - _raise_prompt_async_validation_error( - field="request.messageID", - message="request.messageID must be a string starting with 'msg'", - ) - - model = value.get("model") - if model is not None: + if (model := value.get("model")) is not None: _validate_model_ref(model, field="request.model") - for key in ("agent", "variant"): - data = value.get(key) - if data is not None and not isinstance(data, str): - _raise_prompt_async_validation_error( - field=f"request.{key}", - message=f"request.{key} must be a string", - ) + _validate_optional_string_fields(value, field_names=("agent", "variant")) parts = value.get("parts") if parts is not None: - if not isinstance(parts, list): - _raise_prompt_async_validation_error( - field="request.parts", - message="request.parts must be an array", - ) - parts_list = cast(list[Any], parts) - for index, part in enumerate(parts_list): - _validate_command_part(part, field=f"request.parts[{index}]") + _validate_parts_array(parts, field="request.parts", part_validator=_validate_command_part) def _validate_shell_request_payload(value: dict[str, Any]) -> None: - allowed_fields = set(SHELL_REQUEST_ALLOWED_FIELDS) - unknown_fields = sorted(set(value) - allowed_fields) - if unknown_fields: - joined = ", ".join(f"request.{field}" for field in unknown_fields) - _raise_prompt_async_validation_error( - field="request", - message=f"Unsupported fields: {joined}", - ) - - for key in ("agent", "command"): - item = value.get(key) - if not isinstance(item, str) or not item.strip(): - _raise_prompt_async_validation_error( - field=f"request.{key}", - message=f"request.{key} must be a non-empty string", - ) - - model = value.get("model") - if model is not None: + _validate_allowed_request_fields(value, allowed_fields=SHELL_REQUEST_ALLOWED_FIELDS) + _validate_required_non_empty_string_fields(value, field_names=("agent", "command")) + if (model := value.get("model")) is not None: _validate_model_ref(model, field="request.model") diff --git a/src/opencode_a2a/jsonrpc/params.py b/src/opencode_a2a/jsonrpc/params.py index ed08cfd..3c3f1e7 100644 --- a/src/opencode_a2a/jsonrpc/params.py +++ b/src/opencode_a2a/jsonrpc/params.py @@ -7,6 +7,15 @@ SESSION_QUERY_MAX_LIMIT, SESSION_QUERY_PAGINATION_UNSUPPORTED, ) +from ..parsing import ( + parse_bool_field as parse_shared_bool_field, +) +from ..parsing import ( + parse_int_field as parse_shared_int_field, +) +from ..parsing import ( + parse_string_field as parse_shared_string_field, +) class JsonRpcParamsValidationError(ValueError): @@ -15,94 +24,44 @@ def __init__(self, *, message: str, data: dict[str, Any]) -> None: self.data = data +def _validation_error(field: str, message: str) -> JsonRpcParamsValidationError: + return JsonRpcParamsValidationError( + message=message, + data={"type": "INVALID_FIELD", "field": field}, + ) + + def _parse_positive_int(value: Any, *, field: str) -> int | None: - if value is None: - return None - if isinstance(value, bool): - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) - if isinstance(value, int): - parsed = value - elif isinstance(value, str): - try: - parsed = int(value) - except ValueError as exc: - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) from exc - else: - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) - if parsed < 1: - raise JsonRpcParamsValidationError( - message=f"{field} must be >= 1", - data={"type": "INVALID_FIELD", "field": field}, - ) - return parsed + return parse_shared_int_field( + value, + field=field, + error_factory=_validation_error, + minimum=1, + ) def _parse_non_negative_int(value: Any, *, field: str) -> int | None: - if value is None: - return None - if isinstance(value, bool): - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) - if isinstance(value, int): - parsed = value - elif isinstance(value, str): - try: - parsed = int(value) - except ValueError as exc: - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) from exc - else: - raise JsonRpcParamsValidationError( - message=f"{field} must be an integer", - data={"type": "INVALID_FIELD", "field": field}, - ) - if parsed < 0: - raise JsonRpcParamsValidationError( - message=f"{field} must be >= 0", - data={"type": "INVALID_FIELD", "field": field}, - ) - return parsed + return parse_shared_int_field( + value, + field=field, + error_factory=_validation_error, + minimum=0, + ) def _parse_string_field(value: Any, *, field: str) -> str | None: - if value is None: - return None - if not isinstance(value, str): - raise JsonRpcParamsValidationError( - message=f"{field} must be a string", - data={"type": "INVALID_FIELD", "field": field}, - ) - normalized = value.strip() - return normalized or None + return parse_shared_string_field( + value, + field=field, + error_factory=_validation_error, + ) def _parse_bool_field(value: Any, *, field: str) -> bool | None: - if value is None: - return None - if isinstance(value, bool): - return value - if isinstance(value, str): - normalized = value.strip().lower() - if normalized in {"true", "1", "yes", "on"}: - return True - if normalized in {"false", "0", "no", "off"}: - return False - raise JsonRpcParamsValidationError( - message=f"{field} must be a boolean", - data={"type": "INVALID_FIELD", "field": field}, + return parse_shared_bool_field( + value, + field=field, + error_factory=_validation_error, ) diff --git a/src/opencode_a2a/metadata_access.py b/src/opencode_a2a/metadata_access.py new file mode 100644 index 0000000..f655e8f --- /dev/null +++ b/src/opencode_a2a/metadata_access.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any + + +def extract_namespaced_value( + source: Mapping[str, Any] | None, + *, + namespace: str, + path: tuple[str, ...], +) -> Any | None: + if not isinstance(source, Mapping): + return None + + current: Any = source.get(namespace) + if not isinstance(current, Mapping): + return None + + for part in path: + if not isinstance(current, Mapping): + return None + current = current.get(part) + return current + + +def extract_first_namespaced_string( + sources: Iterable[Mapping[str, Any] | None], + *, + namespace: str, + path: tuple[str, ...], +) -> str | None: + for source in sources: + candidate = extract_namespaced_value(source, namespace=namespace, path=path) + if isinstance(candidate, str): + value = candidate.strip() + if value: + return value + return None diff --git a/src/opencode_a2a/output_modes.py b/src/opencode_a2a/output_modes.py index 323cd01..5a488b6 100644 --- a/src/opencode_a2a/output_modes.py +++ b/src/opencode_a2a/output_modes.py @@ -1,14 +1,52 @@ from __future__ import annotations -from collections.abc import Collection -from typing import Any +import asyncio +import json +from collections.abc import Collection, Iterable +from typing import Any, cast +from a2a.server.events import EventConsumer +from a2a.server.tasks import ResultAggregator, TaskManager +from a2a.types import ( + Artifact, + DataPart, + FilePart, + Message, + Part, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatusUpdateEvent, + TextPart, +) + +OUTPUT_NEGOTIATION_METADATA_KEY = "output_negotiation" +OUTPUT_NEGOTIATION_ACCEPTED_OUTPUT_MODES_FIELD = "accepted_output_modes" +_OPENCODE_METADATA_KEY = "opencode" +_APPLICATION_JSON_MEDIA_TYPE = "application/json" +_TEXT_PLAIN_MEDIA_TYPE = "text/plain" + + +def _accepted_output_modes_source(source: Any) -> Iterable[str] | None: + if source is None: + return None -def normalize_accepted_output_modes(source: Any) -> tuple[str, ...] | None: accepted = getattr(source, "accepted_output_modes", None) or getattr( source, "acceptedOutputModes", None ) - if not isinstance(accepted, list): + if accepted is not None: + source = accepted + + if isinstance(source, str | bytes | bytearray | dict): + return None + if not isinstance(source, Iterable): + return None + return cast(Iterable[str], source) + + +def normalize_accepted_output_modes(source: Any) -> tuple[str, ...] | None: + accepted = _accepted_output_modes_source(source) + if accepted is None: return None normalized: list[str] = [] @@ -18,6 +56,8 @@ def normalize_accepted_output_modes(source: Any) -> tuple[str, ...] | None: mode = value.strip().lower() if not mode or mode in normalized: continue + if mode in {"*", "*/*"}: + return None normalized.append(mode) return tuple(normalized) or None @@ -27,3 +67,320 @@ def accepts_output_mode( media_type: str, ) -> bool: return accepted_output_modes is None or media_type in accepted_output_modes + + +def part_text_fallback(part: Any) -> str | None: + if isinstance(part, TextPart): + return part.text + if isinstance(part, DataPart): + return json.dumps(part.data, ensure_ascii=True, sort_keys=True, separators=(",", ":")) + return None + + +def build_output_negotiation_metadata( + accepted_output_modes: Iterable[str] | None, +) -> dict[str, Any] | None: + normalized = normalize_accepted_output_modes(accepted_output_modes) + if normalized is None: + return None + return { + _OPENCODE_METADATA_KEY: { + OUTPUT_NEGOTIATION_METADATA_KEY: { + OUTPUT_NEGOTIATION_ACCEPTED_OUTPUT_MODES_FIELD: sorted(normalized), + } + } + } + + +def merge_output_negotiation_metadata( + metadata: dict[str, Any] | None, + accepted_output_modes: Iterable[str] | None, +) -> dict[str, Any] | None: + negotiation_metadata = build_output_negotiation_metadata(accepted_output_modes) + if negotiation_metadata is None: + return metadata + + merged = dict(metadata) if isinstance(metadata, dict) else {} + opencode_metadata = merged.get(_OPENCODE_METADATA_KEY) + if not isinstance(opencode_metadata, dict): + opencode_metadata = {} + else: + opencode_metadata = dict(opencode_metadata) + + opencode_metadata[OUTPUT_NEGOTIATION_METADATA_KEY] = dict( + cast( + dict[str, Any], + negotiation_metadata[_OPENCODE_METADATA_KEY][OUTPUT_NEGOTIATION_METADATA_KEY], + ) + ) + merged[_OPENCODE_METADATA_KEY] = opencode_metadata + return merged + + +def extract_accepted_output_modes_from_metadata( + metadata: dict[str, Any] | None, +) -> tuple[str, ...] | None: + if not isinstance(metadata, dict): + return None + opencode_metadata = metadata.get(_OPENCODE_METADATA_KEY) + if not isinstance(opencode_metadata, dict): + return None + negotiation_metadata = opencode_metadata.get(OUTPUT_NEGOTIATION_METADATA_KEY) + if not isinstance(negotiation_metadata, dict): + return None + accepted_output_modes = negotiation_metadata.get(OUTPUT_NEGOTIATION_ACCEPTED_OUTPUT_MODES_FIELD) + return normalize_accepted_output_modes(accepted_output_modes) + + +def annotate_output_negotiation_metadata( + payload: Any, + accepted_output_modes: Iterable[str] | None, +) -> Any: + normalized = normalize_accepted_output_modes(accepted_output_modes) + if normalized is None: + return payload + + if isinstance(payload, Task): + return payload.model_copy( + update={"metadata": merge_output_negotiation_metadata(payload.metadata, normalized)} + ) + + if isinstance(payload, TaskStatusUpdateEvent): + return payload.model_copy( + update={"metadata": merge_output_negotiation_metadata(payload.metadata, normalized)} + ) + + if isinstance(payload, TaskArtifactUpdateEvent): + return payload.model_copy( + update={"metadata": merge_output_negotiation_metadata(payload.metadata, normalized)} + ) + + return payload + + +def apply_accepted_output_modes( + payload: Any, + accepted_output_modes: Iterable[str] | None, +) -> Any | None: + normalized = normalize_accepted_output_modes(accepted_output_modes) + if normalized is None: + return payload + + if isinstance(payload, TaskArtifactUpdateEvent): + artifact = _filter_artifact(payload.artifact, normalized) + if artifact is None: + return None + return payload.model_copy(update={"artifact": artifact}) + + if isinstance(payload, TaskStatusUpdateEvent): + status = payload.status + return payload.model_copy( + update={ + "status": status.model_copy( + update={"message": _filter_optional_message(status.message, normalized)} + ) + } + ) + + if isinstance(payload, Task): + return _filter_task(payload, normalized) + + if isinstance(payload, Message): + filtered = _filter_message(payload, normalized) + if filtered is not None: + return filtered + return payload.model_copy(update={"parts": []}) + + return payload + + +class NegotiatingResultAggregator(ResultAggregator): + def __init__( + self, + task_manager: TaskManager, + accepted_output_modes: Iterable[str] | None, + ) -> None: + super().__init__(task_manager) + self._accepted_output_modes = normalize_accepted_output_modes(accepted_output_modes) + + def _transform_event(self, event: Any) -> Any | None: + negotiated_event = apply_accepted_output_modes(event, self._accepted_output_modes) + if negotiated_event is None: + return None + return annotate_output_negotiation_metadata(negotiated_event, self._accepted_output_modes) + + async def _persist_output_negotiation_metadata(self, event: Any) -> None: + if not isinstance(event, TaskArtifactUpdateEvent): + return + + accepted_output_modes = extract_accepted_output_modes_from_metadata(event.metadata) + if accepted_output_modes is None: + return + + task = await self.task_manager.ensure_task(event) + merged_metadata = merge_output_negotiation_metadata(task.metadata, accepted_output_modes) + if merged_metadata == task.metadata: + return + task.metadata = merged_metadata + await self.task_manager._save_task(task) + + async def consume_and_emit(self, consumer: EventConsumer): # noqa: ANN201 + async for event in consumer.consume_all(): + transformed_event = self._transform_event(event) + if transformed_event is None: + continue + await self._persist_output_negotiation_metadata(transformed_event) + await self.task_manager.process(transformed_event) + yield transformed_event + + async def consume_all(self, consumer: EventConsumer) -> Task | Message | None: + async for event in consumer.consume_all(): + transformed_event = self._transform_event(event) + if transformed_event is None: + continue + if isinstance(transformed_event, Message): + self._message = transformed_event + return transformed_event + await self._persist_output_negotiation_metadata(transformed_event) + await self.task_manager.process(transformed_event) + return await self.task_manager.get_task() + + async def consume_and_break_on_interrupt( + self, + consumer: EventConsumer, + blocking: bool = True, + event_callback=None, # noqa: ANN001 + ) -> tuple[Task | Message | None, bool, asyncio.Task | None]: + event_stream = consumer.consume_all() + interrupted = False + bg_task: asyncio.Task | None = None + async for event in event_stream: + transformed_event = self._transform_event(event) + if transformed_event is None: + continue + if isinstance(transformed_event, Message): + self._message = transformed_event + return transformed_event, False, None + await self._persist_output_negotiation_metadata(transformed_event) + await self.task_manager.process(transformed_event) + + should_interrupt = False + is_auth_required = ( + isinstance(transformed_event, Task | TaskStatusUpdateEvent) + and transformed_event.status.state == TaskState.auth_required + ) + if is_auth_required or not blocking: + should_interrupt = True + + if should_interrupt: + bg_task = asyncio.create_task( + self._continue_consuming(event_stream, event_callback) + ) + interrupted = True + break + + return await self.task_manager.get_task(), interrupted, bg_task + + async def _continue_consuming( + self, + event_stream, + event_callback=None, # noqa: ANN001 + ) -> None: + async for event in event_stream: + transformed_event = self._transform_event(event) + if transformed_event is None: + continue + await self._persist_output_negotiation_metadata(transformed_event) + await self.task_manager.process(transformed_event) + if event_callback: + await event_callback() + + +def _filter_task(task: Task, accepted_output_modes: Collection[str]) -> Task: + status = task.status.model_copy( + update={"message": _filter_optional_message(task.status.message, accepted_output_modes)} + ) + history = None + if task.history is not None: + history = [ + message + for filtered in ( + _filter_message(message, accepted_output_modes) for message in task.history + ) + if filtered is not None + for message in [filtered] + ] + artifacts = None + if task.artifacts is not None: + artifacts = [ + artifact + for filtered in ( + _filter_artifact(artifact, accepted_output_modes) for artifact in task.artifacts + ) + if filtered is not None + for artifact in [filtered] + ] + + return task.model_copy(update={"status": status, "history": history, "artifacts": artifacts}) + + +def _filter_optional_message( + message: Message | None, + accepted_output_modes: Collection[str], +) -> Message | None: + if message is None: + return None + return _filter_message(message, accepted_output_modes) + + +def _filter_message( + message: Message, + accepted_output_modes: Collection[str], +) -> Message | None: + parts = _filter_parts(message.parts, accepted_output_modes) + if not parts: + return None + return message.model_copy(update={"parts": parts}) + + +def _filter_artifact( + artifact: Artifact, + accepted_output_modes: Collection[str], +) -> Artifact | None: + parts = _filter_parts(artifact.parts, accepted_output_modes) + if not parts: + return None + return artifact.model_copy(update={"parts": parts}) + + +def _filter_parts( + parts: list[Part], + accepted_output_modes: Collection[str], +) -> list[Part]: + filtered: list[Part] = [] + for part in parts: + media_type = _part_media_type(part) + if media_type is None or accepts_output_mode(accepted_output_modes, media_type): + filtered.append(part) + continue + if accepts_output_mode(accepted_output_modes, _TEXT_PLAIN_MEDIA_TYPE): + fallback_text = part_text_fallback(part.root) + if fallback_text is not None: + filtered.append(Part(root=TextPart(text=fallback_text))) + return filtered + + +def _part_media_type(part: Part) -> str | None: + payload = part.root + if isinstance(payload, TextPart): + return _TEXT_PLAIN_MEDIA_TYPE + if isinstance(payload, DataPart): + return _APPLICATION_JSON_MEDIA_TYPE + if isinstance(payload, FilePart): + file_value = payload.file + return ( + getattr(file_value, "mime_type", None) + or getattr(file_value, "mimeType", None) + or "application/octet-stream" + ) + return None diff --git a/src/opencode_a2a/parsing.py b/src/opencode_a2a/parsing.py new file mode 100644 index 0000000..b651ae1 --- /dev/null +++ b/src/opencode_a2a/parsing.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from collections.abc import Callable, Collection +from datetime import UTC, datetime +from typing import Any + + +def parse_int_field( + value: Any, + *, + field: str, + error_factory: Callable[[str, str], Exception], + minimum: int | None = None, +) -> int | None: + if value is None: + return None + if isinstance(value, bool): + raise error_factory(field, f"{field} must be an integer") + if isinstance(value, int): + parsed = value + elif isinstance(value, str): + try: + parsed = int(value) + except ValueError as exc: + raise error_factory(field, f"{field} must be an integer") from exc + else: + raise error_factory(field, f"{field} must be an integer") + + if minimum is not None and parsed < minimum: + raise error_factory(field, f"{field} must be >= {minimum}") + return parsed + + +def parse_string_field( + value: Any, + *, + field: str, + error_factory: Callable[[str, str], Exception], +) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise error_factory(field, f"{field} must be a string") + normalized = value.strip() + return normalized or None + + +def parse_bool_field( + value: Any, + *, + field: str, + error_factory: Callable[[str, str], Exception], + true_values: Collection[str] = ("true", "1", "yes", "on"), + false_values: Collection[str] = ("false", "0", "no", "off"), +) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in true_values: + return True + if normalized in false_values: + return False + raise error_factory(field, f"{field} must be a boolean") + + +def parse_timestamp_field( + value: Any, + *, + field: str, + error_factory: Callable[[str, str], Exception], +) -> datetime: + if not isinstance(value, str): + raise error_factory(field, f"{field} must be a valid ISO 8601 timestamp.") + normalized = value.strip() + if normalized.endswith("Z"): + normalized = normalized[:-1] + "+00:00" + try: + parsed = datetime.fromisoformat(normalized) + except ValueError as exc: + raise error_factory(field, f"{field} must be a valid ISO 8601 timestamp.") from exc + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index edec8ab..1e21b9a 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -64,7 +64,12 @@ OpencodeSessionManagementJSONRPCApplication, ) from ..opencode_upstream_client import OpencodeUpstreamClient -from ..output_modes import normalize_accepted_output_modes +from ..output_modes import ( + NegotiatingResultAggregator, + apply_accepted_output_modes, + extract_accepted_output_modes_from_metadata, + normalize_accepted_output_modes, +) from ..profile.runtime import build_runtime_profile from ..trace_context import install_log_record_factory from .agent_card import ( @@ -260,6 +265,33 @@ def _extract_accepted_output_modes(params) -> list[str] | None: # noqa: ANN001 normalized = normalize_accepted_output_modes(configuration) return list(normalized) if normalized is not None else None + @staticmethod + def _apply_task_output_negotiation(task: Task) -> Task: + negotiated = apply_accepted_output_modes( + task, + extract_accepted_output_modes_from_metadata(task.metadata), + ) + if isinstance(negotiated, Task): + return negotiated + return task + + async def _setup_message_execution(self, params, context=None): # noqa: ANN001 + ( + task_manager, + task_id, + queue, + _result_aggregator, + producer_task, + ) = await super()._setup_message_execution(params, context) + accepted_output_modes = self._extract_accepted_output_modes(params) + return ( + task_manager, + task_id, + queue, + NegotiatingResultAggregator(task_manager, accepted_output_modes), + producer_task, + ) + @classmethod def _validate_chat_output_modes(cls, params) -> None: # noqa: ANN001 accepted_output_modes = cls._extract_accepted_output_modes(params) @@ -298,7 +330,10 @@ async def on_get_task( context=None, ) -> Task | None: try: - return await super().on_get_task(params, context) + task = await super().on_get_task(params, context) + if task is None: + return None + return self._apply_task_output_negotiation(task) except TaskStoreOperationError as exc: raise self._task_store_server_error(exc) from exc @@ -347,11 +382,16 @@ async def on_resubscribe_to_task( # Subscribe contract: terminal tasks replay once and then close stream. if task.status.state in TERMINAL_TASK_STATES: - yield task + yield self._apply_task_output_negotiation(task) return async for event in super().on_resubscribe_to_task(params, context): - yield event + negotiated = apply_accepted_output_modes( + event, + extract_accepted_output_modes_from_metadata(getattr(event, "metadata", None)), + ) + if negotiated is not None: + yield negotiated except TaskStoreOperationError as exc: raise self._task_store_server_error(exc) from exc diff --git a/src/opencode_a2a/server/rest_tasks.py b/src/opencode_a2a/server/rest_tasks.py index 278b968..caf2f34 100644 --- a/src/opencode_a2a/server/rest_tasks.py +++ b/src/opencode_a2a/server/rest_tasks.py @@ -12,6 +12,19 @@ from fastapi.responses import JSONResponse from ..jsonrpc.error_responses import build_http_error_body +from ..output_modes import ( + apply_accepted_output_modes, + extract_accepted_output_modes_from_metadata, +) +from ..parsing import ( + parse_bool_field as parse_shared_bool_field, +) +from ..parsing import ( + parse_int_field as parse_shared_int_field, +) +from ..parsing import ( + parse_timestamp_field as parse_shared_timestamp_field, +) from .task_store import TaskStoreOperationError, list_stored_tasks logger = logging.getLogger(__name__) @@ -44,6 +57,10 @@ def __init__(self, *, field: str, message: str) -> None: self.message = message +def _validation_error(field: str, message: str) -> _ListTasksValidationError: + return _ListTasksValidationError(field=field, message=message) + + def build_list_tasks_route( *, task_store: TaskStore, @@ -141,6 +158,13 @@ def _serialize_task( history_length: int, include_artifacts: bool, ) -> dict: + negotiated = apply_accepted_output_modes( + task, + extract_accepted_output_modes_from_metadata(task.metadata), + ) + if isinstance(negotiated, Task): + task = negotiated + payload = task.model_dump(mode="json", by_alias=True, exclude_none=True) history = payload.get("history") @@ -216,43 +240,38 @@ def _parse_list_tasks_query(request: Request) -> _ListTasksQuery: def _parse_int(raw_value: str, *, field: str) -> int: - try: - return int(raw_value) - except ValueError as exc: - raise _ListTasksValidationError( - field=field, - message=f"{field} must be an integer.", - ) from exc + parsed = parse_shared_int_field( + raw_value, + field=field, + error_factory=lambda error_field, _message: _validation_error( + error_field, + f"{error_field} must be an integer.", + ), + ) + assert parsed is not None + return parsed def _parse_bool(raw_value: str | None, *, field: str, default: bool) -> bool: - if raw_value is None: - return default - normalized = raw_value.strip().lower() - if normalized in {"true", "1"}: - return True - if normalized in {"false", "0"}: - return False - raise _ListTasksValidationError( + parsed = parse_shared_bool_field( + raw_value, field=field, - message=f"{field} must be a boolean.", + error_factory=lambda error_field, _message: _validation_error( + error_field, + f"{error_field} must be a boolean.", + ), + true_values=("true", "1"), + false_values=("false", "0"), ) + return default if parsed is None else parsed def _parse_timestamp(raw_value: str, *, field: str) -> datetime: - normalized = raw_value.strip() - if normalized.endswith("Z"): - normalized = normalized[:-1] + "+00:00" - try: - parsed = datetime.fromisoformat(normalized) - except ValueError as exc: - raise _ListTasksValidationError( - field=field, - message=f"{field} must be a valid ISO 8601 timestamp.", - ) from exc - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=UTC) - return parsed.astimezone(UTC) + return parse_shared_timestamp_field( + raw_value, + field=field, + error_factory=_validation_error, + ) def _task_status_timestamp(task: Task) -> datetime: diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py index a0d03c8..3de453a 100644 --- a/src/opencode_a2a/server/task_store.py +++ b/src/opencode_a2a/server/task_store.py @@ -10,13 +10,14 @@ from a2a.server.tasks.database_task_store import DatabaseTaskStore from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.server.tasks.task_store import TaskStore -from a2a.types import Task, TaskState +from a2a.types import Task from sqlalchemy import event, or_, select from sqlalchemy.dialects.postgresql import insert as postgresql_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.engine import make_url from ..config import Settings +from ..task_states import TERMINAL_TASK_STATES if TYPE_CHECKING: from a2a.server.context import ServerCallContext @@ -24,15 +25,7 @@ logger = logging.getLogger(__name__) -_TERMINAL_TASK_STATES = frozenset( - { - TaskState.completed, - TaskState.canceled, - TaskState.failed, - TaskState.rejected, - } -) -_TERMINAL_TASK_STATE_VALUES = tuple(state.value for state in _TERMINAL_TASK_STATES) +_TERMINAL_TASK_STATE_VALUES = tuple(state.value for state in TERMINAL_TASK_STATES) _ATOMIC_TERMINAL_GUARD_DIALECTS = frozenset({"postgresql", "sqlite"}) _SQLITE_JOURNAL_MODE = "WAL" _SQLITE_BUSY_TIMEOUT_MS = 30_000 @@ -70,7 +63,7 @@ def evaluate( existing: Task | None, incoming: Task, ) -> TaskPersistenceDecision: - if existing is None or existing.status.state not in _TERMINAL_TASK_STATES: + if existing is None or existing.status.state not in TERMINAL_TASK_STATES: return TaskPersistenceDecision(persist=True) if incoming.status.state != existing.status.state: return TaskPersistenceDecision( @@ -224,7 +217,7 @@ async def _save_database_task( return if ( existing is not None - and existing.status.state in _TERMINAL_TASK_STATES + and existing.status.state in TERMINAL_TASK_STATES and existing.model_dump(mode="json") == task.model_dump(mode="json") ): return @@ -272,7 +265,7 @@ def _log_terminal_persistence_decision( incoming: Task, decision: TaskPersistenceDecision, ) -> None: - if existing is None or existing.status.state not in _TERMINAL_TASK_STATES: + if existing is None or existing.status.state not in TERMINAL_TASK_STATES: return logger.warning( "Received task persistence after terminal state task_id=%s existing_state=%s " diff --git a/src/opencode_a2a/task_states.py b/src/opencode_a2a/task_states.py new file mode 100644 index 0000000..2c220c2 --- /dev/null +++ b/src/opencode_a2a/task_states.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from a2a.types import TaskState + +TERMINAL_TASK_STATES = frozenset( + { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, + } +) diff --git a/tests/execution/test_streaming_output_contract_blocks.py b/tests/execution/test_streaming_output_contract_blocks.py index d2d8cac..8deeb22 100644 --- a/tests/execution/test_streaming_output_contract_blocks.py +++ b/tests/execution/test_streaming_output_contract_blocks.py @@ -123,7 +123,7 @@ async def test_streaming_emits_structured_tool_part_updates() -> None: @pytest.mark.asyncio -async def test_streaming_suppresses_structured_tool_updates_when_json_output_not_accepted() -> None: +async def test_streaming_downgrades_structured_tool_updates_when_json_output_not_accepted() -> None: client = DummyStreamingClient( stream_events_payload=[ _event( @@ -157,7 +157,9 @@ async def test_streaming_suppresses_structured_tool_updates_when_json_output_not updates = _artifact_updates(queue) tool_updates = [ev for ev in updates if _artifact_stream_meta(ev)["block_type"] == "tool_call"] - assert tool_updates == [] + assert len(tool_updates) == 1 + assert _part_text(tool_updates[0]) == '{"call_id":"call-1","status":"running","tool":"bash"}' + assert getattr(tool_updates[0].artifact.parts[0].root, "kind", None) == "text" assert any(_artifact_stream_meta(ev)["block_type"] == "text" for ev in updates) diff --git a/tests/scripts/test_script_health_contract.py b/tests/scripts/test_script_health_contract.py index 6543332..483ab2b 100644 --- a/tests/scripts/test_script_health_contract.py +++ b/tests/scripts/test_script_health_contract.py @@ -23,8 +23,11 @@ def test_doctor_keeps_local_regression_scope() -> None: assert 'source "$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/health_common.sh"' in DOCTOR_TEXT assert 'run_shared_repo_health_prerequisites "doctor"' in DOCTOR_TEXT assert "uv run pre-commit run --all-files" in DOCTOR_TEXT + assert "uv run mypy src/opencode_a2a" in DOCTOR_TEXT assert "uv run pytest" in DOCTOR_TEXT assert "uv run python ./scripts/check_coverage.py" in DOCTOR_TEXT + assert "uv build --no-sources" in DOCTOR_TEXT + assert "bash ./scripts/smoke_test_built_cli.sh dist/opencode_a2a-*.whl" in DOCTOR_TEXT assert "uv pip list --outdated" not in DOCTOR_TEXT assert "uv run pip-audit" not in DOCTOR_TEXT @@ -46,6 +49,7 @@ def test_scripts_index_documents_split_health_entrypoints() -> None: assert "external A2A conformance experiment entrypoint" in SCRIPTS_INDEX_TEXT assert "dependency review entrypoint" in SCRIPTS_INDEX_TEXT assert "health_common.sh" in SCRIPTS_INDEX_TEXT + assert "built-wheel smoke test" in SCRIPTS_INDEX_TEXT assert "single weekly grouped Dependabot PR for `uv`" in SCRIPTS_INDEX_TEXT diff --git a/tests/server/test_output_negotiation.py b/tests/server/test_output_negotiation.py new file mode 100644 index 0000000..06e4e06 --- /dev/null +++ b/tests/server/test_output_negotiation.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest +from a2a.server.events import EventConsumer, EventQueue +from a2a.server.tasks import TaskManager +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import ( + Artifact, + DataPart, + Message, + Part, + Role, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskQueryParams, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, + TextPart, +) + +from opencode_a2a.output_modes import ( + NegotiatingResultAggregator, + apply_accepted_output_modes, + build_output_negotiation_metadata, + extract_accepted_output_modes_from_metadata, + normalize_accepted_output_modes, + part_text_fallback, +) +from opencode_a2a.server.application import OpencodeRequestHandler + + +def _message(*, message_id: str, text: str, task_id: str, context_id: str) -> Message: + return Message( + message_id=message_id, + role=Role.agent, + parts=[Part(root=TextPart(text=text))], + task_id=task_id, + context_id=context_id, + ) + + +def _task_with_negotiated_outputs(*, task_id: str, context_id: str) -> Task: + metadata = build_output_negotiation_metadata(["text/plain"]) + assert metadata is not None + return Task( + id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.completed, + message=_message( + message_id=f"{task_id}:status", + text="done", + task_id=task_id, + context_id=context_id, + ), + ), + history=[ + _message( + message_id=f"{task_id}:history", + text="history", + task_id=task_id, + context_id=context_id, + ) + ], + artifacts=[ + Artifact( + artifact_id=f"{task_id}:text", + parts=[Part(root=TextPart(text="plain result"))], + ), + Artifact( + artifact_id=f"{task_id}:json", + parts=[Part(root=DataPart(data={"tool": "bash", "status": "completed"}))], + ), + ], + metadata=metadata, + ) + + +def test_normalize_accepted_output_modes_treats_wildcards_as_unrestricted() -> None: + assert normalize_accepted_output_modes(["text/plain", "APPLICATION/JSON"]) == ( + "text/plain", + "application/json", + ) + assert normalize_accepted_output_modes(["text/plain", "*/*"]) is None + assert normalize_accepted_output_modes(["*"]) is None + + +def test_part_text_fallback_serializes_data_parts_as_stable_json() -> None: + assert part_text_fallback(DataPart(data={"tool": "bash", "status": "running"})) == ( + '{"status":"running","tool":"bash"}' + ) + + +def test_apply_accepted_output_modes_downgrades_task_data_parts_to_text() -> None: + task = Task( + id="task-send", + context_id="ctx-send", + status=TaskStatus( + state=TaskState.completed, + message=Message( + message_id="msg-send", + role=Role.agent, + parts=[Part(root=DataPart(data={"tool": "bash", "status": "running"}))], + task_id="task-send", + context_id="ctx-send", + ), + ), + artifacts=[ + Artifact( + artifact_id="artifact-send", + parts=[Part(root=DataPart(data={"tool": "bash", "status": "running"}))], + ) + ], + ) + + downgraded = apply_accepted_output_modes(task, ["text/plain"]) + + assert isinstance(downgraded, Task) + assert downgraded.status.message is not None + assert downgraded.status.message.parts[0].root.text == '{"status":"running","tool":"bash"}' + assert downgraded.artifacts is not None + assert downgraded.artifacts[0].parts[0].root.text == '{"status":"running","tool":"bash"}' + + +@pytest.mark.asyncio +async def test_negotiating_result_aggregator_persists_metadata_for_artifact_first_flow() -> None: + store = InMemoryTaskStore() + task_manager = TaskManager( + task_id="task-artifact-first", + context_id="ctx-artifact-first", + task_store=store, + initial_message=None, + ) + aggregator = NegotiatingResultAggregator(task_manager, ["text/plain"]) + queue = EventQueue() + + await queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id="task-artifact-first", + context_id="ctx-artifact-first", + artifact=Artifact( + artifact_id="artifact-1", + parts=[Part(root=TextPart(text="hello"))], + ), + append=False, + last_chunk=False, + ) + ) + await queue.enqueue_event( + TaskStatusUpdateEvent( + task_id="task-artifact-first", + context_id="ctx-artifact-first", + status=TaskStatus(state=TaskState.completed), + final=True, + ) + ) + + result, interrupted, bg_task = await aggregator.consume_and_break_on_interrupt( + EventConsumer(queue), + blocking=False, + ) + + assert interrupted is True + assert isinstance(result, Task) + assert bg_task is not None + assert extract_accepted_output_modes_from_metadata(result.metadata) == ("text/plain",) + assert result.artifacts is not None + assert [artifact.artifact_id for artifact in result.artifacts] == ["artifact-1"] + + await bg_task + stored = await store.get("task-artifact-first") + assert stored is not None + assert extract_accepted_output_modes_from_metadata(stored.metadata) == ("text/plain",) + + +@pytest.mark.asyncio +async def test_on_get_task_applies_persisted_output_negotiation() -> None: + store = InMemoryTaskStore() + task = _task_with_negotiated_outputs(task_id="task-get", context_id="ctx-get") + await store.save(task) + handler = OpencodeRequestHandler(agent_executor=AsyncMock(), task_store=store) + + result = await handler.on_get_task(TaskQueryParams(id="task-get")) + + assert result is not None + assert extract_accepted_output_modes_from_metadata(result.metadata) == ("text/plain",) + assert result.artifacts is not None + assert [artifact.artifact_id for artifact in result.artifacts] == [ + "task-get:text", + "task-get:json", + ] + assert result.artifacts[1].parts[0].root.text == '{"status":"completed","tool":"bash"}' + + +@pytest.mark.asyncio +async def test_resubscribe_terminal_task_applies_persisted_output_negotiation() -> None: + store = InMemoryTaskStore() + task = _task_with_negotiated_outputs(task_id="task-resub", context_id="ctx-resub") + await store.save(task) + handler = OpencodeRequestHandler(agent_executor=AsyncMock(), task_store=store) + + events = [] + async for event in handler.on_resubscribe_to_task(TaskIdParams(id="task-resub")): + events.append(event) + + assert len(events) == 1 + assert isinstance(events[0], Task) + assert events[0].artifacts is not None + assert [artifact.artifact_id for artifact in events[0].artifacts] == [ + "task-resub:text", + "task-resub:json", + ] + assert events[0].artifacts[1].parts[0].root.text == '{"status":"completed","tool":"bash"}' diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index 27a1e1d..a332eff 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -8,6 +8,7 @@ from a2a.server.apps.rest.rest_adapter import RESTAdapter from a2a.types import ( Artifact, + DataPart, Message, Part, Role, @@ -18,6 +19,7 @@ TransportProtocol, ) +from opencode_a2a.output_modes import build_output_negotiation_metadata from opencode_a2a.server.application import ( AUTHENTICATED_EXTENDED_CARD_CACHE_CONTROL, PUBLIC_AGENT_CARD_CACHE_CONTROL, @@ -343,6 +345,64 @@ async def test_list_tasks_route_supports_history_artifacts_and_filters(monkeypat assert returned_task["artifacts"][0]["artifactId"] == "task-filtered-artifact" +@pytest.mark.asyncio +async def test_list_tasks_route_applies_persisted_output_negotiation(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", DummyChatOpencodeUpstreamClient) + app = app_module.create_app( + make_settings( + test_bearer_token="test-token", + a2a_task_store_backend="memory", + ) + ) + task_store = app.state.task_store + metadata = build_output_negotiation_metadata(["text/plain"]) + assert metadata is not None + + task = Task( + id="task-negotiated-list", + context_id="ctx-negotiated-list", + status=TaskStatus(state=TaskState.completed, timestamp=datetime.now(UTC).isoformat()), + artifacts=[ + Artifact( + artifact_id="task-negotiated-list-text", + parts=[Part(root=TextPart(text="plain"))], + ), + Artifact( + artifact_id="task-negotiated-list-json", + parts=[Part(root=DataPart(data={"tool": "bash", "status": "completed"}))], + ), + ], + metadata=metadata, + ) + await task_store.save(task) + + transport = httpx.ASGITransport(app=app) + headers = {"Authorization": "Bearer test-token"} + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get( + "/v1/tasks", + headers=headers, + params={ + "contextId": "ctx-negotiated-list", + "includeArtifacts": "true", + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["totalSize"] == 1 + assert payload["tasks"][0]["id"] == "task-negotiated-list" + assert [artifact["artifactId"] for artifact in payload["tasks"][0]["artifacts"]] == [ + "task-negotiated-list-text", + "task-negotiated-list-json", + ] + assert payload["tasks"][0]["artifacts"][1]["parts"][0]["text"] == ( + '{"status":"completed","tool":"bash"}' + ) + + @pytest.mark.asyncio async def test_list_tasks_route_tolerates_invalid_stored_status_timestamp(monkeypatch) -> None: import opencode_a2a.server.application as app_module diff --git a/tests/test_metadata_access.py b/tests/test_metadata_access.py new file mode 100644 index 0000000..abea3e5 --- /dev/null +++ b/tests/test_metadata_access.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from opencode_a2a.metadata_access import ( + extract_first_namespaced_string, + extract_namespaced_value, +) + + +def test_extract_namespaced_value_returns_nested_metadata() -> None: + metadata = {"opencode": {"workspace": {"id": "wrk-1"}}} + + assert ( + extract_namespaced_value( + metadata, + namespace="opencode", + path=("workspace", "id"), + ) + == "wrk-1" + ) + + +def test_extract_namespaced_value_returns_none_for_invalid_shape() -> None: + metadata = {"opencode": {"workspace": "invalid"}} + + assert ( + extract_namespaced_value( + metadata, + namespace="opencode", + path=("workspace", "id"), + ) + is None + ) + + +def test_extract_first_namespaced_string_prefers_first_non_empty_value() -> None: + sources = ( + {"opencode": {"directory": " "}}, + {"opencode": {"directory": "services/api"}}, + {"opencode": {"directory": "services/worker"}}, + ) + + assert ( + extract_first_namespaced_string( + sources, + namespace="opencode", + path=("directory",), + ) + == "services/api" + )