diff --git a/src/adcp/__init__.py b/src/adcp/__init__.py index 39b93fb9..22deb88d 100644 --- a/src/adcp/__init__.py +++ b/src/adcp/__init__.py @@ -27,7 +27,7 @@ build_synthetic_capabilities, validate_capabilities, ) -from adcp.client import ADCPClient, ADCPMultiAgentClient +from adcp.client import ADCPClient, ADCPMultiAgentClient, Checkpoint from adcp.exceptions import ( # noqa: F401 AdagentsNotFoundError, AdagentsTimeoutError, @@ -566,6 +566,7 @@ def get_adcp_version() -> str: # Client classes "ADCPClient", "ADCPMultiAgentClient", + "Checkpoint", "RegistryClient", "PropertyRegistry", "RegistrySync", diff --git a/src/adcp/protocols/a2a.py b/src/adcp/protocols/a2a.py index 3fdcbbb5..3f519b73 100644 --- a/src/adcp/protocols/a2a.py +++ b/src/adcp/protocols/a2a.py @@ -523,7 +523,7 @@ def _process_task_response(self, task: Task, debug_info: DebugInfo | None) -> Ta """Process a Task response from A2A into our TaskResult format.""" task_state = task.status.state - if task_state == "completed": + if task_state == TaskState.completed: # Extract the result from the artifacts array result_data = self._extract_result_from_task(task) @@ -542,7 +542,7 @@ def _process_task_response(self, task: Task, debug_info: DebugInfo | None) -> Ta }, debug_info=debug_info, ) - elif task_state == "failed": + elif task_state == TaskState.failed: # Protocol-level failure - extract error message from TextPart error_msg = self._extract_text_from_task(task) or "Task failed" return TaskResult[Any]( diff --git a/tests/fixtures/public_api_snapshot.json b/tests/fixtures/public_api_snapshot.json index a5bcf2cc..0fafdde3 100644 --- a/tests/fixtures/public_api_snapshot.json +++ b/tests/fixtures/public_api_snapshot.json @@ -72,6 +72,7 @@ "ChangeHandler", "CheckGovernanceRequest", "CheckGovernanceResponse", + "Checkpoint", "ComplyTestControllerRequest", "ComplyTestControllerResponse", "ConsentBasis",