diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index a7693591..4aa32d60 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -19,6 +19,7 @@ import software.amazon.awssdk.services.lambda.model.OperationAction; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.awssdk.services.lambda.model.OperationUpdate; import software.amazon.lambda.durable.DurableConfig; import software.amazon.lambda.durable.TestUtils; import software.amazon.lambda.durable.TypeToken; @@ -44,14 +45,21 @@ class ParallelOperationTest { private DurableContextImpl durableContext; private ExecutionManager executionManager; + // Thread-safe backing store for getOperationAndUpdateReplayState. + // Tests pre-populate this; doAnswer writes here before firing onCheckpointComplete, + // guaranteeing visibility to any thread that reads after the future unblocks. + private ConcurrentHashMap operationStore; @BeforeEach void setUp() { durableContext = mock(DurableContextImpl.class); executionManager = mock(ExecutionManager.class); + operationStore = new ConcurrentHashMap<>(); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(null, ThreadType.CONTEXT)); - when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null); + // Delegate to operationStore so all reads see the latest write, regardless of thread. + when(executionManager.getOperationAndUpdateReplayState(anyString())) + .thenAnswer(inv -> operationStore.get(inv.getArgument(0))); var childContext = mock(DurableContextImpl.class); when(childContext.getExecutionManager()).thenReturn(executionManager); @@ -77,9 +85,10 @@ void setUp() { .when(executionManager) .registerOperation(any()); - // Simulate the real backend for all sendOperationUpdate calls: - // - For SUCCEED on the parallel op: update the stub and fire onCheckpointComplete to unblock join(). - // - For everything else (START, child checkpoints): just return a completed future. + // Simulate the real backend for all sendOperationUpdate calls. + // For SUCCEED on the parallel op: write to operationStore first (establishes happens-before + // via ConcurrentHashMap's volatile semantics), then fire onCheckpointComplete to unblock join(). + // This ordering guarantees getOperationAndUpdateReplayState() never returns null after unblocking. var succeededParallelOp = Operation.builder() .id(OPERATION_ID) .name("test-parallel") @@ -88,11 +97,11 @@ void setUp() { .status(OperationStatus.SUCCEEDED) .build(); doAnswer(inv -> { - var update = (software.amazon.awssdk.services.lambda.model.OperationUpdate) inv.getArgument(0); + var update = (OperationUpdate) inv.getArgument(0); if (OPERATION_ID.equals(update.id()) && update.action() == OperationAction.SUCCEED) { - when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)) - .thenReturn(succeededParallelOp); + // Write before completing the future — ConcurrentHashMap guarantees visibility. + operationStore.put(OPERATION_ID, succeededParallelOp); var op = registeredOps.get(OPERATION_ID); if (op != null) { op.onCheckpointComplete(succeededParallelOp);