diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/callback/WaitForCallbackFailedExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/callback/WaitForCallbackFailedExample.java index 28e90581..ce6d909a 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/callback/WaitForCallbackFailedExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/callback/WaitForCallbackFailedExample.java @@ -9,6 +9,7 @@ import software.amazon.lambda.durable.config.WaitForCallbackConfig; import software.amazon.lambda.durable.examples.types.ApprovalRequest; import software.amazon.lambda.durable.exception.SerDesException; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.serde.JacksonSerDes; public class WaitForCallbackFailedExample extends DurableHandler { @@ -31,6 +32,9 @@ public String handleRequest(ApprovalRequest input, DurableContext context) { .serDes(new FailedSerDes()) .build()) .build()); + } catch (SuspendExecutionException e) { + // not to swallow the SuspendExecutionException + throw e; } catch (Exception ex) { return ex.getClass().getSimpleName() + ":" + ex.getMessage(); } diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/DeserializationFailedParallelExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/DeserializationFailedParallelExample.java index 5b568cf7..a54aedfa 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/DeserializationFailedParallelExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/parallel/DeserializationFailedParallelExample.java @@ -10,6 +10,7 @@ import software.amazon.lambda.durable.config.ParallelBranchConfig; import software.amazon.lambda.durable.config.ParallelConfig; import software.amazon.lambda.durable.exception.SerDesException; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.serde.JacksonSerDes; /** @@ -55,6 +56,8 @@ public String handleRequest(Input input, DurableContext context) { parallel.get(); try { return future.get(); + } catch (SuspendExecutionException e) { + throw e; } catch (Exception e) { return e.getMessage(); } diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/step/DeserializationFailureExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/step/DeserializationFailureExample.java index 0514120c..e990acf2 100644 --- a/examples/src/main/java/software/amazon/lambda/durable/examples/step/DeserializationFailureExample.java +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/step/DeserializationFailureExample.java @@ -8,6 +8,7 @@ import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.StepConfig; import software.amazon.lambda.durable.exception.SerDesException; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.serde.JacksonSerDes; public class DeserializationFailureExample extends DurableHandler { @@ -22,6 +23,8 @@ public String handleRequest(String input, DurableContext context) { throw new RuntimeException("this is a test"); }, StepConfig.builder().serDes(new FailedSerDes()).build()); + } catch (SuspendExecutionException e) { + throw e; } catch (Exception e) { context.wait("suspend and replay", Duration.ofSeconds(1)); return e.getClass().getSimpleName() + ":" + e.getMessage(); diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java index 70b1353b..e1e33d0f 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java @@ -14,6 +14,7 @@ import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.ExecutionStatus; import software.amazon.lambda.durable.testing.LocalDurableTestRunner; +import software.amazon.lambda.durable.testing.TestOperation; class ParallelIntegrationTest { @@ -598,7 +599,14 @@ void testParallelWithMinSuccessful_earlyTermination() { }); var result = runner.runUntilComplete("test"); - assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + assertEquals( + ExecutionStatus.SUCCEEDED, + result.getStatus(), + String.join( + " ", + result.getOperations().stream() + .map(TestOperation::toString) + .toList())); } @Test diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java index b83c7598..1e9401cc 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java @@ -74,6 +74,20 @@ public final class DurableConfig { private static final String PROJECT_VERSION = getProjectVersion(VERSION_FILE); private static final String USER_AGENT_SUFFIX = "@aws/durable-execution-sdk-java/" + PROJECT_VERSION; + /** + * A default ExecutorService for running user-defined operations. Uses a cached thread pool with daemon threads by + * default. + * + *

This executor is used exclusively for user operations. Internal SDK coordination uses the + * InternalExecutor::INSTANCE + */ + private static final ExecutorService DEFAULT_USER_THREAD_POOL = Executors.newCachedThreadPool(r -> { + Thread t = new Thread(r); + t.setName("durable-exec-" + t.getId()); + t.setDaemon(true); + return t; + }); + private final DurableExecutionClient durableExecutionClient; private final SerDes serDes; private final ExecutorService executorService; @@ -250,12 +264,7 @@ private static String getProjectVersion(String versionFile) { */ private static ExecutorService createDefaultExecutor() { logger.debug("Creating default ExecutorService"); - return Executors.newCachedThreadPool(r -> { - Thread t = new Thread(r); - t.setName("durable-exec-" + t.getId()); - t.setDaemon(true); - return t; - }); + return DEFAULT_USER_THREAD_POOL; } /** Builder for DurableConfig. Provides fluent API for configuring SDK components. */ diff --git a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java index 6cec3b5f..466ed8fc 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java @@ -280,6 +280,7 @@ public static boolean isTerminalStatus(OperationStatus status) { * @param exception the unrecoverable exception that caused termination */ public void terminateExecution(UnrecoverableDurableExecutionException exception) { + stopAllOperations(exception); executionExceptionFuture.completeExceptionally(exception); throw exception; } @@ -287,10 +288,15 @@ public void terminateExecution(UnrecoverableDurableExecutionException exception) /** Suspends the execution by completing the execution exception future with a {@link SuspendExecutionException}. */ public void suspendExecution() { var ex = new SuspendExecutionException(); + stopAllOperations(ex); executionExceptionFuture.completeExceptionally(ex); throw ex; } + private void stopAllOperations(Exception cause) { + registeredOperations.values().forEach(op -> op.getCompletionFuture().completeExceptionally(cause)); + } + /** * return a future that completes when userFuture completes successfully or the execution is terminated or * suspended. diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index a06c84ec..a5026517 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -22,6 +22,7 @@ import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Base class for all durable operations (STEP, WAIT, etc.). @@ -187,7 +188,7 @@ protected Operation waitForOperationCompletion() { // is between `isOperationCompleted` and `thenRun`. // If this operation is a branch/iteration of a ConcurrencyOperation (map or parallel), the branches/iterations // must be completed sequentially to avoid race conditions. - synchronized (parentOperation == null ? completionFuture : parentOperation) { + synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) { if (!isOperationCompleted()) { // Operation not done yet logger.trace( @@ -208,7 +209,11 @@ protected Operation waitForOperationCompletion() { } // Block until operation completes. No-op if the future is already completed. - completionFuture.join(); + try { + completionFuture.join(); + } catch (Throwable throwable) { + ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable)); + } // Get result based on status var op = getOperation(); @@ -290,7 +295,7 @@ protected void markAlreadyCompleted() { private void markCompletionFutureCompleted() { // It's important that we synchronize access to the future, otherwise the processing could happen // on someone else's thread and cause a race condition. - synchronized (parentOperation == null ? completionFuture : parentOperation) { + synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) { // Completing the future here will also run any other completion stages that have been attached // to the future. In our case, other contexts may have attached a function to reactivate themselves, // so they will definitely have a chance to reactivate before we finish completing and deactivating diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java index 1eb4648e..31f9d43c 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java @@ -21,12 +21,15 @@ import software.amazon.lambda.durable.TypeToken; import software.amazon.lambda.durable.config.RunInChildContextConfig; import software.amazon.lambda.durable.context.DurableContextImpl; +import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; import software.amazon.lambda.durable.execution.OperationIdGenerator; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.execution.ThreadType; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Abstract base class for concurrent execution of multiple child context operations. @@ -143,7 +146,7 @@ protected ChildContextOperation enqueueItem( } private void notifyConsumerThread() { - synchronized (this) { + synchronized (completionFuture) { consumerThreadListener.get().complete(null); } } @@ -156,61 +159,80 @@ protected void executeItems() { AtomicInteger failedCount = new AtomicInteger(0); Runnable consumer = () -> { - while (true) { - // Set a new future if it's completed so that it will be able to receive a notification of - // new items when the thread is checking completion condition and processing - // the queued items below. - synchronized (this) { - if (consumerThreadListener.get() != null - && consumerThreadListener.get().isDone()) { - consumerThreadListener.set(new CompletableFuture<>()); + try { + while (true) { + // Set a new future if it's completed so that it will be able to receive a notification of + // new items when the thread is checking completion condition and processing + // the queued items below. + synchronized (completionFuture) { + if (consumerThreadListener.get() != null + && consumerThreadListener.get().isDone()) { + consumerThreadListener.set(new CompletableFuture<>()); + } } - } - // Process completion condition. Quit the loop if the condition is met. - if (isOperationCompleted()) { - return; - } - var completionStatus = canComplete(succeededCount, failedCount, runningChildren); - if (completionStatus != null) { - handleCompletion(completionStatus); - return; - } + // Process completion condition. Quit the loop if the condition is met. + if (isOperationCompleted()) { + return; + } + var completionStatus = canComplete(succeededCount, failedCount, runningChildren); + if (completionStatus != null) { + handleCompletion(completionStatus); + return; + } - // process new items in the queue - while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) { - var next = pendingQueue.poll(); - runningChildren.add(next); - logger.debug("Executing operation {}", next.getName()); - next.execute(); - } + // process new items in the queue + while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) { + var next = pendingQueue.poll(); + runningChildren.add(next); + logger.debug("Executing operation {}", next.getName()); + next.execute(); + } - // If consumerThreadListener has been completed when processing above, waitForChildCompletion will - // immediately return null and repeat the above again - var child = waitForChildCompletion(succeededCount, failedCount, runningChildren); - - // child may be null if the consumer thread is woken up due to new items added or completion condition - // changed - if (child != null) { - if (runningChildren.contains(child)) { - runningChildren.remove(child); - onItemComplete(succeededCount, failedCount, (ChildContextOperation) child); - } else { - throw new IllegalStateException("Unexpected completion: " + child); + // If consumerThreadListener has been completed when processing above, waitForChildCompletion will + // immediately return null and repeat the above again + var child = waitForChildCompletion(succeededCount, failedCount, runningChildren); + + // child may be null if the consumer thread is woken up due to new items added or completion + // condition + // changed + if (child != null) { + if (runningChildren.contains(child)) { + runningChildren.remove(child); + onItemComplete(succeededCount, failedCount, (ChildContextOperation) child); + } else { + throw new IllegalStateException("Unexpected completion: " + child); + } } } + } catch (Throwable ex) { + handleException(ex); } }; // run consumer in the user thread pool, although it's not a real user thread runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT); } + private void handleException(Throwable ex) { + Throwable throwable = ExceptionHelper.unwrapCompletableFuture(ex); + if (throwable instanceof SuspendExecutionException suspendExecutionException) { + // Rethrow Error immediately — do not checkpoint + throw suspendExecutionException; + } + if (throwable instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { + throw terminateExecution(unrecoverableDurableExecutionException); + } + + throw terminateExecutionWithIllegalDurableOperationException( + String.format("Unexpected exception in concurrency operation: %s", throwable)); + } + private BaseDurableOperation waitForChildCompletion( AtomicInteger succeededCount, AtomicInteger failedCount, Set runningChildren) { var threadContext = getCurrentThreadContext(); CompletableFuture future; - synchronized (this) { + synchronized (completionFuture) { // check again in synchronized block to prevent race conditions if (isOperationCompleted()) { return null; @@ -238,7 +260,12 @@ private BaseDurableOperation waitForChildCompletion( executionManager.deregisterActiveThread(threadContext.threadId()); } } - return future.thenApply(o -> (BaseDurableOperation) o).join(); + try { + return future.thenApply(o -> (BaseDurableOperation) o).join(); + } catch (Throwable throwable) { + ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable)); + throw throwable; + } } /** diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java index c342d02e..9e094914 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -14,11 +14,14 @@ import software.amazon.lambda.durable.config.CompletionConfig; import software.amazon.lambda.durable.config.MapConfig; import software.amazon.lambda.durable.context.DurableContextImpl; +import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException; +import software.amazon.lambda.durable.execution.SuspendExecutionException; import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus; import software.amazon.lambda.durable.model.MapResult; import software.amazon.lambda.durable.model.OperationIdentifier; import software.amazon.lambda.durable.model.OperationSubType; import software.amazon.lambda.durable.serde.SerDes; +import software.amazon.lambda.durable.util.ExceptionHelper; /** * Executes a map operation: applies a function to each item in a collection concurrently, with each item running in its @@ -153,8 +156,18 @@ protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletio } else { try { resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get())); - } catch (Exception e) { - resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e))); + } catch (Throwable exception) { + Throwable throwable = ExceptionHelper.unwrapCompletableFuture(exception); + if (throwable instanceof SuspendExecutionException suspendExecutionException) { + // Rethrow Error immediately — do not checkpoint + throw suspendExecutionException; + } + if (throwable + instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) { + // terminate the execution and throw the exception if it's not recoverable + throw terminateExecution(unrecoverableDurableExecutionException); + } + resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(throwable))); } } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index f5a644c6..8142f108 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -113,6 +113,9 @@ public ParallelResult get() { /** Calls {@link #get()} if not already called. Guarantees that the context is closed. */ @Override public void close() { + if (isJoined.get()) { + return; + } join(); } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java index 9e15b1fb..64a10f6f 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/WaitForConditionOperation.java @@ -4,7 +4,6 @@ import java.time.Duration; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; import java.util.function.BiFunction; import software.amazon.awssdk.services.lambda.model.Operation; import software.amazon.awssdk.services.lambda.model.OperationAction; @@ -40,7 +39,6 @@ public class WaitForConditionOperation extends SerializableDurableOperation> checkFunc; private final WaitForConditionConfig config; - private final ExecutorService userExecutor; public WaitForConditionOperation( String operationId, @@ -57,7 +55,6 @@ public WaitForConditionOperation( this.checkFunc = checkFunc; this.config = config; - this.userExecutor = durableContext.getDurableConfig().getExecutorService(); } @Override diff --git a/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java b/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java index 209684fb..f2766ccb 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/DurableConfigTest.java @@ -190,7 +190,7 @@ void testBuilder_MultipleBuilds_CreateIndependentInstances() { assertEquals(config1.getDurableExecutionClient(), config2.getDurableExecutionClient()); // ExecutorService should be different instances (each gets its own) - assertNotSame(config1.getExecutorService(), config2.getExecutorService()); + assertSame(config1.getExecutorService(), config2.getExecutorService()); } @Test @@ -210,7 +210,7 @@ void testDefaultConfig_CreatesNewInstancesEachTime() { var config2 = DurableConfig.defaultConfig(); assertNotSame(config1, config2); - assertNotSame(config1.getExecutorService(), config2.getExecutorService()); + assertSame(config1.getExecutorService(), config2.getExecutorService()); } @Test