Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ public Output handleRequest(Input input, DurableContext context) {

var deliveries = futures.stream().map(DurableFuture::get).toList();
logger.info("All {} notifications delivered", deliveries.size());
// Test replay
context.wait("wait for finalization", Duration.ofSeconds(5));
return new Output(deliveries);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.exception.ConcurrencyExecutionException;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.model.OperationSubType;
Expand Down Expand Up @@ -42,6 +43,8 @@
*/
public class ParallelOperation<T> extends ConcurrencyOperation<T> {

private boolean skipCheckpoint = false;

public ParallelOperation(
OperationIdentifier operationIdentifier,
TypeToken<T> resultTypeToken,
Expand Down Expand Up @@ -79,6 +82,10 @@ protected <R> ChildContextOperation<R> createItem(

@Override
protected void handleSuccess() {
if (skipCheckpoint) {
// Do not send checkpoint during replay
return;
}
sendOperationUpdate(OperationUpdate.builder()
.action(OperationAction.SUCCEED)
.subType(getSubType().getValue())
Expand All @@ -99,8 +106,9 @@ protected void start() {

@Override
protected void replay(Operation existing) {
// Always replay child branches for parallel
start();
// No-op: child branches handle their own replay via ChildContextOperation.replay().
// Set replaying=true so handleSuccess() skips re-checkpointing the already-completed parallel context.
skipCheckpoint = ExecutionManager.isTerminalStatus(existing.status());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
package software.amazon.lambda.durable.operation;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -70,6 +72,18 @@ void setUp() {
when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet());
// All child operations are NOT in replay
when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null);
// Simulate the real backend: the parent concurrency operation is available in storage after completion
// so that waitForOperationCompletion() can find it. TestConcurrencyOperation.handleSuccess/Failure are no-ops
// (no checkpoint sent), so we stub this unconditionally for OPERATION_ID.
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
.thenReturn(Operation.builder()
.id(OPERATION_ID)
.name("test-concurrency")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL.getValue())
.status(OperationStatus.SUCCEEDED)
.build());
when(executionManager.sendOperationUpdate(any())).thenReturn(CompletableFuture.completedFuture(null));
}

private TestConcurrencyOperation createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount)
Expand Down Expand Up @@ -138,7 +152,7 @@ void allChildrenAlreadySucceed_callsHandleSuccess() throws Exception {
TypeToken.get(String.class),
SER_DES);

runJoin(op);
op.exposedJoin();

assertTrue(op.isSuccessHandled());
assertFalse(op.isFailureHandled());
Expand Down Expand Up @@ -171,7 +185,7 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception {
TypeToken.get(String.class),
SER_DES);

runJoin(op);
op.exposedJoin();

assertTrue(op.isSuccessHandled());
assertEquals(1, op.getSucceededCount());
Expand All @@ -191,14 +205,6 @@ void addItem_usesRootChildContextAsParent() throws Exception {
assertSame(childContext, op.getLastParentContext());
}

// ===== Helpers =====

private void runJoin(TestConcurrencyOperation op) throws InterruptedException {
var t = new Thread(op::exposedJoin);
t.start();
t.join(2000);
}

// ===== Test subclass =====

static class TestConcurrencyOperation extends ConcurrencyOperation<Void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -67,6 +68,26 @@ void setUp() {
mockIdGenerator = mock(OperationIdGenerator.class);
when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet());
when(executionManager.getOperationAndUpdateReplayState(anyString())).thenReturn(null);

// Simulate the real backend: when a SUCCEED checkpoint is sent for the parallel op,
// make getOperationAndUpdateReplayState return a SUCCEEDED operation so waitForOperationCompletion() can find
// it.
var succeededParallelOp = Operation.builder()
.id(OPERATION_ID)
.name("test-parallel")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL.getValue())
.status(OperationStatus.SUCCEEDED)
.build();
when(executionManager.sendOperationUpdate(argThat(u -> u != null
&& u.id() != null
&& u.id().equals(OPERATION_ID)
&& u.action() == OperationAction.SUCCEED)))
.thenAnswer(inv -> {
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
.thenReturn(succeededParallelOp);
return CompletableFuture.completedFuture(null);
});
}

private ParallelOperation<Void> createOperation(int maxConcurrency, int minSuccessful, int toleratedFailureCount) {
Expand Down Expand Up @@ -153,7 +174,7 @@ void handleSuccess_sendsSucceedCheckpoint() throws Exception {
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);

runJoin(op);
op.get();

verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
}
Expand All @@ -179,7 +200,7 @@ void minSuccessful_joinCompletesWhenThresholdMet() throws Exception {
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);

// Should not throw
assertDoesNotThrow(() -> runJoin(op));
op.get();
assertEquals(1, op.getSucceededCount());
}

Expand All @@ -199,6 +220,100 @@ void contextHierarchy_branchesUseParallelContextAsParent() throws Exception {
assertNotNull(childOp);
}

// ===== Replay =====

@Test
void replay_doesNotSendStartCheckpoint() throws Exception {
// Simulate the parallel operation already existing in the service (STARTED status)
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
.thenReturn(Operation.builder()
.id(OPERATION_ID)
.name("test-parallel")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL.getValue())
.status(OperationStatus.STARTED)
.build());
// Both branches already succeeded
when(executionManager.getOperationAndUpdateReplayState("child-1"))
.thenReturn(Operation.builder()
.id("child-1")
.name("branch-1")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
.status(OperationStatus.SUCCEEDED)
.contextDetails(
ContextDetails.builder().result("\"r1\"").build())
.build());
when(executionManager.getOperationAndUpdateReplayState("child-2"))
.thenReturn(Operation.builder()
.id("child-2")
.name("branch-2")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
.status(OperationStatus.SUCCEEDED)
.contextDetails(
ContextDetails.builder().result("\"r2\"").build())
.build());

var op = createOperation(-1, -1, 0);
setOperationIdGenerator(op, mockIdGenerator);
op.execute();
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);

