diff --git a/docs/adr/003-completable-future-based-coordination.md b/docs/adr/003-completable-future-based-coordination.md index d6d4e432e..4b571eaf6 100644 --- a/docs/adr/003-completable-future-based-coordination.md +++ b/docs/adr/003-completable-future-based-coordination.md @@ -1,4 +1,177 @@ # ADR-003: CompletableFuture-Based Operation Coordination -**Status:** Todo -**Date:** 2026-02-18 \ No newline at end of file +**Status:** Review +**Date:** 2026-02-18 + +## Context + +Currently, the SDK employs a Phaser-based mechanism for coordinating operations. The design is detailed in [ADR-002: Phaser-Based Operation Coordination](002-phaser-based-coordination.md). + +With this design, we can: + +- Register a thread when it begins and deregister it when it completes; +- Block `DurableFuture::get()` calls until the operation completes; +- Suspend execution when no registered thread exists. + +However, this design has a few issues: + +- We allow the Phasers to advance over predefined phase ranges (0 - RUNNING, 1 - COMPLETE). If we received duplicate completion updates from local runner or backend API, the phase could be advanced to 2, 3, and so on. +- We assume that there is only one party during operation replay, and two parties when receiving an operation state from checkpoint API. We call Phaser `arriveAndAwaitAdvance` once or twice based on this assumption, but it could be incorrect. In complex scenarios, this could lead to a deadlock (not enough arrive calls) or exceeding the phase range (too many arrive calls). +- The Phaser has higher complexity and cognitive overhead compared to other synchronization mechanisms. + +## Decision + +We will implement operation coordination using `CompletableFuture`., + +### Threads + +Each piece of user code (e.g. the main Lambda function body, a step body, a child context body) runs in its own user thread from the user thread pool. +Execution manager tracks active running user threads. +When a new step or a new child context is created, a new thread is created and registered in execution manager. +When the step or the child context completes, the corresponding thread is deregistered from execution manager. +When the user code is blocked on `DurableFuture::get()` or another synchronous durable operation (e.g., `wait()`), the caller thread is deregistered from execution manager. +When there is no registered thread in execution manager, the durable execution is suspended. + +A special SDK thread is created and managed by the SDK to make checkpoint API requests. + +### CompletableFuture + +The `CompletableFuture` is used to manage the completion of operations. It allows us to track the progress of operations and handle their completion in a more flexible and readable manner. + +Each durable operation has a `CompletableFuture` field. +This field is used by user threads and the SDK thread communicate the completion of operations. + +For example, when a context executes a step, the communication occurs as follows + +```mermaid +sequenceDiagram + participant Context as Context Thread + participant Future as CompletableFuture + participant EM as Execution Manager + participant SDK as SDK Thread + participant Step as Step Thread + + Note over Context: calling context.stepAsync() + Context->>Context: create StepOperation + Context->>Future: create CompletableFuture + Note over EM: Step Thread lifecycle in EM + Context->>EM: register Step Thread + activate Step + activate EM + Context->>+Step: create Step Thread + Note over Context: calling step.get() + Context->>Future: check if CompletableFuture is done + alt is not done + Context->>EM: deregister Context Thread + Context->>Future: attach a callback to register context thread when CompletableFuture is done + Context->>Future: wait for CompletableFuture to complete + Note over Context: (BLOCKED) + end + + Note over Step: executing Step logic + Step->>Step: execute user function + Step->>+SDK: checkpoint SUCCESS + SDK->>SDK: call checkpoint API + SDK->>SDK: handle checkpoint response + SDK->>+Future: complete CompletableFuture + alt callback attached + Future->>EM: register Context Thread + Future->>Context: unblock Context Thread + Note over Context: (UNBLOCKED) + end + Future-->>-SDK: CompletableFuture completed + SDK-->>-Step: checkpoint done + Context->>Context: retrieve the step result + Step->>EM: deregister Step thread + deactivate Step + deactivate EM + +``` + +| | Context Thread | Step Thread | SDK Thread | +|---|-------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 1 | create StepOperation (a CompletableFuture is created) | (not created) | (idle) | +| 2 | checkpoint START event (synchronously or asynchronously) | (not created) | call checkpoint API | +| 3 | create and register the Step thread | execute user code for the step | (idle) | +| 4 | call `DurableFuture::get()`, deregister the context thread and wait for the `CompletableFuture` to complete | (continue) | (idle) | +| 5 | (blocked) | checkpoint the step result and wait for checkpoint call to complete | call checkpoint API, and handle the API response. If it is a terminal response, complete the step operation CompletableFuture, register and unblock the context thread. | +| 6 | retrieve the result of the step | deregister and terminate the Step thread | (idle) | + +If the step code completes quickly, an alternative scenario could happen as follows + +```mermaid +sequenceDiagram + participant Context as Context Thread + participant Future as CompletableFuture + participant EM as Execution Manager + participant SDK as SDK Thread + participant Step as Step Thread + + Note over Context: calling context.stepAsync() + Context->>Context: create StepOperation + Context->>Future: create CompletableFuture + Note over EM: Step Thread lifecycle in EM + Context->>EM: register Step Thread + activate EM + Context->>Step: create Step Thread + activate Step + Step->>Step: execute user function + Step->>EM: checkpoint SUCCESS + EM->>SDK: checkpoint SUCCESS + activate SDK + SDK->>SDK: call checkpoint API + SDK->>SDK: handle checkpoint response + SDK->>+Future: complete CompletableFuture + Note over Future: no callback attached + Future-->>-SDK: CompletableFuture completed + SDK-->>Step: checkpoint done + deactivate SDK + Step->>EM: deregister Step thread + deactivate EM + deactivate Step + + Note over Context: calling step.get() + Context->>Future: check if CompletableFuture is done + alt is done + Context->>Context: retrieve the step result + end + + +``` + +| | Context Thread | Step Thread | SDK Thread | +|---|---------------------------------------------------------------------------------------------|---------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------| +| 1 | create StepOperation (a CompletableFuture is created) | (not created) | (idle) | +| 2 | checkpoint START event (synchronously or asynchronously) | (not created) | call checkpoint API | +| 3 | create and register the Step thread | execute user code for the step and complete quickly | (idle) | +| 5 | (do something else or just get starved) | checkpoint the step result and wait for checkpoint call to complete | call checkpoint API, and handle the API response. If it is a terminal response, complete the Step operation CompletableFuture. | +| 4 | call `DurableFuture::get()` (non-blocking because `CompletableFuture` is already completed) | deregister and terminate the Step thread | (idle) | +| 6 | retrieve the result of the step | (ended) | (idle) | + +The following two key mechanisms make `CompletableFuture` based solution work properly. + +- Strict ordering of `register and unblock the context thread` and `deregister and terminate the Step thread`. + - When a step completes, it calls checkpoint API to checkpoint the result and wait for the checkpoint call to complete. + - SDK thread receives the checkpoint request, makes the API call, and processes the API response. + - If the response contains a terminal operation state (it should for a succeeded or failed step), it will send the response to the `StepOperation` to complete `CompletableFuture`. When completing the future, the attached completion stages will be executed synchronously, which will register any context threads that are waiting for the result of the step. + - When SDK thread completes the API request and registers all waiting threads, the step thread continues to deregister itself from execution manager. +- Synchronized access to `CompletableFuture`. + - When a context thread calls `DurableFuture::get()`, it checks if `CompletableFuture` is done. + 1. If the future is done, `get()` will return the operation result. Otherwise, the context thread will + 2. deregister itself from execution manager; + 3. attach a completion stage to `CompletableFuture` that will re-register the context thread when later the future is completed; + 4. wait for `CompletableFuture` to complete. + - Meantime, `CompletableFuture` can be completed by SDK thread when handling the checkpoint API responses. + - A race condition will occur if this happens when the context thread is between the step `a` and `c`. + - To prevent the race condition, all the mutating access to `CompletableFuture` either to complete the future or to attach a completion stage is synchronized. + +## Consequences + +Enables: +- Support for complex scenarios which were not supported by Phaser +- Reduced implementation complexity and improved readability +- `CompletableFuture` based implementation of `DurableFuture::allOf` and `DurableFuture::anyOf` + +Cost: +- Synchronized access to `CompletableFuture` +- Obscured ordering of thread registration/deregistration \ No newline at end of file diff --git a/sdk-integration-tests/src/test/java/com/amazonaws/lambda/durable/ChildContextIntegrationTest.java b/sdk-integration-tests/src/test/java/com/amazonaws/lambda/durable/ChildContextIntegrationTest.java index 6780ed19a..db471ce8a 100644 --- a/sdk-integration-tests/src/test/java/com/amazonaws/lambda/durable/ChildContextIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/com/amazonaws/lambda/durable/ChildContextIntegrationTest.java @@ -221,7 +221,7 @@ void waitInsideChildContextReturnsPendingThenCompletes() { runner.advanceTime(); // Second run - should complete - var result2 = runner.run("test"); + var result2 = runner.runUntilComplete("test"); assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); assertEquals("done", result2.getResult(String.class)); } diff --git a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/HistoryEventProcessor.java b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/HistoryEventProcessor.java index 2f2474ba7..e08a61832 100644 --- a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/HistoryEventProcessor.java +++ b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/HistoryEventProcessor.java @@ -26,6 +26,7 @@ public TestResult processEvents(List events, Class outputType) var operationEvents = new HashMap>(); var status = ExecutionStatus.PENDING; String result = null; + ErrorObject error = null; for (var event : events) { var eventType = event.eventType(); @@ -51,9 +52,34 @@ public TestResult processEvents(List events, Class outputType) result = details.result().payload(); } } - case EXECUTION_FAILED -> status = ExecutionStatus.FAILED; - case EXECUTION_TIMED_OUT -> status = ExecutionStatus.FAILED; - case EXECUTION_STOPPED -> status = ExecutionStatus.FAILED; + case EXECUTION_FAILED -> { + status = ExecutionStatus.FAILED; + var details = event.executionFailedDetails(); + if (details != null + && details.error() != null + && details.error().payload() != null) { + error = details.error().payload(); + } + } + case EXECUTION_TIMED_OUT -> { + status = ExecutionStatus.FAILED; + var details = event.executionTimedOutDetails(); + if (details != null + && details.error() != null + && details.error().payload() != null) { + error = details.error().payload(); + } + } + case EXECUTION_STOPPED -> { + status = ExecutionStatus.FAILED; + + var details = event.executionStoppedDetails(); + if (details != null + && details.error() != null + && details.error().payload() != null) { + error = details.error().payload(); + } + } case STEP_STARTED -> { if (operationId != null) { operations.putIfAbsent( @@ -186,7 +212,7 @@ public TestResult processEvents(List events, Class outputType) testOperations.add(new TestOperation(entry.getValue(), opEvents, serDes)); } - return new TestResult<>(status, result, null, testOperations, events, serDes); + return new TestResult<>(status, result, error, testOperations, events, serDes); } private Operation createStepOperation( diff --git a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalDurableTestRunner.java b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalDurableTestRunner.java index aa7101f89..c4782f255 100644 --- a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalDurableTestRunner.java +++ b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalDurableTestRunner.java @@ -172,20 +172,26 @@ public TestResult run(I input) { return storage.toTestResult(output); } - /** Run until completion (SUCCEEDED or FAILED), simulating Lambda re-invocations. */ + /** + * Run until completion (SUCCEEDED or FAILED) or pending manual intervention, simulating Lambda re-invocations. + * Operations that don't require manual intervention (like WAIT in STARTED or STEP in PENDING) will be automatically + * advanced. + * + * @param input The input to process + * @return Final test result (SUCCEEDED or FAILED) or PENDING if operations pending manual intervention + */ public TestResult runUntilComplete(I input) { TestResult result = null; for (int i = 0; i < MAX_INVOCATIONS; i++) { result = run(input); - if (result.getStatus() != ExecutionStatus.PENDING) { - return result; // SUCCEEDED or FAILED - we're done - } - - if (skipTime) { - storage.advanceReadyOperations(); // Auto-advance and continue loop - } else { - return result; // Return PENDING - let test manually advance time + if (result.getStatus() != ExecutionStatus.PENDING || !skipTime || !storage.advanceReadyOperations()) { + // break the loop if + // - Return SUCCEEDED or FAILED - we're done + // - Return PENDING and let test manually advance operations if + // - auto advance is disabled, or + // - no operations can be auto advanced + break; } } return result; diff --git a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalMemoryExecutionClient.java b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalMemoryExecutionClient.java index 75cb655a8..bd8941c12 100644 --- a/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalMemoryExecutionClient.java +++ b/sdk-testing/src/main/java/com/amazonaws/lambda/durable/testing/LocalMemoryExecutionClient.java @@ -13,6 +13,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import software.amazon.awssdk.services.lambda.model.*; @@ -63,10 +64,16 @@ public List getEventsForOperation(String operationId) { return allEvents.stream().filter(e -> operationId.equals(e.id())).toList(); } - /** Advance all operations (simulates time passing for retries/waits). */ - public void advanceReadyOperations() { + /** + * Advance all operations (simulates time passing for retries/waits). + * + * @return true if any operations were advanced, false otherwise + */ + public boolean advanceReadyOperations() { + var replaced = new AtomicBoolean(false); operations.replaceAll((key, op) -> { if (op.status() == OperationStatus.PENDING) { + replaced.set(true); return op.toBuilder().status(OperationStatus.READY).build(); } if (op.status() == OperationStatus.STARTED && op.type() == OperationType.WAIT) { @@ -81,10 +88,12 @@ public void advanceReadyOperations() { .build(); var event = eventProcessor.processUpdate(update, succeededOp); allEvents.add(event); + replaced.set(true); return succeededOp; } return op; }); + return replaced.get(); } public void completeChainedInvoke(String name, OperationResult result) { diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java b/sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java index 6e6305d78..952b2c00a 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/DurableContext.java @@ -3,6 +3,7 @@ package com.amazonaws.lambda.durable; import com.amazonaws.lambda.durable.execution.ExecutionManager; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.logging.DurableLogger; import com.amazonaws.lambda.durable.operation.CallbackOperation; @@ -64,8 +65,8 @@ private DurableContext( static DurableContext createRootContext( ExecutionManager executionManager, DurableConfig durableConfig, Context lambdaContext) { var ctx = new DurableContext(executionManager, durableConfig, lambdaContext, null); - executionManager.registerActiveThread(ROOT_CONTEXT, ThreadType.CONTEXT); - executionManager.setCurrentContext(ROOT_CONTEXT, ThreadType.CONTEXT); + executionManager.registerActiveThread(ROOT_CONTEXT); + executionManager.setCurrentThreadContext(new ThreadContext(ROOT_CONTEXT, ThreadType.CONTEXT)); return ctx; } diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java b/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java index 4ad84cef6..8aa9e4237 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ExecutionManager.java @@ -8,9 +8,11 @@ import java.time.Duration; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -56,8 +58,8 @@ public class ExecutionManager { // ===== Thread Coordination ===== private final Map> registeredOperations = Collections.synchronizedMap(new HashMap<>()); - private final Map activeThreads = Collections.synchronizedMap(new HashMap<>()); - private static final ThreadLocal currentContext = new ThreadLocal<>(); + private final Set activeThreads = Collections.synchronizedSet(new HashSet<>()); + private static final ThreadLocal currentThreadContext = new ThreadLocal<>(); private final CompletableFuture executionExceptionFuture = new CompletableFuture<>(); // ===== Checkpoint Batching ===== @@ -174,54 +176,48 @@ public boolean hasOperationsForContext(String parentId) { } // ===== Thread Coordination ===== + /** Sets the current thread's ThreadContext (threadId and threadType). Called when a user thread is started. */ + public void setCurrentThreadContext(ThreadContext threadContext) { + currentThreadContext.set(threadContext); + } + + /** Returns the current thread's ThreadContext (threadId and threadType), or null if not set. */ + public ThreadContext getCurrentThreadContext() { + return currentThreadContext.get(); + } + /** - * Registers a thread as active without setting the thread local OperationContext. Use this when registration must - * happen on a different thread than execution. Call setCurrentContext() on the execution thread to set the local - * OperationContext. + * Registers a thread as active. * - * @see OperationContext + * @see ThreadContext */ - public void registerActiveThread(String threadId, ThreadType threadType) { - if (activeThreads.containsKey(threadId)) { - logger.trace("Thread '{}' ({}) already registered as active", threadId, threadType); + public void registerActiveThread(String threadId) { + if (activeThreads.contains(threadId)) { + logger.trace("Thread '{}' already registered as active", threadId); return; } - activeThreads.put(threadId, threadType); - logger.trace( - "Registered thread '{}' ({}) as active (no context). Active threads: {}", - threadId, - threadType, - activeThreads.size()); + activeThreads.add(threadId); + logger.trace("Registered thread '{}' as active. Active threads: {}", threadId, activeThreads.size()); } /** - * Sets the current thread's context. Use after registerActiveThreadWithoutContext() when the execution thread is - * different from the registration thread. + * Mark a thread as inactive. If no threads remain, suspends the execution. + * + * @param threadId the thread ID to deregister */ - public void setCurrentContext(String contextId, ThreadType threadType) { - currentContext.set(new OperationContext(contextId, threadType)); - } - - /** Returns the current thread's context, or null if not set. */ - public OperationContext getCurrentContext() { - return currentContext.get(); - } - - public void deregisterActiveThreadAndUnsetCurrentContext(String threadId) { + public void deregisterActiveThread(String threadId) { // Skip if already suspended if (executionExceptionFuture.isDone()) { return; } - if (!activeThreads.containsKey(threadId)) { + boolean removed = activeThreads.remove(threadId); + if (removed) { + logger.trace("Deregistered thread '{}' Active threads: {}", threadId, activeThreads.size()); + } else { logger.warn("Thread '{}' not active, cannot deregister", threadId); - return; } - ThreadType type = activeThreads.remove(threadId); - currentContext.remove(); - logger.trace("Deregistered thread '{}' ({}). Active threads: {}", threadId, type, activeThreads.size()); - if (activeThreads.isEmpty()) { logger.info("No active threads remaining - suspending execution"); suspendExecution(); @@ -256,7 +252,7 @@ public void shutdown() { checkpointBatcher.shutdown(); } - private boolean isTerminalStatus(OperationStatus status) { + public static boolean isTerminalStatus(OperationStatus status) { return status == OperationStatus.SUCCEEDED || status == OperationStatus.FAILED || status == OperationStatus.CANCELLED diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/execution/OperationContext.java b/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ThreadContext.java similarity index 73% rename from sdk/src/main/java/com/amazonaws/lambda/durable/execution/OperationContext.java rename to sdk/src/main/java/com/amazonaws/lambda/durable/execution/ThreadContext.java index 12c49e968..62dc938c0 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/execution/OperationContext.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/execution/ThreadContext.java @@ -3,4 +3,4 @@ package com.amazonaws.lambda.durable.execution; /** Holds the current thread's execution context. */ -public record OperationContext(String contextId, ThreadType threadType) {} +public record ThreadContext(String threadId, ThreadType threadType) {} diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/BaseDurableOperation.java index 443638e13..fcbb3fce3 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/BaseDurableOperation.java @@ -9,6 +9,7 @@ import com.amazonaws.lambda.durable.exception.SerDesException; import com.amazonaws.lambda.durable.exception.UnrecoverableDurableExecutionException; import com.amazonaws.lambda.durable.execution.ExecutionManager; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.SerDes; import com.amazonaws.lambda.durable.util.ExceptionHelper; @@ -19,7 +20,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.lambda.model.ErrorObject; import software.amazon.awssdk.services.lambda.model.Operation; -import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.OperationType; import software.amazon.awssdk.services.lambda.model.OperationUpdate; @@ -51,7 +51,7 @@ public abstract class BaseDurableOperation implements DurableFuture { private final ExecutionManager executionManager; private final TypeToken resultTypeToken; private final SerDes resultSerDes; - private final CompletableFuture completionFuture; + protected final CompletableFuture completionFuture; protected BaseDurableOperation( String operationId, @@ -125,7 +125,7 @@ protected Operation getOperation() { * @throws IllegalDurableOperationException if it's in a step */ private void validateCurrentThreadType() { - ThreadType current = executionManager.getCurrentContext().threadType(); + ThreadType current = getCurrentThreadContext().threadType(); if (current == ThreadType.STEP) { var message = String.format( "Nested %s operation is not supported on %s from within a %s execution.", @@ -145,33 +145,34 @@ protected Operation waitForOperationCompletion() { validateCurrentThreadType(); - var context = executionManager.getCurrentContext(); + var threadContext = getCurrentThreadContext(); - // Use a synchronized block here to prevent the completionFuture from being completed by the execution thread - // (a step or child context thread) when it's inside the `if` block where the completion check is done (not - // completed) while the callback isn't added to the completionFuture or the current thread isn't deregistered. + // It's important that we synchronize access to the future. Otherwise, a race condition could happen if the + // completionFuture is completed by a user thread (a step or child context thread) when the execution here + // is between `isOperationCompleted` and `thenRun`. synchronized (completionFuture) { if (!isOperationCompleted()) { // Operation not done yet - logger.debug("get() on {} attempting to deregister context: {}", getType(), context.contextId()); - - // Add a callback to completionFuture so that when the completionFuture is completed, + logger.trace( + "deregistering thread {} when waiting for operation {} ({}) to complete ({})", + threadContext.threadId(), + getOperation(), + getType(), + completionFuture); + + // Add a completion stage to completionFuture so that when the completionFuture is completed, // it will register the current Context thread synchronously to make sure it is always registered - // before the execution thread (Step or child context) is deregistered. - completionFuture.thenRun(() -> registerActiveThread(context.contextId(), context.threadType())); + // strictly before the execution thread (Step or child context) is deregistered. + completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId())); // Deregister the current thread to allow suspension - deregisterActiveThreadAndUnsetCurrentContext(context.contextId()); + deregisterActiveThread(threadContext.threadId()); } } // Block until operation completes. No-op if the future is already completed. - logger.trace("Waiting for operation to finish {} ({})", getOperationId(), completionFuture); completionFuture.join(); - // Reactivate current context. No-op if this is called twice. - setCurrentContext(context.contextId(), context.threadType()); - // Get result based on status var op = getOperation(); if (op == null) { @@ -183,9 +184,17 @@ protected Operation waitForOperationCompletion() { /** Receives operation updates from ExecutionManager and updates the internal state of the operation */ public void onCheckpointComplete(Operation operation) { - if (isTerminalStatus(operation.status())) { + if (ExecutionManager.isTerminalStatus(operation.status())) { + // This method handles only terminal status updates. Override this method if a DurableOperation needs to + // handle other updates. + logger.trace("In onCheckpointComplete, completing operation {} ({})", operationId, completionFuture); + // It's important that we synchronize access to the future, otherwise the processing could happen + // on someone else's thread and cause a race condition. synchronized (completionFuture) { - logger.trace("In onCheckpointComplete, completing operation {} ({})", operationId, completionFuture); + // Completing the future here will also run any other completion stages that have been attached + // to the future. In our case, other contexts may have attached a function to reactivate themselves, + // so they will definitely have a chance to reactivate before we finish completing and deactivating + // whatever operations were just checkpointed. completionFuture.complete(null); } } @@ -196,6 +205,9 @@ protected void markAlreadyCompleted() { // When the operation is already completed in a replay, we complete completionFuture immediately // so that the `get` method will be unblocked and the context thread will be registered logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", operationId, completionFuture); + + // It's important that we synchronize access to the future, otherwise the processing could happen + // on someone else's thread and cause a race condition. synchronized (completionFuture) { completionFuture.complete(null); } @@ -213,16 +225,20 @@ protected T terminateExecutionWithIllegalDurableOperationException(String messag } // advanced thread and context control - protected void deregisterActiveThreadAndUnsetCurrentContext(String threadId) { - executionManager.deregisterActiveThreadAndUnsetCurrentContext(threadId); + protected void deregisterActiveThread(String threadId) { + executionManager.deregisterActiveThread(threadId); } - protected void registerActiveThread(String threadId, ThreadType threadType) { - executionManager.registerActiveThread(threadId, threadType); + protected void registerActiveThread(String threadId) { + executionManager.registerActiveThread(threadId); } - protected void setCurrentContext(String stepThreadId, ThreadType step) { - executionManager.setCurrentContext(stepThreadId, step); + protected ThreadContext getCurrentThreadContext() { + return executionManager.getCurrentThreadContext(); + } + + protected void setCurrentThreadContext(ThreadContext threadContext) { + executionManager.setCurrentThreadContext(threadContext); } // polling and checkpointing @@ -315,12 +331,4 @@ protected void validateReplay(Operation checkpointed) { operationId, checkpointed.name(), getName()))); } } - - private boolean isTerminalStatus(OperationStatus status) { - return status == OperationStatus.SUCCEEDED - || status == OperationStatus.FAILED - || status == OperationStatus.CANCELLED - || status == OperationStatus.TIMED_OUT - || status == OperationStatus.STOPPED; - } } diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/ChildContextOperation.java index d3363d1c6..cb28c7891 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/ChildContextOperation.java @@ -12,6 +12,7 @@ import com.amazonaws.lambda.durable.exception.UnrecoverableDurableExecutionException; import com.amazonaws.lambda.durable.execution.ExecutionManager; import com.amazonaws.lambda.durable.execution.SuspendExecutionException; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.SerDes; import com.amazonaws.lambda.durable.util.ExceptionHelper; @@ -105,10 +106,10 @@ private void executeChildContext() { // 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside // the child context know which context they belong to. // registerActiveThread is idempotent (no-op if already registered). - registerActiveThread(contextId, ThreadType.CONTEXT); + registerActiveThread(contextId); userExecutor.execute(() -> { - setCurrentContext(contextId, ThreadType.CONTEXT); + setCurrentThreadContext(new ThreadContext(contextId, ThreadType.CONTEXT)); try { var childContext = DurableContext.createChildContext(executionManager, durableConfig, lambdaContext, contextId); @@ -128,7 +129,7 @@ private void executeChildContext() { handleChildContextFailure(e); } finally { try { - deregisterActiveThreadAndUnsetCurrentContext(contextId); + deregisterActiveThread(contextId); } catch (SuspendExecutionException e) { // Expected when this is the last active thread — suspension already signaled } @@ -160,6 +161,10 @@ private void checkpointSuccess(T result) { private void handleChildContextFailure(Throwable exception) { exception = ExceptionHelper.unwrapCompletableFuture(exception); + if (exception instanceof SuspendExecutionException) { + // Rethrow Error immediately — do not checkpoint + ExceptionHelper.sneakyThrow(exception); + } if (exception instanceof UnrecoverableDurableExecutionException) { terminateExecution((UnrecoverableDurableExecutionException) exception); } diff --git a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java index 973bd65d3..8e6018eda 100644 --- a/sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java +++ b/sdk/src/main/java/com/amazonaws/lambda/durable/operation/StepOperation.java @@ -12,6 +12,7 @@ import com.amazonaws.lambda.durable.exception.UnrecoverableDurableExecutionException; import com.amazonaws.lambda.durable.execution.ExecutionManager; import com.amazonaws.lambda.durable.execution.SuspendExecutionException; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.logging.DurableLogger; import com.amazonaws.lambda.durable.util.ExceptionHelper; @@ -124,13 +125,13 @@ private void executeStepLogic(int attempt) { var stepThreadId = getOperationId() + "-step"; // Register step thread as active BEFORE executor runs (prevents suspension when handler deregisters) - // thread local OperationContext is set inside the executor since that's where the step actually runs - registerActiveThread(stepThreadId, ThreadType.STEP); + // thread local ThreadContext is set inside the executor since that's where the step actually runs + registerActiveThread(stepThreadId); // Execute user code in customer-configured executor userExecutor.execute(() -> { - // Set thread local OperationContext on the executor thread - setCurrentContext(stepThreadId, ThreadType.STEP); + // Set thread local ThreadContext on the executor thread + setCurrentThreadContext(new ThreadContext(stepThreadId, ThreadType.STEP)); // Set operation context for logging in this thread durableLogger.setOperationContext(getOperationId(), getName(), attempt); try { @@ -163,7 +164,7 @@ private void executeStepLogic(int attempt) { handleStepFailure(e, attempt); } finally { try { - deregisterActiveThreadAndUnsetCurrentContext(stepThreadId); + deregisterActiveThread(stepThreadId); } catch (SuspendExecutionException e) { // Expected when this is the last active thread. Must catch here because: // 1/ This runs in a worker thread detached from handlerFuture diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/BaseDurableOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/BaseDurableOperationTest.java index 75822acd7..b53b86376 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/BaseDurableOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/BaseDurableOperationTest.java @@ -20,7 +20,7 @@ import com.amazonaws.lambda.durable.exception.NonDeterministicExecutionException; import com.amazonaws.lambda.durable.exception.SerDesException; import com.amazonaws.lambda.durable.execution.ExecutionManager; -import com.amazonaws.lambda.durable.execution.OperationContext; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.JacksonSerDes; import com.amazonaws.lambda.durable.serde.SerDes; @@ -54,7 +54,7 @@ class BaseDurableOperationTest { @BeforeEach void setUp() { executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext(CONTEXT_ID, ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT)); when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(OPERATION); } @@ -79,28 +79,6 @@ public String get() { assertEquals(OPERATION, op.getOperation()); } - @Test - void waitForOperationCompletionThrowsIfInStep() { - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("context", ThreadType.STEP)); - - BaseDurableOperation op = - new BaseDurableOperation<>( - OPERATION_ID, OPERATION_NAME, OPERATION_TYPE, RESULT_TYPE, SER_DES, executionManager) { - @Override - public void execute() { - assertThrows(IllegalDurableOperationException.class, this::waitForOperationCompletion); - } - - @Override - public String get() { - return RESULT; - } - }; - - op.execute(); - verify(executionManager).terminateExecution(any(IllegalDurableOperationException.class)); - } - @Test void waitForOperationCompletionThrowsIfOperationMissing() { when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(null); @@ -149,9 +127,8 @@ public String get() { op.onCheckpointComplete( Operation.builder().status(OperationStatus.SUCCEEDED).build()); assertEquals(RESULT, future.get()); - verify(executionManager).deregisterActiveThreadAndUnsetCurrentContext(CONTEXT_ID); - verify(executionManager).registerActiveThread(CONTEXT_ID, ThreadType.CONTEXT); - verify(executionManager).setCurrentContext(CONTEXT_ID, ThreadType.CONTEXT); + verify(executionManager).deregisterActiveThread(CONTEXT_ID); + verify(executionManager).registerActiveThread(CONTEXT_ID); } } @@ -173,9 +150,8 @@ public String get() { }; op.execute(); - verify(executionManager, never()).deregisterActiveThreadAndUnsetCurrentContext(CONTEXT_ID); - verify(executionManager, never()).registerActiveThread(CONTEXT_ID, ThreadType.CONTEXT); - verify(executionManager).setCurrentContext(CONTEXT_ID, ThreadType.CONTEXT); + verify(executionManager, never()).deregisterActiveThread(CONTEXT_ID); + verify(executionManager, never()).registerActiveThread(CONTEXT_ID); } @Test @@ -304,7 +280,7 @@ public String get() { assertEquals("abc", deserializeResult(serializeResult("abc"))); assertEquals("", deserializeResult("\"\"")); assertThrows(SerDesException.class, () -> deserializeResult("x")); - return ""; + return RESULT; } }; op.get(); @@ -327,7 +303,7 @@ public String get() { Throwable ex = deserializeException(serializeException(new RuntimeException("test exception"))); assertInstanceOf(RuntimeException.class, ex); assertEquals("test exception", ex.getMessage()); - return ""; + return RESULT; } }; diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/CallbackOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/CallbackOperationTest.java index cee882271..ed06a0066 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/CallbackOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/CallbackOperationTest.java @@ -10,9 +10,9 @@ import com.amazonaws.lambda.durable.TypeToken; import com.amazonaws.lambda.durable.exception.CallbackFailedException; import com.amazonaws.lambda.durable.exception.CallbackTimeoutException; -import com.amazonaws.lambda.durable.exception.IllegalDurableOperationException; import com.amazonaws.lambda.durable.exception.SerDesException; import com.amazonaws.lambda.durable.execution.ExecutionManager; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.JacksonSerDes; import com.amazonaws.lambda.durable.serde.SerDes; @@ -25,6 +25,10 @@ class CallbackOperationTest { + private static final String OPERATION_ID = "1"; + private static final String EXECUTION_OPERATION_ID = "0"; + private static final String OPERATION_NAME = "approval"; + /** Custom SerDes that tracks deserialization calls. */ static class TrackingSerDes implements SerDes { private final JacksonSerDes delegate = new JacksonSerDes(); @@ -75,7 +79,7 @@ private ExecutionManager createExecutionManager(List initialOperation "test-token", initialState, DurableConfig.builder().withDurableExecutionClient(client).build()); - executionManager.setCurrentContext("Root", ThreadType.CONTEXT); + executionManager.setCurrentThreadContext(new ThreadContext("Root", ThreadType.CONTEXT)); return executionManager; } @@ -85,8 +89,8 @@ void executeCreatesCheckpointAndGetsCallbackId() { var serDes = new JacksonSerDes(); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(serDes).build(), executionManager); @@ -105,7 +109,8 @@ void executeWithConfigSetsOptions() { .serDes(serDes) .build(); - var operation = new CallbackOperation<>("1", "approval", TypeToken.get(String.class), config, executionManager); + var operation = new CallbackOperation<>( + OPERATION_ID, OPERATION_NAME, TypeToken.get(String.class), config, executionManager); operation.execute(); assertNotNull(operation.callbackId()); @@ -114,8 +119,8 @@ void executeWithConfigSetsOptions() { @Test void replayReturnsExistingCallbackIdWhenSucceeded() { var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -127,8 +132,8 @@ void replayReturnsExistingCallbackIdWhenSucceeded() { var serDes = new JacksonSerDes(); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(serDes).build(), executionManager); @@ -140,8 +145,8 @@ void replayReturnsExistingCallbackIdWhenSucceeded() { @Test void getReturnsDeserializedResultWhenSucceeded() { var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -153,8 +158,8 @@ void getReturnsDeserializedResultWhenSucceeded() { var serDes = new JacksonSerDes(); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(serDes).build(), executionManager); @@ -167,8 +172,8 @@ void getReturnsDeserializedResultWhenSucceeded() { @Test void getThrowsCallbackExceptionWhenFailed() { var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.FAILED) .callbackDetails(CallbackDetails.builder() @@ -183,8 +188,8 @@ void getThrowsCallbackExceptionWhenFailed() { var serDes = new JacksonSerDes(); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(serDes).build(), executionManager); @@ -197,8 +202,8 @@ void getThrowsCallbackExceptionWhenFailed() { @Test void getThrowsCallbackTimeoutExceptionWhenTimedOut() { var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.TIMED_OUT) .callbackDetails( @@ -208,8 +213,8 @@ void getThrowsCallbackTimeoutExceptionWhenTimedOut() { var serDes = new JacksonSerDes(); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(serDes).build(), executionManager); @@ -224,8 +229,8 @@ void operationUsesCustomSerDesWhenConfigContainsOne() { var customSerDes = new TrackingSerDes(); var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -236,7 +241,8 @@ void operationUsesCustomSerDesWhenConfigContainsOne() { var executionManager = createExecutionManager(List.of(existingCallback)); var config = CallbackConfig.builder().serDes(customSerDes).build(); - var operation = new CallbackOperation<>("1", "approval", TypeToken.get(String.class), config, executionManager); + var operation = new CallbackOperation<>( + OPERATION_ID, OPERATION_NAME, TypeToken.get(String.class), config, executionManager); operation.execute(); var result = operation.get(); @@ -250,8 +256,8 @@ void operationUsesDefaultSerDesWhenConfigIsNull() { var customSerDes = new TrackingSerDes(); var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -262,8 +268,8 @@ void operationUsesDefaultSerDesWhenConfigIsNull() { var executionManager = createExecutionManager(List.of(existingCallback)); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(customSerDes).build(), executionManager); @@ -280,8 +286,8 @@ void operationUsesDefaultSerDesWhenConfigSerDesIsNull() { var customSerDes = new TrackingSerDes(); var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -292,7 +298,8 @@ void operationUsesDefaultSerDesWhenConfigSerDesIsNull() { var executionManager = createExecutionManager(List.of(existingCallback)); var config = CallbackConfig.builder().serDes(customSerDes).build(); - var operation = new CallbackOperation<>("1", "approval", TypeToken.get(String.class), config, executionManager); + var operation = new CallbackOperation<>( + OPERATION_ID, OPERATION_NAME, TypeToken.get(String.class), config, executionManager); operation.execute(); var result = operation.get(); @@ -306,8 +313,8 @@ void getThrowsSerDesExceptionWithHelpfulMessageWhenDeserializationFails() { var failingSerDes = new FailingSerDes(); var existingCallback = Operation.builder() - .id("1") - .name("approval") + .id(OPERATION_ID) + .name(OPERATION_NAME) .type(OperationType.CALLBACK) .status(OperationStatus.SUCCEEDED) .callbackDetails(CallbackDetails.builder() @@ -318,8 +325,8 @@ void getThrowsSerDesExceptionWithHelpfulMessageWhenDeserializationFails() { var executionManager = createExecutionManager(List.of(existingCallback)); var operation = new CallbackOperation<>( - "1", - "approval", + OPERATION_ID, + OPERATION_NAME, TypeToken.get(String.class), CallbackConfig.builder().serDes(failingSerDes).build(), executionManager); @@ -328,33 +335,4 @@ void getThrowsSerDesExceptionWithHelpfulMessageWhenDeserializationFails() { var exception = assertThrows(SerDesException.class, operation::get); assertEquals("Invalid base64 encoding", exception.getMessage()); } - - @Test - void getThrowsExceptionWhenCalledWithinStep() { - var existingCallback = Operation.builder() - .id("1") - .name("approval") - .type(OperationType.CALLBACK) - .status(OperationStatus.SUCCEEDED) - .callbackDetails(CallbackDetails.builder() - .callbackId("test-callback-123") - .result("invalid-data") - .build()) - .build(); - var executionManager = createExecutionManager(List.of(existingCallback)); - executionManager.setCurrentContext("Root", ThreadType.STEP); - - var operation = new CallbackOperation<>( - "1", - "approval", - TypeToken.get(String.class), - CallbackConfig.builder().serDes(new JacksonSerDes()).build(), - executionManager); - operation.execute(); - - var exception = assertThrows(IllegalDurableOperationException.class, operation::get); - assertEquals( - "Nested CALLBACK operation is not supported on approval from within a Step execution.", - exception.getMessage()); - } } diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/ChildContextOperationTest.java index c63212eda..ff976a8b4 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/ChildContextOperationTest.java @@ -10,7 +10,7 @@ import com.amazonaws.lambda.durable.exception.ChildContextFailedException; import com.amazonaws.lambda.durable.exception.NonDeterministicExecutionException; import com.amazonaws.lambda.durable.execution.ExecutionManager; -import com.amazonaws.lambda.durable.execution.OperationContext; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.JacksonSerDes; import java.util.List; @@ -30,7 +30,7 @@ class ChildContextOperationTest { private ExecutionManager createMockExecutionManager() { var executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("Root", ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("Root", ThreadType.CONTEXT)); return executionManager; } diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/InvokeOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/InvokeOperationTest.java index e94e461bb..b0dbd5952 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/InvokeOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/InvokeOperationTest.java @@ -4,24 +4,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.amazonaws.lambda.durable.InvokeConfig; import com.amazonaws.lambda.durable.TypeToken; -import com.amazonaws.lambda.durable.exception.IllegalDurableOperationException; import com.amazonaws.lambda.durable.exception.InvokeException; import com.amazonaws.lambda.durable.exception.InvokeFailedException; import com.amazonaws.lambda.durable.exception.InvokeStoppedException; import com.amazonaws.lambda.durable.exception.InvokeTimedOutException; import com.amazonaws.lambda.durable.execution.ExecutionManager; -import com.amazonaws.lambda.durable.execution.OperationContext; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.serde.JacksonSerDes; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails; import software.amazon.awssdk.services.lambda.model.ErrorObject; +import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; class InvokeOperationTest { @@ -32,33 +32,16 @@ class InvokeOperationTest { @BeforeEach void setUp() { executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("root", ThreadType.CONTEXT)); - } - - @Test - void getThrowsIllegalStateExceptionWhenCalledFromStepContext() { - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("1-step", ThreadType.STEP)); - var operation = new InvokeOperation<>( - OPERATION_ID, - "test-invoke", - "function-name", - "{}", - TypeToken.get(String.class), - InvokeConfig.builder().serDes(new JacksonSerDes()).build(), - executionManager); - - var ex = assertThrows(IllegalDurableOperationException.class, operation::get); - assertTrue(ex.getMessage().contains("Nested CHAINED_INVOKE operation is not supported")); - assertTrue(ex.getMessage().contains("test-invoke")); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("root", ThreadType.CONTEXT)); } @Test void getDoesNotThrowWhenCalledFromHandlerContext() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) .name("test-invoke") .status(OperationStatus.SUCCEEDED) - .chainedInvokeDetails(software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails.builder() + .chainedInvokeDetails(ChainedInvokeDetails.builder() .result("\"cached-result\"") .build()) .build(); @@ -80,11 +63,11 @@ void getDoesNotThrowWhenCalledFromHandlerContext() { @Test void getInvokeFailedExceptionWhenInvocationFailed() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) .name("test-invoke") .status(OperationStatus.FAILED) - .chainedInvokeDetails(software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails.builder() + .chainedInvokeDetails(ChainedInvokeDetails.builder() .error(ErrorObject.builder() .errorType("errorType") .errorMessage("errorMessage") @@ -112,11 +95,11 @@ void getInvokeFailedExceptionWhenInvocationFailed() { @Test void getInvokeTimedOutExceptionWhenInvocationTimedOut() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) .name("test-invoke") .status(OperationStatus.TIMED_OUT) - .chainedInvokeDetails(software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails.builder() + .chainedInvokeDetails(ChainedInvokeDetails.builder() .error(ErrorObject.builder() .errorType("errorType") .errorMessage("errorMessage") @@ -144,11 +127,11 @@ void getInvokeTimedOutExceptionWhenInvocationTimedOut() { @Test void getInvokeStoppedExceptionWhenInvocationTimedOut() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) .name("test-invoke") .status(OperationStatus.STOPPED) - .chainedInvokeDetails(software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails.builder() + .chainedInvokeDetails(ChainedInvokeDetails.builder() .error(ErrorObject.builder() .errorType("errorType") .errorMessage("errorMessage") @@ -176,11 +159,11 @@ void getInvokeStoppedExceptionWhenInvocationTimedOut() { @Test void getInvokeFailedExceptionWhenInvocationEndedUnexpectedly() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) .name("test-invoke") .status(OperationStatus.CANCELLED) - .chainedInvokeDetails(software.amazon.awssdk.services.lambda.model.ChainedInvokeDetails.builder() + .chainedInvokeDetails(ChainedInvokeDetails.builder() .error(ErrorObject.builder() .errorType("errorType") .errorMessage("errorMessage") diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/StepOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/StepOperationTest.java index 208baec84..767612840 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/StepOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/StepOperationTest.java @@ -8,11 +8,10 @@ import com.amazonaws.lambda.durable.DurableConfig; import com.amazonaws.lambda.durable.StepConfig; import com.amazonaws.lambda.durable.TypeToken; -import com.amazonaws.lambda.durable.exception.IllegalDurableOperationException; import com.amazonaws.lambda.durable.exception.StepFailedException; import com.amazonaws.lambda.durable.exception.StepInterruptedException; import com.amazonaws.lambda.durable.execution.ExecutionManager; -import com.amazonaws.lambda.durable.execution.OperationContext; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import com.amazonaws.lambda.durable.logging.DurableLogger; import com.amazonaws.lambda.durable.serde.JacksonSerDes; @@ -27,10 +26,12 @@ class StepOperationTest { private static final String OPERATION_ID = "1"; + private static final String OPERATION_NAME = "test-step"; + private static final String RESULT = "result"; private ExecutionManager createMockExecutionManager() { var executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("handler", ThreadType.CONTEXT)); return executionManager; } @@ -42,7 +43,7 @@ private void mockFailedOperation( List stackTrace) { var operation = Operation.builder() .id(OPERATION_ID) - .name("test-step") + .name(OPERATION_NAME) .status(OperationStatus.FAILED) .stepDetails(StepDetails.builder() .error(ErrorObject.builder() @@ -57,44 +58,22 @@ private void mockFailedOperation( when(executionManager.getOperationAndUpdateReplayState("1")).thenReturn(operation); } - @Test - void getThrowsIllegalStateExceptionWhenCalledFromStepContext() { - var executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("1-step", ThreadType.STEP)); - - var operation = new StepOperation<>( - OPERATION_ID, - "test-step", - () -> "result", - TypeToken.get(String.class), - StepConfig.builder().serDes(new JacksonSerDes()).build(), - executionManager, - mock(DurableLogger.class), - DurableConfig.builder() - .withExecutorService(Executors.newCachedThreadPool()) - .build()); - - var ex = assertThrows(IllegalDurableOperationException.class, operation::get); - assertTrue(ex.getMessage().contains("Nested STEP operation is not supported")); - assertTrue(ex.getMessage().contains("test-step")); - } - @Test void getDoesNotThrowWhenCalledFromHandlerContext() { var op = Operation.builder() .id(OPERATION_ID) - .name("test-step") + .name(OPERATION_NAME) .status(OperationStatus.SUCCEEDED) .stepDetails(StepDetails.builder().result("\"cached-result\"").build()) .build(); var executionManager = mock(ExecutionManager.class); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext("handler", ThreadType.CONTEXT)); when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op); var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(new JacksonSerDes()).build(), executionManager, @@ -124,8 +103,8 @@ void getThrowsOriginalExceptionWhenClassIsAvailable() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(serDes).build(), executionManager, @@ -159,8 +138,8 @@ void getThrowsOriginalCustomExceptionWhenClassIsAvailable() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(serDes).build(), executionManager, @@ -185,8 +164,8 @@ void getFallsBackToStepFailedExceptionWhenClassNotFound() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(new JacksonSerDes()).build(), executionManager, @@ -217,8 +196,8 @@ void getFallsBackToStepFailedExceptionWhenDeserializationFails() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(new JacksonSerDes()).build(), executionManager, @@ -244,8 +223,8 @@ void getFallsBackToStepFailedExceptionWhenErrorDataIsNull() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(new JacksonSerDes()).build(), executionManager, @@ -271,8 +250,8 @@ void getThrowsStepInterruptedExceptionDirectly() { var operation = new StepOperation<>( OPERATION_ID, - "test-step", - () -> "result", + OPERATION_NAME, + () -> RESULT, TypeToken.get(String.class), StepConfig.builder().serDes(new JacksonSerDes()).build(), executionManager, @@ -285,7 +264,7 @@ void getThrowsStepInterruptedExceptionDirectly() { var thrown = assertThrows(StepInterruptedException.class, operation::get); assertEquals(OPERATION_ID, thrown.getOperation().id()); - assertEquals("test-step", thrown.getOperation().name()); + assertEquals(OPERATION_NAME, thrown.getOperation().name()); } // Custom exception for testing diff --git a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/WaitOperationTest.java b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/WaitOperationTest.java index 2e750d258..4fedcdf96 100644 --- a/sdk/src/test/java/com/amazonaws/lambda/durable/operation/WaitOperationTest.java +++ b/sdk/src/test/java/com/amazonaws/lambda/durable/operation/WaitOperationTest.java @@ -9,18 +9,20 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.amazonaws.lambda.durable.exception.IllegalDurableOperationException; import com.amazonaws.lambda.durable.execution.ExecutionManager; -import com.amazonaws.lambda.durable.execution.OperationContext; +import com.amazonaws.lambda.durable.execution.ThreadContext; import com.amazonaws.lambda.durable.execution.ThreadType; import java.time.Duration; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationStatus; import software.amazon.awssdk.services.lambda.model.WaitDetails; class WaitOperationTest { private static final String OPERATION_ID = "2"; + private static final String CONTEXT_ID = "handler"; + private static final String OPERATION_NAME = "test-wait"; private ExecutionManager executionManager; @BeforeEach @@ -33,7 +35,8 @@ void constructor_withNullDuration_shouldThrow() { var executionManager = mock(ExecutionManager.class); var exception = assertThrows( - IllegalArgumentException.class, () -> new WaitOperation("1", "test-wait", null, executionManager)); + IllegalArgumentException.class, + () -> new WaitOperation(OPERATION_ID, OPERATION_NAME, null, executionManager)); assertEquals("Wait duration cannot be null", exception.getMessage()); } @@ -44,7 +47,7 @@ void constructor_withZeroDuration_shouldThrow() { var exception = assertThrows( IllegalArgumentException.class, - () -> new WaitOperation("1", "test-wait", Duration.ofSeconds(0), executionManager)); + () -> new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(0), executionManager)); assertTrue(exception.getMessage().contains("Wait duration")); assertTrue(exception.getMessage().contains("at least 1 second")); @@ -56,7 +59,7 @@ void constructor_withSubSecondDuration_shouldThrow() { var exception = assertThrows( IllegalArgumentException.class, - () -> new WaitOperation("1", "test-wait", Duration.ofMillis(500), executionManager)); + () -> new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofMillis(500), executionManager)); assertTrue(exception.getMessage().contains("Wait duration")); assertTrue(exception.getMessage().contains("at least 1 second")); @@ -66,34 +69,23 @@ void constructor_withSubSecondDuration_shouldThrow() { void constructor_withValidDuration_shouldPass() { var executionManager = mock(ExecutionManager.class); - var operation = new WaitOperation("1", "test-wait", Duration.ofSeconds(10), executionManager); - - assertEquals("1", operation.getOperationId()); - } - - @Test - void getThrowsIllegalStateExceptionWhenCalledFromStepContext() { - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("1-step", ThreadType.STEP)); - - var operation = new WaitOperation("2", "test-invoke", Duration.ofSeconds(10), executionManager); + var operation = new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(10), executionManager); - var ex = assertThrows(IllegalDurableOperationException.class, operation::get); - assertEquals( - "Nested WAIT operation is not supported on test-invoke from within a Step execution.", ex.getMessage()); + assertEquals(OPERATION_ID, operation.getOperationId()); } @Test void getDoesNotThrowWhenCalledFromHandlerContext() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) - .name("test-invoke") + .name(OPERATION_NAME) .status(OperationStatus.SUCCEEDED) .waitDetails(WaitDetails.builder().build()) .build(); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT)); when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op); - var operation = new WaitOperation(OPERATION_ID, "test-invoke", Duration.ofSeconds(10), executionManager); + var operation = new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(10), executionManager); operation.onCheckpointComplete(op); var result = operation.get(); @@ -102,15 +94,15 @@ void getDoesNotThrowWhenCalledFromHandlerContext() { @Test void getSucceededWhenStarted() { - var op = software.amazon.awssdk.services.lambda.model.Operation.builder() + var op = Operation.builder() .id(OPERATION_ID) - .name("test-invoke") + .name(OPERATION_NAME) .status(OperationStatus.SUCCEEDED) .build(); - when(executionManager.getCurrentContext()).thenReturn(new OperationContext("handler", ThreadType.CONTEXT)); + when(executionManager.getCurrentThreadContext()).thenReturn(new ThreadContext(CONTEXT_ID, ThreadType.CONTEXT)); when(executionManager.getOperationAndUpdateReplayState(OPERATION_ID)).thenReturn(op); - var operation = new WaitOperation(OPERATION_ID, "test-invoke", Duration.ofSeconds(10), executionManager); + var operation = new WaitOperation(OPERATION_ID, OPERATION_NAME, Duration.ofSeconds(10), executionManager); operation.onCheckpointComplete(op); // we currently don't check the operation status at all, so it's not blocked or failed