diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 016cab11e..6071f6d14 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -57,6 +57,7 @@ public abstract class ConcurrencyOperation extends BaseDurableOperation { private final Set completedOperations = Collections.synchronizedSet(new HashSet()); private ConcurrencyCompletionStatus completionStatus; private OperationIdGenerator operationIdGenerator; + private final DurableContextImpl rootContext; protected ConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -73,6 +74,7 @@ protected ConcurrencyOperation( this.toleratedFailureCount = toleratedFailureCount; this.failureRateThreshold = failureRateThreshold; this.operationIdGenerator = new OperationIdGenerator(getOperationId()); + this.rootContext = durableContext.createChildContext(getOperationId(), getName()); } protected ConcurrencyOperation( @@ -142,7 +144,7 @@ public ChildContextOperation addItem( String name, Function function, TypeToken resultType, SerDes serDes) { if (isOperationCompleted()) throw new IllegalStateException("Cannot add items to a completed operation"); var operationId = this.operationIdGenerator.nextOperationId(); - var childOp = createItem(operationId, name, function, resultType, serDes, getContext()); + var childOp = createItem(operationId, name, function, resultType, serDes, this.rootContext); childOperations.add(childOp); pendingQueue.add(childOp); logger.debug("Item added {}", name); diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java index 8c4bb710f..a55d1c139 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ConcurrencyOperationTest.java @@ -38,6 +38,7 @@ class ConcurrencyOperationTest { private static final TypeToken RESULT_TYPE = TypeToken.get(Void.class); private DurableContextImpl durableContext; + private DurableContextImpl childContext; private ExecutionManager executionManager; private AtomicInteger operationIdCounter; private OperationIdGenerator mockIdGenerator; @@ -48,11 +49,20 @@ void setUp() { executionManager = mock(ExecutionManager.class); operationIdCounter = new AtomicInteger(0); + var childContext = mock(DurableContextImpl.class); + this.childContext = childContext; + when(childContext.getExecutionManager()).thenReturn(executionManager); + when(childContext.getDurableConfig()) + .thenReturn(DurableConfig.builder() + .withExecutorService(Executors.newCachedThreadPool()) + .build()); + when(durableContext.getExecutionManager()).thenReturn(executionManager); when(durableContext.getDurableConfig()) .thenReturn(DurableConfig.builder() .withExecutorService(Executors.newCachedThreadPool()) .build()); + when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); mockIdGenerator = mock(OperationIdGenerator.class); when(mockIdGenerator.nextOperationId()).thenAnswer(inv -> "child-" + operationIdCounter.incrementAndGet()); @@ -167,6 +177,18 @@ void singleChildAlreadySucceeds_fullCycle() throws Exception { assertFalse(functionCalled.get(), "Function should not be called during SUCCEEDED replay"); } + @Test + void addItem_usesRootChildContextAsParent() throws Exception { + var op = createOperation(-1, -1, 0); + + op.addItem("branch-1", ctx -> "result", TypeToken.get(String.class), SER_DES); + + // rootContext is created via durableContext.createChildContext(...) in the constructor, + // so the parentContext passed to createItem must be that child context, not durableContext itself + assertNotSame(durableContext, op.getLastParentContext()); + assertSame(childContext, op.getLastParentContext()); + } + // ===== Helpers ===== private void runJoin(TestConcurrencyOperation op) throws InterruptedException { @@ -182,6 +204,7 @@ static class TestConcurrencyOperation extends ConcurrencyOperation { private boolean successHandled = false; private boolean failureHandled = false; private final AtomicInteger executingCount = new AtomicInteger(0); + private DurableContextImpl lastParentContext; TestConcurrencyOperation( OperationIdentifier operationIdentifier, @@ -209,6 +232,7 @@ protected ChildContextOperation createItem( TypeToken resultType, SerDes serDes, DurableContextImpl parentContext) { + lastParentContext = parentContext; return new ChildContextOperation( OperationIdentifier.of(operationId, name, OperationType.CONTEXT, OperationSubType.PARALLEL_BRANCH), function, @@ -260,5 +284,9 @@ boolean isSuccessHandled() { boolean isFailureHandled() { return failureHandled; } + + DurableContextImpl getLastParentContext() { + return lastParentContext; + } } } diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java index dc6a93bd6..00345e60d 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ParallelOperationTest.java @@ -47,11 +47,19 @@ void setUp() { executionManager = mock(ExecutionManager.class); operationIdCounter = new AtomicInteger(0); + var childContext = mock(DurableContextImpl.class); + when(childContext.getExecutionManager()).thenReturn(executionManager); + when(childContext.getDurableConfig()) + .thenReturn(DurableConfig.builder() + .withExecutorService(Executors.newCachedThreadPool()) + .build()); + when(durableContext.getExecutionManager()).thenReturn(executionManager); when(durableContext.getDurableConfig()) .thenReturn(DurableConfig.builder() .withExecutorService(Executors.newCachedThreadPool()) .build()); + when(durableContext.createChildContext(anyString(), anyString())).thenReturn(childContext); when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); // Default: no existing operations (fresh execution) mockIdGenerator = mock(OperationIdGenerator.class);