-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathConcurrencyOperation.java
More file actions
342 lines (307 loc) · 15.5 KB
/
ConcurrencyOperation.java
File metadata and controls
342 lines (307 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package software.amazon.lambda.durable.operation;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.model.OperationType;
import software.amazon.lambda.durable.DurableContext;
import software.amazon.lambda.durable.TypeToken;
import software.amazon.lambda.durable.config.RunInChildContextConfig;
import software.amazon.lambda.durable.context.DurableContextImpl;
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
import software.amazon.lambda.durable.execution.OperationIdGenerator;
import software.amazon.lambda.durable.execution.SuspendExecutionException;
import software.amazon.lambda.durable.execution.ThreadType;
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
import software.amazon.lambda.durable.model.OperationIdentifier;
import software.amazon.lambda.durable.model.OperationSubType;
import software.amazon.lambda.durable.serde.SerDes;
import software.amazon.lambda.durable.util.ExceptionHelper;
/**
* Abstract base class for concurrent execution of multiple child context operations.
*
* <p>Encapsulates shared concurrency logic: queue-based concurrency control, success/failure counting, and completion
* checking. Both {@code ParallelOperation} and {@code MapOperation} extend this base.
*
* <p>Key design points:
*
* <ul>
* <li>Does NOT register its own thread — child context threads handle all suspension
* <li>Uses a pending queue + running counter for concurrency control
* <li>Completion is determined by subclass-specific logic via abstract {@code canComplete()} and
* {@code validateItemCount()}
* <li>When a child suspends, the running count is NOT decremented
* </ul>
*
* @param <T> the result type of this operation
*/
public abstract class ConcurrencyOperation<T> extends SerializableDurableOperation<T> {
private static final Logger logger = LoggerFactory.getLogger(ConcurrencyOperation.class);
private final int maxConcurrency;
private final Integer minSuccessful;
private final Integer toleratedFailureCount;
private final OperationIdGenerator operationIdGenerator;
private final DurableContextImpl rootContext;
// access by context thread only
private final List<ChildContextOperation<?>> branches = Collections.synchronizedList(new ArrayList<>());
// put only by context thread and consume only by consumer thread
private final Queue<ChildContextOperation<?>> pendingQueue = new ConcurrentLinkedDeque<>();
// set by context thread and used by consumer thread
protected final AtomicBoolean isJoined = new AtomicBoolean(false);
// used to wake up consumer thread for either new items or checking completion condition (isJoined changed)
private final AtomicReference<CompletableFuture<BaseDurableOperation>> consumerThreadListener;
protected ConcurrencyOperation(
OperationIdentifier operationIdentifier,
TypeToken<T> resultTypeToken,
SerDes resultSerDes,
DurableContextImpl durableContext,
int maxConcurrency,
Integer minSuccessful,
Integer toleratedFailureCount) {
super(operationIdentifier, resultTypeToken, resultSerDes, durableContext);
this.maxConcurrency = maxConcurrency;
this.minSuccessful = minSuccessful;
this.toleratedFailureCount = toleratedFailureCount;
this.operationIdGenerator = new OperationIdGenerator(getOperationId());
this.rootContext = durableContext.createChildContext(getOperationId(), getName());
this.consumerThreadListener = new AtomicReference<>(new CompletableFuture<>());
}
// ========== Template methods for subclasses ==========
/**
* Creates a child context operation for a single item (branch or iteration).
*
* @param operationId the unique operation ID for this item
* @param name the name of this item
* @param function the user function to execute
* @param resultType the result type token
* @param branchSubType the sub-type of the branch operation
* @param parentContext the parent durable context
* @param <R> the result type of the child operation
* @return a new ChildContextOperation
*/
protected <R> ChildContextOperation<R> createItem(
String operationId,
String name,
Function<DurableContext, R> function,
TypeToken<R> resultType,
SerDes serDes,
OperationSubType branchSubType,
DurableContextImpl parentContext) {
return new ChildContextOperation<>(
OperationIdentifier.of(operationId, name, OperationType.CONTEXT, branchSubType),
function,
resultType,
RunInChildContextConfig.builder().serDes(serDes).build(),
parentContext,
this);
}
/** Called when the concurrency operation completes. Subclasses define checkpointing behavior. */
protected abstract void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletionStatus);
// ========== Concurrency control ==========
/**
* Creates and enqueues an item without starting execution. Use {@link #executeItems()} to begin execution after all
* items have been enqueued. This prevents early termination from blocking item creation when all items are known
* upfront (e.g., map operations).
*/
protected <R> ChildContextOperation<R> enqueueItem(
String name,
Function<DurableContext, R> function,
TypeToken<R> resultType,
SerDes serDes,
OperationSubType branchSubType) {
var operationId = this.operationIdGenerator.nextOperationId();
var childOp = createItem(operationId, name, function, resultType, serDes, branchSubType, this.rootContext);
branches.add(childOp);
pendingQueue.add(childOp);
logger.debug("Item enqueued {}", name);
// notify the consumer thread a new item is available
notifyConsumerThread();
return childOp;
}
private void notifyConsumerThread() {
synchronized (completionFuture) {
consumerThreadListener.get().complete(null);
}
}
/** Starts execution of all enqueued items. */
protected void executeItems() {
// variables accessed only by the consumer thread. Put them here to avoid accidentally used by other threads
Set<BaseDurableOperation> runningChildren = new HashSet<>();
AtomicInteger succeededCount = new AtomicInteger(0);
AtomicInteger failedCount = new AtomicInteger(0);
Runnable consumer = () -> {
try {
while (true) {
// Set a new future if it's completed so that it will be able to receive a notification of
// new items when the thread is checking completion condition and processing
// the queued items below.
synchronized (completionFuture) {
if (consumerThreadListener.get() != null
&& consumerThreadListener.get().isDone()) {
consumerThreadListener.set(new CompletableFuture<>());
}
}
// Process completion condition. Quit the loop if the condition is met.
if (isOperationCompleted()) {
return;
}
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
if (completionStatus != null) {
handleCompletion(completionStatus);
return;
}
// process new items in the queue
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
var next = pendingQueue.poll();
runningChildren.add(next);
logger.debug("Executing operation {}", next.getName());
next.execute();
}
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
// immediately return null and repeat the above again
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
// child may be null if the consumer thread is woken up due to new items added or completion
// condition
// changed
if (child != null) {
if (runningChildren.contains(child)) {
runningChildren.remove(child);
onItemComplete(succeededCount, failedCount, (ChildContextOperation<?>) child);
} else {
throw new IllegalStateException("Unexpected completion: " + child);
}
}
}
} catch (Throwable ex) {
handleException(ex);
}
};
// run consumer in the user thread pool, although it's not a real user thread
runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT);
}
private void handleException(Throwable ex) {
Throwable throwable = ExceptionHelper.unwrapCompletableFuture(ex);
if (throwable instanceof SuspendExecutionException suspendExecutionException) {
// Rethrow Error immediately — do not checkpoint
throw suspendExecutionException;
}
if (throwable instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) {
throw terminateExecution(unrecoverableDurableExecutionException);
}
throw terminateExecutionWithIllegalDurableOperationException(
String.format("Unexpected exception in concurrency operation: %s", throwable));
}
private BaseDurableOperation waitForChildCompletion(
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
var threadContext = getCurrentThreadContext();
CompletableFuture<Object> future;
synchronized (completionFuture) {
// check again in synchronized block to prevent race conditions
if (isOperationCompleted()) {
return null;
}
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
if (completionStatus != null) {
return null;
}
ArrayList<CompletableFuture<BaseDurableOperation>> futures;
futures = new ArrayList<>(runningChildren.stream()
.map(BaseDurableOperation::getCompletionFuture)
.toList());
if (futures.size() < maxConcurrency) {
// add a future to listen to the new items if there is a vacancy
consumerThreadListener.compareAndSet(null, new CompletableFuture<>());
futures.add(consumerThreadListener.get());
}
// future will be completed immediately if any future of the list is already completed
future = CompletableFuture.anyOf(futures.toArray(CompletableFuture[]::new));
// skip deregistering the current thread if there is more completed future to process
if (!future.isDone()) {
future.thenRun(() -> registerActiveThread(threadContext.threadId()));
// Deregister the current thread to allow suspension
executionManager.deregisterActiveThread(threadContext.threadId());
}
}
try {
return future.thenApply(o -> (BaseDurableOperation) o).join();
} catch (Throwable throwable) {
ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable));
throw throwable;
}
}
/**
* Called by a ChildContextOperation BEFORE it closes its child context. Updates counters, checks completion
* criteria, and either triggers the next queued item or completes the operation.
*
* @param child the child operation that completed
*/
private void onItemComplete(
AtomicInteger succeededCount, AtomicInteger failedCount, ChildContextOperation<?> child) {
// Evaluate child result outside the lock — child.get() may block waiting for a checkpoint response.
logger.debug("OnItemComplete called by {}, Id: {}", child.getName(), child.getOperationId());
try {
child.get();
logger.debug("Result succeeded - {}", child.getName());
succeededCount.incrementAndGet();
} catch (Throwable e) {
logger.debug("Child operation {} failed: {}", child.getOperationId(), e.getMessage());
failedCount.incrementAndGet();
}
}
// ========== Completion logic ==========
/**
* Checks whether the concurrency operation can be considered complete.
*
* @return the completion status if the operation is complete, or null if it should continue
*/
private ConcurrencyCompletionStatus canComplete(
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
int succeeded = succeededCount.get();
int failed = failedCount.get();
// If we've met the minimum successful count, we're done
if (minSuccessful != null && succeeded >= minSuccessful) {
return ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED;
}
// If we've exceeded the failure tolerance, we're done
if (toleratedFailureCount != null && failed > toleratedFailureCount) {
return ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED;
}
// All items finished — complete
// This condition relies on isJoined, so the consumer will wake up and check this again when
// isJoined is set to true.
if (isJoined.get() && pendingQueue.isEmpty() && runningChildren.isEmpty()) {
return ConcurrencyCompletionStatus.ALL_COMPLETED;
}
return null;
}
/**
* Blocks the calling thread until the concurrency operation reaches a terminal state. Validates item count, handles
* zero-branch case, then delegates to {@code waitForOperationCompletion()} from BaseDurableOperation.
*/
protected void join() {
if (minSuccessful != null && minSuccessful > branches.size()) {
throw new IllegalStateException("minSuccessful (" + minSuccessful
+ ") exceeds the number of registered items (" + branches.size() + ")");
}
isJoined.set(true);
// Notify the consumer thread this concurrency operation is joined. Consumer thread need to check the
// completion condition again.
notifyConsumerThread();
waitForOperationCompletion();
}
protected List<ChildContextOperation<?>> getBranches() {
return branches;
}
}