diff --git a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java index 9fbf007b..70b1353b 100644 --- a/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java +++ b/sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java @@ -227,7 +227,7 @@ void testParallelReplayAfterInterruption_cachedResultsUsed() { var firstRunCount = executionCounts.get(); assertTrue(firstRunCount >= 3, "Expected at least 3 executions on first run but got " + firstRunCount); - var result2 = runner.run("test"); + var result2 = runner.runUntilComplete("test"); assertEquals(ExecutionStatus.SUCCEEDED, result2.getStatus()); assertEquals("A,B,C", result2.getResult(String.class)); assertEquals(firstRunCount, executionCounts.get(), "Branch functions should not re-execute on replay"); @@ -536,4 +536,97 @@ void testParallelResultSummary_succeededAndFailedCounts() { assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); assertEquals("3/2", result.getResult(String.class)); } + + @Test + void testParallelWithToleratedFailureCount_earlyTermination() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder() + .maxConcurrency(1) + .completionConfig(CompletionConfig.toleratedFailureCount(1)) + .build(); + var futures = new ArrayList>(); + var parallel = context.parallel("tolerated-fail", config); + + try (parallel) { + futures.add(parallel.branch("branch-ok", String.class, ctx -> "OK")); + futures.add(parallel.branch("branch-fail1", String.class, ctx -> { + throw new RuntimeException("failed: fail1"); + })); + futures.add(parallel.branch("branch-fail2", String.class, ctx -> { + throw new RuntimeException("failed: fail2"); + })); + futures.add(parallel.branch("branch-ok2", String.class, ctx -> "OK2")); + } + + var result = parallel.get(); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionStatus()); + assertFalse(result.completionStatus().isSucceeded()); + assertEquals(4, result.size()); + assertEquals("OK", futures.get(0).get()); + + return "done"; + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + } + + @Test + void testParallelWithMinSuccessful_earlyTermination() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder() + .maxConcurrency(1) + .completionConfig(CompletionConfig.minSuccessful(2)) + .build(); + var futures = new ArrayList>(); + var parallel = context.parallel("min-successful", config); + + try (parallel) { + for (var item : List.of("a", "b", "c", "d", "e")) { + futures.add(parallel.branch("branch-" + item, String.class, ctx -> item.toUpperCase())); + } + } + + var result = parallel.get(); + assertEquals(ConcurrencyCompletionStatus.MIN_SUCCESSFUL_REACHED, result.completionStatus()); + assertTrue(result.completionStatus().isSucceeded()); + assertEquals(5, result.size()); + assertEquals("A", futures.get(0).get()); + assertEquals("B", futures.get(1).get()); + + return "done"; + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + } + + @Test + void testParallelWithAllSuccessful_stopsOnFirstFailure() { + var runner = LocalDurableTestRunner.create(String.class, (input, context) -> { + var config = ParallelConfig.builder() + .maxConcurrency(1) + .completionConfig(CompletionConfig.allSuccessful()) + .build(); + var futures = new ArrayList>(); + var parallel = context.parallel("all-successful", config); + + try (parallel) { + futures.add(parallel.branch("branch-ok1", String.class, ctx -> "OK1")); + futures.add(parallel.branch("branch-fail", String.class, ctx -> { + throw new RuntimeException("failed"); + })); + futures.add(parallel.branch("branch-ok2", String.class, ctx -> "OK2")); + } + + var result = parallel.get(); + assertEquals(ConcurrencyCompletionStatus.FAILURE_TOLERANCE_EXCEEDED, result.completionStatus()); + assertEquals("OK1", futures.get(0).get()); + + return "done"; + }); + + var result = runner.runUntilComplete("test"); + assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus()); + } } diff --git a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java index eacc1979..f5a644c6 100644 --- a/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java +++ b/sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java @@ -68,10 +68,12 @@ public ParallelOperation( protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletionStatus) { var items = getBranches(); int succeededCount = Math.toIntExact(items.stream() - .filter(item -> item.getOperation().status() == OperationStatus.SUCCEEDED) + .filter(item -> + item.getOperation() != null && item.getOperation().status() == OperationStatus.SUCCEEDED) .count()); int failedCount = Math.toIntExact(items.stream() - .filter(item -> item.getOperation().status() != OperationStatus.SUCCEEDED) + .filter(item -> + item.getOperation() != null && item.getOperation().status() != OperationStatus.SUCCEEDED) .count()); this.cachedResult = new ParallelResult(items.size(), succeededCount, failedCount, concurrencyCompletionStatus); if (skipCheckpoint) {