diff --git a/src/Worker/Grpc/ReconnectBackoff.cs b/src/Worker/Grpc/GrpcBackoff.cs similarity index 76% rename from src/Worker/Grpc/ReconnectBackoff.cs rename to src/Worker/Grpc/GrpcBackoff.cs index dd08a58c..f60e1902 100644 --- a/src/Worker/Grpc/ReconnectBackoff.cs +++ b/src/Worker/Grpc/GrpcBackoff.cs @@ -6,9 +6,9 @@ namespace Microsoft.DurableTask.Worker.Grpc; /// -/// Helpers for computing reconnect backoff delays in the gRPC worker. +/// Helpers for computing reconnect and retry backoff delays in the gRPC worker. /// -static class ReconnectBackoff +static class GrpcBackoff { /// /// Creates a random source for reconnect jitter using an explicit random seed so multiple workers on @@ -32,10 +32,11 @@ public static Random CreateRandom() /// The base delay used for the exponential growth. /// The maximum delay before jitter is applied. /// The random source used for jitter. + /// If true, applies full jitter. If false, applies a smaller jitter that is biased towards the upper bound. /// The computed jittered delay. - public static TimeSpan Compute(int attempt, TimeSpan baseDelay, TimeSpan cap, Random random) + public static TimeSpan Compute(int attempt, TimeSpan baseDelay, TimeSpan cap, Random random, bool fullJitter) { - if (baseDelay <= TimeSpan.Zero) + if (baseDelay <= TimeSpan.Zero || cap <= TimeSpan.Zero) { return TimeSpan.Zero; } @@ -48,13 +49,13 @@ public static TimeSpan Compute(int attempt, TimeSpan baseDelay, TimeSpan cap, Ra // Cap the exponent to avoid overflow in 2^attempt for pathological attempt values. int safeAttempt = Math.Min(attempt, 30); - double capMs = Math.Max(0, cap.TotalMilliseconds); double exponentialMs = baseDelay.TotalMilliseconds * Math.Pow(2, safeAttempt); - double upperBoundMs = Math.Min(capMs, exponentialMs); + double upperBoundMs = Math.Min(cap.TotalMilliseconds, exponentialMs); + + double jitteredMs = fullJitter + ? random.NextDouble() * upperBoundMs + : upperBoundMs + (random.NextDouble() * (upperBoundMs * .2)); - // Full jitter intentionally allows any value in the retry window. The wide spread keeps many - // workers that saw the same outage from reconnecting in lockstep against the backend. - double jitteredMs = random.NextDouble() * upperBoundMs; return TimeSpan.FromMilliseconds(jitteredMs); } } diff --git a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs index 7f61b68a..4c5a18b2 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorker.Processor.cs @@ -62,7 +62,7 @@ public async Task ExecuteAsync(CancellationToken cancellati // Tracks consecutive retry attempts for backoff calculation. Reset on first stream message. int reconnectAttempt = 0; - Random backoffRandom = ReconnectBackoff.CreateRandom(); + Random backoffRandom = GrpcBackoff.CreateRandom(); while (!cancellation.IsCancellationRequested) { @@ -149,13 +149,16 @@ await this.ProcessWorkItemsAsync( try { - TimeSpan delay = ReconnectBackoff.Compute( + // Full jitter intentionally allows any value in the retry window. The wide spread keeps many + // workers that saw the same outage from reconnecting in lockstep against the backend. + TimeSpan delay = GrpcBackoff.Compute( reconnectAttempt, this.internalOptions.ReconnectBackoffBase, this.internalOptions.ReconnectBackoffCap, - backoffRandom); - this.Logger.ReconnectBackoff(reconnectAttempt, (int)Math.Min(int.MaxValue, delay.TotalMilliseconds)); - reconnectAttempt = Math.Min(reconnectAttempt + 1, 30); // cap to avoid overflow in 2^attempt + backoffRandom, + fullJitter: true); + this.Logger.ReconnectBackoff(reconnectAttempt, (int)delay.TotalMilliseconds); + reconnectAttempt++; await Task.Delay(delay, cancellation); } catch (OperationCanceledException) when (cancellation.IsCancellationRequested) @@ -472,91 +475,75 @@ void RunBackgroundTask(P.WorkItem? workItem, Func handler, CancellationTok if (workItem?.OrchestratorRequest != null) { - try - { - this.Logger.AbandoningOrchestratorWorkItem(instanceId, workItem?.CompletionToken ?? string.Empty); - await this.client.AbandonTaskOrchestratorWorkItemAsync( + this.Logger.AbandoningOrchestratorWorkItem(instanceId, workItem.CompletionToken ?? string.Empty); + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskOrchestratorWorkItemAsync( new P.AbandonOrchestrationTaskRequest { - CompletionToken = workItem?.CompletionToken, + CompletionToken = workItem.CompletionToken, }, - cancellationToken: cancellation); - this.Logger.AbandonedOrchestratorWorkItem(instanceId, workItem?.CompletionToken ?? string.Empty); - } - catch (Exception abandonException) - { - this.Logger.UnexpectedError(abandonException, instanceId); - } + cancellationToken: cancellation), + nameof(this.client.AbandonTaskOrchestratorWorkItemAsync), + cancellation); + this.Logger.AbandonedOrchestratorWorkItem(instanceId, workItem.CompletionToken ?? string.Empty); } else if (workItem?.ActivityRequest != null) { - try - { - this.Logger.AbandoningActivityWorkItem( - instanceId, - workItem.ActivityRequest.Name, - workItem.ActivityRequest.TaskId, - workItem?.CompletionToken ?? string.Empty); - await this.client.AbandonTaskActivityWorkItemAsync( + this.Logger.AbandoningActivityWorkItem( + instanceId, + workItem.ActivityRequest.Name, + workItem.ActivityRequest.TaskId, + workItem.CompletionToken ?? string.Empty); + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskActivityWorkItemAsync( new P.AbandonActivityTaskRequest { - CompletionToken = workItem?.CompletionToken, + CompletionToken = workItem.CompletionToken, }, - cancellationToken: cancellation); - this.Logger.AbandonedActivityWorkItem( - instanceId, - workItem.ActivityRequest.Name, - workItem.ActivityRequest.TaskId, - workItem?.CompletionToken ?? string.Empty); - } - catch (Exception abandonException) - { - this.Logger.UnexpectedError(abandonException, instanceId); - } + cancellationToken: cancellation), + nameof(this.client.AbandonTaskActivityWorkItemAsync), + cancellation); + this.Logger.AbandonedActivityWorkItem( + instanceId, + workItem.ActivityRequest.Name, + workItem.ActivityRequest.TaskId, + workItem.CompletionToken ?? string.Empty); } else if (workItem?.EntityRequest != null) { - try - { - this.Logger.AbandoningEntityWorkItem( - workItem.EntityRequest.InstanceId, - workItem?.CompletionToken ?? string.Empty); - await this.client.AbandonTaskEntityWorkItemAsync( + this.Logger.AbandoningEntityWorkItem( + workItem.EntityRequest.InstanceId, + workItem.CompletionToken ?? string.Empty); + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskEntityWorkItemAsync( new P.AbandonEntityTaskRequest { - CompletionToken = workItem?.CompletionToken, + CompletionToken = workItem.CompletionToken, }, - cancellationToken: cancellation); - this.Logger.AbandonedEntityWorkItem( - workItem.EntityRequest.InstanceId, - workItem?.CompletionToken ?? string.Empty); - } - catch (Exception abandonException) - { - this.Logger.UnexpectedError(abandonException, workItem.EntityRequest.InstanceId); - } + cancellationToken: cancellation), + nameof(this.client.AbandonTaskEntityWorkItemAsync), + cancellation); + this.Logger.AbandonedEntityWorkItem( + workItem.EntityRequest.InstanceId, + workItem.CompletionToken ?? string.Empty); } else if (workItem?.EntityRequestV2 != null) { - try - { - this.Logger.AbandoningEntityWorkItem( - workItem.EntityRequestV2.InstanceId, - workItem?.CompletionToken ?? string.Empty); - await this.client.AbandonTaskEntityWorkItemAsync( + this.Logger.AbandoningEntityWorkItem( + workItem.EntityRequestV2.InstanceId, + workItem.CompletionToken ?? string.Empty); + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskEntityWorkItemAsync( new P.AbandonEntityTaskRequest { - CompletionToken = workItem?.CompletionToken, + CompletionToken = workItem.CompletionToken, }, - cancellationToken: cancellation); - this.Logger.AbandonedEntityWorkItem( - workItem.EntityRequestV2.InstanceId, - workItem?.CompletionToken ?? string.Empty); - } - catch (Exception abandonException) - { - this.Logger.UnexpectedError(abandonException, workItem.EntityRequestV2.InstanceId); - } + cancellationToken: cancellation), + nameof(this.client.AbandonTaskEntityWorkItemAsync), + cancellation); + this.Logger.AbandonedEntityWorkItem( + workItem.EntityRequestV2.InstanceId, + workItem.CompletionToken ?? string.Empty); } } }); @@ -703,13 +690,16 @@ async Task OnRunOrchestratorAsync( if (!filterPassed) { this.Logger.AbandoningOrchestrationDueToOrchestrationFilter(request.InstanceId, completionToken); - await this.client.AbandonTaskOrchestratorWorkItemAsync( - new P.AbandonOrchestrationTaskRequest - { - CompletionToken = completionToken, - }, - cancellationToken: cancellationToken); - + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskOrchestratorWorkItemAsync( + new P.AbandonOrchestrationTaskRequest + { + CompletionToken = completionToken, + }, + cancellationToken: cancellationToken), + nameof(this.client.AbandonTaskOrchestratorWorkItemAsync), + cancellationToken); + this.Logger.AbandonedOrchestratorWorkItem(request.InstanceId, completionToken); return; } @@ -804,13 +794,16 @@ await this.client.AbandonTaskOrchestratorWorkItemAsync( else { this.Logger.AbandoningOrchestrationDueToVersioning(request.InstanceId, completionToken); - await this.client.AbandonTaskOrchestratorWorkItemAsync( - new P.AbandonOrchestrationTaskRequest - { - CompletionToken = completionToken, - }, - cancellationToken: cancellationToken); - + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskOrchestratorWorkItemAsync( + new P.AbandonOrchestrationTaskRequest + { + CompletionToken = completionToken, + }, + cancellationToken: cancellationToken), + nameof(this.client.AbandonTaskOrchestratorWorkItemAsync), + cancellationToken); + this.Logger.AbandonedOrchestratorWorkItem(request.InstanceId, completionToken); return; } } @@ -915,12 +908,16 @@ async Task OnRunActivityAsync(P.ActivityRequest request, string completionToken, if (this.worker.workerOptions.Versioning?.FailureStrategy == DurableTaskWorkerOptions.VersionFailureStrategy.Reject) { this.Logger.AbandoningActivityWorkItem(instance.InstanceId, request.Name, request.TaskId, completionToken); - await this.client.AbandonTaskActivityWorkItemAsync( - new P.AbandonActivityTaskRequest - { - CompletionToken = completionToken, - }, - cancellationToken: cancellation); + await this.ExecuteWithRetryAsync( + async () => await this.client.AbandonTaskActivityWorkItemAsync( + new P.AbandonActivityTaskRequest + { + CompletionToken = completionToken, + }, + cancellationToken: cancellation), + nameof(this.client.AbandonTaskActivityWorkItemAsync), + cancellation); + this.Logger.AbandonedActivityWorkItem(instance.InstanceId, request.Name, request.TaskId, completionToken); } return; @@ -954,7 +951,10 @@ await this.client.AbandonTaskActivityWorkItemAsync( // Stop the trace activity here to avoid including the completion time in the latency calculation traceActivity?.Stop(); - await this.client.CompleteActivityTaskAsync(response, cancellationToken: cancellation); + await this.ExecuteWithRetryAsync( + async () => await this.client.CompleteActivityTaskAsync(response, cancellationToken: cancellation), + nameof(this.client.CompleteActivityTaskAsync), + cancellation); } async Task OnRunEntityBatchAsync( @@ -1020,7 +1020,10 @@ async Task OnRunEntityBatchAsync( completionToken, operationInfos?.Take(batchResult.Results?.Count ?? 0)); - await this.client.CompleteEntityTaskAsync(response, cancellationToken: cancellation); + await this.ExecuteWithRetryAsync( + async () => await this.client.CompleteEntityTaskAsync(response, cancellationToken: cancellation), + nameof(this.client.CompleteEntityTaskAsync), + cancellation); } /// @@ -1082,7 +1085,10 @@ async Task CompleteOrchestratorTaskWithChunkingAsync( }, }; - await this.client.CompleteOrchestratorTaskAsync(failureResponse, cancellationToken: cancellationToken); + await this.ExecuteWithRetryAsync( + async () => await this.client.CompleteOrchestratorTaskAsync(failureResponse, cancellationToken: cancellationToken), + nameof(this.client.CompleteOrchestratorTaskAsync), + cancellationToken); return; } @@ -1109,7 +1115,10 @@ static bool TryAddAction( if (totalSize <= maxChunkBytes) { // Response fits in one chunk, send it directly (isPartial defaults to false) - await this.client.CompleteOrchestratorTaskAsync(response, cancellationToken: cancellationToken); + await this.ExecuteWithRetryAsync( + async () => await this.client.CompleteOrchestratorTaskAsync(response, cancellationToken: cancellationToken), + nameof(this.client.CompleteOrchestratorTaskAsync), + cancellationToken); return; } @@ -1169,7 +1178,67 @@ static bool TryAddAction( chunkIndex++; // Send the chunk - await this.client.CompleteOrchestratorTaskAsync(chunkedResponse, cancellationToken: cancellationToken); + await this.ExecuteWithRetryAsync( + async () => await this.client.CompleteOrchestratorTaskAsync(chunkedResponse, cancellationToken: cancellationToken), + nameof(this.client.CompleteOrchestratorTaskAsync), + cancellationToken); + } + } + + async Task ExecuteWithRetryAsync( + Func action, + string operationName, + CancellationToken cancellationToken) + { + int maxAttempts = this.internalOptions.TransientRetryMaxAttempts; + TimeSpan baseDelay = this.internalOptions.TransientRetryBackoffBase; + TimeSpan cap = this.internalOptions.TransientRetryBackoffCap; + Random retryRandom; +#if NET6_0_OR_GREATER + retryRandom = Random.Shared; +#else + retryRandom = new Random(); +#endif + + for (int attempt = 1; ; attempt++) + { + try + { + await action(); + return; + } + catch (RpcException ex) when ( + (ex.StatusCode == StatusCode.Unavailable || + ex.StatusCode == StatusCode.Unknown || + ex.StatusCode == StatusCode.DeadlineExceeded || + ex.StatusCode == StatusCode.Internal) && + attempt < maxAttempts) + { + // Don't use full jitter since we want to keep the retry interval fairly fixed and increasing with + // each attempt. We don't have lockstep concerns in this case. + // Also make sure to zero-index the attempts + TimeSpan backoff = GrpcBackoff.Compute(attempt - 1, baseDelay, cap, retryRandom, fullJitter: false); + + this.Logger.TransientGrpcRetry( + operationName, + attempt, + maxAttempts, + backoff.TotalMilliseconds, + (int)ex.StatusCode, + ex); + + try + { + await Task.Delay(backoff, cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // If shutting down during the retry delay, propagate the cancellation exception + throw; + } + + continue; + } } } } diff --git a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs index 464c50a8..49bd6350 100644 --- a/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs +++ b/src/Worker/Grpc/GrpcDurableTaskWorkerOptions.cs @@ -135,6 +135,29 @@ internal class InternalOptions /// public TimeSpan ReconnectBackoffCap { get; set; } = TimeSpan.FromSeconds(30); + /// + /// Gets or sets the maximum number of attempts the worker will make when retrying a transient + /// gRPC call (such as completing or abandoning a work item). Once this many attempts have failed, + /// the most recent exception is rethrown. Defaults to 10. + /// + public int TransientRetryMaxAttempts { get; set; } = 10; + + /// + /// Gets or sets the initial delay used when computing exponential backoff between retries of a + /// transient gRPC call. The delay doubles after each failed attempt, and the exponential component + /// is capped at before jitter is applied. In the default + /// biased-jitter mode, the final delay may therefore slightly exceed + /// . Defaults to 200 ms. + /// + public TimeSpan TransientRetryBackoffBase { get; set; } = TimeSpan.FromMilliseconds(200); + + /// + /// Gets or sets the cap applied to the exponential backoff component between retries of a transient + /// gRPC call before jitter is applied. In the default biased-jitter mode, the final computed delay + /// may be slightly greater than this value. Defaults to 15 seconds. + /// + public TimeSpan TransientRetryBackoffCap { get; set; } = TimeSpan.FromSeconds(15); + /// /// Gets or sets an optional callback invoked when the worker requests a fresh gRPC channel after /// repeated connect failures. The callback receives the previously-used channel and should return diff --git a/src/Worker/Grpc/Logs.cs b/src/Worker/Grpc/Logs.cs index ea585dcf..878efe9c 100644 --- a/src/Worker/Grpc/Logs.cs +++ b/src/Worker/Grpc/Logs.cs @@ -99,6 +99,9 @@ static partial class Logs public static partial void ReceivedHealthPing(this ILogger logger); [LoggerMessage(EventId = 76, Level = LogLevel.Information, Message = "Work-item stream ended by the backend (graceful close). Will reconnect.")] - public static partial void StreamEndedByPeer(this ILogger logger); + public static partial void StreamEndedByPeer(this ILogger logger); + + [LoggerMessage(EventId = 77, Level = LogLevel.Warning, Message = "Transient gRPC error for '{OperationName}'. Attempt {Attempt} of {MaxAttempts}. Retrying in {BackoffMs} ms. StatusCode={StatusCode}")] + public static partial void TransientGrpcRetry(this ILogger logger, string operationName, int attempt, int maxAttempts, double backoffMs, int statusCode, Exception exception); } } diff --git a/test/Worker/Grpc.Tests/ExecuteWithRetryTests.cs b/test/Worker/Grpc.Tests/ExecuteWithRetryTests.cs new file mode 100644 index 00000000..50ae19b5 --- /dev/null +++ b/test/Worker/Grpc.Tests/ExecuteWithRetryTests.cs @@ -0,0 +1,319 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Reflection; +using Grpc.Core; +using Microsoft.DurableTask.Tests.Logging; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.Grpc.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit.Abstractions; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class ExecuteWithRetryTests +{ + const string Category = "Microsoft.DurableTask.Worker.Grpc"; + + static readonly MethodInfo ExecuteWithRetryAsyncMethod = FindExecuteWithRetryAsyncMethod(); + + + static MethodInfo FindExecuteWithRetryAsyncMethod() + { + Type processorType = typeof(GrpcDurableTaskWorker).GetNestedType("Processor", BindingFlags.NonPublic)!; + return processorType.GetMethod("ExecuteWithRetryAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + } + + [Fact] + public async Task ExecuteWithRetryAsync_SucceedsOnFirstAttempt_DoesNotRetry() + { + // Arrange + object processor = CreateProcessor(); + int callCount = 0; + + // Act + await InvokeExecuteWithRetryAsync( + processor, + () => { callCount++; return Task.CompletedTask; }, + "TestOperation", + CancellationToken.None); + + // Assert + callCount.Should().Be(1); + } + + [Theory] + [InlineData(StatusCode.Unavailable)] + [InlineData(StatusCode.Unknown)] + [InlineData(StatusCode.DeadlineExceeded)] + [InlineData(StatusCode.Internal)] + public async Task ExecuteWithRetryAsync_TransientError_RetriesAndEventuallySucceeds(StatusCode statusCode) + { + // Arrange + object processor = CreateProcessor(); + int callCount = 0; + + // Act - fail once then succeed + await InvokeExecuteWithRetryAsync( + processor, + () => + { + callCount++; + if (callCount == 1) + { + throw new RpcException(new Status(statusCode, "transient error")); + } + + return Task.CompletedTask; + }, + "TestOperation", + CancellationToken.None); + + // Assert + callCount.Should().Be(2); + } + + [Theory] + [InlineData(StatusCode.InvalidArgument)] + [InlineData(StatusCode.AlreadyExists)] + [InlineData(StatusCode.PermissionDenied)] + public async Task ExecuteWithRetryAsync_NonTransientError_ThrowsWithoutRetrying(StatusCode statusCode) + { + // Arrange + object processor = CreateProcessor(); + int callCount = 0; + + // Act + Func act = () => InvokeExecuteWithRetryAsync( + processor, + () => + { + callCount++; + throw new RpcException(new Status(statusCode, "non-transient error")); + }, + "TestOperation", + CancellationToken.None); + + // Assert + await act.Should().ThrowAsync().Where(e => e.StatusCode == statusCode); + callCount.Should().Be(1); + } + + [Fact] + public async Task ExecuteWithRetryAsync_CancellationRequestedDuringRetryDelay_ThrowsOperationCanceledException() + { + // Arrange + using CancellationTokenSource cts = new(); + object processor = CreateProcessor(); + + // Act - cancel immediately after first failure so the retry delay is cancelled + Func act = () => InvokeExecuteWithRetryAsync( + processor, + () => + { + cts.Cancel(); + throw new RpcException(new Status(StatusCode.Unavailable, "transient error")); + }, + "TestOperation", + cts.Token); + + // Assert + await act.Should().ThrowAsync(); + } + + [Fact] + public async Task ExecuteWithRetryAsync_TransientError_LogsRetryAttempt() + { + // Arrange + TestLogProvider logProvider = new(new NullOutput()); + object processor = CreateProcessor(logProvider); + int callCount = 0; + const string operationName = "CompleteOrchestratorTaskAsync"; + + // Act - fail once then succeed + await InvokeExecuteWithRetryAsync( + processor, + () => + { + callCount++; + if (callCount == 1) + { + throw new RpcException(new Status(StatusCode.Unavailable, "transient error")); + } + + return Task.CompletedTask; + }, + operationName, + CancellationToken.None); + + // Assert + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => + log.Message.Contains($"Transient gRPC error for '{operationName}'") && + log.Message.Contains("Attempt 1 of 10") && + log.Message.Contains($"StatusCode={(int)StatusCode.Unavailable}")); + } + + [Fact] + public async Task ExecuteWithRetryAsync_MultipleTransientErrors_LogsEachRetryAttempt() + { + // Arrange + TestLogProvider logProvider = new(new NullOutput()); + object processor = CreateProcessor(logProvider); + int callCount = 0; + const string operationName = "CompleteActivityTaskAsync"; + + // Act - fail twice then succeed + await InvokeExecuteWithRetryAsync( + processor, + () => + { + callCount++; + if (callCount < 3) + { + throw new RpcException(new Status(StatusCode.Unavailable, "transient error")); + } + + return Task.CompletedTask; + }, + operationName, + CancellationToken.None); + + // Assert + logProvider.TryGetLogs(Category, out IReadOnlyCollection? logs).Should().BeTrue(); + logs!.Should().Contain(log => + log.Message.Contains($"Transient gRPC error for '{operationName}'") && + log.Message.Contains("Attempt 1 of 10") && + log.Message.Contains($"StatusCode={(int)StatusCode.Unavailable}")); + logs.Should().Contain(log => + log.Message.Contains($"Transient gRPC error for '{operationName}'") && + log.Message.Contains("Attempt 2 of 10") && + log.Message.Contains($"StatusCode={(int)StatusCode.Unavailable}")); + callCount.Should().Be(3); + } + + [Fact] + public async Task ExecuteWithRetryAsync_TransientErrorExceedsMaxAttempts_ThrowsLastRpcException() + { + // Arrange - use a small backoff base to avoid long delays in the test + const int maxAttempts = 10; + object processor = CreateProcessor(transientRetryMaxAttempts: maxAttempts, transientRetryBackoffBase: TimeSpan.FromMilliseconds(1)); + int callCount = 0; + StatusCode lastStatusCode = StatusCode.Unavailable; + + // Act - always fail with a transient error + Func act = () => InvokeExecuteWithRetryAsync( + processor, + () => + { + callCount++; + throw new RpcException(new Status(lastStatusCode, "persistent transient error")); + }, + "TestOperation", + CancellationToken.None); + + // Assert - the last RpcException should be surfaced after max attempts. + await act.Should().ThrowAsync().Where(e => e.StatusCode == lastStatusCode); + callCount.Should().Be(maxAttempts); + } + + static object CreateProcessor( + TestLogProvider? logProvider = null, + int? transientRetryMaxAttempts = null, + TimeSpan? transientRetryBackoffBase = null) + { + ILoggerFactory loggerFactory = logProvider is null + ? NullLoggerFactory.Instance + : new SimpleLoggerFactory(logProvider); + + Mock factoryMock = new(MockBehavior.Strict); + GrpcDurableTaskWorkerOptions grpcOptions = new(); + if (transientRetryMaxAttempts.HasValue) + { + grpcOptions.Internal.TransientRetryMaxAttempts = transientRetryMaxAttempts.Value; + } + + if (transientRetryBackoffBase.HasValue) + { + grpcOptions.Internal.TransientRetryBackoffBase = transientRetryBackoffBase.Value; + } + + DurableTaskWorkerOptions workerOptions = new() + { + Logging = { UseLegacyCategories = false }, + }; + + GrpcDurableTaskWorker worker = new( + name: "Test", + factory: factoryMock.Object, + grpcOptions: new OptionsMonitorStub(grpcOptions), + workerOptions: new OptionsMonitorStub(workerOptions), + services: Mock.Of(), + loggerFactory: loggerFactory, + orchestrationFilter: null, + exceptionPropertiesProvider: null); + + CallInvoker callInvoker = Mock.Of(); + P.TaskHubSidecarService.TaskHubSidecarServiceClient client = new(callInvoker); + + Type processorType = typeof(GrpcDurableTaskWorker).GetNestedType("Processor", BindingFlags.NonPublic)!; + return Activator.CreateInstance( + processorType, + BindingFlags.Public | BindingFlags.Instance, + binder: null, + args: new object?[] { worker, client, null, null }, + culture: null)!; + } + + static Task InvokeExecuteWithRetryAsync( + object processor, + Func action, + string operationName, + CancellationToken cancellationToken) + { + return (Task)ExecuteWithRetryAsyncMethod.Invoke( + processor, + new object?[] { action, operationName, cancellationToken })!; + } + + sealed class OptionsMonitorStub : IOptionsMonitor where T : class, new() + { + readonly T value; + + public OptionsMonitorStub(T value) => this.value = value; + + public T CurrentValue => this.value; + + public T Get(string? name) => this.value; + + public IDisposable OnChange(Action listener) => NullDisposable.Instance; + + sealed class NullDisposable : IDisposable + { + public static readonly NullDisposable Instance = new(); + public void Dispose() { } + } + } + + sealed class SimpleLoggerFactory : ILoggerFactory + { + readonly ILoggerProvider provider; + + public SimpleLoggerFactory(ILoggerProvider provider) => this.provider = provider; + + public void AddProvider(ILoggerProvider loggerProvider) { } + + public ILogger CreateLogger(string categoryName) => this.provider.CreateLogger(categoryName); + + public void Dispose() { } + } + + sealed class NullOutput : ITestOutputHelper + { + public void WriteLine(string message) { } + public void WriteLine(string format, params object[] args) { } + } +} diff --git a/test/Worker/Grpc.Tests/GrpcBackoffTests.cs b/test/Worker/Grpc.Tests/GrpcBackoffTests.cs new file mode 100644 index 00000000..93b1cf11 --- /dev/null +++ b/test/Worker/Grpc.Tests/GrpcBackoffTests.cs @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Worker.Grpc.Tests; + +public class GrpcBackoffTests +{ + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Compute_ZeroBase_ReturnsZero(bool fullJitter) + { + // Arrange + Random random = new(42); + + // Act + TimeSpan delay = GrpcBackoff.Compute(attempt: 5, baseDelay: TimeSpan.Zero, cap: TimeSpan.FromSeconds(30), random, fullJitter); + + // Assert + delay.Should().Be(TimeSpan.Zero); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Compute_NegativeBase_ReturnsZero(bool fullJitter) + { + // Arrange + Random random = new(42); + + // Act + TimeSpan delay = GrpcBackoff.Compute(attempt: 0, baseDelay: TimeSpan.FromMilliseconds(-100), cap: TimeSpan.FromSeconds(30), random, fullJitter); + + // Assert + delay.Should().Be(TimeSpan.Zero); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Compute_NonPositiveCap_ReturnsZero(bool fullJitter) + { + // Arrange + DeterministicRandom random = new(0.999999); + + // Act + TimeSpan zero = GrpcBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.Zero, random, fullJitter); + TimeSpan negative = GrpcBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(-1), random, fullJitter); + + // Assert + zero.Should().Be(TimeSpan.Zero); + negative.Should().Be(TimeSpan.Zero); + } + + [Fact] + public void Compute_FullJitter_NeverExceedsCap() + { + // Arrange + TimeSpan cap = TimeSpan.FromSeconds(30); + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + Random random = new(1); + + // Act + Assert: try a wide range of attempts, including pathological values. + // Note: this invariant is full-jitter-specific — biased mode intentionally returns up to + // 1.2x the upper bound and so can legally exceed the cap. + for (int attempt = 0; attempt < 50; attempt++) + { + TimeSpan delay = GrpcBackoff.Compute(attempt, baseDelay, cap, random, fullJitter: true); + delay.Should().BeLessThanOrEqualTo(cap, $"attempt {attempt} produced {delay}"); + delay.Should().BeGreaterThanOrEqualTo(TimeSpan.Zero); + } + } + + [Fact] + public void Compute_FullJitter_GrowsExponentiallyUntilCap() + { + // Arrange: a Random that always returns ~1.0 forces the upper bound of the jitter window. + DeterministicRandom random = new(value: 0.999999); + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + TimeSpan cap = TimeSpan.FromSeconds(30); + + // Act + double d0 = GrpcBackoff.Compute(0, baseDelay, cap, random, fullJitter: true).TotalMilliseconds; + double d1 = GrpcBackoff.Compute(1, baseDelay, cap, random, fullJitter: true).TotalMilliseconds; + double d2 = GrpcBackoff.Compute(2, baseDelay, cap, random, fullJitter: true).TotalMilliseconds; + double d3 = GrpcBackoff.Compute(3, baseDelay, cap, random, fullJitter: true).TotalMilliseconds; + double d10 = GrpcBackoff.Compute(10, baseDelay, cap, random, fullJitter: true).TotalMilliseconds; + + // Assert: roughly doubles each step until cap is reached. + d0.Should().BeApproximately(1000, 1); + d1.Should().BeApproximately(2000, 1); + d2.Should().BeApproximately(4000, 1); + d3.Should().BeApproximately(8000, 1); + d10.Should().BeApproximately(30000, 1, "should be clamped at the cap"); + } + + [Fact] + public void Compute_FullJitter_StaysWithinBounds() + { + // Arrange: with random=0 the result is 0; with random=1 the result is the bound. + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + TimeSpan cap = TimeSpan.FromSeconds(30); + + // Act + Assert: random=0 → 0 + TimeSpan low = GrpcBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.0), fullJitter: true); + low.TotalMilliseconds.Should().BeApproximately(0, 0.5); + + // random ~1 → bound (= 8s for attempt=3, base=1s) + TimeSpan high = GrpcBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.999999), fullJitter: true); + high.TotalMilliseconds.Should().BeApproximately(8000, 1); + } + + [Fact] + public void Compute_BiasedJitter_StaysWithinBounds() + { + // Arrange: biased jitter returns a value in [upperBound, upperBound * 1.2]. + // attempt=3, base=1s → upperBound=8s. + TimeSpan baseDelay = TimeSpan.FromSeconds(1); + TimeSpan cap = TimeSpan.FromSeconds(30); + + // Act + Assert: random=0 → upperBound (lower edge of biased window). + TimeSpan low = GrpcBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.0), fullJitter: false); + low.TotalMilliseconds.Should().BeApproximately(8000, 1); + + // random ~1 → upperBound * 1.2 (upper edge). + TimeSpan high = GrpcBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.999999), fullJitter: false); + high.TotalMilliseconds.Should().BeApproximately(9600, 1); + + // mid value → halfway. + TimeSpan mid = GrpcBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.5), fullJitter: false); + mid.TotalMilliseconds.Should().BeApproximately(8800, 1); + } + + [Fact] + public void Compute_NegativeAttempt_TreatedAsZero() + { + // Arrange + DeterministicRandom random = new(0.999999); + + // Act + TimeSpan delay = GrpcBackoff.Compute(attempt: -5, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(30), random, fullJitter: true); + + // Assert + delay.TotalMilliseconds.Should().BeApproximately(1000, 1); + } + + [Fact] + public void Compute_FullJitter_CapSmallerThanBase_ClampsToCap() + { + // Arrange: cap is intentionally smaller than baseDelay; the cap must still be honored. + // Note: biased mode would return up to 1.2 * cap here by design, so this invariant is + // full-jitter-only. + DeterministicRandom random = new(0.999999); + TimeSpan baseDelay = TimeSpan.FromSeconds(5); + TimeSpan cap = TimeSpan.FromSeconds(1); + + // Act + TimeSpan delay = GrpcBackoff.Compute(attempt: 3, baseDelay, cap, random, fullJitter: true); + + // Assert: with random ~ 1 the result is the bound, which must equal the cap. + delay.TotalMilliseconds.Should().BeApproximately(1000, 1); + delay.Should().BeLessThanOrEqualTo(cap); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Compute_AttemptIsCappedAt30(bool fullJitter) + { + // Arrange: pick a base/cap where the cap is large enough that 2^30 * base does not saturate it, + // so the exponent — not the cap — drives the upper bound. This lets us observe the internal + // attempt clamp at 30: any attempt ≥ 30 must yield the same upper bound as attempt = 30. + TimeSpan baseDelay = TimeSpan.FromMilliseconds(1); + TimeSpan cap = TimeSpan.FromDays(365); // 2^30 ms ≈ 12.4 days < cap. + + // Act: produce a fresh DeterministicRandom for each call so the same NextDouble() value is + // sampled + TimeSpan at30 = GrpcBackoff.Compute(30, baseDelay, cap, new DeterministicRandom(1.0), fullJitter); + TimeSpan at31 = GrpcBackoff.Compute(31, baseDelay, cap, new DeterministicRandom(1.0), fullJitter); + TimeSpan at100 = GrpcBackoff.Compute(100, baseDelay, cap, new DeterministicRandom(1.0), fullJitter); + TimeSpan atIntMax = GrpcBackoff.Compute(int.MaxValue, baseDelay, cap, new DeterministicRandom(1.0), fullJitter); + + // Assert: all produce the same delay, equal to the attempt=30 value (sanity-checked against + // the analytical upper bound of 2^30 ms — exact for full jitter at random=1, and 2^30 * 1.2 + // for biased mode at random=1). + double expectedUpperBoundMs = Math.Pow(2, 30); // 2^30 ms + if (fullJitter) + { + // random = 1 → result == upper bound + at30.TotalMilliseconds.Should().BeApproximately(expectedUpperBoundMs, 1); + } + else + { + // random = 1 → result == upper bound * 1.2 + at30.TotalMilliseconds.Should().BeApproximately(expectedUpperBoundMs * 1.2, 1); + } + + at31.Should().Be(at30); + at100.Should().Be(at30); + atIntMax.Should().Be(at30); + } + + sealed class DeterministicRandom : Random + { + readonly double value; + + public DeterministicRandom(double value) + { + this.value = value; + } + + public override double NextDouble() => this.value; + } +} diff --git a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs index 5813bc68..87e70484 100644 --- a/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs +++ b/test/Worker/Grpc.Tests/GrpcDurableTaskWorkerOptionsInternalTests.cs @@ -20,7 +20,10 @@ public void InternalOptions_HasSafeDefaults() internalOptions.HelloDeadline.Should().Be(TimeSpan.FromSeconds(30)); internalOptions.ChannelRecreateFailureThreshold.Should().Be(5); internalOptions.ReconnectBackoffBase.Should().Be(TimeSpan.FromSeconds(1)); - internalOptions.ReconnectBackoffCap.Should().Be(TimeSpan.FromSeconds(30)); + internalOptions.ReconnectBackoffCap.Should().Be(TimeSpan.FromSeconds(30)); + internalOptions.TransientRetryBackoffBase.Should().Be(TimeSpan.FromMilliseconds(200)); + internalOptions.TransientRetryBackoffCap.Should().Be(TimeSpan.FromSeconds(15)); + internalOptions.TransientRetryMaxAttempts.Should().Be(10); internalOptions.SilentDisconnectTimeout.Should().Be(TimeSpan.FromSeconds(120)); internalOptions.ChannelRecreator.Should().BeNull(); } diff --git a/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs b/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs deleted file mode 100644 index 024f179e..00000000 --- a/test/Worker/Grpc.Tests/ReconnectBackoffTests.cs +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -namespace Microsoft.DurableTask.Worker.Grpc.Tests; - -public class ReconnectBackoffTests -{ - [Fact] - public void Compute_ZeroBase_ReturnsZero() - { - // Arrange - Random random = new(42); - - // Act - TimeSpan delay = ReconnectBackoff.Compute(attempt: 5, baseDelay: TimeSpan.Zero, cap: TimeSpan.FromSeconds(30), random); - - // Assert - delay.Should().Be(TimeSpan.Zero); - } - - [Fact] - public void Compute_NegativeBase_ReturnsZero() - { - // Arrange - Random random = new(42); - - // Act - TimeSpan delay = ReconnectBackoff.Compute(attempt: 0, baseDelay: TimeSpan.FromMilliseconds(-100), cap: TimeSpan.FromSeconds(30), random); - - // Assert - delay.Should().Be(TimeSpan.Zero); - } - - [Fact] - public void Compute_NeverExceedsCap() - { - // Arrange - TimeSpan cap = TimeSpan.FromSeconds(30); - TimeSpan baseDelay = TimeSpan.FromSeconds(1); - Random random = new(1); - - // Act + Assert: try a wide range of attempts, including pathological values. - for (int attempt = 0; attempt < 50; attempt++) - { - TimeSpan delay = ReconnectBackoff.Compute(attempt, baseDelay, cap, random); - delay.Should().BeLessThanOrEqualTo(cap, $"attempt {attempt} produced {delay}"); - delay.Should().BeGreaterThanOrEqualTo(TimeSpan.Zero); - } - } - - [Fact] - public void Compute_GrowsExponentiallyUntilCap() - { - // Arrange: a Random that always returns 1.0 forces the upper bound of the jitter window. - DeterministicRandom random = new(value: 0.999999); - TimeSpan baseDelay = TimeSpan.FromSeconds(1); - TimeSpan cap = TimeSpan.FromSeconds(30); - - // Act - double d0 = ReconnectBackoff.Compute(0, baseDelay, cap, random).TotalMilliseconds; - double d1 = ReconnectBackoff.Compute(1, baseDelay, cap, random).TotalMilliseconds; - double d2 = ReconnectBackoff.Compute(2, baseDelay, cap, random).TotalMilliseconds; - double d3 = ReconnectBackoff.Compute(3, baseDelay, cap, random).TotalMilliseconds; - double d10 = ReconnectBackoff.Compute(10, baseDelay, cap, random).TotalMilliseconds; - - // Assert: roughly doubles each step until cap is reached. - d0.Should().BeApproximately(1000, 1); - d1.Should().BeApproximately(2000, 1); - d2.Should().BeApproximately(4000, 1); - d3.Should().BeApproximately(8000, 1); - d10.Should().BeApproximately(30000, 1, "should be clamped at the cap"); - } - - [Fact] - public void Compute_WithFullJitter_StaysWithinBounds() - { - // Arrange: with random=0 the result is 0; with random=1 the result is the bound. - TimeSpan baseDelay = TimeSpan.FromSeconds(1); - TimeSpan cap = TimeSpan.FromSeconds(30); - - // Act + Assert: random=0 → 0 - TimeSpan low = ReconnectBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.0)); - low.TotalMilliseconds.Should().BeApproximately(0, 0.5); - - // random ~1 → bound (= 8s for attempt=3, base=1s) - TimeSpan high = ReconnectBackoff.Compute(3, baseDelay, cap, new DeterministicRandom(0.999999)); - high.TotalMilliseconds.Should().BeApproximately(8000, 1); - } - - [Fact] - public void Compute_NegativeAttempt_TreatedAsZero() - { - // Arrange - DeterministicRandom random = new(0.999999); - - // Act - TimeSpan delay = ReconnectBackoff.Compute(attempt: -5, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(30), random); - - // Assert - delay.TotalMilliseconds.Should().BeApproximately(1000, 1); - } - - [Fact] - public void Compute_CapSmallerThanBase_ClampsToCap() - { - // Arrange: cap is intentionally smaller than baseDelay; the cap must still be honored. - DeterministicRandom random = new(0.999999); - TimeSpan baseDelay = TimeSpan.FromSeconds(5); - TimeSpan cap = TimeSpan.FromSeconds(1); - - // Act - TimeSpan delay = ReconnectBackoff.Compute(attempt: 3, baseDelay, cap, random); - - // Assert: with random ~ 1 the result is the bound, which must equal the cap. - delay.TotalMilliseconds.Should().BeApproximately(1000, 1); - delay.Should().BeLessThanOrEqualTo(cap); - } - - [Fact] - public void Compute_NonPositiveCap_ReturnsZero() - { - // Arrange - DeterministicRandom random = new(0.999999); - - // Act - TimeSpan zero = ReconnectBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.Zero, random); - TimeSpan negative = ReconnectBackoff.Compute(attempt: 3, baseDelay: TimeSpan.FromSeconds(1), cap: TimeSpan.FromSeconds(-1), random); - - // Assert - zero.Should().Be(TimeSpan.Zero); - negative.Should().Be(TimeSpan.Zero); - } - - sealed class DeterministicRandom : Random - { - readonly double value; - - public DeterministicRandom(double value) - { - this.value = value; - } - - public override double NextDouble() => this.value; - } -} diff --git a/test/Worker/Grpc.Tests/RunBackgroundTaskLoggingTests.cs b/test/Worker/Grpc.Tests/RunBackgroundTaskLoggingTests.cs index 86f9afec..5786d749 100644 --- a/test/Worker/Grpc.Tests/RunBackgroundTaskLoggingTests.cs +++ b/test/Worker/Grpc.Tests/RunBackgroundTaskLoggingTests.cs @@ -285,6 +285,214 @@ public async Task Logs_Abandoning_And_NoAbandoned_When_EntityV2_Abandon_Fails() await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Unexpected error") && l.Message.Contains(instanceId))); } + [Fact] + public async Task Retries_Abandon_Orchestrator_On_Transient_Error_Eventually_Succeeds() + { + await using var fixture = await TestFixture.CreateAsync(transientRetryBackoffBase: TimeSpan.FromMilliseconds(1)); + + string instanceId = Guid.NewGuid().ToString("N"); + string completionToken = Guid.NewGuid().ToString("N"); + + int abandonCallCount = 0; + var tcs = new TaskCompletionSource(); + fixture.ClientMock + .Setup(c => c.AbandonTaskOrchestratorWorkItemAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((P.AbandonOrchestrationTaskRequest r, Metadata h, DateTime? d, CancellationToken ct) => + { + abandonCallCount++; + if (abandonCallCount == 1) + { + // First call: simulate a transient gRPC error + return RpcExceptionAsyncUnaryCall(StatusCode.Unavailable); + } + + // Second call: succeed + return CompletedAsyncUnaryCall(new P.AbandonOrchestrationTaskResponse(), () => tcs.TrySetResult(true)); + }); + + P.WorkItem workItem = new() + { + OrchestratorRequest = new P.OrchestratorRequest { InstanceId = instanceId }, + CompletionToken = completionToken, + }; + + fixture.InvokeRunBackgroundTask(workItem, () => Task.FromException(new Exception("boom"))); + + await WaitAsync(tcs.Task); + + // Verify the call was retried (called twice total) + abandonCallCount.Should().Be(2); + + // Verify the Abandoned log is present (retry succeeded) + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Abandoned orchestrator work item") && l.Message.Contains(instanceId))); + + // Verify a retry warning was logged + await AssertEventually(() => fixture.GetLogs().Any(l => + l.EventId.Name == "TransientGrpcRetry" && + l.Message.Contains("AbandonTaskOrchestratorWorkItemAsync"))); + } + + [Fact] + public async Task Retries_Abandon_Activity_On_Transient_Error_Eventually_Succeeds() + { + await using var fixture = await TestFixture.CreateAsync(transientRetryBackoffBase: TimeSpan.FromMilliseconds(1)); + + string instanceId = Guid.NewGuid().ToString("N"); + string completionToken = Guid.NewGuid().ToString("N"); + + int abandonCallCount = 0; + var tcs = new TaskCompletionSource(); + fixture.ClientMock + .Setup(c => c.AbandonTaskActivityWorkItemAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((P.AbandonActivityTaskRequest r, Metadata h, DateTime? d, CancellationToken ct) => + { + abandonCallCount++; + if (abandonCallCount == 1) + { + // First call: simulate a transient gRPC error + return RpcExceptionAsyncUnaryCall(StatusCode.Unavailable); + } + + // Second call: succeed + return CompletedAsyncUnaryCall(new P.AbandonActivityTaskResponse(), () => tcs.TrySetResult(true)); + }); + + P.WorkItem workItem = new() + { + ActivityRequest = new P.ActivityRequest + { + Name = "MyActivity", + TaskId = 42, + OrchestrationInstance = new P.OrchestrationInstance { InstanceId = instanceId }, + }, + CompletionToken = completionToken, + }; + + fixture.InvokeRunBackgroundTask(workItem, () => Task.FromException(new Exception("boom"))); + + await WaitAsync(tcs.Task); + + abandonCallCount.Should().Be(2); + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Abandoned activity work item") && l.Message.Contains(instanceId))); + await AssertEventually(() => fixture.GetLogs().Any(l => + l.EventId.Name == "TransientGrpcRetry" && + l.Message.Contains("AbandonTaskActivityWorkItemAsync"))); + } + + [Fact] + public async Task Retries_Abandon_Orchestrator_Until_MaxAttempts_Then_Fails() + { + const int maxAttempts = 3; + await using var fixture = await TestFixture.CreateAsync( + transientRetryMaxAttempts: maxAttempts, + transientRetryBackoffBase: TimeSpan.FromMilliseconds(1)); + + string instanceId = Guid.NewGuid().ToString("N"); + string completionToken = Guid.NewGuid().ToString("N"); + + int abandonCallCount = 0; + fixture.ClientMock + .Setup(c => c.AbandonTaskOrchestratorWorkItemAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((P.AbandonOrchestrationTaskRequest r, Metadata h, DateTime? d, CancellationToken ct) => + { + abandonCallCount++; + return RpcExceptionAsyncUnaryCall(StatusCode.Unavailable); + }); + + P.WorkItem workItem = new() + { + OrchestratorRequest = new P.OrchestratorRequest { InstanceId = instanceId }, + CompletionToken = completionToken, + }; + + fixture.InvokeRunBackgroundTask(workItem, () => Task.FromException(new Exception("boom"))); + + // Wait for all retries to be exhausted. The ExecuteAsync loop in the worker logs an "Unexpected error" (for the + // abandon exception) after the retry loop gives up, which signals the task has completed. + await AssertEventually( + () => fixture.GetLogs().Count(l => l.EventId.Name == "UnexpectedError") >= 1, + timeoutMs: 10000); + + // The Abandoned log should NOT be present since the abandon never succeeded (but the abandoning should be present) + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Abandoning orchestrator work item") && l.Message.Contains(instanceId))); + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Unexpected error") && l.Message.Contains(instanceId))); + Assert.DoesNotContain(fixture.GetLogs(), l => l.Message.Contains("Abandoned orchestrator work item") && l.Message.Contains(instanceId)); + + // Verify retry warnings were logged: one per retry attempt + IEnumerable retryLogs = fixture.GetLogs().Where(l => + l.EventId.Name == "TransientGrpcRetry" && + l.Message.Contains("AbandonTaskOrchestratorWorkItemAsync")); + retryLogs.Should().HaveCount(maxAttempts - 1); + + // The abandon RPC was called maxAttempts-1 times (retried) plus one final call that propagated + abandonCallCount.Should().Be(maxAttempts); + } + + [Theory] + [InlineData(StatusCode.InvalidArgument)] + [InlineData(StatusCode.PermissionDenied)] + [InlineData(StatusCode.NotFound)] + public async Task Non_Transient_Abandon_Orchestrator_Error_Is_Not_Retried(StatusCode statusCode) + { + await using var fixture = await TestFixture.CreateAsync(); + + string instanceId = Guid.NewGuid().ToString("N"); + string completionToken = Guid.NewGuid().ToString("N"); + + // Signal fires after the (single) abandon call, giving us a reliable completion signal + var callDoneTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + int abandonCallCount = 0; + fixture.ClientMock + .Setup(c => c.AbandonTaskOrchestratorWorkItemAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((P.AbandonOrchestrationTaskRequest r, Metadata h, DateTime? d, CancellationToken ct) => + { + Interlocked.Increment(ref abandonCallCount); + callDoneTcs.TrySetResult(true); + return RpcExceptionAsyncUnaryCall(statusCode); + }); + + P.WorkItem workItem = new() + { + OrchestratorRequest = new P.OrchestratorRequest { InstanceId = instanceId }, + CompletionToken = completionToken, + }; + + fixture.InvokeRunBackgroundTask(workItem, () => Task.FromException(new Exception("boom"))); + + // Wait for the single abandon call to complete + await WaitAsync(callDoneTcs.Task); + + // Give a brief moment for the final log lines to flush + await Task.Delay(100); + + // The non-transient error must not have been retried – exactly one abandon call + abandonCallCount.Should().Be(1); + + // No retry warning should have been logged + Assert.DoesNotContain(fixture.GetLogs(), l => l.EventId.Name == "TransientGrpcRetry"); + + // The Abandoned log must not be present since the RPC failed without retry (but the abandoning should be present) + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Abandoning orchestrator work item") && l.Message.Contains(instanceId))); + await AssertEventually(() => fixture.GetLogs().Any(l => l.Message.Contains("Unexpected error") && l.Message.Contains(instanceId))); + Assert.DoesNotContain(fixture.GetLogs(), l => l.Message.Contains("Abandoned orchestrator work item") && l.Message.Contains(instanceId)); + } + [Fact] public async Task Forwards_CancellationToken_To_Abandon_Orchestrator() { @@ -366,7 +574,9 @@ sealed class TestFixture : IAsyncDisposable this.RunBackgroundTaskMethod = runBackgroundTaskMethod; } - public static async Task CreateAsync() + public static async Task CreateAsync( + int? transientRetryMaxAttempts = null, + TimeSpan? transientRetryBackoffBase = null) { // Logging var logProvider = new TestLogProvider(new NullOutput()); @@ -375,7 +585,18 @@ public static async Task CreateAsync() var loggerFactory = new SimpleLoggerFactory(logProvider); // Options - var grpcOptions = new OptionsMonitorStub(new GrpcDurableTaskWorkerOptions()); + GrpcDurableTaskWorkerOptions grpcOptionsValue = new(); + if (transientRetryMaxAttempts.HasValue) + { + grpcOptionsValue.Internal.TransientRetryMaxAttempts = transientRetryMaxAttempts.Value; + } + + if (transientRetryBackoffBase.HasValue) + { + grpcOptionsValue.Internal.TransientRetryBackoffBase = transientRetryBackoffBase.Value; + } + + var grpcOptions = new OptionsMonitorStub(grpcOptionsValue); var workerOptions = new OptionsMonitorStub(new DurableTaskWorkerOptions()); // Factory (not used in these tests) @@ -450,6 +671,18 @@ static AsyncUnaryCall FaultedAsyncUnaryCall(Exception ex) () => { }); } + static AsyncUnaryCall RpcExceptionAsyncUnaryCall(StatusCode statusCode, string detail = "transient error") + { + RpcException ex = new(new Status(statusCode, detail)); + var respTask = Task.FromException(ex); + return new AsyncUnaryCall( + respTask, + Task.FromResult(new Metadata()), + () => new Status(statusCode, detail), + () => new Metadata(), + () => { }); + } + sealed class NullOutput : ITestOutputHelper { public void WriteLine(string message) { }