diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducerStrategies.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducerStrategies.swift index 5fe444b9bd..6de186197c 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducerStrategies.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducerStrategies.swift @@ -22,6 +22,7 @@ public enum NIOAsyncSequenceProducerBackPressureStrategies { public struct HighLowWatermark: NIOAsyncSequenceProducerBackPressureStrategy { private let lowWatermark: Int private let highWatermark: Int + private var hasOustandingDemand: Bool = true /// Initializes a new ``NIOAsyncSequenceProducerBackPressureStrategies/HighLowWatermark``. /// @@ -36,12 +37,29 @@ public enum NIOAsyncSequenceProducerBackPressureStrategies { public mutating func didYield(bufferDepth: Int) -> Bool { // We are demanding more until we reach the high watermark - bufferDepth < self.highWatermark + if bufferDepth < self.highWatermark { + precondition(self.hasOustandingDemand) + return true + } else { + self.hasOustandingDemand = false + return false + } } public mutating func didConsume(bufferDepth: Int) -> Bool { // We start demanding again once we are below the low watermark - bufferDepth < self.lowWatermark + if bufferDepth < self.lowWatermark { + if self.hasOustandingDemand { + // We are below and have outstanding demand + return true + } else { + // We are below but don't have outstanding demand but need more + self.hasOustandingDemand = true + return true + } + } else { + return self.hasOustandingDemand + } } } } diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceProducer+HighLowWatermarkBackPressureStrategyTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceProducer+HighLowWatermarkBackPressureStrategyTests.swift index 3506dc9ce8..f7a6790ecf 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceProducer+HighLowWatermarkBackPressureStrategyTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceProducer+HighLowWatermarkBackPressureStrategyTests.swift @@ -51,10 +51,10 @@ final class NIOAsyncSequenceProducerBackPressureStrategiesHighLowWatermarkTests: } func testDidConsume_whenAboveLowWatermark() { - XCTAssertFalse(self.strategy.didConsume(bufferDepth: 6)) + XCTAssertTrue(self.strategy.didConsume(bufferDepth: 6)) } func testDidConsume_whenAtLowWatermark() { - XCTAssertFalse(self.strategy.didConsume(bufferDepth: 5)) + XCTAssertTrue(self.strategy.didConsume(bufferDepth: 5)) } } diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 3720976224..7a7583f59e 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -149,6 +149,42 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { XCTAssertEqual(self.source.yield(contentsOf: [7, 8, 9, 10, 11]), .stopProducing) } + func testWatermarkBackpressure_whenBelowLowwatermark_andOutstandingDemand() async { + let newSequence = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark( + lowWatermark: 2, + highWatermark: 5 + ), + finishOnDeinit: false, + delegate: self.delegate + ) + let iterator = newSequence.sequence.makeAsyncIterator() + var eventsIterator = self.delegate.events.makeAsyncIterator() + let source = newSequence.source + + XCTAssertEqual(source.yield(1), .produceMore) + XCTAssertEqual(source.yield(2), .produceMore) + XCTAssertEqual(source.yield(3), .produceMore) + XCTAssertEqual(source.yield(4), .produceMore) + XCTAssertEqual(source.yield(5), .stopProducing) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 1) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 2) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 3) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 4) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 5) + XCTAssertEqualWithoutAutoclosure(await eventsIterator.next(), .produceMore) + XCTAssertEqual(source.yield(6), .produceMore) + XCTAssertEqual(source.yield(7), .produceMore) + XCTAssertEqual(source.yield(8), .produceMore) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 6) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 7) + XCTAssertEqualWithoutAutoclosure(await iterator.next(), 8) + source.finish() + XCTAssertEqualWithoutAutoclosure(await iterator.next(), nil) + XCTAssertEqualWithoutAutoclosure(await eventsIterator.next(), .didTerminate) + } + // MARK: - Yield func testYield_whenInitial_andStopDemanding() async {