Skip to content

Commit

Permalink
Fix withConnectedSocket in async mode (#2937)
Browse files Browse the repository at this point in the history
Motivation:

The async flavour of withConnectedSocket accidentally type-erased the
channel, which caused it to be unable to be used.

Modifications:

Un-erase the type of the channel.
Add a test.

Result:

withConnectedSocket works again
Resolves #2936
  • Loading branch information
Lukasa authored Oct 21, 2024
1 parent cc1c57c commit 49cd78b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Sources/NIOPosix/Bootstrap.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1398,7 +1398,7 @@ extension ClientBootstrap {
private func initializeAndRegisterChannel<ChannelInitializerResult, PostRegistrationTransformationResult>(
channel: SocketChannel,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping @Sendable (Channel) -> EventLoopFuture<Void>,
registration: @escaping @Sendable (SocketChannel) -> EventLoopFuture<Void>,
postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture<
PostRegistrationTransformationResult
>
Expand Down
79 changes: 79 additions & 0 deletions Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,69 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
}

func testServerClientBootstrap_withAsyncChannel_clientConnectedSocket() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
}

let channel = try await ServerBootstrap(group: eventLoopGroup)
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(.autoRead, value: true)
.bind(
host: "127.0.0.1",
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture { () -> NIOAsyncChannel<String, String> in
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
return try NIOAsyncChannel(
wrappingChannelSynchronously: channel,
configuration: .init(
inboundType: String.self,
outboundType: String.self
)
)
}
}

try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()

group.addTask {
try await withThrowingTaskGroup(of: Void.self) { _ in
try await channel.executeThenClose { inbound in
for try await childChannel in inbound {
try await childChannel.executeThenClose { childChannelInbound, _ in
for try await value in childChannelInbound {
continuation.yield(.string(value))
}
}
}
}
}
}

let s = try Socket(protocolFamily: .inet, type: .stream)
XCTAssert(try s.connect(to: channel.channel.localAddress!))
let fd = try s.takeDescriptorOwnership()

let stringChannel = try await self.makeClientChannel(
eventLoopGroup: eventLoopGroup,
fileDescriptor: fd
)
try await stringChannel.executeThenClose { _, outbound in
try await outbound.write("hello")
}

await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))

group.cancelAll()
}
}

// MARK: Datagram Bootstrap

func testDatagramBootstrap_withAsyncChannel_andHostPort() async throws {
Expand Down Expand Up @@ -1280,6 +1343,22 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
}

private func makeClientChannel(
eventLoopGroup: EventLoopGroup,
fileDescriptor: CInt
) async throws -> NIOAsyncChannel<String, String> {
try await ClientBootstrap(group: eventLoopGroup)
.withConnectedSocket(fileDescriptor) { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler())
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
return try NIOAsyncChannel(wrappingChannelSynchronously: channel)
}
}
}

private func makeClientChannelWithProtocolNegotiation(
eventLoopGroup: EventLoopGroup,
port: Int,
Expand Down

0 comments on commit 49cd78b

Please sign in to comment.