diff --git a/.gitignore b/.gitignore index 1bccc90..8540077 100644 --- a/.gitignore +++ b/.gitignore @@ -131,5 +131,7 @@ dmypy.json # IDEs .idea +.worktrees/ +docs/superpowers/ -coverage.lcov \ No newline at end of file +coverage.lcov diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f84e57..695c4c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,8 +16,32 @@ ADDED `TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` to support pre-configured channel passthrough and low-level gRPC channel customization. -- Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients. -- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window. +- Added `GrpcWorkerResiliencyOptions` and `GrpcClientResiliencyOptions`, plus + `resiliency_options` constructor parameters on `TaskHubGrpcClient`, + `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker`, to configure hello + deadlines, silent-disconnect detection, reconnect backoff, and channel + recreation thresholds for SDK-managed gRPC connections. +- Added `get_orchestration_history()` and `list_instance_ids()` to the sync + and async gRPC clients. +- Added in-memory backend support for `StreamInstanceHistory` and + `ListInstanceIds` so local orchestration tests can retrieve history and page + terminal instance IDs by completion window. + +FIXED + +- Improved `TaskHubGrpcWorker` recovery from stale or disconnected gRPC streams + so configured hello timeouts apply on fresh connections, received work resets + failure tracking, SDK-owned channels are refreshed and cleaned up safely, and + caller-owned channels are never recreated or closed during reconnects. +- Fixed `TaskHubGrpcWorker` so in-flight and queued work item completions keep + draining across graceful gRPC stream resets and worker shutdown before the + worker retires an SDK-owned channel. +- Improved sync and async gRPC clients so repeated transport failures recreate + SDK-owned channels, while long-poll deadlines, successful replies, and + application-level RPC errors do not trigger unnecessary channel replacement. +- Fixed `TaskHubGrpcClient.close()` so explicit sync client shutdown now closes + any previously retired SDK-owned gRPC channels immediately instead of waiting + for the delayed cleanup timer. ## v1.4.0 diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 639d75f..ac2af0b 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and `DurableTaskSchedulerWorker` to allow combining custom gRPC interceptors with DTS defaults and to support pre-configured/customized gRPC channels. +- Added pass-through `resiliency_options` support on + `DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and + `DurableTaskSchedulerWorker` so Azure Managed applications can use the core + SDK's gRPC resiliency option types through their constructors. - Added `workerid` gRPC metadata on Durable Task Scheduler worker calls for improved worker identity and observability. - Improved sync access token refresh concurrency handling to avoid duplicate diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index ea30471..ed5dcc9 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -15,7 +15,10 @@ DTSDefaultClientInterceptorImpl, ) from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, +) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore @@ -30,6 +33,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -54,6 +58,7 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, default_version=default_version, payload_store=payload_store) @@ -74,6 +79,8 @@ class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient): If None, anonymous authentication will be used. secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). Defaults to True. + resiliency_options (Optional[GrpcClientResiliencyOptions], optional): Client-side + gRPC resiliency settings forwarded to the base async client. default_version (Optional[str], optional): Default version string for orchestrations. payload_store (Optional[PayloadStore], optional): A payload store for externalizing large payloads. If None, payloads are sent inline. @@ -104,6 +111,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -128,5 +136,6 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, default_version=default_version, payload_store=payload_store) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 7e9c0ef..6956ae2 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -13,7 +13,10 @@ from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ DTSDefaultClientInterceptorImpl -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcWorkerResiliencyOptions, +) import durabletask.internal.shared as shared from durabletask.payload.store import PayloadStore from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker @@ -34,6 +37,8 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): If None, anonymous authentication will be used. secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). Defaults to True. + resiliency_options (Optional[GrpcWorkerResiliencyOptions], optional): Worker-side + gRPC resiliency settings forwarded to the base worker. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default concurrency settings will be used. @@ -74,6 +79,7 @@ def __init__(self, *, secure_channel: bool = True, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcWorkerResiliencyOptions] = None, concurrency_options: Optional[ConcurrencyOptions] = None, payload_store: Optional[PayloadStore] = None, log_handler: Optional[logging.Handler] = None, @@ -101,6 +107,7 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=resolved_interceptors, channel_options=channel_options, + resiliency_options=resiliency_options, concurrency_options=concurrency_options, # DTS natively supports long timers so chunking is unnecessary maximum_timer_interval=None, diff --git a/durabletask/client.py b/durabletask/client.py index 0ef223b..e76cdfa 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import asyncio import logging +import threading +import time import uuid from dataclasses import dataclass from datetime import datetime @@ -14,7 +17,10 @@ import durabletask.history as history from durabletask.entities import EntityInstanceId from durabletask.entities.entity_metadata import EntityMetadata -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, +) import durabletask.internal.helpers as helpers import durabletask.internal.history_helpers as history_helpers import durabletask.internal.orchestrator_service_pb2 as pb @@ -22,6 +28,10 @@ import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing from durabletask import task +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + is_client_transport_failure, +) from durabletask.internal.client_helpers import ( build_query_entities_req, build_query_instances_req, @@ -166,24 +176,109 @@ def __init__(self, *, secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None): self._owns_channel = channel is None + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) + self._secure_channel = secure_channel + self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcClientResiliencyOptions() + ) + resolved_interceptors = ( + prepare_sync_interceptors(metadata, interceptors) if channel is None else interceptors + ) + self._interceptors = ( + list(resolved_interceptors) + if resolved_interceptors is not None + else None + ) if channel is None: - interceptors = prepare_sync_interceptors(metadata, interceptors) channel = shared.get_grpc_channel( - host_address=host_address, + host_address=self._host_address, secure_channel=secure_channel, - interceptors=interceptors, + interceptors=self._interceptors, channel_options=channel_options, ) self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold + ) + self._closing = False + self._last_recreate_time = 0.0 + self._recreate_lock = threading.Lock() + self._retired_channels: dict[grpc.Channel, threading.Timer] = {} self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store + def _invoke_unary( + self, + method_name: str, + request: Any, + *, + timeout: Optional[float] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = method(request) + else: + response = method(request, timeout=timeout) + except grpc.RpcError as rpc_error: + status_code = rpc_error.code() + if is_client_transport_failure(method_name, status_code): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + self._maybe_recreate_channel() + else: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + + def _maybe_recreate_channel(self) -> None: + if not self._owns_channel or self._closing: + return + with self._recreate_lock: + if self._closing: + return + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + close_timer = threading.Timer( + 30.0, + self._close_retired_channel, + args=(old_channel,), + ) + close_timer.daemon = True + self._retired_channels[old_channel] = close_timer + close_timer.start() + + def _close_retired_channel(self, channel: grpc.Channel) -> None: + with self._recreate_lock: + close_timer = self._retired_channels.pop(channel, None) + if close_timer is None: + return + channel.close() + def close(self) -> None: """Closes the underlying gRPC channel. @@ -193,7 +288,15 @@ def close(self) -> None: it. """ if self._owns_channel: - self._channel.close() + with self._recreate_lock: + self._closing = True + retired_channels = list(self._retired_channels.items()) + self._retired_channels.clear() + current_channel = self._channel + for retired_channel, close_timer in retired_channels: + close_timer.cancel() + retired_channel.close() + current_channel.close() def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, @@ -228,12 +331,12 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu payload_helpers.externalize_payloads( req, self._payload_store, instance_id=req.instanceId, ) - res: pb.CreateInstanceResponse = self._stub.StartInstance(req) + res: pb.CreateInstanceResponse = self._invoke_unary("StartInstance", req) return res.instanceId def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = self._stub.GetInstance(req) + res: pb.GetInstanceResponse = self._invoke_unary("GetInstance", req) # De-externalize any large-payload tokens in the response if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) @@ -273,7 +376,7 @@ def list_instance_ids(self, f"page_size={page_size}, " f"continuation_token={continuation_token}" ) - resp: pb.ListInstanceIdsResponse = self._stub.ListInstanceIds(req) + resp: pb.ListInstanceIdsResponse = self._invoke_unary("ListInstanceIds", req) next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None return Page(items=list(resp.instanceIds), continuation_token=next_token) @@ -290,7 +393,7 @@ def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) - resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) + resp: pb.QueryInstancesResponse = self._invoke_unary("QueryInstances", req) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] @@ -303,11 +406,15 @@ def get_all_orchestration_states(self, def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout) + res: pb.GetInstanceResponse = self._invoke_unary( + "WaitForInstanceStart", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -320,11 +427,15 @@ def wait_for_orchestration_start(self, instance_id: str, *, def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) + res: pb.GetInstanceResponse = self._invoke_unary( + "WaitForInstanceCompletion", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: payload_helpers.deexternalize_payloads(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) @@ -345,7 +456,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=instance_id, ) - self._stub.RaiseEvent(req) + self._invoke_unary("RaiseEvent", req) def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -357,17 +468,17 @@ def terminate_orchestration(self, instance_id: str, *, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=instance_id, ) - self._stub.TerminateInstance(req) + self._invoke_unary("TerminateInstance", req) def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") - self._stub.SuspendInstance(req) + self._invoke_unary("SuspendInstance", req) def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") - self._stub.ResumeInstance(req) + self._invoke_unary("ResumeInstance", req) def restart_orchestration(self, instance_id: str, *, restart_with_new_instance_id: bool = False) -> str: @@ -386,13 +497,13 @@ def restart_orchestration(self, instance_id: str, *, restartWithNewInstanceId=restart_with_new_instance_id) self._logger.info(f"Restarting instance '{instance_id}'.") - res: pb.RestartInstanceResponse = self._stub.RestartInstance(req) + res: pb.RestartInstanceResponse = self._invoke_unary("RestartInstance", req) return res.instanceId def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) def purge_orchestrations_by(self, @@ -406,7 +517,7 @@ def purge_orchestrations_by(self, f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " f"recursive={recursive}") req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) - resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) def signal_entity(self, @@ -419,7 +530,7 @@ def signal_entity(self, payload_helpers.externalize_payloads( req, self._payload_store, instance_id=str(entity_instance_id), ) - self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? + self._invoke_unary("SignalEntity", req) # TODO: Cancellation timeout? def get_entity(self, entity_instance_id: EntityInstanceId, @@ -427,7 +538,7 @@ def get_entity(self, ) -> Optional[EntityMetadata]: req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) self._logger.info(f"Getting entity '{entity_instance_id}'.") - res: pb.GetEntityResponse = self._stub.GetEntity(req) + res: pb.GetEntityResponse = self._invoke_unary("GetEntity", req) if not res.exists: return None if self._payload_store is not None: @@ -446,7 +557,7 @@ def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) - resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) + resp: pb.QueryEntitiesResponse = self._invoke_unary("QueryEntities", query_request) if self._payload_store is not None: payload_helpers.deexternalize_payloads(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] @@ -472,7 +583,7 @@ def clean_entity_storage(self, releaseOrphanedLocks=release_orphaned_locks, continuationToken=_continuation_token ) - resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req) + resp: pb.CleanEntityStorageResponse = self._invoke_unary("CleanEntityStorage", req) empty_entities_removed += resp.emptyEntitiesRemoved orphaned_locks_released += resp.orphanedLocksReleased @@ -496,20 +607,46 @@ def __init__(self, *, secure_channel: bool = False, interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcClientResiliencyOptions] = None, default_version: Optional[str] = None, payload_store: Optional[PayloadStore] = None): self._owns_channel = channel is None + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) + self._secure_channel = secure_channel + self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcClientResiliencyOptions() + ) + resolved_interceptors = ( + prepare_async_interceptors(metadata, interceptors) if channel is None else interceptors + ) + self._interceptors = ( + list(resolved_interceptors) + if resolved_interceptors is not None + else None + ) if channel is None: - interceptors = prepare_async_interceptors(metadata, interceptors) channel = shared.get_async_grpc_channel( - host_address=host_address, + host_address=self._host_address, secure_channel=secure_channel, - interceptors=interceptors, + interceptors=self._interceptors, channel_options=channel_options, ) self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._client_failure_tracker = FailureTracker( + self._resiliency_options.channel_recreate_failure_threshold + ) + self._closing = False + self._recreate_lock = asyncio.Lock() + self._last_recreate_time = 0.0 + self._retired_channels: list[grpc.aio.Channel] = [] + self._retired_channel_close_tasks: set[asyncio.Task[None]] = set() self._logger = shared.get_logger("async_client", log_handler, log_formatter) self.default_version = default_version self._payload_store = payload_store @@ -523,6 +660,18 @@ async def close(self) -> None: it. """ if self._owns_channel: + self._closing = True + async with self._recreate_lock: + retired_channels = list(self._retired_channels) + self._retired_channels.clear() + close_tasks = list(self._retired_channel_close_tasks) + self._retired_channel_close_tasks.clear() + for close_task in close_tasks: + close_task.cancel() + if close_tasks: + await asyncio.gather(*close_tasks, return_exceptions=True) + for retired_channel in retired_channels: + await retired_channel.close() await self._channel.close() async def __aenter__(self): @@ -531,6 +680,63 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() + async def _invoke_unary( + self, + method_name: str, + request: Any, + *, + timeout: Optional[float] = None): + method = getattr(self._stub, method_name) + try: + if timeout is None: + response = await method(request) + else: + response = await method(request, timeout=timeout) + except grpc.aio.AioRpcError as rpc_error: + if is_client_transport_failure(method_name, rpc_error.code()): + should_recreate = self._client_failure_tracker.record_failure() + if should_recreate: + await self._maybe_recreate_channel() + else: + self._client_failure_tracker.record_success() + raise + else: + self._client_failure_tracker.record_success() + return response + + async def _maybe_recreate_channel(self) -> None: + if not self._owns_channel or self._closing: + return + async with self._recreate_lock: + if self._closing: + return + now = time.monotonic() + if now - self._last_recreate_time < self._resiliency_options.min_recreate_interval_seconds: + return + old_channel = self._channel + self._channel = shared.get_async_grpc_channel( + host_address=self._host_address, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, + ) + self._stub = stubs.TaskHubSidecarServiceStub(self._channel) + self._last_recreate_time = now + self._client_failure_tracker.record_success() + self._retired_channels.append(old_channel) + close_task = asyncio.create_task(self._close_retired_channel(old_channel)) + self._retired_channel_close_tasks.add(close_task) + close_task.add_done_callback(self._retired_channel_close_tasks.discard) + + async def _close_retired_channel(self, channel: grpc.aio.Channel) -> None: + try: + await asyncio.sleep(30.0) + await channel.close() + finally: + async with self._recreate_lock: + if channel in self._retired_channels: + self._retired_channels.remove(channel) + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -561,13 +767,13 @@ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=req.instanceId, ) - res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + res: pb.CreateInstanceResponse = await self._invoke_unary("StartInstance", req) return res.instanceId async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + res: pb.GetInstanceResponse = await self._invoke_unary("GetInstance", req) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -606,7 +812,7 @@ async def list_instance_ids(self, f"page_size={page_size}, " f"continuation_token={continuation_token}" ) - resp: pb.ListInstanceIdsResponse = await self._stub.ListInstanceIds(req) + resp: pb.ListInstanceIdsResponse = await self._invoke_unary("ListInstanceIds", req) next_token = resp.lastInstanceKey.value if resp.HasField("lastInstanceKey") else None return Page(items=list(resp.instanceIds), continuation_token=next_token) @@ -623,7 +829,7 @@ async def get_all_orchestration_states(self, while True: req = build_query_instances_req(orchestration_query, _continuation_token) - resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req) + resp: pb.QueryInstancesResponse = await self._invoke_unary("QueryInstances", req) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) states += [parse_orchestration_state(res) for res in resp.orchestrationState] @@ -636,11 +842,15 @@ async def get_all_orchestration_states(self, async def wait_for_orchestration_start(self, instance_id: str, *, fetch_payloads: bool = False, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout) + res: pb.GetInstanceResponse = await self._invoke_unary( + "WaitForInstanceStart", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) return new_orchestration_state(req.instanceId, res) @@ -652,11 +862,15 @@ async def wait_for_orchestration_start(self, instance_id: str, *, async def wait_for_orchestration_completion(self, instance_id: str, *, fetch_payloads: bool = True, - timeout: int = 60) -> Optional[OrchestrationState]: + timeout: float = 60) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout) + res: pb.GetInstanceResponse = await self._invoke_unary( + "WaitForInstanceCompletion", + req, + timeout=timeout, + ) if self._payload_store is not None and res.exists: await payload_helpers.deexternalize_payloads_async(res, self._payload_store) state = new_orchestration_state(req.instanceId, res) @@ -677,7 +891,7 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=instance_id, ) - await self._stub.RaiseEvent(req) + await self._invoke_unary("RaiseEvent", req) async def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, @@ -689,17 +903,17 @@ async def terminate_orchestration(self, instance_id: str, *, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=instance_id, ) - await self._stub.TerminateInstance(req) + await self._invoke_unary("TerminateInstance", req) async def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") - await self._stub.SuspendInstance(req) + await self._invoke_unary("SuspendInstance", req) async def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") - await self._stub.ResumeInstance(req) + await self._invoke_unary("ResumeInstance", req) async def restart_orchestration(self, instance_id: str, *, restart_with_new_instance_id: bool = False) -> str: @@ -718,13 +932,13 @@ async def restart_orchestration(self, instance_id: str, *, restartWithNewInstanceId=restart_with_new_instance_id) self._logger.info(f"Restarting instance '{instance_id}'.") - res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req) + res: pb.RestartInstanceResponse = await self._invoke_unary("RestartInstance", req) return res.instanceId async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") - resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) async def purge_orchestrations_by(self, @@ -738,7 +952,7 @@ async def purge_orchestrations_by(self, f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " f"recursive={recursive}") req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) - resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + resp: pb.PurgeInstancesResponse = await self._invoke_unary("PurgeInstances", req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) async def signal_entity(self, @@ -751,7 +965,7 @@ async def signal_entity(self, await payload_helpers.externalize_payloads_async( req, self._payload_store, instance_id=str(entity_instance_id), ) - await self._stub.SignalEntity(req, None) + await self._invoke_unary("SignalEntity", req) async def get_entity(self, entity_instance_id: EntityInstanceId, @@ -759,7 +973,7 @@ async def get_entity(self, ) -> Optional[EntityMetadata]: req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) self._logger.info(f"Getting entity '{entity_instance_id}'.") - res: pb.GetEntityResponse = await self._stub.GetEntity(req) + res: pb.GetEntityResponse = await self._invoke_unary("GetEntity", req) if not res.exists: return None if self._payload_store is not None: @@ -778,7 +992,7 @@ async def get_all_entities(self, while True: query_request = build_query_entities_req(entity_query, _continuation_token) - resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request) + resp: pb.QueryEntitiesResponse = await self._invoke_unary("QueryEntities", query_request) if self._payload_store is not None: await payload_helpers.deexternalize_payloads_async(resp, self._payload_store) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] @@ -804,7 +1018,7 @@ async def clean_entity_storage(self, releaseOrphanedLocks=release_orphaned_locks, continuationToken=_continuation_token ) - resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req) + resp: pb.CleanEntityStorageResponse = await self._invoke_unary("CleanEntityStorage", req) empty_entities_removed += resp.emptyEntitiesRemoved orphaned_locks_released += resp.orphanedLocksReleased diff --git a/durabletask/grpc_options.py b/durabletask/grpc_options.py index 56f2236..f648104 100644 --- a/durabletask/grpc_options.py +++ b/durabletask/grpc_options.py @@ -100,3 +100,44 @@ def to_grpc_options(self) -> list[tuple[str, Any]]: options.append(("grpc.service_config", json.dumps(self.retry_policy.to_service_config()))) return options + + +@dataclass +class GrpcWorkerResiliencyOptions: + """Configuration for worker-side gRPC resiliency behavior.""" + + hello_timeout_seconds: float = 30.0 + silent_disconnect_timeout_seconds: float = 120.0 + channel_recreate_failure_threshold: int = 5 + reconnect_backoff_base_seconds: float = 1.0 + reconnect_backoff_cap_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.hello_timeout_seconds <= 0: + raise ValueError("hello_timeout_seconds must be > 0") + if self.silent_disconnect_timeout_seconds < 0: + raise ValueError("silent_disconnect_timeout_seconds must be >= 0") + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.reconnect_backoff_base_seconds <= 0: + raise ValueError("reconnect_backoff_base_seconds must be > 0") + if self.reconnect_backoff_cap_seconds <= 0: + raise ValueError("reconnect_backoff_cap_seconds must be > 0") + if self.reconnect_backoff_cap_seconds < self.reconnect_backoff_base_seconds: + raise ValueError( + "reconnect_backoff_cap_seconds must be >= reconnect_backoff_base_seconds" + ) + + +@dataclass +class GrpcClientResiliencyOptions: + """Configuration for client-side gRPC resiliency behavior.""" + + channel_recreate_failure_threshold: int = 5 + min_recreate_interval_seconds: float = 30.0 + + def __post_init__(self) -> None: + if self.channel_recreate_failure_threshold < 0: + raise ValueError("channel_recreate_failure_threshold must be >= 0") + if self.min_recreate_interval_seconds < 0: + raise ValueError("min_recreate_interval_seconds must be >= 0") diff --git a/durabletask/internal/grpc_resiliency.py b/durabletask/internal/grpc_resiliency.py new file mode 100644 index 0000000..0a8cdd6 --- /dev/null +++ b/durabletask/internal/grpc_resiliency.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import random +from dataclasses import dataclass + +import grpc + +LONG_POLL_METHODS = {"WaitForInstanceStart", "WaitForInstanceCompletion"} + + +def get_full_jitter_delay_seconds( + attempt: int, + *, + base_seconds: float, + cap_seconds: float, +) -> float: + capped_attempt = min(attempt, 30) + upper_bound = min(cap_seconds, base_seconds * (2 ** capped_attempt)) + return random.random() * upper_bound + + +@dataclass +class FailureTracker: + threshold: int + consecutive_failures: int = 0 + + def record_failure(self) -> bool: + if self.threshold <= 0: + return False + self.consecutive_failures += 1 + return self.consecutive_failures >= self.threshold + + def record_success(self) -> None: + self.consecutive_failures = 0 + + +def is_client_transport_failure(method_name: str, status_code: grpc.StatusCode) -> bool: + if status_code == grpc.StatusCode.UNAVAILABLE: + return True + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + return method_name not in LONG_POLL_METHODS + return False + + +def is_worker_transport_failure(status_code: grpc.StatusCode) -> bool: + return status_code in { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + } diff --git a/durabletask/worker.py b/durabletask/worker.py index 670a387..172a04d 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -6,12 +6,11 @@ import json import logging import os -import random import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone -from threading import Event, Thread +from threading import Event, Lock, Thread from types import GeneratorType from enum import Enum from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union, overload @@ -21,7 +20,10 @@ import grpc from google.protobuf import empty_pb2 -from durabletask.grpc_options import GrpcChannelOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcWorkerResiliencyOptions, +) from durabletask.entities.entity_operation_failed_exception import EntityOperationFailedException from durabletask.internal import helpers from durabletask.internal.entity_state_shim import StateShim @@ -36,6 +38,11 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared import durabletask.internal.tracing as tracing +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + get_full_jitter_delay_seconds, + is_worker_transport_failure, +) from durabletask.payload import helpers as payload_helpers from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -45,6 +52,7 @@ TOutput = TypeVar("TOutput") DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3) +_STREAM_CLOSED_SENTINEL = object() class ConcurrencyOptions: @@ -115,6 +123,80 @@ class VersionFailureStrategy(Enum): FAIL = 2 +class _WorkItemStreamOutcome(Enum): + SHUTDOWN = "shutdown" + GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE = "graceful_close_before_first_message" + GRACEFUL_CLOSE_AFTER_MESSAGE = "graceful_close_after_message" + SILENT_DISCONNECT = "silent_disconnect" + + +@dataclass +class _TrackedChannelState: + channel: Any + ref_count: int = 0 + close_when_released: bool = False + + +class _InFlightChannelTracker: + def __init__(self): + self._lock = Lock() + self._states: dict[int, _TrackedChannelState] = {} + + def acquire(self, channel: Any): + channel_key = id(channel) + with self._lock: + state = self._states.get(channel_key) + if state is None: + state = _TrackedChannelState(channel=channel) + self._states[channel_key] = state + state.ref_count += 1 + + released = False + + def release() -> None: + nonlocal released + if released: + return + released = True + + channel_to_close = None + with self._lock: + state = self._states.get(channel_key) + if state is None: + return + + state.ref_count -= 1 + if state.ref_count == 0: + if state.close_when_released: + channel_to_close = state.channel + del self._states[channel_key] + + if channel_to_close is not None: + self._close_channel(channel_to_close) + + return release + + def retire(self, channel: Any) -> None: + channel_key = id(channel) + channel_to_close = None + with self._lock: + state = self._states.get(channel_key) + if state is None: + channel_to_close = channel + else: + state.close_when_released = True + + if channel_to_close is not None: + self._close_channel(channel_to_close) + + @staticmethod + def _close_channel(channel: Any) -> None: + try: + channel.close() + except Exception: + logging.debug("Ignoring channel close failure during worker cleanup.", exc_info=True) + + class VersioningOptions: """Configuration options for orchestrator and activity versioning. @@ -369,6 +451,8 @@ class TaskHubGrpcWorker: interceptors to apply to the channel. Defaults to None. channel_options (Optional[GrpcChannelOptions], optional): Extra low-level gRPC channel configuration including retry/service config options. + resiliency_options (Optional[GrpcWorkerResiliencyOptions], optional): Worker-side + gRPC resiliency settings retained for reconnect handling. concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for controlling worker concurrency limits. If None, default settings are used. @@ -436,6 +520,7 @@ def __init__( secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[GrpcChannelOptions] = None, + resiliency_options: Optional[GrpcWorkerResiliencyOptions] = None, concurrency_options: Optional[ConcurrencyOptions] = None, maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL, payload_store: Optional[PayloadStore] = None, @@ -451,6 +536,11 @@ def __init__( self._secure_channel = secure_channel self._payload_store = payload_store self._channel_options = channel_options + self._resiliency_options = ( + resiliency_options + if resiliency_options is not None + else GrpcWorkerResiliencyOptions() + ) # Use provided concurrency options or create default ones self._concurrency_options = ( @@ -490,6 +580,27 @@ def __enter__(self): def __exit__(self, type, value, traceback): self.stop() + def _classify_stream_outcome( + self, + *, + saw_message: bool, + timed_out: bool, + ) -> _WorkItemStreamOutcome: + if timed_out: + return _WorkItemStreamOutcome.SILENT_DISCONNECT + if saw_message: + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + return _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + def _should_count_worker_failure( + self, + status_code: grpc.StatusCode, + ) -> bool: + return is_worker_transport_failure(status_code) + + def _can_recreate_channel(self) -> bool: + return self._channel is None + def add_orchestrator(self, fn: task.Orchestrator[TInput, TOutput]) -> str: """Registers an orchestrator function with the worker.""" if self._is_running: @@ -579,6 +690,8 @@ def start(self): if self._auto_generate_work_item_filters: self._work_item_filters = WorkItemFilters._from_registry(self._registry) + self._shutdown.clear() + def run_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -590,27 +703,29 @@ def run_loop(): self._is_running = True async def _async_run_loop(self): + self._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(self._async_worker_manager.run()) - # Connection state management for retry fix - current_channel = None + current_channel = self._channel current_stub = None current_reader_thread = None conn_retry_count = 0 - conn_max_retry_delay = 60 + failure_tracker = FailureTracker( + threshold=self._resiliency_options.channel_recreate_failure_threshold, + ) + in_flight_channel_tracker = _InFlightChannelTracker() + + def get_reconnect_delay_seconds() -> float: + return get_full_jitter_delay_seconds( + conn_retry_count, + base_seconds=self._resiliency_options.reconnect_backoff_base_seconds, + cap_seconds=self._resiliency_options.reconnect_backoff_cap_seconds, + ) def create_fresh_connection(): nonlocal current_channel, current_stub, conn_retry_count - if current_channel and self._channel is None: - try: - current_channel.close() - except Exception: - pass - current_channel = None current_stub = None try: - if self._channel is not None: - current_channel = self._channel - else: + if current_channel is None: current_channel = shared.get_grpc_channel( self._host_address, self._secure_channel, @@ -618,16 +733,60 @@ def create_fresh_connection(): channel_options=self._channel_options, ) current_stub = stubs.TaskHubSidecarServiceStub(current_channel) - current_stub.Hello(empty_pb2.Empty()) + hello_timeout = self._resiliency_options.hello_timeout_seconds + current_stub.Hello(empty_pb2.Empty(), timeout=hello_timeout) conn_retry_count = 0 self._logger.info(f"Created fresh connection to {self._host_address}") except Exception as e: self._logger.warning(f"Failed to create connection: {e}") - current_channel = self._channel if self._channel is not None else None current_stub = None raise - def invalidate_connection(): + def wrap_execution(handler, release): + def wrapped(*args, **kwargs): + try: + return handler(*args, **kwargs) + finally: + release() + + return wrapped + + def wrap_cancellation(handler, release): + def wrapped(*args, **kwargs): + try: + return handler(*args, **kwargs) + finally: + release() + + return wrapped + + def submit_work_item( + submit_func, + handler, + cancellation_handler, + request, + stub, + completion_token, + channel, + ): + release = in_flight_channel_tracker.acquire(channel) + try: + submit_func( + wrap_execution(handler, release), + wrap_cancellation(cancellation_handler, release), + request, + stub, + completion_token, + ) + except Exception: + release() + raise + + def invalidate_connection( + *, + recreate_channel: bool = False, + close_channel: bool = False, + ): nonlocal current_channel, current_stub, current_reader_thread # Cancel the response stream first to signal the reader thread to stop if self._response_stream is not None: @@ -647,13 +806,13 @@ def invalidate_connection(): pass current_reader_thread = None - # Close the channel - if current_channel and self._channel is None: - try: - current_channel.close() - except Exception: - pass - current_channel = self._channel if self._channel is not None else None + if ( + current_channel is not None + and self._can_recreate_channel() + and (recreate_channel or close_channel) + ): + in_flight_channel_tracker.retire(current_channel) + current_channel = None current_stub = None def should_invalidate_connection(rpc_error): @@ -671,12 +830,18 @@ def should_invalidate_connection(rpc_error): if current_stub is None: try: create_fresh_connection() - except Exception: + except Exception as ex: + recreate_channel = False + if isinstance(ex, grpc.RpcError): + error_code = ex.code() # type: ignore + if self._should_count_worker_failure(error_code): + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + invalidate_connection(recreate_channel=recreate_channel) conn_retry_count += 1 - delay = min( - conn_max_retry_delay, - (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1), - ) + delay = get_reconnect_delay_seconds() self._logger.warning( f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})" ) @@ -685,7 +850,9 @@ def should_invalidate_connection(rpc_error): continue try: assert current_stub is not None + assert current_channel is not None stub = current_stub + channel = current_channel capabilities = [] if self._payload_store is not None: capabilities.append(pb.WORKER_CAPABILITY_LARGE_PAYLOADS) @@ -707,15 +874,18 @@ def should_invalidate_connection(rpc_error): import queue work_item_queue = queue.Queue() + saw_message = False def stream_reader(): try: response_stream = self._response_stream if response_stream is None: + work_item_queue.put(_STREAM_CLOSED_SENTINEL) return for work_item in response_stream: work_item_queue.put(work_item) + work_item_queue.put(_STREAM_CLOSED_SENTINEL) except Exception as e: work_item_queue.put(e) @@ -724,63 +894,126 @@ def stream_reader(): current_reader_thread = threading.Thread(target=stream_reader, daemon=True) current_reader_thread.start() loop = asyncio.get_running_loop() + queue_timeout = ( + self._resiliency_options.silent_disconnect_timeout_seconds or None + ) + stream_outcome = None while not self._shutdown.is_set(): try: - work_item = await loop.run_in_executor( - None, work_item_queue.get + work_item = await asyncio.wait_for( + loop.run_in_executor(None, work_item_queue.get), + timeout=queue_timeout, + ) + except asyncio.TimeoutError: + work_item = None + stream_outcome = self._classify_stream_outcome( + saw_message=saw_message, + timed_out=True, + ) + break + + if work_item is _STREAM_CLOSED_SENTINEL: + stream_outcome = self._classify_stream_outcome( + saw_message=saw_message, + timed_out=False, ) + break + + try: if isinstance(work_item, Exception): raise work_item + + saw_message = True request_type = work_item.WhichOneof("request") self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField("healthPing"): + failure_tracker.record_success() + continue + + failure_tracker.record_success() if work_item.HasField("orchestratorRequest"): - self._async_worker_manager.submit_orchestration( + submit_work_item( + self._async_worker_manager.submit_orchestration, self._execute_orchestrator, self._cancel_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("activityRequest"): - self._async_worker_manager.submit_activity( + submit_work_item( + self._async_worker_manager.submit_activity, self._execute_activity, self._cancel_activity, work_item.activityRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("entityRequest"): - self._async_worker_manager.submit_entity_batch( + submit_work_item( + self._async_worker_manager.submit_entity_batch, self._execute_entity_batch, self._cancel_entity_batch, work_item.entityRequest, stub, work_item.completionToken, + channel, ) elif work_item.HasField("entityRequestV2"): - self._async_worker_manager.submit_entity_batch( + submit_work_item( + self._async_worker_manager.submit_entity_batch, self._execute_entity_batch, self._cancel_entity_batch, work_item.entityRequestV2, stub, - work_item.completionToken + work_item.completionToken, + channel, ) - elif work_item.HasField("healthPing"): - pass else: self._logger.warning( f"Unexpected work item type: {request_type}" ) except Exception as e: self._logger.warning(f"Error in work item stream: {e}") - raise e - current_reader_thread.join(timeout=1) - self._logger.info("Work item stream ended normally") + raise + + if stream_outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE: + self._logger.info( + "Work item stream closed before receiving the first message" + ) + invalidate_connection(close_channel=True) + continue + if stream_outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE: + self._logger.info("Work item stream closed after receiving messages") + invalidate_connection(close_channel=True) + continue + if stream_outcome is _WorkItemStreamOutcome.SILENT_DISCONNECT: + self._logger.warning( + f"Timed out waiting for work items from {self._host_address}" + ) + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + invalidate_connection(recreate_channel=recreate_channel) + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break + continue except grpc.RpcError as rpc_error: should_invalidate = should_invalidate_connection(rpc_error) - if should_invalidate: - invalidate_connection() error_code = rpc_error.code() # type: ignore + recreate_channel = False + if should_invalidate and self._should_count_worker_failure(error_code): + recreate_channel = ( + failure_tracker.record_failure() + and self._can_recreate_channel() + ) + if should_invalidate: + invalidate_connection(recreate_channel=recreate_channel) error_details = str(rpc_error) if error_code == grpc.StatusCode.CANCELLED: @@ -804,12 +1037,18 @@ def stream_reader(): self._logger.warning( f"Application-level gRPC error ({error_code}): {rpc_error}" ) - self._shutdown.wait(1) + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break except Exception as ex: - invalidate_connection() + invalidate_connection(close_channel=True) self._logger.warning(f"Unexpected error: {ex}") - self._shutdown.wait(1) - invalidate_connection() + conn_retry_count += 1 + delay = get_reconnect_delay_seconds() + if self._shutdown.wait(delay): + break + invalidate_connection(close_channel=True) self._logger.info("No longer listening for work items") self._async_worker_manager.shutdown() await worker_task @@ -825,6 +1064,11 @@ def stop(self): self._response_stream.cancel() if self._runLoop is not None: self._runLoop.join(timeout=30) + if self._runLoop.is_alive(): + self._logger.info( + "Waiting for pending work items to finish before completing shutdown..." + ) + self._runLoop.join() self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False @@ -2648,11 +2892,22 @@ def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logg self._pending_activity_work: list = [] self._pending_orchestration_work: list = [] self._pending_entity_batch_work: list = [] - self.thread_pool = ThreadPoolExecutor( - max_workers=concurrency_options.maximum_thread_pool_workers, + self.thread_pool = self._create_thread_pool() + self._shutdown = False + + def _create_thread_pool(self) -> ThreadPoolExecutor: + return ThreadPoolExecutor( + max_workers=self.concurrency_options.maximum_thread_pool_workers, thread_name_prefix="DurableTask", ) + + def _ensure_thread_pool(self) -> None: + if getattr(self.thread_pool, "_shutdown", False): + self.thread_pool = self._create_thread_pool() + + def prepare_for_run(self) -> None: self._shutdown = False + self._ensure_thread_pool() def _ensure_queues_for_current_loop(self): """Ensure queues are bound to the current event loop.""" @@ -2727,8 +2982,7 @@ def _ensure_queues_for_current_loop(self): self._pending_entity_batch_work.clear() async def run(self): - # Reset shutdown flag in case this manager is being reused - self._shutdown = False + self._ensure_thread_pool() # Ensure queues are properly bound to the current event loop self._ensure_queues_for_current_loop() @@ -2790,6 +3044,9 @@ async def run(self): except Exception as cancellation_exception: self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}") self.shutdown() + finally: + if not getattr(self.thread_pool, "_shutdown", False): + self.thread_pool.shutdown(wait=True) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): # List to track running tasks @@ -2833,12 +3090,7 @@ async def _run_func(self, func, *args, **kwargs): return await func(*args, **kwargs) else: loop = asyncio.get_running_loop() - # Avoid submitting to executor after shutdown - if ( - getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( - self.thread_pool, "_shutdown", False) - ): - return None + self._ensure_thread_pool() return await loop.run_in_executor( self.thread_pool, lambda: func(*args, **kwargs) ) @@ -2878,11 +3130,10 @@ def submit_entity_batch(self, func, cancellation_func, *args, **kwargs): def shutdown(self): self._shutdown = True - self.thread_pool.shutdown(wait=True) async def reset_for_new_run(self): """Reset the manager state for a new run.""" - self._shutdown = False + self.prepare_for_run() # Clear any existing queues - they'll be recreated when needed if self.activity_queue is not None: # Clear existing queue by creating a new one diff --git a/tests/durabletask-azuremanaged/test_azuremanaged_grpc_resiliency.py b/tests/durabletask-azuremanaged/test_azuremanaged_grpc_resiliency.py new file mode 100644 index 0000000..8d0fa5d --- /dev/null +++ b/tests/durabletask-azuremanaged/test_azuremanaged_grpc_resiliency.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from unittest.mock import patch + +from durabletask.azuremanaged.client import ( + AsyncDurableTaskSchedulerClient, + DurableTaskSchedulerClient, +) +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) + + +def test_dts_client_passes_resiliency_options_to_base_client(): + resiliency = GrpcClientResiliencyOptions() + with patch("durabletask.azuremanaged.client.TaskHubGrpcClient.__init__", return_value=None) as mock_init: + DurableTaskSchedulerClient( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency + + +def test_dts_worker_passes_resiliency_options_to_base_worker(): + resiliency = GrpcWorkerResiliencyOptions() + with patch("durabletask.azuremanaged.worker.TaskHubGrpcWorker.__init__", return_value=None) as mock_init: + DurableTaskSchedulerWorker( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency + + +def test_async_dts_client_passes_resiliency_options_to_base_client(): + resiliency = GrpcClientResiliencyOptions() + with patch( + "durabletask.azuremanaged.client.AsyncTaskHubGrpcClient.__init__", + return_value=None, + ) as mock_init: + AsyncDurableTaskSchedulerClient( + host_address="localhost:4001", + taskhub="hub", + token_credential=None, + resiliency_options=resiliency, + ) + assert mock_init.call_args.kwargs["resiliency_options"] is resiliency diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index 9bb56ea..78cd28f 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,15 +1,23 @@ +import asyncio import json +import grpc import pytest from datetime import datetime, timezone -from unittest.mock import ANY, AsyncMock, MagicMock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch from google.protobuf import wrappers_pb2 import durabletask.history as history import durabletask.internal.orchestrator_service_pb2 as pb from durabletask.client import AsyncTaskHubGrpcClient, OrchestrationStatus, TaskHubGrpcClient -from durabletask.grpc_options import GrpcChannelOptions, GrpcRetryPolicyOptions +from durabletask.grpc_options import ( + GrpcChannelOptions, + GrpcClientResiliencyOptions, + GrpcRetryPolicyOptions, + GrpcWorkerResiliencyOptions, +) from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore +from durabletask.worker import TaskHubGrpcWorker from durabletask.internal.grpc_interceptor import ( DefaultAsyncClientInterceptorImpl, @@ -26,6 +34,23 @@ INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] +class FakeRpcError(grpc.RpcError): + def __init__(self, status_code: grpc.StatusCode): + super().__init__() + self._status_code = status_code + + def code(self): + return self._status_code + + +def make_aio_rpc_error(status_code: grpc.StatusCode) -> grpc.aio.AioRpcError: + return grpc.aio.AioRpcError( + status_code, + grpc.aio.Metadata(), + grpc.aio.Metadata(), + ) + + class FakePayloadStore(PayloadStore): TOKEN_PREFIX = 'fake://' @@ -290,6 +315,506 @@ def test_async_client_uses_provided_channel_directly(): mock_get_channel.assert_not_called() +def test_client_stores_resiliency_options_for_recreation(): + resiliency = GrpcClientResiliencyOptions(channel_recreate_failure_threshold=7) + channel_options = GrpcChannelOptions(max_receive_message_length=1234) + interceptors = [DefaultClientInterceptorImpl(METADATA)] + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=MagicMock() + ): + client = TaskHubGrpcClient( + host_address="localhost:4001", + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + resiliency_options=resiliency, + ) + assert client._resiliency_options is resiliency + assert client._host_address == "localhost:4001" + assert client._secure_channel is True + assert client._channel_options is channel_options + assert client._interceptors == interceptors + + +def test_sync_client_recreates_sdk_owned_channel_with_original_transport_inputs(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + host_address = "localhost:4001" + interceptors = [DefaultClientInterceptorImpl(METADATA)] + channel_options = GrpcChannelOptions(max_receive_message_length=1234) + + rpc_error = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + first_stub.GetInstance.side_effect = rpc_error + + timer = MagicMock() + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch("threading.Timer", return_value=timer) as mock_timer: + client = TaskHubGrpcClient( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + client.get_orchestration_state("abc") + + expected_channel_call = call( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + ] + assert client._channel is second_channel + mock_timer.assert_called_once() + timer_call = mock_timer.call_args + assert timer_call.args[0] == 30.0 + assert timer_call.args[1].__self__ is client + assert timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert timer_call.kwargs == {"args": (first_channel,)} + assert timer.daemon is True + timer.start.assert_called_once_with() + + +def test_sync_client_close_closes_retired_channels_immediately(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub = MagicMock() + second_stub.GetInstance.return_value = MagicMock(exists=False) + close_timer = MagicMock(name="close-timer") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch("threading.Timer", return_value=close_timer): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + + client.close() + + close_timer.cancel.assert_called_once_with() + first_channel.close.assert_called_once_with() + second_channel.close.assert_called_once_with() + assert client._retired_channels == {} + + +def test_sync_client_close_closes_all_retired_sdk_channels_immediately(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + third_channel = MagicMock(name="third-channel") + first_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub = MagicMock() + second_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + third_stub = MagicMock() + timer1 = MagicMock(name="close-timer-1") + timer2 = MagicMock(name="close-timer-2") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel, third_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", + side_effect=[first_stub, second_stub, third_stub], + ), patch("threading.Timer", side_effect=[timer1, timer2]): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + + client.close() + + timer1.cancel.assert_called_once_with() + timer2.cancel.assert_called_once_with() + first_channel.close.assert_called_once_with() + second_channel.close.assert_called_once_with() + third_channel.close.assert_called_once_with() + assert client._retired_channels == {} + + +@pytest.mark.parametrize( + ("stub_method_name", "client_method_name"), + [ + ("WaitForInstanceStart", "wait_for_orchestration_start"), + ("WaitForInstanceCompletion", "wait_for_orchestration_completion"), + ], +) +def test_sync_client_resets_failure_tracking_after_long_poll_deadline( + stub_method_name: str, + client_method_name: str, +): + stub = MagicMock() + stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + getattr(stub, stub_method_name).side_effect = FakeRpcError(grpc.StatusCode.DEADLINE_EXCEEDED) + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(TimeoutError): + getattr(client, client_method_name)("abc") + assert client._client_failure_tracker.consecutive_failures == 0 + + +def test_sync_client_does_not_recreate_caller_owned_channel(): + provided_channel = MagicMock(name="provided-channel") + stub = MagicMock() + stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + + with patch("durabletask.client.shared.get_grpc_channel") as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ) as mock_stub, patch("threading.Timer") as mock_timer: + client = TaskHubGrpcClient( + channel=provided_channel, + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + client.close() + + assert client._channel is provided_channel + mock_get_channel.assert_not_called() + mock_stub.assert_called_once_with(provided_channel) + mock_timer.assert_not_called() + provided_channel.close.assert_not_called() + + +def test_sync_client_recreate_cooldown_prevents_immediate_repeated_recreation(): + first_channel = MagicMock(name="first-channel") + second_channel = MagicMock(name="second-channel") + third_channel = MagicMock(name="third-channel") + first_stub = MagicMock() + second_stub = MagicMock() + third_stub = MagicMock() + first_stub.GetInstance.side_effect = FakeRpcError(grpc.StatusCode.UNAVAILABLE) + second_stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + ] + timer1 = MagicMock(name="close-timer-1") + timer2 = MagicMock(name="close-timer-2") + + with patch( + "durabletask.client.shared.get_grpc_channel", + side_effect=[first_channel, second_channel, third_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", + side_effect=[first_stub, second_stub, third_stub], + ), patch( + "durabletask.client.time.monotonic", side_effect=[100.0, 101.0, 131.0] + ), patch("threading.Timer", side_effect=[timer1, timer2]) as mock_timer: + client = TaskHubGrpcClient( + host_address=HOST_ADDRESS, + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=30.0, + ), + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is second_channel + assert mock_get_channel.call_count == 2 + + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is second_channel + assert mock_get_channel.call_count == 2 + + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._channel is third_channel + + expected_channel_call = call( + host_address=HOST_ADDRESS, + secure_channel=False, + interceptors=None, + channel_options=None, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + expected_channel_call, + ] + assert mock_timer.call_count == 2 + first_timer_call, second_timer_call = mock_timer.call_args_list + assert first_timer_call.args[0] == 30.0 + assert first_timer_call.args[1].__self__ is client + assert first_timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert first_timer_call.kwargs == {"args": (first_channel,)} + assert second_timer_call.args[0] == 30.0 + assert second_timer_call.args[1].__self__ is client + assert second_timer_call.args[1].__func__ is TaskHubGrpcClient._close_retired_channel + assert second_timer_call.kwargs == {"args": (second_channel,)} + assert timer1.daemon is True + assert timer2.daemon is True + timer1.start.assert_called_once_with() + timer2.start.assert_called_once_with() + + +def test_sync_client_resets_failure_tracking_after_success(): + stub = MagicMock() + stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + MagicMock(exists=False), + ] + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client.get_orchestration_state("abc") is None + assert client._client_failure_tracker.consecutive_failures == 0 + + +def test_sync_client_resets_failure_tracking_after_application_error(): + stub = MagicMock() + stub.GetInstance.side_effect = [ + FakeRpcError(grpc.StatusCode.UNAVAILABLE), + FakeRpcError(grpc.StatusCode.INVALID_ARGUMENT), + ] + + with patch("durabletask.client.shared.get_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = TaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + with pytest.raises(FakeRpcError): + client.get_orchestration_state("abc") + assert client._client_failure_tracker.consecutive_failures == 0 + + +@pytest.mark.asyncio +async def test_async_client_recreates_sdk_owned_channel_with_original_transport_inputs(): + rpc_error = make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE) + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=rpc_error) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + host_address = "localhost:4001" + interceptors = [DefaultAsyncClientInterceptorImpl(METADATA)] + channel_options = GrpcChannelOptions(max_send_message_length=4321) + + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = AsyncTaskHubGrpcClient( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ), + ) + try: + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + await client.get_orchestration_state("abc") + finally: + await client.close() + + expected_channel_call = call( + host_address=host_address, + secure_channel=True, + interceptors=interceptors, + channel_options=channel_options, + ) + assert mock_get_channel.call_args_list == [ + expected_channel_call, + expected_channel_call, + ] + + +@pytest.mark.asyncio +async def test_async_client_does_not_count_wait_for_orchestration_deadline(): + stub = MagicMock() + stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + stub.WaitForInstanceCompletion = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.DEADLINE_EXCEEDED)) + + with patch("durabletask.client.shared.get_async_grpc_channel", return_value=MagicMock()), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=2) + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + with pytest.raises(TimeoutError): + await client.wait_for_orchestration_completion("abc") + assert client._client_failure_tracker.consecutive_failures == 0 + + +@pytest.mark.asyncio +async def test_async_client_close_closes_retired_channels_immediately(): + rpc_error = make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE) + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=rpc_error) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + cleanup_started = asyncio.Event() + release_cleanup = asyncio.Event() + + async def blocked_close_retired_channel(self, channel): + cleanup_started.set() + await release_cleanup.wait() + await channel.close() + + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ), patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ), patch.object( + AsyncTaskHubGrpcClient, + "_close_retired_channel", + new=blocked_close_retired_channel, + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + await cleanup_started.wait() + + try: + await client.close() + first_channel.close.assert_awaited_once() + second_channel.close.assert_awaited_once() + finally: + release_cleanup.set() + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_async_client_does_not_recreate_caller_owned_channel(): + provided_channel = MagicMock(name="provided-channel") + stub = MagicMock() + stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + + with patch("durabletask.client.shared.get_async_grpc_channel") as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", return_value=stub + ): + client = AsyncTaskHubGrpcClient( + channel=provided_channel, + resiliency_options=GrpcClientResiliencyOptions(channel_recreate_failure_threshold=1), + ) + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + with pytest.raises(grpc.aio.AioRpcError): + await client.get_orchestration_state("abc") + + assert client._channel is provided_channel + mock_get_channel.assert_not_called() + + +@pytest.mark.asyncio +async def test_async_client_close_prevents_channel_recreation_race(): + first_channel = MagicMock(name="first-channel") + first_channel.close = AsyncMock() + second_channel = MagicMock(name="second-channel") + second_channel.close = AsyncMock() + first_stub = MagicMock() + first_stub.GetInstance = AsyncMock(side_effect=make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE)) + second_stub = MagicMock() + second_stub.GetInstance = AsyncMock(return_value=MagicMock(exists=False)) + + with patch( + "durabletask.client.shared.get_async_grpc_channel", + side_effect=[first_channel, second_channel], + ) as mock_get_channel, patch( + "durabletask.client.stubs.TaskHubSidecarServiceStub", side_effect=[first_stub, second_stub] + ): + client = AsyncTaskHubGrpcClient( + resiliency_options=GrpcClientResiliencyOptions( + channel_recreate_failure_threshold=1, + min_recreate_interval_seconds=0.0, + ) + ) + await client._recreate_lock.acquire() + try: + rpc_task = asyncio.create_task(client.get_orchestration_state("abc")) + while first_stub.GetInstance.await_count == 0: + await asyncio.sleep(0) + close_task = asyncio.create_task(client.close()) + await asyncio.sleep(0) + finally: + client._recreate_lock.release() + + with pytest.raises(grpc.aio.AioRpcError): + _ = await rpc_task + _ = await close_task + + assert mock_get_channel.call_count == 1 + first_channel.close.assert_awaited_once() + second_channel.close.assert_not_awaited() + + +def test_worker_stores_resiliency_options(): + resiliency = GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=9) + worker = TaskHubGrpcWorker(resiliency_options=resiliency) + assert worker._resiliency_options is resiliency + + def test_get_orchestration_history_aggregates_chunks_and_deexternalizes_payloads(): store = FakePayloadStore() token = store.upload(b'history payload') diff --git a/tests/durabletask/test_grpc_resiliency.py b/tests/durabletask/test_grpc_resiliency.py new file mode 100644 index 0000000..f94f4e1 --- /dev/null +++ b/tests/durabletask/test_grpc_resiliency.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import grpc +import pytest + +from durabletask.grpc_options import ( + GrpcClientResiliencyOptions, + GrpcWorkerResiliencyOptions, +) +from durabletask.internal.grpc_resiliency import ( + FailureTracker, + get_full_jitter_delay_seconds, + is_client_transport_failure, + is_worker_transport_failure, +) + + +def test_worker_resiliency_defaults_are_enabled(): + options = GrpcWorkerResiliencyOptions() + + assert options.hello_timeout_seconds == 30.0 + assert options.silent_disconnect_timeout_seconds == 120.0 + assert options.channel_recreate_failure_threshold == 5 + assert options.reconnect_backoff_base_seconds == 1.0 + assert options.reconnect_backoff_cap_seconds == 30.0 + + +def test_worker_resiliency_allows_disabling_timeout_and_threshold(): + options = GrpcWorkerResiliencyOptions( + silent_disconnect_timeout_seconds=0.0, + channel_recreate_failure_threshold=0, + ) + + assert options.silent_disconnect_timeout_seconds == 0.0 + assert options.channel_recreate_failure_threshold == 0 + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"hello_timeout_seconds": 0.0}, "hello_timeout_seconds must be > 0"), + ( + {"silent_disconnect_timeout_seconds": -1.0}, + "silent_disconnect_timeout_seconds must be >= 0", + ), + ( + {"channel_recreate_failure_threshold": -1}, + "channel_recreate_failure_threshold must be >= 0", + ), + ( + {"reconnect_backoff_base_seconds": 0.0}, + "reconnect_backoff_base_seconds must be > 0", + ), + ( + {"reconnect_backoff_cap_seconds": 0.0}, + "reconnect_backoff_cap_seconds must be > 0", + ), + ( + { + "reconnect_backoff_base_seconds": 5.0, + "reconnect_backoff_cap_seconds": 1.0, + }, + "reconnect_backoff_cap_seconds must be >= " + "reconnect_backoff_base_seconds", + ), + ], +) +def test_worker_resiliency_rejects_invalid_values(kwargs, message): + with pytest.raises(ValueError, match=message): + GrpcWorkerResiliencyOptions(**kwargs) + + +def test_client_resiliency_defaults_are_enabled(): + options = GrpcClientResiliencyOptions() + + assert options.channel_recreate_failure_threshold == 5 + assert options.min_recreate_interval_seconds == 30.0 + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ( + {"channel_recreate_failure_threshold": -1}, + "channel_recreate_failure_threshold must be >= 0", + ), + ( + {"min_recreate_interval_seconds": -1.0}, + "min_recreate_interval_seconds must be >= 0", + ), + ], +) +def test_client_resiliency_rejects_invalid_values(kwargs, message): + with pytest.raises(ValueError, match=message): + GrpcClientResiliencyOptions(**kwargs) + + +def test_full_jitter_delay_is_capped(monkeypatch): + monkeypatch.setattr( + "durabletask.internal.grpc_resiliency.random.random", + lambda: 1.0, + ) + + delay = get_full_jitter_delay_seconds( + 10, + base_seconds=1.0, + cap_seconds=30.0, + ) + + assert delay == 30.0 + + +def test_full_jitter_delay_large_attempt_is_still_capped(monkeypatch): + monkeypatch.setattr( + "durabletask.internal.grpc_resiliency.random.random", + lambda: 1.0, + ) + + delay = get_full_jitter_delay_seconds( + 1_000, + base_seconds=1.0, + cap_seconds=30.0, + ) + + assert delay == 30.0 + + +def test_failure_tracker_trips_at_threshold(): + tracker = FailureTracker(threshold=3) + + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.record_failure() is True + + tracker.record_success() + + assert tracker.consecutive_failures == 0 + + +def test_failure_tracker_threshold_zero_never_trips(): + tracker = FailureTracker(threshold=0) + + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.record_failure() is False + assert tracker.consecutive_failures == 0 + + +@pytest.mark.parametrize( + "method_name", + [ + "WaitForInstanceStart", + "WaitForInstanceCompletion", + ], +) +def test_client_transport_failure_ignores_long_poll_deadlines(method_name): + assert ( + is_client_transport_failure( + method_name, + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + is False + ) + assert ( + is_client_transport_failure( + "StartInstance", + grpc.StatusCode.DEADLINE_EXCEEDED, + ) + is True + ) + assert ( + is_client_transport_failure( + "GetInstance", + grpc.StatusCode.UNAVAILABLE, + ) + is True + ) + + +def test_worker_transport_failure_filters_application_errors(): + assert is_worker_transport_failure(grpc.StatusCode.UNAVAILABLE) is True + assert is_worker_transport_failure(grpc.StatusCode.DEADLINE_EXCEEDED) is True + assert is_worker_transport_failure(grpc.StatusCode.UNAUTHENTICATED) is False + assert is_worker_transport_failure(grpc.StatusCode.NOT_FOUND) is False diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index 6fd1270..199c219 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -73,6 +73,7 @@ def cancel_dummy_activity(req, stub, completionToken): async def run_test(): # Start the worker manager's run loop in the background + worker._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken()) @@ -133,6 +134,7 @@ def fn(*args, **kwargs): # Run the manager loop in a thread (sync context) def run_manager(): + manager.prepare_for_run() asyncio.run(manager.run()) t = threading.Thread(target=run_manager) diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index 8482c20..22ffe23 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -72,6 +72,7 @@ async def cancel_dummy_activity(req, stub, completionToken): async def run_test(): # Clear stub state before each run stub.completed.clear() + grpc_worker._async_worker_manager.prepare_for_run() worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) # Need to yield to that thread in order to let it start up on the second run startup_attempts = 0 diff --git a/tests/durabletask/test_worker_resiliency.py b/tests/durabletask/test_worker_resiliency.py new file mode 100644 index 0000000..8ad64cd --- /dev/null +++ b/tests/durabletask/test_worker_resiliency.py @@ -0,0 +1,841 @@ +import asyncio +import grpc +from threading import Event, Timer +from unittest.mock import MagicMock + +import pytest + +from durabletask.grpc_options import GrpcWorkerResiliencyOptions +from durabletask.internal import orchestrator_service_pb2 as pb +from durabletask.worker import ( + _AsyncWorkerManager, + ConcurrencyOptions, + TaskHubGrpcWorker, + _WorkItemStreamOutcome, +) + + +class FakeRpcError(grpc.RpcError): + def __init__(self, status_code: grpc.StatusCode, details: str): + super().__init__() + self._status_code = status_code + self._details = details + + def code(self): + return self._status_code + + def details(self): + return self._details + + def __str__(self): + return self._details + + +class FakeResponseStream: + def __init__(self, items=(), error: grpc.RpcError | None = None): + self._items = list(items) + self._error = error + self.cancelled = False + + def __iter__(self): + yield from self._items + if self._error is not None: + raise self._error + + def cancel(self): + self.cancelled = True + + +class BlockingResponseStream: + def __init__(self): + self._cancel_event = Event() + self.cancelled = False + + def __iter__(self): + if not self._cancel_event.wait(timeout=0.5): + raise AssertionError("response stream was not cancelled") + return + yield + + def cancel(self): + self.cancelled = True + self._cancel_event.set() + + +class DummyWorkerManager: + def __init__(self): + self._shutdown_event = asyncio.Event() + self.submissions: list[tuple[str, tuple]] = [] + + def prepare_for_run(self): + self._shutdown_event = asyncio.Event() + + async def run(self): + await self._shutdown_event.wait() + + def submit_orchestration(self, *args): + self.submissions.append(("orchestrator", args)) + + def submit_activity(self, *args): + self.submissions.append(("activity", args)) + + def submit_entity_batch(self, *args): + self.submissions.append(("entity", args)) + + def shutdown(self): + self._shutdown_event.set() + + +def _complete_activity_request(req, stub, completion_token): + stub.CompleteActivityTask( + pb.ActivityResponse( + instanceId=req.orchestrationInstance.instanceId, + taskId=req.taskId, + completionToken=completion_token, + ) + ) + + +def _make_activity_work_item( + task_id: int = 1, + completion_token: str = "token", + instance_id: str = "instance-id", +) -> pb.WorkItem: + return pb.WorkItem( + activityRequest=pb.ActivityRequest( + name="test_activity", + taskId=task_id, + orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id), + ), + completionToken=completion_token, + ) + + +async def _wait_for_condition(predicate, *, timeout: float = 2.0): + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while not predicate(): + if loop.time() >= deadline: + raise AssertionError("condition was not met before timeout") + await asyncio.sleep(0.01) + + +@pytest.mark.asyncio +async def test_async_worker_manager_honors_shutdown_requested_before_run(): + manager = _AsyncWorkerManager( + ConcurrencyOptions(maximum_thread_pool_workers=1), + MagicMock(), + ) + + manager.shutdown() + await asyncio.wait_for(manager.run(), timeout=1.0) + + +def test_worker_start_clears_prior_shutdown_request(): + worker = TaskHubGrpcWorker() + worker._shutdown.set() + run_started = Event() + + async def fake_run_loop(): + run_started.set() + + worker._async_run_loop = fake_run_loop + worker.start() + worker._runLoop.join(timeout=1.0) + + assert run_started.is_set() is True + assert worker._shutdown.is_set() is False + + worker.stop() + + +def test_worker_classifies_graceful_close_before_first_message(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) + ) + outcome = worker._classify_stream_outcome( + saw_message=False, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_BEFORE_FIRST_MESSAGE + + +def test_worker_classifies_graceful_close_after_message(): + worker = TaskHubGrpcWorker() + outcome = worker._classify_stream_outcome( + saw_message=True, + timed_out=False, + ) + assert outcome is _WorkItemStreamOutcome.GRACEFUL_CLOSE_AFTER_MESSAGE + + +def test_worker_classifies_silent_disconnect(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(silent_disconnect_timeout_seconds=5.0) + ) + outcome = worker._classify_stream_outcome( + saw_message=False, + timed_out=True, + ) + assert outcome is _WorkItemStreamOutcome.SILENT_DISCONNECT + + +def test_worker_counts_only_transport_failures_for_recreation(): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + assert worker._should_count_worker_failure(grpc.StatusCode.UNAVAILABLE) is True + assert worker._should_count_worker_failure(grpc.StatusCode.UNAUTHENTICATED) is False + + +def test_worker_does_not_recreate_caller_owned_channel(): + worker = TaskHubGrpcWorker(channel=MagicMock()) + assert worker._can_recreate_channel() is False + + +@pytest.mark.asyncio +async def test_worker_applies_configured_hello_timeout(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(hello_timeout_seconds=12.5) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + stub = MagicMock() + stub.GetWorkItems.side_effect = FakeRpcError(grpc.StatusCode.CANCELLED, "stop") + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", lambda *args, **kwargs: MagicMock()) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", lambda channel: stub) + + await worker._async_run_loop() + + assert stub.Hello.call_args.kwargs["timeout"] == 12.5 + + +@pytest.mark.asyncio +async def test_worker_does_not_recreate_sdk_owned_channel_for_non_transport_setup_errors(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + first_stub = MagicMock() + first_stub.Hello.side_effect = RuntimeError("boom") + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError(grpc.StatusCode.CANCELLED, "stop") + + stubs = [first_stub, second_stub] + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr( + "durabletask.worker.stubs.TaskHubSidecarServiceStub", + lambda channel: stubs.pop(0), + ) + + await worker._async_run_loop() + + assert len(created_channels) == 1 + + +@pytest.mark.asyncio +async def test_worker_recreates_sdk_owned_channel_after_transport_failure_threshold(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=2) + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "first transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "second transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(created_channels) == 2 + assert stub_channels[0] is created_channels[0] + assert stub_channels[1] is created_channels[0] + assert stub_channels[2] is created_channels[1] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_recreates_sdk_owned_channel_after_silent_disconnect(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + channel_recreate_failure_threshold=1, + silent_disconnect_timeout_seconds=0.01, + ) + ) + worker._async_worker_manager = DummyWorkerManager() + + wait_calls = [] + + def shutdown_wait(timeout): + wait_calls.append(timeout) + return False + + monkeypatch.setattr(worker._shutdown, "wait", shutdown_wait) + + delay_calls = [] + + def fake_delay(attempt, *, base_seconds, cap_seconds): + delay_calls.append((attempt, base_seconds, cap_seconds)) + return 0.25 + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + blocking_stream = BlockingResponseStream() + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=blocking_stream)), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.get_full_jitter_delay_seconds", fake_delay) + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert blocking_stream.cancelled is True + assert delay_calls == [(1, 1.0, 30.0)] + assert wait_calls == [0.25] + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_closes_sdk_owned_channel_on_graceful_stream_reset(monkeypatch): + worker = TaskHubGrpcWorker() + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream())), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_defers_sdk_owned_channel_close_until_inflight_completion_finishes(monkeypatch): + worker = TaskHubGrpcWorker() + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + completed_responses = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + def complete_activity(response): + assert created_channels[0].close.call_count == 0 + completed_responses.append(response) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + assert len(created_channels) == 2 + assert stub_channels == created_channels + created_channels[0].close.assert_not_called() + created_channels[1].close.assert_called_once() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + func(req, stub, completion_token) + + assert len(completed_responses) == 1 + assert completed_responses[0].completionToken == "token" + created_channels[0].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_releases_inflight_channel_when_activity_handler_raises(monkeypatch): + worker = TaskHubGrpcWorker() + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + def fail_activity(req, stub, completion_token): + raise RuntimeError("boom") + + worker._execute_activity = fail_activity + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + + def create_stub(channel): + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + created_channels[0].close.assert_not_called() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + with pytest.raises(RuntimeError, match="boom"): + func(req, stub, completion_token) + + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_shutdown_drains_real_manager_work_before_closing_retired_sdk_channel(monkeypatch): + worker = TaskHubGrpcWorker( + concurrency_options=ConcurrencyOptions( + maximum_concurrent_activity_work_items=1, + maximum_thread_pool_workers=1, + ) + ) + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + allow_first_completion = Event() + first_completion_started = Event() + completed_task_ids = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[ + _make_activity_work_item(task_id=1, completion_token="token-1"), + _make_activity_work_item(task_id=2, completion_token="token-2"), + ]) + + def complete_activity(response): + completed_task_ids.append(response.taskId) + if response.taskId == 1: + first_completion_started.set() + Timer(0.2, allow_first_completion.set).start() + assert allow_first_completion.wait(timeout=5.0) + elif response.taskId == 2: + assert created_channels[0].close.call_count == 0 + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + run_task = asyncio.create_task(worker._async_run_loop()) + await asyncio.wait_for(run_task, timeout=2.0) + + assert first_completion_started.is_set() is True + assert len(created_channels) == 2 + assert stub_channels == created_channels + assert completed_task_ids == [1, 2] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_shutdown_runs_real_manager_cancellation_wrapper_before_closing_retired_sdk_channel(monkeypatch): + worker = TaskHubGrpcWorker( + concurrency_options=ConcurrencyOptions( + maximum_concurrent_activity_work_items=1, + maximum_thread_pool_workers=1, + ) + ) + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + allow_first_completion = Event() + first_completion_started = Event() + completed_task_ids = [] + cancelled_task_ids = [] + + def execute_activity(req, stub, completion_token): + if req.taskId == 1: + _complete_activity_request(req, stub, completion_token) + else: + raise RuntimeError("boom") + + def cancel_activity(req, stub, completion_token): + cancelled_task_ids.append(req.taskId) + assert created_channels[0].close.call_count == 0 + + worker._execute_activity = execute_activity + worker._cancel_activity = cancel_activity + + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[ + _make_activity_work_item(task_id=1, completion_token="token-1"), + _make_activity_work_item(task_id=2, completion_token="token-2"), + ]) + + def complete_activity(response): + completed_task_ids.append(response.taskId) + Timer(0.2, allow_first_completion.set).start() + first_completion_started.set() + assert allow_first_completion.wait(timeout=5.0) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + run_task = asyncio.create_task(worker._async_run_loop()) + await asyncio.wait_for(run_task, timeout=2.0) + + assert first_completion_started.is_set() is True + assert len(created_channels) == 2 + assert stub_channels == created_channels + assert completed_task_ids == [1] + assert cancelled_task_ids == [2] + created_channels[0].close.assert_called_once() + created_channels[1].close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_never_closes_caller_owned_channel_after_graceful_reset(monkeypatch): + provided_channel = MagicMock(name="provided-channel") + worker = TaskHubGrpcWorker(channel=provided_channel) + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + worker._execute_activity = _complete_activity_request + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + completed_responses = [] + first_stub = MagicMock() + first_stub.GetWorkItems.return_value = FakeResponseStream(items=[_make_activity_work_item()]) + + def complete_activity(response): + assert provided_channel.close.call_count == 0 + completed_responses.append(response) + + first_stub.CompleteActivityTask.side_effect = complete_activity + + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + + stubs = [first_stub, second_stub] + stub_channels = [] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr( + "durabletask.worker.shared.get_grpc_channel", + lambda *args, **kwargs: pytest.fail( + "SDK channel factory should not run for caller-owned channels" + ), + ) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert len(worker_manager.submissions) == 1 + assert stub_channels == [provided_channel, provided_channel] + provided_channel.close.assert_not_called() + + _, submission = worker_manager.submissions[0] + func, _, req, stub, completion_token = submission + func(req, stub, completion_token) + + assert len(completed_responses) == 1 + assert completed_responses[0].completionToken == "token" + provided_channel.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_worker_uses_reconnect_backoff_helper_after_connection_failure(monkeypatch): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + reconnect_backoff_base_seconds=1.5, + reconnect_backoff_cap_seconds=9.0, + ) + ) + worker._async_worker_manager = DummyWorkerManager() + + wait_calls = [] + + def shutdown_wait(timeout): + wait_calls.append(timeout) + return False + + monkeypatch.setattr(worker._shutdown, "wait", shutdown_wait) + + delay_calls = [] + + def fake_delay(attempt, *, base_seconds, cap_seconds): + delay_calls.append((attempt, base_seconds, cap_seconds)) + return 0.75 + + channel = MagicMock(name="channel-1") + stub_channels = [] + first_stub = MagicMock() + first_stub.Hello.side_effect = FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "connect failed", + ) + second_stub = MagicMock() + second_stub.GetWorkItems.side_effect = FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ) + stubs = [first_stub, second_stub] + + def create_stub(current_channel): + stub_channels.append(current_channel) + return stubs.pop(0) + + monkeypatch.setattr("durabletask.worker.get_full_jitter_delay_seconds", fake_delay) + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", lambda *args, **kwargs: channel) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert delay_calls == [(1, 1.5, 9.0)] + assert wait_calls == [0.75] + assert stub_channels == [channel, channel] + channel.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_worker_never_replaces_caller_owned_channel_during_transport_failures(monkeypatch): + provided_channel = MagicMock(name="provided-channel") + worker = TaskHubGrpcWorker( + channel=provided_channel, + resiliency_options=GrpcWorkerResiliencyOptions(channel_recreate_failure_threshold=1), + ) + worker._async_worker_manager = DummyWorkerManager() + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + stub_channels = [] + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + def create_stub(channel): + stub_channels.append(channel) + return stubs.pop(0) + + monkeypatch.setattr( + "durabletask.worker.shared.get_grpc_channel", + lambda *args, **kwargs: pytest.fail("SDK channel factory should not run for caller-owned channels"), + ) + monkeypatch.setattr("durabletask.worker.stubs.TaskHubSidecarServiceStub", create_stub) + + await worker._async_run_loop() + + assert stub_channels == [provided_channel, provided_channel] + provided_channel.close.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("work_item", "expected_submissions"), + [ + (pb.WorkItem(healthPing=pb.HealthPing()), 0), + (_make_activity_work_item(), 1), + ], +) +async def test_worker_received_messages_reset_failure_tracker(monkeypatch, work_item, expected_submissions): + worker = TaskHubGrpcWorker( + resiliency_options=GrpcWorkerResiliencyOptions( + channel_recreate_failure_threshold=2, + silent_disconnect_timeout_seconds=5.0, + ) + ) + worker_manager = DummyWorkerManager() + worker._async_worker_manager = worker_manager + monkeypatch.setattr(worker._shutdown, "wait", lambda timeout: False) + + created_channels = [] + + def get_grpc_channel(*args, **kwargs): + channel = MagicMock(name=f"channel-{len(created_channels) + 1}") + created_channels.append(channel) + return channel + + stubs = [ + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream(error=FakeRpcError( + grpc.StatusCode.UNAVAILABLE, + "first transport failure", + )))), + MagicMock(GetWorkItems=MagicMock(return_value=FakeResponseStream( + items=[work_item], + error=FakeRpcError(grpc.StatusCode.UNAVAILABLE, "second transport failure"), + ))), + MagicMock(GetWorkItems=MagicMock(side_effect=FakeRpcError( + grpc.StatusCode.CANCELLED, + "stop", + ))), + ] + + monkeypatch.setattr("durabletask.worker.shared.get_grpc_channel", get_grpc_channel) + monkeypatch.setattr( + "durabletask.worker.stubs.TaskHubSidecarServiceStub", + lambda channel: stubs.pop(0), + ) + + await worker._async_run_loop() + + assert len(created_channels) == 1 + assert len(worker_manager.submissions) == expected_submissions