diff --git a/Sources/NIOEmbedded/AsyncTestingChannel.swift b/Sources/NIOEmbedded/AsyncTestingChannel.swift index 26677b3db1b..4c062b31196 100644 --- a/Sources/NIOEmbedded/AsyncTestingChannel.swift +++ b/Sources/NIOEmbedded/AsyncTestingChannel.swift @@ -404,8 +404,10 @@ public final class NIOAsyncTestingChannel: Channel { /// This method is similar to ``NIOAsyncTestingChannel/readOutbound(as:)`` but will wait if the outbound buffer is empty. /// If available, this method reads one element of type `T` out of the ``NIOAsyncTestingChannel``'s outbound buffer. If the - /// first element was of a different type than requested, ``WrongTypeError`` will be thrown, if there - /// are no elements in the outbound buffer, `nil` will be returned. + /// first element was of a different type than requested, ``WrongTypeError`` will be thrown. If the channel has + /// already closed or closes before the next pending outbound write, `ChannelError.ioOnClosedChannel` will be + /// thrown. If there are no elements in the outbound buffer, this method will wait until there is one, and return + /// that element. /// /// Data hits the ``NIOAsyncTestingChannel``'s outbound buffer when data was written using `write`, then `flush`ed, and /// then travelled the `ChannelPipeline` all the way to the front. For data to hit the outbound buffer, the very @@ -423,12 +425,13 @@ public final class NIOAsyncTestingChannel: Channel { continuation.resume(returning: element) return } - self.channelcore.outboundBufferConsumer.append { element in - continuation.resume( - with: Result { - try self._cast(element) - } - ) + self.channelcore._enqueueOutboundBufferConsumer { element in + switch element { + case .success(let data): + continuation.resume(with: Result { try self._cast(data) }) + case .failure(let failure): + continuation.resume(throwing: failure) + } } } catch { continuation.resume(throwing: error) @@ -456,8 +459,10 @@ public final class NIOAsyncTestingChannel: Channel { /// This method is similar to ``NIOAsyncTestingChannel/readInbound(as:)`` but will wait if the inbound buffer is empty. /// If available, this method reads one element of type `T` out of the ``NIOAsyncTestingChannel``'s inbound buffer. If the - /// first element was of a different type than requested, ``WrongTypeError`` will be thrown, if there - /// are no elements in the outbound buffer, this method will wait until an element is in the inbound buffer. + /// first element was of a different type than requested, ``WrongTypeError`` will be thrown. If the channel has + /// already closed or closes before the next pending inbound write, `ChannelError.ioOnClosedChannel` will be thrown. + /// If there are no elements in the inbound buffer, this method will wait until there is one, and return that + /// element. /// /// Data hits the ``NIOAsyncTestingChannel``'s inbound buffer when data was send through the pipeline using `fireChannelRead` /// and then travelled the `ChannelPipeline` all the way to the back. For data to hit the inbound buffer, the @@ -473,12 +478,13 @@ public final class NIOAsyncTestingChannel: Channel { continuation.resume(returning: element) return } - self.channelcore.inboundBufferConsumer.append { element in - continuation.resume( - with: Result { - try self._cast(element) - } - ) + self.channelcore._enqueueInboundBufferConsumer { element in + switch element { + case .success(let data): + continuation.resume(with: Result { try self._cast(data) }) + case .failure(let failure): + continuation.resume(throwing: failure) + } } } catch { continuation.resume(throwing: error) diff --git a/Sources/NIOEmbedded/Embedded.swift b/Sources/NIOEmbedded/Embedded.swift index 4fa6aee8893..826514a94e0 100644 --- a/Sources/NIOEmbedded/Embedded.swift +++ b/Sources/NIOEmbedded/Embedded.swift @@ -487,8 +487,7 @@ class EmbeddedChannelCore: ChannelCore { var outboundBuffer: CircularBuffer = CircularBuffer() /// Contains observers that want to consume the first element that would be appended to the `outboundBuffer` - @usableFromInline - var outboundBufferConsumer: Deque<(NIOAny) -> Void> = [] + private var outboundBufferConsumer: Deque<(Result) -> Void> = [] /// Contains the unflushed items that went into the `Channel` @usableFromInline @@ -502,8 +501,7 @@ class EmbeddedChannelCore: ChannelCore { var inboundBuffer: CircularBuffer = CircularBuffer() /// Contains observers that want to consume the first element that would be appended to the `inboundBuffer` - @usableFromInline - var inboundBufferConsumer: Deque<(NIOAny) -> Void> = [] + private var inboundBufferConsumer: Deque<(Result) -> Void> = [] @usableFromInline internal struct Addresses { @@ -567,6 +565,14 @@ class EmbeddedChannelCore: ChannelCore { isActive = false promise?.succeed(()) + // Return a `.failure` result containing an error to all pending inbound and outbound consumers. + while let consumer = self.inboundBufferConsumer.popFirst() { + consumer(.failure(ChannelError.ioOnClosedChannel)) + } + while let consumer = self.outboundBufferConsumer.popFirst() { + consumer(.failure(ChannelError.ioOnClosedChannel)) + } + // As we called register() in the constructor of EmbeddedChannel we also need to ensure we call unregistered here. self.pipeline.syncOperations.fireChannelInactive() self.pipeline.syncOperations.fireChannelUnregistered() @@ -661,16 +667,46 @@ class EmbeddedChannelCore: ChannelCore { private func addToBuffer( buffer: inout CircularBuffer, - consumer: inout Deque<(NIOAny) -> Void>, + consumer: inout Deque<(Result) -> Void>, data: NIOAny ) { self.eventLoop.preconditionInEventLoop() if let consume = consumer.popFirst() { - consume(data) + consume(.success(data)) } else { buffer.append(data) } } + + /// Enqueue a consumer closure that will be invoked upon the next pending inbound write. + /// - Parameter newElement: The consumer closure to enqueue. Returns a `.failure` result if the channel has already + /// closed. + func _enqueueInboundBufferConsumer(_ newElement: @escaping (Result) -> Void) { + self.eventLoop.preconditionInEventLoop() + + // The channel has already closed: there cannot be any further writes. Return a `.failure` result with an error. + guard self.isOpen else { + newElement(.failure(ChannelError.ioOnClosedChannel)) + return + } + + self.inboundBufferConsumer.append(newElement) + } + + /// Enqueue a consumer closure that will be invoked upon the next pending outbound write. + /// - Parameter newElement: The consumer closure to enqueue. Returns a `.failure` result if the channel has already + /// closed. + func _enqueueOutboundBufferConsumer(_ newElement: @escaping (Result) -> Void) { + self.eventLoop.preconditionInEventLoop() + + // The channel has already closed: there cannot be any further writes. Return a `.failure` result with an error. + guard self.isOpen else { + newElement(.failure(ChannelError.ioOnClosedChannel)) + return + } + + self.outboundBufferConsumer.append(newElement) + } } // ChannelCores are basically never Sendable. diff --git a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift index 4754029c31a..f8c69c42185 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift @@ -686,6 +686,70 @@ class AsyncTestingChannelTests: XCTestCase { try await XCTAsyncAssertTrue(try await channel.finish().isClean) } + func testWaitingForWriteTerminatesAfterChannelClose() async throws { + let channel = NIOAsyncTestingChannel() + + // Write some inbound and outbound data + for i in 1...3 { + try await channel.writeInbound(i) + try await channel.writeOutbound(i) + } + + // We should successfully see the three inbound and outbound writes + for i in 1...3 { + try await XCTAsyncAssertEqual(try await channel.waitForInboundWrite(), i) + try await XCTAsyncAssertEqual(try await channel.waitForOutboundWrite(), i) + } + + let task = Task { + // We close the channel after the third inbound/outbound write. Waiting again should result in a + // `ChannelError.ioOnClosedChannel` error. + await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + } + + // Close the channel without performing any writes + try await channel.close() + try await task.value + } + + func testEnqueueWriteConsumersBeforeChannelClosesWithoutAnyWrites() async throws { + let channel = NIOAsyncTestingChannel() + + let task = Task { + // We don't write anything to the channel and simply just close it. Waiting for an inbound/outbound write + // should result in a `ChannelError.ioOnClosedChannel` when the channel closes. + await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + } + + // Close the channel without performing any inbound or outbound writes + try await channel.close() + try await task.value + } + + func testEnqueueWriteConsumersAfterChannelClosesWithoutAnyWrites() async throws { + let channel = NIOAsyncTestingChannel() + // Immediately close the channel without performing any inbound or outbound writes + try await channel.close() + + // Now try to wait for an inbound/outbound write. This should result in a `ChannelError.ioOnClosedChannel`. + await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) { + XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel) + } + } + func testGetSetOption() async throws { let channel = NIOAsyncTestingChannel() let option = ChannelOptions.socket(IPPROTO_IP, IP_TTL) diff --git a/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift b/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift index b9e123a4e64..36a79c2f090 100644 --- a/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift +++ b/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift @@ -697,6 +697,133 @@ class EmbeddedChannelTest: XCTestCase { XCTAssertTrue(try channel.finish().isClean) } + func testWriteInboundBufferConsumer() throws { + let channel = EmbeddedChannel() + let invocationPromise = channel.eventLoop.makePromise(of: Void.self) + + channel.channelcore._enqueueInboundBufferConsumer { element in + invocationPromise.succeed() + switch element { + case .success(let result): + XCTAssertEqual( + channel.channelcore.tryUnwrapData(result, as: ByteBuffer.self), + ByteBuffer(string: "hello") + ) + case .failure(let error): + XCTFail("Unexpectedly received an error: \(error)") + } + } + + var buf = channel.allocator.buffer(capacity: 10) + buf.writeString("hello") + try channel.writeInbound(buf) + + XCTAssertTrue(invocationPromise.futureResult.isFulfilled) + } + + func testWriteOutboundBufferConsumer() throws { + let channel = EmbeddedChannel() + let invocationPromise = channel.eventLoop.makePromise(of: Void.self) + + channel.channelcore._enqueueOutboundBufferConsumer { element in + invocationPromise.succeed() + switch element { + case .success(let result): + XCTAssertEqual( + channel.channelcore.tryUnwrapData(result, as: ByteBuffer.self), + ByteBuffer(string: "hello") + ) + case .failure(let error): + XCTFail("Unexpectedly received an error: \(error)") + } + } + + var buf = channel.allocator.buffer(capacity: 10) + buf.writeString("hello") + channel.write(buf, promise: nil) + channel.flush() + + XCTAssertTrue(invocationPromise.futureResult.isFulfilled) + } + + func testQueueMultipleInboundAndOutboundBufferConsumersBeforeChannelClose() async throws { + let channel = EmbeddedChannel() + let inboundInvocationPromises = [EventLoopPromise]( + repeating: channel.eventLoop.makePromise(of: Void.self), + count: 3 + ) + let outboundInvocationPromises = [EventLoopPromise]( + repeating: channel.eventLoop.makePromise(of: Void.self), + count: 3 + ) + + // Enqueue 3 inbound and outbound consumers + for i in 0..<3 { + // Since the channel closes, all queued consumers should get a `ChannelError.ioOnClosedChannel` + channel.channelcore._enqueueInboundBufferConsumer { element in + inboundInvocationPromises[i].succeed() + switch element { + case .failure(let failure): + XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel) + case .success: + XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.") + } + } + + channel.channelcore._enqueueOutboundBufferConsumer { element in + outboundInvocationPromises[i].succeed() + switch element { + case .failure(let failure): + XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel) + case .success: + XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.") + } + } + } + + // Close the channel without performing any writes + try await channel.close() + XCTAssertEqual(channel.channelcore.isOpen, false) + + // Check that all consumer closures were invoked + XCTAssertTrue(inboundInvocationPromises.map(\.futureResult.isFulfilled).allSatisfy { $0 }) + XCTAssertTrue(outboundInvocationPromises.map(\.futureResult.isFulfilled).allSatisfy { $0 }) + } + + func testQueueInboundAndOutboundBufferConsumerAfterChannelClose() async throws { + let channel = EmbeddedChannel() + let inboundInvocationPromise = channel.eventLoop.makePromise(of: Void.self) + let outboundInvocationPromise = channel.eventLoop.makePromise(of: Void.self) + + // Close the channel immediately + try await channel.close() + XCTAssertEqual(channel.channelcore.isOpen, false) + + // Since the consumers are enqueued after the channel closed, they should get a `ChannelError.ioOnClosedChannel` + channel.channelcore._enqueueInboundBufferConsumer { element in + inboundInvocationPromise.succeed() + switch element { + case .failure(let failure): + XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel) + case .success: + XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.") + } + } + + channel.channelcore._enqueueOutboundBufferConsumer { element in + outboundInvocationPromise.succeed() + switch element { + case .failure(let failure): + XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel) + case .success: + XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.") + } + } + + XCTAssertTrue(inboundInvocationPromise.futureResult.isFulfilled) + XCTAssertTrue(outboundInvocationPromise.futureResult.isFulfilled) + } + func testGetSetOption() throws { let channel = EmbeddedChannel() let option = ChannelOptions.socket(IPPROTO_IP, IP_TTL)