Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions Sources/NIOEmbedded/AsyncTestingChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,10 @@ public final class NIOAsyncTestingChannel: Channel {

/// This method is similar to ``NIOAsyncTestingChannel/readOutbound(as:)`` but will wait if the outbound buffer is empty.
/// If available, this method reads one element of type `T` out of the ``NIOAsyncTestingChannel``'s outbound buffer. If the
/// first element was of a different type than requested, ``WrongTypeError`` will be thrown, if there
/// are no elements in the outbound buffer, `nil` will be returned.
/// first element was of a different type than requested, ``WrongTypeError`` will be thrown. If the channel has
/// already closed or closes before the next pending outbound write, `ChannelError.ioOnClosedChannel` will be
/// thrown. If there are no elements in the outbound buffer, this method will wait until there is one, and return
/// that element.
///
/// Data hits the ``NIOAsyncTestingChannel``'s outbound buffer when data was written using `write`, then `flush`ed, and
/// then travelled the `ChannelPipeline` all the way to the front. For data to hit the outbound buffer, the very
Expand All @@ -423,12 +425,13 @@ public final class NIOAsyncTestingChannel: Channel {
continuation.resume(returning: element)
return
}
self.channelcore.outboundBufferConsumer.append { element in
continuation.resume(
with: Result {
try self._cast(element)
}
)
self.channelcore.enqueueOutboundBufferConsumer { element in
switch element {
case .success(let data):
continuation.resume(with: Result { try self._cast(data) })
case .failure(let failure):
continuation.resume(throwing: failure)
}
}
} catch {
continuation.resume(throwing: error)
Expand Down Expand Up @@ -456,8 +459,10 @@ public final class NIOAsyncTestingChannel: Channel {

/// This method is similar to ``NIOAsyncTestingChannel/readInbound(as:)`` but will wait if the inbound buffer is empty.
/// If available, this method reads one element of type `T` out of the ``NIOAsyncTestingChannel``'s inbound buffer. If the
/// first element was of a different type than requested, ``WrongTypeError`` will be thrown, if there
/// are no elements in the outbound buffer, this method will wait until an element is in the inbound buffer.
/// first element was of a different type than requested, ``WrongTypeError`` will be thrown. If the channel has
/// already closed or closes before the next pending inbound write, `ChannelError.ioOnClosedChannel` will be thrown.
/// If there are no elements in the inbound buffer, this method will wait until there is one, and return that
/// element.
///
/// Data hits the ``NIOAsyncTestingChannel``'s inbound buffer when data was send through the pipeline using `fireChannelRead`
/// and then travelled the `ChannelPipeline` all the way to the back. For data to hit the inbound buffer, the
Expand All @@ -473,12 +478,13 @@ public final class NIOAsyncTestingChannel: Channel {
continuation.resume(returning: element)
return
}
self.channelcore.inboundBufferConsumer.append { element in
continuation.resume(
with: Result {
try self._cast(element)
}
)
self.channelcore.enqueueInboundBufferConsumer { element in
switch element {
case .success(let data):
continuation.resume(with: Result { try self._cast(data) })
case .failure(let failure):
continuation.resume(throwing: failure)
}
}
} catch {
continuation.resume(throwing: error)
Expand Down
48 changes: 42 additions & 6 deletions Sources/NIOEmbedded/Embedded.swift
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,7 @@ class EmbeddedChannelCore: ChannelCore {
var outboundBuffer: CircularBuffer<NIOAny> = CircularBuffer()

/// Contains observers that want to consume the first element that would be appended to the `outboundBuffer`
@usableFromInline
var outboundBufferConsumer: Deque<(NIOAny) -> Void> = []
private var outboundBufferConsumer: Deque<(Result<NIOAny, Error>) -> Void> = []

/// Contains the unflushed items that went into the `Channel`
@usableFromInline
Expand All @@ -502,8 +501,7 @@ class EmbeddedChannelCore: ChannelCore {
var inboundBuffer: CircularBuffer<NIOAny> = CircularBuffer()

/// Contains observers that want to consume the first element that would be appended to the `inboundBuffer`
@usableFromInline
var inboundBufferConsumer: Deque<(NIOAny) -> Void> = []
private var inboundBufferConsumer: Deque<(Result<NIOAny, Error>) -> Void> = []

@usableFromInline
internal struct Addresses {
Expand Down Expand Up @@ -567,6 +565,14 @@ class EmbeddedChannelCore: ChannelCore {
isActive = false
promise?.succeed(())

// Return a `.failure` result containing an error to all pending inbound and outbound consumers.
while let consumer = self.inboundBufferConsumer.popFirst() {
consumer(.failure(ChannelError.ioOnClosedChannel))
}
while let consumer = self.outboundBufferConsumer.popFirst() {
consumer(.failure(ChannelError.ioOnClosedChannel))
}

// As we called register() in the constructor of EmbeddedChannel we also need to ensure we call unregistered here.
self.pipeline.syncOperations.fireChannelInactive()
self.pipeline.syncOperations.fireChannelUnregistered()
Expand Down Expand Up @@ -661,16 +667,46 @@ class EmbeddedChannelCore: ChannelCore {

private func addToBuffer(
buffer: inout CircularBuffer<NIOAny>,
consumer: inout Deque<(NIOAny) -> Void>,
consumer: inout Deque<(Result<NIOAny, Error>) -> Void>,
data: NIOAny
) {
self.eventLoop.preconditionInEventLoop()
if let consume = consumer.popFirst() {
consume(data)
consume(.success(data))
} else {
buffer.append(data)
}
}

/// Enqueue a consumer closure that will be invoked upon the next pending inbound write.
/// - Parameter newElement: The consumer closure to enqueue. Returns a `.failure` result if the channel has already
/// closed.
func enqueueInboundBufferConsumer(_ newElement: @escaping (Result<NIOAny, Error>) -> Void) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we typically annotate methods only supposed to be called on the EL with a leading underscore, and I recommend making this private if it can be.

Copy link
Contributor Author

@aryan-25 aryan-25 Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a leading underscore to both methods (enqueue{In}{Out}boundBufferConsumer(_:)).

Those methods are defined in EmbeddedChannelCore and are also called from NIOAsyncTestingChannel, so unfortunately, they cannot be made private.

self.eventLoop.preconditionInEventLoop()

// The channel has already closed: there cannot be any further writes. Return a `.failure` result with an error.
guard self.isOpen else {
newElement(.failure(ChannelError.ioOnClosedChannel))
return
}

self.inboundBufferConsumer.append(newElement)
}

/// Enqueue a consumer closure that will be invoked upon the next pending outbound write.
/// - Parameter newElement: The consumer closure to enqueue. Returns a `.failure` result if the channel has already
/// closed.
func enqueueOutboundBufferConsumer(_ newElement: @escaping (Result<NIOAny, Error>) -> Void) {
self.eventLoop.preconditionInEventLoop()

// The channel has already closed: there cannot be any further writes. Return a `.failure` result with an error.
guard self.isOpen else {
newElement(.failure(ChannelError.ioOnClosedChannel))
return
}

self.outboundBufferConsumer.append(newElement)
}
}

// ChannelCores are basically never Sendable.
Expand Down
64 changes: 64 additions & 0 deletions Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,70 @@ class AsyncTestingChannelTests: XCTestCase {
try await XCTAsyncAssertTrue(try await channel.finish().isClean)
}

func testWaitingForWriteTerminatesAfterChannelClose() async throws {
let channel = NIOAsyncTestingChannel()

// Write some inbound and outbound data
for i in 1...3 {
try await channel.writeInbound(i)
try await channel.writeOutbound(i)
}

// We should successfully see the three inbound and outbound writes
for i in 1...3 {
try await XCTAsyncAssertEqual(try await channel.waitForInboundWrite(), i)
try await XCTAsyncAssertEqual(try await channel.waitForOutboundWrite(), i)
}

let task = Task {
// We close the channel after the third inbound/outbound write. Waiting again should result in a
// `ChannelError.ioOnClosedChannel` error.
await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
}

// Close the channel without performing any writes
try await channel.close()
try await task.value
}

func testEnqueueWriteConsumersBeforeChannelClosesWithoutAnyWrites() async throws {
let channel = NIOAsyncTestingChannel()

let task = Task {
// We don't write anything to the channel and simply just close it. Waiting for an inbound/outbound write
// should result in a `ChannelError.ioOnClosedChannel` when the channel closes.
await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
}

// Close the channel without performing any inbound or outbound writes
try await channel.close()
try await task.value
}

func testEnqueueWriteConsumersAfterChannelClosesWithoutAnyWrites() async throws {
let channel = NIOAsyncTestingChannel()
// Immediately close the channel without performing any inbound or outbound writes
try await channel.close()

// Now try to wait for an inbound/outbound write. This should result in a `ChannelError.ioOnClosedChannel`.
await XCTAsyncAssertThrowsError(try await channel.waitForInboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
await XCTAsyncAssertThrowsError(try await channel.waitForOutboundWrite(as: Int.self)) {
XCTAssertEqual($0 as? ChannelError, ChannelError.ioOnClosedChannel)
}
}

func testGetSetOption() async throws {
let channel = NIOAsyncTestingChannel()
let option = ChannelOptions.socket(IPPROTO_IP, IP_TTL)
Expand Down
127 changes: 127 additions & 0 deletions Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,133 @@ class EmbeddedChannelTest: XCTestCase {
XCTAssertTrue(try channel.finish().isClean)
}

func testWriteInboundBufferConsumer() throws {
let channel = EmbeddedChannel()
let invocationPromise = channel.eventLoop.makePromise(of: Void.self)

channel.channelcore.enqueueInboundBufferConsumer { element in
invocationPromise.succeed()
switch element {
case .success(let result):
XCTAssertEqual(
channel.channelcore.tryUnwrapData(result, as: ByteBuffer.self),
ByteBuffer(string: "hello")
)
case .failure(let error):
XCTFail("Unexpectedly received an error: \(error)")
}
}

var buf = channel.allocator.buffer(capacity: 10)
buf.writeString("hello")
try channel.writeInbound(buf)

XCTAssertTrue(invocationPromise.futureResult.isFulfilled)
}

func testWriteOutboundBufferConsumer() throws {
let channel = EmbeddedChannel()
let invocationPromise = channel.eventLoop.makePromise(of: Void.self)

channel.channelcore.enqueueOutboundBufferConsumer { element in
invocationPromise.succeed()
switch element {
case .success(let result):
XCTAssertEqual(
channel.channelcore.tryUnwrapData(result, as: ByteBuffer.self),
ByteBuffer(string: "hello")
)
case .failure(let error):
XCTFail("Unexpectedly received an error: \(error)")
}
}

var buf = channel.allocator.buffer(capacity: 10)
buf.writeString("hello")
channel.write(buf, promise: nil)
channel.flush()

XCTAssertTrue(invocationPromise.futureResult.isFulfilled)
}

func testQueueMultipleInboundAndOutboundBufferConsumersBeforeChannelClose() async throws {
let channel = EmbeddedChannel()
let inboundInvocationPromises = [EventLoopPromise<Void>](
repeating: channel.eventLoop.makePromise(of: Void.self),
count: 3
)
let outboundInvocationPromises = [EventLoopPromise<Void>](
repeating: channel.eventLoop.makePromise(of: Void.self),
count: 3
)

// Enqueue 3 inbound and outbound consumers
for i in 0..<3 {
// Since the channel closes, all queued consumers should get a `ChannelError.ioOnClosedChannel`
channel.channelcore.enqueueInboundBufferConsumer { element in
inboundInvocationPromises[i].succeed()
switch element {
case .failure(let failure):
XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel)
case .success:
XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.")
}
}

channel.channelcore.enqueueOutboundBufferConsumer { element in
outboundInvocationPromises[i].succeed()
switch element {
case .failure(let failure):
XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel)
case .success:
XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.")
}
}
}

// Close the channel without performing any writes
try await channel.close()
XCTAssertEqual(channel.channelcore.isOpen, false)

// Check that all consumer closures were invoked
XCTAssertTrue(inboundInvocationPromises.map(\.futureResult.isFulfilled).allSatisfy { $0 })
XCTAssertTrue(outboundInvocationPromises.map(\.futureResult.isFulfilled).allSatisfy { $0 })
}

func testQueueInboundAndOutboundBufferConsumerAfterChannelClose() async throws {
let channel = EmbeddedChannel()
let inboundInvocationPromise = channel.eventLoop.makePromise(of: Void.self)
let outboundInvocationPromise = channel.eventLoop.makePromise(of: Void.self)

// Close the channel immediately
try await channel.close()
XCTAssertEqual(channel.channelcore.isOpen, false)

// Since the consumers are enqueued after the channel closed, they should get a `ChannelError.ioOnClosedChannel`
channel.channelcore.enqueueInboundBufferConsumer { element in
inboundInvocationPromise.succeed()
switch element {
case .failure(let failure):
XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel)
case .success:
XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.")
}
}

channel.channelcore.enqueueOutboundBufferConsumer { element in
outboundInvocationPromise.succeed()
switch element {
case .failure(let failure):
XCTAssertEqual(failure as? ChannelError, ChannelError.ioOnClosedChannel)
case .success:
XCTFail("Unexpectedly received a successful result: no writes were performed on the channel.")
}
}

XCTAssertTrue(inboundInvocationPromise.futureResult.isFulfilled)
XCTAssertTrue(outboundInvocationPromise.futureResult.isFulfilled)
}

func testGetSetOption() throws {
let channel = EmbeddedChannel()
let option = ChannelOptions.socket(IPPROTO_IP, IP_TTL)
Expand Down