From 49cd78bb767b1773688d33f6acb91f507ad109b8 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Mon, 21 Oct 2024 15:52:31 +0100 Subject: [PATCH] Fix withConnectedSocket in async mode (#2937) Motivation: The async flavour of withConnectedSocket accidentally type-erased the channel, which caused it to be unable to be used. Modifications: Un-erase the type of the channel. Add a test. Result: withConnectedSocket works again Resolves #2936 --- Sources/NIOPosix/Bootstrap.swift | 2 +- .../AsyncChannelBootstrapTests.swift | 79 +++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 3d1e3b9d93..f37abb37c5 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -1398,7 +1398,7 @@ extension ClientBootstrap { private func initializeAndRegisterChannel( channel: SocketChannel, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - registration: @escaping @Sendable (Channel) -> EventLoopFuture, + registration: @escaping @Sendable (SocketChannel) -> EventLoopFuture, postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< PostRegistrationTransformationResult > diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 98d8943598..62dfbbc291 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -609,6 +609,69 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + func testServerClientBootstrap_withAsyncChannel_clientConnectedSocket() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let channel = try await ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { () -> NIOAsyncChannel in + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + inboundType: String.self, + outboundType: String.self + ) + ) + } + } + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var iterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { _ in + try await channel.executeThenClose { inbound in + for try await childChannel in inbound { + try await childChannel.executeThenClose { childChannelInbound, _ in + for try await value in childChannelInbound { + continuation.yield(.string(value)) + } + } + } + } + } + } + + let s = try Socket(protocolFamily: .inet, type: .stream) + XCTAssert(try s.connect(to: channel.channel.localAddress!)) + let fd = try s.takeDescriptorOwnership() + + let stringChannel = try await self.makeClientChannel( + eventLoopGroup: eventLoopGroup, + fileDescriptor: fd + ) + try await stringChannel.executeThenClose { _, outbound in + try await outbound.write("hello") + } + + await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) + + group.cancelAll() + } + } + // MARK: Datagram Bootstrap func testDatagramBootstrap_withAsyncChannel_andHostPort() async throws { @@ -1280,6 +1343,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + private func makeClientChannel( + eventLoopGroup: EventLoopGroup, + fileDescriptor: CInt + ) async throws -> NIOAsyncChannel { + try await ClientBootstrap(group: eventLoopGroup) + .withConnectedSocket(fileDescriptor) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler()) + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } + private func makeClientChannelWithProtocolNegotiation( eventLoopGroup: EventLoopGroup, port: Int,