diff --git a/examples/src/main/java/software/amazon/lambda/durable/examples/MapExample.java b/examples/src/main/java/software/amazon/lambda/durable/examples/MapExample.java new file mode 100644 index 000000000..231741185 --- /dev/null +++ b/examples/src/main/java/software/amazon/lambda/durable/examples/MapExample.java @@ -0,0 +1,41 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.examples; + +import java.util.List; +import software.amazon.lambda.durable.ConcurrencyConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableHandler; +import software.amazon.lambda.durable.ParallelBranchConfig; +import software.amazon.lambda.durable.TypeToken; + +/** + * Simple example demonstrating basic step execution with the Durable Execution SDK. + * + *

This handler processes a greeting request through three sequential steps: + * + *

    + *
  1. Create greeting message + *
  2. Transform to uppercase + *
  3. Add punctuation + *
+ */ +public class MapExample extends DurableHandler { + + @Override + public String handleRequest(GreetingRequest input, DurableContext context) { + var squared = context.mapAsync( + "map example", + List.of(1, 2, 3), + (ctx, item, index) -> item * item, + TypeToken.get(Integer.class), + new ConcurrencyConfig(10, 2, 1)); + + var parallel = context.parallelAsync("parallel example", new ConcurrencyConfig(10, 2, 1)); + var b1 = parallel.branch("branch1", TypeToken.get(String.class), ctx -> "hello", new ParallelBranchConfig()); + var b2 = parallel.branch("branch2", TypeToken.get(String.class), ctx -> "world", new ParallelBranchConfig()); + + var result = parallel.get(); + return b1.get() + " " + b2.get(); + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/BatchResult.java b/sdk/src/main/java/software/amazon/lambda/durable/BatchResult.java new file mode 100644 index 000000000..11784e495 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/BatchResult.java @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +public class BatchResult extends ParallelResult { + // results/errors as well as the statistics +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ConcurrencyConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/ConcurrencyConfig.java new file mode 100644 index 000000000..8674a83ca --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/ConcurrencyConfig.java @@ -0,0 +1,5 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +public record ConcurrencyConfig(int maxConcurrency, int minSuccessful, int toleratedFailureCount) {} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java index 32c9d7944..5954d0649 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableContext.java @@ -8,6 +8,7 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.util.HexFormat; +import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; @@ -21,8 +22,11 @@ import software.amazon.lambda.durable.operation.CallbackOperation; import software.amazon.lambda.durable.operation.ChildContextOperation; import software.amazon.lambda.durable.operation.InvokeOperation; +import software.amazon.lambda.durable.operation.MapOperation; +import software.amazon.lambda.durable.operation.ParallelOperation; import software.amazon.lambda.durable.operation.StepOperation; import software.amazon.lambda.durable.operation.WaitOperation; +import software.amazon.lambda.durable.serde.JacksonSerDes; import software.amazon.lambda.durable.validation.ParameterValidator; public class DurableContext extends BaseContext { @@ -335,7 +339,7 @@ private DurableFuture runInChildContextAsync( var operationId = nextOperationId(); var operation = new ChildContextOperation<>( - operationId, name, func, subType, typeToken, getDurableConfig().getSerDes(), this); + operationId, name, func, subType, typeToken, getDurableConfig().getSerDes(), this, null); operation.execute(); return operation; @@ -438,6 +442,28 @@ public DurableFuture waitForCallbackAsync( OperationSubType.WAIT_FOR_CALLBACK); } + // parallel operations + public DurableParallelFuture parallelAsync(String name, ConcurrencyConfig config) { + var operationId = nextOperationId(); + var operation = new ParallelOperation(operationId, name, config, this); + operation.execute(); + return operation; + } + + // map operations + public DurableFuture> mapAsync( + String name, + List collection, + MapFunction func, + TypeToken resultTypeToken, + ConcurrencyConfig config) { + var operationId = nextOperationId(); + var operation = new MapOperation<>( + operationId, name, collection, func, resultTypeToken, new JacksonSerDes(), config, this); + operation.execute(); + return operation; + } + // =============== accessors ================ /** * Returns a logger with execution context information for replay-aware logging. @@ -474,7 +500,7 @@ public void close() { * prefixed with the parent hashed contextId (e.g. "-1", "-2" inside parent context ). This * matches the Python SDK's stepPrefix convention and prevents ID collisions in checkpoint batches. */ - private String nextOperationId() { + public String nextOperationId() { var counter = String.valueOf(operationCounter.incrementAndGet()); var rawId = getContextId() != null ? getContextId() + "-" + counter : counter; try { diff --git a/sdk/src/main/java/software/amazon/lambda/durable/DurableParallelFuture.java b/sdk/src/main/java/software/amazon/lambda/durable/DurableParallelFuture.java new file mode 100644 index 000000000..2e1b514c1 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/DurableParallelFuture.java @@ -0,0 +1,10 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +import java.util.function.Function; + +public interface DurableParallelFuture extends DurableFuture { + DurableFuture branch( + String name, TypeToken resultType, Function func, ParallelBranchConfig config); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java b/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java new file mode 100644 index 000000000..f704fd0c7 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/MapFunction.java @@ -0,0 +1,8 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +@FunctionalInterface +public interface MapFunction { + O apply(DurableContext context, I item, int index); +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelBranchConfig.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelBranchConfig.java new file mode 100644 index 000000000..95938dd00 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelBranchConfig.java @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +public class ParallelBranchConfig { + // SerDes and etc +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/ParallelResult.java b/sdk/src/main/java/software/amazon/lambda/durable/ParallelResult.java new file mode 100644 index 000000000..77bd81a83 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/ParallelResult.java @@ -0,0 +1,6 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable; + +/** Statistics of a parallel operation (succeeded, failed, etc.) */ +public class ParallelResult {} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/model/OperationSubType.java b/sdk/src/main/java/software/amazon/lambda/durable/model/OperationSubType.java index 9e778ef07..3a240ed6a 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/model/OperationSubType.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/model/OperationSubType.java @@ -11,7 +11,9 @@ public enum OperationSubType { RUN_IN_CHILD_CONTEXT("RunInChildContext"), MAP("Map"), + MAP_ITERATION("MapInteration"), PARALLEL("Parallel"), + PARALLEL_BRANCH("ParallelBranch"), WAIT_FOR_CALLBACK("WaitForCallback"); private final String value; diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseConcurrentOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseConcurrentOperation.java new file mode 100644 index 000000000..c26e90c33 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseConcurrentOperation.java @@ -0,0 +1,118 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.operation; + +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import software.amazon.awssdk.services.lambda.model.OperationAction; +import software.amazon.awssdk.services.lambda.model.OperationType; +import software.amazon.awssdk.services.lambda.model.OperationUpdate; +import software.amazon.lambda.durable.ConcurrencyConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.model.OperationSubType; +import software.amazon.lambda.durable.serde.NoopSerDes; +import software.amazon.lambda.durable.serde.SerDes; + +public abstract class BaseConcurrentOperation extends BaseDurableOperation { + + private final ArrayList> branches; + private final Queue> queue; + private final DurableContext rootContext; + private final AtomicInteger succeeded; + private final AtomicInteger failed; + private final OperationSubType subType; + private final ConcurrencyConfig config; + private final AtomicInteger activeBranches; + + public BaseConcurrentOperation( + String operationId, + String name, + OperationSubType subType, + ConcurrencyConfig config, + DurableContext durableContext) { + super(operationId, name, OperationType.CONTEXT, new TypeToken<>() {}, new NoopSerDes(), durableContext); + this.branches = new ArrayList<>(); + this.queue = new ConcurrentLinkedQueue<>(); + this.rootContext = durableContext.createChildContext(operationId, name); + this.config = config; + this.succeeded = new AtomicInteger(0); + this.failed = new AtomicInteger(0); + this.subType = subType; + this.activeBranches = new AtomicInteger(0); + } + + protected ChildContextOperation branchInternal( + String name, TypeToken resultType, SerDes resultSerDes, Function func) { + var operationId = this.rootContext.nextOperationId(); + ChildContextOperation operation; + + synchronized (this.branches) { + operation = new ChildContextOperation<>( + operationId, + name, + func, + OperationSubType.PARALLEL_BRANCH, + resultType, + resultSerDes, + rootContext, + this); + branches.add(operation); + queue.add(operation); + } + + executeNewBranchIfConcurrencyAllows(); + + return operation; + } + + private void executeNewBranchIfConcurrencyAllows() { + synchronized (this) { + // use one extra thread from user's thread pool to wait for the semaphore + if (activeBranches.get() < config.maxConcurrency()) { + if (!queue.isEmpty()) { + activeBranches.incrementAndGet(); + + var op = queue.poll(); + op.execute(); + } + } + } + } + + @Override + public void onChildContextComplete(ChildContextOperation parallelBranchOperation) { + if (isOperationCompleted()) { + return; + } + + activeBranches.decrementAndGet(); + + // handle branch results + try { + parallelBranchOperation.get(); + succeeded.incrementAndGet(); + } catch (Exception e) { + failed.incrementAndGet(); + } + + if (isDone()) { + sendOperationUpdateAsync(OperationUpdate.builder() + .action(OperationAction.SUCCEED) + .subType(OperationSubType.PARALLEL.getValue()) + .payload("")); + + rootContext.close(); + } else { + // we must make sure the thread for the new branch is registered before the child thread is deregistered + executeNewBranchIfConcurrencyAllows(); + } + } + + private boolean isDone() { + return succeeded.get() >= config.minSuccessful() || failed.get() > config.toleratedFailureCount(); + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java index 5ee970f66..13740af85 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java @@ -50,8 +50,8 @@ public abstract class BaseDurableOperation implements DurableFuture { private final String name; private final OperationType operationType; private final ExecutionManager executionManager; - private final TypeToken resultTypeToken; - private final SerDes resultSerDes; + protected final TypeToken resultTypeToken; + protected final SerDes resultSerDes; protected final CompletableFuture completionFuture; private final DurableContext durableContext; @@ -338,4 +338,9 @@ protected void validateReplay(Operation checkpointed) { operationId, checkpointed.name(), getName()))); } } + + protected void onChildContextComplete(ChildContextOperation tChildContextOperation) { + // do nothing + + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java index 95c710dee..ced590de5 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ChildContextOperation.java @@ -42,6 +42,7 @@ public class ChildContextOperation extends BaseDurableOperation { private final Function function; private final ExecutorService userExecutor; + private final BaseDurableOperation parentOperation; private boolean replayChildContext; private T reconstructedResult; private final OperationSubType subType; @@ -53,11 +54,13 @@ public ChildContextOperation( OperationSubType subType, TypeToken resultTypeToken, SerDes resultSerDes, - DurableContext durableContext) { + DurableContext durableContext, + BaseDurableOperation parentOperation) { super(operationId, name, OperationType.CONTEXT, resultTypeToken, resultSerDes, durableContext); this.function = function; this.userExecutor = getContext().getDurableConfig().getExecutorService(); this.subType = subType; + this.parentOperation = parentOperation; } /** Starts the operation. */ @@ -118,6 +121,10 @@ private void executeChildContext() { handleChildContextSuccess(result); } catch (Throwable e) { handleChildContextFailure(e); + } finally { + if (parentOperation != null) { + parentOperation.onChildContextComplete(this); + } } } }; @@ -138,6 +145,9 @@ private void handleChildContextSuccess(T result) { } private void checkpointSuccess(T result) { + if (parentOperation != null && parentOperation.isOperationCompleted()) { + return; // Already completed by parent operation + } var serialized = serializeResult(result); var serializedBytes = serialized.getBytes(StandardCharsets.UTF_8); @@ -169,6 +179,10 @@ private void handleChildContextFailure(Throwable exception) { terminateExecution((UnrecoverableDurableExecutionException) exception); } + if (parentOperation != null && parentOperation.isOperationCompleted()) { + return; // Already completed by parent operation + } + final ErrorObject errorObject; if (exception instanceof DurableOperationException opEx) { errorObject = opEx.getErrorObject(); @@ -210,6 +224,8 @@ public T get() { case MAP -> throw new ChildContextFailedException(op); case PARALLEL -> throw new ChildContextFailedException(op); case RUN_IN_CHILD_CONTEXT -> throw new ChildContextFailedException(op); + case PARALLEL_BRANCH -> throw new ChildContextFailedException(op); + case MAP_ITERATION -> throw new ChildContextFailedException(op); }; } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java new file mode 100644 index 000000000..5824355dc --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.operation; + +import java.util.ArrayList; +import java.util.SequencedCollection; +import software.amazon.awssdk.services.lambda.model.Operation; +import software.amazon.awssdk.services.lambda.model.OperationAction; +import software.amazon.awssdk.services.lambda.model.OperationUpdate; +import software.amazon.lambda.durable.BatchResult; +import software.amazon.lambda.durable.ConcurrencyConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.MapFunction; +import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.model.OperationSubType; +import software.amazon.lambda.durable.serde.SerDes; + +public class MapOperation extends BaseConcurrentOperation> { + private final MapFunction func; + private final ArrayList> iterations; + private final SequencedCollection collection; + private final SerDes serDes; + private final TypeToken branchResultTypeToken; + + public MapOperation( + String operationId, + String name, + SequencedCollection collection, + MapFunction func, + TypeToken resultTypeToken, + SerDes resultSerDes, + ConcurrencyConfig config, + DurableContext durableContext) { + super(operationId, name, OperationSubType.MAP_ITERATION, config, durableContext); + this.func = func; + this.iterations = new ArrayList<>(); + this.branchResultTypeToken = resultTypeToken; + this.serDes = resultSerDes; + this.collection = collection; + } + + /** Starts the operation. */ + @Override + protected void start() { + sendOperationUpdateAsync( + OperationUpdate.builder().action(OperationAction.START).subType(OperationSubType.MAP.getValue())); + for (var item : collection) { + int index = iterations.size(); + iterations.add(branchInternal( + getName() + "-iteration-" + index, + branchResultTypeToken, + serDes, + (ctx) -> func.apply(ctx, item, index))); + } + } + + /** + * Replays the operation. + * + * @param existing + */ + @Override + protected void replay(Operation existing) { + switch (existing.status()) { + case SUCCEEDED, FAILED -> markAlreadyCompleted(); + case STARTED -> start(); + } + } + + /** + * Blocks until the operation completes and returns the result. + * + *

This delegates to operation.get() which handles: - Thread deregistration (allows suspension) - Thread + * reactivation (resumes execution) - Result retrieval + * + * @return the operation result + */ + @Override + public BatchResult get() { + waitForOperationCompletion(); + // build the batch results + // + return new BatchResult<>(); + } +} diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java new file mode 100644 index 000000000..15e9eac11 --- /dev/null +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -0,0 +1,60 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package software.amazon.lambda.durable.operation; + +import java.util.function.Function; +import software.amazon.awssdk.services.lambda.model.Operation; +import software.amazon.awssdk.services.lambda.model.OperationAction; +import software.amazon.awssdk.services.lambda.model.OperationUpdate; +import software.amazon.lambda.durable.ConcurrencyConfig; +import software.amazon.lambda.durable.DurableContext; +import software.amazon.lambda.durable.DurableFuture; +import software.amazon.lambda.durable.DurableParallelFuture; +import software.amazon.lambda.durable.ParallelResult; +import software.amazon.lambda.durable.TypeToken; +import software.amazon.lambda.durable.model.OperationSubType; +import software.amazon.lambda.durable.serde.SerDes; + +public class ParallelOperation extends BaseConcurrentOperation implements DurableParallelFuture { + + public ParallelOperation(String operationId, String name, ConcurrencyConfig config, DurableContext durableContext) { + super(operationId, name, OperationSubType.PARALLEL_BRANCH, config, durableContext); + } + + public DurableFuture branch( + String name, TypeToken resultType, SerDes resultSerDes, Function func) { + return branchInternal(name, resultType, resultSerDes, func); + } + + /** Starts the operation. */ + @Override + protected void start() { + sendOperationUpdateAsync( + OperationUpdate.builder().action(OperationAction.START).subType(OperationSubType.PARALLEL.getValue())); + } + + /** Replays the operation. */ + @Override + protected void replay(Operation existing) { + // always replay the branches + + } + + /** + * Blocks until the operation completes and returns the result. + * + *

This delegates to operation.get() which handles: - Thread deregistration (allows suspension) - Thread + * reactivation (resumes execution) - Result retrieval + * + * @return the operation result + */ + @Override + public ParallelResult get() { + // wait for all to complete + waitForOperationCompletion(); + + // This method only returns stats of the branches (succeeded, failed, etc) + // Users need to use each branch to check for the result or error. + return new ParallelResult(); + } +} diff --git a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java index 9ad9ba084..a60804179 100644 --- a/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java +++ b/sdk/src/test/java/software/amazon/lambda/durable/operation/ChildContextOperationTest.java @@ -58,7 +58,8 @@ private ChildContextOperation createOperation( OperationSubType.RUN_IN_CHILD_CONTEXT, TypeToken.get(String.class), SERDES, - durableContext); + durableContext, + null); } // ===== SUCCEEDED replay =====