-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathBaseDurableOperation.java
More file actions
399 lines (351 loc) · 17 KB
/
BaseDurableOperation.java
File metadata and controls
399 lines (351 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;
import java.time.Instant;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.model.Operation;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.awssdk.services.lambda.model.OperationUpdate;
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.exception.IllegalDurableOperationException;
import software.amazon.lambda.durable.exception.NonDeterministicExecutionException;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.execution.ExecutionManager;
import software.amazon.lambda.durable.execution.SuspendExecutionException;
import software.amazon.lambda.durable.execution.ThreadContext;
import software.amazon.lambda.durable.execution.ThreadType;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.util.ExceptionHelper;
/**
* Base class for all durable operations (STEP, WAIT, etc.).
*
* <p>Key methods:
*
* <ul>
* <li>{@code execute()} starts the operation (returns immediately)
* <li>{@code get()} blocks until complete and returns the result
* </ul>
*
* <p>The separation allows:
*
* <ul>
* <li>Starting multiple async operations quickly
* <li>Blocking on results later when needed
* <li>Proper thread coordination via future
* </ul>
*/
public abstract class BaseDurableOperation {
private static final Logger logger = LoggerFactory.getLogger(BaseDurableOperation.class);
private final OperationIdentifier operationIdentifier;
protected final ExecutionManager executionManager;
protected final CompletableFuture<BaseDurableOperation> completionFuture;
protected final BaseDurableOperation parentOperation;
private final DurableContextImpl durableContext;
private final AtomicReference<CompletableFuture<Void>> runningUserHandler = new AtomicReference<>(null);
/**
* Constructs a new durable operation.
*
* @param operationIdentifier the unique identifier for this operation
* @param durableContext the parent context this operation belongs to
* @param parentOperation the parent operation if this is a branch/iteration of a ConcurrencyOperation
*/
protected BaseDurableOperation(
OperationIdentifier operationIdentifier,
DurableContextImpl durableContext,
BaseDurableOperation parentOperation) {
this.operationIdentifier = operationIdentifier;
this.parentOperation = parentOperation;
this.durableContext = durableContext;
this.executionManager = durableContext.getExecutionManager();
this.completionFuture = new CompletableFuture<>();
// register this operation in ExecutionManager so that the operation can receive updates from ExecutionManager
executionManager.registerOperation(this);
}
public CompletableFuture<BaseDurableOperation> getCompletionFuture() {
return completionFuture;
}
/** Gets the operation sub-type (e.g. RUN_IN_CHILD_CONTEXT, WAIT_FOR_CALLBACK). */
public OperationSubType getSubType() {
return operationIdentifier.subType();
}
/** Gets the unique identifier for this operation. */
public String getOperationId() {
return operationIdentifier.operationId();
}
/** Gets the operation name (may be null). */
public String getName() {
return operationIdentifier.name();
}
/** Gets the parent context. */
protected DurableContextImpl getContext() {
return durableContext;
}
/** Gets the operation type. */
public OperationType getType() {
return operationIdentifier.operationType();
}
/**
* Starts the operation by checking for an existing checkpoint. If a checkpoint exists, validates and replays it;
* otherwise starts fresh execution.
*/
public void execute() {
var existing = getOperation();
if (existing != null) {
validateReplay(existing);
replay(existing);
} else {
if (durableContext.isReplaying()) {
this.durableContext.setExecutionMode();
}
start();
}
}
/** Starts the operation on first execution (no existing checkpoint). */
protected abstract void start();
/**
* Replays the operation from an existing checkpoint.
*
* @param existing the checkpointed operation state
*/
protected abstract void replay(Operation existing);
/**
* Gets the Operation from ExecutionManager and update the replay state from REPLAY to EXECUTE if operation is not
* found. Operation IDs are globally unique (prefixed for child contexts), so no parentId is needed for lookups.
*
* @return the operation if found, otherwise null
*/
protected Operation getOperation() {
return executionManager.getOperationAndUpdateReplayState(getOperationId());
}
/**
* Gets the direct child Operations of this context operation
*
* @return list of the child Operations
*/
protected List<Operation> getChildOperations() {
return executionManager.getChildOperations(getOperationId());
}
/**
* Checks if it's called from a Step.
*
* @throws IllegalDurableOperationException if it's in a step
*/
private void validateCurrentThreadType() {
ThreadType current = getCurrentThreadContext().threadType();
if (current == ThreadType.STEP) {
var message = String.format(
"Nested %s operation is not supported on %s from within a %s execution.",
getType(), getName(), current);
// terminate execution and throw the exception
throw terminateExecutionWithIllegalDurableOperationException(message);
}
}
/** Returns true if this operation has completed (successfully or exceptionally). */
protected boolean isOperationCompleted() {
return completionFuture.isDone();
}
/**
* Waits for the operation to complete. Deregisters the current thread to allow Lambda suspension if the operation
* is still in progress, then re-registers when the operation completes.
*
* @return the completed operation
*/
protected Operation waitForOperationCompletion() {
validateCurrentThreadType();
var threadContext = getCurrentThreadContext();
// It's important that we synchronize access to the future. Otherwise, a race condition could happen if the
// completionFuture is completed by a user thread (a step or child context thread) when the execution here
// is between `isOperationCompleted` and `thenRun`.
// If this operation is a branch/iteration of a ConcurrencyOperation (map or parallel), the branches/iterations
// must be completed sequentially to avoid race conditions.
synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) {
if (!isOperationCompleted()) {
// Operation not done yet
logger.trace(
"deregistering thread {} when waiting for operation {} ({}) to complete ({})",
threadContext.threadId(),
getOperation(),
getType(),
completionFuture);
// Add a completion stage to completionFuture so that when the completionFuture is completed,
// it will register the current Context thread synchronously to make sure it is always registered
// strictly before the execution thread (Step or child context) is deregistered.
completionFuture.thenRun(() -> registerActiveThread(threadContext.threadId()));
// Deregister the current thread to allow suspension
executionManager.deregisterActiveThread(threadContext.threadId());
}
}
// Block until operation completes. No-op if the future is already completed.
try {
completionFuture.join();
} catch (Throwable throwable) {
ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable));
}
// Get result based on status
var op = getOperation();
if (op == null) {
throw terminateExecutionWithIllegalDurableOperationException(
String.format("%s operation not found: %s", getType(), getOperationId()));
}
return op;
}
protected void runUserHandler(Runnable runnable, String contextId, ThreadType threadType) {
Runnable wrapped = () -> {
executionManager.setCurrentThreadContext(new ThreadContext(contextId, threadType));
try {
runnable.run();
} finally {
if (contextId != null) {
try {
// if this is a child context or a step context, we need to
// deregister the context's thread from the execution manager
executionManager.deregisterActiveThread(contextId);
} catch (SuspendExecutionException e) {
// Expected when this is the last active thread. Must catch here because:
// 1/ This runs in a worker thread detached from handlerFuture
// 2/ Uncaught exception would prevent stepAsync().get() from resume
// Suspension/Termination is already signaled via
// suspendExecutionFuture/terminateExecutionFuture
// before the throw.
}
}
}
};
// runUserHandler is used to ensure that only one user handler is running at a time
if (runningUserHandler.get() != null) {
throw new IllegalStateException("User handler already running");
}
// Thread registration is intentionally split across two threads:
// 1. registerActiveThread on the PARENT thread — ensures the child is tracked before the
// parent can deregister and trigger suspension (race prevention).
// 2. setCurrentContext on the CHILD thread — sets the ThreadLocal so operations inside
// the child context know which context they belong to.
// registerActiveThread is idempotent (no-op if already registered).
registerActiveThread(contextId);
if (!runningUserHandler.compareAndSet(
null,
CompletableFuture.runAsync(
wrapped, getContext().getDurableConfig().getExecutorService()))) {
throw new IllegalStateException("User handler already running");
}
}
/**
* Receives operation updates from ExecutionManager. Completes the internal future when the operation reaches a
* terminal status, unblocking any threads waiting on this operation.
*
* @param operation the updated operation state
*/
public void onCheckpointComplete(Operation operation) {
if (ExecutionManager.isTerminalStatus(operation.status())) {
// This method handles only terminal status updates. Override this method if a DurableOperation needs to
// handle other updates.
logger.trace("In onCheckpointComplete, completing operation {} ({})", getOperationId(), completionFuture);
markCompletionFutureCompleted();
}
}
/** Marks the operation as already completed (in replay). */
protected void markAlreadyCompleted() {
// When the operation is already completed in a replay, we complete completionFuture immediately
// so that the `get` method will be unblocked and the context thread will be registered
logger.trace("In markAlreadyCompleted, completing operation: {} ({}).", getOperationId(), completionFuture);
markCompletionFutureCompleted();
}
private void markCompletionFutureCompleted() {
// It's important that we synchronize access to the future, otherwise the processing could happen
// on someone else's thread and cause a race condition.
synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) {
// Completing the future here will also run any other completion stages that have been attached
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
// so they will definitely have a chance to reactivate before we finish completing and deactivating
// whatever operations were just checkpointed.
completionFuture.complete(this);
}
}
/**
* Terminates the execution with the given exception.
*
* @param exception the unrecoverable exception
* @return never returns normally; always throws
*/
protected RuntimeException terminateExecution(UnrecoverableDurableExecutionException exception) {
executionManager.terminateExecution(exception);
// Exception is already thrown from above. Keep the throw statement below to make tests happy
throw exception;
}
/**
* Terminates the execution with an {@link IllegalDurableOperationException}.
*
* @param message the error message
* @return never returns normally; always throws
*/
protected RuntimeException terminateExecutionWithIllegalDurableOperationException(String message) {
return terminateExecution(new IllegalDurableOperationException(message));
}
/**
* Registers a thread as active in the execution manager.
*
* @param threadId the thread identifier to register
*/
protected void registerActiveThread(String threadId) {
executionManager.registerActiveThread(threadId);
}
/** Returns the current thread's context from the execution manager. */
protected ThreadContext getCurrentThreadContext() {
return executionManager.getCurrentThreadContext();
}
/** Polls the backend for updates to this operation. */
protected CompletableFuture<Operation> pollForOperationUpdates() {
return executionManager.pollForOperationUpdates(getOperationId());
}
/**
* Polls the backend for updates to this operation at a specific time.
*
* @param at the time to poll for updates
* @return a future that completes with the updated operation
*/
protected CompletableFuture<Operation> pollForOperationUpdates(Instant at) {
return executionManager.pollForOperationUpdates(getOperationId(), at);
}
/** Sends an operation update synchronously (blocks until the update is acknowledged). */
protected void sendOperationUpdate(OperationUpdate.Builder builder) {
sendOperationUpdateAsync(builder).join();
}
/** Sends an operation update asynchronously. */
protected CompletableFuture<Void> sendOperationUpdateAsync(OperationUpdate.Builder builder) {
var updateBuilder =
builder.id(getOperationId()).name(getName()).type(getType()).parentId(durableContext.getContextId());
if (getSubType() != null) {
updateBuilder.subType(getSubType().getValue());
}
return executionManager.sendOperationUpdate(updateBuilder.build());
}
/** Validates that current operation matches checkpointed operation during replay. */
protected void validateReplay(Operation checkpointed) {
if (checkpointed == null || checkpointed.type() == null) {
return; // First execution, no validation needed
}
if (!checkpointed.type().equals(getType())) {
throw terminateExecution(new NonDeterministicExecutionException(String.format(
"Operation type mismatch for \"%s\". Expected %s, got %s",
getOperationId(), checkpointed.type(), getType())));
}
if (!Objects.equals(checkpointed.name(), getName())) {
throw terminateExecution(new NonDeterministicExecutionException(String.format(
"Operation name mismatch for \"%s\". Expected \"%s\", got \"%s\"",
getOperationId(), checkpointed.name(), getName())));
}
if ((getSubType() == null && checkpointed.subType() != null)
|| getSubType() != null
&& !Objects.equals(checkpointed.subType(), getSubType().getValue())) {
throw terminateExecution(new NonDeterministicExecutionException(String.format(
"Operation subType mismatch for \"%s\". Expected \"%s\", got \"%s\"",
getOperationId(), checkpointed.subType(), getSubType())));
}
}
}