op.get();

verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.START));
verify(executionManager, times(1))
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
}

@Test
void replay_doesNotSendSucceedCheckpointWhenParallelAlreadySucceeded() throws Exception {
when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID))
.thenReturn(Operation.builder()
.id(OPERATION_ID)
.name("test-parallel")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL.getValue())
.status(OperationStatus.SUCCEEDED)
.build());
when(executionManager.getOperationAndUpdateReplayState("child-1"))
.thenReturn(Operation.builder()
.id("child-1")
.name("branch-1")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
.status(OperationStatus.SUCCEEDED)
.contextDetails(
ContextDetails.builder().result("\"r1\"").build())
.build());
when(executionManager.getOperationAndUpdateReplayState("child-2"))
.thenReturn(Operation.builder()
.id("child-2")
.name("branch-2")
.type(OperationType.CONTEXT)
.subType(OperationSubType.PARALLEL_BRANCH.getValue())
.status(OperationStatus.SUCCEEDED)
.contextDetails(
ContextDetails.builder().result("\"r2\"").build())
.build());

var op = createOperation(-1, -1, 0);
setOperationIdGenerator(op, mockIdGenerator);
op.execute();
op.addItem("branch-1", ctx -> "r1", TypeToken.get(String.class), SER_DES);
op.addItem("branch-2", ctx -> "r2", TypeToken.get(String.class), SER_DES);

op.get();

verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.START));
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
}

// ===== handleFailure still sends SUCCEED =====

@Test
Expand All @@ -224,22 +339,10 @@ void handleFailure_sendsSucceedCheckpointEvenWhenFailureToleranceExceeded() thro
TypeToken.get(String.class),
SER_DES);

runJoin(op);
op.get();

verify(executionManager).sendOperationUpdate(argThat(update -> update.action() == OperationAction.SUCCEED));
verify(executionManager, never())
.sendOperationUpdate(argThat(update -> update.action() == OperationAction.FAIL));
}

// ===== Helpers =====

private void runJoin(ParallelOperation<?> op) throws InterruptedException {
var t = new Thread(op::get);
t.start();
t.join(2000);
if (t.isAlive()) {
t.interrupt();
fail("join() did not complete within 2 seconds");
}
}
}
Loading