From e4aaf2ca353da8b5a9d85fb439daf7e0a758012e Mon Sep 17 00:00:00 2001 From: Stanislav Yaglo Date: Wed, 27 Sep 2023 10:46:26 +0100 Subject: [PATCH 01/64] Update CNIOSHA1.h to support C++ mode (#2523) --- Sources/CNIOSHA1/include/CNIOSHA1.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Sources/CNIOSHA1/include/CNIOSHA1.h b/Sources/CNIOSHA1/include/CNIOSHA1.h index 6ed17f79f5..b5c057565b 100644 --- a/Sources/CNIOSHA1/include/CNIOSHA1.h +++ b/Sources/CNIOSHA1/include/CNIOSHA1.h @@ -49,6 +49,10 @@ #include #include +#ifdef __cplusplus +extern "C" { +#endif + struct sha1_ctxt { union { uint8_t b8[20]; @@ -68,7 +72,12 @@ typedef struct sha1_ctxt SHA1_CTX; #define SHA1_RESULTLEN (160/8) +#ifdef __cplusplus +#define __min_size(x) (x) +#else #define __min_size(x) static (x) +#endif + extern void c_nio_sha1_init(struct sha1_ctxt *); extern void c_nio_sha1_pad(struct sha1_ctxt *); extern void c_nio_sha1_loop(struct sha1_ctxt *, const uint8_t *, size_t); @@ -79,5 +88,8 @@ extern void c_nio_sha1_result(struct sha1_ctxt *, char[__min_size(SHA1_RESULTLEN #define SHA1Update(x, y, z) c_nio_sha1_loop((x), (y), (z)) #define SHA1Final(x, y) c_nio_sha1_result((y), (x)) +#ifdef __cplusplus +} /* extern "C" */ +#endif #endif /*_CRYPTO_SHA1_H_*/ From 67553a7d6d6fb7bc62b41a02acf806091c2f5574 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Mon, 2 Oct 2023 11:38:07 +0100 Subject: [PATCH 02/64] Bump minimum Swift version to 5.7 (#2524) * Bump minimum Swift version to 5.7 Motivation: Now that Swift 5.9 is GM we should update the supported versions and remove 5.6 Modifications: * Update `Package.swift` * Remove `#if swift(>=5.7)` guards * Delete the 5.6 docker compose file and make a 5.10 one * Update integration test script * Update docs Result: Remove support for Swift 5.6, add 5.10 * fix indentation issues * 5.9 docker image use release image --- CONTRIBUTING.md | 2 +- .../tests_02_syscall_wrappers/defines.sh | 2 +- Package.swift | 2 +- README.md | 9 +- SECURITY.md | 4 +- .../AsyncChannelInboundStream.swift | 9 - .../AsyncChannelOutboundWriter.swift | 6 - Sources/NIOCore/ChannelHandlers.swift | 6 - Sources/NIOCore/ChannelPipeline.swift | 8 - Sources/NIOCore/Codec.swift | 4 - Sources/NIOCore/EventLoop.swift | 233 +------- Sources/NIOCore/EventLoopFuture.swift | 528 +----------------- .../SingleStepByteToMessageDecoder.swift | 2 - .../NIOCore/UniversalBootstrapSupport.swift | 37 +- Sources/NIOHTTP1/HTTPPipelineSetup.swift | 140 +---- Sources/NIOPosix/Bootstrap.swift | 115 +--- .../MultiThreadedEventLoopGroup.swift | 24 +- Sources/NIOPosix/NIOThreadPool.swift | 69 +-- Sources/NIOPosix/NonBlockingFileIO.swift | 129 +---- Sources/NIOPosix/Thread.swift | 20 +- .../NIOWebSocketServerUpgrader.swift | 72 +-- .../HTTPServerUpgradeTests.swift | 4 - ...4.56.yaml => docker-compose.2204.510.yaml} | 46 +- docker/docker-compose.2204.59.yaml | 3 +- 24 files changed, 100 insertions(+), 1374 deletions(-) rename docker/{docker-compose.2004.56.yaml => docker-compose.2204.510.yaml} (71%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eab29e7456..f4953f0d60 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,7 @@ For this reason, whenever you add new tests **you have to run a script** that ge ### Make sure your patch works for all supported versions of swift -The CI will do this for you. You can use the docker-compose files included if you wish to check locally. Currently all versions of swift >= 5.6 are supported. For example usage of docker compose see the main [README](./README.md#an-alternative-using-docker-compose) +The CI will do this for you. You can use the docker-compose files included if you wish to check locally. Currently all versions of swift >= 5.7 are supported. For example usage of docker compose see the main [README](./README.md#an-alternative-using-docker-compose) ### Make sure your code is performant diff --git a/IntegrationTests/tests_02_syscall_wrappers/defines.sh b/IntegrationTests/tests_02_syscall_wrappers/defines.sh index ac7793b046..c289646759 100644 --- a/IntegrationTests/tests_02_syscall_wrappers/defines.sh +++ b/IntegrationTests/tests_02_syscall_wrappers/defines.sh @@ -22,7 +22,7 @@ function make_package() { fi cat > "$tmpdir/syscallwrapper/Package.swift" <<"EOF" -// swift-tools-version:5.6 +// swift-tools-version:5.7 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription diff --git a/Package.swift b/Package.swift index c23c131429..4ba1b85842 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.6 +// swift-tools-version:5.7 //===----------------------------------------------------------------------===// // // This source file is part of the SwiftNIO open source project diff --git a/README.md b/README.md index 55b77050d5..565a6d0788 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ It's like [Netty](https://netty.io), but written for Swift. The SwiftNIO project is split across multiple repositories: -Repository | NIO 2 (Swift 5.6+) +Repository | NIO 2 (Swift 5.7+) --- | --- [https://github.com/apple/swift-nio][repo-nio]
SwiftNIO core | `from: "2.0.0"` [https://github.com/apple/swift-nio-ssl][repo-nio-ssl]
TLS (SSL) support | `from: "2.0.0"` @@ -70,7 +70,7 @@ Redis | ✅ | ❌ | [swift-server/RediStack](https://github.com/swift-server/Red This is the current version of SwiftNIO and will be supported for the foreseeable future. -The most recent versions of SwiftNIO support Swift 5.6 and newer. The minimum Swift version supported by SwiftNIO releases are detailed below: +The most recent versions of SwiftNIO support Swift 5.7 and newer. The minimum Swift version supported by SwiftNIO releases are detailed below: SwiftNIO | Minimum Swift Version --------------------|---------------------- @@ -78,7 +78,8 @@ SwiftNIO | Minimum Swift Version `2.30.0 ..< 2.40.0` | 5.2 `2.40.0 ..< 2.43.0` | 5.4 `2.43.0 ..< 2.51.0` | 5.5.2 -`2.51.0 ...` | 5.6 +`2.51.0 ..< 2.60.0` | 5.6 +`2.60.0 ...` | 5.7 ### SwiftNIO 1 SwiftNIO 1 is considered end of life - it is strongly recommended that you move to a newer version. The Core NIO team does not actively work on this version. No new features will be added to this version but PRs which fix bugs or security vulnerabilities will be accepted until the end of May 2022. @@ -332,7 +333,7 @@ have a few prerequisites installed on your system. ### Linux -- Swift 5.6 or newer from [swift.org/download](https://swift.org/download/#releases). We always recommend to use the latest released version. +- Swift 5.7 or newer from [swift.org/download](https://swift.org/download/#releases). We always recommend to use the latest released version. - netcat (for integration tests only) - lsof (for integration tests only) - shasum (for integration tests only) diff --git a/SECURITY.md b/SECURITY.md index d1e7b82c60..4f548dcfa9 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -19,8 +19,10 @@ team would create the following patch releases: Swift 5.4 and later * NIO 2.50. + plus next patch release to address the issue for projects that support Swift 5.5.2 and later -* mainline + plus next patch release to address the issue for projects that support +* NIO 2.59. + plus next patch release to address the issue for projects that support Swift 5.6 and later +* mainline + plus next patch release to address the issue for projects that support + Swift 5.7 and later SwiftNIO 1.x is considered end of life and will not receive any security patches. diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index 6826be3d9e..0b5d199fd0 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -48,20 +48,11 @@ public struct NIOAsyncChannelInboundStream: Sendable { } } - #if swift(>=5.7) @usableFromInline enum _Backing: Sendable { case asyncStream(AsyncThrowingStream) case producer(Producer) } - #else - // AsyncStream wasn't marked as `Sendable` in 5.6 - @usableFromInline - enum _Backing: @unchecked Sendable { - case asyncStream(AsyncThrowingStream) - case producer(Producer) - } - #endif /// The underlying async sequence. @usableFromInline diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 13934d5ec7..4fda58f12d 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -163,11 +163,5 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { } } -#if swift(>=5.7) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelOutboundWriter.TestSink: Sendable {} -#else -// AsyncStream wasn't marked as `Sendable` in 5.6 -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -extension NIOAsyncChannelOutboundWriter.TestSink: @unchecked Sendable {} -#endif diff --git a/Sources/NIOCore/ChannelHandlers.swift b/Sources/NIOCore/ChannelHandlers.swift index 89dbc594d6..18a9318f4b 100644 --- a/Sources/NIOCore/ChannelHandlers.swift +++ b/Sources/NIOCore/ChannelHandlers.swift @@ -105,10 +105,8 @@ public final class AcceptBackoffHandler: ChannelDuplexHandler, RemovableChannelH } } -#if swift(>=5.7) @available(*, unavailable) extension AcceptBackoffHandler: Sendable {} -#endif /** ChannelHandler implementation which enforces back-pressure by stopping to read from the remote peer when it cannot write back fast enough. @@ -157,10 +155,8 @@ public final class BackPressureHandler: ChannelDuplexHandler, RemovableChannelHa } } -#if swift(>=5.7) @available(*, unavailable) extension BackPressureHandler: Sendable {} -#endif /// Triggers an IdleStateEvent when a Channel has not performed read, write, or both operation for a while. public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandler { @@ -347,7 +343,5 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl } } -#if swift(>=5.7) @available(*, unavailable) extension IdleStateHandler: Sendable {} -#endif diff --git a/Sources/NIOCore/ChannelPipeline.swift b/Sources/NIOCore/ChannelPipeline.swift index d9d2c72ecf..246d86c0f0 100644 --- a/Sources/NIOCore/ChannelPipeline.swift +++ b/Sources/NIOCore/ChannelPipeline.swift @@ -946,9 +946,7 @@ public final class ChannelPipeline: ChannelInvoker { } } -#if swift(>=5.7) extension ChannelPipeline: @unchecked Sendable {} -#endif extension ChannelPipeline { /// Adds the provided channel handlers to the pipeline in the order given, taking account @@ -1274,10 +1272,8 @@ extension ChannelPipeline { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelPipeline.SynchronousOperations: Sendable {} -#endif extension ChannelPipeline { /// A `Position` within the `ChannelPipeline` used to insert handlers into the `ChannelPipeline`. @@ -1296,10 +1292,8 @@ extension ChannelPipeline { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelPipeline.Position: Sendable {} -#endif /// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation. /* private but tests */ final class HeadChannelHandler: _ChannelOutboundHandler { @@ -1853,10 +1847,8 @@ public final class ChannelHandlerContext: ChannelInvoker { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelHandlerContext: Sendable {} -#endif extension ChannelHandlerContext { /// A `RemovalToken` is handed to a `RemovableChannelHandler` when its `removeHandler` function is invoked. A diff --git a/Sources/NIOCore/Codec.swift b/Sources/NIOCore/Codec.swift index 5d79ab39b7..a46beac870 100644 --- a/Sources/NIOCore/Codec.swift +++ b/Sources/NIOCore/Codec.swift @@ -485,10 +485,8 @@ public final class ByteToMessageHandler { } } -#if swift(>=5.7) @available(*, unavailable) extension ByteToMessageHandler: Sendable {} -#endif // MARK: ByteToMessageHandler: Test Helpers extension ByteToMessageHandler { @@ -776,10 +774,8 @@ public final class MessageToByteHandler: ChannelO } } -#if swift(>=5.7) @available(*, unavailable) extension MessageToByteHandler: Sendable {} -#endif extension MessageToByteHandler { public func handlerAdded(context: ChannelHandlerContext) { diff --git a/Sources/NIOCore/EventLoop.swift b/Sources/NIOCore/EventLoop.swift index 6675f118f9..fe65e3c55e 100644 --- a/Sources/NIOCore/EventLoop.swift +++ b/Sources/NIOCore/EventLoop.swift @@ -23,28 +23,16 @@ import CNIOLinux /// A `Scheduled` allows the user to either `cancel()` the execution of the scheduled task (if possible) or obtain a reference to the `EventLoopFuture` that /// will be notified once the execution is complete. public struct Scheduled { - #if swift(>=5.7) @usableFromInline typealias CancelationCallback = @Sendable () -> Void - #else - @usableFromInline typealias CancelationCallback = () -> Void - #endif /* private but usableFromInline */ @usableFromInline let _promise: EventLoopPromise /* private but usableFromInline */ @usableFromInline let _cancellationTask: CancelationCallback - #if swift(>=5.7) @inlinable @preconcurrency public init(promise: EventLoopPromise, cancellationTask: @escaping @Sendable () -> Void) { self._promise = promise self._cancellationTask = cancellationTask } - #else - @inlinable - public init(promise: EventLoopPromise, cancellationTask: @escaping () -> Void) { - self._promise = promise - self._cancellationTask = cancellationTask - } - #endif /// Try to cancel the execution of the scheduled task. /// @@ -63,19 +51,13 @@ public struct Scheduled { } } -#if swift(>=5.7) extension Scheduled: Sendable where T: Sendable {} -#endif /// Returned once a task was scheduled to be repeatedly executed on the `EventLoop`. /// /// A `RepeatedTask` allows the user to `cancel()` the repeated scheduling of further tasks. public final class RepeatedTask { - #if swift(>=5.7) typealias RepeatedTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture - #else - typealias RepeatedTaskCallback = (RepeatedTask) -> EventLoopFuture - #endif private let delay: TimeAmount private let eventLoop: EventLoop private let cancellationPromise: EventLoopPromise? @@ -196,9 +178,7 @@ public final class RepeatedTask { } } -#if swift(>=5.7) extension RepeatedTask: @unchecked Sendable {} -#endif /// An iterator over the `EventLoop`s forming an `EventLoopGroup`. /// @@ -226,9 +206,7 @@ public struct EventLoopIterator: Sequence, IteratorProtocol { } } -#if swift(>=5.7) extension EventLoopIterator: Sendable {} -#endif /// An EventLoop processes IO / tasks in an endless loop for `Channel`s until it's closed. /// @@ -270,16 +248,10 @@ public protocol EventLoop: EventLoopGroup { /// If it is necessary for correctness to confirm that you're on an event loop, prefer ``preconditionInEventLoop(file:line:)-7ukrq``. var inEventLoop: Bool { get } - #if swift(>=5.7) /// Submit a given task to be executed by the `EventLoop` @preconcurrency func execute(_ task: @escaping @Sendable () -> Void) - #else - /// Submit a given task to be executed by the `EventLoop` - func execute(_ task: @escaping () -> Void) - #endif - - #if swift(>=5.7) + /// Submit a given task to be executed by the `EventLoop`. Once the execution is complete the returned `EventLoopFuture` is notified. /// /// - parameters: @@ -287,16 +259,6 @@ public protocol EventLoop: EventLoopGroup { /// - returns: `EventLoopFuture` that is notified once the task was executed. @preconcurrency func submit(_ task: @escaping @Sendable () throws -> T) -> EventLoopFuture - #else - /// Submit a given task to be executed by the `EventLoop`. Once the execution is complete the returned `EventLoopFuture` is notified. - /// - /// - parameters: - /// - task: The closure that will be submitted to the `EventLoop` for execution. - /// - returns: `EventLoopFuture` that is notified once the task was executed. - func submit(_ task: @escaping () throws -> T) -> EventLoopFuture - #endif - - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` at the given time. /// @@ -309,20 +271,6 @@ public protocol EventLoop: EventLoopGroup { @discardableResult @preconcurrency func scheduleTask(deadline: NIODeadline, _ task: @escaping @Sendable () throws -> T) -> Scheduled - #else - /// Schedule a `task` that is executed by this `EventLoop` at the given time. - /// - /// - parameters: - /// - task: The synchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the completion of the task. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - func scheduleTask(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled - #endif - - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. /// @@ -336,19 +284,6 @@ public protocol EventLoop: EventLoopGroup { @discardableResult @preconcurrency func scheduleTask(in: TimeAmount, _ task: @escaping @Sendable () throws -> T) -> Scheduled - #else - /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. - /// - /// - parameters: - /// - task: The synchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the completion of the task. - /// - /// - note: You can only cancel a task before it has started executing. - /// - note: The `in` value is clamped to a maximum when running on a Darwin-kernel. - @discardableResult - func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled - #endif /// Asserts that the current thread is the one tied to this `EventLoop`. /// Otherwise, the process will be abnormally terminated as per the semantics of `preconditionFailure(_:file:line:)`. @@ -725,7 +660,6 @@ extension NIODeadline { } extension EventLoop { - #if swift(>=5.7) /// Submit `task` to be run on this `EventLoop`. /// /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be succeeded with @@ -740,22 +674,7 @@ extension EventLoop { _submit(task) } @usableFromInline typealias SubmitCallback = @Sendable () throws -> T - #else - /// Submit `task` to be run on this `EventLoop`. - /// - /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be succeeded with - /// `task`'s return value or failed if the execution of `task` threw an error. - /// - /// - parameters: - /// - task: The synchronous task to run. As everything that runs on the `EventLoop`, it must not block. - /// - returns: An `EventLoopFuture` containing the result of `task`'s execution. - @inlinable - public func submit(_ task: @escaping () throws -> T) -> EventLoopFuture { - _submit(task) - } - @usableFromInline typealias SubmitCallback = () throws -> T - #endif - + @inlinable func _submit(_ task: @escaping SubmitCallback) -> EventLoopFuture { let promise: EventLoopPromise = makePromise(file: #fileID, line: #line) @@ -771,7 +690,6 @@ extension EventLoop { return promise.futureResult } - #if swift(>=5.7) /// Submit `task` to be run on this `EventLoop`. /// /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be identical to @@ -786,28 +704,12 @@ extension EventLoop { self._flatSubmit(task) } @usableFromInline typealias FlatSubmitCallback = @Sendable () -> EventLoopFuture - #else - /// Submit `task` to be run on this `EventLoop`. - /// - /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be identical to - /// the `EventLoopFuture` returned by `task`. - /// - /// - parameters: - /// - task: The asynchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: An `EventLoopFuture` identical to the `EventLoopFuture` returned from `task`. - @inlinable - public func flatSubmit(_ task: @escaping () -> EventLoopFuture) -> EventLoopFuture { - self._flatSubmit(task) - } - @usableFromInline typealias FlatSubmitCallback = () -> EventLoopFuture - #endif - + @inlinable func _flatSubmit(_ task: @escaping FlatSubmitCallback) -> EventLoopFuture { self.submit(task).flatMap { $0 } } - - #if swift(>=5.7) + /// Schedule a `task` that is executed by this `EventLoop` at the given time. /// /// - parameters: @@ -828,28 +730,7 @@ extension EventLoop { self._flatScheduleTask(deadline: deadline, file: file, line: line, task) } @usableFromInline typealias FlatScheduleTaskDeadlineCallback = () throws -> EventLoopFuture - #else - /// Schedule a `task` that is executed by this `EventLoop` at the given time. - /// - /// - parameters: - /// - task: The asynchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the full execution of the task, including its returned `EventLoopFuture`. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - @inlinable - public func flatScheduleTask( - deadline: NIODeadline, - file: StaticString = #fileID, - line: UInt = #line, - _ task: @escaping () throws -> EventLoopFuture - ) -> Scheduled { - self._flatScheduleTask(deadline: deadline, file: file, line: line, task) - } - @usableFromInline typealias FlatScheduleTaskDeadlineCallback = () throws -> EventLoopFuture - #endif - + @discardableResult @inlinable func _flatScheduleTask( @@ -865,7 +746,6 @@ extension EventLoop { return .init(promise: promise, cancellationTask: { scheduled.cancel() }) } - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. /// /// - parameters: @@ -886,28 +766,7 @@ extension EventLoop { self._flatScheduleTask(in: delay, file: file, line: line, task) } @usableFromInline typealias FlatScheduleTaskDelayCallback = @Sendable () throws -> EventLoopFuture - #else - /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. - /// - /// - parameters: - /// - task: The asynchronous task to run. As everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the full execution of the task, including its returned `EventLoopFuture`. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - @inlinable - public func flatScheduleTask( - in delay: TimeAmount, - file: StaticString = #fileID, - line: UInt = #line, - _ task: @escaping () throws -> EventLoopFuture - ) -> Scheduled { - self._flatScheduleTask(in: delay, file: file, line: line, task) - } - @usableFromInline typealias FlatScheduleTaskDelayCallback = () throws -> EventLoopFuture - #endif - + @inlinable func _flatScheduleTask( in delay: TimeAmount, @@ -998,7 +857,6 @@ extension EventLoop { // Do nothing } - #if swift(>=5.7) /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each /// task. @@ -1020,28 +878,7 @@ extension EventLoop { self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } typealias ScheduleRepeatedTaskCallback = @Sendable (RepeatedTask) throws -> Void - #else - /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each - /// task. - /// - /// - parameters: - /// - initialDelay: The delay after which the first task is executed. - /// - delay: The delay between the end of one task and the start of the next. - /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. - /// - task: The closure that will be executed. - /// - return: `RepeatedTask` - @discardableResult - public func scheduleRepeatedTask( - initialDelay: TimeAmount, - delay: TimeAmount, - notifying promise: EventLoopPromise? = nil, - _ task: @escaping (RepeatedTask) throws -> Void - ) -> RepeatedTask { - self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) - } - typealias ScheduleRepeatedTaskCallback = (RepeatedTask) throws -> Void - #endif - + func _scheduleRepeatedTask( initialDelay: TimeAmount, delay: TimeAmount, @@ -1059,7 +896,6 @@ extension EventLoop { return self.scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, futureTask) } - #if swift(>=5.7) /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and /// start of each task. /// @@ -1087,35 +923,7 @@ extension EventLoop { self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } typealias ScheduleRepeatedAsyncTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture - #else - /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and - /// start of each task. - /// - /// - note: The delay is measured from the completion of one run's returned future to the start of the execution of - /// the next run. For example: If you schedule a task once per second but your task takes two seconds to - /// complete, the time interval between two subsequent runs will actually be three seconds (2s run time plus - /// the 1s delay.) - /// - /// - parameters: - /// - initialDelay: The delay after which the first task is executed. - /// - delay: The delay between the end of one task and the start of the next. - /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. - /// - task: The closure that will be executed. Task will keep repeating regardless of whether the future - /// gets fulfilled with success or error. - /// - /// - return: `RepeatedTask` - @discardableResult - public func scheduleRepeatedAsyncTask( - initialDelay: TimeAmount, - delay: TimeAmount, - notifying promise: EventLoopPromise? = nil, - _ task: @escaping (RepeatedTask) -> EventLoopFuture - ) -> RepeatedTask { - self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) - } - typealias ScheduleRepeatedAsyncTaskCallback = (RepeatedTask) -> EventLoopFuture - #endif - + func _scheduleRepeatedAsyncTask( initialDelay: TimeAmount, delay: TimeAmount, @@ -1210,20 +1018,12 @@ public protocol EventLoopGroup: AnyObject, _NIOPreconcurrencySendable { /// future or kick off some operation, use `any()`. func any() -> EventLoop - #if swift(>=5.7) /// Shuts down the eventloop gracefully. This function is clearly an outlier in that it uses a completion /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue /// instead. @preconcurrency func shutdownGracefully(queue: DispatchQueue, _ callback: @Sendable @escaping (Error?) -> Void) - #else - /// Shuts down the eventloop gracefully. This function is clearly an outlier in that it uses a completion - /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. - /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue - /// instead. - func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) - #endif - + /// Returns an `EventLoopIterator` over the `EventLoop`s in this `EventLoopGroup`. /// /// - returns: `EventLoopIterator` @@ -1245,27 +1045,14 @@ extension EventLoopGroup { } extension EventLoopGroup { - #if swift(>=5.7) @preconcurrency public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void) { self.shutdownGracefully(queue: .global(), callback) } - #else - public func shutdownGracefully(_ callback: @escaping (Error?) -> Void) { - self.shutdownGracefully(queue: .global(), callback) - } - #endif - - #if swift(>=5.7) @available(*, noasync, message: "this can end up blocking the calling thread", renamed: "shutdownGracefully()") public func syncShutdownGracefully() throws { try self._syncShutdownGracefully() } - #else - public func syncShutdownGracefully() throws { - try self._syncShutdownGracefully() - } - #endif private func _syncShutdownGracefully() throws { self._preconditionSafeToSyncShutdown(file: #fileID, line: #line) @@ -1306,9 +1093,7 @@ public enum NIOEventLoopGroupProvider { case createNew } -#if swift(>=5.7) extension NIOEventLoopGroupProvider: Sendable {} -#endif /// Different `Error`s that are specific to `EventLoop` operations / implementations. public enum EventLoopError: Error { diff --git a/Sources/NIOCore/EventLoopFuture.swift b/Sources/NIOCore/EventLoopFuture.swift index 6c79a08ef9..185e00fecc 100644 --- a/Sources/NIOCore/EventLoopFuture.swift +++ b/Sources/NIOCore/EventLoopFuture.swift @@ -24,13 +24,8 @@ import Dispatch /// This eliminates recursion when processing `flatMap()` chains. @usableFromInline internal struct CallbackList { - #if swift(>=5.7) @usableFromInline internal typealias Element = @Sendable () -> CallbackList - #else - @usableFromInline - internal typealias Element = () -> CallbackList - #endif @usableFromInline internal var firstCallback: Optional @usableFromInline @@ -196,7 +191,7 @@ public struct EventLoopPromise { public func fail(_ error: Error) { self._resolve(value: .failure(error)) } - + /// Complete the promise with the passed in `EventLoopFuture`. /// /// This method is equivalent to invoking `future.cascade(to: promise)`, @@ -259,7 +254,6 @@ public struct EventLoopPromise { } } - /// Holder for a result that will be provided later. /// /// Functions that promise to do work asynchronously can return an `EventLoopFuture`. @@ -447,7 +441,6 @@ extension EventLoopFuture: Equatable { // 'flatMap' and 'map' implementations. This is really the key of the entire system. extension EventLoopFuture { - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, /// which will provide a new `EventLoopFuture`. /// @@ -481,41 +474,7 @@ extension EventLoopFuture { self._flatMap(callback) } @usableFromInline typealias FlatMapCallback = @Sendable (Value) -> EventLoopFuture - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, - /// which will provide a new `EventLoopFuture`. - /// - /// This allows you to dynamically dispatch new asynchronous tasks as phases in a - /// longer series of processing steps. Note that you can use the results of the - /// current `EventLoopFuture` when determining how to dispatch the next operation. - /// - /// This works well when you have APIs that already know how to return `EventLoopFuture`s. - /// You can do something with the result of one and just return the next future: - /// - /// ``` - /// let d1 = networkRequest(args).future() - /// let d2 = d1.flatMap { t -> EventLoopFuture in - /// . . . something with t . . . - /// return netWorkRequest(args) - /// } - /// d2.whenSuccess { u in - /// NSLog("Result of second request: \(u)") - /// } - /// ``` - /// - /// Note: In a sense, the `EventLoopFuture` is returned before it's created. - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMap(_ callback: @escaping (Value) -> EventLoopFuture) -> EventLoopFuture { - self._flatMap(callback) - } - @usableFromInline typealias FlatMapCallback = (Value) -> EventLoopFuture - #endif - + @inlinable func _flatMap(_ callback: @escaping FlatMapCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -537,8 +496,7 @@ extension EventLoopFuture { } return next.futureResult } - - #if swift(>=5.7) + /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns a new value of type `NewValue`. The provided /// callback may optionally `throw`. @@ -559,28 +517,7 @@ extension EventLoopFuture { self._flatMapThrowing(callback) } @usableFromInline typealias FlatMapThrowingCallback = @Sendable (Value) throws -> NewValue - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns a new value of type `NewValue`. The provided - /// callback may optionally `throw`. - /// - /// Operations performed in `flatMapThrowing` should not block, or they will block the entire - /// event loop. `flatMapThrowing` is intended for use when you have a data-driven function that - /// performs a simple data transformation that can potentially error. - /// - /// If your callback function throws, the returned `EventLoopFuture` will error. - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMapThrowing(_ callback: @escaping (Value) throws -> NewValue) -> EventLoopFuture { - self._flatMapThrowing(callback) - } - @usableFromInline typealias FlatMapThrowingCallback = (Value) throws -> NewValue - #endif - + @inlinable func _flatMapThrowing(_ callback: @escaping FlatMapThrowingCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -599,8 +536,7 @@ extension EventLoopFuture { } return next.futureResult } - - #if swift(>=5.7) + /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// may recover from the error and returns a new value of type `Value`. The provided callback may optionally `throw`, /// in which case the `EventLoopFuture` will be in a failed state with the new thrown error. @@ -621,28 +557,7 @@ extension EventLoopFuture { self._flatMapErrorThrowing(callback) } @usableFromInline typealias FlatMapErrorThrowingCallback = @Sendable (Error) throws -> Value - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// may recover from the error and returns a new value of type `Value`. The provided callback may optionally `throw`, - /// in which case the `EventLoopFuture` will be in a failed state with the new thrown error. - /// - /// Operations performed in `flatMapErrorThrowing` should not block, or they will block the entire - /// event loop. `flatMapErrorThrowing` is intended for use when you have the ability to synchronously - /// recover from errors. - /// - /// If your callback function throws, the returned `EventLoopFuture` will error. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value or a rethrown error. - @inlinable - public func flatMapErrorThrowing(_ callback: @escaping (Error) throws -> Value) -> EventLoopFuture { - self._flatMapErrorThrowing(callback) - } - @usableFromInline typealias FlatMapErrorThrowingCallback = (Error) throws -> Value - #endif - + @inlinable func _flatMapErrorThrowing(_ callback: @escaping FlatMapErrorThrowingCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -662,7 +577,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns a new value of type `NewValue`. /// @@ -695,48 +609,11 @@ extension EventLoopFuture { self._map(callback) } @usableFromInline typealias MapCallback = @Sendable (Value) -> (NewValue) - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns a new value of type `NewValue`. - /// - /// Operations performed in `map` should not block, or they will block the entire event - /// loop. `map` is intended for use when you have a data-driven function that performs - /// a simple data transformation that cannot error. - /// - /// If you have a data-driven function that can throw, you should use `flatMapThrowing` - /// instead. - /// - /// ``` - /// let future1 = eventually() - /// let future2 = future1.map { T -> U in - /// ... stuff ... - /// return u - /// } - /// let future3 = future2.map { U -> V in - /// ... stuff ... - /// return v - /// } - /// ``` - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func map(_ callback: @escaping (Value) -> (NewValue)) -> EventLoopFuture { - self._map(callback) - } - @usableFromInline typealias MapCallback = (Value) -> (NewValue) - #endif - + @inlinable func _map(_ callback: @escaping MapCallback) -> EventLoopFuture { if NewValue.self == Value.self && NewValue.self == Void.self { - #if swift(>=5.7) self.whenSuccess(callback as! @Sendable (Value) -> Void) - #else - self.whenSuccess(callback as! (Value) -> Void) - #endif return self as! EventLoopFuture } else { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -747,7 +624,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// may recover from the error by returning an `EventLoopFuture`. The callback is intended to potentially /// recover from the error by returning a new `EventLoopFuture` that will eventually contain the recovered @@ -765,25 +641,7 @@ extension EventLoopFuture { self._flatMapError(callback) } @usableFromInline typealias FlatMapErrorCallback = @Sendable (Error) -> EventLoopFuture - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// may recover from the error by returning an `EventLoopFuture`. The callback is intended to potentially - /// recover from the error by returning a new `EventLoopFuture` that will eventually contain the recovered - /// result. - /// - /// If the callback cannot recover it should return a failed `EventLoopFuture`. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the recovered value. - @inlinable - public func flatMapError(_ callback: @escaping (Error) -> EventLoopFuture) -> EventLoopFuture { - self._flatMapError(callback) - } - @usableFromInline typealias FlatMapErrorCallback = (Error) -> EventLoopFuture - #endif - + @inlinable func _flatMapError(_ callback: @escaping FlatMapErrorCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -806,7 +664,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns either a new value (of type `NewValue`) or /// an error depending on the `Result` returned by the closure. @@ -826,27 +683,7 @@ extension EventLoopFuture { self._flatMapResult(body) } @usableFromInline typealias FlatMapResultCallback = @Sendable (Value) -> Result - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns either a new value (of type `NewValue`) or - /// an error depending on the `Result` returned by the closure. - /// - /// Operations performed in `flatMapResult` should not block, or they will block the entire - /// event loop. `flatMapResult` is intended for use when you have a data-driven function that - /// performs a simple data transformation that can potentially error. - /// - /// - /// - parameters: - /// - body: Function that will receive the value of this `EventLoopFuture` and return - /// a new value or error lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMapResult(_ body: @escaping (Value) -> Result) -> EventLoopFuture { - self._flatMapResult(body) - } - @usableFromInline typealias FlatMapResultCallback = (Value) -> Result - #endif - + @inlinable func _flatMapResult(_ body: @escaping FlatMapResultCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -866,7 +703,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// can recover from the error and return a new value of type `Value`. The provided callback may not `throw`, /// so this function should be used when the error is always recoverable. @@ -885,26 +721,7 @@ extension EventLoopFuture { self._recover(callback) } @usableFromInline typealias RecoverCallback = @Sendable (Error) -> Value - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// can recover from the error and return a new value of type `Value`. The provided callback may not `throw`, - /// so this function should be used when the error is always recoverable. - /// - /// Operations performed in `recover` should not block, or they will block the entire - /// event loop. `recover` is intended for use when you have the ability to synchronously - /// recover from errors. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the recovered value. - @inlinable - public func recover(_ callback: @escaping (Error) -> Value) -> EventLoopFuture { - self._recover(callback) - } - @usableFromInline typealias RecoverCallback = (Error) -> Value - #endif - + @inlinable func _recover(_ callback: @escaping RecoverCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -919,11 +736,7 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) @usableFromInline typealias AddCallbackCallback = @Sendable () -> CallbackList - #else - @usableFromInline typealias AddCallbackCallback = () -> CallbackList - #endif /// Add a callback. If there's already a value, invoke it and return the resulting list of new callback functions. @inlinable internal func _addCallback(_ callback: @escaping AddCallbackCallback) -> CallbackList { @@ -934,8 +747,7 @@ extension EventLoopFuture { } return callback() } - - #if swift(>=5.7) + /// Add a callback. If there's already a value, run as much of the chain as we can. @inlinable @preconcurrency // TODO: We want to remove @preconcurrency but it results in more allocations in 1000_udpconnections @@ -943,15 +755,7 @@ extension EventLoopFuture { self._internalWhenComplete(callback) } @usableFromInline typealias InternalWhenCompleteCallback = @Sendable () -> CallbackList - #else - /// Add a callback. If there's already a value, run as much of the chain as we can. - @inlinable - internal func _whenComplete(_ callback: @escaping () -> CallbackList) { - self._internalWhenComplete(callback) - } - @usableFromInline typealias InternalWhenCompleteCallback = () -> CallbackList - #endif - + /// Add a callback. If there's already a value, run as much of the chain as we can. @inlinable internal func _internalWhenComplete(_ callback: @escaping InternalWhenCompleteCallback) { @@ -964,7 +768,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a success result. /// @@ -981,24 +784,7 @@ extension EventLoopFuture { self._whenSuccess(callback) } @usableFromInline typealias WhenSuccessCallback = @Sendable (Value) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a success result. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `map` or `flatMap`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - callback: The callback that is called with the successful result of the `EventLoopFuture`. - @inlinable - public func whenSuccess(_ callback: @escaping (Value) -> Void) { - self._whenSuccess(callback) - } - @usableFromInline typealias WhenSuccessCallback = (Value) -> Void - #endif - + @inlinable func _whenSuccess(_ callback: @escaping WhenSuccessCallback) { self._whenComplete { @@ -1008,8 +794,7 @@ extension EventLoopFuture { return CallbackList() } } - - #if swift(>=5.7) + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a failure result. /// @@ -1026,24 +811,7 @@ extension EventLoopFuture { self._whenFailure(callback) } @usableFromInline typealias WhenFailureCallback = @Sendable (Error) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a failure result. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `recover` or `flatMapError`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - callback: The callback that is called with the failed result of the `EventLoopFuture`. - @inlinable - public func whenFailure(_ callback: @escaping (Error) -> Void) { - self._whenFailure(callback) - } - @usableFromInline typealias WhenFailureCallback = (Error) -> Void - #endif - + @inlinable func _whenFailure(_ callback: @escaping WhenFailureCallback) { self._whenComplete { @@ -1054,7 +822,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. /// @@ -1066,18 +833,6 @@ extension EventLoopFuture { self._publicWhenComplete(callback) } @usableFromInline typealias WhenCompleteCallback = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. - /// - /// - parameters: - /// - callback: The callback that is called when the `EventLoopFuture` is fulfilled. - @inlinable - public func whenComplete(_ callback: @escaping (Result) -> Void) { - self._publicWhenComplete(callback) - } - @usableFromInline typealias WhenCompleteCallback = (Result) -> Void - #endif @inlinable func _publicWhenComplete(_ callback: @escaping WhenCompleteCallback) { self._whenComplete { @@ -1223,7 +978,6 @@ extension EventLoopFuture { // MARK: wait extension EventLoopFuture { - #if swift(>=5.7) /// Wait for the resolution of this `EventLoopFuture` by blocking the current thread until it /// resolves. /// @@ -1242,25 +996,6 @@ extension EventLoopFuture { public func wait(file: StaticString = #file, line: UInt = #line) throws -> Value { return try self._wait(file: file, line: line) } - #else - /// Wait for the resolution of this `EventLoopFuture` by blocking the current thread until it - /// resolves. - /// - /// If the `EventLoopFuture` resolves with a value, that value is returned from `wait()`. If - /// the `EventLoopFuture` resolves with an error, that error will be thrown instead. - /// `wait()` will block whatever thread it is called on, so it must not be called on event loop - /// threads: it is primarily useful for testing, or for building interfaces between blocking - /// and non-blocking code. - /// - /// This is also forbidden in async contexts: prefer ``EventLoopFuture/get``. - /// - /// - returns: The value of the `EventLoopFuture` when it completes. - /// - throws: The error value of the `EventLoopFuture` if it errors. - @inlinable - public func wait(file: StaticString = #file, line: UInt = #line) throws -> Value { - return try self._wait(file: file, line: line) - } - #endif @inlinable func _wait(file: StaticString, line: UInt) throws -> Value { @@ -1289,7 +1024,6 @@ extension EventLoopFuture { // MARK: fold extension EventLoopFuture { - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when this `EventLoopFuture` and /// all the provided `futures` complete. It then provides the result of folding the value of this /// `EventLoopFuture` with the values of all the provided `futures`. @@ -1315,33 +1049,7 @@ extension EventLoopFuture { self._fold(futures, with: combiningFunction) } @usableFromInline typealias FoldCallback = @Sendable (Value, OtherValue) -> EventLoopFuture - #else - /// Returns a new `EventLoopFuture` that fires only when this `EventLoopFuture` and - /// all the provided `futures` complete. It then provides the result of folding the value of this - /// `EventLoopFuture` with the values of all the provided `futures`. - /// - /// This function is suited when you have APIs that already know how to return `EventLoopFuture`s. - /// - /// The returned `EventLoopFuture` will fail as soon as the a failure is encountered in any of the - /// `futures` (or in this one). However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall EventLoopFuture. - /// - /// - parameters: - /// - futures: An array of `EventLoopFuture` to wait for. - /// - with: A function that will be used to fold the values of two `EventLoopFuture`s and return a new value wrapped in an `EventLoopFuture`. - /// - returns: A new `EventLoopFuture` with the folded value whose callbacks run on `self.eventLoop`. - @inlinable - public func fold( - _ futures: [EventLoopFuture], - with combiningFunction: @escaping (Value, OtherValue) -> EventLoopFuture - ) -> EventLoopFuture { - self._fold(futures, with: combiningFunction) - } - @usableFromInline typealias FoldCallback = (Value, OtherValue) -> EventLoopFuture - #endif - + @inlinable func _fold( _ futures: [EventLoopFuture], @@ -1375,7 +1083,6 @@ extension EventLoopFuture { // MARK: reduce extension EventLoopFuture { - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. /// The new `EventLoopFuture` contains the result of reducing the `initialResult` with the /// values of the `[EventLoopFuture]`. @@ -1406,38 +1113,7 @@ extension EventLoopFuture { Self._reduce(initialResult, futures, on: eventLoop, nextPartialResult) } @usableFromInline typealias ReduceCallback = @Sendable (Value, InputValue) -> Value - #else - /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. - /// The new `EventLoopFuture` contains the result of reducing the `initialResult` with the - /// values of the `[EventLoopFuture]`. - /// - /// This function makes copies of the result for each EventLoopFuture, for a version which avoids - /// making copies, check out `reduce(into:)`. - /// - /// The returned `EventLoopFuture` will fail as soon as a failure is encountered in any of the - /// `futures`. However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall `EventLoopFuture`. - /// - /// - parameters: - /// - initialResult: An initial result to begin the reduction. - /// - futures: An array of `EventLoopFuture` to wait for. - /// - eventLoop: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. - /// - nextPartialResult: The bifunction used to produce partial results. - /// - returns: A new `EventLoopFuture` with the reduced value. - @inlinable - public static func reduce( - _ initialResult: Value, - _ futures: [EventLoopFuture], - on eventLoop: EventLoop, - _ nextPartialResult: @escaping (Value, InputValue) -> Value - ) -> EventLoopFuture { - Self._reduce(initialResult, futures, on: eventLoop, nextPartialResult) - } - @usableFromInline typealias ReduceCallback = (Value, InputValue) -> Value - #endif - + @inlinable static func _reduce( _ initialResult: Value, @@ -1454,7 +1130,6 @@ extension EventLoopFuture { return body } - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. /// The new `EventLoopFuture` contains the result of combining the `initialResult` with the /// values of the `[EventLoopFuture]`. This function is analogous to the standard library's @@ -1483,36 +1158,7 @@ extension EventLoopFuture { Self._reduce(into: initialResult, futures, on: eventLoop, updateAccumulatingResult) } @usableFromInline typealias ReduceIntoCallback = @Sendable (inout Value, InputValue) -> Void - #else - /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. - /// The new `EventLoopFuture` contains the result of combining the `initialResult` with the - /// values of the `[EventLoopFuture]`. This function is analogous to the standard library's - /// `reduce(into:)`, which does not make copies of the result type for each `EventLoopFuture`. - /// - /// The returned `EventLoopFuture` will fail as soon as a failure is encountered in any of the - /// `futures`. However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall `EventLoopFuture`. - /// - /// - parameters: - /// - initialResult: An initial result to begin the reduction. - /// - futures: An array of `EventLoopFuture` to wait for. - /// - eventLoop: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. - /// - updateAccumulatingResult: The bifunction used to combine partialResults with new elements. - /// - returns: A new `EventLoopFuture` with the combined value. - @inlinable - public static func reduce( - into initialResult: Value, - _ futures: [EventLoopFuture], - on eventLoop: EventLoop, - _ updateAccumulatingResult: @escaping (inout Value, InputValue) -> Void - ) -> EventLoopFuture { - Self._reduce(into: initialResult, futures, on: eventLoop, updateAccumulatingResult) - } - @usableFromInline typealias ReduceIntoCallback = (inout Value, InputValue) -> Void - #endif - + @inlinable static func _reduce( into initialResult: Value, @@ -1611,15 +1257,9 @@ extension EventLoopFuture { let reduced = eventLoop.makePromise(of: Void.self) let results: UnsafeMutableTransferBox<[Value?]> = .init(.init(repeating: nil, count: futures.count)) - #if swift(>=5.7) let callback = { @Sendable (index: Int, result: Value) in results.wrappedValue[index] = result } - #else - let callback = { (index: Int, result: Value) in - results.wrappedValue[index] = result - } - #endif if eventLoop.inEventLoop { self._reduceSuccesses0(reduced, futures, eventLoop, onValue: callback) @@ -1640,12 +1280,8 @@ extension EventLoopFuture { } } } - - #if swift(>=5.7) + @usableFromInline typealias ReduceSuccessCallback = @Sendable (Int, InputValue) -> Void - #else - @usableFromInline typealias ReduceSuccessCallback = (Int, InputValue) -> Void - #endif /// Loops through the futures array and attaches callbacks to execute `onValue` on the provided `EventLoop` when /// they succeed. The `onValue` will receive the index of the future that fulfilled the provided `Result`. /// @@ -1773,17 +1409,11 @@ extension EventLoopFuture { promise: EventLoopPromise<[Result]>) { let eventLoop = promise.futureResult.eventLoop let reduced = eventLoop.makePromise(of: Void.self) - + let results: UnsafeMutableTransferBox<[Result]> = .init(.init(repeating: .failure(OperationPlaceholderError()), count: futures.count)) - #if swift(>=5.7) let callback = { @Sendable (index: Int, result: Result) in results.wrappedValue[index] = result } - #else - let callback = { (index: Int, result: Result) in - results.wrappedValue[index] = result - } - #endif if eventLoop.inEventLoop { self._reduceCompletions0(reduced, futures, eventLoop, onResult: callback) @@ -1808,13 +1438,9 @@ extension EventLoopFuture { } } } - - #if swift(>=5.7) + @usableFromInline typealias ReduceCompletions = @Sendable (Int, Result) -> Void - #else - @usableFromInline typealias ReduceCompletions = (Int, Result) -> Void - #endif - + /// Loops through the futures array and attaches callbacks to execute `onResult` on the provided `EventLoop` when /// they complete. The `onResult` will receive the index of the future that fulfilled the provided `Result`. /// @@ -1889,7 +1515,6 @@ extension EventLoopFuture { // MARK: always extension EventLoopFuture { - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. /// @@ -1902,20 +1527,7 @@ extension EventLoopFuture { self._always(callback) } @usableFromInline typealias AlwaysCallback = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. - /// - /// - parameters: - /// - callback: the callback that is called when the `EventLoopFuture` is fulfilled. - /// - returns: the current `EventLoopFuture` - @inlinable - public func always(_ callback: @escaping (Result) -> Void) -> EventLoopFuture { - self._always(callback) - } - @usableFromInline typealias AlwaysCallback = (Result) -> Void - #endif - + @inlinable func _always(_ callback: @escaping AlwaysCallback) -> EventLoopFuture { self.whenComplete { result in callback(result) } @@ -1974,8 +1586,7 @@ extension EventLoopFuture { return value } } - - #if swift(>=5.7) + /// Unwrap an `EventLoopFuture` where its type parameter is an `Optional`. /// /// Unwraps a future returning a new `EventLoopFuture` with either: the value returned by the closure passed in @@ -1998,30 +1609,7 @@ extension EventLoopFuture { self._unwrap(orElse: callback) } @usableFromInline typealias UnwrapCallback = @Sendable () -> NewValue - #else - /// Unwrap an `EventLoopFuture` where its type parameter is an `Optional`. - /// - /// Unwraps a future returning a new `EventLoopFuture` with either: the value returned by the closure passed in - /// the `orElse` parameter when the future resolved with value Optional.none, or the same value otherwise. For example: - /// ``` - /// var x = 2 - /// promise.futureResult.unwrap(orElse: { x * 2 }).wait() - /// ``` - /// - /// - parameters: - /// - orElse: a closure that returns the value of the returned `EventLoopFuture` when then resolved future's value - /// is `Optional.some()`. - /// - returns: an new `EventLoopFuture` with new type parameter `NewValue` and with the value returned by the closure - /// passed in the `orElse` parameter. - @inlinable - public func unwrap( - orElse callback: @escaping () -> NewValue - ) -> EventLoopFuture where Value == Optional { - self._unwrap(orElse: callback) - } - @usableFromInline typealias UnwrapCallback = () -> NewValue - #endif - + @inlinable func _unwrap( orElse callback: @escaping UnwrapCallback @@ -2038,7 +1626,6 @@ extension EventLoopFuture { // MARK: may block extension EventLoopFuture { - #if swift(>=5.7) /// Chain an `EventLoopFuture` providing the result of a IO / task that may block. For example: /// /// promise.futureResult.flatMapBlocking(onto: DispatchQueue.global()) { value in Int @@ -2058,27 +1645,7 @@ extension EventLoopFuture { self._flatMapBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias FlatMapBlockingCallback = @Sendable (Value) throws -> NewValue - #else - /// Chain an `EventLoopFuture` providing the result of a IO / task that may block. For example: - /// - /// promise.futureResult.flatMapBlocking(onto: DispatchQueue.global()) { value in Int - /// blockingTask(value) - /// } - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: Function that will receive the value of this `EventLoopFuture` and return - /// a new `EventLoopFuture`. - @inlinable - public func flatMapBlocking( - onto queue: DispatchQueue, - _ callbackMayBlock: @escaping (Value) throws -> NewValue - ) -> EventLoopFuture { - self._flatMapBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias FlatMapBlockingCallback = (Value) throws -> NewValue - #endif - + @inlinable func _flatMapBlocking( onto queue: DispatchQueue, @@ -2088,7 +1655,7 @@ extension EventLoopFuture { queue.asyncWithFuture(eventLoop: self.eventLoop) { try callbackMayBlock(result) } } } - + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a success result. The observer callback is permitted to block. /// @@ -2106,8 +1673,7 @@ extension EventLoopFuture { queue.async { callbackMayBlock(value) } } } - - #if swift(>=5.7) + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a failure result. The observer callback is permitted to block. /// @@ -2125,34 +1691,14 @@ extension EventLoopFuture { self._whenFailureBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenFailureBlockingCallback = @Sendable (Error) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a failure result. The observer callback is permitted to block. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `recover` or `flatMapError`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: The callback that is called with the failed result of the `EventLoopFuture`. - @inlinable - public func whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping (Error) -> Void) { - self._whenFailureBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias WhenFailureBlockingCallback = (Error) -> Void - #endif - + @inlinable func _whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping WhenFailureBlockingCallback) { self.whenFailure { err in queue.async { callbackMayBlock(err) } } } - - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. The observer callback is permitted to block. /// @@ -2165,20 +1711,7 @@ extension EventLoopFuture { self._whenCompleteBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenCompleteBlocking = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. The observer callback is permitted to block. - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: The callback that is called when the `EventLoopFuture` is fulfilled. - @inlinable - public func whenCompleteBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping (Result) -> Void) { - self._whenCompleteBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias WhenCompleteBlocking = (Result) -> Void - #endif - + @inlinable func _whenCompleteBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping WhenCompleteBlocking) { self.whenComplete { value in @@ -2283,7 +1816,6 @@ public struct _NIOEventLoopFutureIdentifier: Hashable, Sendable { self.opaqueID = _NIOEventLoopFutureIdentifier.obfuscatePointerValue(future: future) } - private static func obfuscatePointerValue(future: EventLoopFuture) -> UInt { // Note: // 1. 0xbf15ca5d is randomly picked such that it fits into both 32 and 64 bit address spaces diff --git a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift index 5cda5a5794..2341a0550e 100644 --- a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift +++ b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift @@ -274,10 +274,8 @@ public final class NIOSingleStepByteToMessageProcessor=5.7) @available(*, unavailable) extension NIOSingleStepByteToMessageProcessor: Sendable {} -#endif // MARK: NIOSingleStepByteToMessageProcessor Public API extension NIOSingleStepByteToMessageProcessor { diff --git a/Sources/NIOCore/UniversalBootstrapSupport.swift b/Sources/NIOCore/UniversalBootstrapSupport.swift index 01ab959e1a..b89594a201 100644 --- a/Sources/NIOCore/UniversalBootstrapSupport.swift +++ b/Sources/NIOCore/UniversalBootstrapSupport.swift @@ -15,7 +15,6 @@ /// `NIOClientTCPBootstrapProtocol` is implemented by various underlying transport mechanisms. Typically, /// this will be the BSD Sockets API implemented by `ClientBootstrap`. public protocol NIOClientTCPBootstrapProtocol { - #if swift(>=5.7) /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -35,28 +34,7 @@ public protocol NIOClientTCPBootstrapProtocol { /// - handler: A closure that initializes the provided `Channel`. @preconcurrency func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self - #else - /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - warning: The `handler` closure may be invoked _multiple times_ so it's usually the right choice to instantiate - /// `ChannelHandler`s within `handler`. The reason `handler` may be invoked multiple times is that to - /// successfully set up a connection multiple connections might be setup in the process. Assuming a - /// hostname that resolves to both IPv4 and IPv6 addresses, NIO will follow - /// [_Happy Eyeballs_](https://en.wikipedia.org/wiki/Happy_Eyeballs) and race both an IPv4 and an IPv6 - /// connection. It is possible that both connections get fully established before the IPv4 connection - /// will be closed again because the IPv6 connection 'won the race'. Therefore the `channelInitializer` - /// might be called multiple times and it's important not to share stateful `ChannelHandler`s in more - /// than one `Channel`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self - #endif - - #if swift(>=5.7) + /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the /// `channelInitializer` has been called. /// @@ -65,15 +43,6 @@ public protocol NIOClientTCPBootstrapProtocol { /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. @preconcurrency func protocolHandlers(_ handlers: @escaping @Sendable () -> [ChannelHandler]) -> Self - #else - /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the - /// `channelInitializer` has been called. - /// - /// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS - /// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses - /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. - func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self - #endif /// Specifies a `ChannelOption` to be applied to the `SocketChannel`. /// @@ -81,7 +50,7 @@ public protocol NIOClientTCPBootstrapProtocol { /// - option: The option to be applied. /// - value: The value for the option. func channelOption(_ option: Option, value: Option.Value) -> Self - + /// Apply any understood convenience options to the bootstrap, removing them from the set of options if they are consumed. /// Method is optional to implement and should never be directly called by users. /// - parameters: @@ -198,7 +167,7 @@ public struct NIOClientTCPBootstrap { self.underlyingBootstrap = bootstrap self.tlsEnablerTypeErased = tlsEnabler } - + internal init(_ original : NIOClientTCPBootstrap, updating underlying : NIOClientTCPBootstrapProtocol) { self.underlyingBootstrap = underlying self.tlsEnablerTypeErased = original.tlsEnablerTypeErased diff --git a/Sources/NIOHTTP1/HTTPPipelineSetup.swift b/Sources/NIOHTTP1/HTTPPipelineSetup.swift index 251143f956..c433665afc 100644 --- a/Sources/NIOHTTP1/HTTPPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPPipelineSetup.swift @@ -14,19 +14,11 @@ import NIOCore -#if swift(>=5.7) /// Configuration required to configure a HTTP client pipeline for upgrade. /// /// See the documentation for `HTTPClientUpgradeHandler` for details on these /// properties. public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) -#else -/// Configuration required to configure a HTTP client pipeline for upgrade. -/// -/// See the documentation for `HTTPClientUpgradeHandler` for details on these -/// properties. -public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: (ChannelHandlerContext) -> Void) -#endif /// Configuration required to configure a HTTP server pipeline for upgrade. /// @@ -35,11 +27,7 @@ public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientP @available(*, deprecated, renamed: "NIOHTTPServerUpgradeConfiguration") public typealias HTTPUpgradeConfiguration = NIOHTTPServerUpgradeConfiguration -#if swift(>=5.7) public typealias NIOHTTPServerUpgradeConfiguration = (upgraders: [HTTPServerProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) -#else -public typealias NIOHTTPServerUpgradeConfiguration = (upgraders: [HTTPServerProtocolUpgrader], completionHandler: (ChannelHandlerContext) -> Void) -#endif extension ChannelPipeline { /// Configure a `ChannelPipeline` for use as a HTTP client. @@ -56,7 +44,6 @@ extension ChannelPipeline { withClientUpgrade: nil) } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. /// /// - parameters: @@ -78,29 +65,7 @@ extension ChannelPipeline { withClientUpgrade: upgrade ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. - /// - /// - parameters: - /// - position: The position in the `ChannelPipeline` where to add the HTTP client handlers. Defaults to `.last`. - /// - leftOverBytesStrategy: The strategy to use when dealing with leftover bytes after removing the `HTTPDecoder` - /// from the pipeline. - /// - upgrade: Add a `HTTPClientUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Should be a tuple of an array of `HTTPClientProtocolUpgrader` and - /// the upgrade completion handler. See the documentation on `HTTPClientUpgradeHandler` - /// for more details. - /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { - self._addHTTPClientHandlers( - position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade - ) - } - #endif - + private func _addHTTPClientHandlers(position: Position = .last, leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { @@ -206,7 +171,6 @@ extension ChannelPipeline { return future } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP server. /// /// This function knows how to set up all first-party HTTP channel handlers appropriately @@ -244,44 +208,6 @@ extension ChannelPipeline { withErrorHandling: errorHandling ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP server. - /// - /// This function knows how to set up all first-party HTTP channel handlers appropriately - /// for server use. It supports the following features: - /// - /// 1. Providing assistance handling clients that pipeline HTTP requests, using the - /// `HTTPServerPipelineHandler`. - /// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`. - /// - /// This method will likely be extended in future with more support for other first-party - /// features. - /// - /// - parameters: - /// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`. - /// - pipelining: Whether to provide assistance handling HTTP clients that pipeline - /// their requests. Defaults to `true`. If `false`, users will need to handle - /// clients that pipeline themselves. - /// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If - /// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade - /// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more - /// details. - /// - errorHandling: Whether to provide assistance handling protocol errors (e.g. - /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. - /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) -> EventLoopFuture { - self._configureHTTPServerPipeline( - position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP server. /// @@ -410,7 +336,6 @@ extension ChannelPipeline { } extension ChannelPipeline.SynchronousOperations { - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. /// /// - important: This **must** be called on the Channel's event loop. @@ -433,29 +358,6 @@ extension ChannelPipeline.SynchronousOperations { withClientUpgrade: upgrade ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. - /// - /// - important: This **must** be called on the Channel's event loop. - /// - parameters: - /// - position: The position in the `ChannelPipeline` where to add the HTTP client handlers. Defaults to `.last`. - /// - leftOverBytesStrategy: The strategy to use when dealing with leftover bytes after removing the `HTTPDecoder` - /// from the pipeline. - /// - upgrade: Add a `HTTPClientUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Should be a tuple of an array of `HTTPClientProtocolUpgrader` and - /// the upgrade completion handler. See the documentation on `HTTPClientUpgradeHandler` - /// for more details. - /// - throws: If the pipeline could not be configured. - public func addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { - try self._addHTTPClientHandlers( - position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP client. /// @@ -558,7 +460,6 @@ extension ChannelPipeline.SynchronousOperations { try self.addHandlers(handlers, position: position) } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP server. /// /// This function knows how to set up all first-party HTTP channel handlers appropriately @@ -597,45 +498,6 @@ extension ChannelPipeline.SynchronousOperations { withErrorHandling: errorHandling ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP server. - /// - /// This function knows how to set up all first-party HTTP channel handlers appropriately - /// for server use. It supports the following features: - /// - /// 1. Providing assistance handling clients that pipeline HTTP requests, using the - /// `HTTPServerPipelineHandler`. - /// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`. - /// - /// This method will likely be extended in future with more support for other first-party - /// features. - /// - /// - important: This **must** be called on the Channel's event loop. - /// - parameters: - /// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`. - /// - pipelining: Whether to provide assistance handling HTTP clients that pipeline - /// their requests. Defaults to `true`. If `false`, users will need to handle - /// clients that pipeline themselves. - /// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If - /// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade - /// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more - /// details. - /// - errorHandling: Whether to provide assistance handling protocol errors (e.g. - /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. - /// - throws: If the pipeline could not be configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) throws { - try self._configureHTTPServerPipeline( - position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP server. /// diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 0e90bb85d6..7d5c2a6606 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -25,13 +25,8 @@ import struct WinSDK.DWORD import struct WinSDK.HANDLE #endif -#if swift(>=5.7) /// The type of all `channelInitializer` callbacks. internal typealias ChannelInitializerCallback = @Sendable (Channel) -> EventLoopFuture -#else -/// The type of all `channelInitializer` callbacks. -internal typealias ChannelInitializerCallback = (Channel) -> EventLoopFuture -#endif /// Common functionality for all NIO on sockets bootstraps. internal enum NIOOnSocketsBootstraps { @@ -151,7 +146,6 @@ public final class ServerBootstrap { self.enableMPTCP = false } - #if swift(>=5.7) /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -166,23 +160,7 @@ public final class ServerBootstrap { self.serverChannelInit = initializer return self } - #else - /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages. - /// - /// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`. - /// - /// - parameters: - /// - initializer: A closure that initializes the provided `Channel`. - public func serverChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture) -> Self { - self.serverChannelInit = initializer - return self - } - #endif - #if swift(>=5.7) /// Initialize the accepted `SocketChannel`s with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. Note that if the `initializer` fails then the error will be /// fired in the *parent* channel. @@ -202,26 +180,6 @@ public final class ServerBootstrap { self.childChannelInit = initializer return self } - #else - /// Initialize the accepted `SocketChannel`s with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. Note that if the `initializer` fails then the error will be - /// fired in the *parent* channel. - /// - /// - warning: The `initializer` will be invoked once for every accepted connection. Therefore it's usually the - /// right choice to instantiate stateful `ChannelHandler`s within the closure to make sure they are not - /// accidentally shared across `Channel`s. There are expert use-cases where stateful handler need to be - /// shared across `Channel`s in which case the user is responsible to synchronise the state access - /// appropriately. - /// - /// The accepted `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - parameters: - /// - initializer: A closure that initializes the provided `Channel`. - public func childChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture) -> Self { - self.childChannelInit = initializer - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `ServerSocketChannel`. /// @@ -738,11 +696,7 @@ private extension Channel { /// The connected `SocketChannel` will operate on `ByteBuffer` as inbound and on `IOData` as outbound messages. public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { private let group: EventLoopGroup - #if swift(>=5.7) private var protocolHandlers: Optional<@Sendable () -> [ChannelHandler]> - #else - private var protocolHandlers: Optional<() -> [ChannelHandler]> - #endif private var _channelInitializer: ChannelInitializerCallback private var channelInitializer: ChannelInitializerCallback { if let protocolHandlers = self.protocolHandlers { @@ -798,7 +752,6 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self.enableMPTCP = false } - #if swift(>=5.7) /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -821,31 +774,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self._channelInitializer = handler return self } - #else - /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - warning: The `handler` closure may be invoked _multiple times_ so it's usually the right choice to instantiate - /// `ChannelHandler`s within `handler`. The reason `handler` may be invoked multiple times is that to - /// successfully set up a connection multiple connections might be setup in the process. Assuming a - /// hostname that resolves to both IPv4 and IPv6 addresses, NIO will follow - /// [_Happy Eyeballs_](https://en.wikipedia.org/wiki/Happy_Eyeballs) and race both an IPv4 and an IPv6 - /// connection. It is possible that both connections get fully established before the IPv4 connection - /// will be closed again because the IPv6 connection 'won the race'. Therefore the `channelInitializer` - /// might be called multiple times and it's important not to share stateful `ChannelHandler`s in more - /// than one `Channel`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self._channelInitializer = handler - return self - } - #endif - #if swift(>=5.7) /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the /// `channelInitializer` has been called. /// @@ -858,19 +787,6 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self.protocolHandlers = handlers return self } - #else - /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the - /// `channelInitializer` has been called. - /// - /// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS - /// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses - /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. - public func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self { - precondition(self.protocolHandlers == nil, "protocol handlers can only be set once") - self.protocolHandlers = handlers - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `SocketChannel`. /// @@ -1435,7 +1351,6 @@ public final class DatagramBootstrap { self.channelInitializer = nil } - #if swift(>=5.7) /// Initialize the bound `DatagramChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -1446,17 +1361,6 @@ public final class DatagramBootstrap { self.channelInitializer = handler return self } - #else - /// Initialize the bound `DatagramChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self.channelInitializer = handler - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `DatagramChannel`. /// @@ -1998,7 +1902,6 @@ public final class NIOPipeBootstrap { self.channelInitializer = nil } - #if swift(>=5.7) /// Initialize the connected `PipeChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -2012,20 +1915,6 @@ public final class NIOPipeBootstrap { self.channelInitializer = handler return self } - #else - /// Initialize the connected `PipeChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and outbound messages. Please note that - /// `IOData.fileRegion` is _not_ supported for `PipeChannel`s because `sendfile` only works on sockets. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self.channelInitializer = handler - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `PipeChannel`. /// @@ -2207,7 +2096,7 @@ extension NIOPipeBootstrap { throw error } } - + /// Create the `PipeChannel` with the provided input and output file descriptors. /// /// The input and output file descriptors must be distinct. If you have a single file descriptor, consider using @@ -2240,7 +2129,7 @@ extension NIOPipeBootstrap { postRegisterTransformation: { $0.makeSucceededFuture($1) } ) } - + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @_spi(AsyncChannel) // Should become private public func _takingOwnershipOfDescriptors( diff --git a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift index f9214a5f91..8c85a873f3 100644 --- a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift +++ b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift @@ -53,12 +53,8 @@ typealias ThreadInitializer = (NIOThread) -> Void /// test. A good place to start a `MultiThreadedEventLoopGroup` is the `setUp` method of your `XCTestCase` /// subclass, a good place to shut it down is the `tearDown` method. public final class MultiThreadedEventLoopGroup: EventLoopGroup { - #if swift(>=5.7) private typealias ShutdownGracefullyCallback = @Sendable (Error?) -> Void - #else - private typealias ShutdownGracefullyCallback = (Error?) -> Void - #endif - + private enum RunState { case running case closing([(DispatchQueue, ShutdownGracefullyCallback)]) @@ -262,7 +258,6 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { } } - #if swift(>=5.7) /// Shut this `MultiThreadedEventLoopGroup` down which causes the `EventLoop`s and their associated threads to be /// shut down and release their resources. /// @@ -277,22 +272,7 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { public func shutdownGracefully(queue: DispatchQueue, _ handler: @escaping @Sendable (Error?) -> Void) { self._shutdownGracefully(queue: queue, handler) } - #else - /// Shut this `MultiThreadedEventLoopGroup` down which causes the `EventLoop`s and their associated threads to be - /// shut down and release their resources. - /// - /// Even though calling `shutdownGracefully` more than once should be avoided, it is safe to do so and execution - /// of the `handler` is guaranteed. - /// - /// - parameters: - /// - queue: The `DispatchQueue` to run `handler` on when the shutdown operation completes. - /// - handler: The handler which is called after the shutdown operation completes. The parameter will be `nil` - /// on success and contain the `Error` otherwise. - public func shutdownGracefully(queue: DispatchQueue, _ handler: @escaping (Error?) -> Void) { - self._shutdownGracefully(queue: queue, handler) - } - #endif - + private func _shutdownGracefully(queue: DispatchQueue, _ handler: @escaping ShutdownGracefullyCallback) { guard self.canBeShutDown else { queue.async { diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index b57c699bec..e6b631d1aa 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -18,7 +18,7 @@ import NIOConcurrencyHelpers /// Errors that may be thrown when executing work on a `NIOThreadPool` public enum NIOThreadPoolError { - + /// The `NIOThreadPool` was not active. public struct ThreadPoolInactive: Error { public init() {} @@ -55,14 +55,9 @@ public final class NIOThreadPool { /// The `WorkItem` was cancelled and will not be processed by the `NIOThreadPool`. case cancelled } - - #if swift(>=5.7) + /// The work that should be done by the `NIOThreadPool`. public typealias WorkItem = @Sendable (WorkItemState) -> Void - #else - /// The work that should be done by the `NIOThreadPool`. - public typealias WorkItem = (WorkItemState) -> Void - #endif private enum State { /// The `NIOThreadPool` is already stopped. case stopped @@ -78,7 +73,6 @@ public final class NIOThreadPool { private let numberOfThreads: Int private let canBeStopped: Bool - #if swift(>=5.7) /// Gracefully shutdown this `NIOThreadPool`. All tasks will be run before shutdown will take place. /// /// - parameters: @@ -88,17 +82,7 @@ public final class NIOThreadPool { public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping @Sendable (Error?) -> Void) { self._shutdownGracefully(queue: queue, callback) } - #else - /// Gracefully shutdown this `NIOThreadPool`. All tasks will be run before shutdown will take place. - /// - /// - parameters: - /// - queue: The `DispatchQueue` used to executed the callback - /// - callback: The function to be executed once the shutdown is complete. - public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { - self._shutdownGracefully(queue: queue, callback) - } - #endif - + private func _shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { guard self.canBeStopped else { queue.async { @@ -135,10 +119,9 @@ public final class NIOThreadPool { callback(nil) } } - - - #if swift(>=5.7) + + /// Submit a `WorkItem` to process. /// /// - note: This is a low-level method, in most cases the `runIfActive` method should be used. @@ -149,17 +132,6 @@ public final class NIOThreadPool { public func submit(_ body: @escaping WorkItem) { self._submit(body) } - #else - /// Submit a `WorkItem` to process. - /// - /// - note: This is a low-level method, in most cases the `runIfActive` method should be used. - /// - /// - parameters: - /// - body: The `WorkItem` to process by the `NIOThreadPool`. - public func submit(_ body: @escaping WorkItem) { - self._submit(body) - } - #endif private func _submit(_ body: @escaping WorkItem) { let item = self.lock.withLock { () -> WorkItem? in @@ -176,7 +148,7 @@ public final class NIOThreadPool { /* if item couldn't be added run it immediately indicating that it couldn't be run */ item.map { $0(.cancelled) } } - + /// Initialize a `NIOThreadPool` thread pool with `numberOfThreads` threads. /// /// - parameters: @@ -290,8 +262,7 @@ public final class NIOThreadPool { extension NIOThreadPool: @unchecked Sendable {} extension NIOThreadPool { - - #if swift(>=5.7) + /// Runs the submitted closure if the thread pool is still active, otherwise fails the promise. /// The closure will be run on the thread pool so can do blocking work. /// @@ -303,19 +274,7 @@ extension NIOThreadPool { public func runIfActive(eventLoop: EventLoop, _ body: @escaping @Sendable () throws -> T) -> EventLoopFuture { self._runIfActive(eventLoop: eventLoop, body) } - #else - /// Runs the submitted closure if the thread pool is still active, otherwise fails the promise. - /// The closure will be run on the thread pool so can do blocking work. - /// - /// - parameters: - /// - eventLoop: The `EventLoop` the returned `EventLoopFuture` will fire on. - /// - body: The closure which performs some blocking work to be done on the thread pool. - /// - returns: The `EventLoopFuture` of `promise` fulfilled with the result (or error) of the passed closure. - public func runIfActive(eventLoop: EventLoop, _ body: @escaping () throws -> T) -> EventLoopFuture { - self._runIfActive(eventLoop: eventLoop, body) - } - #endif - + private func _runIfActive(eventLoop: EventLoop, _ body: @escaping () throws -> T) -> EventLoopFuture { let promise = eventLoop.makePromise(of: T.self) self.submit { shouldRun in @@ -334,16 +293,10 @@ extension NIOThreadPool { } extension NIOThreadPool { - #if swift(>=5.7) @preconcurrency public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void) { self.shutdownGracefully(queue: .global(), callback) } - #else - public func shutdownGracefully(_ callback: @escaping (Error?) -> Void) { - self.shutdownGracefully(queue: .global(), callback) - } - #endif /// Shuts down the thread pool gracefully. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -360,16 +313,10 @@ extension NIOThreadPool { } } - #if swift(>=5.7) @available(*, noasync, message: "this can end up blocking the calling thread", renamed: "shutdownGracefully()") public func syncShutdownGracefully() throws { try self._syncShutdownGracefully() } - #else - public func syncShutdownGracefully() throws { - try self._syncShutdownGracefully() - } - #endif private func _syncShutdownGracefully() throws { let errorStorageLock = NIOLock() diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index 7fe20193e3..64daf5f128 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -52,8 +52,7 @@ public struct NonBlockingFileIO: Sendable { public init(threadPool: NIOThreadPool) { self.threadPool = threadPool } - - #if swift(>=5.7) + /// Read a `FileRegion` in chunks of `chunkSize` bytes on `NonBlockingFileIO`'s private thread /// pool which is separate from any `EventLoop` thread. /// @@ -89,44 +88,7 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read a `FileRegion` in chunks of `chunkSize` bytes on `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `fileRegion.readableBytes` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ fileRegion.readableBytes / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `fileRegion.readableBytes` bytes can be read from the file, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the - /// same `FileRegion` in multiple threads. - /// - /// - parameters: - /// - fileRegion: The file region to read. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileRegion: FileRegion, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - let readableBytes = fileRegion.readableBytes - return self.readChunked(fileHandle: fileRegion.fileHandle, - fromOffset: Int64(fileRegion.readerIndex), - byteCount: readableBytes, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - #if swift(>=5.7) /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread /// pool which is separate from any `EventLoop` thread. /// @@ -165,47 +127,7 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ byteCount / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `byteCount` bytes can be read from `descriptor`, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// - note: `readChunked(fileRegion:chunkSize:allocator:eventLoop:chunkHandler:)` should be preferred as it uses - /// `FileRegion` object instead of raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, - /// please consider using `readChunked(fileHandle:fromOffset:chunkSize:allocator:eventLoop:chunkHandler:)` - /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file - /// descriptors and even across processes.) - /// - /// - parameters: - /// - fileHandle: The `NIOFileHandle` to read from. - /// - byteCount: The number of bytes to read from `fileHandle`. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileHandle: NIOFileHandle, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: nil, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - #if swift(>=5.7) /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread /// pool which is separate from any `EventLoop` thread. /// @@ -246,53 +168,8 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ byteCount / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `byteCount` bytes can be read from `descriptor`, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the - /// same `NIOFileHandle` in multiple threads. - /// - /// - note: `readChunked(fileRegion:chunkSize:allocator:eventLoop:chunkHandler:)` should be preferred as it uses - /// `FileRegion` object instead of raw `NIOFileHandle`s. - /// - /// - parameters: - /// - fileHandle: The `NIOFileHandle` to read from. - /// - byteCount: The number of bytes to read from `fileHandle`. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileHandle: NIOFileHandle, - fromOffset fileOffset: Int64, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: fileOffset, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - - #if swift(>=5.7) + private typealias ReadChunkHandler = @Sendable (ByteBuffer) -> EventLoopFuture - #else - private typealias ReadChunkHandler = (ByteBuffer) -> EventLoopFuture - #endif private func readChunked0(fileHandle: NIOFileHandle, fromOffset: Int64?, @@ -303,7 +180,7 @@ public struct NonBlockingFileIO: Sendable { precondition(chunkSize > 0, "chunkSize must be > 0 (is \(chunkSize))") let remainingReads = 1 + (byteCount / chunkSize) let lastReadSize = byteCount % chunkSize - + let promise = eventLoop.makePromise(of: Void.self) func _read(remainingReads: Int, bytesReadSoFar: Int64) { diff --git a/Sources/NIOPosix/Thread.swift b/Sources/NIOPosix/Thread.swift index be392dec4d..68b486c72a 100644 --- a/Sources/NIOPosix/Thread.swift +++ b/Sources/NIOPosix/Thread.swift @@ -17,7 +17,7 @@ import CNIOLinux #endif enum LowLevelThreadOperations { - + } protocol ThreadOps { @@ -185,8 +185,7 @@ public final class ThreadSpecificVariable { self.currentValue = value } - - #if swift(>=5.7) + /// The value for the current thread. @available(*, noasync, message: "threads can change between suspension points and therefore the thread specific value too") public var currentValue: Value? { @@ -197,18 +196,7 @@ public final class ThreadSpecificVariable { self.set(newValue) } } - #else - /// The value for the current thread. - public var currentValue: Value? { - get { - self.get() - } - set { - self.set(newValue) - } - } - #endif - + /// Get the current value for the calling thread. func get() -> Value? { guard let raw = self.key.get() else { return nil } @@ -218,7 +206,7 @@ public final class ThreadSpecificVariable { .takeUnretainedValue() .value.1 as! Value) } - + /// Set the current value for the calling threads. The `currentValue` for all other threads remains unchanged. func set(_ newValue: Value?) { if let raw = self.key.get() { diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 4d1f77f6a9..9a245edbbd 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -64,15 +64,10 @@ fileprivate extension HTTPHeaders { public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable { // This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime // the conformance is `@unchecked`. - - #if swift(>=5.7) + // FIXME: remove @unchecked when 5.7 is the minimum supported version. private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture - #else - private typealias ShouldUpgrade = (Channel, HTTPRequestHead) -> EventLoopFuture - private typealias UpgradePipelineHandler = (Channel, HTTPRequestHead) -> EventLoopFuture - #endif /// RFC 6455 specs this as the required entry in the Upgrade header. public let supportedProtocol: String = "websocket" @@ -86,7 +81,6 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch private let maxFrameSize: Int private let automaticErrorHandling: Bool - #if swift(>=5.7) /// Create a new `NIOWebSocketServerUpgrader`. /// /// - parameters: @@ -112,34 +106,7 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch self.init(maxFrameSize: 1 << 14, automaticErrorHandling: automaticErrorHandling, shouldUpgrade: shouldUpgrade, upgradePipelineHandler: upgradePipelineHandler) } - #else - /// Create a new `NIOWebSocketServerUpgrader`. - /// - /// - parameters: - /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol - /// errors by sending error responses and closing the connection. Defaults to `true`, - /// may be set to `false` if the user wishes to handle their own errors. - /// - shouldUpgrade: A callback that determines whether the websocket request should be - /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with - /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, - /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return - /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. - /// - upgradePipelineHandler: A function that will be called once the upgrade response is - /// flushed, and that is expected to mutate the `Channel` appropriately to handle the - /// websocket protocol. This only needs to add the user handlers: the - /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the - /// pipeline automatically. - public convenience init( - automaticErrorHandling: Bool = true, - shouldUpgrade: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture, - upgradePipelineHandler: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture - ) { - self.init(maxFrameSize: 1 << 14, automaticErrorHandling: automaticErrorHandling, - shouldUpgrade: shouldUpgrade, upgradePipelineHandler: upgradePipelineHandler) - } - #endif - #if swift(>=5.7) /// Create a new `NIOWebSocketServerUpgrader`. /// /// - parameters: @@ -174,42 +141,7 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch upgradePipelineHandler: upgradePipelineHandler ) } - #else - /// Create a new `NIOWebSocketServerUpgrader`. - /// - /// - parameters: - /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the - /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but - /// this is an objectively unreasonable maximum value (on AMD64 systems it is not - /// possible to even. Users may set this to any value up to `UInt32.max`. - /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol - /// errors by sending error responses and closing the connection. Defaults to `true`, - /// may be set to `false` if the user wishes to handle their own errors. - /// - shouldUpgrade: A callback that determines whether the websocket request should be - /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with - /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, - /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return - /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. - /// - upgradePipelineHandler: A function that will be called once the upgrade response is - /// flushed, and that is expected to mutate the `Channel` appropriately to handle the - /// websocket protocol. This only needs to add the user handlers: the - /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the - /// pipeline automatically. - public convenience init( - maxFrameSize: Int, - automaticErrorHandling: Bool = true, - shouldUpgrade: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture, - upgradePipelineHandler: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture - ) { - self.init( - _maxFrameSize: maxFrameSize, - automaticErrorHandling: automaticErrorHandling, - shouldUpgrade: shouldUpgrade, - upgradePipelineHandler: upgradePipelineHandler - ) - } - #endif - + private init( _maxFrameSize maxFrameSize: Int, automaticErrorHandling: Bool, diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index d9821d6a5d..4f0fdc64e8 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -83,11 +83,7 @@ extension EmbeddedChannel { } } -#if swift(>=5.7) private typealias UpgradeCompletionHandler = @Sendable (ChannelHandlerContext) -> Void -#else -private typealias UpgradeCompletionHandler = (ChannelHandlerContext) -> Void -#endif private func serverHTTPChannelWithAutoremoval(group: EventLoopGroup, pipelining: Bool, diff --git a/docker/docker-compose.2004.56.yaml b/docker/docker-compose.2204.510.yaml similarity index 71% rename from docker/docker-compose.2004.56.yaml rename to docker/docker-compose.2204.510.yaml index f26cac2fec..ef1732e29c 100644 --- a/docker/docker-compose.2004.56.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -3,28 +3,27 @@ version: "3" services: runtime-setup: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 build: args: - ubuntu_version: "focal" - swift_version: "5.6" + base_image: "swiftlang/swift:nightly-5.10-jammy" unit-tests: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 integration-tests: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 documentation-check: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 test: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 environment: - - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 + - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 - - MAX_ALLOCS_ALLOWED_1000_addHandlers_sync=39050 + - MAX_ALLOCS_ALLOWED_1000_addHandlers_sync=38050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlercontext=8050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlername=8050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlertype=8050 @@ -35,13 +34,13 @@ services: - MAX_ALLOCS_ALLOWED_1000_getHandlers=8050 - MAX_ALLOCS_ALLOWED_1000_getHandlers_sync=36 - MAX_ALLOCS_ALLOWED_1000_reqs_1_conn=26400 - - MAX_ALLOCS_ALLOWED_1000_rst_connections=151050 - - MAX_ALLOCS_ALLOWED_1000_tcpbootstraps=4050 - - MAX_ALLOCS_ALLOWED_1000_tcpconnections=156050 + - MAX_ALLOCS_ALLOWED_1000_rst_connections=149050 + - MAX_ALLOCS_ALLOWED_1000_tcpbootstraps=3050 + - MAX_ALLOCS_ALLOWED_1000_tcpconnections=157050 - MAX_ALLOCS_ALLOWED_1000_udp_reqs=6050 - MAX_ALLOCS_ALLOWED_1000_udpbootstraps=2050 - MAX_ALLOCS_ALLOWED_1000_udpconnections=77050 - - MAX_ALLOCS_ALLOWED_1_reqs_1000_conn=404000 + - MAX_ALLOCS_ALLOWED_1_reqs_1000_conn=396000 - MAX_ALLOCS_ALLOWED_bytebuffer_lots_of_rw=2050 - MAX_ALLOCS_ALLOWED_creating_10000_headers=0 - MAX_ALLOCS_ALLOWED_decode_1000_ws_frames=2050 @@ -61,26 +60,27 @@ services: - MAX_ALLOCS_ALLOWED_get_100000_headers_canonical_form_trimming_whitespace_from_long_string=700050 - MAX_ALLOCS_ALLOWED_get_100000_headers_canonical_form_trimming_whitespace_from_short_string=700050 - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=2050 - - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=339 + - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 + - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=343 - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 - - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=60100 - - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=60050 - - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=86 + - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 + - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 + - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 - MAX_ALLOCS_ALLOWED_udp_1000_reqs_1_conn=6200 - - MAX_ALLOCS_ALLOWED_udp_1_reqs_1000_conn=161050 + - MAX_ALLOCS_ALLOWED_udp_1_reqs_1000_conn=167050 - FORCE_TEST_DISCOVERY=--enable-test-discovery - WARN_AS_ERROR_ARG=-Xswiftc -warnings-as-errors + - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error # - SANITIZER_ARG=--sanitize=thread # TSan broken still performance-test: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 shell: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 echo: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 http: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index baf70a9674..357536253d 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -6,7 +6,8 @@ services: image: swift-nio:22.04-5.9 build: args: - base_image: "swiftlang/swift:nightly-5.9-jammy" + ubuntu_version: "jammy" + swift_version: "5.9" unit-tests: image: swift-nio:22.04-5.9 From de07e573d6fbce99cc72a558ce61dbb1ccddb205 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 2 Oct 2023 15:26:16 +0100 Subject: [PATCH 03/64] Introduce new typed `HTTPServerUpgrader` and `WebSocketServerUpgrader` (#2517) * Introduce new typed `HTTPServerUpgrader` and `WebSocketServerUpgrader` # Motivation With our new `NIOAsyncChannel` and typed bootstrap APIs we want to be able to let users spell out their pipeline in a typed way. Pipelines can contain handlers that have to make a forking decision such as HTTP upgrading. Our current `HTTPServerUpgradeHandler` is one of those handlers but it lacks strict typing. To interact nicely with our new typed APIs we need to have a new variant of the `HTTPServerUpgradeHandler` that can carry type information. # Modification This PR adds a few things: 1. A new `NIOTypedHTTPServerUpgradeHandler` + `NIOTypedHTTPServeProtocolUpgrader`. I also moved the state handling logic to a separate state machine. I thought about unifying the state machines of the _old_ handler and the new one but they differ in behaviour which makes the state machine more complicated. 2. A new `NIOTypedWebSocketServerUpgrader` that conforms to `NIOTypedHTTPServerProtocolUpgrader` 3. An overhauled WebSocket server example that fully uses Concurrency. # Result We now have a way to fully type the server side of HTTP protocol upgrading. Code review Update parameter names for new API and fix example * Introduce new configuration struct and rename to `UpgradablePipeline` * Review comments --- Package.swift | 1 + Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 148 +++++ .../NIOTypedHTTPServerUpgradeHandler.swift | 373 +++++++++++ ...OTypedHTTPServerUpgraderStateMachine.swift | 385 +++++++++++ .../NIOWebSocketServerUpgrader.swift | 182 ++++- Sources/NIOWebSocketClient/main.swift | 2 +- Sources/NIOWebSocketServer/Server.swift | 284 ++++++++ Sources/NIOWebSocketServer/main.swift | 282 -------- .../HTTPServerUpgradeTests.swift | 623 ++++++++++++++++-- .../WebSocketServerEndToEndTests.swift | 88 ++- 10 files changed, 1986 insertions(+), 382 deletions(-) create mode 100644 Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift create mode 100644 Sources/NIOWebSocketServer/Server.swift delete mode 100644 Sources/NIOWebSocketServer/main.swift diff --git a/Package.swift b/Package.swift index 4ba1b85842..d0b61c3fc5 100644 --- a/Package.swift +++ b/Package.swift @@ -140,6 +140,7 @@ let package = Package( "NIOCore", "NIOConcurrencyHelpers", "CNIOLLHTTP", + swiftCollections ] ), .target( diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift new file mode 100644 index 0000000000..c92e41715c --- /dev/null +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -0,0 +1,148 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +@_spi(AsyncChannel) import NIOCore + +// MARK: - Server pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public struct NIOUpgradableHTTPServerPipelineConfiguration { + /// Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. + public var enablePipelining = true + + /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP + /// request) by sending 400 errors. Defaults to `true`. + public var enableErrorHandling = true + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableResponseHeaderValidation = true + + /// The configuration for the ``HTTPResponseEncoder``. + public var httpResponseEncoderConfiguration = HTTPResponseEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Assistance handling clients that pipeline HTTP requests. + /// 2. Assistance handling protocol errors. + /// 3. Outbound header fields validation to protect against response splitting attacks. + /// 4. HTTP protocol upgrades. + /// + /// The defaults will likely be extended in the future and we recommend to use this initializer to ensure + /// you get newer features automatically. + public init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + @_spi(AsyncChannel) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + @_spi(AsyncChannel) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let responseEncoder = HTTPResponseEncoder(configuration: configuration.httpResponseEncoderConfiguration) + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + + var extraHTTPHandlers: [RemovableChannelHandler] = [requestDecoder] + extraHTTPHandlers.reserveCapacity(3) + + try self.addHandler(responseEncoder) + try self.addHandler(requestDecoder) + + if configuration.enablePipelining { + let pipeliningHandler = HTTPServerPipelineHandler() + try self.addHandler(pipeliningHandler) + extraHTTPHandlers.append(pipeliningHandler) + } + + if configuration.enableResponseHeaderValidation { + let headerValidationHandler = NIOHTTPResponseHeadersValidator() + try self.addHandler(headerValidationHandler) + extraHTTPHandlers.append(headerValidationHandler) + } + + if configuration.enableErrorHandling { + let errorHandler = HTTPServerProtocolErrorHandler() + try self.addHandler(errorHandler) + extraHTTPHandlers.append(errorHandler) + } + + let upgrader = NIOTypedHTTPServerUpgradeHandler( + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift new file mode 100644 index 0000000000..a665cd63a2 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -0,0 +1,373 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +@_spi(AsyncChannel) import NIOCore + +/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a server-side channel. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public protocol NIOTypedHTTPServerProtocolUpgrader { + associatedtype UpgradeResult + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol needs in the request to successfully upgrade. These header fields + /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated + /// against the inbound request's `Connection` header field. + var requiredUpgradeHeaders: [String] { get } + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + func upgrade( + channel: Channel, + upgradeRequest: HTTPRequestHead + ) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public struct NIOTypedHTTPServerUpgradeConfiguration { + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPServerProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgraders: [any NIOTypedHTTPServerProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. +/// +/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of +/// whether the upgrade succeeded or not. +/// +/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade +/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's +/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined +/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: +/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias InboundOut = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private let upgraders: [String: any NIOTypedHTTPServerProtocolUpgrader] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private let httpEncoder: HTTPResponseEncoder + private let extraHTTPHandlers: [RemovableChannelHandler] + private var stateMachine = NIOTypedHTTPServerUpgraderStateMachine() + + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + /// Create a ``NIOTypedHTTPServerUpgradeHandler``. + /// + /// - Parameters: + /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will + /// be removed from the pipeline once the upgrade response is sent. This is used to ensure + /// that the pipeline will be in a clean state after upgrade. + /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least + /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate + /// receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpEncoder: HTTPResponseEncoder, + extraHTTPHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + var upgraderMap = [String: any NIOTypedHTTPServerProtocolUpgrader]() + for upgrader in upgradeConfiguration.upgraders { + upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader + } + self.upgraders = upgraderMap + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + self.httpEncoder = httpEncoder + self.extraHTTPHandlers = extraHTTPHandlers + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let requestPart = self.unwrapInboundIn(data) + self.channelRead(context: context, requestPart: requestPart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { + switch self.stateMachine.channelReadRequestPart(requestPart) { + case .failUpgradePromise(let error): + self.upgradeResultPromise.fail(error) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .findUpgrader(let head, let requestedProtocols, let allHeaderNames, let connectionHeader): + let protocolIterator = requestedProtocols.makeIterator() + self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: head, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ).whenComplete { result in + context.eventLoop.assertInEventLoop() + self.findingUpgradeCompleted(context: context, requestHead: head, result) + } + + case .startUpgrading(let upgrader, let requestHead, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .none: + break + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result, + requestHeadAndProtocol: (HTTPRequestHead, String)? + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + /// Attempt to upgrade a single protocol. + /// + /// Will recurse through `protocolIterator` if upgrade fails. + private func handleUpgradeForProtocol( + context: ChannelHandlerContext, + protocolIterator: Array.Iterator, + request: HTTPRequestHead, + allHeaderNames: Set, + connectionHeader: Set + ) -> EventLoopFuture<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?> { + // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. + var protocolIterator = protocolIterator + guard let proto = protocolIterator.next() else { + // We're done! No suitable protocol for upgrade. + return context.eventLoop.makeSucceededFuture(nil) + } + + guard let upgrader = self.upgraders[proto.lowercased()] else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) + guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let responseHeaders = self.buildUpgradeHeaders(protocol: proto) + return upgrader.buildUpgradeResponse( + channel: context.channel, + upgradeRequest: request, + initialResponseHeaders: responseHeaders + ) + .hop(to: context.eventLoop) + .map { (upgrader, $0, proto) } + .flatMapError { error in + // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. + context.fireErrorCaught(error) + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + } + + private func findingUpgradeCompleted( + context: ChannelHandlerContext, + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) { + switch self.stateMachine.findingUpgraderCompleted(requestHead: requestHead, result) { + case .startUpgrading(let upgrader, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) { + // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP + // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until + // that completes. + // While there are a lot of Futures involved here it's quite possible that all of this code will + // actually complete synchronously: we just want to program for the possibility that it won't. + // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the + // internal handler, then call the user code, and then finally when the user code is done we do + // our final cleanup steps, namely we replay the received data we buffered in the meantime and + // then remove ourselves from the pipeline. + self.removeExtraHandlers(context: context).flatMap { + self.sendUpgradeResponse(context: context, responseHeaders: responseHeaders) + }.flatMap { + context.pipeline.removeHandler(self.httpEncoder) + }.flatMap { () -> EventLoopFuture in + return upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) + } + } + + /// Sends the 101 Switching Protocols response for the pipeline. + private func sendUpgradeResponse(context: ChannelHandlerContext, responseHeaders: HTTPHeaders) -> EventLoopFuture { + var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) + response.headers = responseHeaders + return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) + } + + /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. + private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { + return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeExtraHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.extraHTTPHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, + on: context.eventLoop) + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } +} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift new file mode 100644 index 0000000000..d0fcf287de --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -0,0 +1,385 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPServerUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial + + @usableFromInline + struct AwaitingUpgrader { + var seenFirstRequest: Bool + var buffer: Deque + } + + /// The request head has been received. We're currently running the future chain awaiting an upgrader. + case awaitingUpgrader(AwaitingUpgrader) + + @usableFromInline + struct UpgraderReady { + var upgrader: any NIOTypedHTTPServerProtocolUpgrader + var requestHead: HTTPRequestHead + var responseHeaders: HTTPHeaders + var proto: String + var buffer: Deque + } + + /// We have an upgrader, which means we can begin upgrade we are just waiting for the request end. + case upgraderReady(UpgraderReady) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state = State.initial + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgrader(var awaitingUpgrader): + if awaitingUpgrader.seenFirstRequest { + // We should buffer the data since we have seen the full request. + self.state = .modifying + awaitingUpgrader.buffer.append(data) + self.state = .awaitingUpgrader(awaitingUpgrader) + return nil + } else { + // We shouldn't buffer. This means we are still expecting HTTP parts. + return .unwrapData + } + + case .upgraderReady: + // We have not seen the end of the HTTP request so this + // data is probably an HTTP request part. + return .unwrapData + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .upgrading(var upgrading): + // We got a read while running ugprading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadRequestPartAction { + case failUpgradePromise(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) + case findUpgrader( + head: HTTPRequestHead, + requestedProtocols: [String], + allHeaderNames: Set, + connectionHeader: Set + ) + } + + @inlinable + mutating func channelReadRequestPart(_ requestPart: HTTPServerRequestPart) -> ChannelReadRequestPartAction? { + switch self.state { + case .initial: + guard case .head(let head) = requestPart else { + // The first data that we saw was not a head. This is a protocol error and we are just going to + // fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + } + + // Ok, we have a HTTP head. Check if it's an upgrade. + let requestedProtocols = head.headers[canonicalForm: "upgrade"].map(String.init) + guard requestedProtocols.count > 0 else { + // We have to buffer now since we got the request head but are not upgrading. + // The user is configuring the HTTP pipeline now. + var buffer = Deque() + buffer.append(NIOAny(requestPart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // We can now transition to awaiting the upgrader. This means that we are trying to + // find an upgrade that can handle requested protocols. We are not buffering because + // we are waiting for the request end. + self.state = .awaitingUpgrader(.init(seenFirstRequest: false, buffer: .init())) + + let connectionHeader = Set(head.headers[canonicalForm: "connection"].map { $0.lowercased() }) + let allHeaderNames = Set(head.headers.map { $0.name.lowercased() }) + + return .findUpgrader( + head: head, + requestedProtocols: requestedProtocols, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) + + case .awaitingUpgrader(let awaitingUpgrader): + switch (awaitingUpgrader.seenFirstRequest, requestPart) { + case (true, _): + // This is weird we are seeing more requests parts after we have seen an end + // Let's fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .head): + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .body): + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .end): + // Okay we got the end as expected. Just gotta store this in our state. + self.state = .awaitingUpgrader(.init(seenFirstRequest: true, buffer: awaitingUpgrader.buffer)) + return nil + } + + case .upgraderReady(let upgraderReady): + switch requestPart { + case .head: + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .body: + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .end: + // Okay we got the end as expected and our upgrader is ready so let's start upgrading + self.state = .upgrading(.init(buffer: upgraderReady.buffer)) + return .startUpgrading( + upgrader: upgraderReady.upgrader, + requestHead: upgraderReady.requestHead, + responseHeaders: upgraderReady.responseHeaders, + proto: upgraderReady.proto + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .awaitingUpgrader, .upgraderReady, .unbuffering: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum FindingUpgraderCompletedAction { + case startUpgrading(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String) + case runNotUpgradingInitializer + case fireErrorCaughtAndStartUnbuffering(Error) + case fireErrorCaughtAndRemoveHandler(Error) + } + + @inlinable + mutating func findingUpgraderCompleted( + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) -> FindingUpgraderCompletedAction? { + switch self.state { + case .initial, .upgraderReady: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .awaitingUpgrader(let awaitingUpgrader): + switch result { + case .success(.some((let upgrader, let responseHeaders, let proto))): + if awaitingUpgrader.seenFirstRequest { + // We have seen the end of the request. So we can upgrade now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .startUpgrading(upgrader: upgrader, responseHeaders: responseHeaders, proto: proto) + } else { + // We have not yet seen the end so we have to wait until that happens + self.state = .upgraderReady(.init( + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto, + buffer: awaitingUpgrader.buffer + )) + return nil + } + + case .success(.none): + // There was no upgrader to handle the request. We just run the not upgrading + // initializer now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .runNotUpgradingInitializer + + case .failure(let error): + if !awaitingUpgrader.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: awaitingUpgrader.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + } + } + +} diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 9a245edbbd..baa66eec5a 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// import CNIOSHA1 -import NIOCore -import NIOHTTP1 +@_spi(AsyncChannel) import NIOCore +@_spi(AsyncChannel) import NIOHTTP1 let magicWebSocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -156,28 +156,136 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch } public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - let key: String - let version: String - - do { - key = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Key") - version = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Version") - } catch { - return channel.eventLoop.makeFailedFuture(error) - } + return _buildUpgradeResponse( + channel: channel, + upgradeRequest: upgradeRequest, + initialResponseHeaders: initialResponseHeaders, + shouldUpgrade: self.shouldUpgrade + ) + } - // The version must be 13. - guard version == "13" else { - return channel.eventLoop.makeFailedFuture(NIOWebSocketUpgradeError.invalidUpgradeHeader) - } + public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + _upgrade( + channel: context.channel, + upgradeRequest: upgradeRequest, + maxFrameSize: self.maxFrameSize, + automaticErrorHandling: self.automaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} + +/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// Users may frequently want to offer multiple websocket endpoints on the same port. For this +/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to +/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers +/// (including the path) and determines whether this is a websocket upgrade request that is acceptable +/// to them. +/// +/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to +/// remove the HTTP `ChannelHandler`s. +@_spi(AsyncChannel) +public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { + private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + + /// We deliberately do not actually set any required headers here, because the websocket + /// spec annoyingly does not actually force the client to send these in the Upgrade header, + /// which NIO requires. We check for these manually. + public let requiredUpgradeHeaders: [String] = [] + + private let shouldUpgrade: ShouldUpgrade + private let upgradePipelineHandler: UpgradePipelineHandler + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + + /// Create a new ``NIOTypedWebSocketServerUpgrader``. + /// + /// - Parameters: + /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the + /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but + /// this is an objectively unreasonable maximum value (on AMD64 systems it is not + /// possible to even. Users may set this to any value up to `UInt32.max`. + /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol + /// errors by sending error responses and closing the connection. Defaults to `true`, + /// may be set to `false` if the user wishes to handle their own errors. + /// - shouldUpgrade: A callback that determines whether the websocket request should be + /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with + /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, + /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return + /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. + /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is + /// flushed, and that is expected to mutate the `Channel` appropriately to handle the + /// websocket protocol. This only needs to add the user handlers: the + /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the + /// pipeline automatically. + public init( + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + ) { + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.shouldUpgrade = shouldUpgrade + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling + } + + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + _buildUpgradeResponse( + channel: channel, + upgradeRequest: upgradeRequest, + initialResponseHeaders: initialResponseHeaders, + shouldUpgrade: self.shouldUpgrade + ) + } - return self.shouldUpgrade(channel, upgradeRequest).flatMapThrowing { extraHeaders in - guard let extraHeaders = extraHeaders else { + public func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeRequest: upgradeRequest, + maxFrameSize: self.maxFrameSize, + automaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} + +private func _buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders, + shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture +) -> EventLoopFuture { + let key: String + let version: String + + do { + key = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Key") + version = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Version") + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + + // The version must be 13. + guard version == "13" else { + return channel.eventLoop.makeFailedFuture(NIOWebSocketUpgradeError.invalidUpgradeHeader) + } + + return shouldUpgrade(channel, upgradeRequest) + .flatMapThrowing { extraHeaders in + guard var extraHeaders = extraHeaders else { throw NIOWebSocketUpgradeError.unsupportedWebSocketTarget } - return extraHeaders - }.map { (extraHeaders: HTTPHeaders) in - var extraHeaders = extraHeaders // Cool, we're good to go! Let's do our upgrade. We do this by concatenating the magic // GUID to the base64-encoded key and taking a SHA1 hash of the result. @@ -195,23 +303,27 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch return extraHeaders } - } - - public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { - /// We never use the automatic error handling feature of the WebSocketFrameDecoder: we always use the separate channel - /// handler. - var upgradeFuture = context.pipeline.addHandler(WebSocketFrameEncoder()).flatMap { - context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize))) - } +} - if self.automaticErrorHandling { - upgradeFuture = upgradeFuture.flatMap { - context.pipeline.addHandler(WebSocketProtocolErrorHandler()) - } - } +private func _upgrade( + channel: Channel, + upgradeRequest: HTTPRequestHead, + maxFrameSize: Int, + automaticErrorHandling: Bool, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture +) -> EventLoopFuture { + /// We never use the automatic error handling feature of the WebSocketFrameDecoder: we always use the separate channel + /// handler. + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()) + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) + ) - return upgradeFuture.flatMap { - self.upgradePipelineHandler(context.channel, upgradeRequest) + if automaticErrorHandling { + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) } + }.flatMap { + upgradePipelineHandler(channel, upgradeRequest) } } diff --git a/Sources/NIOWebSocketClient/main.swift b/Sources/NIOWebSocketClient/main.swift index 1f3adecc4f..5e86a9e045 100644 --- a/Sources/NIOWebSocketClient/main.swift +++ b/Sources/NIOWebSocketClient/main.swift @@ -146,7 +146,7 @@ private final class WebSocketPingPongHandler: ChannelInboundHandler { private func pingTestFrameData(context: ChannelHandlerContext) { let buffer = context.channel.allocator.buffer(string: self.testFrameData) let frame = WebSocketFrame(fin: true, opcode: .ping, data: buffer) - context.write(self.wrapOutboundOut(frame), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(frame), promise: nil) } private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift new file mode 100644 index 0000000000..ba6af55144 --- /dev/null +++ b/Sources/NIOWebSocketServer/Server.swift @@ -0,0 +1,284 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.9) +@_spi(AsyncChannel) import NIOCore +@_spi(AsyncChannel) import NIOPosix +@_spi(AsyncChannel) import NIOHTTP1 +@_spi(AsyncChannel) import NIOWebSocket + +let websocketResponse = """ + + + + + Swift NIO WebSocket Test Page + + + +

WebSocket Stream

+
+ + +""" + +@available(macOS 14, *) +@main +struct Server { + /// The server's host. + private let host: String + /// The server's port. + private let port: Int + /// The server's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + private static let responseBody = ByteBuffer(string: websocketResponse) + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded(NIOAsyncChannel>) + } + + static func main() async throws { + let server = Server( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await server.run() + } + + /// This method starts the server and handles incoming connections. + func run() async throws { + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketServerUpgrader( + shouldUpgrade: { (channel, head) in + channel.eventLoop.makeSucceededFuture(HTTPHeaders()) + }, + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) + let asyncChannel = try NIOAsyncChannel>(synchronouslyWrapping: channel) + return UpgradeResult.notUpgraded(asyncChannel) + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + for try await upgradeResult in channel.inboundStream { + group.addTask { + await self.handleUpgradeResult(upgradeResult) + } + } + } + } + + /// This method handles a single connection by echoing back all inbound data. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded(let httpChannel): + print("Handling HTTP connection") + try await self.handleHTTPChannel(httpChannel) + print("Done handling HTTP connection") + } + } catch { + print("Hit error: \(error)") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for try await frame in channel.inboundStream { + switch frame.opcode { + case .ping: + print("Received ping") + var frameData = frame.data + let maskingKey = frame.maskKey + + if let maskingKey = maskingKey { + frameData.webSocketUnmask(maskingKey) + } + + let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) + try await channel.outboundWriter.write(responseFrame) + + case .connectionClose: + // This is an unsolicited close. We're going to send a response frame and + // then, when we've sent it, close up shop. We should send back the close code the remote + // peer sent us, unless they didn't send one at all. + print("Received close") + var data = frame.unmaskedData + let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() + let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) + try await channel.outboundWriter.write(closeFrame) + return + case .binary, .continuation, .pong: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } + + group.addTask { + // This is our main business logic where we are just sending the current time + // every second. + while true { + // We can't really check for error here, but it's also not the purpose of the + // example so let's not worry about it. + let theTime = ContinuousClock().now + var buffer = channel.channel.allocator.buffer(capacity: 12) + buffer.writeString("\(theTime)") + + let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) + + print("Sending time") + try await channel.outboundWriter.write(frame) + try await Task.sleep(for: .seconds(1)) + } + } + + try await group.next() + group.cancelAll() + } + } + + + private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { + for try await requestPart in channel.inboundStream { + // We're not interested in request bodies here: we're just serving up GET responses + // to get the client to initiate a websocket request. + guard case .head(let head) = requestPart else { + return + } + + // GETs only. + guard case .GET = head.method else { + try await self.respond405(writer: channel.outboundWriter) + return + } + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/html") + headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) + headers.add(name: "Connection", value: "close") + let responseHead = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: headers + ) + + try await channel.outboundWriter.write( + contentsOf: [ + .head(responseHead), + .body(Self.responseBody), + .end(nil) + ] + ) + } + } + + private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { + var headers = HTTPHeaders() + headers.add(name: "Connection", value: "close") + headers.add(name: "Content-Length", value: "0") + let head = HTTPResponseHead( + version: .http1_1, + status: .methodNotAllowed, + headers: headers + ) + + try await writer.write( + contentsOf: [ + .head(head), + .end(nil) + ] + ) + } +} + +final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPPart + typealias OutboundOut = HTTPServerResponsePart + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let part = self.unwrapOutboundIn(data) + switch part { + case .head(let head): + context.write(self.wrapOutboundOut(.head(head)), promise: promise) + case .body(let buffer): + context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) + case .end(let trailers): + context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) + } + } +} + +#else +@main +struct Server { + static func main() { + fatalError("Requires at least Swift 5.9") + } +} +#endif diff --git a/Sources/NIOWebSocketServer/main.swift b/Sources/NIOWebSocketServer/main.swift deleted file mode 100644 index d8bb6592e6..0000000000 --- a/Sources/NIOWebSocketServer/main.swift +++ /dev/null @@ -1,282 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore -import NIOPosix -import NIOHTTP1 -import NIOWebSocket - -let websocketResponse = """ - - - - - Swift NIO WebSocket Test Page - - - -

WebSocket Stream

-
- - -""" - -private final class HTTPHandler: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = HTTPServerRequestPart - typealias OutboundOut = HTTPServerResponsePart - - private var responseBody: ByteBuffer! - - func handlerAdded(context: ChannelHandlerContext) { - self.responseBody = context.channel.allocator.buffer(string: websocketResponse) - } - - func handlerRemoved(context: ChannelHandlerContext) { - self.responseBody = nil - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let reqPart = self.unwrapInboundIn(data) - - // We're not interested in request bodies here: we're just serving up GET responses - // to get the client to initiate a websocket request. - guard case .head(let head) = reqPart else { - return - } - - // GETs only. - guard case .GET = head.method else { - self.respond405(context: context) - return - } - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/html") - headers.add(name: "Content-Length", value: String(self.responseBody.readableBytes)) - headers.add(name: "Connection", value: "close") - let responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), - status: .ok, - headers: headers) - context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(self.responseBody))), promise: nil) - context.write(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result) in - context.close(promise: nil) - } - context.flush() - } - - private func respond405(context: ChannelHandlerContext) { - var headers = HTTPHeaders() - headers.add(name: "Connection", value: "close") - headers.add(name: "Content-Length", value: "0") - let head = HTTPResponseHead(version: .http1_1, - status: .methodNotAllowed, - headers: headers) - context.write(self.wrapOutboundOut(.head(head)), promise: nil) - context.write(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result) in - context.close(promise: nil) - } - context.flush() - } -} - -private final class WebSocketTimeHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - private var awaitingClose: Bool = false - - public func handlerAdded(context: ChannelHandlerContext) { - self.sendTime(context: context) - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = self.unwrapInboundIn(data) - - switch frame.opcode { - case .connectionClose: - self.receivedClose(context: context, frame: frame) - case .ping: - self.pong(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - print(text) - case .binary, .continuation, .pong: - // We ignore these frames. - break - default: - // Unknown frames are errors. - self.closeOnError(context: context) - } - } - - public func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - private func sendTime(context: ChannelHandlerContext) { - guard context.channel.isActive else { return } - - // We can't send if we sent a close message. - guard !self.awaitingClose else { return } - - // We can't really check for error here, but it's also not the purpose of the - // example so let's not worry about it. - let theTime = NIODeadline.now().uptimeNanoseconds - var buffer = context.channel.allocator.buffer(capacity: 12) - buffer.writeString("\(theTime)") - - let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) - context.writeAndFlush(self.wrapOutboundOut(frame)).map { - context.eventLoop.scheduleTask(in: .seconds(1), { self.sendTime(context: context) }) - }.whenFailure { (_: Error) in - context.close(promise: nil) - } - } - - private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - // Handle a received close frame. In websockets, we're just going to send the close - // frame and then close, unless we already sent our own close frame. - if awaitingClose { - // Cool, we started the close and were waiting for the user. We're done. - context.close(promise: nil) - } else { - // This is an unsolicited close. We're going to send a response frame and - // then, when we've sent it, close up shop. We should send back the close code the remote - // peer sent us, unless they didn't send one at all. - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() - let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) - _ = context.write(self.wrapOutboundOut(closeFrame)).map { () in - context.close(promise: nil) - } - } - } - - private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - let maskingKey = frame.maskKey - - if let maskingKey = maskingKey { - frameData.webSocketUnmask(maskingKey) - } - - let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - context.write(self.wrapOutboundOut(responseFrame), promise: nil) - } - - private func closeOnError(context: ChannelHandlerContext) { - // We have hit an error, we want to close. We do that by sending a close frame and then - // shutting down the write side of the connection. - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - awaitingClose = true - } -} - -let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - -let upgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel: Channel, head: HTTPRequestHead) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, - upgradePipelineHandler: { (channel: Channel, _: HTTPRequestHead) in - channel.pipeline.addHandler(WebSocketTimeHandler()) - }) - -let bootstrap = ServerBootstrap(group: group) - // Specify backlog and enable SO_REUSEADDR for the server itself - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are applied to the accepted Channels - .childChannelInitializer { channel in - let httpHandler = HTTPHandler() - let config: NIOHTTPServerUpgradeConfiguration = ( - upgraders: [ upgrader ], - completionHandler: { _ in - channel.pipeline.removeHandler(httpHandler, promise: nil) - } - ) - return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config).flatMap { - channel.pipeline.addHandler(httpHandler) - } - } - - // Enable SO_REUSEADDR for the accepted Channels - .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - -defer { - try! group.syncShutdownGracefully() -} - -// First argument is the program path -let arguments = CommandLine.arguments -let arg1 = arguments.dropFirst().first -let arg2 = arguments.dropFirst(2).first - -let defaultHost = "localhost" -let defaultPort = 8888 - -enum BindTo { - case ip(host: String, port: Int) - case unixDomainSocket(path: String) -} - -let bindTarget: BindTo -switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ - bindTarget = .ip(host: h, port: p) - -case (let portString?, .none, _): - // Couldn't parse as number, expecting unix domain socket path. - bindTarget = .unixDomainSocket(path: portString) - -case (_, let p?, _): - // Only one argument --> port. - bindTarget = .ip(host: defaultHost, port: p) - -default: - bindTarget = .ip(host: defaultHost, port: defaultPort) -} - -let channel = try { () -> Channel in - switch bindTarget { - case .ip(let host, let port): - return try bootstrap.bind(host: host, port: port).wait() - case .unixDomainSocket(let path): - return try bootstrap.bind(unixDomainSocketPath: path).wait() - } -}() - -guard let localAddress = channel.localAddress else { - fatalError("Address was unable to bind. Please check that the socket was not closed or that the address family was understood.") -} -print("Server started and listening on \(localAddress)") - -// This will never unblock as we don't close the ServerChannel -try channel.closeFuture.wait() - -print("Server closed") diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 4f0fdc64e8..9be0cc4c22 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -13,11 +13,10 @@ //===----------------------------------------------------------------------===// import XCTest -import Dispatch import NIOCore import NIOEmbedded @testable import NIOPosix -@testable import NIOHTTP1 +@testable @_spi(AsyncChannel) import NIOHTTP1 extension ChannelPipeline { fileprivate func assertDoesNotContainUpgrader() throws { @@ -36,7 +35,11 @@ extension ChannelPipeline { } fileprivate func assertContainsUpgrader() throws { - try self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + } catch { + try self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + } } func assertContains(handlerType: Handler.Type) throws { @@ -87,7 +90,7 @@ private typealias UpgradeCompletionHandler = @Sendable (ChannelHandlerContext) - private func serverHTTPChannelWithAutoremoval(group: EventLoopGroup, pipelining: Bool, - upgraders: [HTTPServerProtocolUpgrader], + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], extraHandlers: [ChannelHandler], _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, EventLoopFuture) { let p = group.next().makePromise(of: Channel.self) @@ -137,20 +140,6 @@ private func connectedClientChannel(group: EventLoopGroup, serverAddress: Socket .wait() } -private func setUpTestWithAutoremoval(pipelining: Bool = false, - upgraders: [HTTPServerProtocolUpgrader], - extraHandlers: [ChannelHandler], - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (EventLoopGroup, Channel, Channel, Channel) { - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: group, - pipelining: pipelining, - upgraders: upgraders, - extraHandlers: extraHandlers, - upgradeCompletionHandler) - let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannel.localAddress!) - return (group, serverChannel, clientChannel, try connectedServerChannelFuture.wait()) -} - internal func assertResponseIs(response: String, expectedResponseLine: String, expectedResponseHeaders: [String]) { var lines = response.split(separator: "\r\n", omittingEmptySubsequences: false).map { String($0) } @@ -175,7 +164,9 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e XCTAssertEqual(lines.count, 0) } -private class ExplodingUpgrader: HTTPServerProtocolUpgrader { +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} + +private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] @@ -197,9 +188,14 @@ private class ExplodingUpgrader: HTTPServerProtocolUpgrader { XCTFail("upgrade called") return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + XCTFail("upgrade called") + return channel.eventLoop.makeSucceededFuture(true) + } } -private class UpgraderSaysNo: HTTPServerProtocolUpgrader { +private class UpgraderSaysNo: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] = [] @@ -219,9 +215,14 @@ private class UpgraderSaysNo: HTTPServerProtocolUpgrader { XCTFail("upgrade called") return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + XCTFail("upgrade called") + return channel.eventLoop.makeSucceededFuture(true) + } } -private class SuccessfulUpgrader: HTTPServerProtocolUpgrader { +private class SuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] private let onUpgradeComplete: (HTTPRequestHead) -> () @@ -256,13 +257,18 @@ private class SuccessfulUpgrader: HTTPServerProtocolUpgrader { self.onUpgradeComplete(upgradeRequest) return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.onUpgradeComplete(upgradeRequest) + return channel.eventLoop.makeSucceededFuture(true) + } } -private class DelayedUnsuccessfulUpgrader: HTTPServerProtocolUpgrader { +private class DelayedUnsuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] - private var upgradePromise: EventLoopPromise? + private var upgradePromise: EventLoopPromise? init(forProtocol `protocol`: String) { self.supportedProtocol = `protocol` @@ -277,19 +283,24 @@ private class DelayedUnsuccessfulUpgrader: HTTPServerProtocolUpgrader { func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { self.upgradePromise = context.eventLoop.makePromise() - return self.upgradePromise!.futureResult + return self.upgradePromise!.futureResult.map { _ in } } func unblockUpgrade(withError error: Error) { self.upgradePromise!.fail(error) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise(of: Bool.self) + return self.upgradePromise!.futureResult + } } -private class UpgradeDelayer: HTTPServerProtocolUpgrader { +private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] = [] - private var upgradePromise: EventLoopPromise? + private var upgradePromise: EventLoopPromise? public init(forProtocol `protocol`: String) { self.supportedProtocol = `protocol` @@ -303,11 +314,16 @@ private class UpgradeDelayer: HTTPServerProtocolUpgrader { public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { self.upgradePromise = context.eventLoop.makePromise() - return self.upgradePromise!.futureResult + return self.upgradePromise!.futureResult.map { _ in } } public func unblockUpgrade() { - self.upgradePromise!.succeed(()) + self.upgradePromise!.succeed(true) + } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise(of: Bool.self) + return self.upgradePromise!.futureResult } } @@ -392,6 +408,21 @@ private class ReentrantReadOnChannelReadCompleteHandler: ChannelInboundHandler { } class HTTPServerUpgradeTestCase: XCTestCase { + fileprivate func setUpTestWithAutoremoval(pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (EventLoopGroup, Channel, Channel, Channel) { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: group, + pipelining: pipelining, + upgraders: upgraders, + extraHandlers: extraHandlers, + upgradeCompletionHandler) + let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannel.localAddress!) + return (group, serverChannel, clientChannel, try connectedServerChannelFuture.wait()) + } + func testUpgradeWithoutUpgrade() throws { let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in @@ -758,14 +789,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testDelayedUpgradeBehaviour() throws { - let g = DispatchGroup() - g.enter() - let upgrader = UpgradeDelayer(forProtocol: "myproto") let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { context in - g.leave() - } + extraHandlers: []) { context in } defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } @@ -784,8 +810,6 @@ class HTTPServerUpgradeTestCase: XCTestCase { let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - g.wait() - // Ok, we don't think this upgrade should have succeeded yet, but neither should it have failed. We want to // dispatch onto the server event loop and check that the channel still contains the upgrade handler. try connectedServer.pipeline.assertContainsUpgrader() @@ -800,16 +824,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testBuffersInboundDataDuringDelayedUpgrade() throws { - let g = DispatchGroup() - g.enter() - let upgrader = UpgradeDelayer(forProtocol: "myproto") let dataRecorder = DataRecorder() let (group, server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [dataRecorder]) { context in - g.leave() - } + extraHandlers: [dataRecorder]) { context in } defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } @@ -824,14 +843,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - // This request is safe to upgrade, but is immediately followed by non-HTTP data that will probably - // blow up the HTTP parser. + // This request is safe to upgrade, but is immediately followed by non-HTTP data. let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - // Wait for the upgrade machinery to run. - g.wait() - // Ok, send the application data in. let appData = "supersecretawesome data definitely not http\r\nawesome\r\ndata\ryeah" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: appData))).wait()) @@ -1538,3 +1553,517 @@ class HTTPServerUpgradeTestCase: XCTestCase { try channel.pipeline.assertContainsUpgrader() } } + +final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { + fileprivate override func setUpTestWithAutoremoval( + pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler + ) throws -> (EventLoopGroup, Channel, Channel, Channel) { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let connectionChannelPromise = group.next().makePromise(of: Channel.self) + let serverChannelFuture = ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + connectionChannelPromise.succeed(channel) + var configuration = NIOUpgradableHTTPServerPipelineConfiguration( + upgradeConfiguration: .init( + upgraders: upgraders.map { $0 as! any NIOTypedHTTPServerProtocolUpgrader }, + notUpgradingCompletionHandler: { notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) } + ) + ) + configuration.enablePipelining = pipelining + return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(configuration: configuration) + .flatMap { result in + if result { + return channel.pipeline.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self) + .map { + upgradeCompletionHandler($0) + } + } else { + return channel.eventLoop.makeSucceededVoidFuture() + } + } + } + .flatMap { _ in + let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) + } + }.bind(host: "127.0.0.1", port: 0) + let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannelFuture.wait().localAddress!) + return (group, try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) + } + + func testNotUpgrading() throws { + let notUpgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in } + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + notUpgradingHandler: { channel in + notUpgraderCbFired.wrappedValue = true + // We're closing the connection now. + channel.close(promise: nil) + return channel.eventLoop.makeSucceededFuture(true) + } + ) { _ in } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + XCTAssertEqual(resultString, "") + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that the not upgrader got called. + XCTAssert(notUpgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testSimpleUpgradeSucceeds() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + // This is called before completion block. + upgradeRequest.wrappedValue = req + upgradeHandlerCbFired.wrappedValue = true + + XCTAssert(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeRespectsClientPreference() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = ExplodingUpgrader(forProtocol: "exploder") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: []) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + } + + override func testUpgraderCanRejectUpgradeForPersonalReasons() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = UpgraderSaysNo(forProtocol: "noproto") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + let errorCatcher = ErrorSaver() + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [errorCatcher]) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + + // And we want to confirm we saved the error. + XCTAssertEqual(errorCatcher.errors.count, 1) + + switch(errorCatcher.errors[0]) { + case UpgraderSaysNo.No.no: + break + default: + XCTFail("Unexpected error: \(errorCatcher.errors[0])") + } + } + + override func testUpgradeWithUpgradePayloadInlineWithRequestWorks() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + enum ReceivedTheWrongThingError: Error { case error } + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias OutboundIn = Never + typealias OutboundOut = Never + + enum State { + case fresh + case added + case inlineDataRead + case extraDataRead + case closed + } + + private let firstByteDonePromise: EventLoopPromise + private let secondByteDonePromise: EventLoopPromise + private let allDonePromise: EventLoopPromise + private var state = State.fresh + + init(firstByteDonePromise: EventLoopPromise, + secondByteDonePromise: EventLoopPromise, + allDonePromise: EventLoopPromise) { + self.firstByteDonePromise = firstByteDonePromise + self.secondByteDonePromise = secondByteDonePromise + self.allDonePromise = allDonePromise + } + + func handlerAdded(context: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .added + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + XCTAssertEqual(1, buf.readableBytes) + let stringRead = buf.readString(length: buf.readableBytes) + switch self.state { + case .added: + XCTAssertEqual("A", stringRead) + self.state = .inlineDataRead + if stringRead == .some("A") { + self.firstByteDonePromise.succeed(()) + } else { + self.firstByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + case .inlineDataRead: + XCTAssertEqual("B", stringRead) + self.state = .extraDataRead + context.channel.close(promise: nil) + if stringRead == .some("B") { + self.secondByteDonePromise.succeed(()) + } else { + self.secondByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + default: + XCTFail("channel read in wrong state \(self.state)") + } + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + XCTAssertEqual(.extraDataRead, self.state) + self.state = .closed + context.close(mode: mode, promise: promise) + + self.allDonePromise.succeed(()) + } + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let promiseGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try promiseGroup.syncShutdownGracefully()) + } + let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let allDonePromise = promiseGroup.next().makePromise(of: Void.self) + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, + secondByteDonePromise: secondByteDonePromise, + allDonePromise: allDonePromise)) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + request += "A" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + XCTAssertNoThrow(try firstByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: "B"))).wait()) + + XCTAssertNoThrow(try secondByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try allDonePromise.futureResult.wait() as Void) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + + XCTAssertNoThrow(try allDonePromise.futureResult.wait()) + } + + override func testWeTolerateUpgradeFuturesFromWrongEventLoops() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + let otherELG = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", + requiringHeaders: ["kafkaesque"], + buildUpgradeResponseFuture: { + // this is the wrong EL + otherELG.next().makeSucceededFuture($1) + }) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeFiresUserEvent() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let eventSaver = UnsafeTransfer(UserEventSaver()) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) + } + + let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: [eventSaver.wrappedValue]) { context in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + context.close(promise: nil) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise = group.next().makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we should have received one user event. We schedule this onto the + // event loop to guarantee thread safety. + XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { + XCTAssertEqual(proto, "myproto") + XCTAssertEqual(req.method, .OPTIONS) + XCTAssertEqual(req.uri, "*") + XCTAssertEqual(req.version, .http1_1) + } else { + XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") + } + }.futureResult.wait()) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + } +} diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index a5ec8e6dc2..c636de6fc7 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -15,8 +15,8 @@ import XCTest @testable import NIOCore import NIOEmbedded -import NIOHTTP1 -@testable import NIOWebSocket +@_spi(AsyncChannel) import NIOHTTP1 +@testable @_spi(AsyncChannel) import NIOWebSocket extension EmbeddedChannel { func readAllInboundBuffers() throws -> ByteBuffer { @@ -112,10 +112,38 @@ private class WebSocketRecorderHandler: ChannelInboundHandler { } } +struct WebSocketServerUpgraderConfiguration { + let maxFrameSize: Int + let automaticErrorHandling: Bool + let shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + let upgradePipelineHandler: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + + @preconcurrency + init( + maxFrameSize: Int = 1 << 14, + automaticErrorHandling: Bool = true, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + ) { + self.maxFrameSize = maxFrameSize + self.automaticErrorHandling = automaticErrorHandling + self.shouldUpgrade = shouldUpgrade + self.upgradePipelineHandler = upgradePipelineHandler + } +} + class WebSocketServerEndToEndTests: XCTestCase { - private func createTestFixtures(upgraders: [NIOWebSocketServerUpgrader]) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { + func createTestFixtures( + upgraders: [WebSocketServerUpgraderConfiguration] + ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { let loop = EmbeddedEventLoop() let serverChannel = EmbeddedChannel(loop: loop) + let upgraders = upgraders.map { NIOWebSocketServerUpgrader( + maxFrameSize: $0.maxFrameSize, + automaticErrorHandling: $0.automaticErrorHandling, + shouldUpgrade: $0.shouldUpgrade, + upgradePipelineHandler: $0.upgradePipelineHandler + )} XCTAssertNoThrow(try serverChannel.pipeline.configureHTTPServerPipeline( withServerUpgrade: (upgraders: upgraders as [HTTPServerProtocolUpgrader], completionHandler: { (context: ChannelHandlerContext) in } ) ).wait()) @@ -129,7 +157,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testBasicUpgradeDance() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -151,7 +179,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testUpgradeWithProtocolName() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -170,7 +198,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testCanRejectUpgrade() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(nil) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(nil) }, upgradePipelineHandler: { (channel, req) in XCTFail("Should not have called") return channel.eventLoop.makeSucceededFuture(()) @@ -201,7 +229,7 @@ class WebSocketServerEndToEndTests: XCTestCase { var acceptPromise: EventLoopPromise? = nil var upgradeComplete = false - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in acceptPromise = channel.eventLoop.makePromise() return acceptPromise!.futureResult }, @@ -240,7 +268,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresVersion13() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -262,7 +290,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresVersionHeader() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -284,7 +312,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresKeyHeader() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -306,7 +334,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testUpgradeMayAddCustomHeaders() throws { - let upgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + let upgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in var hdrs = HTTPHeaders() hdrs.add(name: "TestHeader", value: "TestValue") return channel.eventLoop.makeSucceededFuture(hdrs) @@ -329,8 +357,8 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testMayRegisterMultipleWebSocketEndpoints() throws { - func buildHandler(path: String) -> NIOWebSocketServerUpgrader { - return NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + func buildHandler(path: String) -> WebSocketServerUpgraderConfiguration { + return WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in guard head.uri == "/\(path)" else { return channel.eventLoop.makeSucceededFuture(nil) } var hdrs = HTTPHeaders() hdrs.add(name: "Target", value: path) @@ -360,7 +388,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testSendAFewFrames() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -398,7 +426,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testMaxFrameSize() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: 16, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(maxFrameSize: 16, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in return channel.eventLoop.makeSucceededFuture(()) }) @@ -423,7 +451,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testAutomaticErrorHandling() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -461,7 +489,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testNoAutomaticErrorHandling() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(automaticErrorHandling: false, + let basicUpgrader = WebSocketServerUpgraderConfiguration(automaticErrorHandling: false, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -498,3 +526,29 @@ class WebSocketServerEndToEndTests: XCTestCase { XCTAssertNoThrow(XCTAssertEqual([], try server.readAllOutboundBytes())) } } + +final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { + override func createTestFixtures( + upgraders: [WebSocketServerUpgraderConfiguration] + ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { + let loop = EmbeddedEventLoop() + let serverChannel = EmbeddedChannel(loop: loop) + let upgraders = upgraders.map { NIOTypedWebSocketServerUpgrader( + maxFrameSize: $0.maxFrameSize, + enableAutomaticErrorHandling: $0.automaticErrorHandling, + shouldUpgrade: $0.shouldUpgrade, + upgradePipelineHandler: $0.upgradePipelineHandler + )} + + XCTAssertNoThrow(try serverChannel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration( + upgraders: upgraders, + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + ) + )) + let clientChannel = EmbeddedChannel(loop: loop) + return (loop: loop, serverChannel: serverChannel, clientChannel: clientChannel) + } +} From 388b85a2e61b4dd5522a396a8b235112c0e3aa09 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 3 Oct 2023 14:21:59 +0100 Subject: [PATCH 04/64] Fix Sendable warning in `NIOPipeBoostrap` (#2530) # Motivation We were accessing `self` in a `@Sendable` function in the `NIOPipeBoostrap`. The 5.10 compiler correctly diagnoses this as a Sendable violation since the bootstrap is in-fact not Sendable. # Modification Store the `channelOptions` in a local variable outside of the `@Sendable` function. # Result No more Sendable warnings on 5.10 --- Sources/NIOPosix/Bootstrap.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 7d5c2a6606..d52f2b9053 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -2142,6 +2142,7 @@ extension NIOPipeBootstrap { "illegal file descriptor pair. The file descriptors \(input), \(output) " + "must be distinct and both positive integers.") let eventLoop = group.next() + let channelOptions = self._channelOptions try self.validateFileDescriptorIsNotAFile(input) try self.validateFileDescriptorIsNotAFile(output) @@ -2161,7 +2162,7 @@ extension NIOPipeBootstrap { @Sendable func setupChannel() -> EventLoopFuture { eventLoop.assertInEventLoop() - return self._channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in + return channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in channelInitializer(channel) }.flatMap { result in eventLoop.assertInEventLoop() From 44a74872464ae2b8ab7722455b679d63a511b183 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 3 Oct 2023 15:41:50 +0100 Subject: [PATCH 05/64] Breaking SPI(AsyncChannel): Align back pressure naming (#2527) # Motivation We spelled back pressure a few different ways throughout our async interfaces. # Modification This PR aligns to `backPressure` everywhere and uses `back pressure` in comments. --- .../NIOCore/AsyncChannel/AsyncChannel.swift | 30 +++++++++---------- .../AsyncChannelInboundStream.swift | 12 ++++---- Sources/NIOPosix/Bootstrap.swift | 28 ++++++++--------- .../AsyncChannel/AsyncChannelTests.swift | 4 +-- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index 0fac8cce1d..aa82aaf374 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -19,14 +19,14 @@ /// the following functionality: /// /// - reads are presented as an `AsyncSequence` -/// - writes can be written to with async functions on a writer, providing backpressure +/// - writes can be written to with async functions on a writer, providing back pressure /// - channels can be closed seamlessly /// /// This type does not replace the full complexity of NIO's ``Channel``. In particular, it /// does not expose the following functionality: /// /// - user events -/// - traditional NIO backpressure such as writability signals and the ``Channel/read()`` call +/// - traditional NIO back pressure such as writability signals and the ``Channel/read()`` call /// /// Users are encouraged to separate their ``ChannelHandler``s into those that implement /// protocol-specific logic (such as parsers and encoders) and those that implement business @@ -37,8 +37,8 @@ public final class NIOAsyncChannel: Sendable { @_spi(AsyncChannel) public struct Configuration: Sendable { - /// The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. - public var backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark + /// The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. + public var backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark /// If outbound half closure should be enabled. Outbound half closure is triggered once /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. @@ -53,19 +53,19 @@ public final class NIOAsyncChannel: Senda /// Initializes a new ``NIOAsyncChannel/Configuration``. /// /// - Parameters: - /// - backpressureStrategy: The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. Defaults + /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. Defaults /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. Defaults to `false`. /// - inboundType: The ``NIOAsyncChannel/inboundStream`` message's type. /// - outboundType: The ``NIOAsyncChannel/outboundWriter`` message's type. public init( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), isOutboundHalfClosureEnabled: Bool = false, inboundType: Inbound.Type = Inbound.self, outboundType: Outbound.Type = Outbound.self ) { - self.backpressureStrategy = backpressureStrategy + self.backPressureStrategy = backPressureStrategy self.isOutboundHalfClosureEnabled = isOutboundHalfClosureEnabled self.inboundType = inboundType self.outboundType = outboundType @@ -99,7 +99,7 @@ public final class NIOAsyncChannel: Senda channel.eventLoop.preconditionInEventLoop() self.channel = channel (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( - backpressureStrategy: configuration.backpressureStrategy, + backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) } @@ -123,7 +123,7 @@ public final class NIOAsyncChannel: Senda channel.eventLoop.preconditionInEventLoop() self.channel = channel (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( - backpressureStrategy: configuration.backpressureStrategy, + backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) @@ -147,13 +147,13 @@ public final class NIOAsyncChannel: Senda @_spi(AsyncChannel) public static func wrapAsyncChannelWithTransformations( synchronouslyWrapping channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, isOutboundHalfClosureEnabled: Bool = false, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannel where Outbound == Never { channel.eventLoop.preconditionInEventLoop() let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, channelReadTransformation: channelReadTransformation ) @@ -174,7 +174,7 @@ extension Channel { @inlinable @_spi(AsyncChannel) public func _syncAddAsyncHandlers( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() @@ -182,7 +182,7 @@ extension Channel { let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeWrappingHandler( channel: self, - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, closeRatchet: closeRatchet ) let writer = try NIOAsyncChannelOutboundWriter( @@ -196,7 +196,7 @@ extension Channel { @inlinable @_spi(AsyncChannel) public func _syncAddAsyncHandlersWithTransformations( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { @@ -205,7 +205,7 @@ extension Channel { let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeTransformationHandler( channel: self, - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, closeRatchet: closeRatchet, channelReadTransformation: channelReadTransformation ) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index 0b5d199fd0..746d128f5d 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -80,14 +80,14 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable init( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, closeRatchet: CloseRatchet, handler: NIOAsyncChannelInboundStreamChannelHandler ) throws { channel.eventLoop.preconditionInEventLoop() let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark - if let userProvided = backpressureStrategy { + if let userProvided = backPressureStrategy { strategy = userProvided } else { // Default strategy. These numbers are fairly arbitrary, but they line up with the default value of @@ -108,7 +108,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeWrappingHandler( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, closeRatchet: CloseRatchet ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandler( @@ -118,7 +118,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { return try .init( channel: channel, - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, closeRatchet: closeRatchet, handler: handler ) @@ -128,7 +128,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeTransformationHandler( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, closeRatchet: CloseRatchet, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannelInboundStream { @@ -140,7 +140,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { return try .init( channel: channel, - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, closeRatchet: closeRatchet, handler: handler ) diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index d52f2b9053..02e5a80794 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -469,7 +469,7 @@ extension ServerBootstrap { /// - Parameters: /// - host: The host to bind on. /// - port: The port to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @@ -478,14 +478,14 @@ extension ServerBootstrap { public func bind( host: String, port: Int, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { let address = try SocketAddress.makeAddressResolvingHost(host, port: port) return try await bind( to: address, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer ) } @@ -494,7 +494,7 @@ extension ServerBootstrap { /// /// - Parameters: /// - address: The `SocketAddress` to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @@ -502,7 +502,7 @@ extension ServerBootstrap { @_spi(AsyncChannel) public func bind( to address: SocketAddress, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { return try await bind0( @@ -514,7 +514,7 @@ extension ServerBootstrap { enableMPTCP: enableMPTCP ) }, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer, registration: { serverChannel in serverChannel.registerAndDoSynchronously { serverChannel in @@ -530,7 +530,7 @@ extension ServerBootstrap { /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, /// unless `cleanupExistingSocketFile`is set to `true`. /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @@ -539,7 +539,7 @@ extension ServerBootstrap { public func bind( unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { if cleanupExistingSocketFile { @@ -550,7 +550,7 @@ extension ServerBootstrap { return try await self.bind( to: address, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer ) } @@ -559,7 +559,7 @@ extension ServerBootstrap { /// /// - Parameters: /// - socket: The _Unix file descriptor_ representing the bound stream socket. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @@ -568,7 +568,7 @@ extension ServerBootstrap { public func bind( _ socket: NIOBSDSocket.Handle, cleanupExistingSocketFile: Bool = false, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { return try await bind0( @@ -582,7 +582,7 @@ extension ServerBootstrap { group: childEventLoopGroup ) }, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer, registration: { serverChannel in let promise = serverChannel.eventLoop.makePromise(of: Void.self) @@ -595,7 +595,7 @@ extension ServerBootstrap { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) private func bind0( makeServerChannel: @escaping (SelectableEventLoop, EventLoopGroup, Bool) throws -> ServerSocketChannel, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, registration: @escaping @Sendable (Channel) -> EventLoopFuture ) -> EventLoopFuture> { @@ -625,7 +625,7 @@ extension ServerBootstrap { let asyncChannel = try NIOAsyncChannel .wrapAsyncChannelWithTransformations( synchronouslyWrapping: serverChannel, - backpressureStrategy: serverBackpressureStrategy, + backPressureStrategy: serverBackPressureStrategy, channelReadTransformation: { channel -> EventLoopFuture in // The channelReadTransformation is run on the EL of the server channel // We have to make sure that we execute child channel initializer on the diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 170add83fe..277cd4bbee 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -317,7 +317,7 @@ final class AsyncChannelTests: XCTestCase { try await channel.closeIgnoringSuppression() } - func testManagingBackpressure() async throws { + func testManagingBackPressure() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let readCounter = ReadCounter() @@ -326,7 +326,7 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel( synchronouslyWrapping: channel, configuration: .init( - backpressureStrategy: .init(lowWatermark: 2, highWatermark: 4), + backPressureStrategy: .init(lowWatermark: 2, highWatermark: 4), inboundType: Void.self, outboundType: Never.self ) From f114f99f2044ab5f0c307c517cc2c72b0f695e85 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 4 Oct 2023 10:34:11 +0100 Subject: [PATCH 06/64] SPI(AsyncChannel): Make `NIOAsyncChannel` a struct (#2528) # Motivation To reduce allocations when wrapping a channel into a `NIOAsyncChannel` we should make the `NIOAsyncChannel` a struct instead of a class. # Modification This PR changes the `NIOAsyncChannel` to a struct. --- Sources/NIOCore/AsyncChannel/AsyncChannel.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index aa82aaf374..466b996355 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -34,7 +34,7 @@ /// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @_spi(AsyncChannel) -public final class NIOAsyncChannel: Sendable { +public struct NIOAsyncChannel: Sendable { @_spi(AsyncChannel) public struct Configuration: Sendable { /// The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. @@ -76,6 +76,8 @@ public final class NIOAsyncChannel: Senda @_spi(AsyncChannel) public let channel: Channel /// The stream of inbound messages. + /// + /// - Important: The `inboundStream` is a unicast `AsyncSequence` and only one iterator can be created. @_spi(AsyncChannel) public let inboundStream: NIOAsyncChannelInboundStream /// The writer for writing outbound messages. From d367dc0a878d5b5e0027b2131d182fa8a5fa606a Mon Sep 17 00:00:00 2001 From: David Nadoba Date: Wed, 4 Oct 2023 15:42:26 +0200 Subject: [PATCH 07/64] Tolerate empty HTTP response body parts (#2531) * Tolerate empty HTTP response body parts ### Motivation Empty HTTP response body parts currently eagerly end a response. This is unexpected and against NIOs documented behaviour. ### Modification - add test which send an empty response body part and would previously fail - skip empty response body parts ### Result Users can send empty body parts without ending the response. * add comment explaining why we do it here instead of fixing it at the source * move empty body part logic to `writeChunk` --- Sources/NIOHTTP1/HTTPEncoder.swift | 42 +++++------- .../NIOHTTP1Tests/HTTPServerClientTest.swift | 66 +++++++++++++++++++ 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/Sources/NIOHTTP1/HTTPEncoder.swift b/Sources/NIOHTTP1/HTTPEncoder.swift index da3fa33a69..c79283ef87 100644 --- a/Sources/NIOHTTP1/HTTPEncoder.swift +++ b/Sources/NIOHTTP1/HTTPEncoder.swift @@ -15,26 +15,24 @@ import NIOCore private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHandlerContext, isChunked: Bool, chunk: IOData, promise: EventLoopPromise?) { - let (mW1, mW2, mW3): (EventLoopPromise?, EventLoopPromise?, EventLoopPromise?) - - switch (isChunked, promise) { - case (true, .some(let p)): - /* chunked encoding and the user's interested: we need three promises and need to cascade into the users promise */ - let (w1, w2, w3) = (context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise) - w1.futureResult.and(w2.futureResult).and(w3.futureResult).map { (_: ((((), ()), ()))) in }.cascade(to: p) - (mW1, mW2, mW3) = (w1, w2, w3) - case (false, .some(let p)): - /* not chunked, so just use the user's promise for the actual data */ - (mW1, mW2, mW3) = (nil, p, nil) - case (_, .none): - /* user isn't interested, let's not bother even allocating promises */ - (mW1, mW2, mW3) = (nil, nil, nil) - } - let readableBytes = chunk.readableBytes - /* we don't want to copy the chunk unnecessarily and therefore call write an annoyingly large number of times */ - if isChunked { + // we don't want to copy the chunk unnecessarily and therefore call write an annoyingly large number of times + // we also don't frame empty chunks as they would otherwise end the response stream + // we still need to write the empty IODate to complete the promise in the right order but is otherwise a no-op. + if isChunked && readableBytes > 0 { + let (mW1, mW2, mW3): (EventLoopPromise?, EventLoopPromise?, EventLoopPromise?) + + if let p = promise { + /* chunked encoding and the user's interested: we need three promises and need to cascade into the users promise */ + let (w1, w2, w3) = (context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise) + w1.futureResult.and(w2.futureResult).and(w3.futureResult).map { (_: ((((), ()), ()))) in }.cascade(to: p) + (mW1, mW2, mW3) = (w1, w2, w3) + } else { + /* user isn't interested, let's not bother even allocating promises */ + (mW1, mW2, mW3) = (nil, nil, nil) + } + var buffer = context.channel.allocator.buffer(capacity: 32) let len = String(readableBytes, radix: 16) buffer.writeString(len) @@ -47,7 +45,7 @@ private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHan buffer.moveReaderIndex(forwardBy: buffer.readableBytes - 2) context.write(wrapOutboundOut(.byteBuffer(buffer)), promise: mW3) } else { - context.write(wrapOutboundOut(chunk), promise: mW2) + context.write(wrapOutboundOut(chunk), promise: promise) } } @@ -195,12 +193,6 @@ public final class HTTPRequestEncoder: ChannelOutboundHandler, RemovableChannelH buffer.write(request: request) }, context: context, headers: request.headers, promise: promise) case .body(let bodyPart): - guard bodyPart.readableBytes > 0 else { - // Empty writes shouldn't send any bytes in chunked or identity encoding. - context.write(self.wrapOutboundOut(bodyPart), promise: promise) - return - } - writeChunk(wrapOutboundOut: self.wrapOutboundOut, context: context, isChunked: self.isChunked, chunk: bodyPart, promise: promise) case .end(let trailers): writeTrailers(wrapOutboundOut: self.wrapOutboundOut, context: context, isChunked: self.isChunked, trailers: trailers, promise: promise) diff --git a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift index c3ab6fbcee..dc6b82fdfb 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift @@ -251,6 +251,25 @@ class HTTPServerClientTest : XCTestCase { self.sentEnd = true self.maybeClose(context: context) } + case "/zero-length-body-part": + + let r = HTTPServerResponsePart.head(.init(version: req.version, status: .ok)) + context.write(self.wrapOutboundOut(r)).whenFailure { error in + XCTFail("unexpected error \(error)") + } + + context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer())))).whenFailure { error in + XCTFail("unexpected error \(error)") + } + context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer(string: "Hello World"))))).whenFailure { error in + XCTFail("unexpected error \(error)") + } + context.write(self.wrapOutboundOut(.end(nil))).recover { error in + XCTFail("unexpected error \(error)") + }.whenComplete { (_: Result) in + self.sentEnd = true + self.maybeClose(context: context) + } default: XCTFail("received request to unknown URI \(req.uri)") } @@ -437,6 +456,53 @@ class HTTPServerClientTest : XCTestCase { func testSimpleGetTrailersFileRegion() throws { try testSimpleGetTrailers(.fileRegion) } + + func testSimpleGetChunkedEncodingWithZeroLengthBodyPart() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + var expectedHeaders = HTTPHeaders() + expectedHeaders.add(name: "transfer-encoding", value: "chunked") + + let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "Hello World") + + let httpHandler = SimpleHTTPServer(.byteBuffer) + let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait()) + + defer { + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) + } + + let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } + } + .connect(to: serverChannel.localAddress!) + .wait()) + + defer { + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) + } + + var head = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/zero-length-body-part") + head.headers.add(name: "Host", value: "apple.com") + clientChannel.write(NIOAny(HTTPClientRequestPart.head(head)), promise: nil) + try clientChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil))).wait() + accumulation.syncWaitForCompletion() + } private func testSimpleGetTrailers(_ mode: SendMode) throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) From 6dd3e08b8e3c004f2254e18116f7088b46a128cc Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 4 Oct 2023 15:50:40 +0100 Subject: [PATCH 08/64] SPI(AsyncChannel): Update the concurrency documentation (#2529) # Motivation We have done quite some changes on the new async interfaces and our documentation needs some updating. # Modification This PR updates the concurrency interop documentation to reflect the latest APIs. --- .../NIOCore/Docs.docc/swift-concurrency.md | 375 +++++++++--------- 1 file changed, 197 insertions(+), 178 deletions(-) diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index 90a260f282..609a2c87d4 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -2,17 +2,21 @@ This article explains how to interface between NIO and Swift Concurrency. -NIO was created before native Concurrency support in Swift existed, hence, NIO had to solve -a few problems that have solutions in the language today. Since the introduction of Swift Concurrency, -NIO has added numerous features to make the interop between NIO's ``Channel`` eventing system and Swift's -Concurrency primitives as easy as possible. +NIO was created before native Concurrency support in Swift existed, hence, NIO +had to solve a few problems that have solutions in the language today. Since the +introduction of Swift Concurrency, NIO has added numerous features to make the +interop between NIO's ``Channel`` eventing system and Swift's Concurrency +primitives as easy as possible. -### EventLoopFuture bridges +### EventLoopFuture/Promise bridges -The first bridges that NIO introduced added methods on ``EventLoopFuture`` and ``EventLoopPromise`` -to enable communication between Concurrency and NIO. These methods are ``EventLoopFuture/get()`` and ``EventLoopPromise/completeWithTask(_:)``. +The first bridges that NIO introduced added methods on ``EventLoopFuture`` and +``EventLoopPromise`` to enable communication between Concurrency and NIO. These +methods are ``EventLoopFuture/get()`` and +``EventLoopPromise/completeWithTask(_:)``. -> Warning: The future ``EventLoopFuture/get()`` method does not support task cancellation. +> Warning: The future ``EventLoopFuture/get()`` method does not support task +> cancellation. Here is a small example of how these work: @@ -29,99 +33,115 @@ promise.completeWithTask { let result = try await promise.futureResult.get() ``` -> Note: The `completeWithTask` method creates an unstructured task under the hood. +> Note: The `completeWithTask` method creates an unstructured task under the +> hood. ### Channel bridges -The ``EventLoopFuture`` and ``EventLoopPromise`` bridges already allow async code to interact with -some of NIO's types. However, they only work where we have request-response-like interfaces. -On the other hand, NIO's ``Channel`` type contains a ``ChannelPipeline`` which can be roughly -described as a bi-directional streaming pipeline. To bridge such a pipeline into Concurrency required -new types. Importantly, these types need to uphold the channel's back-pressure and writability guarantees. -NIO introduced the ``NIOThrowingAsyncSequenceProducer``, ``NIOAsyncSequenceProducer`` and the ``NIOAsyncWriter`` -which form the foundation to bridge a ``Channel``. -On top of these foundational types, NIO provides the `NIOAsyncChannel` which is used to wrap a -``Channel`` to produce an interface that can be consumed directly from Swift Concurrency. The following -sections cover the details of the foundational types and how the `NIOAsyncChannel` works. +The ``EventLoopFuture`` and ``EventLoopPromise`` bridges already allow async +code to interact with some of NIO's types. However, they only work where we have +request-response-like interfaces. On the other hand, NIO's ``Channel`` type +contains a ``ChannelPipeline`` which can be roughly described as a +bi-directional streaming pipeline. To bridge such a pipeline into Concurrency +required new types. Importantly, these types need to uphold the channel's +back pressure and writability guarantees. NIO introduced the +``NIOThrowingAsyncSequenceProducer``, ``NIOAsyncSequenceProducer`` and the +``NIOAsyncWriter`` which form the foundation to bridge a ``Channel``. On top of +these foundational types, NIO provides the `NIOAsyncChannel` which is used to +wrap a ``Channel`` to produce an interface that can be consumed directly from +Swift Concurrency. The following sections cover the details of the foundational +types and how the `NIOAsyncChannel` works. #### NIOThrowingAsyncSequenceProducer and NIOAsyncSequenceProducer -The ``NIOThrowingAsyncSequenceProducer`` and ``NIOAsyncSequenceProducer`` are asynchronous sequences -similar to Swift's `AsyncStream`. Their purpose is to provide a back-pressured bridge between a -synchronous producer and an asynchronous consumer. These types are highly configurable and generic which -makes them usable in a lot of places with very good performance; however, at the same time they are -not the easiest types to hold. We recommend that you **never** expose them in public API but rather -wrap them in your own async sequence. +The ``NIOThrowingAsyncSequenceProducer`` and ``NIOAsyncSequenceProducer`` are +asynchronous sequences similar to Swift's `AsyncStream`. Their purpose is to +provide a back pressured bridge between a synchronous producer and an +asynchronous consumer. These types are highly configurable and generic which +makes them usable in a lot of places with very good performance; however, at the +same time they are not the easiest types to hold. We recommend that you +**never** expose them in public API but rather wrap them in your own async +sequence. #### NIOAsyncWriter -The ``NIOAsyncWriter`` is used for bridging from an asynchronous producer to a synchronous consumer. -It also has back-pressure support which allows the consumer to stop the producer by suspending the +The ``NIOAsyncWriter`` is used for bridging from an asynchronous producer to a +synchronous consumer. It also has back pressure support which allows the +consumer to stop the producer by suspending the ``NIOAsyncWriter/yield(contentsOf:)`` method. - -> Important: Everything below this is currently not public API but can be tested it by using `@_spi(AsyncChannel) import`. -The APIs might change until they become publicly available. - #### NIOAsyncChannel -The above types are used to bridge both the read and write side of a ``Channel`` into Swift Concurrency. -This can be done by wrapping a ``Channel`` via the `NIOAsyncChannel/init(synchronouslyWrapping:backpressureStrategy:isOutboundHalfClosureEnabled:inboundType:outboundType:)` -initializer. Under the hood, this initializer adds two channel handlers to the end of the channel's pipeline. -These handlers bridge the read and write side of the channel. Additionally, the handlers work together -to close the channel once both the reading and the writing have finished. +The above types are used to bridge both the read and write side of a ``Channel`` +into Swift Concurrency. This can be done by wrapping a ``Channel`` via the +`NIOAsyncChannel/init(synchronouslyWrapping:configuration:)` +initializer. Under the hood, this initializer adds two channel handlers to the +end of the channel's pipeline. These handlers bridge the read and write side of +the channel. Additionally, the handlers work together to close the channel once +both the reading and the writing have finished. - -This is how you can wrap an existing channel into a `NIOAsyncChannel`, consume the inbound data and -echo it back outbound. +This is how you can wrap an existing channel into a `NIOAsyncChannel`, consume +the inbound data and echo it back outbound. ```swift let channel = ... -let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) +let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) for try await inboundData in asyncChannel.inboundStream { try await asyncChannel.outboundWriter.write(inboundData) } ``` -The above code works nicely; however, you must be very careful at what point you wrap your channel -otherwise you might lose some reads. For example your channel might be created by a `ServerBootstrap` -for a new inbound connection. The channel might start to produce reads as soon as it registered its -IO which happens after your channel initializer ran. To avoid potentially losing reads the channel -must be wrapped before it registered its IO. -Another example is when the channel contains a handler that does protocol negotiation. Protocol negotiation handlers -are usually waiting for some data to be exchanged before deciding what protocol to chose. Afterwards, they -often modify the channel's pipeline and add the protocol appropriate handlers to it. This is another -case where wrapping of the `Channel` into a `NIOAsyncChannel` needs to happen at the right time to avoid -losing reads. +The above code works nicely; however, you must be very careful at what point you +wrap your channel otherwise you might lose some reads. For example your channel +might be created by a `ServerBootstrap` for a new inbound connection. The +channel might start to produce reads as soon as it registered its IO which +happens after your channel initializer ran. To avoid potentially losing reads +the channel must be wrapped before it registered its IO. Another example is when +the channel contains a handler that does protocol negotiation. Protocol +negotiation handlers are usually waiting for some data to be exchanged before +deciding what protocol to chose. Afterwards, they often modify the channel's +pipeline and add the protocol appropriate handlers to it. This is another case +where wrapping of the `Channel` into a `NIOAsyncChannel` needs to happen at the +right time to avoid losing reads. -### Async bootstrap +### Asynchronous bootstrap methods -NIO offers three different kind of bootstraps `ServerBootstrap`, `ClientBootstrap` and `DatagramBootstrap`. -The next section is going to focus on how to use the methods of these three types to bootstrap connections -using `NIOAsyncChannel`. +NIO offers a multitude of bootstraps. To avoid the above problems +and enable a seamless experience when using NIO from Swift Concurrency, +the bootstraps gained new generic asynchronous methods. +The next section is going to focus on how to use the methods to boostrap a TCP +server and client. #### ServerBootstrap -The server bootstrap is used to create a new TCP based server. Once any of the bind methods on the `ServerBootstrap` -is called, a new listening socket is created to handle new inbound TCP connections. Let's take a look -at the new `NIOAsyncChannel` based bind methods. +The server bootstrap is used to create a new TCP based server. Once any of the +bind methods on the `ServerBootstrap` is called, a new listening socket is +created to handle new inbound TCP connections. Let's use the new methods +to setup a TCP server and configure a `NIOAsyncChannel` for each inbound +connection. ```swift let serverChannel = try await ServerBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", - port: 0, - childChannelInboundType: ByteBuffer.self, - childChannelOutboundType: ByteBuffer.self - ) + port: 1234 + ) { childChannel in + // This closure is called for every inbound connection + childChannel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + synchronouslyWrapping: childChannel + ) + } + } try await withThrowingDiscardingTaskGroup { group in for try await connectionChannel in serverChannel.inboundStream { group.addTask { do { for try await inboundData in connectionChannel.inboundStream { + // Let's echo back all inbound data try await connectionChannel.outboundWriter.write(inboundData) } } catch { @@ -132,28 +152,38 @@ try await withThrowingDiscardingTaskGroup { group in } ``` -In the above code, we are bootstrapping a new TCP server which we assign to `serverChannel`. -The `serverChannel` is a `NIOAsyncChannel` whose inbound type is a `NIOAsyncChannel` and whose -outbound type is `Never`. This is due to the fact that each inbound connection gets its own separate child channel. -The inbound and outbound types of each inbound connection is `ByteBuffer` as specified in the bootstrap. -Afterwards, we handle each inbound connection in separate child tasks and echo the data back. - -> Important: Make sure to use discarding task groups which automatically reap finished child tasks. -Normal task groups will result in a memory leak since they do not reap their child tasks automatically. +In the above code, we are bootstrapping a new TCP server which we assign to +`serverChannel`. Furthermore, in the trailing closure of `bind` we are +configuring the pipeline of each inbound connection. In our example, we are +wrapping each child channel in a `NIOAsyncChannel`. The resulting +`serverChannel` is a `NIOAsyncChannel` whose inbound type is a `NIOAsyncChannel` +and whose outbound type is `Never`. This is due to the fact that each inbound +connection gets its own separate child channel. The inbound and outbound types +of each inbound connection is `ByteBuffer` as specified in the bootstrap. +Afterwards, we handle each inbound connection in separate child tasks and echo +the data back. + +> Important: Make sure to use discarding task groups which automatically reap +finished child tasks. Normal task groups will result in a memory leak since they +do not reap their child tasks automatically. #### ClientBootstrap -The client bootstrap is used to create a new TCP based client. Let's take a look at the new -`NIOAsyncChannel` based connect methods. +The client bootstrap is used to create a new TCP based client. Let's take a look +how to bootstrap a TCP connection and send some data to the server. ```swift let clientChannel = try await ClientBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", - port: 0, - channelInboundType: ByteBuffer.self, - channelOutboundType: ByteBuffer.self - ) + port: 1234 + ) { channel in + channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + synchronouslyWrapping: channel + ) + } + } clientChannel.outboundWriter.write(ByteBuffer(string: "hello")) @@ -162,126 +192,115 @@ for try await inboundData in clientChannel.inboundStream { } ``` -#### DatagramBootstrap -> Important: Support for `DatagramBootstrap` with `NIOAsyncChannel` hasn't landed yet. - -#### Protocol negotiation - -The above bootstrap methods work great in the case where we know the types of the resulting channels -at compile time. However, as mentioned previously protocol negotiation is another case where the timing -of wrapping the ``Channel`` is important that we haven't covered with the `bind` methods that take -an inbound and outbound type yet. -To solve the problem of protocol negotiation, NIO introduced a new ``ChannelHandler`` protocol called -`NIOProtocolNegotiationHandler`. This protocol requires a single future property `NIOProtocolNegotiationHandler/protocolNegotiationResult` -that is completed once the handler is finished with protocol negotiation. In the successful case, -the future can either indicate that protocol negotiation is fully done by returning `NIOProtocolNegotiationResult/finished(_:)` or -indicate that further protocol negotiation needs to be done by returning `NIOProtocolNegotiationResult/deferredResult(_:)`. -Additionally, the various bootstraps provide another set of `bind()`/`connect()` methods that handle protocol negotiation. -Let's walk through how to setup a `ServerBootstrap` with protocol negotiation. - -First, we have to define our negotiation result. For this example, we are negotiating between a -`String` based and `UInt8` based channel. Additionally, we also need an error that we can throw -if protocol negotiation failed. +#### Dynamic pipeline modifications + +The above bootstrap methods work great in the case where we know the types of +the resulting channels at compile time. However, there are some scenarios where +the type is only determined at runtime. Such cases include +[Application-Layer-Protocol-Negotiation](https://en.wikipedia.org/wiki/Application-Layer_Protocol_Negotiation) +or [HTTP protocol +upgrades](https://en.wikipedia.org/wiki/HTTP/1.1_Upgrade_header). To support +those scenarios it is essential that channel handlers that dynamically configure +the pipeline carry type information which allows us to runtime to determine how +the pipeline was configured at runtime. To support this NIO introduced multiple +new `ChannelHandler` and corresponding pipeline configuration methods. Those +types are: + +1. `NIOTypedApplicationProtocolNegotiationHandler` for TLS based ALPN +2. `NIOTypedHTTPServerUpgradeHandler` and + `configureUpgradableHTTPServerPipeline` for server-side HTTP protocol + upgrades +2. `NIOTypedHTTPClientUpgradeHandler` and + `configureUpgradableHTTPClientPipeline` for client-side HTTP protocol + upgrades + +All of those types have one thing in common - they are generic over the result +of the dynamic pipeline configuration. This allows users to exhaustively switch +over the result and correctly handle each case. The following example +demonstrates how this works for a client-side websocket upgrade. + ```swift -enum NegotiationResult { - case string(NIOAsyncChannel) - case byte(NIOAsyncChannel) +enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded } -struct ProtocolNegotiationError: Error {} -``` - -Next, we have to setup our bootstrap. We are adding a `NIOTypedApplicationProtocolNegotiationHandler` -to each child channel's pipeline. This handler listens for user inbound events of the type `TLSUserEvent` -and then calls the provided closure with the result. In our example, we are handling either `string` -or `byte` application protocols. Importantly, we now have to wrap the channel into a `NIOAsyncChannel` ourselves and -return the finished `NIOProtocolNegotiationResult`. -```swift -let serverBoostrap = try await ServerBootstrap(group: eventLoopGroup) - .childChannelInitializer { channel in +let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: eventLoopGroup) + .connect( + host: "127.0.0.1", + port: 1234 + ) { channel in channel.eventLoop.makeCompletedFuture { - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: channel.eventLoop) { alpnResult, channel in - switch alpnResult { - case .negotiated(let alpn): - switch alpn { - case "string": - return channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel, - isOutboundHalfClosureEnabled: true, - inboundType: String.self, - outboundType: String.self - ) - - return NIOProtocolNegotiationResult.finished(NegotiationResult.string(asyncChannel)) - } - case "byte": - return channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel, - isOutboundHalfClosureEnabled: true, - inboundType: UInt8.self, - outboundType: UInt8.self - ) - - return NIOProtocolNegotiationResult.finished(NegotiationResult.byte(asyncChannel)) - } - default: - return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError()) + // Configure the websocket upgrader + let upgrader = NIOTypedWebSocketClientUpgrader( + upgradePipelineHandler: { channel, _ in + // This configures the pipeline after the websocket upgrade was successful. + // We are wrapping the pipeline in a NIOAsyncChannel. + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) + return UpgradeResult.websocket(asyncChannel) } - case .fallback: - return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError()) } - } + ) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "0") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + return UpgradeResult.notUpgraded + } + } + ) - try channel.pipeline.syncOperations.addHandler(negotiationHandler) + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) + + return upgradeResult } } ``` -Lastly, we can now bind the `serverChannel` and handle the incoming connections. In the code below, -you can see that our server channel is now a `NIOAsyncChannel` of `NegotiationResult`s instead of -child channels. -```swift -let serverChannel = serverBootstrap.bind( - host: "127.0.0.1", - port: 1995, - protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler.self -) +After having configured the pipeline to negotiate a websocket upgrade. We can +switch over the the `upgradeResult`. Importantly, we have to `await` the +`upgradeResult` first since it has to be negotiated on the connection. -try await withThrowingDiscardingTaskGroup { group in - for try await negotiationResult in serverChannel.inboundStream { - group.addTask { - do { - switch negotiationResult { - case .string(let channel): - for try await inboundData in channel.inboundStream { - try await channel.outboundWriter.write(inboundData) - } - case .byte(let channel): - for try await value in channel.inboundStream { - try await channel.outboundWriter.write(inboundData) - } - } - } catch { - // Handle errors - } - } - } +``` +switch try await upgradeResult.get() { +case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") +case .notUpgraded: + // The upgrade to websocket did not succeed. + print("Upgrade declined") } ``` - ### General guidance #### Where should your code live? -Before the introduction of Swift Concurrency both implementations of network protocols and business logic -were often written inside ``ChannelHandler``s. This made it easier to get started; however, it came with -some downsides. First, implementing business logic inside channel handlers requires the business logic to -also handle all of the invariants that the ``ChannelHandler`` protocol brings with it. This often requires -writing complex state machines. Additionally, the business logic becomes very tied to NIO and hard to -port between different systems. -Because of the above reasons we recommend to implement your business logic using Swift Concurrency primitives and the -`NIOAsyncChannel` based bootstraps. Network protocol implementation should still be implemented as +Before the introduction of Swift Concurrency both implementations of network +protocols and business logic were often written inside ``ChannelHandler``s. This +made it easier to get started; however, it came with some downsides. First, +implementing business logic inside channel handlers requires the business logic +to also handle all of the invariants that the ``ChannelHandler`` protocol brings +with it. This often requires writing complex state machines. Additionally, the +business logic becomes very tied to NIO and hard to port between different +systems. Because of the above reasons we recommend to implement your business +logic using Swift Concurrency primitives and the `NIOAsyncChannel` based +bootstraps. Network protocol implementation should still be implemented as ``ChannelHandler``s. From ded781e2dd57082b39ae5bc8d4a2e2504d9eda21 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Thu, 5 Oct 2023 15:03:00 +0100 Subject: [PATCH 09/64] Fix test availability for tests (#2533) # Motivation Building on Darwin is currently broken since we are missing availability annotations on some of the tests. --- Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift | 1 + Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift | 5 +++++ Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift | 1 + 3 files changed, 7 insertions(+) diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 42de8b87d5..56c7bd3cfd 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -278,6 +278,7 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPClientUpgradeTestCase: XCTestCase { // MARK: Test basic happy path requests and responses. diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 9be0cc4c22..3b4ab37c9b 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -34,6 +34,7 @@ extension ChannelPipeline { } } + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func assertContainsUpgrader() throws { do { _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() @@ -88,6 +89,7 @@ extension EmbeddedChannel { private typealias UpgradeCompletionHandler = @Sendable (ChannelHandlerContext) -> Void +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) private func serverHTTPChannelWithAutoremoval(group: EventLoopGroup, pipelining: Bool, upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], @@ -164,6 +166,7 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e XCTAssertEqual(lines.count, 0) } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { @@ -407,6 +410,7 @@ private class ReentrantReadOnChannelReadCompleteHandler: ChannelInboundHandler { } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPServerUpgradeTestCase: XCTestCase { fileprivate func setUpTestWithAutoremoval(pipelining: Bool = false, upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], @@ -1554,6 +1558,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { fileprivate override func setUpTestWithAutoremoval( pipelining: Bool = false, diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index c636de6fc7..0058aa6fc7 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -527,6 +527,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { override func createTestFixtures( upgraders: [WebSocketServerUpgraderConfiguration] From 4ab8b98f9d781bb5cedbcb41d3248d80822dc303 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Fri, 6 Oct 2023 16:33:25 +0100 Subject: [PATCH 10/64] Introduce new typed `HTTPClientUpgrader` and `WebSocketClientUpgrader` (#2526) * Introduce new typed `HTTPClientUpgrader` and `WebSocketClientUpgrader` # Motivation In my previous PR https://github.com/apple/swift-nio/pull/2517, I added a new typed `HTTPServerUpgrader` and corresponding implementation for the `WebSocketServerUpgrader`. The goal of those is to carry type information across HTTP upgrades which allows us to build fully typed pipelines. # Modification This PR adds a few things: 1. A new `NIOTypedHttpClientUpgradeHandler` + `NIOTypedHttpClientProtocolUpgrader`. I also moved the state handling to a separate state machine. Similar to the server PR I did not unify the state machine between the newly typed and untyped upgrade handlers since they differ in logic. 2. A new `NIOTypedWebSocketClientUpgrader` 3. An overhauled WebSocket client example. # Result This is the last missing piece of dynamic pipeline changing where we did not carry around the type information. After this PR lands, we can finalize the `AsyncChannel` and async typed NIO pieces. * Remove availability on the protocols --- Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 122 +++++- .../NIOTypedHTTPClientUpgradeHandler.swift | 286 ++++++++++++++ ...OTypedHTTPClientUpgraderStateMachine.swift | 334 +++++++++++++++++ .../NIOTypedHTTPServerUpgradeHandler.swift | 3 +- .../NIOWebSocketClientUpgrader.swift | 146 ++++++-- Sources/NIOWebSocketClient/Client.swift | 144 ++++++++ Sources/NIOWebSocketClient/main.swift | 233 ------------ .../HTTPClientUpgradeTests.swift | 348 ++++++++++++++++-- .../WebSocketClientEndToEndTests.swift | 229 +++++++++++- 9 files changed, 1520 insertions(+), 325 deletions(-) create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift create mode 100644 Sources/NIOWebSocketClient/Client.swift delete mode 100644 Sources/NIOWebSocketClient/main.swift diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift index c92e41715c..e69d034ba1 100644 --- a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -32,7 +32,7 @@ public struct NIOUpgradableHTTPServerPipelineConfiguration @@ -43,10 +43,6 @@ public struct NIOUpgradableHTTPServerPipelineConfiguration ) { @@ -109,11 +105,12 @@ extension ChannelPipeline.SynchronousOperations { ) throws -> EventLoopFuture { self.eventLoop.assertInEventLoop() - let responseEncoder = HTTPResponseEncoder(configuration: configuration.httpResponseEncoderConfiguration) + let responseEncoder = HTTPResponseEncoder(configuration: configuration.encoderConfiguration) let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - var extraHTTPHandlers: [RemovableChannelHandler] = [requestDecoder] - extraHTTPHandlers.reserveCapacity(3) + var extraHTTPHandlers = [RemovableChannelHandler]() + extraHTTPHandlers.reserveCapacity(4) + extraHTTPHandlers.append(requestDecoder) try self.addHandler(responseEncoder) try self.addHandler(requestDecoder) @@ -146,3 +143,112 @@ extension ChannelPipeline.SynchronousOperations { return upgrader.upgradeResultFuture } } + +// MARK: - Client pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public struct NIOUpgradableHTTPClientPipelineConfiguration { + /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. + public var leftOverBytesStrategy = RemoveAfterUpgradeStrategy.dropBytes + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableOutboundHeaderValidation = true + + /// The configuration for the ``HTTPRequestEncoder``. + public var encoderConfiguration = HTTPRequestEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Outbound header fields validation to protect against response splitting attacks. + public init( + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + @_spi(AsyncChannel) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPClientPipeline(configuration: configuration) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + @_spi(AsyncChannel) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let requestEncoder = HTTPRequestEncoder(configuration: configuration.encoderConfiguration) + let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy)) + var httpHandlers = [RemovableChannelHandler]() + httpHandlers.reserveCapacity(3) + httpHandlers.append(requestEncoder) + httpHandlers.append(responseDecoder) + + try self.addHandler(requestEncoder) + try self.addHandler(responseDecoder) + + if configuration.enableOutboundHeaderValidation { + let headerValidationHandler = NIOHTTPRequestHeadersValidator() + try self.addHandler(headerValidationHandler) + httpHandlers.append(headerValidationHandler) + } + + let upgrader = NIOTypedHTTPClientUpgradeHandler( + httpHandlers: httpHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift new file mode 100644 index 0000000000..82c6f129f8 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -0,0 +1,286 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2013 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +import NIOCore + +/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a client-side channel. +/// It has the option of denying this upgrade based upon the server response. +@_spi(AsyncChannel) +public protocol NIOTypedHTTPClientProtocolUpgrader { + associatedtype UpgradeResult: Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol requires in the request to successfully upgrade. + /// These header fields will be added to the outbound request's "Connection" header field. + /// It is the responsibility of the custom headers call to actually add these required headers. + var requiredUpgradeHeaders: [String] { get } + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + func addCustom(upgradeRequestHeaders: inout HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public struct NIOTypedHTTPClientUpgradeConfiguration { + /// The initial request head that is sent out once the channel becomes active. + public var upgradeRequestHead: HTTPRequestHead + + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgradeRequestHead: HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + precondition(upgraders.count > 0, "A minimum of one protocol upgrader must be specified.") + self.upgradeRequestHead = upgradeRequestHead + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. +/// This handler will add all appropriate headers to perform an upgrade to +/// the a protocol. It may add headers for a set of protocols in preference order. +/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply +/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. +/// +/// The request sends an order of preference to request which protocol it would like to use for the upgrade. +/// It will only upgrade to the protocol that is returned first in the list and does not currently +/// have the capability to upgrade to multiple simultaneous layered protocols. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { + public typealias OutboundIn = HTTPClientRequestPart + public typealias OutboundOut = HTTPClientRequestPart + public typealias InboundIn = HTTPClientResponsePart + public typealias InboundOut = HTTPClientResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + private let upgradeRequestHead: HTTPRequestHead + private let httpHandlers: [RemovableChannelHandler] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private var stateMachine: NIOTypedHTTPClientUpgraderStateMachine + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// Create a ``NIOTypedHTTPClientUpgradeHandler``. + /// + /// - Parameters: + /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline + /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state + /// after the upgrade. It should include any handlers that are directly related to handling HTTP. + /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include + /// any other handler that cannot tolerate receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.httpHandlers = httpHandlers + var upgradeRequestHead = upgradeConfiguration.upgradeRequestHead + Self.addHeaders( + to: &upgradeRequestHead, + upgraders: upgradeConfiguration.upgraders + ) + self.upgradeRequestHead = upgradeRequestHead + self.stateMachine = .init(upgraders: upgradeConfiguration.upgraders) + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelActive(context: ChannelHandlerContext) { + switch self.stateMachine.channelActive() { + case .writeUpgradeRequest: + context.write(self.wrapOutboundOut(.head(self.upgradeRequestHead)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(.init()))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + case .none: + break + } + } + + private static func addHeaders( + to requestHead: inout HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + ) { + let requiredHeaders = ["upgrade"] + upgraders.flatMap { $0.requiredUpgradeHeaders } + requestHead.headers.add(name: "Connection", value: requiredHeaders.joined(separator: ",")) + + let allProtocols = upgraders.map { $0.supportedProtocol.lowercased() } + requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) + + // Allow each upgrader the chance to add custom headers. + for upgrader in upgraders { + upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + switch self.stateMachine.write() { + case .failWrite(let error): + promise?.fail(error) + + case .forwardWrite: + context.write(data, promise: promise) + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let responsePart = self.unwrapInboundIn(data) + self.channelRead(context: context, responsePart: responsePart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + public func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { + switch self.stateMachine.channelReadResponsePart(responsePart) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + + case .startUpgrading(let upgrader, let responseHead): + self.startUpgrading( + context: context, + upgrader: upgrader, + responseHead: responseHead + ) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHead: HTTPResponseHead + ) { + // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the + // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. + self.removeHTTPHandlers(context: context) + .flatMap { + upgrader.upgrade(channel: context.channel, upgradeResponse: responseHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeHTTPHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.httpHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + let removeFutures = self.httpHandlers.map { context.pipeline.removeHandler($0) } + return .andAllSucceed(removeFutures, on: context.eventLoop) + } +} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift new file mode 100644 index 0000000000..fa04481ea9 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift @@ -0,0 +1,334 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPClientUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + /// The request has been sent. We are waiting for the upgrade response. + case awaitingUpgradeResponseHead(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + @usableFromInline + struct AwaitingUpgradeResponseEnd { + var upgrader: any NIOTypedHTTPClientProtocolUpgrader + var responseHead: HTTPResponseHead + } + /// We received the response head and are just waiting for the response end. + case awaitingUpgradeResponseEnd(AwaitingUpgradeResponseEnd) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state: State + + init(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) { + self.state = .initial(upgraders: upgraders) + } + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelActiveAction { + case writeUpgradeRequest + } + + @inlinable + mutating func channelActive() -> ChannelActiveAction? { + switch self.state { + case .initial(let upgraders): + self.state = .awaitingUpgradeResponseHead(upgraders: upgraders) + return .writeUpgradeRequest + + case .finished: + return nil + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering, .upgrading: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum WriteAction { + case failWrite(Error) + case forwardWrite + } + + @usableFromInline + func write() -> WriteAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading: + return .failWrite(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) + + case .unbuffering, .finished: + return .forwardWrite + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd: + return .unwrapData + + case .upgrading(var upgrading): + // We got a read while running upgrading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + + @usableFromInline + enum ChannelReadResponsePartAction { + case fireErrorCaughtAndRemoveHandler(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHeaders: HTTPResponseHead + ) + } + + @inlinable + mutating func channelReadResponsePart(_ responsePart: HTTPClientResponsePart) -> ChannelReadResponsePartAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .awaitingUpgradeResponseHead(let upgraders): + // We should decide if we can upgrade based on the first response header: if we aren't upgrading, + // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're + // upgrading, the only thing we should see is a response head. Anything else in an error. + guard case .head(let response) = responsePart else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + } + + // Assess whether the server has accepted our upgrade request. + guard case .switchingProtocols = response.status else { + var buffer = Deque() + buffer.append(.init(responsePart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // Ok, we have a HTTP response. Check if it's an upgrade confirmation. + // If it's not, we want to pass it on and remove ourselves from the channel pipeline. + let acceptedProtocols = response.headers[canonicalForm: "upgrade"] + + // At the moment we only upgrade to the first protocol returned from the server. + guard let protocolName = acceptedProtocols.first?.lowercased() else { + // There are no upgrade protocols returned. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + let matchingUpgrader = upgraders + .first(where: { $0.supportedProtocol.lowercased() == protocolName }) + + guard let upgrader = matchingUpgrader else { + // There is no upgrader for this protocol. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { + // The upgrader says no. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // We received the response head and decided that we can upgrade. + // We now need to wait for the response end and then we can perform the upgrade + self.state = .awaitingUpgradeResponseEnd(.init( + upgrader: upgrader, + responseHead: response + )) + return .none + + case .awaitingUpgradeResponseEnd(let awaitingUpgradeResponseEnd): + switch responsePart { + case .head: + // We got two HTTP response heads. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + + case .body: + // We tolerate body parts to be send but just ignore them + return .none + + case .end: + // We got the response end and can now run the upgrader. + self.state = .upgrading(.init(buffer: .init())) + return .startUpgrading( + upgrader: awaitingUpgradeResponseEnd.upgrader, + responseHeaders: awaitingUpgradeResponseEnd.responseHead + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + } + } +} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift index a665cd63a2..9e43f2d7d3 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -15,10 +15,9 @@ /// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to /// a protocol on a server-side channel. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) @_spi(AsyncChannel) public protocol NIOTypedHTTPServerProtocolUpgrader { - associatedtype UpgradeResult + associatedtype UpgradeResult: Sendable /// The protocol this upgrader knows how to support. var supportedProtocol: String { get } diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index 5e7df19c4f..b92c5f121c 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import NIOCore -import NIOHTTP1 +@_spi(AsyncChannel) import NIOHTTP1 import _NIOBase64 @available(*, deprecated, renamed: "NIOWebSocketClientUpgrader") @@ -25,7 +25,6 @@ public typealias NIOWebClientSocketUpgrader = NIOWebSocketClientUpgrader /// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the /// pipeline to remove the HTTP `ChannelHandler`s. public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { - /// RFC 6455 specs this as the required entry in the Upgrade header. public let supportedProtocol: String = "websocket" /// None of the websocket headers are actually defined as 'required'. @@ -34,7 +33,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { private let requestKey: String private let maxFrameSize: Int private let automaticErrorHandling: Bool - private let upgradePipelineHandler: (Channel, HTTPResponseHead) -> EventLoopFuture + private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture /// - Parameters: /// - requestKey: sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. @@ -45,7 +44,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { requestKey: String = randomRequestKey(), maxFrameSize: Int = 1 << 14, automaticErrorHandling: Bool = true, - upgradePipelineHandler: @escaping (Channel, HTTPResponseHead) -> EventLoopFuture + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture ) { precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") @@ -57,49 +56,81 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { /// Add additional headers that are needed for a WebSocket upgrade request. public func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { - upgradeRequestHeaders.add(name: "Sec-WebSocket-Key", value: self.requestKey) - upgradeRequestHeaders.add(name: "Sec-WebSocket-Version", value: "13") + _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) } - /// Allow or deny the upgrade based on the upgrade HTTP response - /// headers containing the correct accept key. public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - - let acceptValueHeader = upgradeResponse.headers["Sec-WebSocket-Accept"] + _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) + } - guard acceptValueHeader.count == 1 else { - return false - } + public func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + _upgrade( + channel: context.channel, + upgradeResponse: upgradeResponse, + maxFrameSize: self.maxFrameSize, + enableAutomaticErrorHandling: self.automaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} - // Validate the response key in 'Sec-WebSocket-Accept'. - var hasher = SHA1() - hasher.update(string: self.requestKey) - hasher.update(string: magicWebSocketGUID) - let expectedAcceptValue = String(base64Encoding: hasher.finish()) +/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the +/// pipeline to remove the HTTP `ChannelHandler`s. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +@_spi(AsyncChannel) +public final class NIOTypedWebSocketClientUpgrader: NIOTypedHTTPClientProtocolUpgrader { + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + /// None of the websocket headers are actually defined as 'required'. + public let requiredUpgradeHeaders: [String] = [] + + private let requestKey: String + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture - return expectedAcceptValue == acceptValueHeader[0] + /// - Parameters: + /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. + /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. + /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. + /// - upgradePipelineHandler: Called once the upgrade was successful. + public init( + requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture + ) { + precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.requestKey = requestKey + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling } - /// Called when the upgrade response has been flushed and it is safe to mutate the channel - /// pipeline. Adds channel handlers for websocket frame encoding, decoding and errors. - public func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) { + _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) + } - var upgradeFuture = context.pipeline.addHandler(WebSocketFrameEncoder()).flatMap { - context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize))) - } - - if self.automaticErrorHandling { - upgradeFuture = upgradeFuture.flatMap { - context.pipeline.addHandler(WebSocketProtocolErrorHandler()) - } - } - - return upgradeFuture.flatMap { - self.upgradePipelineHandler(context.channel, upgradeResponse) - } + public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { + _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) + } + + public func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeResponse: upgradeResponse, + maxFrameSize: self.maxFrameSize, + enableAutomaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) } } + @available(*, unavailable) extension NIOWebSocketClientUpgrader: Sendable {} @@ -128,3 +159,48 @@ extension NIOWebSocketClientUpgrader { return NIOWebSocketClientUpgrader.randomRequestKey(using: &generator) } } + +/// Add additional headers that are needed for a WebSocket upgrade request. +private func _addCustom(upgradeRequestHeaders: inout HTTPHeaders, requestKey: String) { + upgradeRequestHeaders.add(name: "Sec-WebSocket-Key", value: requestKey) + upgradeRequestHeaders.add(name: "Sec-WebSocket-Version", value: "13") +} + +/// Allow or deny the upgrade based on the upgrade HTTP response +/// headers containing the correct accept key. +private func _shouldAllowUpgrade(upgradeResponse: HTTPResponseHead, requestKey: String) -> Bool { + let acceptValueHeader = upgradeResponse.headers["Sec-WebSocket-Accept"] + + guard acceptValueHeader.count == 1 else { + return false + } + + // Validate the response key in 'Sec-WebSocket-Accept'. + var hasher = SHA1() + hasher.update(string: requestKey) + hasher.update(string: magicWebSocketGUID) + let expectedAcceptValue = String(base64Encoding: hasher.finish()) + + return expectedAcceptValue == acceptValueHeader[0] +} + +/// Called when the upgrade response has been flushed and it is safe to mutate the channel +/// pipeline. Adds channel handlers for websocket frame encoding, decoding and errors. +private func _upgrade( + channel: Channel, + upgradeResponse: HTTPResponseHead, + maxFrameSize: Int, + enableAutomaticErrorHandling: Bool, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture +) -> EventLoopFuture { + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()) + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))) + if enableAutomaticErrorHandling { + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) + } + } + .flatMap { + upgradePipelineHandler(channel, upgradeResponse) + } +} diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift new file mode 100644 index 0000000000..fef5b1e15e --- /dev/null +++ b/Sources/NIOWebSocketClient/Client.swift @@ -0,0 +1,144 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if swift(>=5.9) +@_spi(AsyncChannel) import NIOCore +@_spi(AsyncChannel) import NIOPosix +@_spi(AsyncChannel) import NIOHTTP1 +@_spi(AsyncChannel) import NIOWebSocket + +@available(macOS 14, *) +@main +struct Client { + /// The host to connect to. + private let host: String + /// The port to connect to. + private let port: Int + /// The client's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded + } + + static func main() async throws { + let client = Client( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await client.run() + } + + /// This method starts the client and tries to setup a WebSocket connection. + func run() async throws { + let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) + .connect( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketClientUpgrader( + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "0") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + return UpgradeResult.notUpgraded + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are awaiting and handling the upgrade result now. + try await self.handleUpgradeResult(upgradeResult) + } + + /// This method handles the upgrade result. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded: + // The upgrade to websocket did not succeed. We are just exiting in this case. + print("Upgrade declined") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + // We are sending a ping frame and then + // start to handle all inbound frames. + + let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) + try await channel.outboundWriter.write(pingFrame) + + for try await frame in channel.inboundStream { + switch frame.opcode { + case .pong: + print("Received pong: \(String(buffer: frame.data))") + + case .text: + print("Received: \(String(buffer: frame.data))") + + case .connectionClose: + // Handle a received close frame. We're just going to close by returning from this method. + print("Received Close instruction from server") + return + case .binary, .continuation, .ping: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } +} + +#else +@main +struct Server { + static func main() { + fatalError("Requires at least Swift 5.9") + } +} +#endif diff --git a/Sources/NIOWebSocketClient/main.swift b/Sources/NIOWebSocketClient/main.swift deleted file mode 100644 index 5e86a9e045..0000000000 --- a/Sources/NIOWebSocketClient/main.swift +++ /dev/null @@ -1,233 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore -import NIOPosix -import NIOHTTP1 -import NIOWebSocket - -print("Establishing connection.") - -enum ConnectTo { - case ip(host: String, port: Int) - case unixDomainSocket(path: String) -} - -// The HTTP handler to be used to initiate the request. -// This initial request will be adapted by the WebSocket upgrader to contain the upgrade header parameters. -// Channel read will only be called if the upgrade fails. - -private final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler { - public typealias InboundIn = HTTPClientResponsePart - public typealias OutboundOut = HTTPClientRequestPart - - public let target: ConnectTo - - public init(target: ConnectTo) { - self.target = target - } - - public func channelActive(context: ChannelHandlerContext) { - print("Client connected to \(context.remoteAddress!)") - - // We are connected. It's time to send the message to the server to initialize the upgrade dance. - var headers = HTTPHeaders() - if case let .ip(host: host, port: port) = target { - headers.add(name: "Host", value: "\(host):\(port)") - } - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") - headers.add(name: "Content-Length", value: "\(0)") - - let requestHead = HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/", - headers: headers) - - context.write(self.wrapOutboundOut(.head(requestHead)), promise: nil) - - let body = HTTPClientRequestPart.body(.byteBuffer(ByteBuffer())) - context.write(self.wrapOutboundOut(body), promise: nil) - - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - - let clientResponse = self.unwrapInboundIn(data) - - print("Upgrade failed") - - switch clientResponse { - case .head(let responseHead): - print("Received status: \(responseHead.status)") - case .body(let byteBuffer): - let string = String(buffer: byteBuffer) - print("Received: '\(string)' back from the server.") - case .end: - print("Closing channel.") - context.close(promise: nil) - } - } - - public func handlerRemoved(context: ChannelHandlerContext) { - print("HTTP handler removed.") - } - - public func errorCaught(context: ChannelHandlerContext, error: Error) { - print("error: ", error) - - // As we are not really interested getting notified on success or failure - // we just pass nil as promise to reduce allocations. - context.close(promise: nil) - } -} - -// The web socket handler to be used once the upgrade has occurred. -// One added, it sends a ping-pong round trip with "Hello World" data. -// It also listens for any text frames from the server and prints them. - -private final class WebSocketPingPongHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - let testFrameData: String = "Hello World" - - // This is being hit, channel active won't be called as it is already added. - public func handlerAdded(context: ChannelHandlerContext) { - print("WebSocket handler added.") - self.pingTestFrameData(context: context) - } - - public func handlerRemoved(context: ChannelHandlerContext) { - print("WebSocket handler removed.") - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = self.unwrapInboundIn(data) - - switch frame.opcode { - case .pong: - self.pong(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - print("Websocket: Received \(text)") - case .connectionClose: - self.receivedClose(context: context, frame: frame) - case .binary, .continuation, .ping: - // We ignore these frames. - break - default: - // Unknown frames are errors. - self.closeOnError(context: context) - } - } - - public func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - // Handle a received close frame. We're just going to close. - print("Received Close instruction from server") - context.close(promise: nil) - } - - private func pingTestFrameData(context: ChannelHandlerContext) { - let buffer = context.channel.allocator.buffer(string: self.testFrameData) - let frame = WebSocketFrame(fin: true, opcode: .ping, data: buffer) - context.writeAndFlush(self.wrapOutboundOut(frame), promise: nil) - } - - private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - if let frameDataString = frameData.readString(length: self.testFrameData.count) { - print("Websocket: Received: \(frameDataString)") - } - } - - private func closeOnError(context: ChannelHandlerContext) { - // We have hit an error, we want to close. We do that by sending a close frame and then - // shutting down the write side of the connection. The server will respond with a close of its own. - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - } -} - -// First argument is the program path -let arguments = CommandLine.arguments -let arg1 = arguments.dropFirst().first -let arg2 = arguments.dropFirst(2).first - -let defaultHost = "::1" -let defaultPort: Int = 8888 - -let connectTarget: ConnectTo -switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ - connectTarget = .ip(host: h, port: p) -case (.some(let portString), .none, _): - /* couldn't parse as number, expecting unix domain socket path */ - connectTarget = .unixDomainSocket(path: portString) -case (_, .some(let p), _): - /* only one argument --> port */ - connectTarget = .ip(host: defaultHost, port: p) -default: - connectTarget = .ip(host: defaultHost, port: defaultPort) -} - -let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) -let bootstrap = ClientBootstrap(group: group) - // Enable SO_REUSEADDR. - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .channelInitializer { channel in - - let httpHandler = HTTPInitialRequestHandler(target: connectTarget) - - let websocketUpgrader = NIOWebSocketClientUpgrader(requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketPingPongHandler()) - }) - - let config: NIOHTTPClientUpgradeConfiguration = ( - upgraders: [ websocketUpgrader ], - completionHandler: { _ in - channel.pipeline.removeHandler(httpHandler, promise: nil) - }) - - return channel.pipeline.addHTTPClientHandlers(withClientUpgrade: config).flatMap { - channel.pipeline.addHandler(httpHandler) - } -} -defer { - try! group.syncShutdownGracefully() -} - -let channel = try { () -> Channel in - switch connectTarget { - case .ip(let host, let port): - return try bootstrap.connect(host: host, port: port).wait() - case .unixDomainSocket(let path): - return try bootstrap.connect(unixDomainSocketPath: path).wait() - } -}() - -// Will be closed after we echo-ed back to the server. -try channel.closeFuture.wait() - -print("Client closed") diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 56c7bd3cfd..8344f4741e 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -16,7 +16,7 @@ import XCTest import Dispatch @testable import NIOCore import NIOEmbedded -@testable import NIOHTTP1 +@_spi(AsyncChannel) @testable import NIOHTTP1 extension EmbeddedChannel { @@ -32,31 +32,10 @@ extension EmbeddedChannel { } } -private func setUpClientChannel(clientHTTPHandler: RemovableChannelHandler, - clientUpgraders: [NIOHTTPClientProtocolUpgrader], - _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void) throws -> EmbeddedChannel { - - let channel = EmbeddedChannel() - - let config: NIOHTTPClientUpgradeConfiguration = ( - upgraders: clientUpgraders, - completionHandler: { context in - channel.pipeline.removeHandler(clientHTTPHandler, promise: nil) - upgradeCompletionHandler(context) - }) - - try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config).flatMap({ - channel.pipeline.addHandler(clientHTTPHandler) - }).wait() - - try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) - .wait() - - return channel -} +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} -private final class SuccessfulClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class SuccessfulClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -87,10 +66,15 @@ private final class SuccessfulClientUpgrader: NIOHTTPClientProtocolUpgrader { self.upgradeContextResponseCallCount += 1 return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + self.upgradeContextResponseCallCount += 1 + return channel.eventLoop.makeSucceededFuture(true) + } } -private final class ExplodingClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class ExplodingClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { + fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -118,10 +102,15 @@ private final class ExplodingClientUpgrader: NIOHTTPClientProtocolUpgrader { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + XCTFail("Upgrade should not be called.") + return channel.eventLoop.makeSucceededFuture(false) + } } -private final class DenyingClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class DenyingClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { + fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -152,9 +141,14 @@ private final class DenyingClientUpgrader: NIOHTTPClientProtocolUpgrader { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + XCTFail("Upgrade should not be called.") + return channel.eventLoop.makeSucceededFuture(false) + } } -private final class UpgradeDelayClientUpgrader: NIOHTTPClientProtocolUpgrader { +private final class UpgradeDelayClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] @@ -188,7 +182,14 @@ private final class UpgradeDelayClientUpgrader: NIOHTTPClientProtocolUpgrader { context.pipeline.addHandler(self.upgradedHandler) } } - + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise() + return self.upgradePromise!.futureResult.flatMap { + channel.pipeline.addHandler(self.upgradedHandler) + }.map { _ in true} + } + fileprivate func unblockUpgrade() { self.upgradePromise!.succeed(()) } @@ -278,9 +279,41 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +private func assertPipelineContainsUpgradeHandler(channel: Channel) { + let handler = try? channel.pipeline.syncOperations.handler(type: NIOHTTPClientUpgradeHandler.self) + let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) + + XCTAssertTrue(handler != nil || typedHandler != nil) +} + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPClientUpgradeTestCase: XCTestCase { - + func setUpClientChannel( + clientHTTPHandler: RemovableChannelHandler, + clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], + _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void + ) throws -> EmbeddedChannel { + + let channel = EmbeddedChannel() + + let config: NIOHTTPClientUpgradeConfiguration = ( + upgraders: clientUpgraders, + completionHandler: { context in + channel.pipeline.removeHandler(clientHTTPHandler, promise: nil) + upgradeCompletionHandler(context) + }) + + try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config).flatMap({ + channel.pipeline.addHandler(clientHTTPHandler) + }).wait() + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + + return channel + } + // MARK: Test basic happy path requests and responses. func testSimpleUpgradeSucceeds() throws { @@ -320,9 +353,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { .assertContains(handlerType: HTTPRequestEncoder.self)) XCTAssertNoThrow(try clientChannel.pipeline .assertContains(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: NIOHTTPClientUpgradeHandler.self)) - + assertPipelineContainsUpgradeHandler(channel: clientChannel) + // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" @@ -401,8 +433,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol, upgradeHeaders: clientHeaders) - let clientUpgraders: [NIOHTTPClientProtocolUpgrader] = [unusedClientUpgrader, clientUpgrader] - + let clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader] = [unusedClientUpgrader, clientUpgrader] + // The process should kick-off independently by sending the upgrade request to the server. let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), clientUpgraders: clientUpgraders) { (context) in @@ -457,7 +489,7 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } - final class AddHandlerClientUpgrader: NIOHTTPClientProtocolUpgrader { + final class AddHandlerClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let requiredUpgradeHeaders: [String] = [] fileprivate let supportedProtocol: String fileprivate let handler: T @@ -476,6 +508,10 @@ class HTTPClientUpgradeTestCase: XCTestCase { func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { return context.pipeline.addHandler(handler) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + return channel.pipeline.addHandler(handler).map { _ in true } + } } var upgradeHandlerCallbackFired = false @@ -923,3 +959,239 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } } + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { + override func setUpClientChannel( + clientHTTPHandler: RemovableChannelHandler, + clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], + _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void + ) throws -> EmbeddedChannel { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array(clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader }) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: upgraders + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(clientHTTPHandler) + }.map { _ in + false + } + } + var configuration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: config) + configuration.leftOverBytesStrategy = .forwardBytes + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: configuration) + let context = try channel.pipeline.syncOperations.context(handlerType: NIOTypedHTTPClientUpgradeHandler.self) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + upgradeResult.whenSuccess { result in + if result { + upgradeCompletionHandler(context) + } + } + + return channel + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testUpgradeOnlyHandlesKnownProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { + let upgradeProtocol = "myProto" + + var upgradeHandlerCallbackFired = false + + let clientUpgrader = DenyingClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .upgraderDeniedUpgrade) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is denied) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + + XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testFiresOutboundErrorDuringAddingHandlers() throws { + let upgradeProtocol = "myProto" + var errorOnAdditionalChannelWrite: Error? + var upgradeHandlerCallbackFired = false + + let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + let promise = clientChannel.eventLoop.makePromise(of: Void.self) + + promise.futureResult.whenFailure() { error in + errorOnAdditionalChannelWrite = error + } + + // Send another outbound request during the upgrade. + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let secondRequest: HTTPClientRequestPart = .head(requestHead) + clientChannel.writeAndFlush(secondRequest, promise: promise) + + clientChannel.embeddedEventLoop.run() + + let promiseError = errorOnAdditionalChannelWrite as! NIOHTTPClientUpgradeError + XCTAssertEqual(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade, promiseError) + + // Soundness check that the upgrade was delayed. + XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + + // Upgrade now. + clientUpgrader.unblockUpgrade() + clientChannel.embeddedEventLoop.run() + + // Check that the upgrade was still successful, despite the interruption. + XCTAssert(upgradeHandlerCallbackFired) + XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + } + + override func testUpgradeResponseMissingAllProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } +} diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index fd974e253b..94df819eb9 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -15,8 +15,8 @@ import XCTest import NIOCore import NIOEmbedded -import NIOHTTP1 -@testable import NIOWebSocket +@_spi(AsyncChannel) import NIOHTTP1 +@_spi(AsyncChannel) @testable import NIOWebSocket extension EmbeddedChannel { @@ -146,12 +146,12 @@ private class WebSocketRecorderHandler: ChannelInboundHandler, ChannelOutboundHa } } +private func basicRequest(path: String = "/") -> String { + return "GET \(path) HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0" +} + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class WebSocketClientEndToEndTests: XCTestCase { - - private func basicRequest(path: String = "/") -> String { - return "GET \(path) HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0" - } - func testSimpleUpgradeSucceeds() throws { var upgradeHandlerCallbackFired = false @@ -173,7 +173,7 @@ class WebSocketClientEndToEndTests: XCTestCase { // Read the server request. if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, self.basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") + XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") } else { XCTFail() } @@ -284,7 +284,7 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } - private func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { + fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { let handler = WebSocketRecorderHandler() @@ -404,3 +404,214 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } } + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests { + func setUpClientChannel( + clientUpgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) throws -> (EmbeddedChannel, EventLoopFuture) { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: clientUpgraders, + notUpgradingCompletionHandler: notUpgradingCompletionHandler + ) + + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: .init(upgradeConfiguration: config)) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + + return (channel, upgradeResult) + } + + override func testSimpleUpgradeSucceeds() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Read the server request. + if let requestString = try clientChannel.readByteBufferOutputAsString() { + XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") + } else { + XCTFail() + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // Once upgraded, validate the http pipeline has been removed. + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + // Check that the pipeline now has the correct websocket handlers added. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketFrameEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketRecorderHandler.self)) + + try upgradeResult.wait() + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + override func testRejectUpgradeIfMissingAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with a missing accept key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfIncorrectAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "notACorrectKeyL1am=F1y=nn=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with an incorrect response key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfNotWebsocket() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response with an incorrect protocol. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myProtocol\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + } + + override fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { + let handler = WebSocketRecorderHandler() + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(handler) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:yKEqitDFPE81FyIhKTm+ojBqigk=\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // We now have a successful upgrade, clear the output channels read to test the frames. + XCTAssertNoThrow(try clientChannel.readOutbound(as: ByteBuffer.self)) + + clientChannel.embeddedEventLoop.run() + + try upgradeResult.wait() + + return (clientChannel, handler) + } +} From d782de84df6b6abd915cdefe1d20d9fa58f700d8 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Sun, 8 Oct 2023 21:31:22 +0100 Subject: [PATCH 11/64] Adopt `package-benchmark` (#2534) # Motivation We want to migrate our allocation and later on also our performance tests to use the `package-benchmark` plugin. This plugin makes writing benchmarks way easier than our current setup. Furthermore, debugging benchmarks is also possible from within Xcode now. # Modification This PR adds the setup for the benchmarking infrastructure and connects it with out # Result Allocations tests are more accessible and easier to iterate. --- Benchmarks/.gitignore | 9 ++ .../NIOPosixBenchmarks/Benchmarks.swift | 32 +++++++ .../NIOPosixBenchmarks/TCPEcho.swift | 91 +++++++++++++++++++ Benchmarks/Package.swift | 41 +++++++++ .../5.10/NIOPosixBenchmarks.TCPEcho.p90.json | 3 + .../5.7/NIOPosixBenchmarks.TCPEcho.p90.json | 3 + .../5.8/NIOPosixBenchmarks.TCPEcho.p90.json | 3 + .../5.9/NIOPosixBenchmarks.TCPEcho.p90.json | 3 + .../main/NIOPosixBenchmarks.TCPEcho.p90.json | 3 + README.md | 13 +++ dev/update-benchmark-thresholds.sh | 41 +++++++++ docker/Dockerfile | 5 +- docker/docker-compose.2204.510.yaml | 5 + docker/docker-compose.2204.57.yaml | 5 + docker/docker-compose.2204.58.yaml | 5 + docker/docker-compose.2204.59.yaml | 5 + docker/docker-compose.2204.main.yaml | 5 + docker/docker-compose.yaml | 10 +- 18 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 Benchmarks/.gitignore create mode 100644 Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift create mode 100644 Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift create mode 100644 Benchmarks/Package.swift create mode 100644 Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json create mode 100644 Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json create mode 100644 Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json create mode 100644 Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json create mode 100644 Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json create mode 100755 dev/update-benchmark-thresholds.sh diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 0000000000..2517bcdfa8 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc +.benchmarkBaselines/ \ No newline at end of file diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift new file mode 100644 index 0000000000..f579891e75 --- /dev/null +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Benchmark + +let benchmarks = { + let defaultMetrics: [BenchmarkMetric] = [ + .mallocCountTotal, + ] + + Benchmark( + "TCPEcho", + configuration: .init( + metrics: defaultMetrics, + timeUnits: .milliseconds, + scalingFactor: .mega + ) + ) { benchmark in + try runTCPEcho(numberOfWrites: benchmark.scaledIterations.upperBound) + } +} diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift new file mode 100644 index 0000000000..1c656ea217 --- /dev/null +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift @@ -0,0 +1,91 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOPosix + +private final class EchoChannelHandler: ChannelInboundHandler { + fileprivate typealias InboundIn = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.writeAndFlush(data, promise: nil) + } +} + +private final class EchoRequestChannelHandler: ChannelInboundHandler { + fileprivate typealias InboundIn = ByteBuffer + + private let bufferSize = 10000 + private let numberOfWrites: Int + private var batchCount = 0 + private let data: NIOAny + private let readsCompletePromise: EventLoopPromise + private var receivedData = 0 + + init(numberOfWrites: Int, readsCompletePromise: EventLoopPromise) { + self.numberOfWrites = numberOfWrites + self.readsCompletePromise = readsCompletePromise + self.data = NIOAny(ByteBuffer(repeating: 0, count: self.bufferSize)) + } + + func channelActive(context: ChannelHandlerContext) { + for _ in 0..> $HOME/.profile diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml index ef1732e29c..6c4362987c 100644 --- a/docker/docker-compose.2204.510.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -76,6 +76,11 @@ services: performance-test: image: swift-nio:22.04-5.10 + update-benchmark-baseline: + image: swift-nio:22.04-5.10 + environment: + - SWIFT_VERSION=5.10 + shell: image: swift-nio:22.04-5.10 diff --git a/docker/docker-compose.2204.57.yaml b/docker/docker-compose.2204.57.yaml index 31df38fc98..eccb233c9c 100644 --- a/docker/docker-compose.2204.57.yaml +++ b/docker/docker-compose.2204.57.yaml @@ -76,6 +76,11 @@ services: performance-test: image: swift-nio:22.04-5.7 + update-benchmark-baseline: + image: swift-nio:22.04-5.7 + environment: + - SWIFT_VERSION=5.7 + shell: image: swift-nio:22.04-5.7 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml index f4910ba24b..c7f4059b09 100644 --- a/docker/docker-compose.2204.58.yaml +++ b/docker/docker-compose.2204.58.yaml @@ -77,6 +77,11 @@ services: performance-test: image: swift-nio:22.04-5.8 + update-benchmark-baseline: + image: swift-nio:22.04-5.8 + environment: + - SWIFT_VERSION=5.8 + shell: image: swift-nio:22.04-5.8 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index 357536253d..c1d94f76b5 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -77,6 +77,11 @@ services: performance-test: image: swift-nio:22.04-5.9 + update-benchmark-baseline: + image: swift-nio:22.04-5.9 + environment: + - SWIFT_VERSION=5.9 + shell: image: swift-nio:22.04-5.9 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml index e4da57992a..ad3a7fb6ad 100644 --- a/docker/docker-compose.2204.main.yaml +++ b/docker/docker-compose.2204.main.yaml @@ -76,6 +76,11 @@ services: performance-test: image: swift-nio:22.04-main + update-benchmark-baseline: + image: swift-nio:22.04-main + environment: + - SWIFT_VERSION=main + shell: image: swift-nio:22.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 8bf9b4adb6..c563622f75 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -16,8 +16,8 @@ services: depends_on: [runtime-setup] volumes: - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code + - ..:/swift-nio:z + working_dir: /swift-nio cap_drop: - CAP_NET_RAW - CAP_NET_BIND_SERVICE @@ -40,12 +40,16 @@ services: test: <<: *common - command: /bin/bash -xcl "uname -a && swift -version && swift $${SWIFT_TEST_VERB-test} $${FORCE_TEST_DISCOVERY-} $${WARN_AS_ERROR_ARG-} $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-} && ./scripts/integration_tests.sh $${INTEGRATION_TESTS_ARG-}" + command: /bin/bash -xcl "uname -a && swift -version && swift $${SWIFT_TEST_VERB-test} $${FORCE_TEST_DISCOVERY-} $${WARN_AS_ERROR_ARG-} $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-} && ./scripts/integration_tests.sh $${INTEGRATION_TESTS_ARG-} && cd Benchmarks && swift package benchmark baseline check --check-absolute-path Thresholds/$${SWIFT_VERSION-}/" performance-test: <<: *common command: /bin/bash -xcl "swift build -c release -Xswiftc -Xllvm -Xswiftc -align-all-functions=5 -Xswiftc -Xllvm -Xswiftc -align-all-blocks=5 && ./.build/release/NIOPerformanceTester" + update-benchmark-baseline: + <<: *common + command: /bin/bash -xcl "cd Benchmarks && swift package --disable-sandbox --scratch-path .build/$${SWIFT_VERSION-}/ --allow-writing-to-package-directory benchmark --format metricP90AbsoluteThresholds --path Thresholds/$${SWIFT_VERSION-}/" + # util shell: From 076af6dafd9ae2ddd3095753ef7a618d545bee17 Mon Sep 17 00:00:00 2001 From: Max Desiatov Date: Mon, 9 Oct 2023 08:20:53 +0100 Subject: [PATCH 12/64] Fix missing whitespace in `README.md` (#2535) "It also provides[`EmbeddedChannel`][ec]" -> "It also provides [`EmbeddedChannel`][ec]" Co-authored-by: Cory Benfield --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e195673ad3..c256db1ad0 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,7 @@ In general, [`ChannelHandler`][ch]s are designed to be highly re-usable componen SwiftNIO ships with many [`ChannelHandler`][ch]s built in that provide useful functionality, such as HTTP parsing. In addition, high-performance applications will want to provide as much of their logic as possible in [`ChannelHandler`][ch]s, as it helps avoid problems with context switching. -Additionally, SwiftNIO ships with a few [`Channel`][c] implementations. In particular, it ships with `ServerSocketChannel`, a [`Channel`][c] for sockets that accept inbound connections; `SocketChannel`, a [`Channel`][c] for TCP connections; and `DatagramChannel`, a [`Channel`][c] for UDP sockets. All of these are provided by the `NIOPosix` module. It also provides[`EmbeddedChannel`][ec], a [`Channel`][c] primarily used for testing, provided by the `NIOEmbedded` module. +Additionally, SwiftNIO ships with a few [`Channel`][c] implementations. In particular, it ships with `ServerSocketChannel`, a [`Channel`][c] for sockets that accept inbound connections; `SocketChannel`, a [`Channel`][c] for TCP connections; and `DatagramChannel`, a [`Channel`][c] for UDP sockets. All of these are provided by the `NIOPosix` module. It also provides [`EmbeddedChannel`][ec], a [`Channel`][c] primarily used for testing, provided by the `NIOEmbedded` module. ##### A Note on Blocking From 7a9d37ef497478a73f61e9a446bb3ff38e30b1ca Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 9 Oct 2023 10:45:00 +0100 Subject: [PATCH 13/64] Add `NIOAsyncChannel` benchmark (#2536) # Motivation We want to benchmark our `NIOAsyncChannel` to see how it compares to the synchronous implementation with `ChannelHandler`s # Modification This PR adds a new `TCPEchoAsyncChannel` benchmark that mimics the `TCPEcho` benchmark but uses our new async bridges. Since Swift Concurrency, is normally using a global executor this benchmark would have quite high variation. To reduce this variant I introduced code to hook the global executor and set an `EventLoop` as the executor. In the future, if we get task executors we can change the code to us them instead. # Result New baseline benchmarks for the `NIOAsyncChannel`. --- .../NIOPosixBenchmarks/Benchmarks.swift | 34 ++++++- .../NIOPosixBenchmarks/TCPEcho.swift | 3 +- .../TCPEchoAsyncChannel.swift | 89 +++++++++++++++++++ .../Util/GlobalExecutor.swift | 33 +++++++ Benchmarks/Package.swift | 2 +- .../5.10/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 + .../5.7/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../5.8/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../5.9/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 + .../main/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 + 13 files changed, 171 insertions(+), 9 deletions(-) create mode 100644 Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift create mode 100644 Benchmarks/Benchmarks/NIOPosixBenchmarks/Util/GlobalExecutor.swift create mode 100644 Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json create mode 100644 Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json create mode 100644 Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift index f579891e75..1590fcd65a 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift @@ -13,6 +13,9 @@ //===----------------------------------------------------------------------===// import Benchmark +import NIOPosix + +private let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() let benchmarks = { let defaultMetrics: [BenchmarkMetric] = [ @@ -27,6 +30,35 @@ let benchmarks = { scalingFactor: .mega ) ) { benchmark in - try runTCPEcho(numberOfWrites: benchmark.scaledIterations.upperBound) + try runTCPEcho( + numberOfWrites: benchmark.scaledIterations.upperBound, + eventLoop: eventLoop + ) + } + + // This benchmark is only available above 5.9 since our EL conformance + // to serial executor is also gated behind 5.9. + #if compiler(>=5.9) + Benchmark( + "TCPEchoAsyncChannel", + configuration: .init( + metrics: defaultMetrics, + timeUnits: .milliseconds, + scalingFactor: .mega, + setup: { + swiftTaskEnqueueGlobalHook = { job, _ in + eventLoop.executor.enqueue(job) + } + }, + teardown: { + swiftTaskEnqueueGlobalHook = nil + } + ) + ) { benchmark in + try await runTCPEchoAsyncChannel( + numberOfWrites: benchmark.scaledIterations.upperBound, + eventLoop: eventLoop + ) } + #endif } diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift index 1c656ea217..b7124ad50d 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift @@ -55,8 +55,7 @@ private final class EchoRequestChannelHandler: ChannelInboundHandler { } } -func runTCPEcho(numberOfWrites: Int) throws { - let eventLoop = MultiThreadedEventLoopGroup.singleton.next() +func runTCPEcho(numberOfWrites: Int, eventLoop: any EventLoop) throws { let serverChannel = try ServerBootstrap(group: eventLoop) .childChannelInitializer { channel in channel.eventLoop.makeCompletedFuture { diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift new file mode 100644 index 0000000000..88abcf5470 --- /dev/null +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@_spi(AsyncChannel) import NIOCore +@_spi(AsyncChannel) import NIOPosix + +func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async throws { + let serverChannel = try await ServerBootstrap(group: eventLoop) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + synchronouslyWrapping: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + let clientChannel = try await ClientBootstrap(group: eventLoop) + .connect( + host: "127.0.0.1", + port: serverChannel.channel.localAddress!.port! + ) { channel in + channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + synchronouslyWrapping: channel, + configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + ) + ) + } + } + + let bufferSize = 10000 + + try await withThrowingTaskGroup(of: Void.self) { group in + // This child task is echoing back the data on the server. + group.addTask { + for try await connectionChannel in serverChannel.inboundStream { + for try await inboundData in connectionChannel.inboundStream { + try await connectionChannel.outboundWriter.write(inboundData) + } + } + } + + // This child task is collecting the echoed back responses. + group.addTask { + var receivedData = 0 + for try await inboundData in clientChannel.inboundStream { + receivedData += inboundData.readableBytes + + if receivedData == numberOfWrites * bufferSize { + return + } + } + } + + // Let's start sending data. + let data = ByteBuffer(repeating: 0, count: bufferSize) + for _ in 0.. Void) -> Void + +var swiftTaskEnqueueGlobalHook: EnqueueGlobalHook? { + get { _swiftTaskEnqueueGlobalHook.pointee } + set { _swiftTaskEnqueueGlobalHook.pointee = newValue } +} + +private let _swiftTaskEnqueueGlobalHook: UnsafeMutablePointer = + dlsym(dlopen(nil, RTLD_LAZY), "swift_task_enqueueGlobal_hook").assumingMemoryBound(to: EnqueueGlobalHook?.self) diff --git a/Benchmarks/Package.swift b/Benchmarks/Package.swift index cb7ef412e1..8797a9249f 100644 --- a/Benchmarks/Package.swift +++ b/Benchmarks/Package.swift @@ -18,7 +18,7 @@ import PackageDescription let package = Package( name: "benchmarks", platforms: [ - .macOS(.v13), + .macOS("14"), ], dependencies: [ .package(path: "../"), diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json index a920579610..fa70aea890 100644 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 93 + "mallocCountTotal" : 90 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json new file mode 100644 index 0000000000..ddd4e94cf2 --- /dev/null +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 5554895 +} \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json index 18de727583..1859f424c5 100644 --- a/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 95 + "mallocCountTotal" : 92 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json index 18de727583..1859f424c5 100644 --- a/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 95 + "mallocCountTotal" : 92 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json index 18de727583..1859f424c5 100644 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 95 + "mallocCountTotal" : 92 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json new file mode 100644 index 0000000000..ffdf0ae74c --- /dev/null +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 5636901 +} \ No newline at end of file diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json index a920579610..fa70aea890 100644 --- a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 93 + "mallocCountTotal" : 90 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json new file mode 100644 index 0000000000..ddd4e94cf2 --- /dev/null +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 5554895 +} \ No newline at end of file From c2d8eed7e6a8bc5dd6c38c64d0f48e3a42090c17 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 9 Oct 2023 14:00:57 +0100 Subject: [PATCH 14/64] Add customization point for scheduling `ExecutorJob`s on `EventLoop`s (#2538) # Motivation Currently, the NIO's EventLoop conformance to the `SerialExecutor` protocol always uses `execute` to schedule the actual job. However, the closure for `execute` has to close over the job and the `EventLoop` itself; hence, it always allocates. Since jobs are a very fine grained object in Concurrency that are created a lot this lead to millions of allocations in even small benchmarks. # Modification This PR provides a customization point for `EventLoop`s to execute `ExecutorJob`s directly. For `SelectableEventLoop` we store a type erased `UnownedJob` in our `ScheduledTask` and just run it right away. # Result No more allocations when NIO's EL is used as a `SerialExecutor`. --- .../NIOPosixBenchmarks/Benchmarks.swift | 2 +- .../NIOPosixBenchmarks/TCPEcho.swift | 6 +-- .../TCPEchoAsyncChannel.swift | 6 +-- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...r.swift => EventLoop+SerialExecutor.swift} | 8 ++-- Sources/NIOCore/EventLoop.swift | 16 +++++++ .../MultiThreadedEventLoopGroup.swift | 47 +++++++++++++++++-- Sources/NIOPosix/SelectableEventLoop.swift | 28 ++++++++++- 10 files changed, 100 insertions(+), 19 deletions(-) rename Sources/NIOCore/{ActorExecutor.swift => EventLoop+SerialExecutor.swift} (92%) diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift index 1590fcd65a..56bb64dd61 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift @@ -15,7 +15,7 @@ import Benchmark import NIOPosix -private let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next() +private let eventLoop = MultiThreadedEventLoopGroup.singleton.next() let benchmarks = { let defaultMetrics: [BenchmarkMetric] = [ diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift index b7124ad50d..a1ca7a5df4 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift @@ -26,7 +26,7 @@ private final class EchoChannelHandler: ChannelInboundHandler { private final class EchoRequestChannelHandler: ChannelInboundHandler { fileprivate typealias InboundIn = ByteBuffer - private let bufferSize = 10000 + private let messageSize = 10000 private let numberOfWrites: Int private var batchCount = 0 private let data: NIOAny @@ -36,7 +36,7 @@ private final class EchoRequestChannelHandler: ChannelInboundHandler { init(numberOfWrites: Int, readsCompletePromise: EventLoopPromise) { self.numberOfWrites = numberOfWrites self.readsCompletePromise = readsCompletePromise - self.data = NIOAny(ByteBuffer(repeating: 0, count: self.bufferSize)) + self.data = NIOAny(ByteBuffer(repeating: 0, count: self.messageSize)) } func channelActive(context: ChannelHandlerContext) { @@ -49,7 +49,7 @@ private final class EchoRequestChannelHandler: ChannelInboundHandler { let buffer = self.unwrapInboundIn(data) self.receivedData += buffer.readableBytes - if self.receivedData == self.numberOfWrites * self.bufferSize { + if self.receivedData == self.numberOfWrites * self.messageSize { self.readsCompletePromise.succeed() } } diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift index 88abcf5470..ef28d9e4c1 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -48,7 +48,7 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr } } - let bufferSize = 10000 + let messageSize = 10000 try await withThrowingTaskGroup(of: Void.self) { group in // This child task is echoing back the data on the server. @@ -66,14 +66,14 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr for try await inboundData in clientChannel.inboundStream { receivedData += inboundData.readableBytes - if receivedData == numberOfWrites * bufferSize { + if receivedData == numberOfWrites * messageSize { return } } } // Let's start sending data. - let data = ByteBuffer(repeating: 0, count: bufferSize) + let data = ByteBuffer(repeating: 0, count: messageSize) for _ in 0..=5.9) +@usableFromInline +struct ErasedUnownedJob { + @usableFromInline + let erasedJob: Any + + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + init(job: UnownedJob) { + self.erasedJob = job + } + + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @inlinable + var unownedJob: UnownedJob { + // This force-cast is safe since we only store an UnownedJob + self.erasedJob as! UnownedJob + } +} +#endif + @usableFromInline internal struct ScheduledTask { + @usableFromInline + enum UnderlyingTask { + case function(() -> Void) + #if swift(>=5.9) + case unownedJob(ErasedUnownedJob) + #endif + } + /// The id of the scheduled task. /// /// - Important: This id has two purposes. First, it is used to give this struct an identity so that we can implement ``Equatable`` @@ -411,21 +439,32 @@ internal struct ScheduledTask { /// This means, the ids need to be unique for a given ``SelectableEventLoop`` and they need to be in ascending order. @usableFromInline let id: UInt64 - let task: () -> Void - private let failFn: (Error) ->() + let task: UnderlyingTask + private let failFn: ((Error) ->())? @usableFromInline internal let readyTime: NIODeadline @usableFromInline init(id: UInt64, _ task: @escaping () -> Void, _ failFn: @escaping (Error) -> Void, _ time: NIODeadline) { self.id = id - self.task = task + self.task = .function(task) self.failFn = failFn self.readyTime = time } + #if swift(>=5.9) + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @usableFromInline + init(id: UInt64, job: consuming ExecutorJob, readyTime: NIODeadline) { + self.id = id + self.task = .unownedJob(.init(job: UnownedJob(job))) + self.readyTime = readyTime + self.failFn = nil + } + #endif + func fail(_ error: Error) { - failFn(error) + failFn?(error) } } diff --git a/Sources/NIOPosix/SelectableEventLoop.swift b/Sources/NIOPosix/SelectableEventLoop.swift index 52339eccee..d76dc6005d 100644 --- a/Sources/NIOPosix/SelectableEventLoop.swift +++ b/Sources/NIOPosix/SelectableEventLoop.swift @@ -299,6 +299,20 @@ Further information: }, .now())) } + #if swift(>=5.9) + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @usableFromInline + func enqueue(_ job: consuming ExecutorJob) { + let scheduledTask = ScheduledTask( + id: self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed), + job: job, + readyTime: .now() + ) + // nothing we can do if we fail enqueuing here. + try? self._schedule0(scheduledTask) + } + #endif + /// Add the `ScheduledTask` to be executed. @usableFromInline internal func _schedule0(_ task: ScheduledTask) throws { @@ -515,7 +529,19 @@ Further information: for task in self.tasksCopy { /* for macOS: in case any calls we make to Foundation put objects into an autoreleasepool */ withAutoReleasePool { - task.task() + switch task.task { + case .function(let function): + function() + + #if swift(>=5.9) + case .unownedJob(let erasedUnownedJob): + if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { + erasedUnownedJob.unownedJob.runSynchronously(on: self.asUnownedSerialExecutor()) + } else { + fatalError("Tried to run an UnownedJob without runtime support") + } + #endif + } } } // Drop everything (but keep the capacity) so we can fill it again on the next iteration. From 94cc73dfbb8c13cb72c39b1d120a1e5828f0b864 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 9 Oct 2023 15:39:25 +0100 Subject: [PATCH 15/64] Improve performance of `NIOAsyncChannel` (#2539) * Implement fast paths for `AsyncChannelInboundStreamChannelHandler` * Add `_TinyArray` from `swift-certificates` * Implement single element customization point in `NIOAsyncChannelOutboundWriterHandler` * Call the single element optimization more often and store suspended producers in `_TinyArray`. * Update thresholds * Fix compiler warning --- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- NOTICE.txt | 13 +- Package.swift | 1 + ...ncChannelInboundStreamChannelHandler.swift | 12 +- .../AsyncChannelOutboundWriterHandler.swift | 31 ++ .../AsyncSequences/NIOAsyncWriter.swift | 93 ++++- Sources/_NIODataStructures/_TinyArray.swift | 356 ++++++++++++++++++ 9 files changed, 487 insertions(+), 25 deletions(-) create mode 100644 Sources/_NIODataStructures/_TinyArray.swift diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index 53b20f1974..9255c6c429 100644 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 2557298 + "mallocCountTotal" : 1317015 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index b4e98b55e9..810becb3a2 100644 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 2557305 + "mallocCountTotal" : 1317022 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index 53b20f1974..9255c6c429 100644 --- a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 2557298 + "mallocCountTotal" : 1317015 } \ No newline at end of file diff --git a/NOTICE.txt b/NOTICE.txt index 5090ff4f3e..e977e8e139 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -78,10 +78,21 @@ This product contains a derivation of Fabian Fett's 'Base64.swift'. * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE * HOMEPAGE: * https://github.com/fabianfett/swift-base64-kit - + +--- + This product contains a derivation of "XCTest+AsyncAwait.swift" from AsyncHTTPClient. * LICENSE (Apache License 2.0): * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/swift-server/async-http-client + +--- + +This product contains a derivation of "_TinyArray.swift" from SwiftCertificates. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/apple/swift-certificates diff --git a/Package.swift b/Package.swift index d0b61c3fc5..f7c7adb4c6 100644 --- a/Package.swift +++ b/Package.swift @@ -50,6 +50,7 @@ let package = Package( "CNIODarwin", "CNIOLinux", "CNIOWindows", + "_NIODataStructures", swiftCollections, swiftAtomics, ] diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift index 31d22f3ca6..d1d28c4c3c 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift @@ -326,15 +326,23 @@ struct NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate: @unchecked Se @inlinable func didTerminate() { - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self._didTerminate() + } else { + self.eventLoop.execute { + self._didTerminate() + } } } @inlinable func produceMore() { - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self._produceMore() + } else { + self.eventLoop.execute { + self._produceMore() + } } } } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index b15795b9c5..3d0a78ef3e 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -75,6 +75,23 @@ internal final class NIOAsyncChannelOutboundWriterHandler self._doOutboundWrites(context: context, writes: sequence) } + @inlinable + func _didYield(element: OutboundOut) { + // This is always called from an async context, so we must loop-hop. + // Because we always loop-hop, we're always at the top of a stack frame. As this + // is the only source of writes for us, and as this channel handler doesn't implement + // func write(), we cannot possibly re-entrantly write. That means we can skip many of the + // awkward re-entrancy protections NIO usually requires, and can safely just do an iterative + // write. + self.eventLoop.preconditionInEventLoop() + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } + + self._doOutboundWrite(context: context, write: element) + } + @inlinable func _didTerminate(error: Error?) { self.eventLoop.preconditionInEventLoop() @@ -102,6 +119,12 @@ internal final class NIOAsyncChannelOutboundWriterHandler context.flush() } + @inlinable + func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) { + context.write(self.wrapOutboundOut(write), promise: nil) + context.flush() + } + @inlinable func handlerAdded(context: ChannelHandlerContext) { self.context = context @@ -153,6 +176,14 @@ extension NIOAsyncChannelOutboundWriterHandler { } } + @inlinable + func didYield(_ element: OutboundOut) { + // This always called from an async context, so we must loop-hop. + self.eventLoop.execute { + self.handler._didYield(element: element) + } + } + @inlinable func didTerminate(error: Error?) { // This always called from an async context, so we must loop-hop. diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index dd07fec57f..f0356b574e 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -15,6 +15,7 @@ import Atomics import DequeModule import NIOConcurrencyHelpers +import _NIODataStructures /// The delegate of the ``NIOAsyncWriter``. It is the consumer of the yielded writes to the ``NIOAsyncWriter``. /// Furthermore, the delegate gets informed when the ``NIOAsyncWriter`` terminated. @@ -464,6 +465,21 @@ extension NIOAsyncWriter { // is immediately returning and just enqueues the Job on the executor suspendedYields.forEach { $0.continuation.resume() } + case .callDidYieldElementAndResumeContinuations(let delegate, let element, let suspendedYields): + // We are calling the delegate while holding lock. This can lead to potential crashes + // if the delegate calls `setWritability` reentrantly. However, we call this + // out in the docs of the delegate + delegate.didYield(element) + + // It is safe to resume the continuations while holding the lock since resume + // is immediately returning and just enqueues the Job on the executor + suspendedYields.forEach { $0.continuation.resume() } + + case .resumeContinuations(let suspendedYields): + // It is safe to resume the continuations while holding the lock since resume + // is immediately returning and just enqueues the Job on the executor + suspendedYields.forEach { $0.continuation.resume() } + case .callDidYieldAndDidTerminate(let delegate, let elements): // We are calling the delegate while holding lock. This can lead to potential crashes // if the delegate calls `setWritability` reentrantly. However, we call this @@ -679,7 +695,7 @@ extension NIOAsyncWriter { case streaming( isWritable: Bool, cancelledYields: [YieldID], - suspendedYields: [SuspendedYield], + suspendedYields: _TinyArray, elements: Deque, delegate: Delegate ) @@ -759,7 +775,12 @@ extension NIOAsyncWriter { enum SetWritabilityAction { /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` should be called /// and all continuations should be resumed. - case callDidYieldAndResumeContinuations(Delegate, Deque, [SuspendedYield]) + case callDidYieldAndResumeContinuations(Delegate, Deque, _TinyArray) + /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(element:)`` should be called + /// and all continuations should be resumed. + case callDidYieldElementAndResumeContinuations(Delegate, Element, _TinyArray) + /// Indicates that all continuations should be resumed. + case resumeContinuations(_TinyArray) /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` and /// ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``should be called. case callDidYieldAndDidTerminate(Delegate, Deque) @@ -776,7 +797,7 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let cancelledYields, let suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let cancelledYields, let suspendedYields, var elements, let delegate): if isWritable == newWritability { // The writability didn't change so we can just early exit here return .none @@ -786,19 +807,53 @@ extension NIOAsyncWriter { // We became writable again. This means we have to resume all the continuations // and yield the values. - self._state = .streaming( - isWritable: newWritability, - cancelledYields: cancelledYields, - suspendedYields: [], - elements: .init(), - delegate: delegate - ) + if elements.count == 0 { + // We just have to resume the continuations + self._state = .streaming( + isWritable: newWritability, + cancelledYields: cancelledYields, + suspendedYields: .init(), + elements: elements, + delegate: delegate + ) + + return .resumeContinuations(suspendedYields) + } else if elements.count == 1 { + // We have exactly one element in the buffer. Let's + // pop it and re-use the buffer right away + self._state = .modifying - // We are taking the whole array of suspended yields and the deque of elements - // and allocate a new empty one. - // As a performance optimization we could always keep multiple arrays/deques and - // switch between them but I don't think this is the performance critical part. - return .callDidYieldAndResumeContinuations(delegate, elements, suspendedYields) + // This force-unwrap is safe since we just checked the count for 1. + let element = elements.popFirst()! + + self._state = .streaming( + isWritable: newWritability, + cancelledYields: cancelledYields, + suspendedYields: .init(), + elements: elements, + delegate: delegate + ) + + return .callDidYieldElementAndResumeContinuations( + delegate, + element, + suspendedYields + ) + } else { + self._state = .streaming( + isWritable: newWritability, + cancelledYields: cancelledYields, + suspendedYields: .init(), + elements: .init(), + delegate: delegate + ) + + // We are taking the whole array of suspended yields and the deque of elements + // and allocate a new empty one. + // As a performance optimization we could always keep multiple arrays/deques and + // switch between them but I don't think this is the performance critical part. + return .callDidYieldAndResumeContinuations(delegate, elements, suspendedYields) + } } else { // We became unwritable nothing really to do here precondition(suspendedYields.isEmpty, "No yield should be suspended at this point") @@ -867,7 +922,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, cancelledYields: [], - suspendedYields: [], + suspendedYields: .init(), elements: .init(), delegate: delegate ) @@ -985,7 +1040,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, cancelledYields: [yieldID], - suspendedYields: [], + suspendedYields: .init(), elements: .init(), delegate: delegate ) @@ -1047,7 +1102,7 @@ extension NIOAsyncWriter { /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. case callDidTerminate(Delegate) /// Indicates that all continuations should be resumed. - case resumeContinuations([SuspendedYield]) + case resumeContinuations(_TinyArray) /// Indicates that nothing should be done. case none } @@ -1096,7 +1151,7 @@ extension NIOAsyncWriter { case callDidTerminate(Delegate, Error?) /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called and all /// continuations should be resumed with the given error. - case resumeContinuationsWithErrorAndCallDidTerminate(Delegate, [SuspendedYield], Error) + case resumeContinuationsWithErrorAndCallDidTerminate(Delegate, _TinyArray, Error) /// Indicates that nothing should be done. case none } diff --git a/Sources/_NIODataStructures/_TinyArray.swift b/Sources/_NIODataStructures/_TinyArray.swift new file mode 100644 index 0000000000..bc1c154b6a --- /dev/null +++ b/Sources/_NIODataStructures/_TinyArray.swift @@ -0,0 +1,356 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftCertificates open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftCertificates project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftCertificates project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// ``TinyArray`` is a ``RandomAccessCollection`` optimised to store zero or one ``Element``. +/// It supports arbitrary many elements but if only up to one ``Element`` is stored it does **not** allocate separate storage on the heap +/// and instead stores the ``Element`` inline. +public struct _TinyArray { + @usableFromInline + enum Storage { + case one(Element) + case arbitrary([Element]) + } + + @usableFromInline + var storage: Storage +} + +// MARK: - TinyArray "public" interface + +extension _TinyArray: Equatable where Element: Equatable {} +extension _TinyArray: Hashable where Element: Hashable {} +extension _TinyArray: Sendable where Element: Sendable {} + +extension _TinyArray: RandomAccessCollection { + public typealias Element = Element + + public typealias Index = Int + + @inlinable + public func makeIterator() -> Iterator { + return Iterator(storage: self.storage) + } + + public struct Iterator: IteratorProtocol { + @usableFromInline + let _storage: Storage + @usableFromInline + var _index: Index + + @usableFromInline + init(storage: Storage) { + self._storage = storage + self._index = storage.startIndex + } + + @inlinable + public mutating func next() -> Element? { + if self._index == self._storage.endIndex { + return nil + } + + defer { + self._index &+= 1 + } + + return self._storage[self._index] + } + } + + @inlinable + public subscript(position: Int) -> Element { + get { + self.storage[position] + } + } + + @inlinable + public var startIndex: Int { + self.storage.startIndex + } + + @inlinable + public var endIndex: Int { + self.storage.endIndex + } +} + +extension _TinyArray { + @inlinable + public init(_ elements: some Sequence) { + self.storage = .init(elements) + } + + @inlinable + public init(_ elements: some Sequence>) throws { + self.storage = try .init(elements) + } + + @inlinable + public init() { + self.storage = .init() + } + + @inlinable + public mutating func append(_ newElement: Element) { + self.storage.append(newElement) + } + + @inlinable + public mutating func append(contentsOf newElements: some Sequence) { + self.storage.append(contentsOf: newElements) + } + + @discardableResult + @inlinable + public mutating func remove(at index: Int) -> Element { + self.storage.remove(at: index) + } + + @inlinable + public mutating func removeAll(where shouldBeRemoved: (Element) throws -> Bool) rethrows { + try self.storage.removeAll(where: shouldBeRemoved) + } + + @inlinable + public mutating func sort(by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows { + try self.storage.sort(by: areInIncreasingOrder) + } +} + +// MARK: - TinyArray.Storage "private" implementation + +extension _TinyArray.Storage: Equatable where Element: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.one(let lhs), .one(let rhs)): + return lhs == rhs + case (.arbitrary(let lhs), .arbitrary(let rhs)): + // we don't use lhs.elementsEqual(rhs) so we can hit the fast path from Array + // if both arrays share the same underlying storage: https://github.com/apple/swift/blob/b42019005988b2d13398025883e285a81d323efa/stdlib/public/core/Array.swift#L1775 + return lhs == rhs + + case (.one(let element), .arbitrary(let array)), + (.arbitrary(let array), .one(let element)): + guard array.count == 1 else { + return false + } + return element == array[0] + + } + } +} +extension _TinyArray.Storage: Hashable where Element: Hashable { + @inlinable + func hash(into hasher: inout Hasher) { + // same strategy as Array: https://github.com/apple/swift/blob/b42019005988b2d13398025883e285a81d323efa/stdlib/public/core/Array.swift#L1801 + hasher.combine(count) + for element in self { + hasher.combine(element) + } + } +} +extension _TinyArray.Storage: Sendable where Element: Sendable {} + +extension _TinyArray.Storage: RandomAccessCollection { + @inlinable + subscript(position: Int) -> Element { + get { + switch self { + case .one(let element): + guard position == 0 else { + fatalError("index \(position) out of bounds") + } + return element + case .arbitrary(let elements): + return elements[position] + } + } + } + + @inlinable + var startIndex: Int { + 0 + } + + @inlinable + var endIndex: Int { + switch self { + case .one: return 1 + case .arbitrary(let elements): return elements.endIndex + } + } +} + +extension _TinyArray.Storage { + @inlinable + init(_ elements: some Sequence) { + self = .arbitrary([]) + self.append(contentsOf: elements) + } + + @inlinable + init(_ newElements: some Sequence>) throws { + var iterator = newElements.makeIterator() + guard let firstElement = try iterator.next()?.get() else { + self = .arbitrary([]) + return + } + guard let secondElement = try iterator.next()?.get() else { + // newElements just contains a single element + // and we hit the fast path + self = .one(firstElement) + return + } + + var elements: [Element] = [] + elements.reserveCapacity(newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + while let nextElement = try iterator.next()?.get() { + elements.append(nextElement) + } + self = .arbitrary(elements) + } + + @inlinable + init() { + self = .arbitrary([]) + } + + @inlinable + mutating func append(_ newElement: Element) { + self.append(contentsOf: CollectionOfOne(newElement)) + } + + @inlinable + mutating func append(contentsOf newElements: some Sequence) { + switch self { + case .one(let firstElement): + var iterator = newElements.makeIterator() + guard let secondElement = iterator.next() else { + // newElements is empty, nothing to do + return + } + var elements: [Element] = [] + elements.reserveCapacity(1 + newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + elements.appendRemainingElements(from: &iterator) + self = .arbitrary(elements) + + case .arbitrary(var elements): + if elements.isEmpty { + // if `self` is currently empty and `newElements` just contains a single + // element, we skip allocating an array and set `self` to `.one(firstElement)` + var iterator = newElements.makeIterator() + guard let firstElement = iterator.next() else { + // newElements is empty, nothing to do + return + } + guard let secondElement = iterator.next() else { + // newElements just contains a single element + // and we hit the fast path + self = .one(firstElement) + return + } + elements.reserveCapacity(elements.count + newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + elements.appendRemainingElements(from: &iterator) + self = .arbitrary(elements) + + } else { + elements.append(contentsOf: newElements) + self = .arbitrary(elements) + } + + } + } + + @discardableResult + @inlinable + mutating func remove(at index: Int) -> Element { + switch self { + case .one(let oldElement): + guard index == 0 else { + fatalError("index \(index) out of bounds") + } + self = .arbitrary([]) + return oldElement + + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + return elements.remove(at: index) + + } + } + + @inlinable + mutating func removeAll(where shouldBeRemoved: (Element) throws -> Bool) rethrows { + switch self { + case .one(let oldElement): + if try shouldBeRemoved(oldElement) { + self = .arbitrary([]) + } + + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + return try elements.removeAll(where: shouldBeRemoved) + + } + } + + @inlinable + mutating func sort(by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows { + switch self { + case .one: + // a collection of just one element is always sorted, nothing to do + break + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + + try elements.sort(by: areInIncreasingOrder) + } + } +} + +extension Array { + @inlinable + mutating func appendRemainingElements(from iterator: inout some IteratorProtocol) { + while let nextElement = iterator.next() { + append(nextElement) + } + } +} From 8c3cac7774668e260b7a9d9da8468d82a619829a Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Tue, 10 Oct 2023 16:15:11 +0100 Subject: [PATCH 16/64] perf tests: reset BB indices after every iteration (#2544) --- .../ByteBufferWriteMultipleBenchmarks.swift | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift index e4e7387a5d..489af7b4a1 100644 --- a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift +++ b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift @@ -34,6 +34,7 @@ final class ByteBufferReadWriteMultipleIntegersBenchmark: func run() throws -> Int { var result: I = 0 for _ in 0..: func run() throws -> Int { var result: I = 0 for _ in 0.. Date: Wed, 11 Oct 2023 10:58:16 +0100 Subject: [PATCH 17/64] measureRunTime use DispatchTime (#2545) Motivation: Using `Date()` in benchmarks leaves us open to the possibility that a clock update could occur between the start and end times affecting the measured interval. Modifications: Switch `measureRunTime` and related functions to use `DispatchTime` uptime in nanoseconds to calculate intervals Result: Tests should no longer be susceptible to clock updates. --- Tests/NIOPosixTests/SystemCallWrapperHelpers.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift b/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift index ae1ed7ac15..24b2086587 100644 --- a/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift +++ b/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift @@ -19,10 +19,10 @@ import Foundation public func measureRunTime(_ body: () throws -> Int) rethrows -> TimeInterval { func measureOne(_ body: () throws -> Int) rethrows -> TimeInterval { - let start = Date() + let start = DispatchTime.now().uptimeNanoseconds _ = try body() - let end = Date() - return end.timeIntervalSince(start) + let end = DispatchTime.now().uptimeNanoseconds + return Double(end - start)/1_000_000 } _ = try measureOne(body) From 421fcec525a45351a44142857bab866006e7b275 Mon Sep 17 00:00:00 2001 From: Lorenzo Fritzsch Date: Thu, 12 Oct 2023 16:07:52 +0200 Subject: [PATCH 18/64] Fix overflow (#2543) * Fix overflow * Replace unchecked operation `&*` with `multipliedReportingOverflow` method and add unit tests for underflow/overflow cases * PR fixes --------- Co-authored-by: George Barnett --- Sources/NIOCore/EventLoop.swift | 36 ++++++++++++++++++++---- Tests/NIOCoreTests/TimeAmountTests.swift | 18 ++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/Sources/NIOCore/EventLoop.swift b/Sources/NIOCore/EventLoop.swift index 1c5ae5f964..b7f601f4a9 100644 --- a/Sources/NIOCore/EventLoop.swift +++ b/Sources/NIOCore/EventLoop.swift @@ -436,9 +436,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of microseconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func microseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * 1000) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000)) } /// Creates a new `TimeAmount` for the given amount of milliseconds. @@ -446,9 +448,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of milliseconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func milliseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of seconds. @@ -456,9 +460,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of seconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func seconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of minutes. @@ -466,9 +472,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of minutes this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func minutes(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000 * 60)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60)) } /// Creates a new `TimeAmount` for the given amount of hours. @@ -476,9 +484,27 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of hours this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func hours(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000 * 60 * 60)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60 * 60)) + } + + /// Converts `amount` to nanoseconds multiplying it by `multiplier`. The return value is capped to `Int64.max` if the multiplication overflows and `Int64.min` if it underflows. + /// + /// - parameters: + /// - amount: the amount to be converted to nanoseconds. + /// - multiplier: the multiplier that converts the given amount to nanoseconds. + /// - returns: the amount converted to nanoseconds within [Int64.min, Int64.max]. + @inlinable + static func _cappedNanoseconds(amount: Int64, multiplier: Int64) -> Int64 { + let nanosecondsMultiplication = amount.multipliedReportingOverflow(by: multiplier) + if nanosecondsMultiplication.overflow { + return amount >= 0 ? .max : .min + } else { + return nanosecondsMultiplication.partialValue + } } } diff --git a/Tests/NIOCoreTests/TimeAmountTests.swift b/Tests/NIOCoreTests/TimeAmountTests.swift index dbee1d4831..7fb6a4efe8 100644 --- a/Tests/NIOCoreTests/TimeAmountTests.swift +++ b/Tests/NIOCoreTests/TimeAmountTests.swift @@ -43,4 +43,22 @@ class TimeAmountTests: XCTestCase { lhs -= rhs XCTAssertEqual(lhs, .nanoseconds(0)) } + + func testTimeAmountCappedOverflow() { + let overflowCap = TimeAmount.nanoseconds(Int64.max) + XCTAssertEqual(TimeAmount.microseconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.milliseconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.seconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.minutes(.max), overflowCap) + XCTAssertEqual(TimeAmount.hours(.max), overflowCap) + } + + func testTimeAmountCappedUnderflow() { + let underflowCap = TimeAmount.nanoseconds(.min) + XCTAssertEqual(TimeAmount.microseconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.milliseconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.seconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.minutes(.min), underflowCap) + XCTAssertEqual(TimeAmount.hours(.min), underflowCap) + } } From 2910d6b20323ac4f9003ab51711f9bfb7e912c22 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Thu, 12 Oct 2023 16:33:51 +0100 Subject: [PATCH 19/64] Call `NIOAsyncWriterSinkDelegate` outside of the lock (#2547) * Call `NIOAsyncWriterSinkDelegate` outside of the lock # Motivation The current `NIOAsyncWriter` implementation expects that the delegate is called while holding the lock to avoid reentrancy issues. However, this prevents us from executing the delegate calls directly on the `EventLoop` if we are on it already. # Modification This moves all of the delegate calls outside of the locks and adds protection against reentrancy into the state machine. # Result Less allocations. Clarify the reentrancy problems in docs and protect against them in the writer * Code review --- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 4 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 2 +- .../AsyncChannelOutboundWriterHandler.swift | 21 +- .../AsyncSequences/NIOAsyncWriter.swift | 478 ++++++++++++------ .../AsyncChannel/AsyncChannelTests.swift | 4 +- .../AsyncSequences/NIOAsyncWriterTests.swift | 45 +- 7 files changed, 390 insertions(+), 166 deletions(-) diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index 9255c6c429..74498fb02f 100644 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 1317015 -} \ No newline at end of file + "mallocCountTotal" : 164419 +} diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index 810becb3a2..c38c7cbbfd 100644 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 1317022 + "mallocCountTotal" : 164426 } \ No newline at end of file diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json index 9255c6c429..617e73531c 100644 --- a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 1317015 + "mallocCountTotal" : 164419 } \ No newline at end of file diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index 3d0a78ef3e..33f4d363dc 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -170,25 +170,34 @@ extension NIOAsyncChannelOutboundWriterHandler { @inlinable func didYield(contentsOf sequence: Deque) { - // This always called from an async context, so we must loop-hop. - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self.handler._didYield(sequence: sequence) + } else { + self.eventLoop.execute { + self.handler._didYield(sequence: sequence) + } } } @inlinable func didYield(_ element: OutboundOut) { - // This always called from an async context, so we must loop-hop. - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self.handler._didYield(element: element) + } else { + self.eventLoop.execute { + self.handler._didYield(element: element) + } } } @inlinable func didTerminate(error: Error?) { - // This always called from an async context, so we must loop-hop. - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self.handler._didTerminate(error: error) + } else { + self.eventLoop.execute { + self.handler._didTerminate(error: error) + } } } } diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index f0356b574e..b2ae81d499 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -20,9 +20,8 @@ import _NIODataStructures /// The delegate of the ``NIOAsyncWriter``. It is the consumer of the yielded writes to the ``NIOAsyncWriter``. /// Furthermore, the delegate gets informed when the ``NIOAsyncWriter`` terminated. /// -/// - Important: The methods on the delegate are called while a lock inside of the ``NIOAsyncWriter`` is held. This is done to -/// guarantee the ordering of the writes. However, this means you **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` -/// from within ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` or ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``. +/// - Important: The methods on the delegate might be called on arbitrary threads and the implementation must ensure +/// that proper synchronization is in place. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol NIOAsyncWriterSinkDelegate: Sendable { /// The `Element` type of the delegate and the writer. @@ -35,7 +34,9 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// until the ``NIOAsyncWriter`` becomes writable again. All buffered writes, while the ``NIOAsyncWriter`` is not writable, /// will be coalesced into a single sequence. /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. + /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. + /// This might trigger more calls to one of the `didYield` methods and it is up to the delegate to make sure that this reentrancy is + /// correctly guarded against. func didYield(contentsOf sequence: Deque) /// This method is called once a single element was yielded to the ``NIOAsyncWriter``. @@ -47,7 +48,9 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// /// - Note: This a fast path that you can optionally implement. By default this will just call ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)``. /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. + /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. + /// This might trigger more calls to one of the `didYield` methods and it is up to the delegate to make sure that this reentrancy is + /// correctly guarded against. func didYield(_ element: Element) /// This method is called once the ``NIOAsyncWriter`` is terminated. @@ -63,8 +66,6 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// - Parameter error: The error that terminated the ``NIOAsyncWriter``. If the writer was terminated without an /// error this value is `nil`. This can be either the error passed to ``NIOAsyncWriter/finish(error:)`` or /// to ``NIOAsyncWriter/Sink/finish(error:)``. - /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. func didTerminate(error: Error?) } @@ -433,63 +434,46 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal func writerDeinitialized() { - self._lock.withLock { - let action = self._stateMachine.writerDeinitialized() + let action = self._lock.withLock { + self._stateMachine.writerDeinitialized() + } - switch action { - case .callDidTerminate(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: nil) + switch action { + case .callDidTerminate(let delegate): + delegate.didTerminate(error: nil) - case .none: - break - } + case .none: + break } + } @inlinable /* fileprivate */ internal func setWritability(to writability: Bool) { - self._lock.withLock { - let action = self._stateMachine.setWritability(to: writability) - - switch action { - case .callDidYieldAndResumeContinuations(let delegate, let elements, let suspendedYields): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didYield(contentsOf: elements) + let action = self._lock.withLock { + self._stateMachine.setWritability(to: writability) + } - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } + switch action { + case .callDidYieldAndResumeContinuations(let delegate, let elements, let suspendedYields): + delegate.didYield(contentsOf: elements) + suspendedYields.forEach { $0.continuation.resume() } + self.unbufferQueuedEvents() - case .callDidYieldElementAndResumeContinuations(let delegate, let element, let suspendedYields): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didYield(element) + case .callDidYieldElementAndResumeContinuations(let delegate, let element, let suspendedYields): + delegate.didYield(element) + suspendedYields.forEach { $0.continuation.resume() } + self.unbufferQueuedEvents() - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume() } - case .resumeContinuations(let suspendedYields): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } + case .callDidYieldAndDidTerminate(let delegate, let elements, let error): + delegate.didYield(contentsOf: elements) + delegate.didTerminate(error: error) - case .callDidYieldAndDidTerminate(let delegate, let elements): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didYield(contentsOf: elements) - delegate.didTerminate(error: nil) - - case .none: - return - } + case .none: + return } } @@ -505,13 +489,10 @@ extension NIOAsyncWriter { switch action { case .callDidYield(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - // We are allocating a new Deque for every write here - delegate.didYield(contentsOf: Deque(sequence)) self._lock.unlock() + delegate.didYield(contentsOf: Deque(sequence)) + self.unbufferQueuedEvents() case .returnNormally: self._lock.unlock() @@ -533,18 +514,16 @@ extension NIOAsyncWriter { } } } onCancel: { - self._lock.withLock { - let action = self._stateMachine.cancel(yieldID: yieldID) + let action = self._lock.withLock { + self._stateMachine.cancel(yieldID: yieldID) + } - switch action { - case .resumeContinuation(let continuation): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - continuation.resume() + switch action { + case .resumeContinuation(let continuation): + continuation.resume() - case .none: - break - } + case .none: + break } } } @@ -561,12 +540,9 @@ extension NIOAsyncWriter { switch action { case .callDidYield(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - - delegate.didYield(element) self._lock.unlock() + delegate.didYield(element) + self.unbufferQueuedEvents() case .returnNormally: self._lock.unlock() @@ -588,69 +564,73 @@ extension NIOAsyncWriter { } } } onCancel: { - self._lock.withLock { - let action = self._stateMachine.cancel(yieldID: yieldID) + let action = self._lock.withLock { + self._stateMachine.cancel(yieldID: yieldID) + } - switch action { - case .resumeContinuation(let continuation): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - continuation.resume() + switch action { + case .resumeContinuation(let continuation): + continuation.resume() - case .none: - break - } + case .none: + break } } } @inlinable /* fileprivate */ internal func writerFinish(error: Error?) { - self._lock.withLock { - let action = self._stateMachine.writerFinish() + let action = self._lock.withLock { + self._stateMachine.writerFinish(error: error) + } - switch action { - case .callDidTerminate(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: error) + switch action { + case .callDidTerminate(let delegate): + delegate.didTerminate(error: error) - case .resumeContinuations(let suspendedYields): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume() } - case .none: - break - } + case .none: + break } } @inlinable /* fileprivate */ internal func sinkFinish(error: Error?) { - self._lock.withLock { - let action = self._stateMachine.sinkFinish(error: error) + let action = self._lock.withLock { + self._stateMachine.sinkFinish(error: error) + } + + switch action { + case .callDidTerminate(let delegate, let error): + delegate.didTerminate(error: error) + + case .resumeContinuationsWithError(let suspendedYields, let error): + suspendedYields.forEach { $0.continuation.resume(throwing: error) } + + case .resumeContinuationsWithErrorAndCallDidTerminate(let delegate, let suspendedYields, let error): + delegate.didTerminate(error: error) + suspendedYields.forEach { $0.continuation.resume(throwing: error) } + + case .none: + break + } + } + + @inlinable + /* fileprivate */ internal func unbufferQueuedEvents() { + while let action = self._lock.withLock({ self._stateMachine.unbufferQueuedEvents()}) { switch action { case .callDidTerminate(let delegate, let error): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: error) - - case .resumeContinuationsWithErrorAndCallDidTerminate(let delegate, let suspendedYields, let error): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate delegate.didTerminate(error: error) - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume(throwing: error) } + case .callDidYield(let delegate, let elements): + delegate.didYield(contentsOf: elements) - case .none: - break + case .callDidYieldElement(let delegate, let element): + delegate.didYield(element) } } } @@ -694,6 +674,7 @@ extension NIOAsyncWriter { /// The state after a call to ``NIOAsyncWriter/yield(contentsOf:)``. case streaming( isWritable: Bool, + inDelegateOutcall: Bool, cancelledYields: [YieldID], suspendedYields: _TinyArray, elements: Deque, @@ -705,7 +686,8 @@ extension NIOAsyncWriter { /// 2. ``NIOAsyncWriter/finish(completion:)`` was called. case writerFinished( elements: Deque, - delegate: Delegate + delegate: Delegate, + error: Error? ) /// The state once the sink has been finished or the writer has been finished and all elements @@ -747,7 +729,7 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(_, _, let suspendedYields, let elements, let delegate): + case .streaming(_, _, _, let suspendedYields, let elements, let delegate): // The writer got deinited after we started streaming. // This is normal and we need to transition to finished // and call the delegate. However, we should not have @@ -783,7 +765,7 @@ extension NIOAsyncWriter { case resumeContinuations(_TinyArray) /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` and /// ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``should be called. - case callDidYieldAndDidTerminate(Delegate, Deque) + case callDidYieldAndDidTerminate(Delegate, Deque, Error?) /// Indicates that nothing should be done. case none } @@ -797,13 +779,13 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, var elements, let delegate): if isWritable == newWritability { // The writability didn't change so we can just early exit here return .none } - if newWritability { + if newWritability && !inDelegateOutcall { // We became writable again. This means we have to resume all the continuations // and yield the values. @@ -811,6 +793,7 @@ extension NIOAsyncWriter { // We just have to resume the continuations self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: .init(), elements: elements, @@ -828,6 +811,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: .init(), elements: elements, @@ -842,6 +826,7 @@ extension NIOAsyncWriter { } else { self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: .init(), elements: .init(), @@ -854,13 +839,23 @@ extension NIOAsyncWriter { // switch between them but I don't think this is the performance critical part. return .callDidYieldAndResumeContinuations(delegate, elements, suspendedYields) } + } else if newWritability && inDelegateOutcall { + // We became writable but are in a delegate outcall. + // We just have to store the new writability here + self._state = .streaming( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + return .none } else { // We became unwritable nothing really to do here - precondition(suspendedYields.isEmpty, "No yield should be suspended at this point") - precondition(elements.isEmpty, "No element should be buffered at this point") - self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, @@ -869,7 +864,7 @@ extension NIOAsyncWriter { return .none } - case .writerFinished(let elements, let delegate): + case .writerFinished(let elements, let delegate, let error): if !newWritability { // We are not writable so we can't deliver the outstanding elements return .none @@ -877,7 +872,7 @@ extension NIOAsyncWriter { self._state = .finished(sinkError: nil) - return .callDidYieldAndDidTerminate(delegate, elements) + return .callDidYieldAndDidTerminate(delegate, elements, error) case .finished: // We are already finished nothing to do here @@ -921,6 +916,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: isWritable, // If we are writable we are going to make an outcall cancelledYields: [], suspendedYields: .init(), elements: .init(), @@ -929,26 +925,43 @@ extension NIOAsyncWriter { return .init(isWritable: isWritable, delegate: delegate) - case .streaming(let isWritable, var cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, let suspendedYields, var elements, let delegate): + self._state = .modifying + if let index = cancelledYields.firstIndex(of: yieldID) { // We already marked the yield as cancelled. We have to remove it and // throw an error. - self._state = .modifying - cancelledYields.remove(at: index) - if isWritable { + switch (isWritable, inDelegateOutcall) { + case (true, false): // We are writable so we can yield the elements right away and then // return normally. self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, delegate: delegate ) return .callDidYield(delegate) - } else { + + case (true, true): + // We are writable but already calling out to the delegate + // so we have to buffer the elements. + elements.append(contentsOf: sequence) + + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + return .returnNormally + case (false, _): // We are not writable so we are just going to enqueue the writes // and return normally. We are not suspending the yield since the Task // is marked as cancelled. @@ -956,6 +969,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, @@ -966,8 +980,42 @@ extension NIOAsyncWriter { } } else { // Yield hasn't been marked as cancelled. - // This means we can either call the delegate or suspend - return .init(isWritable: isWritable, delegate: delegate) + + switch (isWritable, inDelegateOutcall) { + case (true, false): + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: true, // We are now making a call to the delegate + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + + return .callDidYield(delegate) + case (true, true): + elements.append(contentsOf: sequence) + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + return .returnNormally + case (false, _): + // We are not writable + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + return .suspendTask + } } case .writerFinished: @@ -991,7 +1039,7 @@ extension NIOAsyncWriter { yieldID: YieldID ) where S.Element == Element { switch self._state { - case .streaming(let isWritable, let cancelledYields, var suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, var suspendedYields, var elements, let delegate): // We have a suspended yield at this point that hasn't been cancelled yet. // We need to store the yield now. @@ -1006,6 +1054,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, @@ -1039,6 +1088,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: false, cancelledYields: [yieldID], suspendedYields: .init(), elements: .init(), @@ -1047,7 +1097,7 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, var cancelledYields, var suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, var suspendedYields, let elements, let delegate): if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { self._state = .modifying // We have a suspended yield for the id. We need to resume the continuation now. @@ -1061,6 +1111,7 @@ extension NIOAsyncWriter { // We are keeping the elements that the yield produced. self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, @@ -1078,6 +1129,7 @@ extension NIOAsyncWriter { cancelledYields.append(yieldID) self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, elements: elements, @@ -1108,7 +1160,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func writerFinish() -> WriterFinishAction { + /* fileprivate */ internal mutating func writerFinish(error: Error?) -> WriterFinishAction { switch self._state { case .initial(_, let delegate): // Nothing was ever written so we can transition to finished @@ -1116,18 +1168,30 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(_, _, let suspendedYields, let elements, let delegate): + case .streaming(_, let inDelegateOutcall, _, let suspendedYields, let elements, let delegate): // We are currently streaming and the writer got finished. if elements.isEmpty { - // We have no elements left and can transition to finished directly - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) + if inDelegateOutcall { + // We are in an outcall already and have to buffer + // the didTerminate call. + self._state = .writerFinished( + elements: elements, + delegate: delegate, + error: error + ) + return .none + } else { + // We have no elements left and are not in an outcall so we + // can transition to finished directly + self._state = .finished(sinkError: nil) + return .callDidTerminate(delegate) + } } else { // There are still elements left which we need to deliver once we become writable again self._state = .writerFinished( elements: elements, - delegate: delegate + delegate: delegate, + error: error ) // We are not resuming the continuations with the error here since their elements @@ -1152,6 +1216,8 @@ extension NIOAsyncWriter { /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called and all /// continuations should be resumed with the given error. case resumeContinuationsWithErrorAndCallDidTerminate(Delegate, _TinyArray, Error) + /// Indicates that all continuations should be resumed with the given error. + case resumeContinuationsWithError(_TinyArray, Error) /// Indicates that nothing should be done. case none } @@ -1165,18 +1231,29 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate, error) - case .streaming(_, _, let suspendedYields, _, let delegate): - // We are currently streaming and the writer got finished. - // We can transition to finished and need to resume all continuations. - self._state = .finished(sinkError: error) + case .streaming(_, let inDelegateOutcall, _, let suspendedYields, _, let delegate): + if inDelegateOutcall { + // We are currently streaming and the sink got finished. + // However we are in an outcall so we have to delay the call to didTerminate + // but we can resume the continuations already. + self._state = .writerFinished(elements: .init(), delegate: delegate, error: error) - return .resumeContinuationsWithErrorAndCallDidTerminate( - delegate, - suspendedYields, - error ?? NIOAsyncWriterError.alreadyFinished() - ) + return .resumeContinuationsWithError( + suspendedYields, + error ?? NIOAsyncWriterError.alreadyFinished() + ) + } else { + // We are currently streaming and the writer got finished. + // We can transition to finished and need to resume all continuations. + self._state = .finished(sinkError: error) + return .resumeContinuationsWithErrorAndCallDidTerminate( + delegate, + suspendedYields, + error ?? NIOAsyncWriterError.alreadyFinished() + ) + } - case .writerFinished(_, let delegate): + case .writerFinished(_, let delegate, let error): // The writer already finished and we were waiting to become writable again // The Sink finished before we became writable so we can drop the elements and // transition to finished @@ -1192,5 +1269,108 @@ extension NIOAsyncWriter { preconditionFailure("Invalid state") } } + + /// Actions returned by `sinkFinish()`. + @usableFromInline + enum UnbufferQueuedEventsAction { + case callDidYield(Delegate, Deque) + case callDidYieldElement(Delegate, Element) + case callDidTerminate(Delegate, Error?) + } + + @inlinable + /* fileprivate */ internal mutating func unbufferQueuedEvents() -> UnbufferQueuedEventsAction? { + switch self._state { + case .initial: + preconditionFailure("Invalid state") + + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, var elements, let delegate): + precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") + + if elements.count == 0 { + // Nothing to do. We haven't gotten any writes. + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: false, // We can now indicate that we are done with the outcall + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + return .none + } else if elements.count > 1 { + // We have to yield all of the elements now. + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: .init(), + delegate: delegate + ) + + return .callDidYield(delegate, elements) + + } else { + // There is only a single element and we can optimize this to not + // yield the whole Deque + self._state = .modifying + + // This force-unwrap is safe since we just checked the count of the Deque + // and it must be 1 here. + let element = elements.popFirst()! + + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + elements: elements, + delegate: delegate + ) + + return .callDidYieldElement(delegate, element) + } + + case .writerFinished(var elements, let delegate, let error): + if elements.isEmpty { + // We have returned the last buffered elements and have to + // call didTerminate now. + self._state = .finished(sinkError: nil) + return .callDidTerminate(delegate, error) + } else if elements.count > 1 { + // We have to yield all of the elements now. + self._state = .writerFinished( + elements: .init(), + delegate: delegate, + error: error + ) + + return .callDidYield(delegate, elements) + } else { + // There is only a single element and we can optimize this to not + // yield the whole Deque + self._state = .modifying + + // This force-unwrap is safe since we just checked the count of the Deque + // and it must be 1 here. + let element = elements.popFirst()! + + self._state = .writerFinished( + elements: .init(), + delegate: delegate, + error: error + ) + + return .callDidYieldElement(delegate, element) + } + + case .finished: + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } } } diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 277cd4bbee..01d281477c 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -85,7 +85,7 @@ final class AsyncChannelTests: XCTestCase { inboundReader = wrapped.inboundStream try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) + XCTAssertEqual(1, closeRecorder.outboundCloses) } } @@ -159,7 +159,7 @@ final class AsyncChannelTests: XCTestCase { inboundReader = wrapped.inboundStream try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.allCloses) + XCTAssertEqual(1, closeRecorder.allCloses) } } diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 3acb2c0ffc..d4ab9764d9 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -15,25 +15,32 @@ import DequeModule import NIOCore import XCTest +import NIOConcurrencyHelpers private struct SomeError: Error, Hashable {} private final class MockAsyncWriterDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { typealias Element = String - var didYieldCallCount = 0 + var _didYieldCallCount = NIOLockedValueBox(0) + var didYieldCallCount: Int { + self._didYieldCallCount.withLockedValue { $0 } + } var didYieldHandler: ((Deque) -> Void)? func didYield(contentsOf sequence: Deque) { - self.didYieldCallCount += 1 + self._didYieldCallCount.withLockedValue { $0 += 1 } if let didYieldHandler = self.didYieldHandler { didYieldHandler(sequence) } } - var didTerminateCallCount = 0 + var _didTerminateCallCount = NIOLockedValueBox(0) + var didTerminateCallCount: Int { + self._didTerminateCallCount.withLockedValue { $0 } + } var didTerminateHandler: ((Error?) -> Void)? func didTerminate(error: Error?) { - self.didTerminateCallCount += 1 + self._didTerminateCallCount.withLockedValue { $0 += 1 } if let didTerminateHandler = self.didTerminateHandler { didTerminateHandler(error) } @@ -68,6 +75,8 @@ final class NIOAsyncWriterTests: XCTestCase { } func testMultipleConcurrentWrites() async throws { + var elements = 0 + self.delegate.didYieldHandler = { elements += $0.count } let task1 = Task { [writer] in for i in 0...9 { try await writer!.yield("message\(i)") @@ -88,7 +97,33 @@ final class NIOAsyncWriterTests: XCTestCase { try await task2.value try await task3.value - XCTAssertEqual(self.delegate.didYieldCallCount, 30) + XCTAssertEqual(elements, 30) + } + + func testMultipleConcurrentBatchWrites() async throws { + var elements = 0 + self.delegate.didYieldHandler = { elements += $0.count } + let task1 = Task { [writer] in + for i in 0...9 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } + } + let task2 = Task { [writer] in + for i in 10...19 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } + } + let task3 = Task { [writer] in + for i in 20...29 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } + } + + try await task1.value + try await task2.value + try await task3.value + + XCTAssertEqual(elements, 60) } func testWriterCoalescesWrites() async throws { From 7954dba6de44c76e9efb7cefa02922b372ba1012 Mon Sep 17 00:00:00 2001 From: Lorenzo Fritzsch Date: Fri, 13 Oct 2023 12:13:30 +0200 Subject: [PATCH 20/64] Add jitter support to recurring tasks (#2542) * Add jitter support * Add jitter support to all functions managing recurring tasks * Fix typo * Fix typos * PR fixes * Remove unnecessary `@preconcurrency` annotations. * Fix tests --- Sources/NIOCore/EventLoop.swift | 70 ++++++++++++++++++++++++- Tests/NIOPosixTests/EventLoopTest.swift | 52 ++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/Sources/NIOCore/EventLoop.swift b/Sources/NIOCore/EventLoop.swift index b7f601f4a9..098e6f1652 100644 --- a/Sources/NIOCore/EventLoop.swift +++ b/Sources/NIOCore/EventLoop.swift @@ -807,6 +807,7 @@ extension EventLoop { ) -> Scheduled { self._flatScheduleTask(in: delay, file: file, line: line, task) } + @usableFromInline typealias FlatScheduleTaskDelayCallback = @Sendable () throws -> EventLoopFuture @inlinable @@ -919,6 +920,29 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } + + /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each + /// task. + /// + /// - parameters: + /// - initialDelay: The delay after which the first task is executed. + /// - delay: The delay between the end of one task and the start of the next. + /// - maximumAllowableJitter: Exclusive upper bound of jitter range added to the `delay` parameter. + /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. + /// - task: The closure that will be executed. + /// - return: `RepeatedTask` + @discardableResult + public func scheduleRepeatedTask( + initialDelay: TimeAmount, + delay: TimeAmount, + maximumAllowableJitter: TimeAmount, + notifying promise: EventLoopPromise? = nil, + _ task: @escaping @Sendable (RepeatedTask) throws -> Void + ) -> RepeatedTask { + let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) + return self.scheduleRepeatedTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) + } typealias ScheduleRepeatedTaskCallback = @Sendable (RepeatedTask) throws -> Void func _scheduleRepeatedTask( @@ -964,6 +988,36 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } + + /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and + /// start of each task. + /// + /// - note: The delay is measured from the completion of one run's returned future to the start of the execution of + /// the next run. For example: If you schedule a task once per second but your task takes two seconds to + /// complete, the time interval between two subsequent runs will actually be three seconds (2s run time plus + /// the 1s delay.) + /// + /// - parameters: + /// - initialDelay: The delay after which the first task is executed. + /// - delay: The delay between the end of one task and the start of the next. + /// - maximumAllowableJitter: Exclusive upper bound of jitter range added to the `delay` parameter. + /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. + /// - task: The closure that will be executed. Task will keep repeating regardless of whether the future + /// gets fulfilled with success or error. + /// + /// - return: `RepeatedTask` + @discardableResult + public func scheduleRepeatedAsyncTask( + initialDelay: TimeAmount, + delay: TimeAmount, + maximumAllowableJitter: TimeAmount, + notifying promise: EventLoopPromise? = nil, + _ task: @escaping @Sendable (RepeatedTask) -> EventLoopFuture + ) -> RepeatedTask { + let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) + return self._scheduleRepeatedAsyncTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) + } typealias ScheduleRepeatedAsyncTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture func _scheduleRepeatedAsyncTask( @@ -976,7 +1030,21 @@ extension EventLoop { repeated.begin(in: initialDelay) return repeated } - + + /// Adds a random amount of `.nanoseconds` (within `.zero.. TimeAmount { + let jitter = TimeAmount.nanoseconds(Int64.random(in: .zero..(0) + let loop = EmbeddedEventLoop() + + _ = loop.scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, maximumAllowableJitter: maximumAllowableJitter, { RepeatedTask in + counter.wrappingIncrement(ordering: .relaxed) + let p = loop.makePromise(of: Void.self) + loop.scheduleTask(in: .milliseconds(10)) { + p.succeed(()) + } + return p.futureResult + }) + + for _ in 0..<10 { + // just running shouldn't do anything + loop.run() + } + let timeRange = TimeAmount.hours(1) + // Due to jittered delays is not possible to exactly know how many tasks will be executed in a given time range, + // instead calculate a range representing an estimate of the number of tasks executed during that given time range. + let minNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / (delay.nanoseconds + maximumAllowableJitter.nanoseconds) + let maxNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / delay.nanoseconds + 1 + + loop.advanceTime(by: timeRange) + XCTAssertTrue((minNumberOfExecutedTasks...maxNumberOfExecutedTasks).contains(counter.load(ordering: .relaxed))) + } + public func testEventLoopGroupMakeIterator() throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) defer { @@ -803,6 +834,27 @@ public final class EventLoopTest : XCTestCase { semaphore.signal() XCTAssertEqual(XCTWaiter.wait(for: [expect1, expect2], timeout: 0.5), .completed) } + + func testRepeatedTaskIsJittered() throws { + let initialDelay = TimeAmount.minutes(5) + let delay = TimeAmount.minutes(2) + let maximumAllowableJitter = TimeAmount.minutes(1) + let counter = ManagedAtomic(0) + let loop = EmbeddedEventLoop() + + _ = loop.scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, maximumAllowableJitter: maximumAllowableJitter, { RepeatedTask in + counter.wrappingIncrement(ordering: .relaxed) + }) + + let timeRange = TimeAmount.hours(1) + // Due to jittered delays is not possible to exactly know how many tasks will be executed in a given time range, + // instead calculate a range representing an estimate of the number of tasks executed during that given time range. + let minNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / (delay.nanoseconds + maximumAllowableJitter.nanoseconds) + let maxNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / delay.nanoseconds + 1 + + loop.advanceTime(by: timeRange) + XCTAssertTrue((minNumberOfExecutedTasks...maxNumberOfExecutedTasks).contains(counter.load(ordering: .relaxed))) + } func testCancelledScheduledTasksDoNotHoldOnToRunClosure() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) From e9676f92712f363063238fbbc7c83a8934e34633 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 16 Oct 2023 10:40:46 +0100 Subject: [PATCH 21/64] Fix flaky `testRemovesAllHTTPRelatedHandlersAfterUpgrade` test (#2552) # Motivation This test has been flaky since we only waited for the untyped handler to be removed but we also execute this test for the typed test subclass. # Modification This PR fixes the test by also waiting for the typed handler to no longer be present in the pipeline. # Result Less flaky tests. --- Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 3b4ab37c9b..9717817ef7 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -62,8 +62,15 @@ extension ChannelPipeline { // handler present, keep waiting usleep(50) } catch ChannelPipelineError.notFound { - // No upgrader, we're good. - return + // Checking if the typed variant is present + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + // handler present, keep waiting + usleep(50) + } catch ChannelPipelineError.notFound { + // No upgrader, we're good. + return + } } } From 17eab37e935e47e9300cc062637b9146ff4799c8 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 16 Oct 2023 11:48:45 +0100 Subject: [PATCH 22/64] Fix `Sendable` conformance for `Lock` (#2556) # Motivation The latest nightly toolchains have merged the patch that removes `Sendable` conformance from the various `Unsafe*Pointer` APIs. Our `Lock` type was using such a type internally and had a `Sendable` conformance. This conformance was now failing since the pointer was no longer `Sendable`. # Modification This PR changes the `Sendable` conformance of `Lock` to `@unchecked Sendable`. # Result No more `Sendable` warnings in non-strict mode on nightly toolchains --- Sources/NIOConcurrencyHelpers/lock.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/NIOConcurrencyHelpers/lock.swift b/Sources/NIOConcurrencyHelpers/lock.swift index 2aac90ade1..5df4af7b15 100644 --- a/Sources/NIOConcurrencyHelpers/lock.swift +++ b/Sources/NIOConcurrencyHelpers/lock.swift @@ -295,5 +295,5 @@ internal func debugOnly(_ body: () -> Void) { } @available(*, deprecated) -extension Lock: Sendable {} +extension Lock: @unchecked Sendable {} extension ConditionLock: @unchecked Sendable {} From dda031625c71bc62f27186315594abd164f6ea06 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 16 Oct 2023 13:10:29 +0100 Subject: [PATCH 23/64] Remove `NIOProtocolNegotiationResult` (#2554) # Motivation After playing around more with the new async bootstrap methods, I came to the conclusion that the `NIOProtocolNegotiationResult` isn't carrying its weight. The `NIOProtocolNegotiationResult` is a glorified `EventLoopFuture` with an easy way to recursively resolve nested futures. Furthermore, it forces any nested protocol negotiation to use the same generic type for the end result. # Modification This PR removes the `NIOProtocolNegotiationResult` and changes all the tests to be solely based on `EventLoopFuture`s. # Result Less code to support the async bootstrap and better composition of nested protocol negotiation handlers. --- Sources/NIOCore/ChannelHandler.swift | 86 ------------------- ...pplicationProtocolNegotiationHandler.swift | 16 ++-- .../AsyncChannelBootstrapTests.swift | 68 +++++++-------- ...ationProtocolNegotiationHandlerTests.swift | 26 +++--- 4 files changed, 55 insertions(+), 141 deletions(-) diff --git a/Sources/NIOCore/ChannelHandler.swift b/Sources/NIOCore/ChannelHandler.swift index 1e630cdbd4..93f11f537b 100644 --- a/Sources/NIOCore/ChannelHandler.swift +++ b/Sources/NIOCore/ChannelHandler.swift @@ -343,89 +343,3 @@ extension RemovableChannelHandler { context.leavePipeline(removalToken: removalToken) } } - -/// The result of protocol negotiation. -@_spi(AsyncChannel) -public struct NIOProtocolNegotiationResult { - fileprivate enum Result { - /// Indicates that the protocol negotiation finished. - case finished(NegotiationResult) - /// Indicates that protocol negotiation has been deferred to the next handler. - case deferredResult(EventLoopFuture>) - } - - private let _result: Result - - /// The final result of protocol negotiation. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - var result: NegotiationResult { - get async throws { - switch self._result { - case .finished(let negotiationResult): - return negotiationResult - case .deferredResult(let eventLoopFuture): - return try await eventLoopFuture.flatMap { $0._result.resolve(on: eventLoopFuture.eventLoop) }.get() - } - } - } - - /// Intializes a new ``NIOProtocolNegotiationResult`` with a final result. - /// - /// - Parameter result: The final result of protocol negotiation. - public init(result: NegotiationResult) { - self._result = .finished(result) - } - - /// Intializes a new ``NIOProtocolNegotiationResult`` with a deferred result. - /// - /// - Parameter deferredResult: The deferred result. - public init(deferredResult: EventLoopFuture>) { - self._result = .deferredResult(deferredResult) - } - - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @available(*, deprecated, renamed: "getResult") - public func waitForFinalResult() async throws -> NegotiationResult { - try await self.result - } -} - -extension EventLoopFuture { - /// Get the result/error from the protocol negotiation. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func getResult() async throws -> NegotiationResult where Value == NIOProtocolNegotiationResult { - try await self.get().result - } -} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result { - fileprivate func resolve(on eventLoop: EventLoop) -> EventLoopFuture { - Self.resolve(on: eventLoop, result: self) - } - - fileprivate static func resolve(on eventLoop: EventLoop, result: Self) -> EventLoopFuture { - switch result { - case .finished(let negotiationResult): - return eventLoop.makeSucceededFuture(negotiationResult) - - case .deferredResult(let future): - return future.flatMap { result in - return resolve(on: eventLoop, result: result._result) - } - } - } -} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult: Equatable where NegotiationResult: Equatable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult: Sendable where NegotiationResult: Sendable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result: Equatable where NegotiationResult: Equatable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result: Sendable where NegotiationResult: Sendable {} diff --git a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift index 903f665974..1166525ac4 100644 --- a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift +++ b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift @@ -48,18 +48,18 @@ public final class NIOTypedApplicationProtocolNegotiationHandler> { + public var protocolNegotiationResult: EventLoopFuture { return self.negotiatedPromise.futureResult } - private var negotiatedPromise: EventLoopPromise> { + private var negotiatedPromise: EventLoopPromise { precondition(self._negotiatedPromise != nil, "Tried to access the protocol negotiation result before the handler was added to a pipeline") return self._negotiatedPromise! } - private var _negotiatedPromise: EventLoopPromise>? + private var _negotiatedPromise: EventLoopPromise? - private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture> - private var stateMachine = ProtocolNegotiationHandlerStateMachine>() + private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture + private var stateMachine = ProtocolNegotiationHandlerStateMachine() /// Create an `ApplicationProtocolNegotiationHandler` with the given completion /// callback. @@ -67,7 +67,7 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture>) { + public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture) { self.completionHandler = alpnCompleteHandler } @@ -77,7 +77,7 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture>) { + public convenience init(alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture) { self.init { result, _ in alpnCompleteHandler(result) } @@ -137,7 +137,7 @@ public final class NIOTypedApplicationProtocolNegotiationHandler, Error>) { + private func userFutureCompleted(context: ChannelHandlerContext, result: Result) { switch self.stateMachine.userFutureCompleted(with: result) { case .fireErrorCaughtAndRemoveHandler(let error): self.negotiatedPromise.fail(error) diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 7e9061b88d..0aa1306310 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -262,7 +262,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelOption(ChannelOptions.autoRead, value: true) .bind( @@ -282,7 +282,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try await withThrowingTaskGroup(of: Void.self) { group in for try await negotiationResult in channel.inboundStream { group.addTask { - switch try await negotiationResult.getResult() { + switch try await negotiationResult.get() { case .string(let channel): for try await value in channel.inboundStream { continuation.yield(.string(value)) @@ -302,7 +302,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .string ) - let stringNegotiationResult = try await stringNegotiationResultFuture.getResult() + let stringNegotiationResult = try await stringNegotiationResultFuture.get() switch stringNegotiationResult { case .string(let stringChannel): // This is the actual content @@ -317,7 +317,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .byte ) - let byteNegotiationResult = try await byteNegotiationResultFuture.getResult() + let byteNegotiationResult = try await byteNegotiationResultFuture.get() switch byteNegotiationResult { case .string: preconditionFailure() @@ -337,7 +337,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .bind( host: "127.0.0.1", @@ -356,7 +356,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try await withThrowingTaskGroup(of: Void.self) { group in for try await negotiationResult in channel.inboundStream { group.addTask { - switch try await negotiationResult.getResult() { + switch try await negotiationResult.get().get() { case .string(let channel): for try await value in channel.inboundStream { continuation.yield(.string(value)) @@ -377,7 +377,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .string, proposedInnerALPN: .string ) - switch try await stringStringNegotiationResult.getResult() { + switch try await stringStringNegotiationResult.get().get() { case .string(let stringChannel): // This is the actual content try await stringChannel.outboundWriter.write("hello") @@ -392,7 +392,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .byte, proposedInnerALPN: .string ) - switch try await byteStringNegotiationResult.getResult() { + switch try await byteStringNegotiationResult.get().get() { case .string(let stringChannel): // This is the actual content try await stringChannel.outboundWriter.write("hello") @@ -407,7 +407,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .byte, proposedInnerALPN: .byte ) - switch try await byteByteNegotiationResult.getResult() { + switch try await byteByteNegotiationResult.get().get() { case .string: preconditionFailure() case .byte(let byteChannel): @@ -422,7 +422,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .string, proposedInnerALPN: .byte ) - switch try await stringByteNegotiationResult.getResult() { + switch try await stringByteNegotiationResult.get().get() { case .string: preconditionFailure() case .byte(let byteChannel): @@ -460,7 +460,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } let channels = NIOLockedValueBox<[Channel]>([Channel]()) - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .serverChannelInitializer { channel in channel.eventLoop.makeCompletedFuture { @@ -485,7 +485,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try await withThrowingTaskGroup(of: Void.self) { group in for try await negotiationResult in channel.inboundStream { group.addTask { - switch try await negotiationResult.getResult() { + switch try await negotiationResult.get() { case .string(let channel): for try await value in channel.inboundStream { continuation.yield(.string(value)) @@ -506,7 +506,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedALPN: .unknown ) await XCTAssertThrowsError( - try await failedProtocolNegotiation.getResult() + try await failedProtocolNegotiation.get() ) // Let's check that we can still open a new connection @@ -515,7 +515,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .string ) - switch try await stringNegotiationResult.getResult() { + switch try await stringNegotiationResult.get() { case .string(let stringChannel): // This is the actual content try await stringChannel.outboundWriter.write("hello") @@ -575,7 +575,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { let port = channel.localAddress!.port! try await channel.close() - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in + try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in group.addTask { // We have to use a fixed port here since we only get the channel once protocol negotiation is done try await self.makeUDPServerChannelWithProtocolNegotiation( @@ -599,7 +599,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { let firstNegotiationResult = try await group.next() let secondNegotiationResult = try await group.next() - switch (try await firstNegotiationResult?.getResult(), try await secondNegotiationResult?.getResult()) { + switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() @@ -662,7 +662,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { let toChannel = FileHandle(fileDescriptor: pipe1WriteFH, closeOnDealloc: false) let fromChannel = FileHandle(fileDescriptor: pipe2ReadFH, closeOnDealloc: false) - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in + try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in group.addTask { do { return try await NIOPipeBootstrap(group: eventLoopGroup) @@ -682,7 +682,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try toChannel.writeBytes(.init(string: "alpn:string\nHello\n")) let negotiationResult = try await group.next() - switch try await negotiationResult?.getResult() { + switch try await negotiationResult?.get() { case .string(let channel): var inboundIterator = channel.inboundStream.makeAsyncIterator() do { @@ -732,7 +732,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in + try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in group.addTask { // We have to use a fixed port here since we only get the channel once protocol negotiation is done try await self.makeRawSocketServerChannelWithProtocolNegotiation( @@ -753,7 +753,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { let firstNegotiationResult = try await group.next() let secondNegotiationResult = try await group.next() - switch (try await firstNegotiationResult?.getResult(), try await secondNegotiationResult?.getResult()) { + switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() @@ -820,7 +820,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { private func makeRawSocketServerChannelWithProtocolNegotiation( eventLoopGroup: EventLoopGroup - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await NIORawSocketBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", @@ -837,7 +837,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { private func makeRawSocketClientChannelWithProtocolNegotiation( eventLoopGroup: EventLoopGroup, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await NIORawSocketBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -870,7 +870,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { return try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) @@ -886,7 +886,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: Int, proposedOuterALPN: TLSUserEventHandler.ALPN, proposedInnerALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture> { return try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) @@ -921,7 +921,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN? = nil - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await DatagramBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", @@ -954,7 +954,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await DatagramBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -973,7 +973,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedALPN: TLSUserEventHandler.ALPN? = nil, inboundID: UInt8? = nil, outboundID: UInt8? = nil - ) throws -> EventLoopFuture> { + ) throws -> EventLoopFuture { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: inboundID))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: outboundID))) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN)) @@ -985,11 +985,11 @@ final class AsyncChannelBootstrapTests: XCTestCase { channel: Channel, proposedOuterALPN: TLSUserEventHandler.ALPN? = nil, proposedInnerALPN: TLSUserEventHandler.ALPN? = nil - ) throws -> EventLoopFuture> { + ) throws -> EventLoopFuture> { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler> { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): switch alpn { @@ -998,14 +998,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - return NIOProtocolNegotiationResult(deferredResult: negotiationFuture) + return negotiationFuture } case "byte": return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - return NIOProtocolNegotiationResult(deferredResult: negotiationHandler) + return negotiationHandler } default: return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } @@ -1019,7 +1019,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } @discardableResult - private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture> { + private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture { let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): @@ -1031,7 +1031,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { synchronouslyWrapping: channel ) - return NIOProtocolNegotiationResult(result: .string(asyncChannel)) + return .string(asyncChannel) } case "byte": return channel.eventLoop.makeCompletedFuture { @@ -1041,7 +1041,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { synchronouslyWrapping: channel ) - return NIOProtocolNegotiationResult(result: .byte(asyncChannel)) + return .byte(asyncChannel) } default: return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } diff --git a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift index 99bc3e795d..c2d4d771d3 100644 --- a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift +++ b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift @@ -31,7 +31,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let channel = EmbeddedChannel() let handler = NIOTypedApplicationProtocolNegotiationHandler { result, channel in - return channel.eventLoop.makeSucceededFuture(.init(result: (.negotiated(result)))) + return channel.eventLoop.makeSucceededFuture(.negotiated(result)) } try channel.pipeline.addHandler(handler).wait() try channel.pipeline.removeHandler(handler).wait() @@ -49,14 +49,14 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { called = true XCTAssertEqual(result, self.negotiatedResult) XCTAssertTrue(emChannel === channel) - return loop.makeSucceededFuture(.init(result: (.negotiated(result)))) + return loop.makeSucceededFuture(.negotiated(result)) } try emChannel.pipeline.addHandler(handler).wait() emChannel.pipeline.fireUserInboundEventTriggered(negotiatedEvent) XCTAssertTrue(called) - XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .init(result: (.negotiated(negotiatedResult)))) + XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .negotiated(negotiatedResult)) } func testIgnoresUnknownUserEvents() throws { @@ -65,7 +65,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Negotiation fired") - return loop.makeSucceededFuture(.init(result: (.failed))) + return loop.makeSucceededFuture(.failed) } try channel.pipeline.addHandler(handler).wait() @@ -86,7 +86,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Should not be called") - return loop.makeSucceededFuture(.init(result: (.failed))) + return loop.makeSucceededFuture(.failed) } try channel.pipeline.addHandler(handler).wait() @@ -101,7 +101,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testBufferingWhileWaitingForFuture() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in return continuePromise.futureResult @@ -119,7 +119,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertNoThrow(XCTAssertNil(try channel.readInbound())) // Complete the pipeline swap. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) // Now everything should have been unbuffered. XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "writes")) @@ -132,7 +132,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testNothingBufferedDoesNotFireReadCompleted() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -150,7 +150,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { // Now satisfy the future, which forces data unbuffering. As we haven't buffered any data, // readComplete should not be fired. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0) XCTAssertTrue(try channel.finish().isClean) @@ -159,7 +159,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testUnbufferingFiresReadCompleted() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -179,7 +179,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1) // Now satisfy the future, which forces data unbuffering. This should fire readComplete. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 2) @@ -190,7 +190,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testUnbufferingHandlesReentrantReads() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -211,7 +211,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1) // Now satisfy the future, which forces data unbuffering. This should fire readComplete. - continuePromise.succeed(.init(result: .failed)) + continuePromise.succeed(.failed) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) From 0fb8cb794730d5acd67701bbb5c72543c8530c1d Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 16 Oct 2023 15:02:43 +0100 Subject: [PATCH 24/64] Add docs for the async NIO APIs (#2549) --- .../TCPEchoAsyncChannel.swift | 10 +- .../NIOCore/AsyncChannel/AsyncChannel.swift | 28 +- .../AsyncChannelOutboundWriter.swift | 4 +- .../NIOTypedHTTPClientUpgradeHandler.swift | 2 +- Sources/NIOTCPEchoClient/Client.swift | 4 +- Sources/NIOTCPEchoServer/Server.swift | 6 +- Sources/NIOWebSocketClient/Client.swift | 4 +- Sources/NIOWebSocketServer/Server.swift | 16 +- .../AsyncChannel/AsyncChannelTests.swift | 24 +- .../AsyncChannelBootstrapTests.swift | 78 +- docs/public-async-nio-apis.md | 1121 +++++++++++++++++ 11 files changed, 1210 insertions(+), 87 deletions(-) create mode 100644 docs/public-async-nio-apis.md diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift index ef28d9e4c1..99e5d0cf56 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -53,9 +53,9 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr try await withThrowingTaskGroup(of: Void.self) { group in // This child task is echoing back the data on the server. group.addTask { - for try await connectionChannel in serverChannel.inboundStream { - for try await inboundData in connectionChannel.inboundStream { - try await connectionChannel.outboundWriter.write(inboundData) + for try await connectionChannel in serverChannel.inbound { + for try await inboundData in connectionChannel.inbound { + try await connectionChannel.outbound.write(inboundData) } } } @@ -63,7 +63,7 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr // This child task is collecting the echoed back responses. group.addTask { var receivedData = 0 - for try await inboundData in clientChannel.inboundStream { + for try await inboundData in clientChannel.inbound { receivedData += inboundData.readableBytes if receivedData == numberOfWrites * messageSize { @@ -75,7 +75,7 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr // Let's start sending data. let data = ByteBuffer(repeating: 0, count: messageSize) for _ in 0..: Sendable { /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. public var isOutboundHalfClosureEnabled: Bool - /// The ``NIOAsyncChannel/inboundStream`` message's type. + /// The ``NIOAsyncChannel/inbound`` message's type. public var inboundType: Inbound.Type - /// The ``NIOAsyncChannel/outboundWriter`` message's type. + /// The ``NIOAsyncChannel/outbound`` message's type. public var outboundType: Outbound.Type /// Initializes a new ``NIOAsyncChannel/Configuration``. /// /// - Parameters: - /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. Defaults + /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inbound``. Defaults /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. Defaults to `false`. - /// - inboundType: The ``NIOAsyncChannel/inboundStream`` message's type. - /// - outboundType: The ``NIOAsyncChannel/outboundWriter`` message's type. + /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. + /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. public init( backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), isOutboundHalfClosureEnabled: Bool = false, @@ -77,12 +77,12 @@ public struct NIOAsyncChannel: Sendable { public let channel: Channel /// The stream of inbound messages. /// - /// - Important: The `inboundStream` is a unicast `AsyncSequence` and only one iterator can be created. + /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. @_spi(AsyncChannel) - public let inboundStream: NIOAsyncChannelInboundStream + public let inbound: NIOAsyncChannelInboundStream /// The writer for writing outbound messages. @_spi(AsyncChannel) - public let outboundWriter: NIOAsyncChannelOutboundWriter + public let outbound: NIOAsyncChannelOutboundWriter /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. /// @@ -100,7 +100,7 @@ public struct NIOAsyncChannel: Sendable { ) throws { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( + (self.inbound, self.outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) @@ -108,7 +108,7 @@ public struct NIOAsyncChannel: Sendable { /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. /// - /// This initializer will finish the ``NIOAsyncChannel/outboundWriter`` immediately. + /// This initializer will finish the ``NIOAsyncChannel/outbound`` immediately. /// /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. @@ -124,12 +124,12 @@ public struct NIOAsyncChannel: Sendable { ) throws where Outbound == Never { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( + (self.inbound, self.outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) - self.outboundWriter.finish() + self.outbound.finish() } @inlinable @@ -141,8 +141,8 @@ public struct NIOAsyncChannel: Sendable { ) { channel.eventLoop.preconditionInEventLoop() self.channel = channel - self.inboundStream = inboundStream - self.outboundWriter = outboundWriter + self.inbound = inboundStream + self.outbound = outboundWriter } @inlinable diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 4fda58f12d..9af339bf68 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -138,7 +138,9 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { } } - /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// Send an asynchronous sequence of writes into the ``ChannelPipeline``. + /// + /// This will flush after every write. /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift index 82c6f129f8..2fd8e25af2 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -190,7 +190,7 @@ public final class NIOTypedHTTPClientUpgradeHandler: Ch } } - public func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { + private func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { switch self.stateMachine.channelReadResponsePart(responsePart) { case .fireErrorCaughtAndRemoveHandler(let error): self.upgradeResultPromise.fail(error) diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index 0bbb1dba3e..29d8d5d65a 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -69,9 +69,9 @@ struct Client { } print("Connection(\(number)): Writing request") - try await channel.outboundWriter.write("Hello on connection \(number)") + try await channel.outbound.write("Hello on connection \(number)") - for try await inboundData in channel.inboundStream { + for try await inboundData in channel.inbound { print("Connection(\(number)): Received response (\(inboundData))") // We only expect a single response so we can exit here. diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index 1d38049058..b467fb99de 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -64,7 +64,7 @@ struct Server { // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in channel.inboundStream { + for try await connectionChannel in channel.inbound { group.addTask { print("Handling new connection") await self.handleConnection(channel: connectionChannel) @@ -80,9 +80,9 @@ struct Server { // We do this since we don't want to tear down the whole server when a single connection // encounters an error. do { - for try await inboundData in channel.inboundStream { + for try await inboundData in channel.inbound { print("Received request (\(inboundData))") - try await channel.outboundWriter.write(inboundData) + try await channel.outbound.write(inboundData) } } catch { print("Hit error: \(error)") diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index fef5b1e15e..cb6270e6d9 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -109,9 +109,9 @@ struct Client { // start to handle all inbound frames. let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) - try await channel.outboundWriter.write(pingFrame) + try await channel.outbound.write(pingFrame) - for try await frame in channel.inboundStream { + for try await frame in channel.inbound { switch frame.opcode { case .pong: print("Received pong: \(String(buffer: frame.data))") diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index ba6af55144..01bb64994d 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -114,7 +114,7 @@ struct Server { // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. try await withThrowingDiscardingTaskGroup { group in - for try await upgradeResult in channel.inboundStream { + for try await upgradeResult in channel.inbound { group.addTask { await self.handleUpgradeResult(upgradeResult) } @@ -146,7 +146,7 @@ struct Server { private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { - for try await frame in channel.inboundStream { + for try await frame in channel.inbound { switch frame.opcode { case .ping: print("Received ping") @@ -158,7 +158,7 @@ struct Server { } let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - try await channel.outboundWriter.write(responseFrame) + try await channel.outbound.write(responseFrame) case .connectionClose: // This is an unsolicited close. We're going to send a response frame and @@ -168,7 +168,7 @@ struct Server { var data = frame.unmaskedData let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) - try await channel.outboundWriter.write(closeFrame) + try await channel.outbound.write(closeFrame) return case .binary, .continuation, .pong: // We ignore these frames. @@ -193,7 +193,7 @@ struct Server { let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) print("Sending time") - try await channel.outboundWriter.write(frame) + try await channel.outbound.write(frame) try await Task.sleep(for: .seconds(1)) } } @@ -205,7 +205,7 @@ struct Server { private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { - for try await requestPart in channel.inboundStream { + for try await requestPart in channel.inbound { // We're not interested in request bodies here: we're just serving up GET responses // to get the client to initiate a websocket request. guard case .head(let head) = requestPart else { @@ -214,7 +214,7 @@ struct Server { // GETs only. guard case .GET = head.method else { - try await self.respond405(writer: channel.outboundWriter) + try await self.respond405(writer: channel.outbound) return } @@ -228,7 +228,7 @@ struct Server { headers: headers ) - try await channel.outboundWriter.write( + try await channel.outbound.write( contentsOf: [ .head(responseHead), .body(Self.responseBody), diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 01d281477c..5f7cae78dd 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -25,7 +25,7 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - var iterator = wrapped.inboundStream.makeAsyncIterator() + var iterator = wrapped.inbound.makeAsyncIterator() try await channel.writeInbound("hello") let firstRead = try await iterator.next() XCTAssertEqual(firstRead, "hello") @@ -51,8 +51,8 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - try await wrapped.outboundWriter.write("hello") - try await wrapped.outboundWriter.write("world") + try await wrapped.outbound.write("hello") + try await wrapped.outbound.write("world") let firstRead = try await channel.waitForOutboundWrite(as: String.self) let secondRead = try await channel.waitForOutboundWrite(as: String.self) @@ -82,7 +82,7 @@ final class AsyncChannelTests: XCTestCase { ) ) } - inboundReader = wrapped.inboundStream + inboundReader = wrapped.inbound try await channel.testingEventLoop.executeInContext { XCTAssertEqual(1, closeRecorder.outboundCloses) @@ -119,7 +119,7 @@ final class AsyncChannelTests: XCTestCase { ) ) } - inboundReader = wrapped.inboundStream + inboundReader = wrapped.inbound try await channel.testingEventLoop.executeInContext { XCTAssertEqual(0, closeRecorder.outboundCloses) @@ -156,7 +156,7 @@ final class AsyncChannelTests: XCTestCase { ) ) } - inboundReader = wrapped.inboundStream + inboundReader = wrapped.inbound try await channel.testingEventLoop.executeInContext { XCTAssertEqual(1, closeRecorder.allCloses) @@ -230,7 +230,7 @@ final class AsyncChannelTests: XCTestCase { try await channel.close().get() - let reads = try await Array(wrapped.inboundStream) + let reads = try await Array(wrapped.inbound) XCTAssertEqual(reads, ["hello"]) } @@ -246,7 +246,7 @@ final class AsyncChannelTests: XCTestCase { channel.pipeline.fireErrorCaught(TestError.bang) } - var iterator = wrapped.inboundStream.makeAsyncIterator() + var iterator = wrapped.inbound.makeAsyncIterator() let first = try await iterator.next() XCTAssertEqual(first, "hello") @@ -271,7 +271,7 @@ final class AsyncChannelTests: XCTestCase { await withThrowingTaskGroup(of: Void.self) { group in group.addTask { - try await wrapped.outboundWriter.write("hello") + try await wrapped.outbound.write("hello") lock.withLockedValue { XCTAssertTrue($0) } @@ -378,7 +378,7 @@ final class AsyncChannelTests: XCTestCase { XCTAssertEqual(readCounter.readCount, 6) // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. - var reader = wrapped.inboundStream.makeAsyncIterator() + var reader = wrapped.inbound.makeAsyncIterator() for _ in 0..<3 { try await XCTAsyncAssertNotNil(await reader.next()) } @@ -437,12 +437,12 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - var iterator = wrapped.inboundStream.makeAsyncIterator() + var iterator = wrapped.inbound.makeAsyncIterator() try await channel.writeInbound("hello") let firstRead = try await iterator.next() XCTAssertEqual(firstRead, "hello") - try await wrapped.outboundWriter.write("world") + try await wrapped.outbound.write("world") let write = try await channel.waitForOutboundWrite(as: String.self) XCTAssertEqual(write, "world") diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 0aa1306310..cd74144a45 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -239,8 +239,8 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { _ in - for try await childChannel in channel.inboundStream { - for try await value in childChannel.inboundStream { + for try await childChannel in channel.inbound { + for try await value in childChannel.inbound { continuation.yield(.string(value)) } } @@ -248,7 +248,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.outbound.write("hello") await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) @@ -280,15 +280,15 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { + for try await negotiationResult in channel.inbound { group.addTask { switch try await negotiationResult.get() { case .string(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.string(value)) } case .byte(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.byte(value)) } } @@ -306,7 +306,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch stringNegotiationResult { case .string(let stringChannel): // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.outbound.write("hello") await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -323,7 +323,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { preconditionFailure() case .byte(let byteChannel): // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.outbound.write(UInt8(8)) await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -354,15 +354,15 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { + for try await negotiationResult in channel.inbound { group.addTask { switch try await negotiationResult.get().get() { case .string(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.string(value)) } case .byte(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.byte(value)) } } @@ -380,7 +380,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch try await stringStringNegotiationResult.get().get() { case .string(let stringChannel): // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.outbound.write("hello") await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -395,7 +395,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch try await byteStringNegotiationResult.get().get() { case .string(let stringChannel): // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.outbound.write("hello") await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -412,7 +412,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { preconditionFailure() case .byte(let byteChannel): // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.outbound.write(UInt8(8)) await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -427,7 +427,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { preconditionFailure() case .byte(let byteChannel): // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.outbound.write(UInt8(8)) await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -483,15 +483,15 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { + for try await negotiationResult in channel.inbound { group.addTask { switch try await negotiationResult.get() { case .string(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.string(value)) } case .byte(let channel): - for try await value in channel.inboundStream { + for try await value in channel.inbound { continuation.yield(.byte(value)) } } @@ -518,7 +518,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch try await stringNegotiationResult.get() { case .string(let stringChannel): // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.outbound.write("hello") await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -549,13 +549,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: eventLoopGroup, port: serverChannel.channel.localAddress!.port! ) - var serverInboundIterator = serverChannel.inboundStream.makeAsyncIterator() - var clientInboundIterator = clientChannel.inboundStream.makeAsyncIterator() + var serverInboundIterator = serverChannel.inbound.makeAsyncIterator() + var clientInboundIterator = clientChannel.inbound.makeAsyncIterator() - try await clientChannel.outboundWriter.write("request") + try await clientChannel.outbound.write("request") try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outboundWriter.write("response") + try await serverChannel.outbound.write("response") try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") } @@ -601,13 +601,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() - var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() + var firstInboundIterator = firstChannel.inbound.makeAsyncIterator() + var secondInboundIterator = secondChannel.inbound.makeAsyncIterator() - try await firstChannel.outboundWriter.write("request") + try await firstChannel.outbound.write("request") try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outboundWriter.write("response") + try await secondChannel.outbound.write("response") try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") default: @@ -640,14 +640,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { throw error } - var inboundIterator = channel.inboundStream.makeAsyncIterator() + var inboundIterator = channel.inbound.makeAsyncIterator() do { try toChannel.writeBytes(.init(string: "Request")) try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) let response = ByteBuffer(string: "Response") - try await channel.outboundWriter.write(response) + try await channel.outbound.write(response) XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) } catch { // We only got to close the FDs that are not owned by the PipeChannel @@ -684,12 +684,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { let negotiationResult = try await group.next() switch try await negotiationResult?.get() { case .string(let channel): - var inboundIterator = channel.inboundStream.makeAsyncIterator() + var inboundIterator = channel.inbound.makeAsyncIterator() do { try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") let response = ByteBuffer(string: "Response") - try await channel.outboundWriter.write("Response") + try await channel.outbound.write("Response") XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) } catch { // We only got to close the FDs that are not owned by the PipeChannel @@ -715,13 +715,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { let serverChannel = try await self.makeRawSocketServerChannel(eventLoopGroup: eventLoopGroup) let clientChannel = try await self.makeRawSocketClientChannel(eventLoopGroup: eventLoopGroup) - var serverInboundIterator = serverChannel.inboundStream.makeAsyncIterator() - var clientInboundIterator = clientChannel.inboundStream.makeAsyncIterator() + var serverInboundIterator = serverChannel.inbound.makeAsyncIterator() + var clientInboundIterator = clientChannel.inbound.makeAsyncIterator() - try await clientChannel.outboundWriter.write("request") + try await clientChannel.outbound.write("request") try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outboundWriter.write("response") + try await serverChannel.outbound.write("response") try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") } @@ -755,13 +755,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() - var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() + var firstInboundIterator = firstChannel.inbound.makeAsyncIterator() + var secondInboundIterator = secondChannel.inbound.makeAsyncIterator() - try await firstChannel.outboundWriter.write("request") + try await firstChannel.outbound.write("request") try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outboundWriter.write("response") + try await secondChannel.outbound.write("response") try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") default: diff --git a/docs/public-async-nio-apis.md b/docs/public-async-nio-apis.md new file mode 100644 index 0000000000..85fba9c912 --- /dev/null +++ b/docs/public-async-nio-apis.md @@ -0,0 +1,1121 @@ +# Async NIO bridges + +This is a summary of all the new APIs we introduced to make NIO and the various +network protocols work with new async interfaces. The intention of this document +is to quickly outline the incompatibilities in the original API and then show a +holistic view over all the new APIs. + +## What APIs are needed? + +The first problem that we had to tackle was bridging NIO's `Channel` to Swift +Concurrency. To do so we introduced two new foundational types - the +`NIOAsyncSequenceProducer` and the `NIOAsyncWriter`. Those allow us to bridge +the read and the write side of the `Channel` while propagating the back-pressure +across the bridge. + +On top of those two types, we built out the `NIOAsyncChannel` which allows users +to bridge a `Channel` into Swift Concurrency. To do this it inserts two channel +handlers which bridge the read and write side using the +`NIOAsyncSequenceProducer` and the `NIOAsyncWriter`. + +Next up we had to look at the bootstraps. Here the import part is that the +`Channel`s **must** be wrapped at the correct timing otherwise there is the +potential that reads might be dropped. This is not problematic for most of the +bootstraps since they call their various `channelInitializer`s and +`childChannelInitializer`s at the right time. However, there was one tricky +bootstrap - `ServerBootstrap`. The `ServerBootstrap` multiplexes the incoming +connections and we have to make sure that the wrapping of the child channels +happens at the correct time. Additionally, the new bootstrap APIs **must** be +able to relay the type information of the configured channels to the +`bind`/`connect` methods. + +The next thing we had to tackle was networking protocols that dynamically +re-configure the `ChannelPipeline`. The two examples that we provide +implementations for are HTTP/1 protocol upgrades and Application Protocol +Negotiation (ALPN) via TLS. Similar to the bootstraps we have to ensure that the +type information is retained so that users can correctly identify which +reconfiguration path has been taken. + +Lastly, we had to look at how to handle protocols that multiplex, like HTTP/2. +Multiplexing protocols need to expose a typed async interface to consume new +inbound connections/streams and to open new outbound connections/streams where +applicable. + +## Proposed APIs + +This section contains all the new APIs that we are adding and gives us a holistic +overview to review them. + +### `NIOAsyncChannel` + +```swift +/// Wraps a NIO ``Channel`` object into a form suitable for use in Swift Concurrency. +/// +/// ``NIOAsyncChannel`` abstracts the notion of a NIO ``Channel`` into something that +/// can safely be used in a structured concurrency context. In particular, this exposes +/// the following functionality: +/// +/// - reads are presented as an `AsyncSequence` +/// - writes can be written to with async functions on a writer, providing back pressure +/// - channels can be closed seamlessly +/// +/// This type does not replace the full complexity of NIO's ``Channel``. In particular, it +/// does not expose the following functionality: +/// +/// - user events +/// - traditional NIO back pressure such as writability signals and the ``Channel/read()`` call +/// +/// Users are encouraged to separate their ``ChannelHandler``s into those that implement +/// protocol-specific logic (such as parsers and encoders) and those that implement business +/// logic. Protocol-specific logic should be implemented as a ``ChannelHandler``, while business +/// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannel : Sendable where Inbound : Sendable, Outbound : Sendable { + public struct Configuration : Sendable { + /// The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. + public var backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark + + /// If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. + public var isOutboundHalfClosureEnabled: Bool + + /// The ``NIOAsyncChannel/inbound`` message's type. + public var inboundType: Inbound.Type + + /// The ``NIOAsyncChannel/outbound`` message's type. + public var outboundType: Outbound.Type + + /// Initializes a new ``NIOAsyncChannel/Configuration``. + /// + /// - Parameters: + /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inbound``. Defaults + /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). + /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. Defaults to `false`. + /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. + /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. + public init(backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), isOutboundHalfClosureEnabled: Bool = false, inboundType: Inbound.Type = Inbound.self, outboundType: Outbound.Type = Outbound.self) + } + + /// The underlying channel being wrapped by this ``NIOAsyncChannel``. + public let channel: NIOCore.Channel + + /// The stream of inbound messages. + /// + /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. + public let inbound: NIOCore.NIOAsyncChannelInboundStream + + /// The writer for writing outbound messages. + public let outbound: NIOCore.NIOAsyncChannelOutboundWriter + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable public init(synchronouslyWrapping channel: NIOCore.Channel, configuration: NIOCore.NIOAsyncChannel.Configuration = .init()) throws + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. + /// + /// This initializer will finish the ``NIOAsyncChannel/outboundWriter`` immediately. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable public init(synchronouslyWrapping channel: NIOCore.Channel, configuration: NIOCore.NIOAsyncChannel.Configuration = .init()) throws where Outbound == Never + + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. + @inlinable public static func _wrapAsyncChannelWithTransformations(synchronouslyWrapping channel: NIOCore.Channel, backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, isOutboundHalfClosureEnabled: Bool = false, channelReadTransformation: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) throws -> NIOCore.NIOAsyncChannel where Outbound == Never +} + +/// The inbound message asynchronous sequence of a ``NIOAsyncChannel``. +/// +/// This is a unicast async sequence that allows a single iterator to be created. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannelInboundStream : Sendable where Inbound : Sendable { + /// A source used for driving a ``NIOAsyncChannelInboundStream`` during tests. + public struct TestSource { + /// Yields the element to the inbound stream. + /// + /// - Parameter element: The element to yield to the inbound stream. + @inlinable public func yield(_ element: Inbound) + + /// Finished the inbound stream. + /// + /// - Parameter error: The error to throw, or nil, to finish normally. + @inlinable public func finish(throwing error: Error? = nil) + } + + /// Creates a new stream with a source for testing. + /// + /// This is useful for writing unit tests where you want to drive a ``NIOAsyncChannelInboundStream``. + /// + /// - Returns: A tuple containing the input stream and a test source to drive it. + @inlinable public static func makeTestingStream() -> (NIOCore.NIOAsyncChannelInboundStream, NIOCore.NIOAsyncChannelInboundStream.TestSource) +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelInboundStream : AsyncSequence { + /// The type of element produced by this asynchronous sequence. + public typealias Element = Inbound + + /// The type of asynchronous iterator that produces elements of this + /// asynchronous sequence. + public struct AsyncIterator : AsyncIteratorProtocol { + /// Asynchronously advances to the next element and returns it, or ends the + /// sequence if there is no next element. + /// + /// - Returns: The next element, if it exists, or `nil` to signal the end of + /// the sequence. + @inlinable public mutating func next() async throws -> NIOCore.NIOAsyncChannelInboundStream.Element? + } + + /// Creates the asynchronous iterator that produces elements of this + /// asynchronous sequence. + /// + /// - Returns: An instance of the `AsyncIterator` type used to produce + /// elements of the asynchronous sequence. + @inlinable public func makeAsyncIterator() -> NIOCore.NIOAsyncChannelInboundStream.AsyncIterator +} + +/// A ``NIOAsyncChannelOutboundWriter`` is used to write and flush new outbound messages in a channel. +/// +/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the +/// underlying ``Channel``. Furthermore, it respects back-pressure of the channel by suspending the calls to write until +/// the channel becomes writable again. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannelOutboundWriter : Sendable where OutboundOut : Sendable { + /// An `AsyncSequence` backing a ``NIOAsyncChannelOutboundWriter`` for testing purposes. + public struct TestSink : AsyncSequence { + /// The type of element produced by this asynchronous sequence. + public typealias Element = OutboundOut + + /// Creates the asynchronous iterator that produces elements of this + /// asynchronous sequence. + /// + /// - Returns: An instance of the `AsyncIterator` type used to produce + /// elements of the asynchronous sequence. + public func makeAsyncIterator() -> NIOCore.NIOAsyncChannelOutboundWriter.TestSink.AsyncIterator + + /// The type of asynchronous iterator that produces elements of this + /// asynchronous sequence. + public struct AsyncIterator : AsyncIteratorProtocol { + /// Asynchronously advances to the next element and returns it, or ends the + /// sequence if there is no next element. + /// + /// - Returns: The next element, if it exists, or `nil` to signal the end of + /// the sequence. + public mutating func next() async -> NIOCore.NIOAsyncChannelOutboundWriter.TestSink.Element? + } + } + + /// Creates a new ``NIOAsyncChannelOutboundWriter`` backed by a ``NIOAsyncChannelOutboundWriter/TestSink``. + /// This is mostly useful for testing purposes where one wants to observe the written data. + @inlinable public static func makeTestingWriter() -> (NIOCore.NIOAsyncChannelOutboundWriter, NIOCore.NIOAsyncChannelOutboundWriter.TestSink) + + /// Send a write into the ``ChannelPipeline`` and flush it right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(_ data: OutboundOut) async throws + + /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(contentsOf sequence: Writes) async throws where OutboundOut == Writes.Element, Writes : Sequence + + /// Send an asynchronous sequence of writes into the ``ChannelPipeline``. + /// + /// This will flush after every write. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(contentsOf sequence: Writes) async throws where OutboundOut == Writes.Element, Writes : AsyncSequence + + /// Finishes the writer. + /// + /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. + public func finish() +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelOutboundWriter.TestSink : Sendable {} +``` + +### Bootstraps + +```swift +extension ClientBootstrap { + /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Specify the `address` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - address: The address to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. + /// + /// - Parameters: + /// - unixDomainSocketPath: The _Unix domain socket_ path to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Use the existing connected socket file descriptor. + /// + /// - Parameters: + /// - descriptor: The _Unix file descriptor_ representing the connected stream socket. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withConnectedSocket(_ socket: NIOCore.NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension ServerBootstrap { + /// Bind the `ServerSocketChannel` to the `host` and `port` parameters. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - port: The port to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, port: Int, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Bind the `ServerSocketChannel` to the `address` parameter. + /// + /// - Parameters: + /// - address: The `SocketAddress` to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(to address: NIOCore.SocketAddress, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Bind the `ServerSocketChannel` to a UNIX Domain Socket. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, + /// unless `cleanupExistingSocketFile`is set to `true`. + /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Use the existing bound socket file descriptor. + /// + /// - Parameters: + /// - socket: The _Unix file descriptor_ representing the bound stream socket. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(_ socket: NIOCore.NIOBSDSocket.Handle, cleanupExistingSocketFile: Bool = false, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable +} + +extension DatagramBootstrap { + /// Use the existing bound socket file descriptor. + /// + /// - Parameters: + /// - socket: The _Unix file descriptor_ representing the bound stream socket. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withBoundSocket(_ socket: NIOCore.NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to `host` and `port`. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - port: The port to bind on. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to the `address`. + /// + /// - Parameters: + /// - address: The `SocketAddress` to bind on. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to the `unixDomainSocketPath`. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, + /// unless `cleanupExistingSocketFile`is set to `true`. + /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to `host` and `port`. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to the `address`. + /// + /// - Parameters: + /// - address: The `SocketAddress` to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to the `unixDomainSocketPath`. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to connect to. `path` must not exist, it will be created by the system. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension NIOPipeBootstrap { + + /// Create the `PipeChannel` with the provided file descriptor which is used for both input & output. + /// + /// This method is useful for specialilsed use-cases where you want to use `NIOPipeBootstrap` for say a serial line. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `inputOutput` when the `Channel` + /// becomes inactive. You _must not_ do any further operations with `inputOutput`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing it. + /// + /// - Parameters: + /// - inputOutput: The _Unix file descriptor_ for the input & output. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptor(inputOutput: CInt, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Create the `PipeChannel` with the provided input and output file descriptors. + /// + /// The input and output file descriptors must be distinct. If you have a single file descriptor, consider using + /// `ClientBootstrap.withConnectedSocket(descriptor:)` if it's a socket or + /// `NIOPipeBootstrap.takingOwnershipOfDescriptor` if it is not a socket. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` and `output` + /// when the `Channel` becomes inactive. You _must not_ do any further operations `input` or + /// `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptors and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptors(input: CInt, output: CInt, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension NIORawSocketBootstrap { + + /// Bind the `Channel` to `host`. + /// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, ipProtocol: NIOCore.NIOIPProtocol, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `Channel` to `host`. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, ipProtocol: NIOCore.NIOIPProtocol, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} +``` + +### HTTPUpgrade + +```swift +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPClientPipelineConfiguration where UpgradeResult : Sendable { + + /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. + public var leftOverBytesStrategy: NIOHTTP1.RemoveAfterUpgradeStrategy + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableOutboundHeaderValidation: Bool + + /// The configuration for the ``HTTPRequestEncoder``. + public var encoderConfiguration: NIOHTTP1.HTTPRequestEncoder.Configuration + + /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. + public var upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Outbound header fields validation to protect against response splitting attacks. + public init(upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration) +} + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPServerPipelineConfiguration where UpgradeResult : Sendable { + + /// Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. + public var enablePipelining: Bool + + /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP + /// request) by sending 400 errors. Defaults to `true`. + public var enableErrorHandling: Bool + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableResponseHeaderValidation: Bool + + /// The configuration for the ``HTTPResponseEncoder``. + public var encoderConfiguration: NIOHTTP1.HTTPResponseEncoder.Configuration + + /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. + public var upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Assistance handling clients that pipeline HTTP requests. + /// 2. Assistance handling protocol errors. + /// 3. Outbound header fields validation to protect against response splitting attacks. + public init(upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration) +} + +extension ChannelPipeline { + + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPServerPipelineConfiguration) -> NIOCore.EventLoopFuture> where UpgradeResult : Sendable +} + +extension ChannelPipeline.SynchronousOperations { + + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPServerPipelineConfiguration) throws -> NIOCore.EventLoopFuture where UpgradeResult : Sendable +} + +extension ChannelPipeline { + + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPClientPipelineConfiguration) -> NIOCore.EventLoopFuture> where UpgradeResult : Sendable +} + +extension ChannelPipeline.SynchronousOperations { + + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPClientPipelineConfiguration) throws -> NIOCore.EventLoopFuture where UpgradeResult : Sendable +} + +/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a client-side channel. +/// It has the option of denying this upgrade based upon the server response. +public protocol NIOTypedHTTPClientProtocolUpgrader { + + associatedtype UpgradeResult : Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol requires in the request to successfully upgrade. + /// These header fields will be added to the outbound request's "Connection" header field. + /// It is the responsibility of the custom headers call to actually add these required headers. + var requiredUpgradeHeaders: [String] { get } + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + func shouldAllowUpgrade(upgradeResponse: NIOHTTP1.HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + func upgrade(channel: NIOCore.Channel, upgradeResponse: NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPClientUpgradeConfiguration where UpgradeResult : Sendable { + + /// The initial request head that is sent out once the channel becomes active. + public var upgradeRequestHead: NIOHTTP1.HTTPRequestHead + + /// The array of potential upgraders. + public var upgraders: [NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture + + public init(upgradeRequestHead: NIOHTTP1.HTTPRequestHead, upgraders: [NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader], notUpgradingCompletionHandler: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) +} + +/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. +/// This handler will add all appropriate headers to perform an upgrade to +/// the a protocol. It may add headers for a set of protocols in preference order. +/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply +/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. +/// +/// The request sends an order of preference to request which protocol it would like to use for the upgrade. +/// It will only upgrade to the protocol that is returned first in the list and does not currently +/// have the capability to upgrade to multiple simultaneous layered protocols. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedHTTPClientUpgradeHandler : NIOCore.ChannelDuplexHandler, NIOCore.RemovableChannelHandler where UpgradeResult : Sendable { + + /// The type of the outbound data which is wrapped in `NIOAny`. + public typealias OutboundIn = NIOHTTP1.HTTPClientRequestPart + + /// The type of the outbound data which will be forwarded to the next `ChannelOutboundHandler` in the `ChannelPipeline`. + public typealias OutboundOut = NIOHTTP1.HTTPClientRequestPart + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = NIOHTTP1.HTTPClientResponsePart + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = NIOHTTP1.HTTPClientResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: NIOCore.EventLoopFuture { get } + + /// Create a ``NIOTypedHTTPClientUpgradeHandler``. + /// + /// - Parameters: + /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline + /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state + /// after the upgrade. It should include any handlers that are directly related to handling HTTP. + /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include + /// any other handler that cannot tolerate receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init(httpHandlers: [NIOCore.RemovableChannelHandler], upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when the `Channel` has become active, and is able to send and receive data. + /// + /// This should call `context.fireChannelActive` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func channelActive(context: NIOCore.ChannelHandlerContext) + + /// Called to request a write operation. The write operation will write the messages through the + /// `ChannelPipeline`. Those are then ready to be flushed to the actual `Channel` when + /// `Channel.flush` or `ChannelHandlerContext.flush` is called. + /// + /// This should call `context.write` to forward the operation to the next `_ChannelOutboundHandler` in the `ChannelPipeline` or + /// complete the `EventLoopPromise` to let the caller know that the operation completed. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data to write through the `Channel`, wrapped in a `NIOAny`. + /// - promise: The `EventLoopPromise` which should be notified once the operation completes, or nil if no notification should take place. + public func write(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny, promise: NIOCore.EventLoopPromise?) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) +} + +/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a server-side channel. +public protocol NIOTypedHTTPServerProtocolUpgrader { + + associatedtype UpgradeResult : Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol needs in the request to successfully upgrade. These header fields + /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated + /// against the inbound request's `Connection` header field. + var requiredUpgradeHeaders: [String] { get } + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + func buildUpgradeResponse(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead, initialResponseHeaders: NIOHTTP1.HTTPHeaders) -> NIOCore.EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + func upgrade(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPServerUpgradeConfiguration where UpgradeResult : Sendable { + + /// The array of potential upgraders. + public var upgraders: [NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture + + public init(upgraders: [NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader], notUpgradingCompletionHandler: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) +} + +/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. +/// +/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of +/// whether the upgrade succeeded or not. +/// +/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade +/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's +/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined +/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: +/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedHTTPServerUpgradeHandler : NIOCore.ChannelInboundHandler, NIOCore.RemovableChannelHandler where UpgradeResult : Sendable { + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = NIOHTTP1.HTTPServerRequestPart + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = NIOHTTP1.HTTPServerRequestPart + + /// The type of the outbound data which will be forwarded to the next `ChannelOutboundHandler` in the `ChannelPipeline`. + public typealias OutboundOut = NIOHTTP1.HTTPServerResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: NIOCore.EventLoopFuture { get } + + /// Create a ``NIOTypedHTTPServerUpgradeHandler``. + /// + /// - Parameters: + /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will + /// be removed from the pipeline once the upgrade response is sent. This is used to ensure + /// that the pipeline will be in a clean state after upgrade. + /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least + /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate + /// receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init(httpEncoder: NIOHTTP1.HTTPResponseEncoder, extraHTTPHandlers: [NIOCore.RemovableChannelHandler], upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) +} +``` + +### Websocket + +```swift +/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the +/// pipeline to remove the HTTP `ChannelHandler`s. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedWebSocketClientUpgrader : NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader where UpgradeResult : Sendable { + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String + + /// None of the websocket headers are actually defined as 'required'. + public let requiredUpgradeHeaders: [String] + + /// - Parameters: + /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. + /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. + /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. + /// - upgradePipelineHandler: Called once the upgrade was successful. + public init(requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, upgradePipelineHandler: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture) + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + public func shouldAllowUpgrade(upgradeResponse: NIOHTTP1.HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + public func upgrade(channel: NIOCore.Channel, upgradeResponse: NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture +} + +/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// Users may frequently want to offer multiple websocket endpoints on the same port. For this +/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to +/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers +/// (including the path) and determines whether this is a websocket upgrade request that is acceptable +/// to them. +/// +/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to +/// remove the HTTP `ChannelHandler`s. +final public class NIOTypedWebSocketServerUpgrader : NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader, Sendable where UpgradeResult : Sendable { + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String + + /// We deliberately do not actually set any required headers here, because the websocket + /// spec annoyingly does not actually force the client to send these in the Upgrade header, + /// which NIO requires. We check for these manually. + public let requiredUpgradeHeaders: [String] + + /// Create a new ``NIOTypedWebSocketServerUpgrader``. + /// + /// - Parameters: + /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the + /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but + /// this is an objectively unreasonable maximum value (on AMD64 systems it is not + /// possible to even. Users may set this to any value up to `UInt32.max`. + /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol + /// errors by sending error responses and closing the connection. Defaults to `true`, + /// may be set to `false` if the user wishes to handle their own errors. + /// - shouldUpgrade: A callback that determines whether the websocket request should be + /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with + /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, + /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return + /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. + /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is + /// flushed, and that is expected to mutate the `Channel` appropriately to handle the + /// websocket protocol. This only needs to add the user handlers: the + /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the + /// pipeline automatically. + public init(maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, shouldUpgrade: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture, upgradePipelineHandler: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture) + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + public func buildUpgradeResponse(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead, initialResponseHeaders: NIOHTTP1.HTTPHeaders) -> NIOCore.EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + public func upgrade(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture +} +``` + +### ALPN + +```swift +/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines +/// based on the result of an ALPN negotiation. +/// +/// The standard pattern used by applications that want to use ALPN is to select +/// an application protocol based on the result, optionally falling back to some +/// default protocol. To do this in SwiftNIO requires that the channel pipeline be +/// reconfigured based on the result of the ALPN negotiation. This channel handler +/// encapsulates that logic in a generic form that doesn't depend on the specific +/// TLS implementation in use by using ``TLSUserEvent`` +/// +/// The user of this channel handler provides a single closure that is called with +/// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result +/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should +/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured. +/// +/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound +/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed +/// down the channel. Then, finally, this channel handler will automatically remove +/// itself from the channel pipeline, leaving the pipeline in its final +/// configuration. +/// +/// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to +/// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult`` +/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel`` +/// based bootstraps. +final public class NIOTypedApplicationProtocolNegotiationHandler : NIOCore.ChannelInboundHandler, NIOCore.RemovableChannelHandler { + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = Any + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = Any + + public var protocolNegotiationResult: NIOCore.EventLoopFuture { get } + + /// Create an `ApplicationProtocolNegotiationHandler` with the given completion + /// callback. + /// + /// - Parameter alpnCompleteHandler: The closure that will fire when ALPN + /// negotiation has completed. + public init(alpnCompleteHandler: @escaping (NIOTLS.ALPNResult, NIOCore.Channel) -> NIOCore.EventLoopFuture) + + /// Create an `ApplicationProtocolNegotiationHandler` with the given completion + /// callback. + /// + /// - Parameter alpnCompleteHandler: The closure that will fire when ALPN + /// negotiation has completed. + public convenience init(alpnCompleteHandler: @escaping (NIOTLS.ALPNResult) -> NIOCore.EventLoopFutureNegotiationResult>) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when a user inbound event has been triggered. + /// + /// This should call `context.fireUserInboundEventTriggered` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - event: The event. + public func userInboundEventTriggered(context: NIOCore.ChannelHandlerContext, event: Any) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) + + /// Called when the `Channel` has become inactive and is no longer able to send and receive data. + /// + /// This should call `context.fireChannelInactive` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func channelInactive(context: NIOCore.ChannelHandlerContext) +} +``` + +### HTTP/2.0 + +```swift +extension NIOHTTP2Handler { + /// A variant of `NIOHTTP2Handler.StreamMultiplexer` which creates a child channel for each HTTP/2 stream and + /// provides access to inbound HTTP/2 streams. + /// + /// In general in NIO applications it is helpful to consider each HTTP/2 stream as an + /// independent stream of HTTP/2 frames. This multiplexer achieves this by creating a + /// number of in-memory `HTTP2StreamChannel` objects, one for each stream. These operate + /// on ``HTTP2Frame/FramePayload`` objects as their base communication + /// atom, as opposed to the regular NIO `SelectableChannel` objects which use `ByteBuffer` + /// and `IOData`. + /// + /// Outbound stream channel objects are initialized upon creation using the supplied `streamStateInitializer` which returns a type + /// `Output`. This type may be `HTTP2Frame` or changed to any other type. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public struct AsyncStreamMultiplexer { + /// Create a stream channel initialized with the provided closure + public func createStreamChannel(_ initializer: @escaping NIOChannelInitializerWithOutput) async throws -> Output + } +} + +/// `NIOHTTP2InboundStreamChannels` provides access to inbound stream channels as a generic `AsyncSequence`. +/// They make use of generics to allow for wrapping the stream `Channel`s, for example as `NIOAsyncChannel`s or protocol negotiation objects. +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +public struct NIOHTTP2InboundStreamChannels: AsyncSequence { + public struct AsyncIterator: AsyncIteratorProtocol { + public typealias Element = Output + + public mutating func next() async throws -> Output? + } + + public typealias Element = Output + + public func makeAsyncIterator() -> AsyncIterator +} + +extension Channel { + /// Configures a `ChannelPipeline` to speak HTTP/2 and sets up mapping functions so that it may be interacted with from concurrent code. + /// + /// In general this is not entirely useful by itself, as HTTP/2 is a negotiated protocol. This helper does not handle negotiation. + /// Instead, this simply adds the handler required to speak HTTP/2 after negotiation has completed, or when agreed by prior knowledge. + /// Use this function to setup a HTTP/2 pipeline if you wish to use async sequence abstractions over inbound and outbound streams. + /// Using this rather than implementing a similar function yourself allows that pipeline to evolve without breaking your code. + /// + /// - Parameters: + /// - mode: The mode this pipeline will operate in, server or client. + /// - configuration: The settings that will be used when establishing the connection and new streams. + /// - inboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing the `AsyncStreamMultiplexer` inserted into this pipeline, which can + /// be used to initiate new streams and iterate over inbound HTTP/2 stream channels. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTP2Pipeline( + mode: NIOHTTP2Handler.ParserMode, + configuration: NIOHTTP2Handler.Configuration = .init(), + inboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) -> EventLoopFuture> + + /// Configures a `ChannelPipeline` to speak either HTTP/1.1 or HTTP/2 according to what can be negotiated with the client. + /// + /// This helper takes care of configuring the server pipeline such that it negotiates whether to + /// use HTTP/1.1 or HTTP/2. + /// + /// This function doesn't configure the TLS handler. Callers of this function need to add a TLS + /// handler appropriately configured to perform protocol negotiation. + /// + /// - Parameters: + /// - http2Configuration: The settings that will be used when establishing the HTTP/2 connections and new HTTP/2 streams. + /// - http1ConnectionInitializer: An optional callback that will be invoked only when the negotiated protocol + /// is HTTP/1.1 to configure the connection channel. + /// - http2ConnectionInitializer: An optional callback that will be invoked only when the negotiated protocol + /// is HTTP/2 to configure the connection channel. + /// - http2InboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing a ``NIOTypedApplicationProtocolNegotiationHandler`` that completes when the channel + /// is ready to negotiate. This can then be used to access the protocol negotiation result which may itself + /// be waited on to retrieve the result of the negotiation. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTPServerPipeline( + http2Configuration: NIOHTTP2Handler.Configuration = .init(), + http1ConnectionInitializer: @escaping NIOChannelInitializerWithOutput, + http2ConnectionInitializer: @escaping NIOChannelInitializerWithOutput, + http2InboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) -> EventLoopFuture) + >>> + +extension ChannelPipeline.SynchronousOperations { + /// Configures a `ChannelPipeline` to speak HTTP/2 and sets up mapping functions so that it may be interacted with from concurrent code. + /// + /// This operation **must** be called on the event loop. + /// + /// In general this is not entirely useful by itself, as HTTP/2 is a negotiated protocol. This helper does not handle negotiation. + /// Instead, this simply adds the handler required to speak HTTP/2 after negotiation has completed, or when agreed by prior knowledge. + /// Use this function to setup a HTTP/2 pipeline if you wish to use async sequence abstractions over inbound and outbound streams, + /// as it allows that pipeline to evolve without breaking your code. + /// + /// - Parameters: + /// - mode: The mode this pipeline will operate in, server or client. + /// - configuration: The settings that will be used when establishing the connection and new streams. + /// - inboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing the `AsyncStreamMultiplexer` inserted into this pipeline, which can + /// be used to initiate new streams and iterate over inbound HTTP/2 stream channels. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTP2Pipeline( + mode: NIOHTTP2Handler.ParserMode, + configuration: NIOHTTP2Handler.Configuration = .init(), + inboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) throws -> NIOHTTP2Handler.AsyncStreamMultiplexer +} + +/// `NIONegotiatedHTTPVersion` is a generic negotiation result holder for HTTP/1.1 and HTTP/2 +public enum NIONegotiatedHTTPVersion { + case http1_1(HTTP1Output) + case http2(HTTP2Output) +} +``` \ No newline at end of file From f38b7fd38a1f5ba3c38ec63322a25b7e18939a62 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 17 Oct 2023 14:14:02 +0100 Subject: [PATCH 25/64] Remove SPI from `NIOAsyncChannel`, new bootstrap methods, protocol negotiation and HTTP upgrade. (#2548) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Motivation Over the past months, we have been working on new async bridges to make using NIO's `Channel` from Swift Concurrency possible. Since this work was far reaching we have opted to land all of it as SPI. Now the time has come and we feel confident enough to make the SPI official API. This comes after testing the new APIs in various scenarios such as HTTP 1&2, HTTP upgrades, protocol negotiation and in benchmarks. # Modification This PR removes the SPI from the `NIOAsyncChannel`, the bootstrap methods, protocol negotiation and HTTP upgrade. # Result Everyone can use the our new APIs🚀 --- .../TCPEchoAsyncChannel.swift | 4 +-- .../NIOCore/AsyncChannel/AsyncChannel.swift | 33 ++++++++----------- .../AsyncChannelInboundStream.swift | 6 +--- .../AsyncChannelOutboundWriter.swift | 7 +--- .../NIOCore/Docs.docc/swift-concurrency.md | 4 +-- Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 8 +---- .../NIOTypedHTTPClientUpgradeHandler.swift | 3 -- .../NIOTypedHTTPServerUpgradeHandler.swift | 5 +-- Sources/NIOPosix/Bootstrap.swift | 24 ++------------ Sources/NIOPosix/RawSocketBootstrap.swift | 4 +-- Sources/NIOTCPEchoClient/Client.swift | 4 +-- Sources/NIOTCPEchoServer/Server.swift | 4 +-- ...pplicationProtocolNegotiationHandler.swift | 23 ++++--------- .../NIOWebSocketClientUpgrader.swift | 3 +- .../NIOWebSocketServerUpgrader.swift | 5 ++- Sources/NIOWebSocketClient/Client.swift | 8 ++--- Sources/NIOWebSocketServer/Server.swift | 8 ++--- .../AsyncChannelInboundStreamTests.swift | 2 +- .../AsyncChannelOutboundWriterTests.swift | 2 +- .../AsyncChannel/AsyncChannelTests.swift | 4 +-- .../HTTPClientUpgradeTests.swift | 2 +- .../HTTPServerUpgradeTests.swift | 2 +- .../AsyncChannelBootstrapTests.swift | 6 ++-- ...ationProtocolNegotiationHandlerTests.swift | 4 +-- .../WebSocketClientEndToEndTests.swift | 4 +-- .../WebSocketServerEndToEndTests.swift | 4 +-- docs/public-async-nio-apis.md | 4 +-- 27 files changed, 64 insertions(+), 123 deletions(-) diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift index 99e5d0cf56..bfe553c0db 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix +import NIOCore +import NIOPosix func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async throws { let serverChannel = try await ServerBootstrap(group: eventLoop) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index 96856b7327..1619199571 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -33,15 +33,13 @@ /// logic. Protocol-specific logic should be implemented as a ``ChannelHandler``, while business /// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) public struct NIOAsyncChannel: Sendable { - @_spi(AsyncChannel) public struct Configuration: Sendable { - /// The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. + /// The back pressure strategy of the ``NIOAsyncChannel/inbound``. public var backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark /// If outbound half closure should be enabled. Outbound half closure is triggered once - /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. public var isOutboundHalfClosureEnabled: Bool /// The ``NIOAsyncChannel/inbound`` message's type. @@ -56,7 +54,7 @@ public struct NIOAsyncChannel: Sendable { /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inbound``. Defaults /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once - /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. Defaults to `false`. + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. Defaults to `false`. /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. public init( @@ -73,15 +71,12 @@ public struct NIOAsyncChannel: Sendable { } /// The underlying channel being wrapped by this ``NIOAsyncChannel``. - @_spi(AsyncChannel) public let channel: Channel /// The stream of inbound messages. /// /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. - @_spi(AsyncChannel) public let inbound: NIOAsyncChannelInboundStream /// The writer for writing outbound messages. - @_spi(AsyncChannel) public let outbound: NIOAsyncChannelOutboundWriter /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. @@ -93,7 +88,6 @@ public struct NIOAsyncChannel: Sendable { /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable - @_spi(AsyncChannel) public init( synchronouslyWrapping channel: Channel, configuration: Configuration = .init() @@ -117,10 +111,9 @@ public struct NIOAsyncChannel: Sendable { /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable - @_spi(AsyncChannel) public init( synchronouslyWrapping channel: Channel, - configuration: Configuration + configuration: Configuration = .init() ) throws where Outbound == Never { channel.eventLoop.preconditionInEventLoop() self.channel = channel @@ -133,8 +126,7 @@ public struct NIOAsyncChannel: Sendable { } @inlinable - @_spi(AsyncChannel) - public init( + internal init( channel: Channel, inboundStream: NIOAsyncChannelInboundStream, outboundWriter: NIOAsyncChannelOutboundWriter @@ -145,9 +137,13 @@ public struct NIOAsyncChannel: Sendable { self.outbound = outboundWriter } + + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. @inlinable - @_spi(AsyncChannel) - public static func wrapAsyncChannelWithTransformations( + public static func _wrapAsyncChannelWithTransformations( synchronouslyWrapping channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, isOutboundHalfClosureEnabled: Bool = false, @@ -171,11 +167,9 @@ public struct NIOAsyncChannel: Sendable { } extension Channel { - // TODO: We need to remove the public and spi here once we make the AsyncChannel methods public @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable - @_spi(AsyncChannel) - public func _syncAddAsyncHandlers( + func _syncAddAsyncHandlers( backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { @@ -196,8 +190,7 @@ extension Channel { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable - @_spi(AsyncChannel) - public func _syncAddAsyncHandlersWithTransformations( + func _syncAddAsyncHandlersWithTransformations( backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index 746d128f5d..fb713929f6 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -16,7 +16,6 @@ /// /// This is a unicast async sequence that allows a single iterator to be created. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) public struct NIOAsyncChannelInboundStream: Sendable { @usableFromInline typealias Producer = NIOThrowingAsyncSequenceProducer @@ -149,10 +148,8 @@ public struct NIOAsyncChannelInboundStream: Sendable { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelInboundStream: AsyncSequence { - @_spi(AsyncChannel) public typealias Element = Inbound - @_spi(AsyncChannel) public struct AsyncIterator: AsyncIteratorProtocol { @usableFromInline enum _Backing { @@ -172,7 +169,7 @@ extension NIOAsyncChannelInboundStream: AsyncSequence { } } - @inlinable @_spi(AsyncChannel) + @inlinable public mutating func next() async throws -> Element? { switch self._backing { case .asyncStream(var iterator): @@ -189,7 +186,6 @@ extension NIOAsyncChannelInboundStream: AsyncSequence { } @inlinable - @_spi(AsyncChannel) public func makeAsyncIterator() -> AsyncIterator { return AsyncIterator(self._backing) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 9af339bf68..50e4d2ad4d 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -12,13 +12,12 @@ // //===----------------------------------------------------------------------===// -/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel. +/// A ``NIOAsyncChannelOutboundWriter`` is used to write and flush new outbound messages in a channel. /// /// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the /// underlying ``Channel``. Furthermore, it respects back-pressure of the channel by suspending the calls to write until /// the channel becomes writable again. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) public struct NIOAsyncChannelOutboundWriter: Sendable { @usableFromInline typealias _Writer = NIOAsyncChannelOutboundWriterHandler.Writer @@ -112,7 +111,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(_ data: OutboundOut) async throws { switch self._backing { case .asyncStream(let continuation): @@ -126,7 +124,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { switch self._backing { case .asyncStream(let continuation): @@ -144,7 +141,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { for try await data in sequence { try await self.write(data) @@ -154,7 +150,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// Finishes the writer. /// /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. - @_spi(AsyncChannel) public func finish() { switch self._backing { case .asyncStream(let continuation): diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index 609a2c87d4..0a5b58f47c 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -46,7 +46,7 @@ bi-directional streaming pipeline. To bridge such a pipeline into Concurrency required new types. Importantly, these types need to uphold the channel's back pressure and writability guarantees. NIO introduced the ``NIOThrowingAsyncSequenceProducer``, ``NIOAsyncSequenceProducer`` and the -``NIOAsyncWriter`` which form the foundation to bridge a ``Channel``. On top of +``NIOAsyncChannelOutboundWriter`` which form the foundation to bridge a ``Channel``. On top of these foundational types, NIO provides the `NIOAsyncChannel` which is used to wrap a ``Channel`` to produce an interface that can be consumed directly from Swift Concurrency. The following sections cover the details of the foundational @@ -65,7 +65,7 @@ sequence. #### NIOAsyncWriter -The ``NIOAsyncWriter`` is used for bridging from an asynchronous producer to a +The ``NIOAsyncChannelOutboundWriter`` is used for bridging from an asynchronous producer to a synchronous consumer. It also has back pressure support which allows the consumer to stop the producer by suspending the ``NIOAsyncWriter/yield(contentsOf:)`` method. diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift index e69d034ba1..57fe5fd780 100644 --- a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -11,13 +11,12 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore // MARK: - Server pipeline configuration /// Configuration for an upgradable HTTP pipeline. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public struct NIOUpgradableHTTPServerPipelineConfiguration { /// Whether to provide assistance handling HTTP clients that pipeline /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. @@ -58,7 +57,6 @@ extension ChannelPipeline { /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - @_spi(AsyncChannel) public func configureUpgradableHTTPServerPipeline( configuration: NIOUpgradableHTTPServerPipelineConfiguration ) -> EventLoopFuture> { @@ -99,7 +97,6 @@ extension ChannelPipeline.SynchronousOperations { /// - configuration: The HTTP pipeline's configuration. /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - @_spi(AsyncChannel) public func configureUpgradableHTTPServerPipeline( configuration: NIOUpgradableHTTPServerPipelineConfiguration ) throws -> EventLoopFuture { @@ -148,7 +145,6 @@ extension ChannelPipeline.SynchronousOperations { /// Configuration for an upgradable HTTP pipeline. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public struct NIOUpgradableHTTPClientPipelineConfiguration { /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. public var leftOverBytesStrategy = RemoveAfterUpgradeStrategy.dropBytes @@ -182,7 +178,6 @@ extension ChannelPipeline { /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - @_spi(AsyncChannel) public func configureUpgradableHTTPClientPipeline( configuration: NIOUpgradableHTTPClientPipelineConfiguration ) -> EventLoopFuture> { @@ -221,7 +216,6 @@ extension ChannelPipeline.SynchronousOperations { /// - configuration: The HTTP pipeline's configuration. /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - @_spi(AsyncChannel) public func configureUpgradableHTTPClientPipeline( configuration: NIOUpgradableHTTPClientPipelineConfiguration ) throws -> EventLoopFuture { diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift index 2fd8e25af2..f5a2f505ec 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -16,7 +16,6 @@ import NIOCore /// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to /// a protocol on a client-side channel. /// It has the option of denying this upgrade based upon the server response. -@_spi(AsyncChannel) public protocol NIOTypedHTTPClientProtocolUpgrader { associatedtype UpgradeResult: Sendable @@ -42,7 +41,6 @@ public protocol NIOTypedHTTPClientProtocolUpgrader { /// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public struct NIOTypedHTTPClientUpgradeConfiguration { /// The initial request head that is sent out once the channel becomes active. public var upgradeRequestHead: HTTPRequestHead @@ -76,7 +74,6 @@ public struct NIOTypedHTTPClientUpgradeConfiguration { /// It will only upgrade to the protocol that is returned first in the list and does not currently /// have the capability to upgrade to multiple simultaneous layered protocols. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { public typealias OutboundIn = HTTPClientRequestPart public typealias OutboundOut = HTTPClientRequestPart diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift index 9e43f2d7d3..55b21e5982 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -11,11 +11,10 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore /// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to /// a protocol on a server-side channel. -@_spi(AsyncChannel) public protocol NIOTypedHTTPServerProtocolUpgrader { associatedtype UpgradeResult: Sendable @@ -47,7 +46,6 @@ public protocol NIOTypedHTTPServerProtocolUpgrader { /// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public struct NIOTypedHTTPServerUpgradeConfiguration { /// The array of potential upgraders. public var upgraders: [any NIOTypedHTTPServerProtocolUpgrader] @@ -76,7 +74,6 @@ public struct NIOTypedHTTPServerUpgradeConfiguration { /// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: /// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -@_spi(AsyncChannel) public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = HTTPServerRequestPart diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 02e5a80794..f503d77444 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore #if os(Windows) import ucrt @@ -474,7 +474,6 @@ extension ServerBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, port: Int, @@ -499,7 +498,6 @@ extension ServerBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( to address: SocketAddress, serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, @@ -535,7 +533,6 @@ extension ServerBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, @@ -564,7 +561,6 @@ extension ServerBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( _ socket: NIOBSDSocket.Handle, cleanupExistingSocketFile: Bool = false, @@ -623,7 +619,7 @@ extension ServerBootstrap { name: "AcceptHandler" ) let asyncChannel = try NIOAsyncChannel - .wrapAsyncChannelWithTransformations( + ._wrapAsyncChannelWithTransformations( synchronouslyWrapping: serverChannel, backPressureStrategy: serverBackPressureStrategy, channelReadTransformation: { channel -> EventLoopFuture in @@ -1067,7 +1063,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, port: Int, @@ -1093,7 +1088,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1118,7 +1112,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1138,7 +1131,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func withConnectedSocket( _ socket: NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1567,7 +1559,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func withBoundSocket( _ socket: NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1598,7 +1589,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, port: Int, @@ -1623,7 +1613,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1649,7 +1638,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, @@ -1679,7 +1667,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, port: Int, @@ -1704,7 +1691,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1728,7 +1714,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -2077,7 +2062,6 @@ extension NIOPipeBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func takingOwnershipOfDescriptor( inputOutput: CInt, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -2116,7 +2100,6 @@ extension NIOPipeBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func takingOwnershipOfDescriptors( input: CInt, output: CInt, @@ -2131,8 +2114,7 @@ extension NIOPipeBootstrap { } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) // Should become private - public func _takingOwnershipOfDescriptors( + func _takingOwnershipOfDescriptors( input: CInt, output: CInt, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, diff --git a/Sources/NIOPosix/RawSocketBootstrap.swift b/Sources/NIOPosix/RawSocketBootstrap.swift index 9712bd3a9b..9847f17e1f 100644 --- a/Sources/NIOPosix/RawSocketBootstrap.swift +++ b/Sources/NIOPosix/RawSocketBootstrap.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore /// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP. /// @@ -204,7 +204,6 @@ extension NIORawSocketBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, ipProtocol: NIOIPProtocol, @@ -227,7 +226,6 @@ extension NIORawSocketBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, ipProtocol: NIOIPProtocol, diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index 29d8d5d65a..0d8bd4404f 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// #if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix +import NIOCore +import NIOPosix @available(macOS 14, *) @main diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index b467fb99de..390fff795b 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// #if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix +import NIOCore +import NIOPosix @available(macOS 14, *) @main diff --git a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift index 1166525ac4..d3e63fcf47 100644 --- a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift +++ b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore -/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines +/// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines /// based on the result of an ALPN negotiation. /// /// The standard pattern used by applications that want to use ALPN is to select @@ -26,28 +26,24 @@ /// /// The user of this channel handler provides a single closure that is called with /// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result -/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should -/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured. +/// the user is free to reconfigure the `ChannelPipeline` as required, and should +/// return an `EventLoopFuture` that will complete when the pipeline is reconfigured. /// -/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound -/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed +/// Until the `EventLoopFuture` completes, this channel handler will buffer inbound +/// data. When the `EventLoopFuture` completes, the buffered data will be replayed /// down the channel. Then, finally, this channel handler will automatically remove /// itself from the channel pipeline, leaving the pipeline in its final /// configuration. /// /// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to /// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult`` -/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel`` +/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into `NIOAsyncChannel` /// based bootstraps. -@_spi(AsyncChannel) public final class NIOTypedApplicationProtocolNegotiationHandler: ChannelInboundHandler, RemovableChannelHandler { - @_spi(AsyncChannel) public typealias InboundIn = Any - @_spi(AsyncChannel) public typealias InboundOut = Any - @_spi(AsyncChannel) public var protocolNegotiationResult: EventLoopFuture { return self.negotiatedPromise.futureResult } @@ -66,7 +62,6 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture) { self.completionHandler = alpnCompleteHandler } @@ -76,7 +71,6 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture) { self.init { result, _ in alpnCompleteHandler(result) @@ -97,7 +91,6 @@ public final class NIOTypedApplicationProtocolNegotiationHandler: NIOTypedHTTPClientProtocolUpgrader { /// RFC 6455 specs this as the required entry in the Upgrade header. public let supportedProtocol: String = "websocket" diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index baa66eec5a..4580d0ec07 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// import CNIOSHA1 -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOHTTP1 +import NIOCore +import NIOHTTP1 let magicWebSocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -185,7 +185,6 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch /// /// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to /// remove the HTTP `ChannelHandler`s. -@_spi(AsyncChannel) public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index cb6270e6d9..6477416684 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// #if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix -@_spi(AsyncChannel) import NIOHTTP1 -@_spi(AsyncChannel) import NIOWebSocket +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket @available(macOS 14, *) @main diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index 01bb64994d..525c64b00d 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// #if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix -@_spi(AsyncChannel) import NIOHTTP1 -@_spi(AsyncChannel) import NIOWebSocket +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket let websocketResponse = """ diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift index f3d60baee6..f12df126a5 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import XCTest final class AsyncChannelInboundStreamTests: XCTestCase { diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift index 1a58a54358..49d070aab6 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import XCTest final class AsyncChannelOutboundWriterTests: XCTestCase { diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 5f7cae78dd..c5feb91e81 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import Atomics import NIOConcurrencyHelpers -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import NIOEmbedded import XCTest @@ -41,7 +41,7 @@ final class AsyncChannelTests: XCTestCase { let thirdRead = try await iterator.next() XCTAssertNil(thirdRead) - try await channel.close() + try await channel.closeFuture.get() } func testAsyncChannelBasicWrites() async throws { diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 8344f4741e..c22297c53b 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -16,7 +16,7 @@ import XCTest import Dispatch @testable import NIOCore import NIOEmbedded -@_spi(AsyncChannel) @testable import NIOHTTP1 +@testable import NIOHTTP1 extension EmbeddedChannel { diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 9717817ef7..1677a83525 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -16,7 +16,7 @@ import XCTest import NIOCore import NIOEmbedded @testable import NIOPosix -@testable @_spi(AsyncChannel) import NIOHTTP1 +@testable import NIOHTTP1 extension ChannelPipeline { fileprivate func assertDoesNotContainUpgrader() throws { diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index cd74144a45..9788b284f8 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -13,10 +13,10 @@ //===----------------------------------------------------------------------===// import NIOConcurrencyHelpers -@_spi(AsyncChannel) @testable import NIOCore -@_spi(AsyncChannel) @testable import NIOPosix +@testable import NIOCore +@testable import NIOPosix import XCTest -@_spi(AsyncChannel) import NIOTLS +import NIOTLS private final class IPHeaderRemoverHandler: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope diff --git a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift index c2d4d771d3..36fc7bd04a 100644 --- a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift +++ b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOTLS -@_spi(AsyncChannel) import NIOCore +import NIOTLS +import NIOCore import NIOEmbedded import XCTest import NIOTestUtils diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 94df819eb9..bd9cef6936 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -15,8 +15,8 @@ import XCTest import NIOCore import NIOEmbedded -@_spi(AsyncChannel) import NIOHTTP1 -@_spi(AsyncChannel) @testable import NIOWebSocket +import NIOHTTP1 +@testable import NIOWebSocket extension EmbeddedChannel { diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index 0058aa6fc7..44246e3ab0 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -15,8 +15,8 @@ import XCTest @testable import NIOCore import NIOEmbedded -@_spi(AsyncChannel) import NIOHTTP1 -@testable @_spi(AsyncChannel) import NIOWebSocket +import NIOHTTP1 +@testable import NIOWebSocket extension EmbeddedChannel { func readAllInboundBuffers() throws -> ByteBuffer { diff --git a/docs/public-async-nio-apis.md b/docs/public-async-nio-apis.md index 85fba9c912..2a3ff118c1 100644 --- a/docs/public-async-nio-apis.md +++ b/docs/public-async-nio-apis.md @@ -72,7 +72,7 @@ overview to review them. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public struct NIOAsyncChannel : Sendable where Inbound : Sendable, Outbound : Sendable { public struct Configuration : Sendable { - /// The back pressure strategy of the ``NIOAsyncChannel/inboundStream``. + /// The back pressure strategy of the ``NIOAsyncChannel/inbound``. public var backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark /// If outbound half closure should be enabled. Outbound half closure is triggered once @@ -1118,4 +1118,4 @@ public enum NIONegotiatedHTTPVersion Date: Tue, 17 Oct 2023 16:36:52 +0100 Subject: [PATCH 26/64] Remove continuation resumption inside locks (#2558) Motivation: Whilst investigating deadlocking code elsewhere we discovered that resuming a continuation inside a lock can call deadlocks. The reason is that - `withTaskCancellationHandler` takes a sequence's underlying runtime lock then executes the cancellation handler. - The task cancellation handler then does work which requires taking the NIO-level lock - If a NIO method on a separate thread then takes the NIO-level lock and within it resumes a continuation e.g. in `yield` we will deadlock because the resumption attempts to obtain the underlying runtime lock. Modifications: Do not resume continuations within locks. Result: Fewer deadlock opportunities --- .../AsyncSequences/NIOAsyncWriter.swift | 15 +++ .../NIOThrowingAsyncSequenceProducer.swift | 126 ++++++++++-------- 2 files changed, 85 insertions(+), 56 deletions(-) diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index b2ae81d499..a2e6fdae67 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -450,6 +450,9 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal func setWritability(to writability: Bool) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler let action = self._lock.withLock { self._stateMachine.setWritability(to: writability) } @@ -514,6 +517,9 @@ extension NIOAsyncWriter { } } } onCancel: { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler let action = self._lock.withLock { self._stateMachine.cancel(yieldID: yieldID) } @@ -564,6 +570,9 @@ extension NIOAsyncWriter { } } } onCancel: { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler let action = self._lock.withLock { self._stateMachine.cancel(yieldID: yieldID) } @@ -580,6 +589,9 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal func writerFinish(error: Error?) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler let action = self._lock.withLock { self._stateMachine.writerFinish(error: error) } @@ -598,6 +610,9 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal func sinkFinish(error: Error?) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler let action = self._lock.withLock { self._stateMachine.sinkFinish(error: error) } diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index 2574d42d92..3b5e2776ad 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -414,63 +414,65 @@ extension NIOThrowingAsyncSequenceProducer { @inlinable /* fileprivate */ internal func yield(_ sequence: S) -> Source.YieldResult where S.Element == Element { - self._lock.withLock { - let action = self._stateMachine.yield(sequence) + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.yield(sequence) + } - switch action { - case .returnProduceMore: - return .produceMore + switch action { + case .returnProduceMore: + return .produceMore - case .returnStopProducing: - return .stopProducing + case .returnStopProducing: + return .stopProducing - case .returnDropped: - return .dropped + case .returnDropped: + return .dropped - case .resumeContinuationAndReturnProduceMore(let continuation, let element): - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - continuation.resume(returning: element) + case .resumeContinuationAndReturnProduceMore(let continuation, let element): + continuation.resume(returning: element) - return .produceMore + return .produceMore - case .resumeContinuationAndReturnStopProducing(let continuation, let element): - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - continuation.resume(returning: element) + case .resumeContinuationAndReturnStopProducing(let continuation, let element): + continuation.resume(returning: element) - return .stopProducing - } + return .stopProducing } } @inlinable /* fileprivate */ internal func finish(_ failure: Failure?) { - let delegate: Delegate? = self._lock.withLock { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.FinishAction) = self._lock.withLock { let action = self._stateMachine.finish(failure) - + switch action { - case .resumeContinuationWithFailureAndCallDidTerminate(let continuation, let failure): + case .resumeContinuationWithFailureAndCallDidTerminate: let delegate = self._delegate self._delegate = nil + return (delegate, action) - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - switch failure { - case .some(let error): - continuation.resume(throwing: error) - case .none: - continuation.resume(returning: nil) - } - - return delegate + case .none: + return (nil, action) + } + } + switch action { + case .resumeContinuationWithFailureAndCallDidTerminate(let continuation, let failure): + switch failure { + case .some(let error): + continuation.resume(throwing: error) case .none: - return nil + continuation.resume(returning: nil) } + + case .none: + break } delegate?.didTerminate() @@ -549,7 +551,10 @@ extension NIOThrowingAsyncSequenceProducer { } } } onCancel: { - let delegate: Delegate? = self._lock.withLock { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.CancelledAction) = self._lock.withLock { let action = self._stateMachine.cancelled() switch action { @@ -557,33 +562,42 @@ extension NIOThrowingAsyncSequenceProducer { let delegate = self._delegate self._delegate = nil - return delegate - - case .resumeContinuationWithCancellationErrorAndCallDidTerminate(let continuation): - // We have deprecated the generic Failure type in the public API and Failure should - // now be `Swift.Error`. However, if users have not migrated to the new API they could - // still use a custom generic Error type and this cast might fail. - // In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the - // non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and - // this cast will fail as well. - // Everything is marked @inlinable and the Failure type is known at compile time, - // therefore this cast should be optimised away in release build. - if let failure = CancellationError() as? Failure { - continuation.resume(throwing: failure) - } else { - continuation.resume(returning: nil) - } + return (delegate, action) + case .resumeContinuationWithCancellationErrorAndCallDidTerminate: let delegate = self._delegate self._delegate = nil - return delegate + return (delegate, action) case .none: - return nil + return (nil, action) } } + switch action { + case .callDidTerminate: + break + + case .resumeContinuationWithCancellationErrorAndCallDidTerminate(let continuation): + // We have deprecated the generic Failure type in the public API and Failure should + // now be `Swift.Error`. However, if users have not migrated to the new API they could + // still use a custom generic Error type and this cast might fail. + // In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the + // non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and + // this cast will fail as well. + // Everything is marked @inlinable and the Failure type is known at compile time, + // therefore this cast should be optimised away in release build. + if let failure = CancellationError() as? Failure { + continuation.resume(throwing: failure) + } else { + continuation.resume(returning: nil) + } + + case .none: + break + } + delegate?.didTerminate() } } From a9071cc110fcbe2bcf77c27b479d63f7e222d687 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Tue, 17 Oct 2023 17:10:53 +0100 Subject: [PATCH 27/64] waitForUpgraderToBeRemoved availability guard (#2559) Motivation: Tests fail to compile on 5.7, 5.8 Modifications: Add waitForUpgraderToBeRemoved availability guard Result: Tests should compile Co-authored-by: Franz Busch --- Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 1677a83525..d2d4dcc870 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -55,6 +55,7 @@ extension ChannelPipeline { // Waits up to 1 second for the upgrader to be removed by polling the pipeline // every 50ms checking for the handler. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func waitForUpgraderToBeRemoved() throws { for _ in 0..<20 { do { From a41b02c580bf75bba5548492bf4b2b014d55413c Mon Sep 17 00:00:00 2001 From: hamzahrmalik Date: Wed, 18 Oct 2023 14:26:43 +0100 Subject: [PATCH 28/64] Avoid terminating when a precondition is not met in HTTPServerPipelineHandler (#2550) * Avoid terminating when a precondition is not met in HTTPServerPipelineHandler * Address review comments * rename precondition to assertion --------- Co-authored-by: Cory Benfield --- .../NIOHTTP1/HTTPServerPipelineHandler.swift | 213 ++++++++++++++---- .../HTTPServerPipelineHandlerTest.swift | 50 +++- 2 files changed, 222 insertions(+), 41 deletions(-) diff --git a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift index 489eef8954..1e68af848b 100644 --- a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift @@ -13,15 +13,6 @@ //===----------------------------------------------------------------------===// import NIOCore -/// A utility function that runs the body code only in debug builds, without -/// emitting compiler warnings. -/// -/// This is currently the only way to do this in Swift: see -/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. -internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) -} - /// A `ChannelHandler` that handles HTTP pipelining by buffering inbound data until a /// response has been sent. /// @@ -56,6 +47,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha public typealias OutboundIn = HTTPServerResponsePart public typealias OutboundOut = HTTPServerResponsePart + // If this is true AND we're in a debug build, crash the program when an invariant in violated + // Otherwise, we will try to handle the situation as cleanly as possible + internal var failOnPreconditions: Bool = true + public init() { self.nextExpectedInboundMessage = nil self.nextExpectedOutboundMessage = nil @@ -66,6 +61,59 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } } + private enum ConnectionStateAction { + /// A precondition has been violated. Should send an error down the pipeline + case warnPreconditionViolated(message: String) + + /// A further state change was attempted when a precondition has already been violated. + /// Should force close this connection + case forceCloseConnection + + /// Nothing to do + case none + } + + public struct ConnectionStateError: Error, CustomStringConvertible, Hashable { + enum Base: Hashable, CustomStringConvertible { + /// A precondition was violated + case preconditionViolated(message: String) + + var description: String { + switch self { + case .preconditionViolated(let message): + return "Precondition violated \(message)" + } + } + } + + private var base: Base + private var file: String + private var line: Int + + private init(base: Base, file: String, line: Int) { + self.base = base + self.file = file + self.line = line + } + + public static func ==(lhs: ConnectionStateError, rhs: ConnectionStateError) -> Bool { + lhs.base == rhs.base + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.base) + } + + /// A precondition was violated + public static func preconditionViolated(message: String, file: String = #fileID, line: Int = #line) -> Self { + .init(base: .preconditionViolated(message: message), file: file, line: line) + } + + public var description: String { + "\(self.base) file \(self.file) line \(self.line)" + } + } + /// The state of the HTTP connection. private enum ConnectionState { /// We are waiting for a HTTP response to complete before we @@ -92,53 +140,76 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha /// never suppress reads again. case sentCloseOutput - mutating func requestHeadReceived() { + /// The user has violated an invariant. We should refuse further IO now + case preconditionFailed + + mutating func requestHeadReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .idle: self = .requestAndResponseEndPending + return .none case .requestAndResponseEndPending, .responseEndPending, .requestEndPending, .sentCloseOutputRequestEndPending, .sentCloseOutput: - preconditionFailure("received request head in state \(self)") + let message = "received request head in state \(self)" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } - mutating func responseEndReceived() { + mutating func responseEndReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .responseEndPending: // Got the response we were waiting for. self = .idle + return .none case .requestAndResponseEndPending: // We got a response while still receiving a request, which we have to // wait for. self = .requestEndPending + return .none case .sentCloseOutput, .sentCloseOutputRequestEndPending: // This is a user error: they have sent close(mode: .output), but are continuing to write. // The write will fail, so we can allow it to pass. - () + return .none case .requestEndPending, .idle: - preconditionFailure("Unexpectedly received a response in state \(self)") + let message = "Unexpectedly received a response in state \(self)" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } - mutating func requestEndReceived() { + mutating func requestEndReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .requestEndPending: // Got the request end we were waiting for. self = .idle + return .none case .requestAndResponseEndPending: // We got a request and the response isn't done, wait for the // response. self = .responseEndPending + return .none case .sentCloseOutputRequestEndPending: // Got the request end we were waiting for. self = .sentCloseOutput + return .none case .responseEndPending, .idle, .sentCloseOutput: - preconditionFailure("Received second request") + let message = "Received second request" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } mutating func closeOutputSent() { switch self { + case .preconditionFailed: + break case .idle, .responseEndPending: self = .sentCloseOutput case .requestEndPending, .requestAndResponseEndPending: @@ -221,35 +292,37 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self.eventBuffer.append(.channelRead(data)) return } else { - self.deliverOneMessage(context: context, data: data) + let connectionStateAction = self.deliverOneMessage(context: context, data: data) + _ = self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: nil) } } - private func deliverOneMessage(context: ChannelHandlerContext, data: NIOAny) { - assert(self.lifecycleState != .quiescingLastRequestEndReceived && - self.lifecycleState != .quiescingCompleted, - "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)") + private func deliverOneMessage(context: ChannelHandlerContext, data: NIOAny) -> ConnectionStateAction { + self.checkAssertion(self.lifecycleState != .quiescingLastRequestEndReceived && + self.lifecycleState != .quiescingCompleted, + "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)") let msg = self.unwrapInboundIn(data) debugOnly { switch msg { case .head: - assert(self.nextExpectedInboundMessage == .head) + self.checkAssertion(self.nextExpectedInboundMessage == .head) self.nextExpectedInboundMessage = .bodyOrEnd case .body: - assert(self.nextExpectedInboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedInboundMessage == .bodyOrEnd) case .end: - assert(self.nextExpectedInboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedInboundMessage == .bodyOrEnd) self.nextExpectedInboundMessage = .head } } + let action: ConnectionStateAction switch msg { case .head: - self.state.requestHeadReceived() + action = self.state.requestHeadReceived() case .end: // New request is complete. We don't want any more data from now on. - self.state.requestEndReceived() + action = self.state.requestEndReceived() if self.lifecycleState == .quiescingWaitingForRequestEnd { self.lifecycleState = .quiescingLastRequestEndReceived @@ -260,10 +333,11 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha context.close(promise: nil) } case .body: - () + action = .none } context.fireChannelRead(data) + return action } private func deliverOneError(context: ChannelHandlerContext, error: Error) { @@ -280,14 +354,16 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is ChannelShouldQuiesceEvent: - assert(self.lifecycleState == .acceptingEvents, - "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)") + self.checkAssertion(self.lifecycleState == .acceptingEvents, + "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)") switch self.state { case .responseEndPending: // we're not in the middle of a request, let's just shut the door self.lifecycleState = .quiescingLastRequestEndReceived self.eventBuffer.removeAll() - case .idle, .sentCloseOutput: + case .preconditionFailed, + // An invariant has been violated already, this time we close the connection + .idle, .sentCloseOutput: // we're completely idle, let's just close self.lifecycleState = .quiescingCompleted self.eventBuffer.removeAll() @@ -324,20 +400,20 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - assert(self.state != .requestEndPending, - "Received second response while waiting for first one to complete") + self.checkAssertion(self.state != .requestEndPending, + "Received second response while waiting for first one to complete") debugOnly { let res = self.unwrapOutboundIn(data) switch res { case .head(let head) where head.isInformational: - assert(self.nextExpectedOutboundMessage == .head) + self.checkAssertion(self.nextExpectedOutboundMessage == .head) case .head: - assert(self.nextExpectedOutboundMessage == .head) + self.checkAssertion(self.nextExpectedOutboundMessage == .head) self.nextExpectedOutboundMessage = .bodyOrEnd case .body: - assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedOutboundMessage == .bodyOrEnd) case .end: - assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedOutboundMessage == .bodyOrEnd) self.nextExpectedOutboundMessage = .head } } @@ -367,7 +443,7 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha case .quiescingCompleted: // Uh, why are we writing more data here? We'll write it, but it should be guaranteed // to fail. - assertionFailure("Wrote in quiescing completed state") + self.assertionFailed("Wrote in quiescing completed state") context.write(data, promise: promise) } case .body, .head: @@ -375,7 +451,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } if startReadingAgain { - self.state.responseEndReceived() + let connectionStateAction = self.state.responseEndReceived() + if self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: promise) { + return + } self.deliverPendingRequests(context: context) self.startReading(context: context) } @@ -477,6 +556,30 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } } + /// - Returns: True if an error was sent, ie the caller should not continue + private func handleConnectionStateAction( + context: ChannelHandlerContext, + action: ConnectionStateAction, + promise: EventLoopPromise? + ) -> Bool { + switch action { + case .warnPreconditionViolated(let message): + self.assertionFailed(message) + let error = ConnectionStateError.preconditionViolated(message: message) + self.deliverOneError(context: context, error: error) + promise?.fail(error) + return true + case .forceCloseConnection: + let message = "The connection has been forcefully closed because further IO was attempted after a precondition was violated" + let error = ConnectionStateError.preconditionViolated(message: message) + promise?.fail(error) + self.close(context: context, mode: .all, promise: nil) + return true + case .none: + return false + } + } + /// A response has been sent: we can now start passing reads through /// again if there are no further pending requests, and send any read() /// call we may have swallowed. @@ -503,8 +606,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha switch event { case .channelRead(let read): - self.deliverOneMessage(context: context, data: read) - deliveredRead = true + let connectionStateAction = self.deliverOneMessage(context: context, data: read) + if !self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: nil) { + deliveredRead = true + } case .error(let error): self.deliverOneError(context: context, error: error) case .halfClose: @@ -557,6 +662,34 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self.eventBuffer.removeSubrange(firstHead...) } + + /// A utility function that runs the body code only in debug builds, without + /// emitting compiler warnings. + /// + /// This is currently the only way to do this in Swift: see + /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. + private func debugOnly(_ body: () -> Void) { + self.checkAssertion({ body(); return true }()) + } + + /// Calls assertionFailure if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests + private func assertionFailed(_ message: @autoclosure () -> String, file: StaticString = #file, line: UInt = #line) { + if self.failOnPreconditions { + assertionFailure(message(), file: file, line: line) + } + } + + /// Calls assert if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests + private func checkAssertion( + _ closure: @autoclosure () -> Bool, + _ message: @autoclosure () -> String = String(), + file: StaticString = #file, + line: UInt = #line + ) { + if self.failOnPreconditions { + assert(closure(), message(), file: file, line: line) + } + } } @available(*, unavailable) diff --git a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift index d03d119713..572a1a5ce1 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift @@ -1010,7 +1010,7 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) } - func testServerCanRespondProcessingMultipleTimes() throws { + func testServerCanRespondProcessingMultipleTimes() throws { // Send in a request. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -1181,4 +1181,52 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeOutbound(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer())))) XCTAssertNoThrow(try self.channel.writeOutbound(HTTPServerResponsePart.end(nil))) } + + func testSendingHeadTwiceGivesError() throws { + self.pipelineHandler.failOnPreconditions = false + // Sending a head once is normal + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) + // Sending a head twice is an error + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "received request head in state requestAndResponseEndPending")) + } + } + + func testServerRespondToNothing() { + self.pipelineHandler.failOnPreconditions = false + + // Writing an end whilst in state idle is an error + XCTAssertThrowsError(try self.channel.writeOutbound(HTTPServerResponsePart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + } + + // Calling finish surfaces the error again + XCTAssertThrowsError(try self.channel.finish()) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + } + } + + func testServerRequestEndFirstIsError() { + self.pipelineHandler.failOnPreconditions = false + // End sending a request which was never started + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + } + } + + func testForcefulShutdownWhenViolatedPrecondition() { + self.pipelineHandler.failOnPreconditions = false + + // End sending a request which was never started + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + } + // The handler should now refuse further io, and forcefully shutdown + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) + + self.channel.embeddedEventLoop.run() + + // Ensure the channel is closed + XCTAssertNoThrow(try self.channel.closeFuture.wait()) + } } From b7c122e09ee93433f834798061451f1daacf25aa Mon Sep 17 00:00:00 2001 From: hamzahrmalik Date: Wed, 18 Oct 2023 15:54:51 +0100 Subject: [PATCH 29/64] Fix flakiness in testDelayedUpgradeBehaviour (#2557) --- .../HTTPClientUpgradeTests.swift | 54 ++--- .../HTTPServerUpgradeTests.swift | 229 ++++++++---------- 2 files changed, 118 insertions(+), 165 deletions(-) diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index c22297c53b..89f2b64c40 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -349,10 +349,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { } // Validate the pipeline still has http handlers. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) assertPipelineContainsUpgradeHandler(channel: clientChannel) // Push the successful server response. @@ -577,10 +575,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { clientChannel.embeddedEventLoop.run() // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -619,10 +615,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -662,10 +656,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -707,10 +699,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -754,10 +744,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is denied) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) @@ -1038,10 +1026,8 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -1083,10 +1069,8 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { // Should fail with error (response is denied) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) @@ -1179,10 +1163,8 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index d2d4dcc870..378d64d0a8 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -35,15 +35,15 @@ extension ChannelPipeline { } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - fileprivate func assertContainsUpgrader() throws { + fileprivate func assertContainsUpgrader() { do { _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() } catch { - try self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) } } - func assertContains(handlerType: Handler.Type) throws { + func assertContains(handlerType: Handler.Type) { XCTAssertNoThrow(try self.context(handlerType: handlerType).wait(), "did not find handler") } @@ -312,9 +312,14 @@ private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { let requiredUpgradeHeaders: [String] = [] private var upgradePromise: EventLoopPromise? + private let upgradeRequestedPromise: EventLoopPromise - public init(forProtocol `protocol`: String) { + /// - Parameters: + /// - protocol: The protocol this upgrader knows how to support. + /// - upgradeRequestedPromise: Will be fulfilled when upgrade() is called + init(forProtocol `protocol`: String, upgradeRequestedPromise: EventLoopPromise) { self.supportedProtocol = `protocol` + self.upgradeRequestedPromise = upgradeRequestedPromise } public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { @@ -325,6 +330,7 @@ private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { self.upgradePromise = context.eventLoop.makePromise() + upgradeRequestedPromise.succeed() return self.upgradePromise!.futureResult.map { _ in } } @@ -333,7 +339,8 @@ private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { } func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { - self.upgradePromise = channel.eventLoop.makePromise(of: Bool.self) + self.upgradePromise = channel.eventLoop.makePromise() + self.upgradeRequestedPromise.succeed() return self.upgradePromise!.futureResult } } @@ -420,30 +427,31 @@ private class ReentrantReadOnChannelReadCompleteHandler: ChannelInboundHandler { @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPServerUpgradeTestCase: XCTestCase { + + static let eventLoop = MultiThreadedEventLoopGroup.singleton.next() + fileprivate func setUpTestWithAutoremoval(pipelining: Bool = false, upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], extraHandlers: [ChannelHandler], notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (EventLoopGroup, Channel, Channel, Channel) { - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: group, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, Channel, Channel) { + let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: Self.eventLoop, pipelining: pipelining, upgraders: upgraders, extraHandlers: extraHandlers, upgradeCompletionHandler) - let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannel.localAddress!) - return (group, serverChannel, clientChannel, try connectedServerChannelFuture.wait()) + let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannel.localAddress!) + return (serverChannel, clientChannel, try connectedServerChannelFuture.wait()) } func testUpgradeWithoutUpgrade() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\n" @@ -454,14 +462,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeAfterInitialRequest() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } // This request fires a subsequent upgrade in immediately. It should also be ignored. @@ -509,7 +516,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -518,11 +525,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -549,14 +554,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresCorrectHeaders() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\n" @@ -567,14 +571,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresHeadersInConnection() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } // This request is missing a 'Kafkaesque' connection header. @@ -586,14 +589,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeOnlyHandlesKnownProtocols() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: something-else\r\n\r\n" @@ -615,7 +617,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: []) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -624,11 +626,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -663,16 +663,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: [eventSaver.wrappedValue]) { context in XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -720,7 +718,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { } let errorCatcher = ErrorSaver() - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: [errorCatcher]) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -729,11 +727,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -771,15 +767,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testUpgradeIsCaseInsensitive() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["WeIrDcAsE"]) { req in } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { context in context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -801,14 +795,12 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testDelayedUpgradeBehaviour() throws { - let upgrader = UpgradeDelayer(forProtocol: "myproto") - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { context in } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = SingleHTTPResponseAccumulator { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -824,8 +816,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Ok, we don't think this upgrade should have succeeded yet, but neither should it have failed. We want to // dispatch onto the server event loop and check that the channel still contains the upgrade handler. - try connectedServer.pipeline.assertContainsUpgrader() + connectedServer.pipeline.assertContainsUpgrader() + // Wait for the upgrade function to be called + try upgradeRequestPromise.futureResult.wait() // Ok, let's unblock the upgrade now. The machinery should do its thing. try server.eventLoop.submit { upgrader.unblockUpgrade() @@ -836,16 +830,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testBuffersInboundDataDuringDelayedUpgrade() throws { - let upgrader = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) let dataRecorder = DataRecorder() - let (group, server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: [dataRecorder]) { context in } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -903,7 +896,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) // Ok, now we can upgrade. Upgrader should be out of the pipeline, and we should have seen the 101 response. @@ -951,14 +944,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded for the failing protocol. XCTAssertEqual(upgradingProtocol, "failingProtocol") - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) // Ok, now we'll fail the promise. This will catch an error, but the upgrade won't happen: instead, the second handler will be fired. failingProtocolPromise.fail(No.no) XCTAssertEqual(upgradingProtocol, "myproto") - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in @@ -1001,7 +994,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) @@ -1062,7 +1055,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) @@ -1118,17 +1111,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testRemovesAllHTTPRelatedHandlersAfterUpgrade() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(pipelining: true, + let (_, client, connectedServer) = try setUpTestWithAutoremoval(pipelining: true, upgraders: [upgrader], extraHandlers: []) { context in } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } // First, validate the pipeline is right. - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: HTTPResponseEncoder.self)) - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: HTTPServerPipelineHandler.self)) + connectedServer.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + connectedServer.pipeline.assertContains(handlerType: HTTPResponseEncoder.self) + connectedServer.pipeline.assertContains(handlerType: HTTPServerPipelineHandler.self) // This request is safe to upgrade. let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" @@ -1229,7 +1219,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -1239,11 +1229,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { secondByteDonePromise: secondByteDonePromise, allDonePromise: allDonePromise)) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1286,7 +1273,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - let delayer = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let delayer = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) defer { delayer.unblockUpgrade() } @@ -1298,7 +1286,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.embeddedEventLoop.run() // Upgrade has been requested but not proceeded. - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readInbound(as: ByteBuffer.self))) // The 101 has been sent. @@ -1334,7 +1322,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - let delayer = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let delayer = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) defer { delayer.unblockUpgrade() } @@ -1347,7 +1336,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.embeddedEventLoop.run() // Upgrade has been requested but not proceeded. - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readInbound(as: ByteBuffer.self))) // The 101 has been sent. @@ -1399,7 +1388,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -1408,11 +1397,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1474,7 +1460,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() } func testFailedUpgradeResponseWriteThrowsError() throws { @@ -1519,7 +1505,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() } func testFailedUpgraderThrowsError() throws { @@ -1562,7 +1548,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() } } @@ -1574,10 +1560,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { extraHandlers: [ChannelHandler], notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler - ) throws -> (EventLoopGroup, Channel, Channel, Channel) { - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let connectionChannelPromise = group.next().makePromise(of: Channel.self) - let serverChannelFuture = ServerBootstrap(group: group) + ) throws -> (Channel, Channel, Channel) { + let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self) + let serverChannelFuture = ServerBootstrap(group: Self.eventLoop) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in channel.eventLoop.makeCompletedFuture { @@ -1606,8 +1591,8 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) } }.bind(host: "127.0.0.1", port: 0) - let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannelFuture.wait().localAddress!) - return (group, try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) + let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannelFuture.wait().localAddress!) + return (try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) } func testNotUpgrading() throws { @@ -1615,7 +1600,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval( + let (_, client, connectedServer) = try setUpTestWithAutoremoval( upgraders: [upgrader], extraHandlers: [], notUpgradingHandler: { channel in @@ -1625,11 +1610,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { return channel.eventLoop.makeSucceededFuture(true) } ) { _ in } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") XCTAssertEqual(resultString, "") @@ -1669,7 +1652,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval( + let (_, client, connectedServer) = try setUpTestWithAutoremoval( upgraders: [upgrader], extraHandlers: [] ) { (context) in @@ -1680,11 +1663,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1724,7 +1705,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: []) { context in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) @@ -1733,11 +1714,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1778,7 +1757,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { } let errorCatcher = ErrorSaver() - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: [errorCatcher]) { context in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) @@ -1787,11 +1766,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1915,7 +1892,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) @@ -1925,11 +1902,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { secondByteDonePromise: secondByteDonePromise, allDonePromise: allDonePromise)) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1988,7 +1963,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) @@ -1997,11 +1972,9 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -2036,16 +2009,14 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: [eventSaver.wrappedValue]) { context in XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, From eb2757d88debbd34dabe326af8b8e238b10619cb Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Thu, 19 Oct 2023 20:46:28 +0100 Subject: [PATCH 30/64] NonBlockingFileIO: tolerate chunk handlers from other ELs (#2562) --- Sources/NIOPosix/NonBlockingFileIO.swift | 2 +- .../NIOPosixTests/NonBlockingFileIOTest.swift | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index 64daf5f128..dacdc576be 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -202,7 +202,7 @@ public struct NonBlockingFileIO: Sendable { return } let bytesRead = Int64(buffer.readableBytes) - chunkHandler(buffer).whenComplete { result in + chunkHandler(buffer).hop(to: eventLoop).whenComplete { result in switch result { case .success(_): eventLoop.assertInEventLoop() diff --git a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift index 6c4ff1b1d9..1e742b1e92 100644 --- a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift +++ b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift @@ -1016,4 +1016,29 @@ class NonBlockingFileIOTest: XCTestCase { } }) } + + func testChunkedReadingToleratesChunkHandlersWithForeignEventLoops() throws { + let content = "hello" + let contentBytes = Array(content.utf8) + var numCalls = 0 + let otherGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! otherGroup.syncShutdownGracefully() + } + try withTemporaryFile(content: content) { (fileHandle, path) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 5) + try self.fileIO.readChunked(fileRegion: fr, + chunkSize: 1, + allocator: self.allocator, + eventLoop: self.eventLoop) { buf in + var buf = buf + XCTAssertTrue(self.eventLoop.inEventLoop) + XCTAssertEqual(1, buf.readableBytes) + XCTAssertEqual(contentBytes[numCalls], buf.readBytes(length: 1)?.first!) + numCalls += 1 + return otherGroup.next().makeSucceededFuture(()) + }.wait() + } + XCTAssertEqual(content.utf8.count, numCalls) + } } From f71ae347479a2b79087d6d1714de890630c8251e Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Fri, 20 Oct 2023 17:53:17 +0100 Subject: [PATCH 31/64] Support disabling body aggregation in NIOHTTP1TestServer (#2563) --- Sources/NIOTestUtils/NIOHTTP1TestServer.swift | 14 ++++++- .../NIOHTTP1TestServerTest.swift | 37 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift index 1bddd9057f..c787c3e7d5 100644 --- a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift +++ b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift @@ -167,6 +167,7 @@ private final class AggregateBodyHandler: ChannelInboundHandler { /// XCTAssertNoThrow(XCTAssertEqual(responseBody, try requestComplete.wait())) public final class NIOHTTP1TestServer { private let eventLoop: EventLoop + private let aggregateBody: Bool // all protected by eventLoop private let inboundBuffer: BlockingQueue = .init() private var currentClientChannel: Channel? = nil @@ -213,7 +214,11 @@ public final class NIOHTTP1TestServer { return } channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(AggregateBodyHandler()) + if self.aggregateBody { + return channel.pipeline.addHandler(AggregateBodyHandler()) + } else { + return self.eventLoop.makeSucceededVoidFuture() + } }.flatMap { channel.pipeline.addHandler(WebServerHandler(webServer: self)) }.whenSuccess { @@ -221,8 +226,13 @@ public final class NIOHTTP1TestServer { } } - public init(group: EventLoopGroup) { + public convenience init(group: EventLoopGroup) { + self.init(group: group, aggregateBody: true) + } + + public init(group: EventLoopGroup, aggregateBody: Bool) { self.eventLoop = group.next() + self.aggregateBody = aggregateBody self.serverChannel = try! ServerBootstrap(group: self.eventLoop) .childChannelOption(ChannelOptions.autoRead, value: false) diff --git a/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift b/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift index 9d453a243d..10741197b1 100644 --- a/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift +++ b/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift @@ -313,6 +313,43 @@ class NIOHTTP1TestServerTest: XCTestCase { XCTAssertNotNil(channel) XCTAssertNoThrow(try channel.closeFuture.wait()) } + + func testReceiveBodyWithoutAggregation() { + let testServer = NIOHTTP1TestServer(group: self.group, aggregateBody: false) + + let responsePromise = self.group.next().makePromise(of: String.self) + var channel: Channel! + XCTAssertNoThrow(channel = try self.connect(serverPort: testServer.serverPort, + responsePromise: responsePromise).wait()) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/uri", headers: headers) + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: nil) + XCTAssertNoThrow(try testServer.receiveHeadAndVerify { head in + XCTAssertEqual(head.uri, "/uri") + XCTAssertEqual(head.headers["Content-Type"], ["text/plain; charset=utf-8"]) + }) + XCTAssertNoThrow(try testServer.writeOutbound(.head(.init(version: .http1_1, status: .ok)))) + + for _ in 0..<10 { + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "ping")))), promise: nil) + XCTAssertNoThrow(try testServer.receiveBodyAndVerify { buffer in + XCTAssertEqual(String(buffer: buffer), "ping") + }) + XCTAssertNoThrow(try testServer.writeOutbound(.body(.byteBuffer(ByteBuffer(string: "pong"))))) + } + + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil)), promise: nil) + XCTAssertNoThrow(try testServer.receiveEndAndVerify { trailers in + XCTAssertNil(trailers) + }) + XCTAssertNoThrow(try testServer.writeOutbound(.end(nil))) + + XCTAssertNoThrow(try testServer.stop()) + XCTAssertNotNil(channel) + XCTAssertNoThrow(try channel.closeFuture.wait()) + } } private final class TestHTTPHandler: ChannelInboundHandler { From 1c205735141370d14a7345ee661d2e90afe56d2e Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Mon, 23 Oct 2023 13:13:41 +0100 Subject: [PATCH 32/64] Add autogenerated files from VSCode to .gitignore (#2567) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a70f621b48..ec69488c87 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ Package.resolved DerivedData .swiftpm .*.sw? +.vscode/launch.json From 86d05fb79ff293da4da4e94f9fa5d481f2b36dcd Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Mon, 23 Oct 2023 19:00:13 +0100 Subject: [PATCH 33/64] Mark retroactive conformances appropriately. (#2569) Motivation: In nightly Swift, we now need to mark retroactive conformances when they are intentional. These conformances are safe for us, so we can safely suppress the warnings. Modifications: - Mark NIOFoundationCompat retroactive conformances. Result: Nightly builds work again --- .../NIOFoundationCompat/ByteBuffer-foundation.swift | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift index 271707ad83..fd6362bf7f 100644 --- a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift +++ b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift @@ -375,9 +375,17 @@ extension ByteBufferAllocator { } // MARK: - Conformances +#if compiler(>=5.11) +extension ByteBufferView: @retroactive ContiguousBytes {} +extension ByteBufferView: @retroactive DataProtocol {} +extension ByteBufferView: @retroactive MutableDataProtocol {} +#else extension ByteBufferView: ContiguousBytes {} +extension ByteBufferView: DataProtocol {} +extension ByteBufferView: MutableDataProtocol {} +#endif -extension ByteBufferView: DataProtocol { +extension ByteBufferView { public typealias Regions = CollectionOfOne public var regions: CollectionOfOne { @@ -385,8 +393,6 @@ extension ByteBufferView: DataProtocol { } } -extension ByteBufferView: MutableDataProtocol {} - // MARK: - Data extension Data { From 39d1be7fc92eec3f6f87e67b401fa9d729e12d1e Mon Sep 17 00:00:00 2001 From: Max Desiatov Date: Tue, 24 Oct 2023 13:41:04 +0100 Subject: [PATCH 34/64] Mention file length in bytes in `readFileSize` explicitly (#2572) The rest of the functions on `NonBlockingFileIO` are explicit in the fact that they count buffer length in bytes. `readFileSize` should be explicit in its documentation too. --- Sources/NIOPosix/NonBlockingFileIO.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index dacdc576be..2e07175bc3 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -376,12 +376,12 @@ public struct NonBlockingFileIO: Sendable { } } - /// Returns the length of the file associated with `fileHandle`. + /// Returns the length of the file in bytes associated with `fileHandle`. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to read from. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. - /// - returns: An `EventLoopFuture` which is fulfilled if the write was successful or fails on error. + /// - returns: An `EventLoopFuture` which is fulfilled with the length of the file in bytes if the write was successful or fails on error. public func readFileSize(fileHandle: NIOFileHandle, eventLoop: EventLoop) -> EventLoopFuture { return self.threadPool.runIfActive(eventLoop: eventLoop) { From 26838898afbd4e819988c035a7e617638b5a8bc8 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 24 Oct 2023 19:15:05 +0300 Subject: [PATCH 35/64] Add support for async VSock bootstrap methods (#2561) * Add support for async VSock bootstrap methods # Motivation We are about to release our new async bootstrap methods that retain the type information of the various initializers. While developing those we also added support for VSock in NIO; however, we missed adding support for the async bootstrap methods. # Modification This PR adds new async bootstrap methods that take a `VSockAddress` and a test for it. # Result Support for async VSock bootstrap. * Update docs --- Sources/NIOPosix/Bootstrap.swift | 77 ++++++++++++++++++- .../AsyncChannelBootstrapTests.swift | 61 +++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index f503d77444..1fcc4e3a5c 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -552,6 +552,40 @@ extension ServerBootstrap { ) } + /// Bind the `ServerSocketChannel` to a VSOCK socket. + /// + /// - Parameters: + /// - vsockAddress: The VSOCK socket address to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind( + to vsockAddress: VsockAddress, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { + try ServerSocketChannel(eventLoop: eventLoop, group: childEventLoopGroup, protocolFamily: .vsock, enableMPTCP: enableMPTCP) + } + + return try await self.bind0( + makeServerChannel: makeChannel, + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer + ) { channel in + channel.register().flatMap { + let promise = channel.eventLoop.makePromise(of: Void.self) + channel.triggerUserOutboundEvent0( + VsockChannelEvents.BindToAddress(vsockAddress), + promise: promise + ) + return promise.futureResult + } + }.get() + } + /// Use the existing bound socket file descriptor. /// /// - Parameters: @@ -593,7 +627,7 @@ extension ServerBootstrap { makeServerChannel: @escaping (SelectableEventLoop, EventLoopGroup, Bool) throws -> ServerSocketChannel, serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - registration: @escaping @Sendable (Channel) -> EventLoopFuture + registration: @escaping @Sendable (ServerSocketChannel) -> EventLoopFuture ) -> EventLoopFuture> { let eventLoop = self.group.next() let childEventLoopGroup = self.childGroup @@ -931,6 +965,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - address: The VSOCK address to connect to. /// - returns: An `EventLoopFuture` for when the `Channel` is connected. public func connect(to address: VsockAddress) -> EventLoopFuture { + let connectTimeout = self.connectTimeout return self.initializeAndRegisterNewChannel( eventLoop: self.group.next(), protocolFamily: .vsock @@ -938,8 +973,8 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { let connectPromise = channel.eventLoop.makePromise(of: Void.self) channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) - let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) { - connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout)) + let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { + connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) channel.close(promise: nil) } connectPromise.futureResult.whenComplete { (_: Result) in @@ -1123,6 +1158,42 @@ extension ClientBootstrap { ) } + /// Specify the VSOCK address to connect to for the `Channel`. + /// + /// - Parameters: + /// - address: The VSOCK address to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + to address: VsockAddress, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + let connectTimeout = self.connectTimeout + return try await self.initializeAndRegisterNewChannel( + eventLoop: self.group.next(), + protocolFamily: NIOBSDSocket.ProtocolFamily.vsock, + channelInitializer: channelInitializer, + postRegisterTransformation: { result, eventLoop in + return eventLoop.makeSucceededFuture(result) + } + ) { channel in + let connectPromise = channel.eventLoop.makePromise(of: Void.self) + channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) + + let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { + connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) + channel.close(promise: nil) + } + connectPromise.futureResult.whenComplete { (_: Result) in + cancelTask.cancel() + } + + return connectPromise.futureResult + }.get().1 + } + /// Use the existing connected socket file descriptor. /// /// - Parameters: diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 9788b284f8..527a2ccd08 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -770,6 +770,67 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + // MARK: VSock + + func testVSock() async throws { + try XCTSkipUnless(System.supportsVsock, "No vsock transport available") + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let port = VsockAddress.Port(1234) + + let serverChannel = try await ServerBootstrap(group: eventLoopGroup) + .bind( + to: VsockAddress(cid: .any, port: port) + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + + #if canImport(Darwin) + let connectAddress = VsockAddress(cid: .any, port: port) + #elseif os(Linux) || os(Android) + let connectAddress = VsockAddress(cid: .local, port: port) + #endif + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var iterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { _ in + for try await childChannel in serverChannel.inbound { + for try await value in childChannel.inbound { + continuation.yield(.string(value)) + } + } + } + } + + let stringChannel = try await ClientBootstrap(group: eventLoopGroup) + .connect(to: connectAddress) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + try await stringChannel.outbound.write("hello") + + await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) + + group.cancelAll() + } + } + + // MARK: - Test Helpers private func makePipeFileDescriptors() -> (pipe1ReadFH: Int32, pipe1WriteFH: Int32, pipe2ReadFH: Int32, pipe2WriteFH: Int32) { From 935dbdf114e8008ae190af78442db50a7c4eaec9 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Tue, 24 Oct 2023 20:18:22 +0300 Subject: [PATCH 36/64] Add support for unidirectional `NIOPipeBootstrap` (#2560) * Add support for unidirectional `NIOPipeBootstrap` # Motivation In some scenarios, it is useful to only have either an input or output side for a `PipeChannel`. This fixes https://github.com/apple/swift-nio/issues/2444. # Modification This PR adds new methods to `NIOPipeBootstrap` that make either the input or the output optional. Furthermore, I am intentionally breaking the API for the new async methods since those haven't shipped yet to reflect the same API there. # Result It is now possible to bootstrap a `PipeChannel` with either the input or output side closed. * Docs and naming --- .../AsyncChannelOutboundWriterHandler.swift | 12 + Sources/NIOCrashTester/OutputGrepper.swift | 6 +- Sources/NIOPosix/Bootstrap.swift | 229 ++++++++++----- Sources/NIOPosix/PipeChannel.swift | 70 +++-- Sources/NIOPosix/PipePair.swift | 47 ++-- .../AsyncChannelBootstrapTests.swift | 265 ++++++++++++++---- 6 files changed, 455 insertions(+), 174 deletions(-) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index 33f4d363dc..2a9de1328c 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -147,6 +147,18 @@ internal final class NIOAsyncChannelOutboundWriterHandler self.sink?.setWritability(to: context.channel.isWritable) context.fireChannelWritabilityChanged() } + + @inlinable + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case ChannelEvent.outputClosed: + self.sink?.finish() + default: + break + } + + context.fireUserInboundEventTriggered(event) + } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) diff --git a/Sources/NIOCrashTester/OutputGrepper.swift b/Sources/NIOCrashTester/OutputGrepper.swift index ee99921bd1..a226a38bf7 100644 --- a/Sources/NIOCrashTester/OutputGrepper.swift +++ b/Sources/NIOCrashTester/OutputGrepper.swift @@ -22,7 +22,6 @@ internal struct OutputGrepper { internal static func make(group: EventLoopGroup) -> OutputGrepper { let processToChannel = Pipe() - let deadPipe = Pipe() // just so we have an output... let eventLoop = group.next() let outputPromise = eventLoop.makePromise(of: ProgramOutput.self) @@ -34,13 +33,10 @@ internal struct OutputGrepper { channel.pipeline.addHandlers([ByteToMessageHandler(NewlineFramer()), GrepHandler(promise: outputPromise)]) } - .takingOwnershipOfDescriptors(input: dup(processToChannel.fileHandleForReading.fileDescriptor), - output: dup(deadPipe.fileHandleForWriting.fileDescriptor)) + .takingOwnershipOfDescriptor(input: dup(processToChannel.fileHandleForReading.fileDescriptor)) let processOutputPipe = NIOFileHandle(descriptor: dup(processToChannel.fileHandleForWriting.fileDescriptor)) processToChannel.fileHandleForReading.closeFile() processToChannel.fileHandleForWriting.closeFile() - deadPipe.fileHandleForReading.closeFile() - deadPipe.fileHandleForWriting.closeFile() channelFuture.cascadeFailure(to: outputPromise) return OutputGrepper(result: outputPromise.futureResult, processOutputPipe: processOutputPipe) diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 1fcc4e3a5c..bf70b9863e 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -2055,53 +2055,63 @@ public final class NIOPipeBootstrap { /// - output: The _Unix file descriptor_ for the output (ie. the write side). /// - Returns: an `EventLoopFuture` to deliver the `Channel`. public func takingOwnershipOfDescriptors(input: CInt, output: CInt) -> EventLoopFuture { - precondition(input >= 0 && output >= 0 && input != output, - "illegal file descriptor pair. The file descriptors \(input), \(output) " + - "must be distinct and both positive integers.") - let eventLoop = group.next() - do { - try self.validateFileDescriptorIsNotAFile(input) - try self.validateFileDescriptorIsNotAFile(output) - } catch { - return eventLoop.makeFailedFuture(error) - } + self._takingOwnershipOfDescriptors(input: input, output: output) + } - let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } - let channel: PipeChannel - do { - let inputFH = NIOFileHandle(descriptor: input) - let outputFH = NIOFileHandle(descriptor: output) - channel = try PipeChannel(eventLoop: eventLoop as! SelectableEventLoop, - inputPipe: inputFH, - outputPipe: outputFH) - } catch { - return eventLoop.makeFailedFuture(error) - } + /// Create the `PipeChannel` with the provided input file descriptor. + /// + /// The input file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `input`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - Returns: an `EventLoopFuture` to deliver the `Channel`. + public func takingOwnershipOfDescriptor( + input: CInt + ) -> EventLoopFuture { + self._takingOwnershipOfDescriptors(input: input, output: nil) + } - func setupChannel() -> EventLoopFuture { - eventLoop.assertInEventLoop() - return self._channelOptions.applyAllChannelOptions(to: channel).flatMap { - channelInitializer(channel) - }.flatMap { - eventLoop.assertInEventLoop() - let promise = eventLoop.makePromise(of: Void.self) - channel.registerAlreadyConfigured0(promise: promise) - return promise.futureResult - }.map { - channel - }.flatMapError { error in - channel.close0(error: error, mode: .all, promise: nil) - return channel.eventLoop.makeFailedFuture(error) - } - } + /// Create the `PipeChannel` with the provided output file descriptor. + /// + /// The output file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `output` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - Returns: an `EventLoopFuture` to deliver the `Channel`. + public func takingOwnershipOfDescriptor( + output: CInt + ) -> EventLoopFuture { + self._takingOwnershipOfDescriptors(input: nil, output: output) + } - if eventLoop.inEventLoop { - return setupChannel() - } else { - return eventLoop.flatSubmit { - setupChannel() + private func _takingOwnershipOfDescriptors(input: CInt?, output: CInt?) -> EventLoopFuture { + let channelInitializer: @Sendable (Channel) -> EventLoopFuture = { + let eventLoop = self.group.next() + let channelInitializer = self.channelInitializer + return { channel in + if let channelInitializer = channelInitializer { + return channelInitializer(channel).map { channel } + } else { + return eventLoop.makeSucceededFuture(channel) + } } - } + + }() + return self._takingOwnershipOfDescriptors( + input: input, + output: output, + channelInitializer: channelInitializer + ) } @available(*, deprecated, renamed: "takingOwnershipOfDescriptor(inputOutput:)") @@ -2154,9 +2164,7 @@ extension NIOPipeBootstrap { /// Create the `PipeChannel` with the provided input and output file descriptors. /// - /// The input and output file descriptors must be distinct. If you have a single file descriptor, consider using - /// `ClientBootstrap.withConnectedSocket(descriptor:)` if it's a socket or - /// `NIOPipeBootstrap.takingOwnershipOfDescriptor` if it is not a socket. + /// The input and output file descriptors must be distinct. /// /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` and `output` /// when the `Channel` becomes inactive. You _must not_ do any further operations `input` or @@ -2179,41 +2187,112 @@ extension NIOPipeBootstrap { try await self._takingOwnershipOfDescriptors( input: input, output: output, - channelInitializer: channelInitializer, - postRegisterTransformation: { $0.makeSucceededFuture($1) } + channelInitializer: channelInitializer ) } + /// Create the `PipeChannel` with the provided input file descriptor. + /// + /// The input file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `input`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - func _takingOwnershipOfDescriptors( + public func takingOwnershipOfDescriptor( input: CInt, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self._takingOwnershipOfDescriptors( + input: input, + output: nil, + channelInitializer: channelInitializer + ) + } + + /// Create the `PipeChannel` with the provided output file descriptor. + /// + /// The output file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `output` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptor( output: CInt, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (EventLoop, ChannelInitializerResult) -> EventLoopFuture - ) async throws -> PostRegistrationTransformationResult { - precondition(input >= 0 && output >= 0 && input != output, - "illegal file descriptor pair. The file descriptors \(input), \(output) " + + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self._takingOwnershipOfDescriptors( + input: nil, + output: output, + channelInitializer: channelInitializer + ) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + func _takingOwnershipOfDescriptors( + input: CInt?, + output: CInt?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> ChannelInitializerResult { + try await self._takingOwnershipOfDescriptors( + input: input, + output: output, + channelInitializer: channelInitializer + ).get() + } + + func _takingOwnershipOfDescriptors( + input: CInt?, + output: CInt?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> EventLoopFuture { + precondition(input ?? 0 >= 0 && output ?? 0 >= 0 && input != output, + "illegal file descriptor pair. The file descriptors \(String(describing: input)), \(String(describing: output)) " + "must be distinct and both positive integers.") + precondition(!(input == nil && output == nil), "Either input or output has to be set") let eventLoop = group.next() let channelOptions = self._channelOptions - try self.validateFileDescriptorIsNotAFile(input) - try self.validateFileDescriptorIsNotAFile(output) - let channelInitializer = { (channel: Channel) -> EventLoopFuture in - let initializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } - return initializer(channel).flatMap { channelInitializer(channel) } + let channel: PipeChannel + let inputFileHandle: NIOFileHandle? + let outputFileHandle: NIOFileHandle? + do { + if let input = input { + try self.validateFileDescriptorIsNotAFile(input) + } + if let output = output { + try self.validateFileDescriptorIsNotAFile(output) + } + + inputFileHandle = input.flatMap { NIOFileHandle(descriptor: $0) } + outputFileHandle = output.flatMap { NIOFileHandle(descriptor: $0) } + channel = try PipeChannel( + eventLoop: eventLoop as! SelectableEventLoop, + inputPipe: inputFileHandle, + outputPipe: outputFileHandle + ) + } catch { + return eventLoop.makeFailedFuture(error) } - let inputFileHandle = NIOFileHandle(descriptor: input) - let outputFileHandle = NIOFileHandle(descriptor: output) - let channel = try PipeChannel( - eventLoop: eventLoop as! SelectableEventLoop, - inputPipe: inputFileHandle, - outputPipe: outputFileHandle - ) @Sendable - func setupChannel() -> EventLoopFuture { + func setupChannel() -> EventLoopFuture { eventLoop.assertInEventLoop() return channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in channelInitializer(channel) @@ -2221,7 +2300,15 @@ extension NIOPipeBootstrap { eventLoop.assertInEventLoop() let promise = eventLoop.makePromise(of: Void.self) channel.registerAlreadyConfigured0(promise: promise) - return promise.futureResult.flatMap { postRegisterTransformation(eventLoop, result) } + return promise.futureResult.map { result } + }.flatMap { result -> EventLoopFuture in + if inputFileHandle == nil { + return channel.close(mode: .input).map { result } + } + if outputFileHandle == nil { + return channel.close(mode: .output).map { result } + } + return channel.selectableEventLoop.makeSucceededFuture(result) }.flatMapError { error in channel.close0(error: error, mode: .all, promise: nil) return channel.eventLoop.makeFailedFuture(error) @@ -2229,11 +2316,11 @@ extension NIOPipeBootstrap { } if eventLoop.inEventLoop { - return try await setupChannel().get() + return setupChannel() } else { - return try await eventLoop.flatSubmit { + return eventLoop.flatSubmit { setupChannel() - }.get() + } } } } diff --git a/Sources/NIOPosix/PipeChannel.swift b/Sources/NIOPosix/PipeChannel.swift index 053f5f14a4..069fdfcc40 100644 --- a/Sources/NIOPosix/PipeChannel.swift +++ b/Sources/NIOPosix/PipeChannel.swift @@ -21,14 +21,18 @@ final class PipeChannel: BaseStreamSocketChannel { case output } - init(eventLoop: SelectableEventLoop, - inputPipe: NIOFileHandle, - outputPipe: NIOFileHandle) throws { + init( + eventLoop: SelectableEventLoop, + inputPipe: NIOFileHandle?, + outputPipe: NIOFileHandle? + ) throws { self.pipePair = try PipePair(inputFD: inputPipe, outputFD: outputPipe) - try super.init(socket: self.pipePair, - parent: nil, - eventLoop: eventLoop, - recvAllocator: AdaptiveRecvByteBufferAllocator()) + try super.init( + socket: self.pipePair, + parent: nil, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator() + ) } func registrationForInput(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { @@ -56,50 +60,62 @@ final class PipeChannel: BaseStreamSocketChannel { } override func register(selector: Selector, interested: SelectorEventSet) throws { - try selector.register(selectable: self.pipePair.inputFD, - interested: interested.intersection([.read, .reset]), - makeRegistration: self.registrationForInput) - try selector.register(selectable: self.pipePair.outputFD, - interested: interested.intersection([.write, .reset]), - makeRegistration: self.registrationForOutput) + if let inputFD = self.pipePair.inputFD { + try selector.register( + selectable: inputFD, + interested: interested.intersection([.read, .reset]), + makeRegistration: self.registrationForInput + ) + } + if let outputFD = self.pipePair.outputFD { + try selector.register( + selectable: outputFD, + interested: interested.intersection([.write, .reset]), + makeRegistration: self.registrationForOutput + ) + } } override func deregister(selector: Selector, mode: CloseMode) throws { - if (mode == .all || mode == .input) && self.pipePair.inputFD.isOpen { - try selector.deregister(selectable: self.pipePair.inputFD) + if let inputFD = self.pipePair.inputFD, (mode == .all || mode == .input) && inputFD.isOpen { + try selector.deregister(selectable: inputFD) } - if (mode == .all || mode == .output) && self.pipePair.outputFD.isOpen { - try selector.deregister(selectable: self.pipePair.outputFD) + if let outputFD = self.pipePair.outputFD, (mode == .all || mode == .output) && outputFD.isOpen { + try selector.deregister(selectable: outputFD) } } override func reregister(selector: Selector, interested: SelectorEventSet) throws { - if self.pipePair.inputFD.isOpen { - try selector.reregister(selectable: self.pipePair.inputFD, - interested: interested.intersection([.read, .reset])) + if let inputFD = self.pipePair.inputFD, inputFD.isOpen { + try selector.reregister( + selectable: inputFD, + interested: interested.intersection([.read, .reset]) + ) } - if self.pipePair.outputFD.isOpen { - try selector.reregister(selectable: self.pipePair.outputFD, - interested: interested.intersection([.write, .reset])) + if let outputFD = self.pipePair.outputFD, outputFD.isOpen { + try selector.reregister( + selectable: outputFD, + interested: interested.intersection([.write, .reset]) + ) } } override func readEOF() { super.readEOF() - guard self.pipePair.inputFD.isOpen else { + guard let inputFD = self.pipePair.inputFD, inputFD.isOpen else { return } try! self.selectableEventLoop.deregister(channel: self, mode: .input) - try! self.pipePair.inputFD.close() + try! inputFD.close() } override func writeEOF() { - guard self.pipePair.outputFD.isOpen else { + guard let outputFD = self.pipePair.outputFD, outputFD.isOpen else { return } try! self.selectableEventLoop.deregister(channel: self, mode: .output) - try! self.pipePair.outputFD.close() + try! outputFD.close() } override func shutdownSocket(mode: CloseMode) throws { diff --git a/Sources/NIOPosix/PipePair.swift b/Sources/NIOPosix/PipePair.swift index 921fdab3ce..cb9a92d4b3 100644 --- a/Sources/NIOPosix/PipePair.swift +++ b/Sources/NIOPosix/PipePair.swift @@ -38,14 +38,14 @@ extension SelectableFileHandle: Selectable { final class PipePair: SocketProtocol { typealias SelectableType = SelectableFileHandle - let inputFD: SelectableFileHandle - let outputFD: SelectableFileHandle + let inputFD: SelectableFileHandle? + let outputFD: SelectableFileHandle? - init(inputFD: NIOFileHandle, outputFD: NIOFileHandle) throws { - self.inputFD = SelectableFileHandle(inputFD) - self.outputFD = SelectableFileHandle(outputFD) + init(inputFD: NIOFileHandle?, outputFD: NIOFileHandle?) throws { + self.inputFD = inputFD.flatMap { SelectableFileHandle($0) } + self.outputFD = outputFD.flatMap { SelectableFileHandle($0) } try self.ignoreSIGPIPE() - for fileHandle in [inputFD, outputFD] { + for fileHandle in [inputFD, outputFD].compactMap({ $0 }) { try fileHandle.withUnsafeFileDescriptor { try NIOFileHandle.setNonBlocking(fileDescriptor: $0) } @@ -53,7 +53,7 @@ final class PipePair: SocketProtocol { } func ignoreSIGPIPE() throws { - for fileHandle in [self.inputFD, self.outputFD] { + for fileHandle in [self.inputFD, self.outputFD].compactMap({ $0 }) { try fileHandle.withUnsafeHandle { try PipePair.ignoreSIGPIPE(descriptor: $0) } @@ -61,7 +61,7 @@ final class PipePair: SocketProtocol { } var description: String { - return "PipePair { in=\(inputFD), out=\(outputFD) }" + return "PipePair { in=\(String(describing: inputFD)), out=\(String(describing: inputFD)) }" } func connect(to address: SocketAddress) throws -> Bool { @@ -73,19 +73,28 @@ final class PipePair: SocketProtocol { } func write(pointer: UnsafeRawBufferPointer) throws -> IOResult { - return try self.outputFD.withUnsafeHandle { + guard let outputFD = self.outputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try outputFD.withUnsafeHandle { try Posix.write(descriptor: $0, pointer: pointer.baseAddress!, size: pointer.count) } } func writev(iovecs: UnsafeBufferPointer) throws -> IOResult { - return try self.outputFD.withUnsafeHandle { + guard let outputFD = self.outputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try outputFD.withUnsafeHandle { try Posix.writev(descriptor: $0, iovecs: iovecs) } } func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult { - return try self.inputFD.withUnsafeHandle { + guard let inputFD = self.inputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try inputFD.withUnsafeHandle { try Posix.read(descriptor: $0, pointer: pointer.baseAddress!, size: pointer.count) } } @@ -119,30 +128,30 @@ final class PipePair: SocketProtocol { func shutdown(how: Shutdown) throws { switch how { case .RD: - try self.inputFD.close() + try self.inputFD?.close() case .WR: - try self.outputFD.close() + try self.outputFD?.close() case .RDWR: try self.close() } } var isOpen: Bool { - return self.inputFD.isOpen || self.outputFD.isOpen + return self.inputFD?.isOpen ?? false || self.outputFD?.isOpen ?? false } func close() throws { - guard self.inputFD.isOpen || self.outputFD.isOpen else { + guard self.isOpen else { throw ChannelError.alreadyClosed } let r1 = Result { - if self.inputFD.isOpen { - try self.inputFD.close() + if let inputFD = self.inputFD, inputFD.isOpen { + try inputFD.close() } } let r2 = Result { - if self.outputFD.isOpen { - try self.outputFD.close() + if let outputFD = self.outputFD, outputFD.isOpen { + try outputFD.close() } } try r1.get() diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 527a2ccd08..2775b48d2d 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -620,86 +620,237 @@ final class AsyncChannelBootstrapTests: XCTestCase { func testPipeBootstrap() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH) = self.makePipeFileDescriptors() - let toChannel = FileHandle(fileDescriptor: pipe1WriteFH, closeOnDealloc: false) - let fromChannel = FileHandle(fileDescriptor: pipe2ReadFH, closeOnDealloc: false) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD) = self.makePipeFileDescriptors() let channel: NIOAsyncChannel + let toChannel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel do { channel = try await NIOPipeBootstrap(group: eventLoopGroup) .takingOwnershipOfDescriptors( - input: pipe1ReadFH, - output: pipe2WriteFH + input: pipe1ReadFD, + output: pipe2WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe2ReadFD ) { channel in channel.eventLoop.makeCompletedFuture { try NIOAsyncChannel(synchronouslyWrapping: channel) } } } catch { - [pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH].forEach { try? SystemCalls.close(descriptor: $0) } + try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } throw error } var inboundIterator = channel.inbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() + + try await toChannel.outbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + + let response = ByteBuffer(string: "Response") + try await channel.outbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + + func testPipeBootstrap_whenInputNil() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD) = self.makePipeFileDescriptors() + let channel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel do { - try toChannel.writeBytes(.init(string: "Request")) - try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + channel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } - let response = ByteBuffer(string: "Response") - try await channel.outbound.write(response) - XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe1ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } } catch { - // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFH, pipe2ReadFH].forEach { try? SystemCalls.close(descriptor: $0) } + try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } throw error } + + var inboundIterator = channel.inbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() + + try await XCTAsyncAssertEqual(try await inboundIterator.next(), nil) + + let response = ByteBuffer(string: "Response") + try await channel.outbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + + func testPipeBootstrap_whenOutputNil() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD) = self.makePipeFileDescriptors() + let channel: NIOAsyncChannel + let toChannel: NIOAsyncChannel + + do { + channel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe1ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + var inboundIterator = channel.inbound.makeAsyncIterator() + + try await toChannel.outbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + + let response = ByteBuffer(string: "Response") + await XCTAsyncAssertThrowsError(try await channel.outbound.write(response)) { error in + XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) + } } func testPipeBootstrap_withProtocolNegotiation() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH) = self.makePipeFileDescriptors() - let toChannel = FileHandle(fileDescriptor: pipe1WriteFH, closeOnDealloc: false) - let fromChannel = FileHandle(fileDescriptor: pipe2ReadFH, closeOnDealloc: false) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD) = self.makePipeFileDescriptors() + let negotiationResult: EventLoopFuture + let toChannel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel - try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in - group.addTask { - do { - return try await NIOPipeBootstrap(group: eventLoopGroup) - .takingOwnershipOfDescriptors( - input: pipe1ReadFH, - output: pipe2WriteFH - ) { channel in - return channel.eventLoop.makeCompletedFuture { - return try self.configureProtocolNegotiationHandlers(channel: channel) - } - } - } catch { - [pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH].forEach { try? SystemCalls.close(descriptor: $0) } - throw error + do { + negotiationResult = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptors( + input: pipe1ReadFD, + output: pipe2WriteFD + ) { channel in + return channel.eventLoop.makeCompletedFuture { + return try self.configureProtocolNegotiationHandlers(channel: channel) + } } - } + } catch { + try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } - try toChannel.writeBytes(.init(string: "alpn:string\nHello\n")) - let negotiationResult = try await group.next() - switch try await negotiationResult?.get() { - case .string(let channel): - var inboundIterator = channel.inbound.makeAsyncIterator() - do { - try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") - - let response = ByteBuffer(string: "Response") - try await channel.outbound.write("Response") - XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) - } catch { - // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFH, pipe2ReadFH].forEach { try? SystemCalls.close(descriptor: $0) } - throw error + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } } + } catch { + try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } - case .byte, nil: - fatalError() + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe2ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(synchronouslyWrapping: channel) + } + } + } catch { + try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() + + try await toChannel.outbound.write(.init(string: "alpn:string\nHello\n")) + switch try await negotiationResult.get() { + case .string(let channel): + var inboundIterator = channel.inbound.makeAsyncIterator() + do { + try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") + + let expectedResponse = ByteBuffer(string: "Response\n") + try await channel.outbound.write("Response") + let response = try await fromChannelInboundIterator.next() + XCTAssertEqual(response, expectedResponse) + } catch { + // We only got to close the FDs that are not owned by the PipeChannel + [pipe1WriteFD, pipe2ReadFD].forEach { try? SystemCalls.close(descriptor: $0) } + throw error } + + case .byte: + fatalError() } } @@ -833,18 +984,28 @@ final class AsyncChannelBootstrapTests: XCTestCase { // MARK: - Test Helpers - private func makePipeFileDescriptors() -> (pipe1ReadFH: Int32, pipe1WriteFH: Int32, pipe2ReadFH: Int32, pipe2WriteFH: Int32) { - var pipe1FDs: [Int32] = [-1, -1] + private func makePipeFileDescriptors() -> (pipe1ReadFD: CInt, pipe1WriteFD: CInt, pipe2ReadFD: CInt, pipe2WriteFD: CInt) { + var pipe1FDs: [CInt] = [-1, -1] pipe1FDs.withUnsafeMutableBufferPointer { ptr in XCTAssertEqual(0, pipe(ptr.baseAddress!)) } - var pipe2FDs: [Int32] = [-1, -1] + var pipe2FDs: [CInt] = [-1, -1] pipe2FDs.withUnsafeMutableBufferPointer { ptr in XCTAssertEqual(0, pipe(ptr.baseAddress!)) } return (pipe1FDs[0], pipe1FDs[1], pipe2FDs[0], pipe2FDs[1]) } + private func makePipeFileDescriptors() -> (pipeReadFD: CInt, pipeWriteFD: CInt) { + var pipeFDs: [CInt] = [-1, -1] + pipeFDs.withUnsafeMutableBufferPointer { ptr in + XCTAssertEqual(0, pipe(ptr.baseAddress!)) + } + return (pipeFDs[0], pipeFDs[1]) + } + + + private func makeRawSocketServerChannel(eventLoopGroup: EventLoopGroup) async throws -> NIOAsyncChannel { try await NIORawSocketBootstrap(group: eventLoopGroup) .bind( From 54c85cb26308b89846d4671f23954dce088da2b0 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Wed, 25 Oct 2023 10:20:49 +0100 Subject: [PATCH 37/64] Fix thread-safety issues in TCPThroughputBenchmark (#2537) Motivation: Several thread-safety issues were missed in code review. This patch fixes them. Modifications: - Removed the use of an unstructured Task, replaced with eventLoop.execute to ServerHandler's EventLoop. - Stopped ClientHandler reaching into the benchmark object without any synchronization, used promises and event loop hops instead. Result: Thread safety is back --- .../TCPThroughputBenchmark.swift | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift index a1b88cabac..aaf3f8acb6 100644 --- a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift +++ b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift @@ -32,21 +32,21 @@ final class TCPThroughputBenchmark: Benchmark { private var clientChannel: Channel! private var message: ByteBuffer! - private var isDonePromise: EventLoopPromise! + private var serverEventLoop: EventLoop! final class ServerHandler: ChannelInboundHandler { public typealias InboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer - private var channel: Channel! + private var context: ChannelHandlerContext! public func channelActive(context: ChannelHandlerContext) { - self.channel = context.channel + self.context = context } public func send(_ message: ByteBuffer, times count: Int) { for _ in 0..? - init(_ benchmark: TCPThroughputBenchmark) { - self.benchmark = benchmark + init() { self.messagesReceived = 0 } + func prepareRun(expectedMessages: Int, promise: EventLoopPromise) { + self.expectedMessages = expectedMessages + self.completionPromise = promise + } + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.messagesReceived += 1 - if (self.benchmark.messages == self.messagesReceived) { - self.benchmark.isDonePromise.succeed() + + if (self.expectedMessages == self.messagesReceived) { + let promise = self.completionPromise + self.messagesReceived = 0 + self.expectedMessages = nil + self.completionPromise = nil + + promise!.succeed() } } } @@ -95,12 +106,12 @@ final class TCPThroughputBenchmark: Benchmark { func setUp() throws { self.group = MultiThreadedEventLoopGroup(numberOfThreads: 4) - let connectionEstablished: EventLoopPromise = self.group.next().makePromise() + let connectionEstablished: EventLoopPromise = self.group.next().makePromise() self.serverChannel = try ServerBootstrap(group: self.group) .childChannelInitializer { channel in self.serverHandler = ServerHandler() - connectionEstablished.succeed() + connectionEstablished.succeed(channel.eventLoop) return channel.pipeline.addHandler(self.serverHandler) } .bind(host: "127.0.0.1", port: 0) @@ -109,7 +120,7 @@ final class TCPThroughputBenchmark: Benchmark { self.clientChannel = try ClientBootstrap(group: group) .channelInitializer { channel in channel.pipeline.addHandler(ByteToMessageHandler(StreamDecoder())).flatMap { _ in - channel.pipeline.addHandler(ClientHandler(self)) + channel.pipeline.addHandler(ClientHandler()) } } .connect(to: serverChannel.localAddress!) @@ -122,7 +133,7 @@ final class TCPThroughputBenchmark: Benchmark { } self.message = message - try connectionEstablished.futureResult.wait() + self.serverEventLoop = try connectionEstablished.futureResult.wait() } func tearDown() { @@ -132,12 +143,22 @@ final class TCPThroughputBenchmark: Benchmark { } func run() throws -> Int { - self.isDonePromise = self.group.next().makePromise() - Task { - self.serverHandler.send(self.message, times: self.messages) + let isDonePromise = self.clientChannel.eventLoop.makePromise(of: Void.self) + let clientChannel = self.clientChannel! + let expectedMessages = self.messages + + try clientChannel.eventLoop.submit { + try clientChannel.pipeline.syncOperations.handler(type: ClientHandler.self).prepareRun(expectedMessages: expectedMessages, promise: isDonePromise) + }.wait() + + let serverHandler = self.serverHandler! + let message = self.message! + let messages = self.messages + + self.serverEventLoop.execute { + serverHandler.send(message, times: messages) } - try self.isDonePromise.futureResult.wait() - self.isDonePromise = nil + try isDonePromise.futureResult.wait() return 0 } } From 95a4eaa0bcc324043b163136920dade083c995be Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 25 Oct 2023 16:56:00 +0100 Subject: [PATCH 38/64] Add async version of NIOThreadPool.runIfActive (#2566) * Add async version of NIOThreadPool.runIfActive * Changes from comments in PR Remove no longer supported code Add tests for errors being thrown, and the thread pool not being active * Collapse async runIfActive into one function * remove whitespace * Make T Sendable --------- Co-authored-by: Cory Benfield --- Sources/NIOPosix/NIOThreadPool.swift | 23 ++++++++++ Tests/NIOPosixTests/NIOThreadPoolTest.swift | 49 +++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index e6b631d1aa..6d29801894 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -290,6 +290,29 @@ extension NIOThreadPool { } return promise.futureResult } + + /// Runs the submitted closure if the thread pool is still active, otherwise throw an error. + /// The closure will be run on the thread pool so can do blocking work. + /// + /// - parameters: + /// - body: The closure which performs some blocking work to be done on the thread pool. + /// - returns: result of the passed closure. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func runIfActive(_ body: @escaping @Sendable () throws -> T) async throws -> T { + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + self.submit { shouldRun in + guard case shouldRun = NIOThreadPool.WorkItemState.active else { + cont.resume(throwing: NIOThreadPoolError.ThreadPoolInactive()) + return + } + do { + try cont.resume(returning: body()) + } catch { + cont.resume(throwing: error) + } + } + } + } } extension NIOThreadPool { diff --git a/Tests/NIOPosixTests/NIOThreadPoolTest.swift b/Tests/NIOPosixTests/NIOThreadPoolTest.swift index a36e4794ac..b51a96ad47 100644 --- a/Tests/NIOPosixTests/NIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/NIOThreadPoolTest.swift @@ -14,6 +14,7 @@ import XCTest @testable import NIOPosix +import Atomics import Dispatch import NIOConcurrencyHelpers import NIOEmbedded @@ -110,6 +111,54 @@ class NIOThreadPoolTest: XCTestCase { } } + func testAsyncThreadPool() async throws { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + pool.start() + do { + let hitCount = ManagedAtomic(false) + try await pool.runIfActive { + hitCount.store(true, ordering: .relaxed) + } + XCTAssertEqual(hitCount.load(ordering: .relaxed), true) + } catch {} + try await pool.shutdownGracefully() + } + + func testAsyncThreadPoolErrorPropagation() async throws { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } + struct ThreadPoolError: Error {} + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + pool.start() + do { + try await pool.runIfActive { + throw ThreadPoolError() + } + XCTFail("Should not get here as closure sent to runIfActive threw an error") + } catch { + XCTAssertNotNil(error as? ThreadPoolError, "Error thrown should be of type ThreadPoolError") + } + try await pool.shutdownGracefully() + } + + func testAsyncThreadPoolNotActiveError() async throws { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } + struct ThreadPoolError: Error {} + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + do { + try await pool.runIfActive { + throw ThreadPoolError() + } + XCTFail("Should not get here as thread pool isn't active") + } catch { + XCTAssertNotNil(error as? NIOThreadPoolError.ThreadPoolInactive, "Error thrown should be of type ThreadPoolError") + } + try await pool.shutdownGracefully() + } + func testAsyncShutdownWorks() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let threadPool = NIOThreadPool(numberOfThreads: 17) From 740fc734f3266e8e374817cea0db48d2da807007 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Thu, 26 Oct 2023 08:03:51 +0100 Subject: [PATCH 39/64] Fix concurrency doc APIs (#2575) Our Concurrency doc article got a bit outdated when we renamed `inboundStream` and `outboundWriter` to `inbound` and `outbound` --- Sources/NIOCore/Docs.docc/swift-concurrency.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index 0a5b58f47c..c8ddaadd86 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -87,8 +87,8 @@ the inbound data and echo it back outbound. let channel = ... let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) -for try await inboundData in asyncChannel.inboundStream { - try await asyncChannel.outboundWriter.write(inboundData) +for try await inboundData in asyncChannel.inbound { + try await asyncChannel.outbound.write(inboundData) } ``` @@ -137,12 +137,12 @@ let serverChannel = try await ServerBootstrap(group: eventLoopGroup) } try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in serverChannel.inboundStream { + for try await connectionChannel in serverChannel.inbound { group.addTask { do { - for try await inboundData in connectionChannel.inboundStream { + for try await inboundData in connectionChannel.inbound { // Let's echo back all inbound data - try await connectionChannel.outboundWriter.write(inboundData) + try await connectionChannel.outbound.write(inboundData) } } catch { // Handle errors @@ -185,9 +185,9 @@ let clientChannel = try await ClientBootstrap(group: eventLoopGroup) } } -clientChannel.outboundWriter.write(ByteBuffer(string: "hello")) +clientChannel.outbound.write(ByteBuffer(string: "hello")) -for try await inboundData in clientChannel.inboundStream { +for try await inboundData in clientChannel.inbound { print(inboundData) } ``` From 8c238f2cc18b4b2883caacac9c95e852f67e5cdb Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Fri, 27 Oct 2023 15:48:23 +0100 Subject: [PATCH 40/64] Back out new typed HTTP protocol upgrader (#2579) # Motivation We got reports in https://github.com/apple/swift-nio/issues/2574 that our new typed HTTP upgrader are hitting a Swift compiler bug which manifests in a runtime crash on older iOS/macOS/etc. # Modification This PR backs out the new typed HTTP protocol upgrader APIs so that we can unblock our users until the Swift compiler bug is fixed. # Result No more crashes for our users. --- Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 248 --------- .../NIOTypedHTTPClientUpgradeHandler.swift | 283 ---------- ...OTypedHTTPClientUpgraderStateMachine.swift | 334 ----------- .../NIOTypedHTTPServerUpgradeHandler.swift | 369 ------------- ...OTypedHTTPServerUpgraderStateMachine.swift | 385 ------------- Sources/NIOTCPEchoClient/Client.swift | 2 +- Sources/NIOTCPEchoServer/Server.swift | 2 +- .../NIOWebSocketClientUpgrader.swift | 56 -- .../NIOWebSocketServerUpgrader.swift | 84 --- Sources/NIOWebSocketClient/Client.swift | 241 ++++---- Sources/NIOWebSocketServer/Server.swift | 461 ++++++++-------- .../HTTPClientUpgradeTests.swift | 235 +------- .../HTTPServerUpgradeTests.swift | 518 +----------------- .../WebSocketClientEndToEndTests.swift | 211 ------- .../WebSocketServerEndToEndTests.swift | 27 - 15 files changed, 367 insertions(+), 3089 deletions(-) delete mode 100644 Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift delete mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift delete mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift delete mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift delete mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift deleted file mode 100644 index 57fe5fd780..0000000000 --- a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift +++ /dev/null @@ -1,248 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore - -// MARK: - Server pipeline configuration - -/// Configuration for an upgradable HTTP pipeline. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public struct NIOUpgradableHTTPServerPipelineConfiguration { - /// Whether to provide assistance handling HTTP clients that pipeline - /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. - public var enablePipelining = true - - /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP - /// request) by sending 400 errors. Defaults to `true`. - public var enableErrorHandling = true - - /// Whether to validate outbound response headers to confirm that they are - /// spec compliant. Defaults to `true`. - public var enableResponseHeaderValidation = true - - /// The configuration for the ``HTTPResponseEncoder``. - public var encoderConfiguration = HTTPResponseEncoder.Configuration() - - /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. - public var upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration - - /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. - /// - /// The current defaults provide the following features: - /// 1. Assistance handling clients that pipeline HTTP requests. - /// 2. Assistance handling protocol errors. - /// 3. Outbound header fields validation to protect against response splitting attacks. - public init( - upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration - ) { - self.upgradeConfiguration = upgradeConfiguration - } -} - -extension ChannelPipeline { - /// Configure a `ChannelPipeline` for use as an HTTP server. - /// - /// - Parameters: - /// - configuration: The HTTP pipeline's configuration. - /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` - /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - public func configureUpgradableHTTPServerPipeline( - configuration: NIOUpgradableHTTPServerPipelineConfiguration - ) -> EventLoopFuture> { - self._configureUpgradableHTTPServerPipeline( - configuration: configuration - ) - } - - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - private func _configureUpgradableHTTPServerPipeline( - configuration: NIOUpgradableHTTPServerPipelineConfiguration - ) -> EventLoopFuture> { - let future: EventLoopFuture> - - if self.eventLoop.inEventLoop { - let result = Result, Error> { - try self.syncOperations.configureUpgradableHTTPServerPipeline( - configuration: configuration - ) - } - future = self.eventLoop.makeCompletedFuture(result) - } else { - future = self.eventLoop.submit { - try self.syncOperations.configureUpgradableHTTPServerPipeline( - configuration: configuration - ) - } - } - - return future - } -} - -extension ChannelPipeline.SynchronousOperations { - /// Configure a `ChannelPipeline` for use as an HTTP server. - /// - /// - Parameters: - /// - configuration: The HTTP pipeline's configuration. - /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - public func configureUpgradableHTTPServerPipeline( - configuration: NIOUpgradableHTTPServerPipelineConfiguration - ) throws -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let responseEncoder = HTTPResponseEncoder(configuration: configuration.encoderConfiguration) - let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) - - var extraHTTPHandlers = [RemovableChannelHandler]() - extraHTTPHandlers.reserveCapacity(4) - extraHTTPHandlers.append(requestDecoder) - - try self.addHandler(responseEncoder) - try self.addHandler(requestDecoder) - - if configuration.enablePipelining { - let pipeliningHandler = HTTPServerPipelineHandler() - try self.addHandler(pipeliningHandler) - extraHTTPHandlers.append(pipeliningHandler) - } - - if configuration.enableResponseHeaderValidation { - let headerValidationHandler = NIOHTTPResponseHeadersValidator() - try self.addHandler(headerValidationHandler) - extraHTTPHandlers.append(headerValidationHandler) - } - - if configuration.enableErrorHandling { - let errorHandler = HTTPServerProtocolErrorHandler() - try self.addHandler(errorHandler) - extraHTTPHandlers.append(errorHandler) - } - - let upgrader = NIOTypedHTTPServerUpgradeHandler( - httpEncoder: responseEncoder, - extraHTTPHandlers: extraHTTPHandlers, - upgradeConfiguration: configuration.upgradeConfiguration - ) - try self.addHandler(upgrader) - - return upgrader.upgradeResultFuture - } -} - -// MARK: - Client pipeline configuration - -/// Configuration for an upgradable HTTP pipeline. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public struct NIOUpgradableHTTPClientPipelineConfiguration { - /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. - public var leftOverBytesStrategy = RemoveAfterUpgradeStrategy.dropBytes - - /// Whether to validate outbound response headers to confirm that they are - /// spec compliant. Defaults to `true`. - public var enableOutboundHeaderValidation = true - - /// The configuration for the ``HTTPRequestEncoder``. - public var encoderConfiguration = HTTPRequestEncoder.Configuration() - - /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. - public var upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration - - /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. - /// - /// The current defaults provide the following features: - /// 1. Outbound header fields validation to protect against response splitting attacks. - public init( - upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration - ) { - self.upgradeConfiguration = upgradeConfiguration - } -} - -extension ChannelPipeline { - /// Configure a `ChannelPipeline` for use as an HTTP client. - /// - /// - Parameters: - /// - configuration: The HTTP pipeline's configuration. - /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` - /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - public func configureUpgradableHTTPClientPipeline( - configuration: NIOUpgradableHTTPClientPipelineConfiguration - ) -> EventLoopFuture> { - self._configureUpgradableHTTPClientPipeline(configuration: configuration) - } - - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - private func _configureUpgradableHTTPClientPipeline( - configuration: NIOUpgradableHTTPClientPipelineConfiguration - ) -> EventLoopFuture> { - let future: EventLoopFuture> - - if self.eventLoop.inEventLoop { - let result = Result, Error> { - try self.syncOperations.configureUpgradableHTTPClientPipeline( - configuration: configuration - ) - } - future = self.eventLoop.makeCompletedFuture(result) - } else { - future = self.eventLoop.submit { - try self.syncOperations.configureUpgradableHTTPClientPipeline( - configuration: configuration - ) - } - } - - return future - } -} - -extension ChannelPipeline.SynchronousOperations { - /// Configure a `ChannelPipeline` for use as an HTTP client. - /// - /// - Parameters: - /// - configuration: The HTTP pipeline's configuration. - /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. - @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) - public func configureUpgradableHTTPClientPipeline( - configuration: NIOUpgradableHTTPClientPipelineConfiguration - ) throws -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let requestEncoder = HTTPRequestEncoder(configuration: configuration.encoderConfiguration) - let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy)) - var httpHandlers = [RemovableChannelHandler]() - httpHandlers.reserveCapacity(3) - httpHandlers.append(requestEncoder) - httpHandlers.append(responseDecoder) - - try self.addHandler(requestEncoder) - try self.addHandler(responseDecoder) - - if configuration.enableOutboundHeaderValidation { - let headerValidationHandler = NIOHTTPRequestHeadersValidator() - try self.addHandler(headerValidationHandler) - httpHandlers.append(headerValidationHandler) - } - - let upgrader = NIOTypedHTTPClientUpgradeHandler( - httpHandlers: httpHandlers, - upgradeConfiguration: configuration.upgradeConfiguration - ) - try self.addHandler(upgrader) - - return upgrader.upgradeResultFuture - } -} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift deleted file mode 100644 index f5a2f505ec..0000000000 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift +++ /dev/null @@ -1,283 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2013 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore - -/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to -/// a protocol on a client-side channel. -/// It has the option of denying this upgrade based upon the server response. -public protocol NIOTypedHTTPClientProtocolUpgrader { - associatedtype UpgradeResult: Sendable - - /// The protocol this upgrader knows how to support. - var supportedProtocol: String { get } - - /// All the header fields the protocol requires in the request to successfully upgrade. - /// These header fields will be added to the outbound request's "Connection" header field. - /// It is the responsibility of the custom headers call to actually add these required headers. - var requiredUpgradeHeaders: [String] { get } - - /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. - func addCustom(upgradeRequestHeaders: inout HTTPHeaders) - - /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. - func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool - - /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel - /// pipeline to add whatever channel handlers are required. - /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. - func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture -} - -/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public struct NIOTypedHTTPClientUpgradeConfiguration { - /// The initial request head that is sent out once the channel becomes active. - public var upgradeRequestHead: HTTPRequestHead - - /// The array of potential upgraders. - public var upgraders: [any NIOTypedHTTPClientProtocolUpgrader] - - /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used - /// to configure handlers that expect HTTP. - public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture - - public init( - upgradeRequestHead: HTTPRequestHead, - upgraders: [any NIOTypedHTTPClientProtocolUpgrader], - notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture - ) { - precondition(upgraders.count > 0, "A minimum of one protocol upgrader must be specified.") - self.upgradeRequestHead = upgradeRequestHead - self.upgraders = upgraders - self.notUpgradingCompletionHandler = notUpgradingCompletionHandler - } -} - -/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. -/// This handler will add all appropriate headers to perform an upgrade to -/// the a protocol. It may add headers for a set of protocols in preference order. -/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply -/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. -/// -/// The request sends an order of preference to request which protocol it would like to use for the upgrade. -/// It will only upgrade to the protocol that is returned first in the list and does not currently -/// have the capability to upgrade to multiple simultaneous layered protocols. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { - public typealias OutboundIn = HTTPClientRequestPart - public typealias OutboundOut = HTTPClientRequestPart - public typealias InboundIn = HTTPClientResponsePart - public typealias InboundOut = HTTPClientResponsePart - - /// The upgrade future which will be completed once protocol upgrading has been done. - public var upgradeResultFuture: EventLoopFuture { - self.upgradeResultPromise.futureResult - } - - private let upgradeRequestHead: HTTPRequestHead - private let httpHandlers: [RemovableChannelHandler] - private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture - private var stateMachine: NIOTypedHTTPClientUpgraderStateMachine - private var _upgradeResultPromise: EventLoopPromise? - private var upgradeResultPromise: EventLoopPromise { - precondition( - self._upgradeResultPromise != nil, - "Tried to access the upgrade result before the handler was added to a pipeline" - ) - return self._upgradeResultPromise! - } - - /// Create a ``NIOTypedHTTPClientUpgradeHandler``. - /// - /// - Parameters: - /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline - /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state - /// after the upgrade. It should include any handlers that are directly related to handling HTTP. - /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include - /// any other handler that cannot tolerate receiving non-HTTP data. - /// - upgradeConfiguration: The upgrade configuration. - public init( - httpHandlers: [RemovableChannelHandler], - upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration - ) { - self.httpHandlers = httpHandlers - var upgradeRequestHead = upgradeConfiguration.upgradeRequestHead - Self.addHeaders( - to: &upgradeRequestHead, - upgraders: upgradeConfiguration.upgraders - ) - self.upgradeRequestHead = upgradeRequestHead - self.stateMachine = .init(upgraders: upgradeConfiguration.upgraders) - self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler - } - - public func handlerAdded(context: ChannelHandlerContext) { - self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) - } - - public func handlerRemoved(context: ChannelHandlerContext) { - switch self.stateMachine.handlerRemoved() { - case .failUpgradePromise: - self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) - case .none: - break - } - } - - public func channelActive(context: ChannelHandlerContext) { - switch self.stateMachine.channelActive() { - case .writeUpgradeRequest: - context.write(self.wrapOutboundOut(.head(self.upgradeRequestHead)), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(.init()))), promise: nil) - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - - case .none: - break - } - } - - private static func addHeaders( - to requestHead: inout HTTPRequestHead, - upgraders: [any NIOTypedHTTPClientProtocolUpgrader] - ) { - let requiredHeaders = ["upgrade"] + upgraders.flatMap { $0.requiredUpgradeHeaders } - requestHead.headers.add(name: "Connection", value: requiredHeaders.joined(separator: ",")) - - let allProtocols = upgraders.map { $0.supportedProtocol.lowercased() } - requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) - - // Allow each upgrader the chance to add custom headers. - for upgrader in upgraders { - upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) - } - } - - public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - switch self.stateMachine.write() { - case .failWrite(let error): - promise?.fail(error) - - case .forwardWrite: - context.write(data, promise: promise) - } - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - switch self.stateMachine.channelReadData(data) { - case .unwrapData: - let responsePart = self.unwrapInboundIn(data) - self.channelRead(context: context, responsePart: responsePart) - - case .fireChannelRead: - context.fireChannelRead(data) - - case .none: - break - } - } - - private func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { - switch self.stateMachine.channelReadResponsePart(responsePart) { - case .fireErrorCaughtAndRemoveHandler(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - context.pipeline.removeHandler(self, promise: nil) - - case .runNotUpgradingInitializer: - self.notUpgradingCompletionHandler(context.channel) - .hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result) - } - - case .startUpgrading(let upgrader, let responseHead): - self.startUpgrading( - context: context, - upgrader: upgrader, - responseHead: responseHead - ) - - case .none: - break - } - } - - private func startUpgrading( - context: ChannelHandlerContext, - upgrader: any NIOTypedHTTPClientProtocolUpgrader, - responseHead: HTTPResponseHead - ) { - // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the - // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. - self.removeHTTPHandlers(context: context) - .flatMap { - upgrader.upgrade(channel: context.channel, upgradeResponse: responseHead) - }.hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result) - } - } - - private func upgradingHandlerCompleted( - context: ChannelHandlerContext, - _ result: Result - ) { - switch self.stateMachine.upgradingHandlerCompleted(result) { - case .fireErrorCaughtAndRemoveHandler(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - context.pipeline.removeHandler(self, promise: nil) - - case .fireErrorCaughtAndStartUnbuffering(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - self.unbuffer(context: context) - - case .startUnbuffering(let value): - self.upgradeResultPromise.succeed(value) - self.unbuffer(context: context) - - case .removeHandler(let value): - self.upgradeResultPromise.succeed(value) - context.pipeline.removeHandler(self, promise: nil) - - case .none: - break - } - } - - private func unbuffer(context: ChannelHandlerContext) { - while true { - switch self.stateMachine.unbuffer() { - case .fireChannelRead(let data): - context.fireChannelRead(data) - - case .fireChannelReadCompleteAndRemoveHandler: - context.fireChannelReadComplete() - context.pipeline.removeHandler(self, promise: nil) - return - } - } - } - - /// Removes any extra HTTP-related handlers from the channel pipeline. - private func removeHTTPHandlers(context: ChannelHandlerContext) -> EventLoopFuture { - guard self.httpHandlers.count > 0 else { - return context.eventLoop.makeSucceededFuture(()) - } - - let removeFutures = self.httpHandlers.map { context.pipeline.removeHandler($0) } - return .andAllSucceed(removeFutures, on: context.eventLoop) - } -} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift deleted file mode 100644 index fa04481ea9..0000000000 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift +++ /dev/null @@ -1,334 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import DequeModule -import NIOCore - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -struct NIOTypedHTTPClientUpgraderStateMachine { - @usableFromInline - enum State { - /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. - case initial(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) - - /// The request has been sent. We are waiting for the upgrade response. - case awaitingUpgradeResponseHead(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) - - @usableFromInline - struct AwaitingUpgradeResponseEnd { - var upgrader: any NIOTypedHTTPClientProtocolUpgrader - var responseHead: HTTPResponseHead - } - /// We received the response head and are just waiting for the response end. - case awaitingUpgradeResponseEnd(AwaitingUpgradeResponseEnd) - - @usableFromInline - struct Upgrading { - var buffer: Deque - } - /// We are either running the upgrading handler. - case upgrading(Upgrading) - - @usableFromInline - struct Unbuffering { - var buffer: Deque - } - case unbuffering(Unbuffering) - - case finished - - case modifying - } - - private var state: State - - init(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) { - self.state = .initial(upgraders: upgraders) - } - - @usableFromInline - enum HandlerRemovedAction { - case failUpgradePromise - } - - @inlinable - mutating func handlerRemoved() -> HandlerRemovedAction? { - switch self.state { - case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .unbuffering: - self.state = .finished - return .failUpgradePromise - - case .finished: - return .none - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - } - } - - @usableFromInline - enum ChannelActiveAction { - case writeUpgradeRequest - } - - @inlinable - mutating func channelActive() -> ChannelActiveAction? { - switch self.state { - case .initial(let upgraders): - self.state = .awaitingUpgradeResponseHead(upgraders: upgraders) - return .writeUpgradeRequest - - case .finished: - return nil - - case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering, .upgrading: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - } - } - - @usableFromInline - enum WriteAction { - case failWrite(Error) - case forwardWrite - } - - @usableFromInline - func write() -> WriteAction { - switch self.state { - case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading: - return .failWrite(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) - - case .unbuffering, .finished: - return .forwardWrite - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - } - } - - @usableFromInline - enum ChannelReadDataAction { - case unwrapData - case fireChannelRead - } - - @inlinable - mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { - switch self.state { - case .initial: - return .unwrapData - - case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd: - return .unwrapData - - case .upgrading(var upgrading): - // We got a read while running upgrading. - // We have to buffer the read to unbuffer it afterwards - self.state = .modifying - upgrading.buffer.append(data) - self.state = .upgrading(upgrading) - return nil - - case .unbuffering(var unbuffering): - self.state = .modifying - unbuffering.buffer.append(data) - self.state = .unbuffering(unbuffering) - return nil - - case .finished: - return .fireChannelRead - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - - @usableFromInline - enum ChannelReadResponsePartAction { - case fireErrorCaughtAndRemoveHandler(Error) - case runNotUpgradingInitializer - case startUpgrading( - upgrader: any NIOTypedHTTPClientProtocolUpgrader, - responseHeaders: HTTPResponseHead - ) - } - - @inlinable - mutating func channelReadResponsePart(_ responsePart: HTTPClientResponsePart) -> ChannelReadResponsePartAction? { - switch self.state { - case .initial: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - - case .awaitingUpgradeResponseHead(let upgraders): - // We should decide if we can upgrade based on the first response header: if we aren't upgrading, - // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're - // upgrading, the only thing we should see is a response head. Anything else in an error. - guard case .head(let response) = responsePart else { - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) - } - - // Assess whether the server has accepted our upgrade request. - guard case .switchingProtocols = response.status else { - var buffer = Deque() - buffer.append(.init(responsePart)) - self.state = .upgrading(.init(buffer: buffer)) - return .runNotUpgradingInitializer - } - - // Ok, we have a HTTP response. Check if it's an upgrade confirmation. - // If it's not, we want to pass it on and remove ourselves from the channel pipeline. - let acceptedProtocols = response.headers[canonicalForm: "upgrade"] - - // At the moment we only upgrade to the first protocol returned from the server. - guard let protocolName = acceptedProtocols.first?.lowercased() else { - // There are no upgrade protocols returned. - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) - } - - let matchingUpgrader = upgraders - .first(where: { $0.supportedProtocol.lowercased() == protocolName }) - - guard let upgrader = matchingUpgrader else { - // There is no upgrader for this protocol. - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) - } - - guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { - // The upgrader says no. - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) - } - - // We received the response head and decided that we can upgrade. - // We now need to wait for the response end and then we can perform the upgrade - self.state = .awaitingUpgradeResponseEnd(.init( - upgrader: upgrader, - responseHead: response - )) - return .none - - case .awaitingUpgradeResponseEnd(let awaitingUpgradeResponseEnd): - switch responsePart { - case .head: - // We got two HTTP response heads. - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) - - case .body: - // We tolerate body parts to be send but just ignore them - return .none - - case .end: - // We got the response end and can now run the upgrader. - self.state = .upgrading(.init(buffer: .init())) - return .startUpgrading( - upgrader: awaitingUpgradeResponseEnd.upgrader, - responseHeaders: awaitingUpgradeResponseEnd.responseHead - ) - } - - case .upgrading, .unbuffering, .finished: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - } - } - - @usableFromInline - enum UpgradingHandlerCompletedAction { - case fireErrorCaughtAndStartUnbuffering(Error) - case removeHandler(UpgradeResult) - case fireErrorCaughtAndRemoveHandler(Error) - case startUnbuffering(UpgradeResult) - } - - @inlinable - mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { - switch self.state { - case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - - case .upgrading(let upgrading): - switch result { - case .success(let value): - if !upgrading.buffer.isEmpty { - self.state = .unbuffering(.init(buffer: upgrading.buffer)) - return .startUnbuffering(value) - } else { - self.state = .finished - return .removeHandler(value) - } - - case .failure(let error): - if !upgrading.buffer.isEmpty { - // So we failed to upgrade. There is nothing really that we can do here. - // We are unbuffering the reads but there shouldn't be any handler in the pipeline - // that expects a specific type of reads anyhow. - self.state = .unbuffering(.init(buffer: upgrading.buffer)) - return .fireErrorCaughtAndStartUnbuffering(error) - } else { - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(error) - } - } - - case .finished: - // We have to tolerate this - return nil - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - } - } - - @usableFromInline - enum UnbufferAction { - case fireChannelRead(NIOAny) - case fireChannelReadCompleteAndRemoveHandler - } - - @inlinable - mutating func unbuffer() -> UnbufferAction { - switch self.state { - case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .finished: - preconditionFailure("Invalid state \(self.state)") - - case .unbuffering(var unbuffering): - self.state = .modifying - - if let element = unbuffering.buffer.popFirst() { - self.state = .unbuffering(unbuffering) - - return .fireChannelRead(element) - } else { - self.state = .finished - - return .fireChannelReadCompleteAndRemoveHandler - } - - case .modifying: - fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - - } - } -} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift deleted file mode 100644 index 55b21e5982..0000000000 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ /dev/null @@ -1,369 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore - -/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to -/// a protocol on a server-side channel. -public protocol NIOTypedHTTPServerProtocolUpgrader { - associatedtype UpgradeResult: Sendable - - /// The protocol this upgrader knows how to support. - var supportedProtocol: String { get } - - /// All the header fields the protocol needs in the request to successfully upgrade. These header fields - /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated - /// against the inbound request's `Connection` header field. - var requiredUpgradeHeaders: [String] { get } - - /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client - /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should - /// return a failed future. - func buildUpgradeResponse( - channel: Channel, - upgradeRequest: HTTPRequestHead, - initialResponseHeaders: HTTPHeaders - ) -> EventLoopFuture - - /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline - /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received - /// data will be buffered. - func upgrade( - channel: Channel, - upgradeRequest: HTTPRequestHead - ) -> EventLoopFuture -} - -/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public struct NIOTypedHTTPServerUpgradeConfiguration { - /// The array of potential upgraders. - public var upgraders: [any NIOTypedHTTPServerProtocolUpgrader] - - /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used - /// to configure handlers that expect HTTP. - public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture - - public init( - upgraders: [any NIOTypedHTTPServerProtocolUpgrader], - notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture - ) { - self.upgraders = upgraders - self.notUpgradingCompletionHandler = notUpgradingCompletionHandler - } -} - -/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. -/// -/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of -/// whether the upgrade succeeded or not. -/// -/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade -/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's -/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined -/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: -/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias InboundOut = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private let upgraders: [String: any NIOTypedHTTPServerProtocolUpgrader] - private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture - private let httpEncoder: HTTPResponseEncoder - private let extraHTTPHandlers: [RemovableChannelHandler] - private var stateMachine = NIOTypedHTTPServerUpgraderStateMachine() - - private var _upgradeResultPromise: EventLoopPromise? - private var upgradeResultPromise: EventLoopPromise { - precondition( - self._upgradeResultPromise != nil, - "Tried to access the upgrade result before the handler was added to a pipeline" - ) - return self._upgradeResultPromise! - } - - /// The upgrade future which will be completed once protocol upgrading has been done. - public var upgradeResultFuture: EventLoopFuture { - self.upgradeResultPromise.futureResult - } - - /// Create a ``NIOTypedHTTPServerUpgradeHandler``. - /// - /// - Parameters: - /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will - /// be removed from the pipeline once the upgrade response is sent. This is used to ensure - /// that the pipeline will be in a clean state after upgrade. - /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least - /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate - /// receiving non-HTTP data. - /// - upgradeConfiguration: The upgrade configuration. - public init( - httpEncoder: HTTPResponseEncoder, - extraHTTPHandlers: [RemovableChannelHandler], - upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration - ) { - var upgraderMap = [String: any NIOTypedHTTPServerProtocolUpgrader]() - for upgrader in upgradeConfiguration.upgraders { - upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader - } - self.upgraders = upgraderMap - self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler - self.httpEncoder = httpEncoder - self.extraHTTPHandlers = extraHTTPHandlers - } - - public func handlerAdded(context: ChannelHandlerContext) { - self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) - } - - public func handlerRemoved(context: ChannelHandlerContext) { - switch self.stateMachine.handlerRemoved() { - case .failUpgradePromise: - self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) - case .none: - break - } - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - switch self.stateMachine.channelReadData(data) { - case .unwrapData: - let requestPart = self.unwrapInboundIn(data) - self.channelRead(context: context, requestPart: requestPart) - - case .fireChannelRead: - context.fireChannelRead(data) - - case .none: - break - } - } - - private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { - switch self.stateMachine.channelReadRequestPart(requestPart) { - case .failUpgradePromise(let error): - self.upgradeResultPromise.fail(error) - - case .runNotUpgradingInitializer: - self.notUpgradingCompletionHandler(context.channel) - .hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) - } - - case .findUpgrader(let head, let requestedProtocols, let allHeaderNames, let connectionHeader): - let protocolIterator = requestedProtocols.makeIterator() - self.handleUpgradeForProtocol( - context: context, - protocolIterator: protocolIterator, - request: head, - allHeaderNames: allHeaderNames, - connectionHeader: connectionHeader - ).whenComplete { result in - context.eventLoop.assertInEventLoop() - self.findingUpgradeCompleted(context: context, requestHead: head, result) - } - - case .startUpgrading(let upgrader, let requestHead, let responseHeaders, let proto): - self.startUpgrading( - context: context, - upgrader: upgrader, - requestHead: requestHead, - responseHeaders: responseHeaders, - proto: proto - ) - - case .none: - break - } - } - - private func upgradingHandlerCompleted( - context: ChannelHandlerContext, - _ result: Result, - requestHeadAndProtocol: (HTTPRequestHead, String)? - ) { - switch self.stateMachine.upgradingHandlerCompleted(result) { - case .fireErrorCaughtAndRemoveHandler(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - context.pipeline.removeHandler(self, promise: nil) - - case .fireErrorCaughtAndStartUnbuffering(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - self.unbuffer(context: context) - - case .startUnbuffering(let value): - if let requestHeadAndProtocol = requestHeadAndProtocol { - context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) - } - self.upgradeResultPromise.succeed(value) - self.unbuffer(context: context) - - case .removeHandler(let value): - if let requestHeadAndProtocol = requestHeadAndProtocol { - context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) - } - self.upgradeResultPromise.succeed(value) - context.pipeline.removeHandler(self, promise: nil) - - case .none: - break - } - } - - /// Attempt to upgrade a single protocol. - /// - /// Will recurse through `protocolIterator` if upgrade fails. - private func handleUpgradeForProtocol( - context: ChannelHandlerContext, - protocolIterator: Array.Iterator, - request: HTTPRequestHead, - allHeaderNames: Set, - connectionHeader: Set - ) -> EventLoopFuture<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?> { - // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. - var protocolIterator = protocolIterator - guard let proto = protocolIterator.next() else { - // We're done! No suitable protocol for upgrade. - return context.eventLoop.makeSucceededFuture(nil) - } - - guard let upgrader = self.upgraders[proto.lowercased()] else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) - } - - let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) - guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) - } - - let responseHeaders = self.buildUpgradeHeaders(protocol: proto) - return upgrader.buildUpgradeResponse( - channel: context.channel, - upgradeRequest: request, - initialResponseHeaders: responseHeaders - ) - .hop(to: context.eventLoop) - .map { (upgrader, $0, proto) } - .flatMapError { error in - // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. - context.fireErrorCaught(error) - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) - } - } - - private func findingUpgradeCompleted( - context: ChannelHandlerContext, - requestHead: HTTPRequestHead, - _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> - ) { - switch self.stateMachine.findingUpgraderCompleted(requestHead: requestHead, result) { - case .startUpgrading(let upgrader, let responseHeaders, let proto): - self.startUpgrading( - context: context, - upgrader: upgrader, - requestHead: requestHead, - responseHeaders: responseHeaders, - proto: proto - ) - - case .runNotUpgradingInitializer: - self.notUpgradingCompletionHandler(context.channel) - .hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) - } - - case .fireErrorCaughtAndStartUnbuffering(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - self.unbuffer(context: context) - - case .fireErrorCaughtAndRemoveHandler(let error): - self.upgradeResultPromise.fail(error) - context.fireErrorCaught(error) - context.pipeline.removeHandler(self, promise: nil) - - case .none: - break - } - } - - private func startUpgrading( - context: ChannelHandlerContext, - upgrader: any NIOTypedHTTPServerProtocolUpgrader, - requestHead: HTTPRequestHead, - responseHeaders: HTTPHeaders, - proto: String - ) { - // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP - // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until - // that completes. - // While there are a lot of Futures involved here it's quite possible that all of this code will - // actually complete synchronously: we just want to program for the possibility that it won't. - // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the - // internal handler, then call the user code, and then finally when the user code is done we do - // our final cleanup steps, namely we replay the received data we buffered in the meantime and - // then remove ourselves from the pipeline. - self.removeExtraHandlers(context: context).flatMap { - self.sendUpgradeResponse(context: context, responseHeaders: responseHeaders) - }.flatMap { - context.pipeline.removeHandler(self.httpEncoder) - }.flatMap { () -> EventLoopFuture in - return upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) - }.hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) - } - } - - /// Sends the 101 Switching Protocols response for the pipeline. - private func sendUpgradeResponse(context: ChannelHandlerContext, responseHeaders: HTTPHeaders) -> EventLoopFuture { - var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) - response.headers = responseHeaders - return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) - } - - /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. - private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { - return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) - } - - /// Removes any extra HTTP-related handlers from the channel pipeline. - private func removeExtraHandlers(context: ChannelHandlerContext) -> EventLoopFuture { - guard self.extraHTTPHandlers.count > 0 else { - return context.eventLoop.makeSucceededFuture(()) - } - - return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, - on: context.eventLoop) - } - - private func unbuffer(context: ChannelHandlerContext) { - while true { - switch self.stateMachine.unbuffer() { - case .fireChannelRead(let data): - context.fireChannelRead(data) - - case .fireChannelReadCompleteAndRemoveHandler: - context.fireChannelReadComplete() - context.pipeline.removeHandler(self, promise: nil) - return - } - } - } -} diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift deleted file mode 100644 index d0fcf287de..0000000000 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift +++ /dev/null @@ -1,385 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -import DequeModule -import NIOCore - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -struct NIOTypedHTTPServerUpgraderStateMachine { - @usableFromInline - enum State { - /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. - case initial - - @usableFromInline - struct AwaitingUpgrader { - var seenFirstRequest: Bool - var buffer: Deque - } - - /// The request head has been received. We're currently running the future chain awaiting an upgrader. - case awaitingUpgrader(AwaitingUpgrader) - - @usableFromInline - struct UpgraderReady { - var upgrader: any NIOTypedHTTPServerProtocolUpgrader - var requestHead: HTTPRequestHead - var responseHeaders: HTTPHeaders - var proto: String - var buffer: Deque - } - - /// We have an upgrader, which means we can begin upgrade we are just waiting for the request end. - case upgraderReady(UpgraderReady) - - @usableFromInline - struct Upgrading { - var buffer: Deque - } - /// We are either running the upgrading handler. - case upgrading(Upgrading) - - @usableFromInline - struct Unbuffering { - var buffer: Deque - } - case unbuffering(Unbuffering) - - case finished - - case modifying - } - - private var state = State.initial - - @usableFromInline - enum HandlerRemovedAction { - case failUpgradePromise - } - - @inlinable - mutating func handlerRemoved() -> HandlerRemovedAction? { - switch self.state { - case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .unbuffering: - self.state = .finished - return .failUpgradePromise - - case .finished: - return .none - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - @usableFromInline - enum ChannelReadDataAction { - case unwrapData - case fireChannelRead - } - - @inlinable - mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { - switch self.state { - case .initial: - return .unwrapData - - case .awaitingUpgrader(var awaitingUpgrader): - if awaitingUpgrader.seenFirstRequest { - // We should buffer the data since we have seen the full request. - self.state = .modifying - awaitingUpgrader.buffer.append(data) - self.state = .awaitingUpgrader(awaitingUpgrader) - return nil - } else { - // We shouldn't buffer. This means we are still expecting HTTP parts. - return .unwrapData - } - - case .upgraderReady: - // We have not seen the end of the HTTP request so this - // data is probably an HTTP request part. - return .unwrapData - - case .unbuffering(var unbuffering): - self.state = .modifying - unbuffering.buffer.append(data) - self.state = .unbuffering(unbuffering) - return nil - - case .finished: - return .fireChannelRead - - case .upgrading(var upgrading): - // We got a read while running ugprading. - // We have to buffer the read to unbuffer it afterwards - self.state = .modifying - upgrading.buffer.append(data) - self.state = .upgrading(upgrading) - return nil - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - @usableFromInline - enum ChannelReadRequestPartAction { - case failUpgradePromise(Error) - case runNotUpgradingInitializer - case startUpgrading( - upgrader: any NIOTypedHTTPServerProtocolUpgrader, - requestHead: HTTPRequestHead, - responseHeaders: HTTPHeaders, - proto: String - ) - case findUpgrader( - head: HTTPRequestHead, - requestedProtocols: [String], - allHeaderNames: Set, - connectionHeader: Set - ) - } - - @inlinable - mutating func channelReadRequestPart(_ requestPart: HTTPServerRequestPart) -> ChannelReadRequestPartAction? { - switch self.state { - case .initial: - guard case .head(let head) = requestPart else { - // The first data that we saw was not a head. This is a protocol error and we are just going to - // fail upgrading - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - } - - // Ok, we have a HTTP head. Check if it's an upgrade. - let requestedProtocols = head.headers[canonicalForm: "upgrade"].map(String.init) - guard requestedProtocols.count > 0 else { - // We have to buffer now since we got the request head but are not upgrading. - // The user is configuring the HTTP pipeline now. - var buffer = Deque() - buffer.append(NIOAny(requestPart)) - self.state = .upgrading(.init(buffer: buffer)) - return .runNotUpgradingInitializer - } - - // We can now transition to awaiting the upgrader. This means that we are trying to - // find an upgrade that can handle requested protocols. We are not buffering because - // we are waiting for the request end. - self.state = .awaitingUpgrader(.init(seenFirstRequest: false, buffer: .init())) - - let connectionHeader = Set(head.headers[canonicalForm: "connection"].map { $0.lowercased() }) - let allHeaderNames = Set(head.headers.map { $0.name.lowercased() }) - - return .findUpgrader( - head: head, - requestedProtocols: requestedProtocols, - allHeaderNames: allHeaderNames, - connectionHeader: connectionHeader - ) - - case .awaitingUpgrader(let awaitingUpgrader): - switch (awaitingUpgrader.seenFirstRequest, requestPart) { - case (true, _): - // This is weird we are seeing more requests parts after we have seen an end - // Let's fail upgrading - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - - case (false, .head): - // This is weird we are seeing another head but haven't seen the end for the request before - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - - case (false, .body): - // This is weird we are seeing body parts for a request that indicated that it wanted - // to upgrade. - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - - case (false, .end): - // Okay we got the end as expected. Just gotta store this in our state. - self.state = .awaitingUpgrader(.init(seenFirstRequest: true, buffer: awaitingUpgrader.buffer)) - return nil - } - - case .upgraderReady(let upgraderReady): - switch requestPart { - case .head: - // This is weird we are seeing another head but haven't seen the end for the request before - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - - case .body: - // This is weird we are seeing body parts for a request that indicated that it wanted - // to upgrade. - return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) - - case .end: - // Okay we got the end as expected and our upgrader is ready so let's start upgrading - self.state = .upgrading(.init(buffer: upgraderReady.buffer)) - return .startUpgrading( - upgrader: upgraderReady.upgrader, - requestHead: upgraderReady.requestHead, - responseHeaders: upgraderReady.responseHeaders, - proto: upgraderReady.proto - ) - } - - case .upgrading, .unbuffering, .finished: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - @usableFromInline - enum UpgradingHandlerCompletedAction { - case fireErrorCaughtAndStartUnbuffering(Error) - case removeHandler(UpgradeResult) - case fireErrorCaughtAndRemoveHandler(Error) - case startUnbuffering(UpgradeResult) - } - - @inlinable - mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { - switch self.state { - case .initial: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - case .upgrading(let upgrading): - switch result { - case .success(let value): - if !upgrading.buffer.isEmpty { - self.state = .unbuffering(.init(buffer: upgrading.buffer)) - return .startUnbuffering(value) - } else { - self.state = .finished - return .removeHandler(value) - } - - case .failure(let error): - if !upgrading.buffer.isEmpty { - // So we failed to upgrade. There is nothing really that we can do here. - // We are unbuffering the reads but there shouldn't be any handler in the pipeline - // that expects a specific type of reads anyhow. - self.state = .unbuffering(.init(buffer: upgrading.buffer)) - return .fireErrorCaughtAndStartUnbuffering(error) - } else { - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(error) - } - } - - case .finished: - // We have to tolerate this - return nil - - case .awaitingUpgrader, .upgraderReady, .unbuffering: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - @usableFromInline - enum FindingUpgraderCompletedAction { - case startUpgrading(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String) - case runNotUpgradingInitializer - case fireErrorCaughtAndStartUnbuffering(Error) - case fireErrorCaughtAndRemoveHandler(Error) - } - - @inlinable - mutating func findingUpgraderCompleted( - requestHead: HTTPRequestHead, - _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> - ) -> FindingUpgraderCompletedAction? { - switch self.state { - case .initial, .upgraderReady: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - case .awaitingUpgrader(let awaitingUpgrader): - switch result { - case .success(.some((let upgrader, let responseHeaders, let proto))): - if awaitingUpgrader.seenFirstRequest { - // We have seen the end of the request. So we can upgrade now. - self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) - return .startUpgrading(upgrader: upgrader, responseHeaders: responseHeaders, proto: proto) - } else { - // We have not yet seen the end so we have to wait until that happens - self.state = .upgraderReady(.init( - upgrader: upgrader, - requestHead: requestHead, - responseHeaders: responseHeaders, - proto: proto, - buffer: awaitingUpgrader.buffer - )) - return nil - } - - case .success(.none): - // There was no upgrader to handle the request. We just run the not upgrading - // initializer now. - self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) - return .runNotUpgradingInitializer - - case .failure(let error): - if !awaitingUpgrader.buffer.isEmpty { - self.state = .unbuffering(.init(buffer: awaitingUpgrader.buffer)) - return .fireErrorCaughtAndStartUnbuffering(error) - } else { - self.state = .finished - return .fireErrorCaughtAndRemoveHandler(error) - } - } - - case .upgrading, .unbuffering, .finished: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - } - } - - @usableFromInline - enum UnbufferAction { - case fireChannelRead(NIOAny) - case fireChannelReadCompleteAndRemoveHandler - } - - @inlinable - mutating func unbuffer() -> UnbufferAction { - switch self.state { - case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .finished: - preconditionFailure("Invalid state \(self.state)") - - case .unbuffering(var unbuffering): - self.state = .modifying - - if let element = unbuffering.buffer.popFirst() { - self.state = .unbuffering(unbuffering) - - return .fireChannelRead(element) - } else { - self.state = .finished - - return .fireChannelReadCompleteAndRemoveHandler - } - - case .modifying: - fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - - } - } - -} diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index 0d8bd4404f..9bc0e0c9aa 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -15,7 +15,7 @@ import NIOCore import NIOPosix -@available(macOS 14, *) +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Client { /// The host to connect to. diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index 390fff795b..edc52f2e1b 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -15,7 +15,7 @@ import NIOCore import NIOPosix -@available(macOS 14, *) +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Server { /// The server's host. diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index ac6e92ee32..6483954bde 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -74,62 +74,6 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { } } -/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. -/// -/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. -/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the -/// pipeline to remove the HTTP `ChannelHandler`s. -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public final class NIOTypedWebSocketClientUpgrader: NIOTypedHTTPClientProtocolUpgrader { - /// RFC 6455 specs this as the required entry in the Upgrade header. - public let supportedProtocol: String = "websocket" - /// None of the websocket headers are actually defined as 'required'. - public let requiredUpgradeHeaders: [String] = [] - - private let requestKey: String - private let maxFrameSize: Int - private let enableAutomaticErrorHandling: Bool - private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture - - /// - Parameters: - /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. - /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. - /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. - /// - upgradePipelineHandler: Called once the upgrade was successful. - public init( - requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), - maxFrameSize: Int = 1 << 14, - enableAutomaticErrorHandling: Bool = true, - upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture - ) { - precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") - precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") - self.requestKey = requestKey - self.upgradePipelineHandler = upgradePipelineHandler - self.maxFrameSize = maxFrameSize - self.enableAutomaticErrorHandling = enableAutomaticErrorHandling - } - - public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) { - _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) - } - - public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) - } - - public func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { - _upgrade( - channel: channel, - upgradeResponse: upgradeResponse, - maxFrameSize: self.maxFrameSize, - enableAutomaticErrorHandling: self.enableAutomaticErrorHandling, - upgradePipelineHandler: self.upgradePipelineHandler - ) - } -} - - @available(*, unavailable) extension NIOWebSocketClientUpgrader: Sendable {} diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 4580d0ec07..44b9f56731 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -175,90 +175,6 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch } } -/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. -/// -/// Users may frequently want to offer multiple websocket endpoints on the same port. For this -/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to -/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers -/// (including the path) and determines whether this is a websocket upgrade request that is acceptable -/// to them. -/// -/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to -/// remove the HTTP `ChannelHandler`s. -public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { - private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture - private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture - - /// RFC 6455 specs this as the required entry in the Upgrade header. - public let supportedProtocol: String = "websocket" - - /// We deliberately do not actually set any required headers here, because the websocket - /// spec annoyingly does not actually force the client to send these in the Upgrade header, - /// which NIO requires. We check for these manually. - public let requiredUpgradeHeaders: [String] = [] - - private let shouldUpgrade: ShouldUpgrade - private let upgradePipelineHandler: UpgradePipelineHandler - private let maxFrameSize: Int - private let enableAutomaticErrorHandling: Bool - - /// Create a new ``NIOTypedWebSocketServerUpgrader``. - /// - /// - Parameters: - /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the - /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but - /// this is an objectively unreasonable maximum value (on AMD64 systems it is not - /// possible to even. Users may set this to any value up to `UInt32.max`. - /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol - /// errors by sending error responses and closing the connection. Defaults to `true`, - /// may be set to `false` if the user wishes to handle their own errors. - /// - shouldUpgrade: A callback that determines whether the websocket request should be - /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with - /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, - /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return - /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. - /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is - /// flushed, and that is expected to mutate the `Channel` appropriately to handle the - /// websocket protocol. This only needs to add the user handlers: the - /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the - /// pipeline automatically. - public init( - maxFrameSize: Int = 1 << 14, - enableAutomaticErrorHandling: Bool = true, - shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, - upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture - ) { - precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") - self.shouldUpgrade = shouldUpgrade - self.upgradePipelineHandler = upgradePipelineHandler - self.maxFrameSize = maxFrameSize - self.enableAutomaticErrorHandling = enableAutomaticErrorHandling - } - - public func buildUpgradeResponse( - channel: Channel, - upgradeRequest: HTTPRequestHead, - initialResponseHeaders: HTTPHeaders - ) -> EventLoopFuture { - _buildUpgradeResponse( - channel: channel, - upgradeRequest: upgradeRequest, - initialResponseHeaders: initialResponseHeaders, - shouldUpgrade: self.shouldUpgrade - ) - } - - public func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { - _upgrade( - channel: channel, - upgradeRequest: upgradeRequest, - maxFrameSize: self.maxFrameSize, - automaticErrorHandling: self.enableAutomaticErrorHandling, - upgradePipelineHandler: self.upgradePipelineHandler - ) - } -} - private func _buildUpgradeResponse( channel: Channel, upgradeRequest: HTTPRequestHead, diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index 6477416684..a2698536fe 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -12,127 +12,136 @@ // //===----------------------------------------------------------------------===// #if swift(>=5.9) -import NIOCore -import NIOPosix -import NIOHTTP1 -import NIOWebSocket - -@available(macOS 14, *) @main struct Client { - /// The host to connect to. - private let host: String - /// The port to connect to. - private let port: Int - /// The client's event loop group. - private let eventLoopGroup: MultiThreadedEventLoopGroup - - enum UpgradeResult { - case websocket(NIOAsyncChannel) - case notUpgraded - } - - static func main() async throws { - let client = Client( - host: "localhost", - port: 8888, - eventLoopGroup: .singleton - ) - try await client.run() - } - - /// This method starts the client and tries to setup a WebSocket connection. - func run() async throws { - let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) - .connect( - host: self.host, - port: self.port - ) { channel in - channel.eventLoop.makeCompletedFuture { - let upgrader = NIOTypedWebSocketClientUpgrader( - upgradePipelineHandler: { (channel, _) in - channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) - return UpgradeResult.websocket(asyncChannel) - } - } - ) - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") - headers.add(name: "Content-Length", value: "0") - - let requestHead = HTTPRequestHead( - version: .http1_1, - method: .GET, - uri: "/", - headers: headers - ) - - let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( - upgradeRequestHead: requestHead, - upgraders: [upgrader], - notUpgradingCompletionHandler: { channel in - channel.eventLoop.makeCompletedFuture { - return UpgradeResult.notUpgraded - } - } - ) - - let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( - configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) - ) - - return negotiationResultFuture - } - } - - // We are awaiting and handling the upgrade result now. - try await self.handleUpgradeResult(upgradeResult) - } - - /// This method handles the upgrade result. - private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { - switch try await upgradeResult.get() { - case .websocket(let websocketChannel): - print("Handling websocket connection") - try await self.handleWebsocketChannel(websocketChannel) - print("Done handling websocket connection") - case .notUpgraded: - // The upgrade to websocket did not succeed. We are just exiting in this case. - print("Upgrade declined") - } + static func main() { + fatalError("Disabled due to https://github.com/apple/swift-nio/issues/2574") } +} - private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { - // We are sending a ping frame and then - // start to handle all inbound frames. - - let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) - try await channel.outbound.write(pingFrame) - - for try await frame in channel.inbound { - switch frame.opcode { - case .pong: - print("Received pong: \(String(buffer: frame.data))") - - case .text: - print("Received: \(String(buffer: frame.data))") +// Commented out due https://github.com/apple/swift-nio/issues/2574 - case .connectionClose: - // Handle a received close frame. We're just going to close by returning from this method. - print("Received Close instruction from server") - return - case .binary, .continuation, .ping: - // We ignore these frames. - break - default: - // Unknown frames are errors. - return - } - } - } -} +//import NIOCore +//import NIOPosix +//import NIOHTTP1 +//import NIOWebSocket +// +//@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) +//@main +//struct Client { +// /// The host to connect to. +// private let host: String +// /// The port to connect to. +// private let port: Int +// /// The client's event loop group. +// private let eventLoopGroup: MultiThreadedEventLoopGroup +// +// enum UpgradeResult { +// case websocket(NIOAsyncChannel) +// case notUpgraded +// } +// +// static func main() async throws { +// let client = Client( +// host: "localhost", +// port: 8888, +// eventLoopGroup: .singleton +// ) +// try await client.run() +// } +// +// /// This method starts the client and tries to setup a WebSocket connection. +// func run() async throws { +// let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) +// .connect( +// host: self.host, +// port: self.port +// ) { channel in +// channel.eventLoop.makeCompletedFuture { +// let upgrader = NIOTypedWebSocketClientUpgrader( +// upgradePipelineHandler: { (channel, _) in +// channel.eventLoop.makeCompletedFuture { +// let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) +// return UpgradeResult.websocket(asyncChannel) +// } +// } +// ) +// +// var headers = HTTPHeaders() +// headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") +// headers.add(name: "Content-Length", value: "0") +// +// let requestHead = HTTPRequestHead( +// version: .http1_1, +// method: .GET, +// uri: "/", +// headers: headers +// ) +// +// let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( +// upgradeRequestHead: requestHead, +// upgraders: [upgrader], +// notUpgradingCompletionHandler: { channel in +// channel.eventLoop.makeCompletedFuture { +// return UpgradeResult.notUpgraded +// } +// } +// ) +// +// let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( +// configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) +// ) +// +// return negotiationResultFuture +// } +// } +// +// // We are awaiting and handling the upgrade result now. +// try await self.handleUpgradeResult(upgradeResult) +// } +// +// /// This method handles the upgrade result. +// private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { +// switch try await upgradeResult.get() { +// case .websocket(let websocketChannel): +// print("Handling websocket connection") +// try await self.handleWebsocketChannel(websocketChannel) +// print("Done handling websocket connection") +// case .notUpgraded: +// // The upgrade to websocket did not succeed. We are just exiting in this case. +// print("Upgrade declined") +// } +// } +// +// private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { +// // We are sending a ping frame and then +// // start to handle all inbound frames. +// +// let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) +// try await channel.outbound.write(pingFrame) +// +// for try await frame in channel.inbound { +// switch frame.opcode { +// case .pong: +// print("Received pong: \(String(buffer: frame.data))") +// +// case .text: +// print("Received: \(String(buffer: frame.data))") +// +// case .connectionClose: +// // Handle a received close frame. We're just going to close by returning from this method. +// print("Received Close instruction from server") +// return +// case .binary, .continuation, .ping: +// // We ignore these frames. +// break +// default: +// // Unknown frames are errors. +// return +// } +// } +// } +//} #else @main diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index 525c64b00d..dad8fbd12c 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -41,238 +41,247 @@ let websocketResponse = """ """ -@available(macOS 14, *) @main struct Server { - /// The server's host. - private let host: String - /// The server's port. - private let port: Int - /// The server's event loop group. - private let eventLoopGroup: MultiThreadedEventLoopGroup - - private static let responseBody = ByteBuffer(string: websocketResponse) - - enum UpgradeResult { - case websocket(NIOAsyncChannel) - case notUpgraded(NIOAsyncChannel>) - } - - static func main() async throws { - let server = Server( - host: "localhost", - port: 8888, - eventLoopGroup: .singleton - ) - try await server.run() - } - - /// This method starts the server and handles incoming connections. - func run() async throws { - let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .bind( - host: self.host, - port: self.port - ) { channel in - channel.eventLoop.makeCompletedFuture { - let upgrader = NIOTypedWebSocketServerUpgrader( - shouldUpgrade: { (channel, head) in - channel.eventLoop.makeSucceededFuture(HTTPHeaders()) - }, - upgradePipelineHandler: { (channel, _) in - channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) - return UpgradeResult.websocket(asyncChannel) - } - } - ) - - let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( - upgraders: [upgrader], - notUpgradingCompletionHandler: { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) - let asyncChannel = try NIOAsyncChannel>(synchronouslyWrapping: channel) - return UpgradeResult.notUpgraded(asyncChannel) - } - } - ) - - let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( - configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) - ) - - return negotiationResultFuture - } - } - - // We are handling each incoming connection in a separate child task. It is important - // to use a discarding task group here which automatically discards finished child tasks. - // A normal task group retains all child tasks and their outputs in memory until they are - // consumed by iterating the group or by exiting the group. Since, we are never consuming - // the results of the group we need the group to automatically discard them; otherwise, this - // would result in a memory leak over time. - try await withThrowingDiscardingTaskGroup { group in - for try await upgradeResult in channel.inbound { - group.addTask { - await self.handleUpgradeResult(upgradeResult) - } - } - } - } - - /// This method handles a single connection by echoing back all inbound data. - private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { - // Note that this method is non-throwing and we are catching any error. - // We do this since we don't want to tear down the whole server when a single connection - // encounters an error. - do { - switch try await upgradeResult.get() { - case .websocket(let websocketChannel): - print("Handling websocket connection") - try await self.handleWebsocketChannel(websocketChannel) - print("Done handling websocket connection") - case .notUpgraded(let httpChannel): - print("Handling HTTP connection") - try await self.handleHTTPChannel(httpChannel) - print("Done handling HTTP connection") - } - } catch { - print("Hit error: \(error)") - } - } - - private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - for try await frame in channel.inbound { - switch frame.opcode { - case .ping: - print("Received ping") - var frameData = frame.data - let maskingKey = frame.maskKey - - if let maskingKey = maskingKey { - frameData.webSocketUnmask(maskingKey) - } - - let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - try await channel.outbound.write(responseFrame) - - case .connectionClose: - // This is an unsolicited close. We're going to send a response frame and - // then, when we've sent it, close up shop. We should send back the close code the remote - // peer sent us, unless they didn't send one at all. - print("Received close") - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() - let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) - try await channel.outbound.write(closeFrame) - return - case .binary, .continuation, .pong: - // We ignore these frames. - break - default: - // Unknown frames are errors. - return - } - } - } - - group.addTask { - // This is our main business logic where we are just sending the current time - // every second. - while true { - // We can't really check for error here, but it's also not the purpose of the - // example so let's not worry about it. - let theTime = ContinuousClock().now - var buffer = channel.channel.allocator.buffer(capacity: 12) - buffer.writeString("\(theTime)") - - let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) - - print("Sending time") - try await channel.outbound.write(frame) - try await Task.sleep(for: .seconds(1)) - } - } - - try await group.next() - group.cancelAll() - } - } - - - private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { - for try await requestPart in channel.inbound { - // We're not interested in request bodies here: we're just serving up GET responses - // to get the client to initiate a websocket request. - guard case .head(let head) = requestPart else { - return - } - - // GETs only. - guard case .GET = head.method else { - try await self.respond405(writer: channel.outbound) - return - } - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/html") - headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) - headers.add(name: "Connection", value: "close") - let responseHead = HTTPResponseHead( - version: .init(major: 1, minor: 1), - status: .ok, - headers: headers - ) - - try await channel.outbound.write( - contentsOf: [ - .head(responseHead), - .body(Self.responseBody), - .end(nil) - ] - ) - } - } - - private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { - var headers = HTTPHeaders() - headers.add(name: "Connection", value: "close") - headers.add(name: "Content-Length", value: "0") - let head = HTTPResponseHead( - version: .http1_1, - status: .methodNotAllowed, - headers: headers - ) - - try await writer.write( - contentsOf: [ - .head(head), - .end(nil) - ] - ) + static func main() { + fatalError("Disabled due to https://github.com/apple/swift-nio/issues/2574") } } -final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { - typealias OutboundIn = HTTPPart - typealias OutboundOut = HTTPServerResponsePart +// Commented out due https://github.com/apple/swift-nio/issues/2574 - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let part = self.unwrapOutboundIn(data) - switch part { - case .head(let head): - context.write(self.wrapOutboundOut(.head(head)), promise: promise) - case .body(let buffer): - context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) - case .end(let trailers): - context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) - } - } -} +//@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) +//@main +//struct Server { +// /// The server's host. +// private let host: String +// /// The server's port. +// private let port: Int +// /// The server's event loop group. +// private let eventLoopGroup: MultiThreadedEventLoopGroup +// +// private static let responseBody = ByteBuffer(string: websocketResponse) +// +// enum UpgradeResult { +// case websocket(NIOAsyncChannel) +// case notUpgraded(NIOAsyncChannel>) +// } +// +// static func main() async throws { +// let server = Server( +// host: "localhost", +// port: 8888, +// eventLoopGroup: .singleton +// ) +// try await server.run() +// } +// +// /// This method starts the server and handles incoming connections. +// func run() async throws { +// let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) +// .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) +// .bind( +// host: self.host, +// port: self.port +// ) { channel in +// channel.eventLoop.makeCompletedFuture { +// let upgrader = NIOTypedWebSocketServerUpgrader( +// shouldUpgrade: { (channel, head) in +// channel.eventLoop.makeSucceededFuture(HTTPHeaders()) +// }, +// upgradePipelineHandler: { (channel, _) in +// channel.eventLoop.makeCompletedFuture { +// let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) +// return UpgradeResult.websocket(asyncChannel) +// } +// } +// ) +// +// let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( +// upgraders: [upgrader], +// notUpgradingCompletionHandler: { channel in +// channel.eventLoop.makeCompletedFuture { +// try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) +// let asyncChannel = try NIOAsyncChannel>(synchronouslyWrapping: channel) +// return UpgradeResult.notUpgraded(asyncChannel) +// } +// } +// ) +// +// let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( +// configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) +// ) +// +// return negotiationResultFuture +// } +// } +// +// // We are handling each incoming connection in a separate child task. It is important +// // to use a discarding task group here which automatically discards finished child tasks. +// // A normal task group retains all child tasks and their outputs in memory until they are +// // consumed by iterating the group or by exiting the group. Since, we are never consuming +// // the results of the group we need the group to automatically discard them; otherwise, this +// // would result in a memory leak over time. +// try await withThrowingDiscardingTaskGroup { group in +// for try await upgradeResult in channel.inbound { +// group.addTask { +// await self.handleUpgradeResult(upgradeResult) +// } +// } +// } +// } +// +// /// This method handles a single connection by echoing back all inbound data. +// private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { +// // Note that this method is non-throwing and we are catching any error. +// // We do this since we don't want to tear down the whole server when a single connection +// // encounters an error. +// do { +// switch try await upgradeResult.get() { +// case .websocket(let websocketChannel): +// print("Handling websocket connection") +// try await self.handleWebsocketChannel(websocketChannel) +// print("Done handling websocket connection") +// case .notUpgraded(let httpChannel): +// print("Handling HTTP connection") +// try await self.handleHTTPChannel(httpChannel) +// print("Done handling HTTP connection") +// } +// } catch { +// print("Hit error: \(error)") +// } +// } +// +// private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { +// try await withThrowingTaskGroup(of: Void.self) { group in +// group.addTask { +// for try await frame in channel.inbound { +// switch frame.opcode { +// case .ping: +// print("Received ping") +// var frameData = frame.data +// let maskingKey = frame.maskKey +// +// if let maskingKey = maskingKey { +// frameData.webSocketUnmask(maskingKey) +// } +// +// let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) +// try await channel.outbound.write(responseFrame) +// +// case .connectionClose: +// // This is an unsolicited close. We're going to send a response frame and +// // then, when we've sent it, close up shop. We should send back the close code the remote +// // peer sent us, unless they didn't send one at all. +// print("Received close") +// var data = frame.unmaskedData +// let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() +// let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) +// try await channel.outbound.write(closeFrame) +// return +// case .binary, .continuation, .pong: +// // We ignore these frames. +// break +// default: +// // Unknown frames are errors. +// return +// } +// } +// } +// +// group.addTask { +// // This is our main business logic where we are just sending the current time +// // every second. +// while true { +// // We can't really check for error here, but it's also not the purpose of the +// // example so let's not worry about it. +// let theTime = ContinuousClock().now +// var buffer = channel.channel.allocator.buffer(capacity: 12) +// buffer.writeString("\(theTime)") +// +// let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) +// +// print("Sending time") +// try await channel.outbound.write(frame) +// try await Task.sleep(for: .seconds(1)) +// } +// } +// +// try await group.next() +// group.cancelAll() +// } +// } +// +// +// private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { +// for try await requestPart in channel.inbound { +// // We're not interested in request bodies here: we're just serving up GET responses +// // to get the client to initiate a websocket request. +// guard case .head(let head) = requestPart else { +// return +// } +// +// // GETs only. +// guard case .GET = head.method else { +// try await self.respond405(writer: channel.outbound) +// return +// } +// +// var headers = HTTPHeaders() +// headers.add(name: "Content-Type", value: "text/html") +// headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) +// headers.add(name: "Connection", value: "close") +// let responseHead = HTTPResponseHead( +// version: .init(major: 1, minor: 1), +// status: .ok, +// headers: headers +// ) +// +// try await channel.outbound.write( +// contentsOf: [ +// .head(responseHead), +// .body(Self.responseBody), +// .end(nil) +// ] +// ) +// } +// } +// +// private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { +// var headers = HTTPHeaders() +// headers.add(name: "Connection", value: "close") +// headers.add(name: "Content-Length", value: "0") +// let head = HTTPResponseHead( +// version: .http1_1, +// status: .methodNotAllowed, +// headers: headers +// ) +// +// try await writer.write( +// contentsOf: [ +// .head(head), +// .end(nil) +// ] +// ) +// } +//} +// +//final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { +// typealias OutboundIn = HTTPPart +// typealias OutboundOut = HTTPServerResponsePart +// +// func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { +// let part = self.unwrapOutboundIn(data) +// switch part { +// case .head(let head): +// context.write(self.wrapOutboundOut(.head(head)), promise: promise) +// case .body(let buffer): +// context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) +// case .end(let trailers): +// context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) +// } +// } +//} #else @main diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 89f2b64c40..7bdd4c3622 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -33,7 +33,7 @@ extension EmbeddedChannel { } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader {} private final class SuccessfulClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String @@ -282,9 +282,8 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) private func assertPipelineContainsUpgradeHandler(channel: Channel) { let handler = try? channel.pipeline.syncOperations.handler(type: NIOHTTPClientUpgradeHandler.self) - let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) - XCTAssertTrue(handler != nil || typedHandler != nil) + XCTAssertTrue(handler != nil) } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) @@ -947,233 +946,3 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } } - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { - override func setUpClientChannel( - clientHTTPHandler: RemovableChannelHandler, - clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], - _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void - ) throws -> EmbeddedChannel { - - let channel = EmbeddedChannel() - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") - headers.add(name: "Content-Length", value: "\(0)") - - let requestHead = HTTPRequestHead( - version: .http1_1, - method: .GET, - uri: "/", - headers: headers - ) - - let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array(clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader }) - - let config = NIOTypedHTTPClientUpgradeConfiguration( - upgradeRequestHead: requestHead, - upgraders: upgraders - ) { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(clientHTTPHandler) - }.map { _ in - false - } - } - var configuration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: config) - configuration.leftOverBytesStrategy = .forwardBytes - let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: configuration) - let context = try channel.pipeline.syncOperations.context(handlerType: NIOTypedHTTPClientUpgradeHandler.self) - - try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) - .wait() - upgradeResult.whenSuccess { result in - if result { - upgradeCompletionHandler(context) - } - } - - return channel - } - - // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour - - override func testUpgradeOnlyHandlesKnownProtocols() throws { - var upgradeHandlerCallbackFired = false - - let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") - let clientHandler = RecordingHTTPHandler() - - // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true - } - defer { - XCTAssertNoThrow(try clientChannel.finish()) - } - - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) - } - - clientChannel.embeddedEventLoop.run() - - // Should fail with error (response is malformed) and remove upgrader from pipeline. - - // Check that the http elements are not removed from the pipeline. - clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) - clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - - // Check that the HTTP handler received its response. - XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) - // Check an error is reported - XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) - - XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) - } - - override func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { - let upgradeProtocol = "myProto" - - var upgradeHandlerCallbackFired = false - - let clientUpgrader = DenyingClientUpgrader(forProtocol: upgradeProtocol) - let clientHandler = RecordingHTTPHandler() - - // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true - } - defer { - XCTAssertNoThrow(try clientChannel.finish()) - } - - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .upgraderDeniedUpgrade) - } - - clientChannel.embeddedEventLoop.run() - - // Should fail with error (response is denied) and remove upgrader from pipeline. - - // Check that the http elements are not removed from the pipeline. - clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) - clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - - XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) - - // Check that the HTTP handler received its response. - XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) - - // Check an error is reported - XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) - - XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) - } - - override func testFiresOutboundErrorDuringAddingHandlers() throws { - let upgradeProtocol = "myProto" - var errorOnAdditionalChannelWrite: Error? - var upgradeHandlerCallbackFired = false - - let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) - let clientHandler = RecordingHTTPHandler() - - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { (context) in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true - } - defer { - XCTAssertNoThrow(try clientChannel.finish()) - } - - // Push the successful server response. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" - XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - - let promise = clientChannel.eventLoop.makePromise(of: Void.self) - - promise.futureResult.whenFailure() { error in - errorOnAdditionalChannelWrite = error - } - - // Send another outbound request during the upgrade. - let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") - let secondRequest: HTTPClientRequestPart = .head(requestHead) - clientChannel.writeAndFlush(secondRequest, promise: promise) - - clientChannel.embeddedEventLoop.run() - - let promiseError = errorOnAdditionalChannelWrite as! NIOHTTPClientUpgradeError - XCTAssertEqual(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade, promiseError) - - // Soundness check that the upgrade was delayed. - XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) - - // Upgrade now. - clientUpgrader.unblockUpgrade() - clientChannel.embeddedEventLoop.run() - - // Check that the upgrade was still successful, despite the interruption. - XCTAssert(upgradeHandlerCallbackFired) - XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) - } - - override func testUpgradeResponseMissingAllProtocols() throws { - var upgradeHandlerCallbackFired = false - - let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") - let clientHandler = RecordingHTTPHandler() - - // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true - } - defer { - XCTAssertNoThrow(try clientChannel.finish()) - } - - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) - } - - clientChannel.embeddedEventLoop.run() - - // Should fail with error (response is malformed) and remove upgrader from pipeline. - - // Check that the http elements are not removed from the pipeline. - clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) - clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - - // Check that the HTTP handler received its response. - XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) - // Check an error is reported - XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) - - XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) - } -} diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 378d64d0a8..4393adcfc6 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -36,11 +36,7 @@ extension ChannelPipeline { @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func assertContainsUpgrader() { - do { - _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() - } catch { - self.assertContains(handlerType: HTTPServerUpgradeHandler.self) - } + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) } func assertContains(handlerType: Handler.Type) { @@ -63,15 +59,7 @@ extension ChannelPipeline { // handler present, keep waiting usleep(50) } catch ChannelPipelineError.notFound { - // Checking if the typed variant is present - do { - _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() - // handler present, keep waiting - usleep(50) - } catch ChannelPipelineError.notFound { - // No upgrader, we're good. - return - } + return } } @@ -175,7 +163,7 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader {} private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String @@ -1551,503 +1539,3 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.pipeline.assertContainsUpgrader() } } - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { - fileprivate override func setUpTestWithAutoremoval( - pipelining: Bool = false, - upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], - extraHandlers: [ChannelHandler], - notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler - ) throws -> (Channel, Channel, Channel) { - let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self) - let serverChannelFuture = ServerBootstrap(group: Self.eventLoop) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.eventLoop.makeCompletedFuture { - connectionChannelPromise.succeed(channel) - var configuration = NIOUpgradableHTTPServerPipelineConfiguration( - upgradeConfiguration: .init( - upgraders: upgraders.map { $0 as! any NIOTypedHTTPServerProtocolUpgrader }, - notUpgradingCompletionHandler: { notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) } - ) - ) - configuration.enablePipelining = pipelining - return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(configuration: configuration) - .flatMap { result in - if result { - return channel.pipeline.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self) - .map { - upgradeCompletionHandler($0) - } - } else { - return channel.eventLoop.makeSucceededVoidFuture() - } - } - } - .flatMap { _ in - let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) - } - }.bind(host: "127.0.0.1", port: 0) - let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannelFuture.wait().localAddress!) - return (try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) - } - - func testNotUpgrading() throws { - let notUpgraderCbFired = UnsafeMutableTransferBox(false) - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in } - - let (_, client, connectedServer) = try setUpTestWithAutoremoval( - upgraders: [upgrader], - extraHandlers: [], - notUpgradingHandler: { channel in - notUpgraderCbFired.wrappedValue = true - // We're closing the connection now. - channel.close(promise: nil) - return channel.eventLoop.makeSucceededFuture(true) - } - ) { _ in } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - XCTAssertEqual(resultString, "") - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that the not upgrader got called. - XCTAssert(notUpgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.assertDoesNotContainUpgrader() - } - - // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour - - override func testSimpleUpgradeSucceeds() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - let upgradeRequest = UnsafeMutableTransferBox(nil) - let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) - let upgraderCbFired = UnsafeMutableTransferBox(false) - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in - // This is called before completion block. - upgradeRequest.wrappedValue = req - upgradeHandlerCbFired.wrappedValue = true - - XCTAssert(upgradeHandlerCbFired.wrappedValue) - upgraderCbFired.wrappedValue = true - } - - let (_, client, connectedServer) = try setUpTestWithAutoremoval( - upgraders: [upgrader], - extraHandlers: [] - ) { (context) in - // This is called before the upgrader gets called. - XCTAssertNotNil(upgradeRequest.wrappedValue) - upgradeHandlerCbFired.wrappedValue = true - - // We're closing the connection now. - context.close(promise: nil) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that everything got called. Their own callbacks assert - // that the ordering was correct. - XCTAssert(upgradeHandlerCbFired.wrappedValue) - XCTAssert(upgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.assertDoesNotContainUpgrader() - } - - override func testUpgradeRespectsClientPreference() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - let upgradeRequest = UnsafeMutableTransferBox(nil) - let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) - let upgraderCbFired = UnsafeMutableTransferBox(false) - - let explodingUpgrader = ExplodingUpgrader(forProtocol: "exploder") - let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in - upgradeRequest.wrappedValue = req - XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) - upgraderCbFired.wrappedValue = true - } - - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: []) { context in - // This is called before the upgrader gets called. - XCTAssertNotNil(upgradeRequest.wrappedValue) - upgradeHandlerCbFired.wrappedValue = true - - // We're closing the connection now. - context.close(promise: nil) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that everything got called. Their own callbacks assert - // that the ordering was correct. - XCTAssert(upgradeHandlerCbFired.wrappedValue) - XCTAssert(upgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.waitForUpgraderToBeRemoved() - } - - override func testUpgraderCanRejectUpgradeForPersonalReasons() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - let upgradeRequest = UnsafeMutableTransferBox(nil) - let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) - let upgraderCbFired = UnsafeMutableTransferBox(false) - - let explodingUpgrader = UpgraderSaysNo(forProtocol: "noproto") - let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in - upgradeRequest.wrappedValue = req - XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) - upgraderCbFired.wrappedValue = true - } - let errorCatcher = ErrorSaver() - - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: [errorCatcher]) { context in - // This is called before the upgrader gets called. - XCTAssertNotNil(upgradeRequest.wrappedValue) - upgradeHandlerCbFired.wrappedValue = true - - // We're closing the connection now. - context.close(promise: nil) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that everything got called. Their own callbacks assert - // that the ordering was correct. - XCTAssert(upgradeHandlerCbFired.wrappedValue) - XCTAssert(upgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.waitForUpgraderToBeRemoved() - - // And we want to confirm we saved the error. - XCTAssertEqual(errorCatcher.errors.count, 1) - - switch(errorCatcher.errors[0]) { - case UpgraderSaysNo.No.no: - break - default: - XCTFail("Unexpected error: \(errorCatcher.errors[0])") - } - } - - override func testUpgradeWithUpgradePayloadInlineWithRequestWorks() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - enum ReceivedTheWrongThingError: Error { case error } - let upgradeRequest = UnsafeMutableTransferBox(nil) - let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) - let upgraderCbFired = UnsafeMutableTransferBox(false) - - class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { - typealias InboundIn = ByteBuffer - typealias OutboundIn = Never - typealias OutboundOut = Never - - enum State { - case fresh - case added - case inlineDataRead - case extraDataRead - case closed - } - - private let firstByteDonePromise: EventLoopPromise - private let secondByteDonePromise: EventLoopPromise - private let allDonePromise: EventLoopPromise - private var state = State.fresh - - init(firstByteDonePromise: EventLoopPromise, - secondByteDonePromise: EventLoopPromise, - allDonePromise: EventLoopPromise) { - self.firstByteDonePromise = firstByteDonePromise - self.secondByteDonePromise = secondByteDonePromise - self.allDonePromise = allDonePromise - } - - func handlerAdded(context: ChannelHandlerContext) { - XCTAssertEqual(.fresh, self.state) - self.state = .added - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - var buf = self.unwrapInboundIn(data) - XCTAssertEqual(1, buf.readableBytes) - let stringRead = buf.readString(length: buf.readableBytes) - switch self.state { - case .added: - XCTAssertEqual("A", stringRead) - self.state = .inlineDataRead - if stringRead == .some("A") { - self.firstByteDonePromise.succeed(()) - } else { - self.firstByteDonePromise.fail(ReceivedTheWrongThingError.error) - } - case .inlineDataRead: - XCTAssertEqual("B", stringRead) - self.state = .extraDataRead - context.channel.close(promise: nil) - if stringRead == .some("B") { - self.secondByteDonePromise.succeed(()) - } else { - self.secondByteDonePromise.fail(ReceivedTheWrongThingError.error) - } - default: - XCTFail("channel read in wrong state \(self.state)") - } - } - - func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { - XCTAssertEqual(.extraDataRead, self.state) - self.state = .closed - context.close(mode: mode, promise: promise) - - self.allDonePromise.succeed(()) - } - } - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in - upgradeRequest.wrappedValue = req - XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) - upgraderCbFired.wrappedValue = true - } - - let promiseGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { - XCTAssertNoThrow(try promiseGroup.syncShutdownGracefully()) - } - let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) - let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) - let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in - // This is called before the upgrader gets called. - XCTAssertNotNil(upgradeRequest.wrappedValue) - upgradeHandlerCbFired.wrappedValue = true - - _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, - secondByteDonePromise: secondByteDonePromise, - allDonePromise: allDonePromise)) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" - request += "A" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - XCTAssertNoThrow(try firstByteDonePromise.futureResult.wait() as Void) - - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: "B"))).wait()) - - XCTAssertNoThrow(try secondByteDonePromise.futureResult.wait() as Void) - - XCTAssertNoThrow(try allDonePromise.futureResult.wait() as Void) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that everything got called. Their own callbacks assert - // that the ordering was correct. - XCTAssert(upgradeHandlerCbFired.wrappedValue) - XCTAssert(upgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.assertDoesNotContainUpgrader() - - XCTAssertNoThrow(try allDonePromise.futureResult.wait()) - } - - override func testWeTolerateUpgradeFuturesFromWrongEventLoops() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - let upgradeRequest = UnsafeMutableTransferBox(nil) - let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) - let upgraderCbFired = UnsafeMutableTransferBox(false) - let otherELG = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { - XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) - } - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", - requiringHeaders: ["kafkaesque"], - buildUpgradeResponseFuture: { - // this is the wrong EL - otherELG.next().makeSucceededFuture($1) - }) { req in - upgradeRequest.wrappedValue = req - XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) - upgraderCbFired.wrappedValue = true - } - - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in - // This is called before the upgrader gets called. - XCTAssertNotNil(upgradeRequest.wrappedValue) - upgradeHandlerCbFired.wrappedValue = true - - // We're closing the connection now. - context.close(promise: nil) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we want to assert that everything got called. Their own callbacks assert - // that the ordering was correct. - XCTAssert(upgradeHandlerCbFired.wrappedValue) - XCTAssert(upgraderCbFired.wrappedValue) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.assertDoesNotContainUpgrader() - } - - override func testUpgradeFiresUserEvent() throws { - // This test is different since we call the completionHandler after the upgrader - // modified the pipeline in the typed version. - let eventSaver = UnsafeTransfer(UserEventSaver()) - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in - XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) - } - - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [eventSaver.wrappedValue]) { context in - XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) - context.close(promise: nil) - } - - - let completePromise = Self.eventLoop.makePromise(of: Void.self) - let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) - completePromise.succeed(()) - } - XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - - // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" - XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - - // Let the machinery do its thing. - XCTAssertNoThrow(try completePromise.futureResult.wait()) - - // At this time we should have received one user event. We schedule this onto the - // event loop to guarantee thread safety. - XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { - XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) - if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { - XCTAssertEqual(proto, "myproto") - XCTAssertEqual(req.method, .OPTIONS) - XCTAssertEqual(req.uri, "*") - XCTAssertEqual(req.version, .http1_1) - } else { - XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") - } - }.futureResult.wait()) - - // We also want to confirm that the upgrade handler is no longer in the pipeline. - try connectedServer.pipeline.waitForUpgraderToBeRemoved() - } -} diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index bd9cef6936..137e897988 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -404,214 +404,3 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } } - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests { - func setUpClientChannel( - clientUpgraders: [any NIOTypedHTTPClientProtocolUpgrader], - notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture - ) throws -> (EmbeddedChannel, EventLoopFuture) { - - let channel = EmbeddedChannel() - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") - headers.add(name: "Content-Length", value: "\(0)") - - let requestHead = HTTPRequestHead( - version: .http1_1, - method: .GET, - uri: "/", - headers: headers - ) - - let config = NIOTypedHTTPClientUpgradeConfiguration( - upgradeRequestHead: requestHead, - upgraders: clientUpgraders, - notUpgradingCompletionHandler: notUpgradingCompletionHandler - ) - - let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: .init(upgradeConfiguration: config)) - - try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) - .wait() - - return (channel, upgradeResult) - } - - override func testSimpleUpgradeSucceeds() throws { - let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" - let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" - - let basicUpgrader = NIOTypedWebSocketClientUpgrader( - requestKey: requestKey, - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketRecorderHandler()) - }) - - // The process should kick-off independently by sending the upgrade request to the server. - let (clientChannel, upgradeResult) = try setUpClientChannel( - clientUpgraders: [basicUpgrader], - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - - // Read the server request. - if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") - } else { - XCTFail() - } - - // Push the successful server response. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" - - XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - - clientChannel.embeddedEventLoop.run() - - // Once upgraded, validate the http pipeline has been removed. - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) - - // Check that the pipeline now has the correct websocket handlers added. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: WebSocketFrameEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: WebSocketRecorderHandler.self)) - - try upgradeResult.wait() - - // Close the pipeline. - XCTAssertNoThrow(try clientChannel.close().wait()) - } - - override func testRejectUpgradeIfMissingAcceptKey() throws { - let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" - - let basicUpgrader = NIOTypedWebSocketClientUpgrader( - requestKey: requestKey, - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketRecorderHandler()) - }) - - // The process should kick-off independently by sending the upgrade request to the server. - let (clientChannel, upgradeResult) = try setUpClientChannel( - clientUpgraders: [basicUpgrader], - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - - // Push the successful server response but with a missing accept key. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n" - - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) - } - - // Close the pipeline. - XCTAssertNoThrow(try clientChannel.close().wait()) - - XCTAssertThrowsError(try upgradeResult.wait()) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) - } - } - - override func testRejectUpgradeIfIncorrectAcceptKey() throws { - let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" - let responseKey = "notACorrectKeyL1am=F1y=nn=" - - let basicUpgrader = NIOTypedWebSocketClientUpgrader( - requestKey: requestKey, - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketRecorderHandler()) - }) - - // The process should kick-off independently by sending the upgrade request to the server. - let (clientChannel, upgradeResult) = try setUpClientChannel( - clientUpgraders: [basicUpgrader], - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - - // Push the successful server response but with an incorrect response key. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" - - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) - } - - // Close the pipeline. - XCTAssertNoThrow(try clientChannel.close().wait()) - - XCTAssertThrowsError(try upgradeResult.wait()) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) - } - } - - override func testRejectUpgradeIfNotWebsocket() throws { - let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" - let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" - - let basicUpgrader = NIOTypedWebSocketClientUpgrader( - requestKey: requestKey, - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketRecorderHandler()) - }) - - // The process should kick-off independently by sending the upgrade request to the server. - let (clientChannel, upgradeResult) = try setUpClientChannel( - clientUpgraders: [basicUpgrader], - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - - // Push the successful server response with an incorrect protocol. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myProtocol\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" - - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) - } - - // Close the pipeline. - XCTAssertNoThrow(try clientChannel.close().wait()) - - XCTAssertThrowsError(try upgradeResult.wait()) { error in - XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) - } - } - - override fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { - let handler = WebSocketRecorderHandler() - - let basicUpgrader = NIOTypedWebSocketClientUpgrader( - requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(handler) - }) - - // The process should kick-off independently by sending the upgrade request to the server. - let (clientChannel, upgradeResult) = try setUpClientChannel( - clientUpgraders: [basicUpgrader], - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - - // Push the successful server response. - let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:yKEqitDFPE81FyIhKTm+ojBqigk=\r\n\r\n" - - XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - - clientChannel.embeddedEventLoop.run() - - // We now have a successful upgrade, clear the output channels read to test the frames. - XCTAssertNoThrow(try clientChannel.readOutbound(as: ByteBuffer.self)) - - clientChannel.embeddedEventLoop.run() - - try upgradeResult.wait() - - return (clientChannel, handler) - } -} diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index 44246e3ab0..2a1a3c6980 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -526,30 +526,3 @@ class WebSocketServerEndToEndTests: XCTestCase { XCTAssertNoThrow(XCTAssertEqual([], try server.readAllOutboundBytes())) } } - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { - override func createTestFixtures( - upgraders: [WebSocketServerUpgraderConfiguration] - ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { - let loop = EmbeddedEventLoop() - let serverChannel = EmbeddedChannel(loop: loop) - let upgraders = upgraders.map { NIOTypedWebSocketServerUpgrader( - maxFrameSize: $0.maxFrameSize, - enableAutomaticErrorHandling: $0.automaticErrorHandling, - shouldUpgrade: $0.shouldUpgrade, - upgradePipelineHandler: $0.upgradePipelineHandler - )} - - XCTAssertNoThrow(try serverChannel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( - configuration: .init( - upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration( - upgraders: upgraders, - notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } - ) - ) - )) - let clientChannel = EmbeddedChannel(loop: loop) - return (loop: loop, serverChannel: serverChannel, clientChannel: clientChannel) - } -} From 9497e442486aab515c8486ef8153a506f93a5032 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Fri, 27 Oct 2023 17:52:46 +0100 Subject: [PATCH 41/64] Fix exclusive access violation in `NIOAsyncChannelOutboundWriterHandler` (#2580) # Motivation We were setting `self.sink = nil` in the `NIOAsyncChannelOutboundWriterHandler` twice in the same call stack which is an exclusivity violation. This happens because the first `self.sink = nil` triggers the `didTerminate` delegate call which again triggered `self.sink = nil`. # Modification This PR changes the code to only call `self.sink?.finish()` and only sets the `sink` to `nil` in the `didTerminate` implementation. This follows what we do for the inbound handler implementation. I also added a test that triggers this exclusivity violation. # Result No more exclusivity violations in our code. --- .../AsyncChannelOutboundWriterHandler.swift | 2 +- .../AsyncChannel/AsyncChannelTests.swift | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index 2a9de1328c..bae110c61c 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -133,7 +133,7 @@ internal final class NIOAsyncChannelOutboundWriterHandler @inlinable func handlerRemoved(context: ChannelHandlerContext) { self.context = nil - self.sink = nil + self.sink?.finish() } @inlinable diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index c5feb91e81..1047e54437 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -18,6 +18,25 @@ import NIOEmbedded import XCTest final class AsyncChannelTests: XCTestCase { + func testAsyncChannelCloseOnWrite() async throws { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + final class CloseOnWriteHandler: ChannelOutboundHandler { + typealias OutboundIn = String + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.close(promise: promise) + } + } + let channel = NIOAsyncTestingChannel() + let wrapped = try await channel.testingEventLoop.executeInContext { + try channel.pipeline.syncOperations.addHandler(CloseOnWriteHandler()) + return try NIOAsyncChannel(synchronouslyWrapping: channel) + } + + try await wrapped.outbound.write("Test") + try await channel.closeFuture.get() + } + func testAsyncChannelBasicFunctionality() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() From 853522d90871b4b63262843196685795b5008c46 Mon Sep 17 00:00:00 2001 From: Rick Newton-Rogers Date: Wed, 1 Nov 2023 10:34:16 +0000 Subject: [PATCH 42/64] use feature-specific guard for @retroactive (#2581) Motivation: We should use the more granular guard for uses of @retroactive rather than coarse swift versions to guard against corner-cases Modifications: Switch `#if compiler(>=5.11)` for `#if hasFeature(RetroactiveAttribute)` Result: No change in most cases, more protected against corner-cases. --- Sources/NIOFoundationCompat/ByteBuffer-foundation.swift | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift index fd6362bf7f..11d4c7e183 100644 --- a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift +++ b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift @@ -375,7 +375,8 @@ extension ByteBufferAllocator { } // MARK: - Conformances -#if compiler(>=5.11) +#if swift(>=5.8) +#if hasFeature(RetroactiveAttribute) extension ByteBufferView: @retroactive ContiguousBytes {} extension ByteBufferView: @retroactive DataProtocol {} extension ByteBufferView: @retroactive MutableDataProtocol {} @@ -384,6 +385,11 @@ extension ByteBufferView: ContiguousBytes {} extension ByteBufferView: DataProtocol {} extension ByteBufferView: MutableDataProtocol {} #endif +#else +extension ByteBufferView: ContiguousBytes {} +extension ByteBufferView: DataProtocol {} +extension ByteBufferView: MutableDataProtocol {} +#endif extension ByteBufferView { public typealias Regions = CollectionOfOne From 23977a932f90d3d7066929cb9a3d6074158d6a50 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Mon, 6 Nov 2023 12:38:18 +0000 Subject: [PATCH 43/64] Fix spelling of retroactive guard (#2586) Motivation: `hasFeature(RetroactiveAttribute)` doesn't work as expected, but `$RetroactiveAttribute` does. Modifications: Switch from `hasFeature` to `$RetroactiveAttribute`. Result: - `@retroactive` is applied appropriately --- Sources/NIOFoundationCompat/ByteBuffer-foundation.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift index 11d4c7e183..0ee8a685dc 100644 --- a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift +++ b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift @@ -376,7 +376,7 @@ extension ByteBufferAllocator { // MARK: - Conformances #if swift(>=5.8) -#if hasFeature(RetroactiveAttribute) +#if $RetroactiveAttribute extension ByteBufferView: @retroactive ContiguousBytes {} extension ByteBufferView: @retroactive DataProtocol {} extension ByteBufferView: @retroactive MutableDataProtocol {} From 035141dddb9e35ec7ef1f45cc64ec43e53161f3c Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 8 Nov 2023 10:38:49 +0000 Subject: [PATCH 44/64] Fix reordering/reentrancy bug in `NIOAsyncWriter` + `NIOAsyncChannel` (#2587) * Fix reordering/reentrancy bug in `NIOAsyncWriter` + `NIOAsyncChannel` # Motivation While testing the latest async interfaces we found a potential reordering/reentrancy bug in the `NIOAsyncWriter`. This was caused due to our latest performance changes where we fast-pathed calls in `didYield` to not hop. The problem was in the following flow: 1. Task 1: Calls `outbound.write()` -> which led an `EventLoop` enqueue with `didYield` 2. Task 1: Calls `outbound.write()` -> which led an `EventLoop` enqueue with `didYield` 3. EventLoop: While processing the write from 1. the channel became **not** writable 4. Task 1: Calls `outbound.write()` -> which lead to buffering the write in the writer's state machine since we are **not** writable 5. EventLoop: While still processing the write from 1. the channel became writable again -> We informed the `NIOAsyncWriter` about this which unbuffered the write in 4. that was stored in the state machine and call `didYield`. Since, we are on the EventLoop already we processed the write right away The above flow show-cases a flow where we reordered the write in 2. and 4. # Modification This PR fixes the above issue while upholding a few constraints: 1. Produce as few context switches as possible 2. Minimize allocations I tried different approaches but in the end decided to do the following: 1. Make sure to never call `didYield/didTerminate` from calls on the `NIOAsyncWriter.Sink` 2. Don't coalesce the elements of different writes in the `NIOAsyncWriter` but rather use the suspended tasks to retry a write after they were suspended. I choose to do this since I wanted to avoid any allocation (remember writers are `some Sequence`) and because we assume that continuous contention in a multi producer pattern is low. 3. Make sure that `Sink.finish()` is terminal and does not lead to a `didTerminate` event. This is in line with 1. One important thing to call out, our `writer.finish()` method is not `async` we have to buffer the finish event and deliver it with the yield that got buffered before we transitioned to `writerFinished`. # Result No more reordering/reentrancy problems in `NIOAsyncWriter` or `NIOAsyncChannel`. * Code review --- .../AsyncSequences/NIOAsyncWriter.swift | 661 ++++++++---------- .../AsyncSequences/NIOAsyncWriterTests.swift | 79 +-- 2 files changed, 325 insertions(+), 415 deletions(-) diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index a2e6fdae67..5e113fae17 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -31,26 +31,20 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// /// If the ``NIOAsyncWriter`` was writable when the sequence was yielded, the sequence will be forwarded /// right away to the delegate. If the ``NIOAsyncWriter`` was _NOT_ writable then the sequence will be buffered - /// until the ``NIOAsyncWriter`` becomes writable again. All buffered writes, while the ``NIOAsyncWriter`` is not writable, - /// will be coalesced into a single sequence. + /// until the ``NIOAsyncWriter`` becomes writable again. /// /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. - /// This might trigger more calls to one of the `didYield` methods and it is up to the delegate to make sure that this reentrancy is - /// correctly guarded against. func didYield(contentsOf sequence: Deque) /// This method is called once a single element was yielded to the ``NIOAsyncWriter``. /// /// If the ``NIOAsyncWriter`` was writable when the sequence was yielded, the sequence will be forwarded /// right away to the delegate. If the ``NIOAsyncWriter`` was _NOT_ writable then the sequence will be buffered - /// until the ``NIOAsyncWriter`` becomes writable again. All buffered writes, while the ``NIOAsyncWriter`` is not writable, - /// will be coalesced into a single sequence. + /// until the ``NIOAsyncWriter`` becomes writable again. /// /// - Note: This a fast path that you can optionally implement. By default this will just call ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)``. /// /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. - /// This might trigger more calls to one of the `didYield` methods and it is up to the delegate to make sure that this reentrancy is - /// correctly guarded against. func didYield(_ element: Element) /// This method is called once the ``NIOAsyncWriter`` is terminated. @@ -59,13 +53,11 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// - The ``NIOAsyncWriter`` is deinited and all yielded elements have been delivered to the delegate. /// - ``NIOAsyncWriter/finish()`` is called and all yielded elements have been delivered to the delegate. /// - ``NIOAsyncWriter/finish(error:)`` is called and all yielded elements have been delivered to the delegate. - /// - ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` is called. /// - /// - Note: This is guaranteed to be called _exactly_ once. + /// - Note: This is guaranteed to be called _at most_ once. /// /// - Parameter error: The error that terminated the ``NIOAsyncWriter``. If the writer was terminated without an - /// error this value is `nil`. This can be either the error passed to ``NIOAsyncWriter/finish(error:)`` or - /// to ``NIOAsyncWriter/Sink/finish(error:)``. + /// error this value is `nil`. This can be either the error passed to ``NIOAsyncWriter/finish(error:)``. func didTerminate(error: Error?) } @@ -231,15 +223,10 @@ public struct NIOAsyncWriter< /// /// If the ``NIOAsyncWriter`` is writable the sequence will get forwarded to the ``NIOAsyncWriterSinkDelegate`` immediately. /// Otherwise, the sequence will be buffered and the call to ``NIOAsyncWriter/yield(contentsOf:)`` will get suspended until the ``NIOAsyncWriter`` - /// becomes writable again. If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(contentsOf:)`` - /// will be resumed. + /// becomes writable again. /// - /// If the ``NIOAsyncWriter/finish()`` or ``NIOAsyncWriter/finish(error:)`` method is called while a call to - /// ``NIOAsyncWriter/yield(contentsOf:)`` is suspended then the call will be resumed and the yielded sequence will be kept buffered. - /// - /// If the ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` method is called while - /// a call to ``NIOAsyncWriter/yield(contentsOf:)`` is suspended then the call will be resumed with an error and the - /// yielded sequence is dropped. + /// If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(contentsOf:)`` + /// will be resumed. Consequently, the provided elements will not be yielded. /// /// This can be called more than once and from multiple `Task`s at the same time. /// @@ -253,22 +240,17 @@ public struct NIOAsyncWriter< /// /// If the ``NIOAsyncWriter`` is writable the element will get forwarded to the ``NIOAsyncWriterSinkDelegate`` immediately. /// Otherwise, the element will be buffered and the call to ``NIOAsyncWriter/yield(_:)`` will get suspended until the ``NIOAsyncWriter`` - /// becomes writable again. If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(_:)`` - /// will be resumed. - /// - /// If the ``NIOAsyncWriter/finish()`` or ``NIOAsyncWriter/finish(error:)`` method is called while a call to - /// ``NIOAsyncWriter/yield(_:)`` is suspended then the call will be resumed and the yielded sequence will be kept buffered. + /// becomes writable again. /// - /// If the ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` method is called while - /// a call to ``NIOAsyncWriter/yield(_:)`` is suspended then the call will be resumed with an error and the - /// yielded sequence is dropped. + /// If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(_:)`` + /// will be resumed. Consequently, the provided element will not be yielded. /// /// This can be called more than once and from multiple `Task`s at the same time. /// /// - Parameter element: The element to yield. @inlinable public func yield(_ element: Element) async throws { - try await self._storage.yield(element) + try await self._storage.yield(contentsOf: CollectionOfOne(element)) } /// Finishes the writer. @@ -277,7 +259,7 @@ public struct NIOAsyncWriter< /// or ``NIOAsyncWriter/yield(_:)`` will be resumed. Any subsequent calls to ``NIOAsyncWriter/yield(contentsOf:)`` /// or ``NIOAsyncWriter/yield(_:)`` will throw. /// - /// Any element that have been yielded elements before the writer has been finished which have not been delivered yet are continued + /// Any element that have been yielded before the writer has been finished which have not been delivered yet are continued /// to be buffered and will be delivered once the writer becomes writable again. /// /// - Note: Calling this function more than once has no effect. @@ -292,7 +274,7 @@ public struct NIOAsyncWriter< /// or ``NIOAsyncWriter/yield(_:)`` will be resumed. Any subsequent calls to ``NIOAsyncWriter/yield(contentsOf:)`` /// or ``NIOAsyncWriter/yield(_:)`` will throw. /// - /// Any element that have been yielded elements before the writer has been finished which have not been delivered yet are continued + /// Any element that have been yielded before the writer has been finished which have not been delivered yet are continued /// to be buffered and will be delivered once the writer becomes writable again. /// /// - Note: Calling this function more than once has no effect. @@ -458,22 +440,8 @@ extension NIOAsyncWriter { } switch action { - case .callDidYieldAndResumeContinuations(let delegate, let elements, let suspendedYields): - delegate.didYield(contentsOf: elements) - suspendedYields.forEach { $0.continuation.resume() } - self.unbufferQueuedEvents() - - case .callDidYieldElementAndResumeContinuations(let delegate, let element, let suspendedYields): - delegate.didYield(element) - suspendedYields.forEach { $0.continuation.resume() } - self.unbufferQueuedEvents() - case .resumeContinuations(let suspendedYields): - suspendedYields.forEach { $0.continuation.resume() } - - case .callDidYieldAndDidTerminate(let delegate, let elements, let error): - delegate.didYield(contentsOf: elements) - delegate.didTerminate(error: error) + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } case .none: return @@ -483,85 +451,42 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal func yield(contentsOf sequence: S) async throws where S.Element == Element { let yieldID = self._yieldIDGenerator.generateUniqueYieldID() - - try await withTaskCancellationHandler { - // We are manually locking here to hold the lock across the withCheckedContinuation call - self._lock.lock() - - let action = self._stateMachine.yield(contentsOf: sequence, yieldID: yieldID) - - switch action { - case .callDidYield(let delegate): - // We are allocating a new Deque for every write here - self._lock.unlock() - delegate.didYield(contentsOf: Deque(sequence)) - self.unbufferQueuedEvents() - - case .returnNormally: - self._lock.unlock() + while true { + switch try await self._yield(contentsOf: sequence, yieldID: yieldID) { + case .retry: + continue + case .yielded: return - - case .throwError(let error): - self._lock.unlock() - throw error - - case .suspendTask: - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self._stateMachine.yield( - contentsOf: sequence, - continuation: continuation, - yieldID: yieldID - ) - - self._lock.unlock() - } - } - } onCancel: { - // We must not resume the continuation while holding the lock - // because it can deadlock in combination with the underlying ulock - // in cases where we race with a cancellation handler - let action = self._lock.withLock { - self._stateMachine.cancel(yieldID: yieldID) - } - - switch action { - case .resumeContinuation(let continuation): - continuation.resume() - - case .none: - break } } } @inlinable - /* fileprivate */ internal func yield(_ element: Element) async throws { - let yieldID = self._yieldIDGenerator.generateUniqueYieldID() + /* fileprivate */ internal func _yield(contentsOf sequence: S, yieldID: StateMachine.YieldID?) async throws -> StateMachine.YieldResult where S.Element == Element { + let yieldID = yieldID ?? self._yieldIDGenerator.generateUniqueYieldID() - try await withTaskCancellationHandler { + return try await withTaskCancellationHandler { // We are manually locking here to hold the lock across the withCheckedContinuation call self._lock.lock() - let action = self._stateMachine.yield(contentsOf: CollectionOfOne(element), yieldID: yieldID) + let action = self._stateMachine.yield(contentsOf: sequence, yieldID: yieldID) switch action { case .callDidYield(let delegate): + // We are allocating a new Deque for every write here self._lock.unlock() - delegate.didYield(element) + delegate.didYield(contentsOf: Deque(sequence)) self.unbufferQueuedEvents() - - case .returnNormally: - self._lock.unlock() - return + return .yielded case .throwError(let error): self._lock.unlock() throw error case .suspendTask: - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self._stateMachine.yield( - contentsOf: CollectionOfOne(element), + contentsOf: sequence, continuation: continuation, yieldID: yieldID ) @@ -578,8 +503,8 @@ extension NIOAsyncWriter { } switch action { - case .resumeContinuation(let continuation): - continuation.resume() + case .resumeContinuationWithCancellationError(let continuation): + continuation.resume(throwing: CancellationError()) case .none: break @@ -601,7 +526,7 @@ extension NIOAsyncWriter { delegate.didTerminate(error: error) case .resumeContinuations(let suspendedYields): - suspendedYields.forEach { $0.continuation.resume() } + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } case .none: break @@ -618,16 +543,9 @@ extension NIOAsyncWriter { } switch action { - case .callDidTerminate(let delegate, let error): - delegate.didTerminate(error: error) - case .resumeContinuationsWithError(let suspendedYields, let error): suspendedYields.forEach { $0.continuation.resume(throwing: error) } - case .resumeContinuationsWithErrorAndCallDidTerminate(let delegate, let suspendedYields, let error): - delegate.didTerminate(error: error) - suspendedYields.forEach { $0.continuation.resume(throwing: error) } - case .none: break } @@ -641,11 +559,9 @@ extension NIOAsyncWriter { case .callDidTerminate(let delegate, let error): delegate.didTerminate(error: error) - case .callDidYield(let delegate, let elements): - delegate.didYield(contentsOf: elements) - - case .callDidYieldElement(let delegate, let element): - delegate.didYield(element) + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } + return } } } @@ -667,18 +583,26 @@ extension NIOAsyncWriter { /// The yield's produced sequence of elements. /// The yield's continuation. @usableFromInline - var continuation: CheckedContinuation + var continuation: CheckedContinuation @inlinable - init(yieldID: YieldID, continuation: CheckedContinuation) { + init(yieldID: YieldID, continuation: CheckedContinuation) { self.yieldID = yieldID self.continuation = continuation } } + /// The internal result of a yield. + @usableFromInline + /* private */ internal enum YieldResult { + /// Indicates that the elements got yielded to the sink. + case yielded + /// Indicates that the yield should be retried. + case retry + } /// The current state of our ``NIOAsyncWriter``. @usableFromInline - /* private */ internal enum State { + /* private */ internal enum State: CustomStringConvertible { /// The initial state before either a call to ``NIOAsyncWriter/yield(contentsOf:)`` or /// ``NIOAsyncWriter/finish(completion:)`` happened. case initial( @@ -692,15 +616,19 @@ extension NIOAsyncWriter { inDelegateOutcall: Bool, cancelledYields: [YieldID], suspendedYields: _TinyArray, - elements: Deque, delegate: Delegate ) - /// The state once the writer finished and there are still elements that need to be delivered. This can happen if: + /// The state once the writer finished and there are still tasks that need to write. This can happen if: /// 1. The ``NIOAsyncWriter`` was deinited /// 2. ``NIOAsyncWriter/finish(completion:)`` was called. case writerFinished( - elements: Deque, + isWritable: Bool, + inDelegateOutcall: Bool, + suspendedYields: _TinyArray, + cancelledYields: [YieldID], + // These are the yields that have been enqueued before the writer got finished. + bufferedYieldIDs: _TinyArray, delegate: Delegate, error: Error? ) @@ -711,6 +639,22 @@ extension NIOAsyncWriter { /// Internal state to avoid CoW. case modifying + + @usableFromInline + var description: String { + switch self { + case .initial(let isWritable, _): + return "initial(isWritable: \(isWritable))" + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, _): + return "streaming(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), cancelledYields: \(cancelledYields.count), suspendedYields: \(suspendedYields.count))" + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, _, _): + return "writerFinished(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), suspendedYields: \(suspendedYields.count), cancelledYields: \(cancelledYields.count), bufferedYieldIDs: \(bufferedYieldIDs.count)" + case .finished: + return "finished" + case .modifying: + return "modifying" + } + } } /// The state machine's current state. @@ -744,16 +688,15 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(_, _, _, let suspendedYields, let elements, let delegate): + case .streaming(_, _, _, let suspendedYields, let delegate): // The writer got deinited after we started streaming. // This is normal and we need to transition to finished // and call the delegate. However, we should not have // any suspended yields because they MUST strongly retain // the writer. precondition(suspendedYields.isEmpty, "We have outstanding suspended yields") - precondition(elements.isEmpty, "We have buffered elements") - // We have no elements left and can transition to finished directly + // We can transition to finished directly self._state = .finished(sinkError: nil) return .callDidTerminate(delegate) @@ -770,23 +713,12 @@ extension NIOAsyncWriter { /// Actions returned by `setWritability()`. @usableFromInline enum SetWritabilityAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` should be called - /// and all continuations should be resumed. - case callDidYieldAndResumeContinuations(Delegate, Deque, _TinyArray) - /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(element:)`` should be called - /// and all continuations should be resumed. - case callDidYieldElementAndResumeContinuations(Delegate, Element, _TinyArray) - /// Indicates that all continuations should be resumed. + /// Indicates that all writer continuations should be resumed. case resumeContinuations(_TinyArray) - /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` and - /// ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``should be called. - case callDidYieldAndDidTerminate(Delegate, Deque, Error?) - /// Indicates that nothing should be done. - case none } @inlinable - /* fileprivate */ internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction { + /* fileprivate */ internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction? { switch self._state { case .initial(_, let delegate): // We just need to store the new writability state @@ -794,75 +726,31 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): if isWritable == newWritability { // The writability didn't change so we can just early exit here return .none } if newWritability && !inDelegateOutcall { - // We became writable again. This means we have to resume all the continuations - // and yield the values. - - if elements.count == 0 { - // We just have to resume the continuations - self._state = .streaming( - isWritable: newWritability, - inDelegateOutcall: inDelegateOutcall, - cancelledYields: cancelledYields, - suspendedYields: .init(), - elements: elements, - delegate: delegate - ) - - return .resumeContinuations(suspendedYields) - } else if elements.count == 1 { - // We have exactly one element in the buffer. Let's - // pop it and re-use the buffer right away - self._state = .modifying - - // This force-unwrap is safe since we just checked the count for 1. - let element = elements.popFirst()! - - self._state = .streaming( - isWritable: newWritability, - inDelegateOutcall: true, // We are now making a call to the delegate - cancelledYields: cancelledYields, - suspendedYields: .init(), - elements: elements, - delegate: delegate - ) - - return .callDidYieldElementAndResumeContinuations( - delegate, - element, - suspendedYields - ) - } else { - self._state = .streaming( - isWritable: newWritability, - inDelegateOutcall: true, // We are now making a call to the delegate - cancelledYields: cancelledYields, - suspendedYields: .init(), - elements: .init(), - delegate: delegate - ) + // We became writable again. This means we have to resume all the continuations. + self._state = .streaming( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: .init(), + delegate: delegate + ) - // We are taking the whole array of suspended yields and the deque of elements - // and allocate a new empty one. - // As a performance optimization we could always keep multiple arrays/deques and - // switch between them but I don't think this is the performance critical part. - return .callDidYieldAndResumeContinuations(delegate, elements, suspendedYields) - } + return .resumeContinuations(suspendedYields) } else if newWritability && inDelegateOutcall { // We became writable but are in a delegate outcall. - // We just have to store the new writability here + // We just have to store the new writability here. self._state = .streaming( isWritable: newWritability, inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none @@ -873,21 +761,56 @@ extension NIOAsyncWriter { inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none } - case .writerFinished(let elements, let delegate, let error): + case .writerFinished(_, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): if !newWritability { // We are not writable so we can't deliver the outstanding elements return .none } - self._state = .finished(sinkError: nil) + if newWritability && !inDelegateOutcall { + // We became writable again. This means we have to resume all the continuations. + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) - return .callDidYieldAndDidTerminate(delegate, elements, error) + return .resumeContinuations(suspendedYields) + } else if newWritability && inDelegateOutcall { + // We became writable but are in a delegate outcall. + // We just have to store the new writability here. + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .none + } else { + // We became unwritable nothing really to do here + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .none + } case .finished: // We are already finished nothing to do here @@ -905,8 +828,6 @@ extension NIOAsyncWriter { case callDidYield(Delegate) /// Indicates that the calling `Task` should get suspended. case suspendTask - /// Indicates that the method should just return. - case returnNormally /// Indicates the given error should be thrown. case throwError(Error) @@ -934,65 +855,28 @@ extension NIOAsyncWriter { inDelegateOutcall: isWritable, // If we are writable we are going to make an outcall cancelledYields: [], suspendedYields: .init(), - elements: .init(), delegate: delegate ) return .init(isWritable: isWritable, delegate: delegate) - case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, let suspendedYields, let delegate): self._state = .modifying if let index = cancelledYields.firstIndex(of: yieldID) { // We already marked the yield as cancelled. We have to remove it and - // throw an error. + // throw a CancellationError. cancelledYields.remove(at: index) - switch (isWritable, inDelegateOutcall) { - case (true, false): - // We are writable so we can yield the elements right away and then - // return normally. - self._state = .streaming( - isWritable: isWritable, - inDelegateOutcall: true, // We are now making a call to the delegate - cancelledYields: cancelledYields, - suspendedYields: suspendedYields, - elements: elements, - delegate: delegate - ) - return .callDidYield(delegate) - - case (true, true): - // We are writable but already calling out to the delegate - // so we have to buffer the elements. - elements.append(contentsOf: sequence) - - self._state = .streaming( - isWritable: isWritable, - inDelegateOutcall: inDelegateOutcall, - cancelledYields: cancelledYields, - suspendedYields: suspendedYields, - elements: elements, - delegate: delegate - ) - return .returnNormally - case (false, _): - // We are not writable so we are just going to enqueue the writes - // and return normally. We are not suspending the yield since the Task - // is marked as cancelled. - elements.append(contentsOf: sequence) - - self._state = .streaming( - isWritable: isWritable, - inDelegateOutcall: inDelegateOutcall, - cancelledYields: cancelledYields, - suspendedYields: suspendedYields, - elements: elements, - delegate: delegate - ) + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + delegate: delegate + ) - return .returnNormally - } + return .throwError(CancellationError()) } else { // Yield hasn't been marked as cancelled. @@ -1003,40 +887,77 @@ extension NIOAsyncWriter { inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .callDidYield(delegate) - case (true, true): - elements.append(contentsOf: sequence) + case (true, true), (false, _): self._state = .streaming( isWritable: isWritable, inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) - return .returnNormally - case (false, _): - // We are not writable - self._state = .streaming( + return .suspendTask + } + } + + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + if bufferedYieldIDs.contains(yieldID) { + // This yield was buffered before we became finished so we still have to deliver it + self._state = .modifying + + if let index = cancelledYields.firstIndex(of: yieldID) { + // We already marked the yield as cancelled. We have to remove it and + // throw a CancellationError. + cancelledYields.remove(at: index) + + self._state = .writerFinished( isWritable: isWritable, inDelegateOutcall: inDelegateOutcall, - cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, - delegate: delegate + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error ) - return .suspendTask + + return .throwError(CancellationError()) + } else { + // Yield hasn't been marked as cancelled. + + switch (isWritable, inDelegateOutcall) { + case (true, false): + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: true, // We are now making a call to the delegate + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .callDidYield(delegate) + case (true, true), (false, _): + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .suspendTask + } } + } else { + // We are already finished and still tried to write something + return .throwError(NIOAsyncWriterError.alreadyFinished()) } - case .writerFinished: - // We are already finished and still tried to write something - return .throwError(NIOAsyncWriterError.alreadyFinished()) - case .finished(let sinkError): // We are already finished and still tried to write something return .throwError(sinkError ?? NIOAsyncWriterError.alreadyFinished()) @@ -1050,11 +971,11 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal mutating func yield( contentsOf sequence: S, - continuation: CheckedContinuation, + continuation: CheckedContinuation, yieldID: YieldID ) where S.Element == Element { switch self._state { - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, var suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, var suspendedYields, let delegate): // We have a suspended yield at this point that hasn't been cancelled yet. // We need to store the yield now. @@ -1065,14 +986,12 @@ extension NIOAsyncWriter { continuation: continuation ) suspendedYields.append(suspendedYield) - elements.append(contentsOf: sequence) self._state = .streaming( isWritable: isWritable, inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) @@ -1087,7 +1006,8 @@ extension NIOAsyncWriter { /// Actions returned by `cancel()`. @usableFromInline enum CancelAction { - case resumeContinuation(CheckedContinuation) + /// Indicates that the continuation should be resumed with a `CancellationError`. + case resumeContinuationWithCancellationError(CheckedContinuation) /// Indicates that nothing should be done. case none } @@ -1106,13 +1026,12 @@ extension NIOAsyncWriter { inDelegateOutcall: false, cancelledYields: [yieldID], suspendedYields: .init(), - elements: .init(), delegate: delegate ) return .none - case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, var suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, var suspendedYields, let delegate): if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { self._state = .modifying // We have a suspended yield for the id. We need to resume the continuation now. @@ -1129,11 +1048,10 @@ extension NIOAsyncWriter { inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) - return .resumeContinuation(suspendedYield.continuation) + return .resumeContinuationWithCancellationError(suspendedYield.continuation) } else { self._state = .modifying @@ -1147,14 +1065,60 @@ extension NIOAsyncWriter { inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none } - case .writerFinished, .finished: + case .writerFinished(let isWritable, let inDelegateOutcall, var suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + guard bufferedYieldIDs.contains(yieldID) else { + return .none + } + if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { + self._state = .modifying + // We have a suspended yield for the id. We need to resume the continuation now. + + // Removing can be quite expensive if it produces a gap in the array. + // Since we are not expecting a lot of elements in this array it should be fine + // to just remove. If this turns out to be a performance pitfall, we can + // swap the elements before removing. So that we always remove the last element. + let suspendedYield = suspendedYields.remove(at: index) + + // We are keeping the elements that the yield produced. + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .resumeContinuationWithCancellationError(suspendedYield.continuation) + + } else { + self._state = .modifying + // There is no suspended yield. This can mean that we either already yielded + // or that the call to `yield` is coming afterwards. We need to store + // the ID here. However, if the yield already happened we will never remove the + // stored ID. The only way to avoid doing this would be storing every ID + cancelledYields.append(yieldID) + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .none + } + + case .finished: // We are already finished and there is nothing to do return .none @@ -1183,14 +1147,18 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(_, let inDelegateOutcall, _, let suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): // We are currently streaming and the writer got finished. - if elements.isEmpty { + if suspendedYields.isEmpty { if inDelegateOutcall { // We are in an outcall already and have to buffer // the didTerminate call. self._state = .writerFinished( - elements: elements, + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: .init(), delegate: delegate, error: error ) @@ -1202,16 +1170,18 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) } } else { - // There are still elements left which we need to deliver once we become writable again + // There are still suspended writer tasks which we need to deliver once we become writable again self._state = .writerFinished( - elements: elements, + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: _TinyArray(suspendedYields.map { $0.yieldID }), delegate: delegate, error: error ) - // We are not resuming the continuations with the error here since their elements - // are still queued up. If they try to yield again they will run into an alreadyFinished error - return .resumeContinuations(suspendedYields) + return .none } case .writerFinished, .finished: @@ -1226,11 +1196,6 @@ extension NIOAsyncWriter { /// Actions returned by `sinkFinish()`. @usableFromInline enum SinkFinishAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. - case callDidTerminate(Delegate, Error?) - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called and all - /// continuations should be resumed with the given error. - case resumeContinuationsWithErrorAndCallDidTerminate(Delegate, _TinyArray, Error) /// Indicates that all continuations should be resumed with the given error. case resumeContinuationsWithError(_TinyArray, Error) /// Indicates that nothing should be done. @@ -1240,41 +1205,29 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal mutating func sinkFinish(error: Error?) -> SinkFinishAction { switch self._state { - case .initial(_, let delegate): + case .initial(_, _): // Nothing was ever written so we can transition to finished self._state = .finished(sinkError: error) - return .callDidTerminate(delegate, error) - - case .streaming(_, let inDelegateOutcall, _, let suspendedYields, _, let delegate): - if inDelegateOutcall { - // We are currently streaming and the sink got finished. - // However we are in an outcall so we have to delay the call to didTerminate - // but we can resume the continuations already. - self._state = .writerFinished(elements: .init(), delegate: delegate, error: error) - - return .resumeContinuationsWithError( - suspendedYields, - error ?? NIOAsyncWriterError.alreadyFinished() - ) - } else { - // We are currently streaming and the writer got finished. - // We can transition to finished and need to resume all continuations. - self._state = .finished(sinkError: error) - return .resumeContinuationsWithErrorAndCallDidTerminate( - delegate, - suspendedYields, - error ?? NIOAsyncWriterError.alreadyFinished() - ) - } + return .none - case .writerFinished(_, let delegate, let error): - // The writer already finished and we were waiting to become writable again - // The Sink finished before we became writable so we can drop the elements and - // transition to finished + case .streaming(_, _, _, let suspendedYields, _): + // We are currently streaming and the sink got finished. + // We can transition to finished and need to resume all continuations. self._state = .finished(sinkError: error) + return .resumeContinuationsWithError( + suspendedYields, + error ?? NIOAsyncWriterError.alreadyFinished() + ) - return .callDidTerminate(delegate, error) + case .writerFinished(_, _, let suspendedYields, _, _, _, _): + // The writer already got finished and the sink got finished too now. + // We can transition to finished and need to resume all continuations. + self._state = .finished(sinkError: error) + return .resumeContinuationsWithError( + suspendedYields, + error ?? NIOAsyncWriterError.alreadyFinished() + ) case .finished: // We are already finished and there is nothing to do @@ -1288,8 +1241,7 @@ extension NIOAsyncWriter { /// Actions returned by `sinkFinish()`. @usableFromInline enum UnbufferQueuedEventsAction { - case callDidYield(Delegate, Deque) - case callDidYieldElement(Delegate, Element) + case resumeContinuations(_TinyArray) case callDidTerminate(Delegate, Error?) } @@ -1299,85 +1251,54 @@ extension NIOAsyncWriter { case .initial: preconditionFailure("Invalid state") - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") + // We have to resume the other suspended yields now. - if elements.count == 0 { - // Nothing to do. We haven't gotten any writes. + if suspendedYields.isEmpty { + // There are no other writer suspended writer tasks so we can just return self._state = .streaming( isWritable: isWritable, - inDelegateOutcall: false, // We can now indicate that we are done with the outcall + inDelegateOutcall: false, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none - } else if elements.count > 1 { - // We have to yield all of the elements now. - self._state = .streaming( - isWritable: isWritable, - inDelegateOutcall: inDelegateOutcall, - cancelledYields: cancelledYields, - suspendedYields: suspendedYields, - elements: .init(), - delegate: delegate - ) - - return .callDidYield(delegate, elements) - } else { - // There is only a single element and we can optimize this to not - // yield the whole Deque - self._state = .modifying - - // This force-unwrap is safe since we just checked the count of the Deque - // and it must be 1 here. - let element = elements.popFirst()! - + // We have to resume the other suspended yields now. self._state = .streaming( isWritable: isWritable, - inDelegateOutcall: inDelegateOutcall, + inDelegateOutcall: false, cancelledYields: cancelledYields, - suspendedYields: suspendedYields, - elements: elements, + suspendedYields: .init(), delegate: delegate ) - - return .callDidYieldElement(delegate, element) + return .resumeContinuations(suspendedYields) } - case .writerFinished(var elements, let delegate, let error): - if elements.isEmpty { - // We have returned the last buffered elements and have to - // call didTerminate now. + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): + precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") + if suspendedYields.isEmpty { + // We were the last writer task and can now call didTerminate self._state = .finished(sinkError: nil) return .callDidTerminate(delegate, error) - } else if elements.count > 1 { - // We have to yield all of the elements now. - self._state = .writerFinished( - elements: .init(), - delegate: delegate, - error: error - ) - - return .callDidYield(delegate, elements) } else { - // There is only a single element and we can optimize this to not - // yield the whole Deque + // There are still other writer tasks that need to be resumed self._state = .modifying - // This force-unwrap is safe since we just checked the count of the Deque - // and it must be 1 here. - let element = elements.popFirst()! self._state = .writerFinished( - elements: .init(), + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, delegate: delegate, error: error ) - return .callDidYieldElement(delegate, element) + return .resumeContinuations(suspendedYields) } case .finished: diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index d4ab9764d9..4c4517b311 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -126,36 +126,6 @@ final class NIOAsyncWriterTests: XCTestCase { XCTAssertEqual(elements, 60) } - func testWriterCoalescesWrites() async throws { - var writes = [Deque]() - self.delegate.didYieldHandler = { - writes.append($0) - } - self.sink.setWritability(to: false) - - let task1 = Task { [writer] in - try await writer!.yield("message1") - } - task1.cancel() - try await task1.value - - let task2 = Task { [writer] in - try await writer!.yield("message2") - } - task2.cancel() - try await task2.value - - let task3 = Task { [writer] in - try await writer!.yield("message3") - } - task3.cancel() - try await task3.value - - self.sink.setWritability(to: true) - - XCTAssertEqual(writes, [Deque(["message1", "message2", "message3"])]) - } - // MARK: - WriterDeinitialized func testWriterDeinitialized_whenInitial() async throws { @@ -183,11 +153,11 @@ final class NIOAsyncWriterTests: XCTestCase { func testWriterDeinitialized_whenFinished() async throws { self.sink.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) self.writer = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } // MARK: - ToggleWritability @@ -235,6 +205,9 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: true) + // Sleep a bit so that the other Task can retry the yield + try await Task.sleep(nanoseconds: 1_000_000) + XCTAssertEqual(self.delegate.didYieldCallCount, 1) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } @@ -273,6 +246,9 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: true) + // Sleep a bit so that the other Task can retry the yield + try await Task.sleep(nanoseconds: 1_000_000) + XCTAssertEqual(self.delegate.didYieldCallCount, 1) XCTAssertEqual(self.delegate.didTerminateCallCount, 1) } @@ -282,7 +258,7 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: false) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } // MARK: - Yield @@ -348,8 +324,10 @@ final class NIOAsyncWriterTests: XCTestCase { task.cancel() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + XCTAssertEqual(self.delegate.didYieldCallCount, 1) } func testYield_whenWriterFinished() async throws { @@ -376,7 +354,7 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } func testYield_whenFinishedError() async throws { @@ -385,7 +363,7 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertTrue(error is SomeError) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } // MARK: - Cancel @@ -401,8 +379,10 @@ final class NIOAsyncWriterTests: XCTestCase { task.cancel() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + XCTAssertEqual(self.delegate.didYieldCallCount, 0) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } @@ -421,8 +401,10 @@ final class NIOAsyncWriterTests: XCTestCase { task.cancel() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + XCTAssertEqual(self.delegate.didYieldCallCount, 1) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } @@ -442,12 +424,14 @@ final class NIOAsyncWriterTests: XCTestCase { task.cancel() - await XCTAssertNoThrow(try await task.value) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } XCTAssertEqual(self.delegate.didYieldCallCount, 1) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) self.sink.setWritability(to: true) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + XCTAssertEqual(self.delegate.didYieldCallCount, 1) } func testCancel_whenFinished() async throws { @@ -510,7 +494,12 @@ final class NIOAsyncWriterTests: XCTestCase { self.writer.finish() XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + + // We have to become writable again to unbuffer the yield + self.sink.setWritability(to: true) + await XCTAssertNoThrow(try await task.value) + XCTAssertEqual(self.delegate.didTerminateCallCount, 1) } func testWriterFinish_whenFinished() { @@ -527,7 +516,7 @@ final class NIOAsyncWriterTests: XCTestCase { func testSinkFinish_whenInitial() async throws { self.sink = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } func testSinkFinish_whenStreaming() async throws { @@ -539,7 +528,7 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } func testSinkFinish_whenFinished() async throws { From fa977dc16b34e658c4e9698e1dbdc7a3650e9b46 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Thu, 9 Nov 2023 10:40:55 -0800 Subject: [PATCH 45/64] Fixing an issue with CNIOSHA1 missing an #include for the BYTE_ORDER define. (#2584) When building swift-nio with a system with explicit modules (bazel build rules) I was getting incorrect sha1 results. It turns out the root cause is the `BYTE_ORDER` macro was not defined in my build context. Using `-Wundef` in clang I was seeing: ``` Sources/CNIOSHA1/c_nio_sha1.c:56:7: error: '__linux__' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:63:5: error: 'BYTE_ORDER' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:63:19: error: 'BIG_ENDIAN' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:113:5: error: 'BYTE_ORDER' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:113:19: error: 'LITTLE_ENDIAN' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:222:5: error: 'BYTE_ORDER' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:222:19: error: 'BIG_ENDIAN' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:267:5: error: 'BYTE_ORDER' is not defined, evaluates to 0 [-Werror,-Wundef] ^ Sources/CNIOSHA1/c_nio_sha1.c:267:19: error: 'BIG_ENDIAN' is not defined, evaluates to 0 [-Werror,-Wundef] ``` The soundness check was not actually signaling an error because it was implicitly comparing 0 to 0. This change includes an update to the `#include`'s to ensure `BIG_ENDIAN` is defined on macOS and updates the soundness check to have an explicit error if `BYTE_ORDER` is not defined. Co-authored-by: Cory Benfield --- Sources/CNIOSHA1/c_nio_sha1.c | 8 +++++--- Sources/CNIOSHA1/include/CNIOSHA1.h | 1 + Sources/CNIOSHA1/update_and_patch_sha1.sh | 4 +++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/Sources/CNIOSHA1/c_nio_sha1.c b/Sources/CNIOSHA1/c_nio_sha1.c index 965185e7b5..30b61c8793 100644 --- a/Sources/CNIOSHA1/c_nio_sha1.c +++ b/Sources/CNIOSHA1/c_nio_sha1.c @@ -4,6 +4,7 @@ - defined the __min_size macro inline - included sys/endian.h on Android - use welcoming language (soundness check) + - ensure BYTE_ORDER is defined */ /* $KAME: sha1.c,v 1.5 2000/11/08 06:13:08 itojun Exp $ */ /*- @@ -53,14 +54,15 @@ #endif #ifdef __ANDROID__ #include -#elif __linux__ +#elif defined(__linux__) || defined(__APPLE__) #include #endif - /* soundness check */ -#if BYTE_ORDER != BIG_ENDIAN +#if !defined(BYTE_ORDER) +#error "BYTE_ORDER not defined" +#elif BYTE_ORDER != BIG_ENDIAN # if BYTE_ORDER != LITTLE_ENDIAN # define unsupported 1 # endif diff --git a/Sources/CNIOSHA1/include/CNIOSHA1.h b/Sources/CNIOSHA1/include/CNIOSHA1.h index b5c057565b..b4d8c9524d 100644 --- a/Sources/CNIOSHA1/include/CNIOSHA1.h +++ b/Sources/CNIOSHA1/include/CNIOSHA1.h @@ -4,6 +4,7 @@ - defined the __min_size macro inline - included sys/endian.h on Android - use welcoming language (soundness check) + - ensure BYTE_ORDER is defined */ /* $FreeBSD$ */ /* $KAME: sha1.h,v 1.5 2000/03/27 04:36:23 sumikawa Exp $ */ diff --git a/Sources/CNIOSHA1/update_and_patch_sha1.sh b/Sources/CNIOSHA1/update_and_patch_sha1.sh index c88709ebca..8557928a34 100755 --- a/Sources/CNIOSHA1/update_and_patch_sha1.sh +++ b/Sources/CNIOSHA1/update_and_patch_sha1.sh @@ -33,6 +33,7 @@ for f in sha1.c sha1.h; do echo " - defined the __min_size macro inline" echo " - included sys/endian.h on Android" echo " - use welcoming language (soundness check)" + echo " - ensure BYTE_ORDER is defined" echo "*/" curl -Ls "https://raw.githubusercontent.com/freebsd/freebsd/master/sys/crypto/$f" ) > "$here/c_nio_$f" @@ -52,8 +53,9 @@ $sed -e $'/#define _CRYPTO_SHA1_H_/a #include \\\n#include ' $sed -e 's/u_int\([0-9]\+\)_t/uint\1_t/g' \ -e '/^#include/d' \ - -e $'/__FBSDID/c #include "include/CNIOSHA1.h"\\n#include \\n#if !defined(bzero)\\n#define bzero(b,l) memset((b), \'\\\\0\', (l))\\n#endif\\n#if !defined(bcopy)\\n#define bcopy(s,d,l) memmove((d), (s), (l))\\n#endif\\n#ifdef __ANDROID__\\n#include \\n#elif __linux__\\n#include \\n#endif' \ + -e $'/__FBSDID/c #include "include/CNIOSHA1.h"\\n#include \\n#if !defined(bzero)\\n#define bzero(b,l) memset((b), \'\\\\0\', (l))\\n#endif\\n#if !defined(bcopy)\\n#define bcopy(s,d,l) memmove((d), (s), (l))\\n#endif\\n#ifdef __ANDROID__\\n#include \\n#elif defined(__linux__) || defined(__APPLE__)\\n#include \\n#endif' \ -e 's/sanit[y]/soundness/g' \ + -e 's/#if BYTE_ORDER != BIG_ENDIAN/#if !defined(BYTE_ORDER)\\n#error "BYTE_ORDER not defined"\\n#elif BYTE_ORDER != BIG_ENDIAN/' \ -i "$here/c_nio_sha1.c" mv "$here/c_nio_sha1.h" "$here/include/CNIOSHA1.h" From 118de503e2893966cec0111e512641e3cf6bc8f8 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Mon, 13 Nov 2023 11:37:54 +0000 Subject: [PATCH 46/64] =?UTF-8?q?Add=20`withInboundOutboud`=20to=20`NIOAsy?= =?UTF-8?q?ncChannel`=20and=20deprecate=20deinit=20ba=E2=80=A6=20(#2589)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add `withInboundOutboud` to `NIOAsyncChannel` and deprecate deinit based cleanup # Motivation We just released our new async NIO APIs and have already gotten quite a bunch of feedback from adopters. One of the feedback was that the deinit based closing that we have added to the `NIOAsyncChannel` has caused problems since it leads to unexpected closure of their `Channel`. Furthermore, it makes it impossible to determine how many open sockets a program has at any given time since deinit based clean up relies on the optimizer and can happen at random times. # Modifications This PR adds new inits to `NIOAsyncSequenceProducer` and `NIOAsyncWriter` which disable the `deinit` based clean up and instead replace them with an assertion. This allows developers to still catch these issues at debug time. Furthermore, I added a new `withInboundOutbound` scoped access to `NIOAsyncChannel` which will close the channel at the end of the scope. This still gives users a nice API while not having to care much about closing themselves. # Result We are no longer using deinit based clean up and bring back one of the core principles of NIO which is deterministic resource usage. * Review comments * Internal labels for closure arguments * Rename to `executeThenCloseChannel` * Actually call `sinkDeinitialized` * Change preconditions * Move logic to deinits * Rename to `executeThenClose` and review nits --- .../TCPEchoAsyncChannel.swift | 50 +-- .../NIOCore/AsyncChannel/AsyncChannel.swift | 82 ++++- .../AsyncChannelInboundStream.swift | 12 +- ...ncChannelInboundStreamChannelHandler.swift | 20 +- .../AsyncChannelOutboundWriter.swift | 5 +- .../AsyncChannelOutboundWriterHandler.swift | 16 +- .../NIOCore/AsyncChannel/CloseRatchet.swift | 93 ------ .../NIOAsyncSequenceProducer.swift | 78 +++-- .../AsyncSequences/NIOAsyncWriter.swift | 162 +++++---- .../NIOThrowingAsyncSequenceProducer.swift | 76 ++++- .../NIOCore/Docs.docc/swift-concurrency.md | 34 +- .../NIOAsyncSequenceProducerBenchmark.swift | 6 +- .../NIOAsyncWriterSingleWritesBenchmark.swift | 2 +- Sources/NIOTCPEchoClient/Client.swift | 18 +- Sources/NIOTCPEchoServer/Server.swift | 20 +- .../AsyncChannel/AsyncChannelTests.swift | 313 +++++++----------- .../NIOAsyncSequenceTests.swift | 83 ++++- .../AsyncSequences/NIOAsyncWriterTests.swift | 68 +++- .../NIOThrowingAsyncSequenceTests.swift | 88 ++++- .../AsyncChannelBootstrapTests.swift | 310 ++++++++++------- 20 files changed, 884 insertions(+), 652 deletions(-) delete mode 100644 Sources/NIOCore/AsyncChannel/CloseRatchet.swift diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift index bfe553c0db..269bde7e53 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -53,37 +53,43 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr try await withThrowingTaskGroup(of: Void.self) { group in // This child task is echoing back the data on the server. group.addTask { - for try await connectionChannel in serverChannel.inbound { - for try await inboundData in connectionChannel.inbound { - try await connectionChannel.outbound.write(inboundData) + try await serverChannel.executeThenClose { serverChannelInbound in + for try await connectionChannel in serverChannelInbound { + try await connectionChannel.executeThenClose { connectionChannelInbound, connectionChannelOutbound in + for try await inboundData in connectionChannelInbound { + try await connectionChannelOutbound.write(inboundData) + } + } } } } - // This child task is collecting the echoed back responses. - group.addTask { - var receivedData = 0 - for try await inboundData in clientChannel.inbound { - receivedData += inboundData.readableBytes + try await clientChannel.executeThenClose { inbound, outbound in + // This child task is collecting the echoed back responses. + group.addTask { + var receivedData = 0 + for try await inboundData in inbound { + receivedData += inboundData.readableBytes - if receivedData == numberOfWrites * messageSize { - return + if receivedData == numberOfWrites * messageSize { + return + } } } - } - // Let's start sending data. - let data = ByteBuffer(repeating: 0, count: messageSize) - for _ in 0..: Sendable { /// The underlying channel being wrapped by this ``NIOAsyncChannel``. public let channel: Channel + /// The stream of inbound messages. /// /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. - public let inbound: NIOAsyncChannelInboundStream + @available(*, deprecated, message: "Use the executeThenClose scoped method instead.") + public var inbound: NIOAsyncChannelInboundStream { + self._inbound + } /// The writer for writing outbound messages. - public let outbound: NIOAsyncChannelOutboundWriter + @available(*, deprecated, message: "Use the executeThenClose scoped method instead.") + public var outbound: NIOAsyncChannelOutboundWriter { + self._outbound + } + + @usableFromInline + let _inbound: NIOAsyncChannelInboundStream + @usableFromInline + let _outbound: NIOAsyncChannelOutboundWriter /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. /// @@ -94,7 +106,7 @@ public struct NIOAsyncChannel: Sendable { ) throws { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inbound, self.outbound) = try channel._syncAddAsyncHandlers( + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) @@ -117,12 +129,12 @@ public struct NIOAsyncChannel: Sendable { ) throws where Outbound == Never { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inbound, self.outbound) = try channel._syncAddAsyncHandlers( + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled ) - self.outbound.finish() + self._outbound.finish() } @inlinable @@ -133,8 +145,8 @@ public struct NIOAsyncChannel: Sendable { ) { channel.eventLoop.preconditionInEventLoop() self.channel = channel - self.inbound = inboundStream - self.outbound = outboundWriter + self._inbound = inboundStream + self._outbound = outboundWriter } @@ -164,6 +176,52 @@ public struct NIOAsyncChannel: Sendable { outboundWriter: outboundWriter ) } + + /// Provides scoped access to the inbound and outbound side of the underlying ``Channel``. + /// + /// - Important: After this method returned the underlying ``Channel`` will be closed. + /// + /// - Parameter body: A closure that gets scoped access to the inbound and outbound. + public func executeThenClose( + _ body: (_ inbound: NIOAsyncChannelInboundStream, _ outbound: NIOAsyncChannelOutboundWriter) async throws -> Result + ) async throws -> Result { + let result: Result + do { + result = try await body(self._inbound, self._outbound) + } catch let bodyError { + do { + self._outbound.finish() + try await self.channel.close().get() + throw bodyError + } catch { + throw bodyError + } + } + + do { + self._outbound.finish() + try await self.channel.close().get() + } catch { + if let error = error as? ChannelError, error == .alreadyClosed { + return result + } + throw error + } + return result + } + + /// Provides scoped access to the inbound side of the underlying ``Channel``. + /// + /// - Important: After this method returned the underlying ``Channel`` will be closed. + /// + /// - Parameter body: A closure that gets scoped access to the inbound. + public func executeThenClose( + _ body: (_ inbound: NIOAsyncChannelInboundStream) async throws -> Result + ) async throws -> Result where Outbound == Never { + try await self.executeThenClose { inbound, _ in + try await body(inbound) + } + } } extension Channel { @@ -175,15 +233,13 @@ extension Channel { ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() - let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeWrappingHandler( channel: self, - backPressureStrategy: backPressureStrategy, - closeRatchet: closeRatchet + backPressureStrategy: backPressureStrategy ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled ) return (inboundStream, writer) } @@ -197,16 +253,14 @@ extension Channel { ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() - let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeTransformationHandler( channel: self, backPressureStrategy: backPressureStrategy, - closeRatchet: closeRatchet, channelReadTransformation: channelReadTransformation ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled ) return (inboundStream, writer) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index fb713929f6..f00896de94 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -80,7 +80,6 @@ public struct NIOAsyncChannelInboundStream: Sendable { init( channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet, handler: NIOAsyncChannelInboundStreamChannelHandler ) throws { channel.eventLoop.preconditionInEventLoop() @@ -96,6 +95,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { let sequence = Producer.makeSequence( backPressureStrategy: strategy, + finishOnDeinit: false, delegate: NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate(handler: handler) ) handler.source = sequence.source @@ -107,18 +107,15 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeWrappingHandler( channel: Channel, - backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandler( - eventLoop: channel.eventLoop, - closeRatchet: closeRatchet + eventLoop: channel.eventLoop ) return try .init( channel: channel, backPressureStrategy: backPressureStrategy, - closeRatchet: closeRatchet, handler: handler ) } @@ -128,19 +125,16 @@ public struct NIOAsyncChannelInboundStream: Sendable { static func makeTransformationHandler( channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandlerWithTransformations( eventLoop: channel.eventLoop, - closeRatchet: closeRatchet, channelReadTransformation: channelReadTransformation ) return try .init( channel: channel, backPressureStrategy: backPressureStrategy, - closeRatchet: closeRatchet, handler: handler ) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift index d1d28c4c3c..c04b203560 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift @@ -63,10 +63,6 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler NIOAsyncChannelInboundStreamChannelHandler where InboundIn == ProducerElement { return .init( eventLoop: eventLoop, - closeRatchet: closeRatchet, transformation: .syncWrapping { $0 } ) } @@ -109,12 +101,10 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler EventLoopFuture ) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel { return .init( eventLoop: eventLoop, - closeRatchet: closeRatchet, transformation: .transformation( channelReadTransformation: channelReadTransformation ) @@ -277,14 +267,6 @@ extension NIOAsyncChannelInboundStreamChannelHandler { // Wedges the read open forever, we'll never read again. self.producingState = .producingPausedWithOutstandingRead - - switch self.closeRatchet.closeRead() { - case .nothing: - break - - case .close: - self.context?.close(promise: nil) - } } @inlinable diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 50e4d2ad4d..3af5751e6b 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -84,15 +84,16 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { @inlinable init( channel: Channel, - closeRatchet: CloseRatchet + isOutboundHalfClosureEnabled: Bool ) throws { let handler = NIOAsyncChannelOutboundWriterHandler( eventLoop: channel.eventLoop, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled ) let writer = _Writer.makeWriter( elementType: OutboundOut.self, isWritable: true, + finishOnDeinit: false, delegate: .init(handler: handler) ) handler.sink = writer.sink diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index bae110c61c..59fad5e3e1 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -45,17 +45,16 @@ internal final class NIOAsyncChannelOutboundWriterHandler @usableFromInline let eventLoop: EventLoop - /// The shared `CloseRatchet` between this handler and the inbound stream handler. @usableFromInline - let closeRatchet: CloseRatchet + let isOutboundHalfClosureEnabled: Bool @inlinable init( eventLoop: EventLoop, - closeRatchet: CloseRatchet + isOutboundHalfClosureEnabled: Bool ) { self.eventLoop = eventLoop - self.closeRatchet = closeRatchet + self.isOutboundHalfClosureEnabled = isOutboundHalfClosureEnabled } @inlinable @@ -96,15 +95,8 @@ internal final class NIOAsyncChannelOutboundWriterHandler func _didTerminate(error: Error?) { self.eventLoop.preconditionInEventLoop() - switch self.closeRatchet.closeWrite() { - case .nothing: - break - - case .closeOutput: + if self.isOutboundHalfClosureEnabled { self.context?.close(mode: .output, promise: nil) - - case .close: - self.context?.close(promise: nil) } self.sink = nil diff --git a/Sources/NIOCore/AsyncChannel/CloseRatchet.swift b/Sources/NIOCore/AsyncChannel/CloseRatchet.swift deleted file mode 100644 index 9d87ea4d07..0000000000 --- a/Sources/NIOCore/AsyncChannel/CloseRatchet.swift +++ /dev/null @@ -1,93 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -/// A helper type that lets ``NIOAsyncChannelAdapterHandler`` and ``NIOAsyncChannelWriterHandler`` collude -/// to ensure that the ``Channel`` they share is closed appropriately. -/// -/// The strategy of this type is that it keeps track of which side has closed, so that the handlers can work out -/// which of them was "last", in order to arrange closure. -@usableFromInline -final class CloseRatchet { - @usableFromInline - enum State { - case notClosed(isOutboundHalfClosureEnabled: Bool) - case readClosed - case writeClosed - case bothClosed - - @inlinable - mutating func closeRead() -> CloseReadAction { - switch self { - case .notClosed: - self = .readClosed - return .nothing - case .writeClosed: - self = .bothClosed - return .close - case .readClosed, .bothClosed: - preconditionFailure("Duplicate read closure") - } - } - - @inlinable - mutating func closeWrite() -> CloseWriteAction { - switch self { - case .notClosed(let isOutboundHalfClosureEnabled): - self = .writeClosed - - if isOutboundHalfClosureEnabled { - return .closeOutput - } else { - return .nothing - } - case .readClosed: - self = .bothClosed - return .close - case .writeClosed, .bothClosed: - preconditionFailure("Duplicate write closure") - } - } - } - - @usableFromInline - var _state: State - - @inlinable - init(isOutboundHalfClosureEnabled: Bool) { - self._state = .notClosed(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) - } - - @usableFromInline - enum CloseReadAction { - case nothing - case close - } - - @inlinable - func closeRead() -> CloseReadAction { - return self._state.closeRead() - } - - @usableFromInline - enum CloseWriteAction { - case nothing - case close - case closeOutput - } - - @inlinable - func closeWrite() -> CloseWriteAction { - return self._state.closeWrite() - } -} diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift index ba39820021..c724f7798f 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift @@ -139,17 +139,52 @@ public struct NIOAsyncSequenceProducer< /// - Parameters: /// - elementType: The element type of the sequence. /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - finishOnDeinit: Indicates if ``NIOAsyncSequenceProducer/Source/finish()`` should be called on deinit of the the source. + /// We do not recommend to rely on deinit based resource tear down. /// - delegate: The delegate of the sequence /// - Returns: A ``NIOAsyncSequenceProducer/Source`` and a ``NIOAsyncSequenceProducer``. @inlinable public static func makeSequence( elementType: Element.Type = Element.self, backPressureStrategy: Strategy, + finishOnDeinit: Bool, delegate: Delegate ) -> NewSequence { let newSequence = NIOThrowingAsyncSequenceProducer.makeNonThrowingSequence( elementType: Element.self, backPressureStrategy: backPressureStrategy, + finishOnDeinit: finishOnDeinit, + delegate: delegate + ) + + let sequence = self.init(throwingSequence: newSequence.sequence) + + return .init(source: Source(throwingSource: newSequence.source), sequence: sequence) + } + + /// Initializes a new ``NIOAsyncSequenceProducer`` and a ``NIOAsyncSequenceProducer/Source``. + /// + /// - Important: This method returns a struct containing a ``NIOAsyncSequenceProducer/Source`` and + /// a ``NIOAsyncSequenceProducer``. The source MUST be held by the caller and + /// used to signal new elements or finish. The sequence MUST be passed to the actual consumer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying source. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - delegate: The delegate of the sequence + /// - Returns: A ``NIOAsyncSequenceProducer/Source`` and a ``NIOAsyncSequenceProducer``. + @inlinable + @available(*, deprecated, renamed: "makeSequence(elementType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + public static func makeSequence( + elementType: Element.Type = Element.self, + backPressureStrategy: Strategy, + delegate: Delegate + ) -> NewSequence { + let newSequence = NIOThrowingAsyncSequenceProducer.makeNonThrowingSequence( + elementType: Element.self, + backPressureStrategy: backPressureStrategy, + finishOnDeinit: true, delegate: delegate ) @@ -209,45 +244,20 @@ extension NIOAsyncSequenceProducer { /// This type allows the producer to synchronously `yield` new elements to the ``NIOAsyncSequenceProducer`` /// and to `finish` the sequence. public struct Source { - /// This class is needed to hook the deinit to observe once all references to the ``NIOAsyncSequenceProducer/Source`` are dropped. - /// - /// - Important: This is safe to be unchecked ``Sendable`` since the `storage` is ``Sendable`` and `immutable`. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { - @usableFromInline - typealias ThrowingSource = NIOThrowingAsyncSequenceProducer< - Element, - Never, - Strategy, - Delegate - >.Source - - @usableFromInline - /* fileprivate */ internal let _throwingSource: ThrowingSource - - @inlinable - init(throwingSource: ThrowingSource) { - self._throwingSource = throwingSource - } - - @inlinable - deinit { - // We need to call finish here to resume any suspended continuation. - self._throwingSource.finish() - } - } - - @usableFromInline - /* private */ internal let _internalClass: InternalClass + typealias ThrowingSource = NIOThrowingAsyncSequenceProducer< + Element, + Never, + Strategy, + Delegate + >.Source @usableFromInline - /* private */ internal var _throwingSource: InternalClass.ThrowingSource { - self._internalClass._throwingSource - } + /* private */ internal var _throwingSource: ThrowingSource @usableFromInline - /* fileprivate */ internal init(throwingSource: InternalClass.ThrowingSource) { - self._internalClass = .init(throwingSource: throwingSource) + /* fileprivate */ internal init(throwingSource: ThrowingSource) { + self._throwingSource = throwingSource } /// The result of a call to ``NIOAsyncSequenceProducer/Source/yield(_:)``. diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index 5e113fae17..9a8947f231 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -161,14 +161,23 @@ public struct NIOAsyncWriter< @usableFromInline internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - _storage.writerDeinitialized() + if !self._finishOnDeinit && !self._storage.isWriterFinished { + preconditionFailure("Deinited NIOAsyncWriter without calling finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.writerFinish(error: nil) + } } } @@ -193,16 +202,49 @@ public struct NIOAsyncWriter< /// - delegate: The delegate of the writer. /// - Returns: A ``NIOAsyncWriter/NewWriter``. @inlinable + @available(*, deprecated, renamed: "makeWriter(elementType:isWritable:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + public static func makeWriter( + elementType: Element.Type = Element.self, + isWritable: Bool, + delegate: Delegate + ) -> NewWriter { + let writer = Self( + isWritable: isWritable, + finishOnDeinit: true, + delegate: delegate + ) + let sink = Sink(storage: writer._storage, finishOnDeinit: true) + + return .init(sink: sink, writer: writer) + } + + /// Initializes a new ``NIOAsyncWriter`` and a ``NIOAsyncWriter/Sink``. + /// + /// - Important: This method returns a struct containing a ``NIOAsyncWriter/Sink`` and + /// a ``NIOAsyncWriter``. The sink MUST be held by the caller and is used to set the writability. + /// The writer MUST be passed to the actual producer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying sink. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - isWritable: The initial writability state of the writer. + /// - finishOnDeinit: Indicates if ``NIOAsyncWriter/finish()`` should be called on deinit. We do not recommend to rely on + /// deinit based resource tear down. + /// - delegate: The delegate of the writer. + /// - Returns: A ``NIOAsyncWriter/NewWriter``. + @inlinable public static func makeWriter( elementType: Element.Type = Element.self, isWritable: Bool, + finishOnDeinit: Bool, delegate: Delegate ) -> NewWriter { let writer = Self( isWritable: isWritable, + finishOnDeinit: finishOnDeinit, delegate: delegate ) - let sink = Sink(storage: writer._storage) + let sink = Sink(storage: writer._storage, finishOnDeinit: finishOnDeinit) return .init(sink: sink, writer: writer) } @@ -210,13 +252,14 @@ public struct NIOAsyncWriter< @inlinable /* private */ internal init( isWritable: Bool, + finishOnDeinit: Bool, delegate: Delegate ) { let storage = Storage( isWritable: isWritable, delegate: delegate ) - self._internalClass = .init(storage: storage) + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// Yields a sequence of new elements to the ``NIOAsyncWriter``. @@ -297,15 +340,23 @@ extension NIOAsyncWriter { @usableFromInline /* fileprivate */ internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - // We need to call finish here to resume any suspended continuation. - self._storage.sinkFinish(error: nil) + if !self._finishOnDeinit && !self._storage.isSinkFinished { + preconditionFailure("Deinited NIOAsyncWriter.Sink without calling sink.finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.sinkFinish(error: nil) + } } } @@ -318,8 +369,8 @@ extension NIOAsyncWriter { } @inlinable - init(storage: Storage) { - self._internalClass = .init(storage: storage) + init(storage: Storage, finishOnDeinit: Bool) { + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// Sets the writability of the ``NIOAsyncWriter``. @@ -407,27 +458,24 @@ extension NIOAsyncWriter { /* private */ internal var _stateMachine: StateMachine @inlinable - /* fileprivate */ internal init( - isWritable: Bool, - delegate: Delegate - ) { - self._stateMachine = .init(isWritable: isWritable, delegate: delegate) + internal var isWriterFinished: Bool { + self._lock.withLock { self._stateMachine.isWriterFinished } } @inlinable - /* fileprivate */ internal func writerDeinitialized() { - let action = self._lock.withLock { - self._stateMachine.writerDeinitialized() - } - - switch action { - case .callDidTerminate(let delegate): - delegate.didTerminate(error: nil) - - case .none: - break - } + internal var isSinkFinished: Bool { + self._lock.withLock { self._stateMachine.isSinkFinished } + } + @inlinable + /* fileprivate */ internal init( + isWritable: Bool, + delegate: Delegate + ) { + self._stateMachine = .init( + isWritable: isWritable, + delegate: delegate + ) } @inlinable @@ -662,54 +710,38 @@ extension NIOAsyncWriter { /* private */ internal var _state: State @inlinable - init( - isWritable: Bool, - delegate: Delegate - ) { - self._state = .initial(isWritable: isWritable, delegate: delegate) - } - - /// Actions returned by `writerDeinitialized()`. - @usableFromInline - enum WriterDeinitializedAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. - case callDidTerminate(Delegate) - /// Indicates that nothing should be done. - case none + internal var isWriterFinished: Bool { + switch self._state { + case .initial, .streaming: + return false + case .writerFinished, .finished: + return true + case .modifying: + preconditionFailure("Invalid state") + } } @inlinable - /* fileprivate */ internal mutating func writerDeinitialized() -> WriterDeinitializedAction { + internal var isSinkFinished: Bool { switch self._state { - case .initial(_, let delegate): - // The writer deinited before writing anything. - // We can transition to finished and inform our delegate - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) - - case .streaming(_, _, _, let suspendedYields, let delegate): - // The writer got deinited after we started streaming. - // This is normal and we need to transition to finished - // and call the delegate. However, we should not have - // any suspended yields because they MUST strongly retain - // the writer. - precondition(suspendedYields.isEmpty, "We have outstanding suspended yields") - - // We can transition to finished directly - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) - - case .finished, .writerFinished: - // We are already finished nothing to do here - return .none - + case .initial, .streaming, .writerFinished: + return false + case .finished: + return true case .modifying: preconditionFailure("Invalid state") } } + + @inlinable + init( + isWritable: Bool, + delegate: Delegate + ) { + self._state = .initial(isWritable: isWritable, delegate: delegate) + } + /// Actions returned by `setWritability()`. @usableFromInline enum SetWritabilityAction { diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index 3b5e2776ad..8f7020adab 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -85,6 +85,38 @@ public struct NIOThrowingAsyncSequenceProducer< self._internalClass._storage } + /// Initializes a new ``NIOThrowingAsyncSequenceProducer`` and a ``NIOThrowingAsyncSequenceProducer/Source``. + /// + /// - Important: This method returns a struct containing a ``NIOThrowingAsyncSequenceProducer/Source`` and + /// a ``NIOThrowingAsyncSequenceProducer``. The source MUST be held by the caller and + /// used to signal new elements or finish. The sequence MUST be passed to the actual consumer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying source. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - failureType: The failure type of the sequence. Must be `Swift.Error` + /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - finishOnDeinit: Indicates if ``NIOThrowingAsyncSequenceProducer/Source/finish()`` should be called on deinit of the. + /// We do not recommend to rely on deinit based resource tear down. + /// - delegate: The delegate of the sequence + /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. + @inlinable + public static func makeSequence( + elementType: Element.Type = Element.self, + failureType: Failure.Type = Error.self, + backPressureStrategy: Strategy, + finishOnDeinit: Bool, + delegate: Delegate + ) -> NewSequence where Failure == Error { + let sequence = Self( + backPressureStrategy: backPressureStrategy, + delegate: delegate + ) + let source = Source(storage: sequence._storage, finishOnDeinit: finishOnDeinit) + + return .init(source: source, sequence: sequence) + } + /// Initializes a new ``NIOThrowingAsyncSequenceProducer`` and a ``NIOThrowingAsyncSequenceProducer/Source``. /// /// - Important: This method returns a struct containing a ``NIOThrowingAsyncSequenceProducer/Source`` and @@ -110,7 +142,7 @@ public struct NIOThrowingAsyncSequenceProducer< backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: true) return .init(source: source, sequence: sequence) } @@ -129,6 +161,7 @@ public struct NIOThrowingAsyncSequenceProducer< /// - delegate: The delegate of the sequence /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. @inlinable + @available(*, deprecated, renamed: "makeSequence(elementType:failureType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") public static func makeSequence( elementType: Element.Type = Element.self, failureType: Failure.Type = Error.self, @@ -139,7 +172,7 @@ public struct NIOThrowingAsyncSequenceProducer< backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: true) return .init(source: source, sequence: sequence) } @@ -149,13 +182,14 @@ public struct NIOThrowingAsyncSequenceProducer< internal static func makeNonThrowingSequence( elementType: Element.Type = Element.self, backPressureStrategy: Strategy, + finishOnDeinit: Bool, delegate: Delegate ) -> NewSequence where Failure == Never { let sequence = Self( backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: finishOnDeinit) return .init(source: source, sequence: sequence) } @@ -238,15 +272,23 @@ extension NIOThrowingAsyncSequenceProducer { @usableFromInline internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - // We need to call finish here to resume any suspended continuation. - self._storage.finish(nil) + if !self._finishOnDeinit && !self._storage.isFinished { + preconditionFailure("Deinited NIOAsyncSequenceProducer.Source without calling source.finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.finish(nil) + } } } @@ -259,8 +301,8 @@ extension NIOThrowingAsyncSequenceProducer { } @usableFromInline - /* fileprivate */ internal init(storage: Storage) { - self._internalClass = .init(storage: storage) + /* fileprivate */ internal init(storage: Storage, finishOnDeinit: Bool) { + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// The result of a call to ``NIOThrowingAsyncSequenceProducer/Source/yield(_:)``. @@ -357,6 +399,11 @@ extension NIOThrowingAsyncSequenceProducer { @usableFromInline /* private */ internal var _delegate: Delegate? + @inlinable + var isFinished: Bool { + self._lock.withLock { self._stateMachine.isFinished } + } + @usableFromInline /* fileprivate */ internal init( backPressureStrategy: Strategy, @@ -648,6 +695,19 @@ extension NIOThrowingAsyncSequenceProducer { @usableFromInline /* private */ internal var _state: State + @inlinable + var isFinished: Bool { + switch self._state { + case .initial, .streaming: + return false + case .cancelled, .sourceFinished, .finished: + return true + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Initializes a new `StateMachine`. /// /// We are passing and holding the back-pressure strategy here because diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index c8ddaadd86..891ad9e4fe 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -87,8 +87,10 @@ the inbound data and echo it back outbound. let channel = ... let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) -for try await inboundData in asyncChannel.inbound { - try await asyncChannel.outbound.write(inboundData) +try await asyncChannel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + try await outbound.write(inboundData) + } } ``` @@ -137,15 +139,19 @@ let serverChannel = try await ServerBootstrap(group: eventLoopGroup) } try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in serverChannel.inbound { - group.addTask { - do { - for try await inboundData in connectionChannel.inbound { - // Let's echo back all inbound data - try await connectionChannel.outbound.write(inboundData) + try await serverChannel.executeThenClose { serverChannelInbound in + for try await connectionChannel in serverChannelInbound { + group.addTask { + do { + try await connectionChannel.executeThenClose { connectionChannelInbound, connectionChannelOutbound in + for try await inboundData in connectionChannelInbound { + // Let's echo back all inbound data + try await connectionChannelOutbound.write(inboundData) + } + } + } catch { + // Handle errors } - } catch { - // Handle errors } } } @@ -185,10 +191,12 @@ let clientChannel = try await ClientBootstrap(group: eventLoopGroup) } } -clientChannel.outbound.write(ByteBuffer(string: "hello")) +try await clientChannel.executeThenClose { inbound, outbound in + try await outbound.write(ByteBuffer(string: "hello")) -for try await inboundData in clientChannel.inbound { - print(inboundData) + for try await inboundData in inbound { + print(inboundData) + } } ``` diff --git a/Sources/NIOPerformanceTester/NIOAsyncSequenceProducerBenchmark.swift b/Sources/NIOPerformanceTester/NIOAsyncSequenceProducerBenchmark.swift index b26a0f99b7..e8424e3fa2 100644 --- a/Sources/NIOPerformanceTester/NIOAsyncSequenceProducerBenchmark.swift +++ b/Sources/NIOPerformanceTester/NIOAsyncSequenceProducerBenchmark.swift @@ -31,7 +31,11 @@ final class NIOAsyncSequenceProducerBenchmark: AsyncBenchmark, NIOAsyncSequenceP } func setUp() async throws { - let producer = SequenceProducer.makeSequence(backPressureStrategy: .init(lowWatermark: 100, highWatermark: 500), delegate: self) + let producer = SequenceProducer.makeSequence( + backPressureStrategy: .init(lowWatermark: 100, highWatermark: 500), + finishOnDeinit: false, + delegate: self + ) self.iterator = producer.sequence.makeAsyncIterator() self.source = producer.source } diff --git a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift index b2122b8c05..8ee2cda9d5 100644 --- a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift +++ b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift @@ -38,7 +38,7 @@ final class NIOAsyncWriterSingleWritesBenchmark: AsyncBenchmark, @unchecked Send init(iterations: Int) { self.iterations = iterations self.delegate = .init() - let newWriter = NIOAsyncWriter.makeWriter(isWritable: true, delegate: self.delegate) + let newWriter = NIOAsyncWriter.makeWriter(isWritable: true, finishOnDeinit: false, delegate: self.delegate) self.writer = newWriter.writer self.sink = newWriter.sink } diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index 9bc0e0c9aa..a334c7f605 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -68,16 +68,18 @@ struct Client { } } - print("Connection(\(number)): Writing request") - try await channel.outbound.write("Hello on connection \(number)") + try await channel.executeThenClose { inbound, outbound in + print("Connection(\(number)): Writing request") + try await outbound.write("Hello on connection \(number)") - for try await inboundData in channel.inbound { - print("Connection(\(number)): Received response (\(inboundData))") + for try await inboundData in inbound { + print("Connection(\(number)): Received response (\(inboundData))") - // We only expect a single response so we can exit here. - // Once, we exit out of this loop and the references to the `NIOAsyncChannel` are dropped - // the connection is going to close itself. - break + // We only expect a single response so we can exit here. + // Once, we exit out of this loop and the references to the `NIOAsyncChannel` are dropped + // the connection is going to close itself. + break + } } } } diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index edc52f2e1b..368df645a9 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -64,11 +64,13 @@ struct Server { // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in channel.inbound { - group.addTask { - print("Handling new connection") - await self.handleConnection(channel: connectionChannel) - print("Done handling connection") + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + group.addTask { + print("Handling new connection") + await self.handleConnection(channel: connectionChannel) + print("Done handling connection") + } } } } @@ -80,9 +82,11 @@ struct Server { // We do this since we don't want to tear down the whole server when a single connection // encounters an error. do { - for try await inboundData in channel.inbound { - print("Received request (\(inboundData))") - try await channel.outbound.write(inboundData) + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + print("Received request (\(inboundData))") + try await outbound.write(inboundData) + } } } catch { print("Hit error: \(error)") diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 1047e54437..8c7fcf7348 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -33,8 +33,9 @@ final class AsyncChannelTests: XCTestCase { return try NIOAsyncChannel(synchronouslyWrapping: channel) } - try await wrapped.outbound.write("Test") - try await channel.closeFuture.get() + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("Test") + } } func testAsyncChannelBasicFunctionality() async throws { @@ -44,23 +45,23 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - var iterator = wrapped.inbound.makeAsyncIterator() - try await channel.writeInbound("hello") - let firstRead = try await iterator.next() - XCTAssertEqual(firstRead, "hello") - - try await channel.writeInbound("world") - let secondRead = try await iterator.next() - XCTAssertEqual(secondRead, "world") + try await wrapped.executeThenClose { inbound, _ in + var iterator = inbound.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) - } + try await channel.writeInbound("world") + let secondRead = try await iterator.next() + XCTAssertEqual(secondRead, "world") - let thirdRead = try await iterator.next() - XCTAssertNil(thirdRead) + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + } - try await channel.closeFuture.get() + let thirdRead = try await iterator.next() + XCTAssertNil(thirdRead) + } } func testAsyncChannelBasicWrites() async throws { @@ -70,138 +71,51 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - try await wrapped.outbound.write("hello") - try await wrapped.outbound.write("world") - - let firstRead = try await channel.waitForOutboundWrite(as: String.self) - let secondRead = try await channel.waitForOutboundWrite(as: String.self) - - XCTAssertEqual(firstRead, "hello") - XCTAssertEqual(secondRead, "world") - - try await channel.close() - } - - func testDroppingTheWriterClosesTheWriteSideOfTheChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } - let channel = NIOAsyncTestingChannel() - let closeRecorder = CloseRecorder() - try await channel.pipeline.addHandler(closeRecorder) - - let inboundReader: NIOAsyncChannelInboundStream + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("hello") + try await outbound.write("world") - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: true, - inboundType: Never.self, - outboundType: Never.self - ) - ) - } - inboundReader = wrapped.inbound + let firstRead = try await channel.waitForOutboundWrite(as: String.self) + let secondRead = try await channel.waitForOutboundWrite(as: String.self) - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.outboundCloses) - } + XCTAssertEqual(firstRead, "hello") + XCTAssertEqual(secondRead, "world") } - - await channel.testingEventLoop.run() - - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.outboundCloses) - } - - // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - channel.close(promise: nil) } - func testDroppingTheWriterDoesntCloseTheWriteSideOfTheChannelIfHalfClosureIsDisabled() async throws { + func testFinishingTheWriterClosesTheWriteSideOfTheChannel() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() try await channel.pipeline.addHandler(closeRecorder) - let inboundReader: NIOAsyncChannelInboundStream - - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: false, - inboundType: Never.self, - outboundType: Never.self - ) + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel( + synchronouslyWrapping: channel, + configuration: .init( + isOutboundHalfClosureEnabled: true, + inboundType: Never.self, + outboundType: Never.self ) - } - inboundReader = wrapped.inbound - - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) - } - } - - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) + ) } - // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - channel.close(promise: nil) - } - - func testDroppingTheWriterFirstLeadsToChannelClosureWhenReaderIsAlsoDropped() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } - let channel = NIOAsyncTestingChannel() - let closeRecorder = CloseRecorder() - try await channel.pipeline.addHandler(CloseSuppressor()) - try await channel.pipeline.addHandler(closeRecorder) - - do { - let inboundReader: NIOAsyncChannelInboundStream - - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: true, - inboundType: Never.self, - outboundType: Never.self - ) - ) - } - inboundReader = wrapped.inbound - - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.allCloses) - } - } + try await wrapped.executeThenClose { inbound, outbound in + outbound.finish() await channel.testingEventLoop.run() - // First we see half-closure. try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.allCloses) + XCTAssertEqual(1, closeRecorder.outboundCloses) } // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - } + withExtendedLifetime(inbound) {} - // Now the inbound reader is dead, we see full closure. - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(2, closeRecorder.allCloses) } - - try await channel.closeIgnoringSuppression() } - func testDroppingEverythingClosesTheChannel() async throws { + func testDroppingEverythingDoesntCloseTheChannel() async throws { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() @@ -209,7 +123,7 @@ final class AsyncChannelTests: XCTestCase { try await channel.pipeline.addHandler(closeRecorder) do { - let wrapped = try await channel.testingEventLoop.executeInContext { + _ = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( synchronouslyWrapping: channel, configuration: .init( @@ -223,14 +137,11 @@ final class AsyncChannelTests: XCTestCase { try await channel.testingEventLoop.executeInContext { XCTAssertEqual(0, closeRecorder.allCloses) } - - // Just use this to keep the wrapper alive until here. - withExtendedLifetime(wrapped) {} } // Now that everything is dead, we see full closure. try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.allCloses) + XCTAssertEqual(0, closeRecorder.allCloses) } try await channel.closeIgnoringSuppression() @@ -249,8 +160,10 @@ final class AsyncChannelTests: XCTestCase { try await channel.close().get() - let reads = try await Array(wrapped.inbound) - XCTAssertEqual(reads, ["hello"]) + try await wrapped.executeThenClose { inbound, _ in + let reads = try await Array(inbound) + XCTAssertEqual(reads, ["hello"]) + } } func testErrorsArePropagatedButAfterReads() async throws { @@ -265,12 +178,14 @@ final class AsyncChannelTests: XCTestCase { channel.pipeline.fireErrorCaught(TestError.bang) } - var iterator = wrapped.inbound.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first, "hello") + try await wrapped.executeThenClose { inbound, _ in + var iterator = inbound.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, "hello") - try await XCTAssertThrowsError(await iterator.next()) { error in - XCTAssertEqual(error as? TestError, .bang) + try await XCTAssertThrowsError(await iterator.next()) { error in + XCTAssertEqual(error as? TestError, .bang) + } } } @@ -290,9 +205,11 @@ final class AsyncChannelTests: XCTestCase { await withThrowingTaskGroup(of: Void.self) { group in group.addTask { - try await wrapped.outbound.write("hello") - lock.withLockedValue { - XCTAssertTrue($0) + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("hello") + lock.withLockedValue { + XCTAssertTrue($0) + } } } @@ -307,8 +224,6 @@ final class AsyncChannelTests: XCTestCase { } } } - - try await channel.close().get() } func testBufferDropsReadsIfTheReaderIsGone() async throws { @@ -396,57 +311,59 @@ final class AsyncChannelTests: XCTestCase { } XCTAssertEqual(readCounter.readCount, 6) - // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. - var reader = wrapped.inbound.makeAsyncIterator() - for _ in 0..<3 { + try await wrapped.executeThenClose { inbound, outbound in + // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. + var reader = inbound.makeAsyncIterator() + for _ in 0..<3 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 6) + + // Removing the next element should trigger an automatic read. try await XCTAsyncAssertNotNil(await reader.next()) - } - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 6) + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 7) - // Removing the next element should trigger an automatic read. - try await XCTAsyncAssertNotNil(await reader.next()) - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 7) + // Reads now work again, even if more data arrives. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() - // Reads now work again, even if more data arrives. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelReadComplete() - - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() - } - XCTAssertEqual(readCounter.readCount, 13) + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 13) - // The next reads arriving pushes us past the limit again. - // This time we won't read. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelReadComplete() - } - XCTAssertEqual(readCounter.readCount, 13) + // The next reads arriving pushes us past the limit again. + // This time we won't read. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + } + XCTAssertEqual(readCounter.readCount, 13) - // This time we'll consume 4 more elements, and we won't find a read at all. - for _ in 0..<4 { - try await XCTAsyncAssertNotNil(await reader.next()) - } - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 13) + // This time we'll consume 4 more elements, and we won't find a read at all. + for _ in 0..<4 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 13) - // But the next reads work fine. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() + // But the next reads work fine. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 16) } - XCTAssertEqual(readCounter.readCount, 16) } func testCanWrapAChannelSynchronously() async throws { @@ -456,23 +373,23 @@ final class AsyncChannelTests: XCTestCase { try NIOAsyncChannel(synchronouslyWrapping: channel) } - var iterator = wrapped.inbound.makeAsyncIterator() - try await channel.writeInbound("hello") - let firstRead = try await iterator.next() - XCTAssertEqual(firstRead, "hello") - - try await wrapped.outbound.write("world") - let write = try await channel.waitForOutboundWrite(as: String.self) - XCTAssertEqual(write, "world") + try await wrapped.executeThenClose { inbound, outbound in + var iterator = inbound.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) - } + try await outbound.write("world") + let write = try await channel.waitForOutboundWrite(as: String.self) + XCTAssertEqual(write, "world") - let secondRead = try await iterator.next() - XCTAssertNil(secondRead) + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + } - try await channel.close() + let secondRead = try await iterator.next() + XCTAssertNil(secondRead) + } } } diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 114b06c576..7f96495a1a 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -102,6 +102,7 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let result = NIOAsyncSequenceProducer.makeSequence( elementType: Int.self, backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: false, delegate: self.delegate ) self.source = result.source @@ -112,6 +113,8 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { self.backPressureStrategy = nil self.delegate = nil self.sequence = nil + self.source.finish() + self.source = nil super.tearDown() } @@ -307,39 +310,82 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - self.source = nil + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + + source = nil + XCTAssertNil(source) + XCTAssertNotNil(sequence) } func testSourceDeinited_whenStreaming_andSuspended() async throws { + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + // We are registering our demand and sleeping a bit to make // sure the other child task runs when the demand is registered - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - let element = await sequence.first { _ in true } + let element = await sequence!.first { _ in true } return element } try await Task.sleep(nanoseconds: 1_000_000) - self.source = nil + source = nil return try await group.next() ?? nil } XCTAssertEqual(element, nil) + XCTAssertNil(source) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - _ = self.source.yield(contentsOf: []) + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: []) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence.first { _ in true } + return await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -350,14 +396,27 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - _ = self.source.yield(contentsOf: [1]) + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: [1]) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence.first { _ in true } + return await sequence!.first { _ in true } } return try await group.next() ?? nil diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 4c4517b311..091e41743e 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -60,6 +60,7 @@ final class NIOAsyncWriterTests: XCTestCase { let newWriter = NIOAsyncWriter.makeWriter( elementType: String.self, isWritable: true, + finishOnDeinit: false, delegate: self.delegate ) self.writer = newWriter.writer @@ -67,6 +68,12 @@ final class NIOAsyncWriterTests: XCTestCase { } override func tearDown() { + if let writer = self.writer { + writer.finish() + } + if let sink = self.sink { + sink.finish() + } self.delegate = nil self.writer = nil self.sink = nil @@ -129,16 +136,42 @@ final class NIOAsyncWriterTests: XCTestCase { // MARK: - WriterDeinitialized func testWriterDeinitialized_whenInitial() async throws { - self.writer = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + let sink = newWriter!.sink + var writer: NIOAsyncWriter? = newWriter!.writer + newWriter = nil + + writer = nil XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertNil(writer) + + sink.finish() } func testWriterDeinitialized_whenStreaming() async throws { - try await writer.yield("message1") - self.writer = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + let sink = newWriter!.sink + var writer: NIOAsyncWriter? = newWriter!.writer + newWriter = nil + + try await writer!.yield("message1") + writer = nil XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertNil(writer) + + sink.finish() } func testWriterDeinitialized_whenWriterFinished() async throws { @@ -514,20 +547,43 @@ final class NIOAsyncWriterTests: XCTestCase { // MARK: - Sink Finish func testSinkFinish_whenInitial() async throws { - self.sink = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + var sink: NIOAsyncWriter.Sink? = newWriter!.sink + let writer = newWriter!.writer + newWriter = nil + + sink = nil + XCTAssertNil(sink) + XCTAssertNotNil(writer) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } func testSinkFinish_whenStreaming() async throws { + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + var sink: NIOAsyncWriter.Sink? = newWriter!.sink + let writer = newWriter!.writer + newWriter = nil + Task { [writer] in - try await writer!.yield("message1") + try await writer.yield("message1") } try await Task.sleep(nanoseconds: 1_000_000) - self.sink = nil + sink = nil + XCTAssertNil(sink) XCTAssertEqual(self.delegate.didTerminateCallCount, 0) } diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index ebd38f87db..e63ccb8391 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -41,6 +41,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { elementType: Int.self, failureType: Error.self, backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: false, delegate: self.delegate ) self.source = result.source @@ -51,6 +52,8 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { self.backPressureStrategy = nil self.delegate = nil self.sequence = nil + self.source.finish() + self.source = nil super.tearDown() } @@ -383,39 +386,86 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - self.source = nil + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + + source = nil + + XCTAssertNil(source) + XCTAssertNotNil(sequence) } func testSourceDeinited_whenStreaming_andSuspended() async throws { + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + // We are registering our demand and sleeping a bit to make // sure the other child task runs when the demand is registered - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - let element = try await sequence.first { _ in true } + let element = try await sequence!.first { _ in true } return element } try await Task.sleep(nanoseconds: 1_000_000) - self.source = nil + source = nil return try await group.next() ?? nil } XCTAssertEqual(element, nil) + XCTAssertNil(source) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - _ = self.source.yield(contentsOf: []) + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: []) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + return try await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -426,14 +476,28 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - _ = self.source.yield(contentsOf: [1]) + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: [1]) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + return try await sequence!.first { _ in true } } return try await group.next() ?? nil diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 2775b48d2d..fb87269dfb 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -239,16 +239,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { _ in - for try await childChannel in channel.inbound { - for try await value in childChannel.inbound { - continuation.yield(.string(value)) + try await channel.executeThenClose { inbound in + for try await childChannel in inbound { + try await childChannel.executeThenClose { childChannelInbound, _ in + for try await value in childChannelInbound { + continuation.yield(.string(value)) + } + } } } } } let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) @@ -280,16 +286,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inbound { - group.addTask { - switch try await negotiationResult.get() { - case .string(let channel): - for try await value in channel.inbound { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inbound { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -305,8 +317,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { let stringNegotiationResult = try await stringNegotiationResultFuture.get() switch stringNegotiationResult { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -322,8 +336,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outbound.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -354,16 +370,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inbound { - group.addTask { - switch try await negotiationResult.get().get() { - case .string(let channel): - for try await value in channel.inbound { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inbound { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get().get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -379,8 +401,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { ) switch try await stringStringNegotiationResult.get().get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -394,8 +418,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { ) switch try await byteStringNegotiationResult.get().get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -411,8 +437,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outbound.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -426,8 +454,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outbound.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -483,16 +513,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inbound { - group.addTask { - switch try await negotiationResult.get() { - case .string(let channel): - for try await value in channel.inbound { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inbound { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -517,8 +553,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { ) switch try await stringNegotiationResult.get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -549,14 +587,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: eventLoopGroup, port: serverChannel.channel.localAddress!.port! ) - var serverInboundIterator = serverChannel.inbound.makeAsyncIterator() - var clientInboundIterator = clientChannel.inbound.makeAsyncIterator() + try await serverChannel.executeThenClose { serverChannelInbound, serverChannelOutbound in + try await clientChannel.executeThenClose { clientChannelInbound, clientChannelOutbound in + var serverInboundIterator = serverChannelInbound.makeAsyncIterator() + var clientInboundIterator = clientChannelInbound.makeAsyncIterator() - try await clientChannel.outbound.write("request") - try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") + try await clientChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outbound.write("response") - try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + try await serverChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + } + } } func testDatagramBootstrap_withProtocolNegotiation_andHostPort() async throws { @@ -601,14 +643,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inbound.makeAsyncIterator() - var secondInboundIterator = secondChannel.inbound.makeAsyncIterator() + try await firstChannel.executeThenClose { firstChannelInbound, firstChannelOutbound in + try await secondChannel.executeThenClose { secondChannelInbound, secondChannelOutbound in + var firstInboundIterator = firstChannelInbound.makeAsyncIterator() + var secondInboundIterator = secondChannelInbound.makeAsyncIterator() - try await firstChannel.outbound.write("request") - try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") + try await firstChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outbound.write("response") - try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + try await secondChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + } + } default: preconditionFailure() @@ -671,15 +717,21 @@ final class AsyncChannelBootstrapTests: XCTestCase { throw error } - var inboundIterator = channel.inbound.makeAsyncIterator() - var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() + try await channel.executeThenClose { channelInbound, channelOutbound in + try await fromChannel.executeThenClose { fromChannelInbound, _ in + try await toChannel.executeThenClose { _, toChannelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() - try await toChannel.outbound.write(.init(string: "Request")) - try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + try await toChannelOutbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) - let response = ByteBuffer(string: "Response") - try await channel.outbound.write(response) - try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + let response = ByteBuffer(string: "Response") + try await channelOutbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + } + } } func testPipeBootstrap_whenInputNil() async throws { @@ -719,14 +771,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { throw error } - var inboundIterator = channel.inbound.makeAsyncIterator() - var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() + try await channel.executeThenClose { channelInbound, channelOutbound in + try await fromChannel.executeThenClose { fromChannelInbound, _ in + var inboundIterator = channelInbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() - try await XCTAsyncAssertEqual(try await inboundIterator.next(), nil) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), nil) - let response = ByteBuffer(string: "Response") - try await channel.outbound.write(response) - try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + let response = ByteBuffer(string: "Response") + try await channelOutbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + } } func testPipeBootstrap_whenOutputNil() async throws { @@ -766,14 +822,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { throw error } - var inboundIterator = channel.inbound.makeAsyncIterator() + try await channel.executeThenClose { channelInbound, channelOutbound in + try await toChannel.executeThenClose { _, toChannelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() - try await toChannel.outbound.write(.init(string: "Request")) - try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + try await toChannelOutbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) - let response = ByteBuffer(string: "Response") - await XCTAsyncAssertThrowsError(try await channel.outbound.write(response)) { error in - XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) + let response = ByteBuffer(string: "Response") + await XCTAsyncAssertThrowsError(try await channelOutbound.write(response)) { error in + XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) + } + } } } @@ -830,27 +890,33 @@ final class AsyncChannelBootstrapTests: XCTestCase { throw error } - var fromChannelInboundIterator = fromChannel.inbound.makeAsyncIterator() - - try await toChannel.outbound.write(.init(string: "alpn:string\nHello\n")) - switch try await negotiationResult.get() { - case .string(let channel): - var inboundIterator = channel.inbound.makeAsyncIterator() - do { - try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") - - let expectedResponse = ByteBuffer(string: "Response\n") - try await channel.outbound.write("Response") - let response = try await fromChannelInboundIterator.next() - XCTAssertEqual(response, expectedResponse) - } catch { - // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFD, pipe2ReadFD].forEach { try? SystemCalls.close(descriptor: $0) } - throw error - } + try await fromChannel.executeThenClose { fromChannelInbound, _ in + try await toChannel.executeThenClose { _, toChannelOutbound in + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() + + try await toChannelOutbound.write(.init(string: "alpn:string\nHello\n")) + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { channelInbound, channelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() + do { + try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") + + let expectedResponse = ByteBuffer(string: "Response\n") + try await channelOutbound.write("Response") + let response = try await fromChannelInboundIterator.next() + XCTAssertEqual(response, expectedResponse) + } catch { + // We only got to close the FDs that are not owned by the PipeChannel + [pipe1WriteFD, pipe2ReadFD].forEach { try? SystemCalls.close(descriptor: $0) } + throw error + } + } - case .byte: - fatalError() + case .byte: + fatalError() + } + } } } @@ -866,14 +932,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { let serverChannel = try await self.makeRawSocketServerChannel(eventLoopGroup: eventLoopGroup) let clientChannel = try await self.makeRawSocketClientChannel(eventLoopGroup: eventLoopGroup) - var serverInboundIterator = serverChannel.inbound.makeAsyncIterator() - var clientInboundIterator = clientChannel.inbound.makeAsyncIterator() + try await serverChannel.executeThenClose { serverChannelInbound, serverChannelOutbound in + try await clientChannel.executeThenClose { clientChannelInbound, clientChannelOutbound in + var serverInboundIterator = serverChannelInbound.makeAsyncIterator() + var clientInboundIterator = clientChannelInbound.makeAsyncIterator() - try await clientChannel.outbound.write("request") - try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") + try await clientChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outbound.write("response") - try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + try await serverChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + } + } } func testRawSocketBootstrap_withProtocolNegotiation() async throws { @@ -906,14 +976,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inbound.makeAsyncIterator() - var secondInboundIterator = secondChannel.inbound.makeAsyncIterator() + try await firstChannel.executeThenClose { firstChannelInbound, firstChannelOutbound in + try await secondChannel.executeThenClose { secondChannelInbound, secondChannelOutbound in + var firstInboundIterator = firstChannelInbound.makeAsyncIterator() + var secondInboundIterator = secondChannelInbound.makeAsyncIterator() - try await firstChannel.outbound.write("request") - try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") + try await firstChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outbound.write("response") - try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + try await secondChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + } + } default: preconditionFailure() @@ -956,9 +1030,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { _ in - for try await childChannel in serverChannel.inbound { - for try await value in childChannel.inbound { - continuation.yield(.string(value)) + try await serverChannel.executeThenClose { inbound in + for try await childChannel in inbound { + try await childChannel.executeThenClose { childChannelInbound, _ in + for try await value in childChannelInbound { + continuation.yield(.string(value)) + } + } } } } @@ -973,7 +1051,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { return try NIOAsyncChannel(synchronouslyWrapping: channel) } } - try await stringChannel.outbound.write("hello") + try await stringChannel.executeThenClose { _, outbound in + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) From 1040927f12d045a31aa8d09b6d99e03483f120b0 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 15 Nov 2023 13:42:53 +0000 Subject: [PATCH 47/64] Add `closeOnDeinit` to the `NIOAsyncChannel` init (#2592) * Add `closeOnDeinit` to the `NIOAsyncChannel` init # Motivation In my previous PR, I already did the work to add `finishOnDeinit` configuration to the `NIOAsyncWriter` and `NIOAsyncSequenceProducer`. This PR also automatically migrated the `NIOAsyncChanell` to set the `finishOnDeinit = false`. This was intentional since we really want users to not use the deinit based cleanup; however, it also broke all current adopters of this API semantically and they might now run into the preconditions. # Modification This PR reverts the change in `NIOAsyncChannel` and does the usual deprecate + new init dance to provide users to configure this behaviour while still nudging them to check that this is really what they want. # Result Easier migration without semantically breaking current adopters of `NIOAsyncChannel`. * Rename to `wrappingChannelSynchronously` --- .../TCPEchoAsyncChannel.swift | 4 +- .../NIOCore/AsyncChannel/AsyncChannel.swift | 101 ++++++++++++++++-- .../AsyncChannelInboundStream.swift | 9 +- .../AsyncChannelOutboundWriter.swift | 5 +- .../NIOCore/Docs.docc/swift-concurrency.md | 6 +- Sources/NIOPosix/Bootstrap.swift | 2 +- Sources/NIOTCPEchoClient/Client.swift | 2 +- Sources/NIOTCPEchoServer/Server.swift | 2 +- Sources/NIOWebSocketClient/Client.swift | 2 +- Sources/NIOWebSocketServer/Server.swift | 4 +- .../AsyncChannel/AsyncChannelTests.swift | 22 ++-- .../AsyncChannelBootstrapTests.swift | 38 +++---- 12 files changed, 145 insertions(+), 52 deletions(-) diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift index 269bde7e53..0a080598a8 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEchoAsyncChannel.swift @@ -23,7 +23,7 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr ) { channel in channel.eventLoop.makeCompletedFuture { return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( inboundType: ByteBuffer.self, outboundType: ByteBuffer.self @@ -39,7 +39,7 @@ func runTCPEchoAsyncChannel(numberOfWrites: Int, eventLoop: EventLoop) async thr ) { channel in channel.eventLoop.makeCompletedFuture { return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( inboundType: ByteBuffer.self, outboundType: ByteBuffer.self diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index 2a2bcf217a..3bbaaf4ba9 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -100,6 +100,55 @@ public struct NIOAsyncChannel: Sendable { /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable + public init( + wrappingChannelSynchronously channel: Channel, + configuration: Configuration = .init() + ) throws { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: false + ) + } + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. + /// + /// This initializer will finish the ``NIOAsyncChannel/outbound`` immediately. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable + public init( + wrappingChannelSynchronously channel: Channel, + configuration: Configuration = .init() + ) throws where Outbound == Never { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: false + ) + + self._outbound.finish() + } + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @inlinable public init( synchronouslyWrapping channel: Channel, configuration: Configuration = .init() @@ -108,7 +157,8 @@ public struct NIOAsyncChannel: Sendable { self.channel = channel (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, - isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: true ) } @@ -123,6 +173,7 @@ public struct NIOAsyncChannel: Sendable { /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable + @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") public init( synchronouslyWrapping channel: Channel, configuration: Configuration = .init() @@ -131,7 +182,8 @@ public struct NIOAsyncChannel: Sendable { self.channel = channel (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( backPressureStrategy: configuration.backPressureStrategy, - isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: true ) self._outbound.finish() @@ -149,12 +201,12 @@ public struct NIOAsyncChannel: Sendable { self._outbound = outboundWriter } - /// This method is only used from our server bootstrap to allow us to run the child channel initializer /// at the right moment. /// /// - Important: This is not considered stable API and should not be used. @inlinable + @available(*, deprecated, message: "This method has been deprecated since it defaults to deinit based resource teardown") public static func _wrapAsyncChannelWithTransformations( synchronouslyWrapping channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, @@ -165,6 +217,35 @@ public struct NIOAsyncChannel: Sendable { let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( backPressureStrategy: backPressureStrategy, isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: true, + channelReadTransformation: channelReadTransformation + ) + + outboundWriter.finish() + + return .init( + channel: channel, + inboundStream: inboundStream, + outboundWriter: outboundWriter + ) + } + + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. + @inlinable + public static func _wrapAsyncChannelWithTransformations( + wrappingChannelSynchronously channel: Channel, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + isOutboundHalfClosureEnabled: Bool = false, + channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture + ) throws -> NIOAsyncChannel where Outbound == Never { + channel.eventLoop.preconditionInEventLoop() + let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( + backPressureStrategy: backPressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: false, channelReadTransformation: channelReadTransformation ) @@ -229,17 +310,20 @@ extension Channel { @inlinable func _syncAddAsyncHandlers( backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - isOutboundHalfClosureEnabled: Bool + isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() let inboundStream = try NIOAsyncChannelInboundStream.makeWrappingHandler( channel: self, - backPressureStrategy: backPressureStrategy + backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: closeOnDeinit ) return (inboundStream, writer) } @@ -249,6 +333,7 @@ extension Channel { func _syncAddAsyncHandlersWithTransformations( backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() @@ -256,11 +341,13 @@ extension Channel { let inboundStream = try NIOAsyncChannelInboundStream.makeTransformationHandler( channel: self, backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, channelReadTransformation: channelReadTransformation ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: closeOnDeinit ) return (inboundStream, writer) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index f00896de94..0a672dc3ed 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -80,6 +80,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { init( channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool, handler: NIOAsyncChannelInboundStreamChannelHandler ) throws { channel.eventLoop.preconditionInEventLoop() @@ -95,7 +96,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { let sequence = Producer.makeSequence( backPressureStrategy: strategy, - finishOnDeinit: false, + finishOnDeinit: closeOnDeinit, delegate: NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate(handler: handler) ) handler.source = sequence.source @@ -107,7 +108,8 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeWrappingHandler( channel: Channel, - backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandler( eventLoop: channel.eventLoop @@ -116,6 +118,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { return try .init( channel: channel, backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, handler: handler ) } @@ -125,6 +128,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { static func makeTransformationHandler( channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandlerWithTransformations( @@ -135,6 +139,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { return try .init( channel: channel, backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, handler: handler ) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 3af5751e6b..d89332e255 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -84,7 +84,8 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { @inlinable init( channel: Channel, - isOutboundHalfClosureEnabled: Bool + isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool ) throws { let handler = NIOAsyncChannelOutboundWriterHandler( eventLoop: channel.eventLoop, @@ -93,7 +94,7 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { let writer = _Writer.makeWriter( elementType: OutboundOut.self, isWritable: true, - finishOnDeinit: false, + finishOnDeinit: closeOnDeinit, delegate: .init(handler: handler) ) handler.sink = writer.sink diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index 891ad9e4fe..7198e18260 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -85,7 +85,7 @@ the inbound data and echo it back outbound. ```swift let channel = ... -let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) +let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) try await asyncChannel.executeThenClose { inbound, outbound in for try await inboundData in inbound { @@ -186,7 +186,7 @@ let clientChannel = try await ClientBootstrap(group: eventLoopGroup) ) { channel in channel.eventLoop.makeCompletedFuture { return try NIOAsyncChannel( - synchronouslyWrapping: channel + wrappingChannelSynchronously: channel ) } } @@ -245,7 +245,7 @@ let upgradeResult: EventLoopFuture = try await ClientBootstrap(gr // This configures the pipeline after the websocket upgrade was successful. // We are wrapping the pipeline in a NIOAsyncChannel. channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) return UpgradeResult.websocket(asyncChannel) } } diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index bf70b9863e..3b565f1b71 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -654,7 +654,7 @@ extension ServerBootstrap { ) let asyncChannel = try NIOAsyncChannel ._wrapAsyncChannelWithTransformations( - synchronouslyWrapping: serverChannel, + wrappingChannelSynchronously: serverChannel, backPressureStrategy: serverBackPressureStrategy, channelReadTransformation: { channel -> EventLoopFuture in // The channelReadTransformation is run on the EL of the server channel diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index a334c7f605..c3fcb963dd 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -59,7 +59,7 @@ struct Client { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(NewlineDelimiterCoder())) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: NIOAsyncChannel.Configuration( inboundType: String.self, outboundType: String.self diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index 368df645a9..30b786b79d 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -48,7 +48,7 @@ struct Server { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(NewlineDelimiterCoder())) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: NIOAsyncChannel.Configuration( inboundType: String.self, outboundType: String.self diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index a2698536fe..8ad4db2d64 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -61,7 +61,7 @@ struct Client { // let upgrader = NIOTypedWebSocketClientUpgrader( // upgradePipelineHandler: { (channel, _) in // channel.eventLoop.makeCompletedFuture { -// let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) +// let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) // return UpgradeResult.websocket(asyncChannel) // } // } diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index dad8fbd12c..7cdc84ff64 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -91,7 +91,7 @@ struct Server { // }, // upgradePipelineHandler: { (channel, _) in // channel.eventLoop.makeCompletedFuture { -// let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel) +// let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) // return UpgradeResult.websocket(asyncChannel) // } // } @@ -102,7 +102,7 @@ struct Server { // notUpgradingCompletionHandler: { channel in // channel.eventLoop.makeCompletedFuture { // try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) -// let asyncChannel = try NIOAsyncChannel>(synchronouslyWrapping: channel) +// let asyncChannel = try NIOAsyncChannel>(wrappingChannelSynchronously: channel) // return UpgradeResult.notUpgraded(asyncChannel) // } // } diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 8c7fcf7348..7fb92d1a82 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -30,7 +30,7 @@ final class AsyncChannelTests: XCTestCase { let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try channel.pipeline.syncOperations.addHandler(CloseOnWriteHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await wrapped.executeThenClose { _, outbound in @@ -42,7 +42,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await wrapped.executeThenClose { inbound, _ in @@ -68,7 +68,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await wrapped.executeThenClose { _, outbound in @@ -91,7 +91,7 @@ final class AsyncChannelTests: XCTestCase { let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( isOutboundHalfClosureEnabled: true, inboundType: Never.self, @@ -125,7 +125,7 @@ final class AsyncChannelTests: XCTestCase { do { _ = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( isOutboundHalfClosureEnabled: false, inboundType: Never.self, @@ -151,7 +151,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.writeInbound("hello") @@ -170,7 +170,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.writeInbound("hello") @@ -193,7 +193,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.testingEventLoop.executeInContext { @@ -233,7 +233,7 @@ final class AsyncChannelTests: XCTestCase { do { // Create the NIOAsyncChannel, then drop it. The handler will still be in the pipeline. _ = try await channel.testingEventLoop.executeInContext { - _ = try NIOAsyncChannel(synchronouslyWrapping: channel) + _ = try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } @@ -258,7 +258,7 @@ final class AsyncChannelTests: XCTestCase { try await channel.pipeline.addHandler(readCounter) let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( backPressureStrategy: .init(lowWatermark: 2, highWatermark: 4), inboundType: Void.self, @@ -370,7 +370,7 @@ final class AsyncChannelTests: XCTestCase { guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await wrapped.executeThenClose { inbound, outbound in diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index fb87269dfb..099074c71b 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -224,7 +224,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( inboundType: String.self, outboundType: String.self @@ -681,7 +681,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { output: pipe2WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -695,7 +695,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { output: pipe1WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -709,7 +709,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { input: pipe2ReadFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -749,7 +749,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { output: pipe1WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -763,7 +763,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { input: pipe1ReadFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -800,7 +800,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { input: pipe1ReadFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -814,7 +814,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { output: pipe1WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -868,7 +868,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { output: pipe1WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -882,7 +882,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { input: pipe2ReadFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { @@ -1014,7 +1014,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } @@ -1048,7 +1048,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } try await stringChannel.executeThenClose { _, outbound in @@ -1098,7 +1098,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 1))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 2))) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -1115,7 +1115,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 2))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 1))) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -1163,7 +1163,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -1214,7 +1214,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -1247,7 +1247,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -1330,7 +1330,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel + wrappingChannelSynchronously: channel ) return .string(asyncChannel) @@ -1340,7 +1340,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler()) let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel + wrappingChannelSynchronously: channel ) return .byte(asyncChannel) From c46199386cc12f4e4ae799cebf361d97d4234feb Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 15 Nov 2023 18:58:27 +0000 Subject: [PATCH 48/64] Revert "Back out new typed HTTP protocol upgrader (#2579)" (#2593) * Revert "Back out new typed HTTP protocol upgrader (#2579)" # Motivation We have reverted the typed HTTP protocol upgrader pieces since adopters were running into a compiler bug (https://github.com/apple/swift/pull/69459) that caused the compiler to emit strong references to `swift_getExtendedExistentialTypeMetadata`. The problem is that `swift_getExtendedExistentialTypeMetadata` is not available on older runtimes before constrained existentials have been introduced. This caused adopters to run into runtime crashes when loading any library compiled with this NIO code. # Modifications This PR reverts the revert and guard all new code in a compiler guard that checks that we are either on non-Darwin platforms or on a new enough Swift compiler that contains the fix. # Result We can offer the typed HTTP upgrade code to our adopters again. * Add compiler guards --- Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 250 +++++++++ .../NIOTypedHTTPClientUpgradeHandler.swift | 285 ++++++++++ ...OTypedHTTPClientUpgraderStateMachine.swift | 335 +++++++++++ .../NIOTypedHTTPServerUpgradeHandler.swift | 371 +++++++++++++ ...OTypedHTTPServerUpgraderStateMachine.swift | 386 +++++++++++++ .../NIOWebSocketClientUpgrader.swift | 57 ++ .../NIOWebSocketServerUpgrader.swift | 86 +++ Sources/NIOWebSocketClient/Client.swift | 245 ++++---- Sources/NIOWebSocketServer/Server.swift | 469 ++++++++-------- .../HTTPClientUpgradeTests.swift | 240 +++++++- .../HTTPServerUpgradeTests.swift | 523 +++++++++++++++++- .../WebSocketClientEndToEndTests.swift | 213 +++++++ .../WebSocketServerEndToEndTests.swift | 29 + 13 files changed, 3124 insertions(+), 365 deletions(-) create mode 100644 Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift create mode 100644 Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift new file mode 100644 index 0000000000..4135203a8e --- /dev/null +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -0,0 +1,250 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +import NIOCore + +// MARK: - Server pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPServerPipelineConfiguration { + /// Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. + public var enablePipelining = true + + /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP + /// request) by sending 400 errors. Defaults to `true`. + public var enableErrorHandling = true + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableResponseHeaderValidation = true + + /// The configuration for the ``HTTPResponseEncoder``. + public var encoderConfiguration = HTTPResponseEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Assistance handling clients that pipeline HTTP requests. + /// 2. Assistance handling protocol errors. + /// 3. Outbound header fields validation to protect against response splitting attacks. + public init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let responseEncoder = HTTPResponseEncoder(configuration: configuration.encoderConfiguration) + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + + var extraHTTPHandlers = [RemovableChannelHandler]() + extraHTTPHandlers.reserveCapacity(4) + extraHTTPHandlers.append(requestDecoder) + + try self.addHandler(responseEncoder) + try self.addHandler(requestDecoder) + + if configuration.enablePipelining { + let pipeliningHandler = HTTPServerPipelineHandler() + try self.addHandler(pipeliningHandler) + extraHTTPHandlers.append(pipeliningHandler) + } + + if configuration.enableResponseHeaderValidation { + let headerValidationHandler = NIOHTTPResponseHeadersValidator() + try self.addHandler(headerValidationHandler) + extraHTTPHandlers.append(headerValidationHandler) + } + + if configuration.enableErrorHandling { + let errorHandler = HTTPServerProtocolErrorHandler() + try self.addHandler(errorHandler) + extraHTTPHandlers.append(errorHandler) + } + + let upgrader = NIOTypedHTTPServerUpgradeHandler( + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} + +// MARK: - Client pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPClientPipelineConfiguration { + /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. + public var leftOverBytesStrategy = RemoveAfterUpgradeStrategy.dropBytes + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableOutboundHeaderValidation = true + + /// The configuration for the ``HTTPRequestEncoder``. + public var encoderConfiguration = HTTPRequestEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Outbound header fields validation to protect against response splitting attacks. + public init( + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPClientPipeline(configuration: configuration) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let requestEncoder = HTTPRequestEncoder(configuration: configuration.encoderConfiguration) + let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy)) + var httpHandlers = [RemovableChannelHandler]() + httpHandlers.reserveCapacity(3) + httpHandlers.append(requestEncoder) + httpHandlers.append(responseDecoder) + + try self.addHandler(requestEncoder) + try self.addHandler(responseDecoder) + + if configuration.enableOutboundHeaderValidation { + let headerValidationHandler = NIOHTTPRequestHeadersValidator() + try self.addHandler(headerValidationHandler) + httpHandlers.append(headerValidationHandler) + } + + let upgrader = NIOTypedHTTPClientUpgradeHandler( + httpHandlers: httpHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift new file mode 100644 index 0000000000..ea76a74b91 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -0,0 +1,285 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2013 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +import NIOCore + +/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a client-side channel. +/// It has the option of denying this upgrade based upon the server response. +public protocol NIOTypedHTTPClientProtocolUpgrader { + associatedtype UpgradeResult: Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol requires in the request to successfully upgrade. + /// These header fields will be added to the outbound request's "Connection" header field. + /// It is the responsibility of the custom headers call to actually add these required headers. + var requiredUpgradeHeaders: [String] { get } + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + func addCustom(upgradeRequestHeaders: inout HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPClientUpgradeConfiguration { + /// The initial request head that is sent out once the channel becomes active. + public var upgradeRequestHead: HTTPRequestHead + + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgradeRequestHead: HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + precondition(upgraders.count > 0, "A minimum of one protocol upgrader must be specified.") + self.upgradeRequestHead = upgradeRequestHead + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. +/// This handler will add all appropriate headers to perform an upgrade to +/// the a protocol. It may add headers for a set of protocols in preference order. +/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply +/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. +/// +/// The request sends an order of preference to request which protocol it would like to use for the upgrade. +/// It will only upgrade to the protocol that is returned first in the list and does not currently +/// have the capability to upgrade to multiple simultaneous layered protocols. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { + public typealias OutboundIn = HTTPClientRequestPart + public typealias OutboundOut = HTTPClientRequestPart + public typealias InboundIn = HTTPClientResponsePart + public typealias InboundOut = HTTPClientResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + private let upgradeRequestHead: HTTPRequestHead + private let httpHandlers: [RemovableChannelHandler] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private var stateMachine: NIOTypedHTTPClientUpgraderStateMachine + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// Create a ``NIOTypedHTTPClientUpgradeHandler``. + /// + /// - Parameters: + /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline + /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state + /// after the upgrade. It should include any handlers that are directly related to handling HTTP. + /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include + /// any other handler that cannot tolerate receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.httpHandlers = httpHandlers + var upgradeRequestHead = upgradeConfiguration.upgradeRequestHead + Self.addHeaders( + to: &upgradeRequestHead, + upgraders: upgradeConfiguration.upgraders + ) + self.upgradeRequestHead = upgradeRequestHead + self.stateMachine = .init(upgraders: upgradeConfiguration.upgraders) + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelActive(context: ChannelHandlerContext) { + switch self.stateMachine.channelActive() { + case .writeUpgradeRequest: + context.write(self.wrapOutboundOut(.head(self.upgradeRequestHead)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(.init()))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + case .none: + break + } + } + + private static func addHeaders( + to requestHead: inout HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + ) { + let requiredHeaders = ["upgrade"] + upgraders.flatMap { $0.requiredUpgradeHeaders } + requestHead.headers.add(name: "Connection", value: requiredHeaders.joined(separator: ",")) + + let allProtocols = upgraders.map { $0.supportedProtocol.lowercased() } + requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) + + // Allow each upgrader the chance to add custom headers. + for upgrader in upgraders { + upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + switch self.stateMachine.write() { + case .failWrite(let error): + promise?.fail(error) + + case .forwardWrite: + context.write(data, promise: promise) + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let responsePart = self.unwrapInboundIn(data) + self.channelRead(context: context, responsePart: responsePart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + private func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { + switch self.stateMachine.channelReadResponsePart(responsePart) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + + case .startUpgrading(let upgrader, let responseHead): + self.startUpgrading( + context: context, + upgrader: upgrader, + responseHead: responseHead + ) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHead: HTTPResponseHead + ) { + // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the + // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. + self.removeHTTPHandlers(context: context) + .flatMap { + upgrader.upgrade(channel: context.channel, upgradeResponse: responseHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeHTTPHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.httpHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + let removeFutures = self.httpHandlers.map { context.pipeline.removeHandler($0) } + return .andAllSucceed(removeFutures, on: context.eventLoop) + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift new file mode 100644 index 0000000000..875fb2ce64 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift @@ -0,0 +1,335 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPClientUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + /// The request has been sent. We are waiting for the upgrade response. + case awaitingUpgradeResponseHead(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + @usableFromInline + struct AwaitingUpgradeResponseEnd { + var upgrader: any NIOTypedHTTPClientProtocolUpgrader + var responseHead: HTTPResponseHead + } + /// We received the response head and are just waiting for the response end. + case awaitingUpgradeResponseEnd(AwaitingUpgradeResponseEnd) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state: State + + init(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) { + self.state = .initial(upgraders: upgraders) + } + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelActiveAction { + case writeUpgradeRequest + } + + @inlinable + mutating func channelActive() -> ChannelActiveAction? { + switch self.state { + case .initial(let upgraders): + self.state = .awaitingUpgradeResponseHead(upgraders: upgraders) + return .writeUpgradeRequest + + case .finished: + return nil + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering, .upgrading: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum WriteAction { + case failWrite(Error) + case forwardWrite + } + + @usableFromInline + func write() -> WriteAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading: + return .failWrite(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) + + case .unbuffering, .finished: + return .forwardWrite + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd: + return .unwrapData + + case .upgrading(var upgrading): + // We got a read while running upgrading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + + @usableFromInline + enum ChannelReadResponsePartAction { + case fireErrorCaughtAndRemoveHandler(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHeaders: HTTPResponseHead + ) + } + + @inlinable + mutating func channelReadResponsePart(_ responsePart: HTTPClientResponsePart) -> ChannelReadResponsePartAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .awaitingUpgradeResponseHead(let upgraders): + // We should decide if we can upgrade based on the first response header: if we aren't upgrading, + // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're + // upgrading, the only thing we should see is a response head. Anything else in an error. + guard case .head(let response) = responsePart else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + } + + // Assess whether the server has accepted our upgrade request. + guard case .switchingProtocols = response.status else { + var buffer = Deque() + buffer.append(.init(responsePart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // Ok, we have a HTTP response. Check if it's an upgrade confirmation. + // If it's not, we want to pass it on and remove ourselves from the channel pipeline. + let acceptedProtocols = response.headers[canonicalForm: "upgrade"] + + // At the moment we only upgrade to the first protocol returned from the server. + guard let protocolName = acceptedProtocols.first?.lowercased() else { + // There are no upgrade protocols returned. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + let matchingUpgrader = upgraders + .first(where: { $0.supportedProtocol.lowercased() == protocolName }) + + guard let upgrader = matchingUpgrader else { + // There is no upgrader for this protocol. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { + // The upgrader says no. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // We received the response head and decided that we can upgrade. + // We now need to wait for the response end and then we can perform the upgrade + self.state = .awaitingUpgradeResponseEnd(.init( + upgrader: upgrader, + responseHead: response + )) + return .none + + case .awaitingUpgradeResponseEnd(let awaitingUpgradeResponseEnd): + switch responsePart { + case .head: + // We got two HTTP response heads. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + + case .body: + // We tolerate body parts to be send but just ignore them + return .none + + case .end: + // We got the response end and can now run the upgrader. + self.state = .upgrading(.init(buffer: .init())) + return .startUpgrading( + upgrader: awaitingUpgradeResponseEnd.upgrader, + responseHeaders: awaitingUpgradeResponseEnd.responseHead + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + } + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift new file mode 100644 index 0000000000..1a1a47988c --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -0,0 +1,371 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +import NIOCore + +/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a server-side channel. +public protocol NIOTypedHTTPServerProtocolUpgrader { + associatedtype UpgradeResult: Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol needs in the request to successfully upgrade. These header fields + /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated + /// against the inbound request's `Connection` header field. + var requiredUpgradeHeaders: [String] { get } + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + func upgrade( + channel: Channel, + upgradeRequest: HTTPRequestHead + ) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPServerUpgradeConfiguration { + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPServerProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgraders: [any NIOTypedHTTPServerProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. +/// +/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of +/// whether the upgrade succeeded or not. +/// +/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade +/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's +/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined +/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: +/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias InboundOut = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private let upgraders: [String: any NIOTypedHTTPServerProtocolUpgrader] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private let httpEncoder: HTTPResponseEncoder + private let extraHTTPHandlers: [RemovableChannelHandler] + private var stateMachine = NIOTypedHTTPServerUpgraderStateMachine() + + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + /// Create a ``NIOTypedHTTPServerUpgradeHandler``. + /// + /// - Parameters: + /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will + /// be removed from the pipeline once the upgrade response is sent. This is used to ensure + /// that the pipeline will be in a clean state after upgrade. + /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least + /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate + /// receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpEncoder: HTTPResponseEncoder, + extraHTTPHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + var upgraderMap = [String: any NIOTypedHTTPServerProtocolUpgrader]() + for upgrader in upgradeConfiguration.upgraders { + upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader + } + self.upgraders = upgraderMap + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + self.httpEncoder = httpEncoder + self.extraHTTPHandlers = extraHTTPHandlers + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let requestPart = self.unwrapInboundIn(data) + self.channelRead(context: context, requestPart: requestPart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { + switch self.stateMachine.channelReadRequestPart(requestPart) { + case .failUpgradePromise(let error): + self.upgradeResultPromise.fail(error) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .findUpgrader(let head, let requestedProtocols, let allHeaderNames, let connectionHeader): + let protocolIterator = requestedProtocols.makeIterator() + self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: head, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ).whenComplete { result in + context.eventLoop.assertInEventLoop() + self.findingUpgradeCompleted(context: context, requestHead: head, result) + } + + case .startUpgrading(let upgrader, let requestHead, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .none: + break + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result, + requestHeadAndProtocol: (HTTPRequestHead, String)? + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + /// Attempt to upgrade a single protocol. + /// + /// Will recurse through `protocolIterator` if upgrade fails. + private func handleUpgradeForProtocol( + context: ChannelHandlerContext, + protocolIterator: Array.Iterator, + request: HTTPRequestHead, + allHeaderNames: Set, + connectionHeader: Set + ) -> EventLoopFuture<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?> { + // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. + var protocolIterator = protocolIterator + guard let proto = protocolIterator.next() else { + // We're done! No suitable protocol for upgrade. + return context.eventLoop.makeSucceededFuture(nil) + } + + guard let upgrader = self.upgraders[proto.lowercased()] else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) + guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let responseHeaders = self.buildUpgradeHeaders(protocol: proto) + return upgrader.buildUpgradeResponse( + channel: context.channel, + upgradeRequest: request, + initialResponseHeaders: responseHeaders + ) + .hop(to: context.eventLoop) + .map { (upgrader, $0, proto) } + .flatMapError { error in + // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. + context.fireErrorCaught(error) + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + } + + private func findingUpgradeCompleted( + context: ChannelHandlerContext, + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) { + switch self.stateMachine.findingUpgraderCompleted(requestHead: requestHead, result) { + case .startUpgrading(let upgrader, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) { + // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP + // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until + // that completes. + // While there are a lot of Futures involved here it's quite possible that all of this code will + // actually complete synchronously: we just want to program for the possibility that it won't. + // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the + // internal handler, then call the user code, and then finally when the user code is done we do + // our final cleanup steps, namely we replay the received data we buffered in the meantime and + // then remove ourselves from the pipeline. + self.removeExtraHandlers(context: context).flatMap { + self.sendUpgradeResponse(context: context, responseHeaders: responseHeaders) + }.flatMap { + context.pipeline.removeHandler(self.httpEncoder) + }.flatMap { () -> EventLoopFuture in + return upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) + } + } + + /// Sends the 101 Switching Protocols response for the pipeline. + private func sendUpgradeResponse(context: ChannelHandlerContext, responseHeaders: HTTPHeaders) -> EventLoopFuture { + var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) + response.headers = responseHeaders + return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) + } + + /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. + private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { + return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeExtraHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.extraHTTPHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, + on: context.eventLoop) + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift new file mode 100644 index 0000000000..c4fa19c348 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -0,0 +1,386 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPServerUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial + + @usableFromInline + struct AwaitingUpgrader { + var seenFirstRequest: Bool + var buffer: Deque + } + + /// The request head has been received. We're currently running the future chain awaiting an upgrader. + case awaitingUpgrader(AwaitingUpgrader) + + @usableFromInline + struct UpgraderReady { + var upgrader: any NIOTypedHTTPServerProtocolUpgrader + var requestHead: HTTPRequestHead + var responseHeaders: HTTPHeaders + var proto: String + var buffer: Deque + } + + /// We have an upgrader, which means we can begin upgrade we are just waiting for the request end. + case upgraderReady(UpgraderReady) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state = State.initial + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgrader(var awaitingUpgrader): + if awaitingUpgrader.seenFirstRequest { + // We should buffer the data since we have seen the full request. + self.state = .modifying + awaitingUpgrader.buffer.append(data) + self.state = .awaitingUpgrader(awaitingUpgrader) + return nil + } else { + // We shouldn't buffer. This means we are still expecting HTTP parts. + return .unwrapData + } + + case .upgraderReady: + // We have not seen the end of the HTTP request so this + // data is probably an HTTP request part. + return .unwrapData + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .upgrading(var upgrading): + // We got a read while running ugprading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadRequestPartAction { + case failUpgradePromise(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) + case findUpgrader( + head: HTTPRequestHead, + requestedProtocols: [String], + allHeaderNames: Set, + connectionHeader: Set + ) + } + + @inlinable + mutating func channelReadRequestPart(_ requestPart: HTTPServerRequestPart) -> ChannelReadRequestPartAction? { + switch self.state { + case .initial: + guard case .head(let head) = requestPart else { + // The first data that we saw was not a head. This is a protocol error and we are just going to + // fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + } + + // Ok, we have a HTTP head. Check if it's an upgrade. + let requestedProtocols = head.headers[canonicalForm: "upgrade"].map(String.init) + guard requestedProtocols.count > 0 else { + // We have to buffer now since we got the request head but are not upgrading. + // The user is configuring the HTTP pipeline now. + var buffer = Deque() + buffer.append(NIOAny(requestPart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // We can now transition to awaiting the upgrader. This means that we are trying to + // find an upgrade that can handle requested protocols. We are not buffering because + // we are waiting for the request end. + self.state = .awaitingUpgrader(.init(seenFirstRequest: false, buffer: .init())) + + let connectionHeader = Set(head.headers[canonicalForm: "connection"].map { $0.lowercased() }) + let allHeaderNames = Set(head.headers.map { $0.name.lowercased() }) + + return .findUpgrader( + head: head, + requestedProtocols: requestedProtocols, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) + + case .awaitingUpgrader(let awaitingUpgrader): + switch (awaitingUpgrader.seenFirstRequest, requestPart) { + case (true, _): + // This is weird we are seeing more requests parts after we have seen an end + // Let's fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .head): + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .body): + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .end): + // Okay we got the end as expected. Just gotta store this in our state. + self.state = .awaitingUpgrader(.init(seenFirstRequest: true, buffer: awaitingUpgrader.buffer)) + return nil + } + + case .upgraderReady(let upgraderReady): + switch requestPart { + case .head: + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .body: + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .end: + // Okay we got the end as expected and our upgrader is ready so let's start upgrading + self.state = .upgrading(.init(buffer: upgraderReady.buffer)) + return .startUpgrading( + upgrader: upgraderReady.upgrader, + requestHead: upgraderReady.requestHead, + responseHeaders: upgraderReady.responseHeaders, + proto: upgraderReady.proto + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .awaitingUpgrader, .upgraderReady, .unbuffering: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum FindingUpgraderCompletedAction { + case startUpgrading(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String) + case runNotUpgradingInitializer + case fireErrorCaughtAndStartUnbuffering(Error) + case fireErrorCaughtAndRemoveHandler(Error) + } + + @inlinable + mutating func findingUpgraderCompleted( + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) -> FindingUpgraderCompletedAction? { + switch self.state { + case .initial, .upgraderReady: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .awaitingUpgrader(let awaitingUpgrader): + switch result { + case .success(.some((let upgrader, let responseHeaders, let proto))): + if awaitingUpgrader.seenFirstRequest { + // We have seen the end of the request. So we can upgrade now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .startUpgrading(upgrader: upgrader, responseHeaders: responseHeaders, proto: proto) + } else { + // We have not yet seen the end so we have to wait until that happens + self.state = .upgraderReady(.init( + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto, + buffer: awaitingUpgrader.buffer + )) + return nil + } + + case .success(.none): + // There was no upgrader to handle the request. We just run the not upgrading + // initializer now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .runNotUpgradingInitializer + + case .failure(let error): + if !awaitingUpgrader.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: awaitingUpgrader.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + } + } + +} +#endif diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index 6483954bde..d1b190c288 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -74,6 +74,63 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { } } +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the +/// pipeline to remove the HTTP `ChannelHandler`s. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedWebSocketClientUpgrader: NIOTypedHTTPClientProtocolUpgrader { + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + /// None of the websocket headers are actually defined as 'required'. + public let requiredUpgradeHeaders: [String] = [] + + private let requestKey: String + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture + + /// - Parameters: + /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. + /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. + /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. + /// - upgradePipelineHandler: Called once the upgrade was successful. + public init( + requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture + ) { + precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.requestKey = requestKey + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling + } + + public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) { + _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) + } + + public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { + _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) + } + + public func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeResponse: upgradeResponse, + maxFrameSize: self.maxFrameSize, + enableAutomaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} +#endif + @available(*, unavailable) extension NIOWebSocketClientUpgrader: Sendable {} diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 44b9f56731..14f29f750b 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -175,6 +175,92 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch } } +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// Users may frequently want to offer multiple websocket endpoints on the same port. For this +/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to +/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers +/// (including the path) and determines whether this is a websocket upgrade request that is acceptable +/// to them. +/// +/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to +/// remove the HTTP `ChannelHandler`s. +public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { + private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + + /// We deliberately do not actually set any required headers here, because the websocket + /// spec annoyingly does not actually force the client to send these in the Upgrade header, + /// which NIO requires. We check for these manually. + public let requiredUpgradeHeaders: [String] = [] + + private let shouldUpgrade: ShouldUpgrade + private let upgradePipelineHandler: UpgradePipelineHandler + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + + /// Create a new ``NIOTypedWebSocketServerUpgrader``. + /// + /// - Parameters: + /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the + /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but + /// this is an objectively unreasonable maximum value (on AMD64 systems it is not + /// possible to even. Users may set this to any value up to `UInt32.max`. + /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol + /// errors by sending error responses and closing the connection. Defaults to `true`, + /// may be set to `false` if the user wishes to handle their own errors. + /// - shouldUpgrade: A callback that determines whether the websocket request should be + /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with + /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, + /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return + /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. + /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is + /// flushed, and that is expected to mutate the `Channel` appropriately to handle the + /// websocket protocol. This only needs to add the user handlers: the + /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the + /// pipeline automatically. + public init( + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + ) { + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.shouldUpgrade = shouldUpgrade + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling + } + + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + _buildUpgradeResponse( + channel: channel, + upgradeRequest: upgradeRequest, + initialResponseHeaders: initialResponseHeaders, + shouldUpgrade: self.shouldUpgrade + ) + } + + public func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeRequest: upgradeRequest, + maxFrameSize: self.maxFrameSize, + automaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} +#endif + private func _buildUpgradeResponse( channel: Channel, upgradeRequest: HTTPRequestHead, diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index 8ad4db2d64..5efa89993a 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -11,137 +11,130 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) +#if (!canImport(Darwin) && swift(>=5.9)) || (canImport(Darwin) && swift(>=5.10)) +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket + +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Client { - static func main() { - fatalError("Disabled due to https://github.com/apple/swift-nio/issues/2574") + /// The host to connect to. + private let host: String + /// The port to connect to. + private let port: Int + /// The client's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded } -} -// Commented out due https://github.com/apple/swift-nio/issues/2574 + static func main() async throws { + let client = Client( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await client.run() + } -//import NIOCore -//import NIOPosix -//import NIOHTTP1 -//import NIOWebSocket -// -//@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) -//@main -//struct Client { -// /// The host to connect to. -// private let host: String -// /// The port to connect to. -// private let port: Int -// /// The client's event loop group. -// private let eventLoopGroup: MultiThreadedEventLoopGroup -// -// enum UpgradeResult { -// case websocket(NIOAsyncChannel) -// case notUpgraded -// } -// -// static func main() async throws { -// let client = Client( -// host: "localhost", -// port: 8888, -// eventLoopGroup: .singleton -// ) -// try await client.run() -// } -// -// /// This method starts the client and tries to setup a WebSocket connection. -// func run() async throws { -// let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) -// .connect( -// host: self.host, -// port: self.port -// ) { channel in -// channel.eventLoop.makeCompletedFuture { -// let upgrader = NIOTypedWebSocketClientUpgrader( -// upgradePipelineHandler: { (channel, _) in -// channel.eventLoop.makeCompletedFuture { -// let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) -// return UpgradeResult.websocket(asyncChannel) -// } -// } -// ) -// -// var headers = HTTPHeaders() -// headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") -// headers.add(name: "Content-Length", value: "0") -// -// let requestHead = HTTPRequestHead( -// version: .http1_1, -// method: .GET, -// uri: "/", -// headers: headers -// ) -// -// let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( -// upgradeRequestHead: requestHead, -// upgraders: [upgrader], -// notUpgradingCompletionHandler: { channel in -// channel.eventLoop.makeCompletedFuture { -// return UpgradeResult.notUpgraded -// } -// } -// ) -// -// let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( -// configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) -// ) -// -// return negotiationResultFuture -// } -// } -// -// // We are awaiting and handling the upgrade result now. -// try await self.handleUpgradeResult(upgradeResult) -// } -// -// /// This method handles the upgrade result. -// private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { -// switch try await upgradeResult.get() { -// case .websocket(let websocketChannel): -// print("Handling websocket connection") -// try await self.handleWebsocketChannel(websocketChannel) -// print("Done handling websocket connection") -// case .notUpgraded: -// // The upgrade to websocket did not succeed. We are just exiting in this case. -// print("Upgrade declined") -// } -// } -// -// private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { -// // We are sending a ping frame and then -// // start to handle all inbound frames. -// -// let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) -// try await channel.outbound.write(pingFrame) -// -// for try await frame in channel.inbound { -// switch frame.opcode { -// case .pong: -// print("Received pong: \(String(buffer: frame.data))") -// -// case .text: -// print("Received: \(String(buffer: frame.data))") -// -// case .connectionClose: -// // Handle a received close frame. We're just going to close by returning from this method. -// print("Received Close instruction from server") -// return -// case .binary, .continuation, .ping: -// // We ignore these frames. -// break -// default: -// // Unknown frames are errors. -// return -// } -// } -// } -//} + /// This method starts the client and tries to setup a WebSocket connection. + func run() async throws { + let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) + .connect( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketClientUpgrader( + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "0") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + return UpgradeResult.notUpgraded + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are awaiting and handling the upgrade result now. + try await self.handleUpgradeResult(upgradeResult) + } + + /// This method handles the upgrade result. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded: + // The upgrade to websocket did not succeed. We are just exiting in this case. + print("Upgrade declined") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + // We are sending a ping frame and then + // start to handle all inbound frames. + + let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) + try await channel.executeThenClose { inbound, outbound in + try await outbound.write(pingFrame) + + for try await frame in inbound { + switch frame.opcode { + case .pong: + print("Received pong: \(String(buffer: frame.data))") + + case .text: + print("Received: \(String(buffer: frame.data))") + + case .connectionClose: + // Handle a received close frame. We're just going to close by returning from this method. + print("Received Close instruction from server") + return + case .binary, .continuation, .ping: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } + } +} #else @main diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index 7cdc84ff64..9ef311c57a 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) +#if (!canImport(Darwin) && swift(>=5.9)) || (canImport(Darwin) && swift(>=5.10)) import NIOCore import NIOPosix import NIOHTTP1 @@ -41,247 +41,244 @@ let websocketResponse = """ """ +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Server { - static func main() { - fatalError("Disabled due to https://github.com/apple/swift-nio/issues/2574") + /// The server's host. + private let host: String + /// The server's port. + private let port: Int + /// The server's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + private static let responseBody = ByteBuffer(string: websocketResponse) + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded(NIOAsyncChannel>) + } + + static func main() async throws { + let server = Server( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await server.run() + } + + /// This method starts the server and handles incoming connections. + func run() async throws { + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketServerUpgrader( + shouldUpgrade: { (channel, head) in + channel.eventLoop.makeSucceededFuture(HTTPHeaders()) + }, + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) + let asyncChannel = try NIOAsyncChannel>(wrappingChannelSynchronously: channel) + return UpgradeResult.notUpgraded(asyncChannel) + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await upgradeResult in inbound { + group.addTask { + await self.handleUpgradeResult(upgradeResult) + } + } + } + } + } + + /// This method handles a single connection by echoing back all inbound data. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded(let httpChannel): + print("Handling HTTP connection") + try await self.handleHTTPChannel(httpChannel) + print("Done handling HTTP connection") + } + } catch { + print("Hit error: \(error)") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + try await channel.executeThenClose { inbound, outbound in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for try await frame in inbound { + switch frame.opcode { + case .ping: + print("Received ping") + var frameData = frame.data + let maskingKey = frame.maskKey + + if let maskingKey = maskingKey { + frameData.webSocketUnmask(maskingKey) + } + + let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) + try await outbound.write(responseFrame) + + case .connectionClose: + // This is an unsolicited close. We're going to send a response frame and + // then, when we've sent it, close up shop. We should send back the close code the remote + // peer sent us, unless they didn't send one at all. + print("Received close") + var data = frame.unmaskedData + let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() + let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) + try await outbound.write(closeFrame) + return + case .binary, .continuation, .pong: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } + + group.addTask { + // This is our main business logic where we are just sending the current time + // every second. + while true { + // We can't really check for error here, but it's also not the purpose of the + // example so let's not worry about it. + let theTime = ContinuousClock().now + var buffer = channel.channel.allocator.buffer(capacity: 12) + buffer.writeString("\(theTime)") + + let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) + + print("Sending time") + try await outbound.write(frame) + try await Task.sleep(for: .seconds(1)) + } + } + + try await group.next() + group.cancelAll() + } + } + } + + + private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { + try await channel.executeThenClose { inbound, outbound in + for try await requestPart in inbound { + // We're not interested in request bodies here: we're just serving up GET responses + // to get the client to initiate a websocket request. + guard case .head(let head) = requestPart else { + return + } + + // GETs only. + guard case .GET = head.method else { + try await self.respond405(writer: outbound) + return + } + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/html") + headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) + headers.add(name: "Connection", value: "close") + let responseHead = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: headers + ) + + try await outbound.write( + contentsOf: [ + .head(responseHead), + .body(Self.responseBody), + .end(nil) + ] + ) + } + } + } + + private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { + var headers = HTTPHeaders() + headers.add(name: "Connection", value: "close") + headers.add(name: "Content-Length", value: "0") + let head = HTTPResponseHead( + version: .http1_1, + status: .methodNotAllowed, + headers: headers + ) + + try await writer.write( + contentsOf: [ + .head(head), + .end(nil) + ] + ) } } -// Commented out due https://github.com/apple/swift-nio/issues/2574 +final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPPart + typealias OutboundOut = HTTPServerResponsePart -//@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) -//@main -//struct Server { -// /// The server's host. -// private let host: String -// /// The server's port. -// private let port: Int -// /// The server's event loop group. -// private let eventLoopGroup: MultiThreadedEventLoopGroup -// -// private static let responseBody = ByteBuffer(string: websocketResponse) -// -// enum UpgradeResult { -// case websocket(NIOAsyncChannel) -// case notUpgraded(NIOAsyncChannel>) -// } -// -// static func main() async throws { -// let server = Server( -// host: "localhost", -// port: 8888, -// eventLoopGroup: .singleton -// ) -// try await server.run() -// } -// -// /// This method starts the server and handles incoming connections. -// func run() async throws { -// let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) -// .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) -// .bind( -// host: self.host, -// port: self.port -// ) { channel in -// channel.eventLoop.makeCompletedFuture { -// let upgrader = NIOTypedWebSocketServerUpgrader( -// shouldUpgrade: { (channel, head) in -// channel.eventLoop.makeSucceededFuture(HTTPHeaders()) -// }, -// upgradePipelineHandler: { (channel, _) in -// channel.eventLoop.makeCompletedFuture { -// let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) -// return UpgradeResult.websocket(asyncChannel) -// } -// } -// ) -// -// let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( -// upgraders: [upgrader], -// notUpgradingCompletionHandler: { channel in -// channel.eventLoop.makeCompletedFuture { -// try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) -// let asyncChannel = try NIOAsyncChannel>(wrappingChannelSynchronously: channel) -// return UpgradeResult.notUpgraded(asyncChannel) -// } -// } -// ) -// -// let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( -// configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) -// ) -// -// return negotiationResultFuture -// } -// } -// -// // We are handling each incoming connection in a separate child task. It is important -// // to use a discarding task group here which automatically discards finished child tasks. -// // A normal task group retains all child tasks and their outputs in memory until they are -// // consumed by iterating the group or by exiting the group. Since, we are never consuming -// // the results of the group we need the group to automatically discard them; otherwise, this -// // would result in a memory leak over time. -// try await withThrowingDiscardingTaskGroup { group in -// for try await upgradeResult in channel.inbound { -// group.addTask { -// await self.handleUpgradeResult(upgradeResult) -// } -// } -// } -// } -// -// /// This method handles a single connection by echoing back all inbound data. -// private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { -// // Note that this method is non-throwing and we are catching any error. -// // We do this since we don't want to tear down the whole server when a single connection -// // encounters an error. -// do { -// switch try await upgradeResult.get() { -// case .websocket(let websocketChannel): -// print("Handling websocket connection") -// try await self.handleWebsocketChannel(websocketChannel) -// print("Done handling websocket connection") -// case .notUpgraded(let httpChannel): -// print("Handling HTTP connection") -// try await self.handleHTTPChannel(httpChannel) -// print("Done handling HTTP connection") -// } -// } catch { -// print("Hit error: \(error)") -// } -// } -// -// private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { -// try await withThrowingTaskGroup(of: Void.self) { group in -// group.addTask { -// for try await frame in channel.inbound { -// switch frame.opcode { -// case .ping: -// print("Received ping") -// var frameData = frame.data -// let maskingKey = frame.maskKey -// -// if let maskingKey = maskingKey { -// frameData.webSocketUnmask(maskingKey) -// } -// -// let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) -// try await channel.outbound.write(responseFrame) -// -// case .connectionClose: -// // This is an unsolicited close. We're going to send a response frame and -// // then, when we've sent it, close up shop. We should send back the close code the remote -// // peer sent us, unless they didn't send one at all. -// print("Received close") -// var data = frame.unmaskedData -// let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() -// let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) -// try await channel.outbound.write(closeFrame) -// return -// case .binary, .continuation, .pong: -// // We ignore these frames. -// break -// default: -// // Unknown frames are errors. -// return -// } -// } -// } -// -// group.addTask { -// // This is our main business logic where we are just sending the current time -// // every second. -// while true { -// // We can't really check for error here, but it's also not the purpose of the -// // example so let's not worry about it. -// let theTime = ContinuousClock().now -// var buffer = channel.channel.allocator.buffer(capacity: 12) -// buffer.writeString("\(theTime)") -// -// let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) -// -// print("Sending time") -// try await channel.outbound.write(frame) -// try await Task.sleep(for: .seconds(1)) -// } -// } -// -// try await group.next() -// group.cancelAll() -// } -// } -// -// -// private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { -// for try await requestPart in channel.inbound { -// // We're not interested in request bodies here: we're just serving up GET responses -// // to get the client to initiate a websocket request. -// guard case .head(let head) = requestPart else { -// return -// } -// -// // GETs only. -// guard case .GET = head.method else { -// try await self.respond405(writer: channel.outbound) -// return -// } -// -// var headers = HTTPHeaders() -// headers.add(name: "Content-Type", value: "text/html") -// headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) -// headers.add(name: "Connection", value: "close") -// let responseHead = HTTPResponseHead( -// version: .init(major: 1, minor: 1), -// status: .ok, -// headers: headers -// ) -// -// try await channel.outbound.write( -// contentsOf: [ -// .head(responseHead), -// .body(Self.responseBody), -// .end(nil) -// ] -// ) -// } -// } -// -// private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { -// var headers = HTTPHeaders() -// headers.add(name: "Connection", value: "close") -// headers.add(name: "Content-Length", value: "0") -// let head = HTTPResponseHead( -// version: .http1_1, -// status: .methodNotAllowed, -// headers: headers -// ) -// -// try await writer.write( -// contentsOf: [ -// .head(head), -// .end(nil) -// ] -// ) -// } -//} -// -//final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { -// typealias OutboundIn = HTTPPart -// typealias OutboundOut = HTTPServerResponsePart -// -// func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { -// let part = self.unwrapOutboundIn(data) -// switch part { -// case .head(let head): -// context.write(self.wrapOutboundOut(.head(head)), promise: promise) -// case .body(let buffer): -// context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) -// case .end(let trailers): -// context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) -// } -// } -//} + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let part = self.unwrapOutboundIn(data) + switch part { + case .head(let head): + context.write(self.wrapOutboundOut(.head(head)), promise: promise) + case .body(let buffer): + context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) + case .end(let trailers): + context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) + } + } +} #else @main diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 7bdd4c3622..195338f9ef 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -32,8 +32,13 @@ extension EmbeddedChannel { } } +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} +#else @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader {} +#endif private final class SuccessfulClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String @@ -282,8 +287,9 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) private func assertPipelineContainsUpgradeHandler(channel: Channel) { let handler = try? channel.pipeline.syncOperations.handler(type: NIOHTTPClientUpgradeHandler.self) + let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) - XCTAssertTrue(handler != nil) + XCTAssertTrue(handler != nil || typedHandler != nil) } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) @@ -946,3 +952,235 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } } + +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { + override func setUpClientChannel( + clientHTTPHandler: RemovableChannelHandler, + clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], + _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void + ) throws -> EmbeddedChannel { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array(clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader }) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: upgraders + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(clientHTTPHandler) + }.map { _ in + false + } + } + var configuration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: config) + configuration.leftOverBytesStrategy = .forwardBytes + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: configuration) + let context = try channel.pipeline.syncOperations.context(handlerType: NIOTypedHTTPClientUpgradeHandler.self) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + upgradeResult.whenSuccess { result in + if result { + upgradeCompletionHandler(context) + } + } + + return channel + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testUpgradeOnlyHandlesKnownProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { + let upgradeProtocol = "myProto" + + var upgradeHandlerCallbackFired = false + + let clientUpgrader = DenyingClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .upgraderDeniedUpgrade) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is denied) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testFiresOutboundErrorDuringAddingHandlers() throws { + let upgradeProtocol = "myProto" + var errorOnAdditionalChannelWrite: Error? + var upgradeHandlerCallbackFired = false + + let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + let promise = clientChannel.eventLoop.makePromise(of: Void.self) + + promise.futureResult.whenFailure() { error in + errorOnAdditionalChannelWrite = error + } + + // Send another outbound request during the upgrade. + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let secondRequest: HTTPClientRequestPart = .head(requestHead) + clientChannel.writeAndFlush(secondRequest, promise: promise) + + clientChannel.embeddedEventLoop.run() + + let promiseError = errorOnAdditionalChannelWrite as! NIOHTTPClientUpgradeError + XCTAssertEqual(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade, promiseError) + + // Soundness check that the upgrade was delayed. + XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + + // Upgrade now. + clientUpgrader.unblockUpgrade() + clientChannel.embeddedEventLoop.run() + + // Check that the upgrade was still successful, despite the interruption. + XCTAssert(upgradeHandlerCallbackFired) + XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + } + + override func testUpgradeResponseMissingAllProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } +} +#endif diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 4393adcfc6..70d55eab55 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -36,7 +36,11 @@ extension ChannelPipeline { @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func assertContainsUpgrader() { - self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + } catch { + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + } } func assertContains(handlerType: Handler.Type) { @@ -59,7 +63,15 @@ extension ChannelPipeline { // handler present, keep waiting usleep(50) } catch ChannelPipelineError.notFound { - return + // Checking if the typed variant is present + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + // handler present, keep waiting + usleep(50) + } catch ChannelPipelineError.notFound { + // No upgrader, we're good. + return + } } } @@ -162,8 +174,13 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e XCTAssertEqual(lines.count, 0) } +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} +#else @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader {} +#endif private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String @@ -1539,3 +1556,505 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.pipeline.assertContainsUpgrader() } } + +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { + fileprivate override func setUpTestWithAutoremoval( + pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler + ) throws -> (Channel, Channel, Channel) { + let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self) + let serverChannelFuture = ServerBootstrap(group: Self.eventLoop) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + connectionChannelPromise.succeed(channel) + var configuration = NIOUpgradableHTTPServerPipelineConfiguration( + upgradeConfiguration: .init( + upgraders: upgraders.map { $0 as! any NIOTypedHTTPServerProtocolUpgrader }, + notUpgradingCompletionHandler: { notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) } + ) + ) + configuration.enablePipelining = pipelining + return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(configuration: configuration) + .flatMap { result in + if result { + return channel.pipeline.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self) + .map { + upgradeCompletionHandler($0) + } + } else { + return channel.eventLoop.makeSucceededVoidFuture() + } + } + } + .flatMap { _ in + let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) + } + }.bind(host: "127.0.0.1", port: 0) + let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannelFuture.wait().localAddress!) + return (try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) + } + + func testNotUpgrading() throws { + let notUpgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + notUpgradingHandler: { channel in + notUpgraderCbFired.wrappedValue = true + // We're closing the connection now. + channel.close(promise: nil) + return channel.eventLoop.makeSucceededFuture(true) + } + ) { _ in } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + XCTAssertEqual(resultString, "") + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that the not upgrader got called. + XCTAssert(notUpgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testSimpleUpgradeSucceeds() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + // This is called before completion block. + upgradeRequest.wrappedValue = req + upgradeHandlerCbFired.wrappedValue = true + + XCTAssert(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeRespectsClientPreference() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = ExplodingUpgrader(forProtocol: "exploder") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: []) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + } + + override func testUpgraderCanRejectUpgradeForPersonalReasons() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = UpgraderSaysNo(forProtocol: "noproto") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + let errorCatcher = ErrorSaver() + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [errorCatcher]) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + + // And we want to confirm we saved the error. + XCTAssertEqual(errorCatcher.errors.count, 1) + + switch(errorCatcher.errors[0]) { + case UpgraderSaysNo.No.no: + break + default: + XCTFail("Unexpected error: \(errorCatcher.errors[0])") + } + } + + override func testUpgradeWithUpgradePayloadInlineWithRequestWorks() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + enum ReceivedTheWrongThingError: Error { case error } + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias OutboundIn = Never + typealias OutboundOut = Never + + enum State { + case fresh + case added + case inlineDataRead + case extraDataRead + case closed + } + + private let firstByteDonePromise: EventLoopPromise + private let secondByteDonePromise: EventLoopPromise + private let allDonePromise: EventLoopPromise + private var state = State.fresh + + init(firstByteDonePromise: EventLoopPromise, + secondByteDonePromise: EventLoopPromise, + allDonePromise: EventLoopPromise) { + self.firstByteDonePromise = firstByteDonePromise + self.secondByteDonePromise = secondByteDonePromise + self.allDonePromise = allDonePromise + } + + func handlerAdded(context: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .added + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + XCTAssertEqual(1, buf.readableBytes) + let stringRead = buf.readString(length: buf.readableBytes) + switch self.state { + case .added: + XCTAssertEqual("A", stringRead) + self.state = .inlineDataRead + if stringRead == .some("A") { + self.firstByteDonePromise.succeed(()) + } else { + self.firstByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + case .inlineDataRead: + XCTAssertEqual("B", stringRead) + self.state = .extraDataRead + context.channel.close(promise: nil) + if stringRead == .some("B") { + self.secondByteDonePromise.succeed(()) + } else { + self.secondByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + default: + XCTFail("channel read in wrong state \(self.state)") + } + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + XCTAssertEqual(.extraDataRead, self.state) + self.state = .closed + context.close(mode: mode, promise: promise) + + self.allDonePromise.succeed(()) + } + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let promiseGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try promiseGroup.syncShutdownGracefully()) + } + let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let allDonePromise = promiseGroup.next().makePromise(of: Void.self) + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, + secondByteDonePromise: secondByteDonePromise, + allDonePromise: allDonePromise)) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + request += "A" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + XCTAssertNoThrow(try firstByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: "B"))).wait()) + + XCTAssertNoThrow(try secondByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try allDonePromise.futureResult.wait() as Void) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + + XCTAssertNoThrow(try allDonePromise.futureResult.wait()) + } + + override func testWeTolerateUpgradeFuturesFromWrongEventLoops() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + let otherELG = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", + requiringHeaders: ["kafkaesque"], + buildUpgradeResponseFuture: { + // this is the wrong EL + otherELG.next().makeSucceededFuture($1) + }) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeFiresUserEvent() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let eventSaver = UnsafeTransfer(UserEventSaver()) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: [eventSaver.wrappedValue]) { context in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we should have received one user event. We schedule this onto the + // event loop to guarantee thread safety. + XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { + XCTAssertEqual(proto, "myproto") + XCTAssertEqual(req.method, .OPTIONS) + XCTAssertEqual(req.uri, "*") + XCTAssertEqual(req.version, .http1_1) + } else { + XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") + } + }.futureResult.wait()) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + } +} +#endif diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 137e897988..35b89c02b7 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -404,3 +404,216 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } } + +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests { + func setUpClientChannel( + clientUpgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) throws -> (EmbeddedChannel, EventLoopFuture) { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: clientUpgraders, + notUpgradingCompletionHandler: notUpgradingCompletionHandler + ) + + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: .init(upgradeConfiguration: config)) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + + return (channel, upgradeResult) + } + + override func testSimpleUpgradeSucceeds() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Read the server request. + if let requestString = try clientChannel.readByteBufferOutputAsString() { + XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") + } else { + XCTFail() + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // Once upgraded, validate the http pipeline has been removed. + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + // Check that the pipeline now has the correct websocket handlers added. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketFrameEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketRecorderHandler.self)) + + try upgradeResult.wait() + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + override func testRejectUpgradeIfMissingAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with a missing accept key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfIncorrectAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "notACorrectKeyL1am=F1y=nn=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with an incorrect response key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfNotWebsocket() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response with an incorrect protocol. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myProtocol\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + } + + override fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { + let handler = WebSocketRecorderHandler() + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(handler) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:yKEqitDFPE81FyIhKTm+ojBqigk=\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // We now have a successful upgrade, clear the output channels read to test the frames. + XCTAssertNoThrow(try clientChannel.readOutbound(as: ByteBuffer.self)) + + clientChannel.embeddedEventLoop.run() + + try upgradeResult.wait() + + return (clientChannel, handler) + } +} +#endif diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index 2a1a3c6980..609a4e9325 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -526,3 +526,32 @@ class WebSocketServerEndToEndTests: XCTestCase { XCTAssertNoThrow(XCTAssertEqual([], try server.readAllOutboundBytes())) } } + +#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { + override func createTestFixtures( + upgraders: [WebSocketServerUpgraderConfiguration] + ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { + let loop = EmbeddedEventLoop() + let serverChannel = EmbeddedChannel(loop: loop) + let upgraders = upgraders.map { NIOTypedWebSocketServerUpgrader( + maxFrameSize: $0.maxFrameSize, + enableAutomaticErrorHandling: $0.automaticErrorHandling, + shouldUpgrade: $0.shouldUpgrade, + upgradePipelineHandler: $0.upgradePipelineHandler + )} + + XCTAssertNoThrow(try serverChannel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration( + upgraders: upgraders, + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + ) + )) + let clientChannel = EmbeddedChannel(loop: loop) + return (loop: loop, serverChannel: serverChannel, clientChannel: clientChannel) + } +} +#endif From 702cd7c56d5d44eeba73fdf83918339b26dc855c Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Thu, 16 Nov 2023 11:30:21 +0000 Subject: [PATCH 49/64] Fix the typed HTTP upgrade compiler guards (#2594) The compiler guards were unnecessarily complex and I also didn't cover 3 methods where we used the types. --- Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift | 2 +- .../NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift | 2 +- .../NIOTypedHTTPClientUpgraderStateMachine.swift | 2 +- .../NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift | 2 +- .../NIOTypedHTTPServerUpgraderStateMachine.swift | 2 +- .../NIOWebSocket/NIOWebSocketClientUpgrader.swift | 2 +- .../NIOWebSocket/NIOWebSocketServerUpgrader.swift | 2 +- Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift | 10 +++++++--- Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift | 12 ++++++++++-- .../WebSocketClientEndToEndTests.swift | 2 +- .../WebSocketServerEndToEndTests.swift | 2 +- 11 files changed, 26 insertions(+), 14 deletions(-) diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift index 4135203a8e..9021062488 100644 --- a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) import NIOCore // MARK: - Server pipeline configuration diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift index ea76a74b91..c683b61b3e 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) import NIOCore /// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift index 875fb2ce64..6e9c696811 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) import DequeModule import NIOCore diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift index 1a1a47988c..b6a90b1294 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) import NIOCore /// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift index c4fa19c348..bc2536f7c8 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) import DequeModule import NIOCore diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index d1b190c288..a9e456f857 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -74,7 +74,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) /// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. /// /// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 14f29f750b..0672bc4a06 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -175,7 +175,7 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) /// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. /// /// Users may frequently want to offer multiple websocket endpoints on the same port. For this diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 195338f9ef..a0cda42f73 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -32,7 +32,7 @@ extension EmbeddedChannel { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} #else @@ -287,9 +287,13 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) private func assertPipelineContainsUpgradeHandler(channel: Channel) { let handler = try? channel.pipeline.syncOperations.handler(type: NIOHTTPClientUpgradeHandler.self) - let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) + #if !canImport(Darwin) || swift(>=5.10) + let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) XCTAssertTrue(handler != nil || typedHandler != nil) + #else + XCTAssertTrue(handler != nil) + #endif } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) @@ -953,7 +957,7 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { override func setUpClientChannel( diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index 70d55eab55..5b48485751 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -36,11 +36,15 @@ extension ChannelPipeline { @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func assertContainsUpgrader() { + #if !canImport(Darwin) || swift(>=5.10) do { _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() } catch { self.assertContains(handlerType: HTTPServerUpgradeHandler.self) } + #else + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + #endif } func assertContains(handlerType: Handler.Type) { @@ -63,6 +67,7 @@ extension ChannelPipeline { // handler present, keep waiting usleep(50) } catch ChannelPipelineError.notFound { + #if !canImport(Darwin) || swift(>=5.10) // Checking if the typed variant is present do { _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() @@ -72,6 +77,9 @@ extension ChannelPipeline { // No upgrader, we're good. return } + #else + return + #endif } } @@ -174,7 +182,7 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e XCTAssertEqual(lines.count, 0) } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} #else @@ -1557,7 +1565,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { fileprivate override func setUpTestWithAutoremoval( diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 35b89c02b7..1e64a27544 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -405,7 +405,7 @@ class WebSocketClientEndToEndTests: XCTestCase { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests { func setUpClientChannel( diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index 609a4e9325..d73d1f21dc 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -527,7 +527,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } } -#if !canImport(Darwin) || (canImport(Darwin) && swift(>=5.10)) +#if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { override func createTestFixtures( From e69354f1237d34e3453248dfca57aade20a288d8 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 16 Nov 2023 16:21:21 +0000 Subject: [PATCH 50/64] Async apis for NonBlockingFileIO (#2576) * Async versions of NonBlockingFileIO functions I have not produced async versions of the readChunked functions. These would likely be replaced with an AsyncSequence though. * Add tests for async versions of NonBlockingFileIO * Add wrapper for FileRegion to avoid Sendable issues * Allocate ByteBuffer in readSync * Update read_10000_chunks_from_file alloc count * Add UnsafeTransfer to NIOPosix. Use in NonBlockingFileIO.openFile functions * Update Sources/NIOPosix/NonBlockingFileIO.swift Co-authored-by: George Barnett * formatting, docc * update comment * replace openFile with withOpenFile * More indentation * Update comments * Remove docc links to NIOCore types * Fix read docc links * Catch errors when closing file handle in withOpenFile * rename to withFileHandle * Fix double close on error Also - move @available above each function - Remove wording about separate from Task pool * Update Sources/NIOPosix/NonBlockingFileIO.swift Co-authored-by: George Barnett * withOpenFile is now withFileRegion --------- Co-authored-by: George Barnett --- Sources/NIOPosix/NIOThreadPool.swift | 2 +- Sources/NIOPosix/NonBlockingFileIO.swift | 536 +++++++++++++--- Sources/NIOPosix/UnsafeTransfer.swift | 33 + .../NIOPosixTests/NonBlockingFileIOTest.swift | 605 ++++++++++++++++++ Tests/NIOPosixTests/TestUtils.swift | 58 ++ docker/docker-compose.2204.510.yaml | 2 +- docker/docker-compose.2204.57.yaml | 2 +- docker/docker-compose.2204.58.yaml | 2 +- docker/docker-compose.2204.59.yaml | 2 +- docker/docker-compose.2204.main.yaml | 2 +- 10 files changed, 1166 insertions(+), 78 deletions(-) create mode 100644 Sources/NIOPosix/UnsafeTransfer.swift diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index 6d29801894..39db189196 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -360,4 +360,4 @@ extension NIOThreadPool { } } } -} +} \ No newline at end of file diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index 2e07175bc3..63830a8e0c 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -15,7 +15,7 @@ import NIOCore import NIOConcurrencyHelpers -/// `NonBlockingFileIO` is a helper that allows you to read files without blocking the calling thread. +/// ``NonBlockingFileIO`` is a helper that allows you to read files without blocking the calling thread. /// /// It is worth noting that `kqueue`, `epoll` or `poll` returning claiming a file is readable does not mean that the /// data is already available in the kernel's memory. In other words, a `read` from a file can still block even if @@ -25,19 +25,19 @@ import NIOConcurrencyHelpers /// - [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html): "epoll is simply a faster poll(2), and can be used wherever the latter is used since it shares the same semantics." /// - [`kqueue`](https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2): "Returns when the file pointer is not at the end of file." /// -/// `NonBlockingFileIO` helps to work around this issue by maintaining its own thread pool that is used to read the data +/// ``NonBlockingFileIO`` helps to work around this issue by maintaining its own thread pool that is used to read the data /// from the files into memory. It will then hand the (in-memory) data back which makes it available without the possibility /// of blocking. public struct NonBlockingFileIO: Sendable { - /// The default and recommended size for `NonBlockingFileIO`'s thread pool. + /// The default and recommended size for ``NonBlockingFileIO``'s thread pool. public static let defaultThreadPoolSize = 2 /// The default and recommended chunk size. public static let defaultChunkSize = 128*1024 - /// `NonBlockingFileIO` errors. + /// ``NonBlockingFileIO`` errors. public enum Error: Swift.Error { - /// `NonBlockingFileIO` is meant to be used with file descriptors that are set to the default (blocking) mode. + /// ``NonBlockingFileIO`` is meant to be used with file descriptors that are set to the default (blocking) mode. /// It doesn't make sense to use it with a file descriptor where `O_NONBLOCK` is set therefore this error is /// raised when that was requested. case descriptorSetToNonBlocking @@ -45,7 +45,7 @@ public struct NonBlockingFileIO: Sendable { private let threadPool: NIOThreadPool - /// Initialize a `NonBlockingFileIO` which uses the `NIOThreadPool`. + /// Initialize a ``NonBlockingFileIO`` which uses the `NIOThreadPool`. /// /// - parameters: /// - threadPool: The `NIOThreadPool` that will be used for all the IO. @@ -53,7 +53,7 @@ public struct NonBlockingFileIO: Sendable { self.threadPool = threadPool } - /// Read a `FileRegion` in chunks of `chunkSize` bytes on `NonBlockingFileIO`'s private thread + /// Read a `FileRegion` in chunks of `chunkSize` bytes on ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `fileRegion.readableBytes` is greater than @@ -89,7 +89,7 @@ public struct NonBlockingFileIO: Sendable { chunkHandler: chunkHandler) } - /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread + /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than @@ -128,7 +128,7 @@ public struct NonBlockingFileIO: Sendable { chunkHandler: chunkHandler) } - /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread + /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than @@ -225,7 +225,7 @@ public struct NonBlockingFileIO: Sendable { return promise.futureResult } - /// Read a `FileRegion` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Read a `FileRegion` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `fileRegion.readableBytes` unless we hit end-of-file in which /// case the `ByteBuffer` will contain the bytes available to read. @@ -250,15 +250,15 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop) } - /// Read `byteCount` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Read `byteCount` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which /// case the `ByteBuffer` will contain the bytes available to read. /// /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. - /// - note: `read(fileRegion:allocator:eventLoop:)` should be preferred as it uses `FileRegion` object instead of + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of /// raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, - /// please consider using `read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)` + /// please consider using ``read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)`` /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file /// descriptors and even across processes.) /// @@ -279,7 +279,7 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop) } - /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in `NonBlockingFileIO`'s private thread pool + /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in ``NonBlockingFileIO``'s private thread pool /// which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which @@ -289,7 +289,7 @@ public struct NonBlockingFileIO: Sendable { /// same `fileHandle` in multiple threads. /// /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. - /// - note: `read(fileRegion:allocator:eventLoop:)` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to read. @@ -320,40 +320,49 @@ public struct NonBlockingFileIO: Sendable { } let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) - var buf = allocator.buffer(capacity: byteCount) return self.threadPool.runIfActive(eventLoop: eventLoop) { () -> ByteBuffer in - var bytesRead = 0 - while bytesRead < byteCount { - let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in - let res = try fileHandle.withUnsafeFileDescriptor { descriptor -> IOResult in - if let offset = fromOffset { - return try Posix.pread(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead, - offset: off_t(offset) + off_t(bytesRead)) - } else { - return try Posix.read(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead) - } - } - switch res { - case .processed(let n): - assert(n >= 0, "read claims to have read a negative number of bytes \(n)") - return n - case .wouldBlock: - throw Error.descriptorSetToNonBlocking + try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + } + } + + private func readSync( + fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount: Int, + allocator: ByteBufferAllocator + ) throws -> ByteBuffer { + var bytesRead = 0 + var buf = allocator.buffer(capacity: byteCount) + while bytesRead < byteCount { + let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in + let res = try fileHandle.withUnsafeFileDescriptor { descriptor -> IOResult in + if let offset = fromOffset { + return try Posix.pread(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead, + offset: off_t(offset) + off_t(bytesRead)) + } else { + return try Posix.read(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead) } } - if n == 0 { - // EOF - break - } else { - bytesRead += n + switch res { + case .processed(let n): + assert(n >= 0, "read claims to have read a negative number of bytes \(n)") + return n + case .wouldBlock: + throw Error.descriptorSetToNonBlocking } } - return buf + if n == 0 { + // EOF + break + } else { + bytesRead += n + } } + return buf } /// Changes the file size of `fileHandle` to `size`. @@ -394,7 +403,7 @@ public struct NonBlockingFileIO: Sendable { } } - /// Write `buffer` to `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Write `buffer` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to write to. @@ -407,7 +416,7 @@ public struct NonBlockingFileIO: Sendable { return self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer, eventLoop: eventLoop) } - /// Write `buffer` starting from `toOffset` to `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to write to. @@ -433,35 +442,44 @@ public struct NonBlockingFileIO: Sendable { } return self.threadPool.runIfActive(eventLoop: eventLoop) { - var buf = buffer - - var offsetAccumulator: Int = 0 - repeat { - let n = try buf.readWithUnsafeReadableBytes { ptr in - precondition(ptr.count == byteCount - offsetAccumulator) - let res: IOResult = try fileHandle.withUnsafeFileDescriptor { descriptor in - if let toOffset = toOffset { - return try Posix.pwrite(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator, - offset: off_t(toOffset + Int64(offsetAccumulator))) - } else { - return try Posix.write(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator) - } - } - switch res { - case .processed(let n): - assert(n >= 0, "write claims to have written a negative number of bytes \(n)") - return n - case .wouldBlock: - throw Error.descriptorSetToNonBlocking + try self.writeSync(fileHandle: fileHandle, byteCount: byteCount, toOffset: toOffset, buffer: buffer) + } + } + + private func writeSync( + fileHandle: NIOFileHandle, + byteCount: Int, + toOffset: Int64?, + buffer: ByteBuffer + ) throws { + var buf = buffer + + var offsetAccumulator: Int = 0 + repeat { + let n = try buf.readWithUnsafeReadableBytes { ptr in + precondition(ptr.count == byteCount - offsetAccumulator) + let res: IOResult = try fileHandle.withUnsafeFileDescriptor { descriptor in + if let toOffset = toOffset { + return try Posix.pwrite(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator, + offset: off_t(toOffset + Int64(offsetAccumulator))) + } else { + return try Posix.write(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator) } } - offsetAccumulator += n - } while offsetAccumulator < byteCount - } + switch res { + case .processed(let n): + assert(n >= 0, "write claims to have written a negative number of bytes \(n)") + return n + case .wouldBlock: + throw Error.descriptorSetToNonBlocking + } + } + offsetAccumulator += n + } while offsetAccumulator < byteCount } /// Open the file at `path` for reading on a private thread pool which is separate from any `EventLoop` thread. @@ -712,3 +730,377 @@ public struct NIODirectoryEntry: Hashable { } } #endif + +extension NonBlockingFileIO { + /// Read a `FileRegion` in ``NonBlockingFileIO``'s private thread pool. + /// + /// The returned `ByteBuffer` will not have less than the minimum of `fileRegion.readableBytes` and `UInt32.max` unless we hit + /// end-of-file in which case the `ByteBuffer` will contain the bytes available to read. + /// + /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the + /// same `FileRegion` in multiple threads. + /// + /// - note: Only use this function for small enough `FileRegion`s as it will need to allocate enough memory to hold `fileRegion.readableBytes` bytes. + /// - note: In most cases you should prefer one of the `readChunked` functions. + /// + /// - parameters: + /// - fileRegion: The file region to read. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read(fileRegion: FileRegion, allocator: ByteBufferAllocator) async throws -> ByteBuffer { + let readableBytes = fileRegion.readableBytes + return try await self.read( + fileHandle: fileRegion.fileHandle, + fromOffset: Int64(fileRegion.readerIndex), + byteCount: readableBytes, + allocator: allocator + ) + } + + /// Read `byteCount` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which + /// case the `ByteBuffer` will contain the bytes available to read. + /// + /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of + /// raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, + /// please consider using ``read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)`` + /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file + /// descriptors and even across processes.) + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read. + /// - byteCount: The number of bytes to read from `fileHandle`. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read( + fileHandle: NIOFileHandle, + byteCount: Int, + allocator: ByteBufferAllocator + ) async throws-> ByteBuffer { + return try await self.read0( + fileHandle: fileHandle, + fromOffset: nil, + byteCount: byteCount, + allocator: allocator + ) + } + + /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in ``NonBlockingFileIO``'s private thread pool + ///. + /// + /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which + /// case the `ByteBuffer` will contain the bytes available to read. + /// + /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the + /// same `fileHandle` in multiple threads. + /// + /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read. + /// - fileOffset: The offset to read from. + /// - byteCount: The number of bytes to read from `fileHandle`. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read(fileHandle: NIOFileHandle, + fromOffset fileOffset: Int64, + byteCount: Int, + allocator: ByteBufferAllocator) async throws -> ByteBuffer { + return try await self.read0(fileHandle: fileHandle, + fromOffset: fileOffset, + byteCount: byteCount, + allocator: allocator) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func read0(fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount rawByteCount: Int, + allocator: ByteBufferAllocator) async throws -> ByteBuffer { + guard rawByteCount > 0 else { + return allocator.buffer(capacity: 0) + } + let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) + + return try await self.threadPool.runIfActive { () -> ByteBuffer in + try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + } + } + + /// Changes the file size of `fileHandle` to `size`. + /// + /// If `size` is smaller than the current file size, the remaining bytes will be truncated and are lost. If `size` + /// is larger than the current file size, the gap will be filled with zero bytes. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - size: The new file size in bytes to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func changeFileSize(fileHandle: NIOFileHandle, + size: Int64) async throws { + return try await self.threadPool.runIfActive { + try fileHandle.withUnsafeFileDescriptor { descriptor -> Void in + try Posix.ftruncate(descriptor: descriptor, size: off_t(size)) + } + } + } + + /// Returns the length of the file associated with `fileHandle`. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read from. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func readFileSize(fileHandle: NIOFileHandle) async throws -> Int64 { + return try await self.threadPool.runIfActive { + return try fileHandle.withUnsafeFileDescriptor { descriptor in + let curr = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_CUR) + let eof = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_END) + try Posix.lseek(descriptor: descriptor, offset: curr, whence: SEEK_SET) + return Int64(eof) + } + } + } + + /// Write `buffer` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - buffer: The `ByteBuffer` to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func write(fileHandle: NIOFileHandle, + buffer: ByteBuffer) async throws { + return try await self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer) + } + + /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - toOffset: The file offset to write to. + /// - buffer: The `ByteBuffer` to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func write(fileHandle: NIOFileHandle, + toOffset: Int64, + buffer: ByteBuffer) async throws { + return try await self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func write0(fileHandle: NIOFileHandle, + toOffset: Int64?, + buffer: ByteBuffer) async throws { + let byteCount = buffer.readableBytes + + guard byteCount > 0 else { + return + } + + return try await self.threadPool.runIfActive { + try self.writeSync(fileHandle: fileHandle, byteCount: byteCount, toOffset: toOffset, buffer: buffer) + } + } + + /// Open file at `path` and query its size on a private thread pool, run an operation given + /// the resulting file region and then close the file handle. + /// + /// The will return the result of the operation. + /// + /// - note: This function opens a file and queries it size which are both blocking operations + /// + /// - parameters: + /// - path: The path of the file to be opened for reading. + /// - body: operation to run with file handle and region + /// - returns: return value of operation + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withFileRegion( + path: String, + _ body: (_ fileRegion: FileRegion) async throws -> Result + ) async throws -> Result { + let fileRegion = try await self.threadPool.runIfActive { + let fh = try NIOFileHandle(path: path) + do { + let fr = try FileRegion(fileHandle: fh) + return UnsafeTransfer(fr) + } catch { + _ = try? fh.close() + throw error + } + } + let result: Result + do { + result = try await body(fileRegion.wrappedValue) + } catch { + try fileRegion.wrappedValue.fileHandle.close() + throw error + } + try fileRegion.wrappedValue.fileHandle.close() + return result + } + + /// Open file at `path` on a private thread pool, run an operation given the file handle and then close the file handle. + /// + /// This function will return the result of the operation. + /// + /// - parameters: + /// - path: The path of the file to be opened for writing. + /// - mode: File access mode. + /// - flags: Additional POSIX flags. + /// - returns: return value of operation + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withFileHandle( + path: String, + mode: NIOFileHandle.Mode, + flags: NIOFileHandle.Flags = .default, + _ body: (NIOFileHandle) async throws -> Result + ) async throws -> Result { + let fileHandle = try await self.threadPool.runIfActive { + return try UnsafeTransfer(NIOFileHandle(path: path, mode: mode, flags: flags)) + } + let result: Result + do { + result = try await body(fileHandle.wrappedValue) + } catch { + try fileHandle.wrappedValue.close() + throw error + } + try fileHandle.wrappedValue.close() + return result + } + +#if !os(Windows) + + /// Returns information about a file at `path` on a private thread pool. + /// + /// - note: If `path` is a symlink, information about the link, not the file it points to. + /// + /// - parameters: + /// - path: The path of the file to get information about. + /// - returns: file information. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func lstat(path: String) async throws -> stat { + return try await self.threadPool.runIfActive { + var s = stat() + try Posix.lstat(pathname: path, outStat: &s) + return s + } + } + + /// Creates a symbolic link to a `destination` file at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the link. + /// - destination: Target path where this link will point to. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func symlink(path: String, to destination: String) async throws { + return try await self.threadPool.runIfActive { + try Posix.symlink(pathname: path, destination: destination) + } + } + + /// Returns target of the symbolic link at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the link to read. + /// - returns: link target. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func readlink(path: String) async throws -> String { + return try await self.threadPool.runIfActive { + let maxLength = Int(PATH_MAX) + let pointer = UnsafeMutableBufferPointer.allocate(capacity: maxLength) + defer { + pointer.deallocate() + } + let length = try Posix.readlink(pathname: path, outPath: pointer.baseAddress!, outPathSize: maxLength) + return String(decoding: UnsafeRawBufferPointer(pointer).prefix(length), as: UTF8.self) + } + } + + /// Removes symbolic link at `path` on a private thread pool which is separate from any `EventLoop` thread. + /// + /// - parameters: + /// - path: The path of the link to remove. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func unlink(path: String) async throws { + return try await self.threadPool.runIfActive { + try Posix.unlink(pathname: path) + } + } + + + /// Creates directory at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the directory to be created. + /// - withIntermediateDirectories: Whether intermediate directories should be created. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func createDirectory(path: String, withIntermediateDirectories createIntermediates: Bool = false, mode: NIOPOSIXFileMode) async throws { + return try await self.threadPool.runIfActive { + if createIntermediates { + #if canImport(Darwin) + try Posix.mkpath_np(pathname: path, mode: mode) + #else + try self.createDirectory0(path, mode: mode) + #endif + } else { + try Posix.mkdir(pathname: path, mode: mode) + } + } + } + + /// List contents of the directory at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the directory to list the content of. + /// - returns: The directory entries. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func listDirectory(path: String) async throws -> [NIODirectoryEntry] { + return try await self.threadPool.runIfActive { + let dir = try Posix.opendir(pathname: path) + var entries: [NIODirectoryEntry] = [] + do { + while let entry = try Posix.readdir(dir: dir) { + let name = withUnsafeBytes(of: entry.pointee.d_name) { pointer -> String in + let ptr = pointer.baseAddress!.assumingMemoryBound(to: CChar.self) + return String(cString: ptr) + } + entries.append(NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name)) + } + try? Posix.closedir(dir: dir) + } catch { + try? Posix.closedir(dir: dir) + throw error + } + return entries + } + } + + /// Renames the file at `path` to `newName` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the file to be renamed. + /// - newName: New file name. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func rename(path: String, newName: String) async throws { + return try await self.threadPool.runIfActive() { + try Posix.rename(pathname: path, newName: newName) + } + } + + /// Removes the file at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the file to be removed. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func remove(path: String) async throws { + return try await self.threadPool.runIfActive() { + try Posix.remove(pathname: path) + } + } +#endif +} \ No newline at end of file diff --git a/Sources/NIOPosix/UnsafeTransfer.swift b/Sources/NIOPosix/UnsafeTransfer.swift new file mode 100644 index 0000000000..daef8dacc0 --- /dev/null +++ b/Sources/NIOPosix/UnsafeTransfer.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`. +/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler. +/// It can be used similar to `@unsafe Sendable` but for values instead of types. +@usableFromInline +struct UnsafeTransfer { + @usableFromInline + var wrappedValue: Wrapped + + @inlinable + init(_ wrappedValue: Wrapped) { + self.wrappedValue = wrappedValue + } +} + +extension UnsafeTransfer: @unchecked Sendable {} + +extension UnsafeTransfer: Equatable where Wrapped: Equatable {} +extension UnsafeTransfer: Hashable where Wrapped: Hashable {} + diff --git a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift index 1e742b1e92..cd5d4d5f34 100644 --- a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift +++ b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift @@ -1041,4 +1041,609 @@ class NonBlockingFileIOTest: XCTestCase { } XCTAssertEqual(content.utf8.count, numCalls) } + +} + + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NonBlockingFileIOTest { + func testAsyncBasicFileIOWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 5) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(content.utf8.count, buf.readableBytes) + XCTAssertEqual(content, buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncOffsetWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 3, endIndex: 5) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(2, buf.readableBytes) + XCTAssertEqual("lo", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncOffsetBeyondEOF() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 3000, endIndex: 3001) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(0, buf.readableBytes) + XCTAssertEqual("", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncEmptyReadWorks() async throws { + try await withTemporaryFile { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 0) + let buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(0, buf.readableBytes) + } + } + + func testAsyncReadingShortWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 10) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(content.utf8.count, buf.readableBytes) + XCTAssertEqual(content, buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncDoesNotBlockTheThreadOrEventLoop() async throws { + try await withPipe { readFH, writeFH in + async let byteBufferTask = try await self.fileIO.read( + fileHandle: readFH, + byteCount: 10, + allocator: self.allocator) + do { + try await self.threadPool.runIfActive { + try writeFH.withUnsafeFileDescriptor { writeFD in + _ = try Posix.write(descriptor: writeFD, pointer: "X", size: 1) + } + try writeFH.close() + } + var buf = try await byteBufferTask + XCTAssertEqual(1, buf.readableBytes) + XCTAssertEqual("X", buf.readString(length: buf.readableBytes)) + } + return [readFH] + } + } + + func testAsyncGettingErrorWhenEventLoopGroupIsShutdown() async throws { + try await self.threadPool.shutdownGracefully() + + try await withPipe { readFH, writeFH in + do { + _ = try await self.fileIO.read( + fileHandle: readFH, + byteCount: 1, + allocator: self.allocator) + XCTFail("testAsyncGettingErrorWhenEventLoopGroupIsShutdown: fileIO.read should throw an error") + } catch { + XCTAssertTrue(error is NIOThreadPoolError.ThreadPoolInactive) + } + return [readFH, writeFH] + } + } + + func testAsyncReadDoesNotReadShort() async throws { + try await withPipe { readFH, writeFH in + async let bufferTask = try await self.fileIO.read(fileHandle: readFH, + byteCount: 10, + allocator: self.allocator) + for i in 0..<10 { + try await Task.sleep(nanoseconds: 5_000_000) + try await self.threadPool.runIfActive { + try writeFH.withUnsafeFileDescriptor { writeFD in + _ = try Posix.write(descriptor: writeFD, pointer: "\(i)", size: 1) + } + } + } + try writeFH.close() + + var buf = try await bufferTask + XCTAssertEqual(10, buf.readableBytes) + XCTAssertEqual("0123456789", buf.readString(length: buf.readableBytes)) + return [readFH] + } + } + + func testAsyncReadMoreThanIntMaxBytesDoesntThrow() async throws { + try XCTSkipIf(MemoryLayout.size == MemoryLayout.size) + // here we try to read way more data back from the file than it contains but it serves the purpose + // even on a small file the OS will return EINVAL if you try to read > INT_MAX bytes + try await withTemporaryFile(content: "some-dummy-content", { (filehandle, path) -> Void in + let content = try await self.fileIO.read(fileHandle: filehandle, byteCount:Int(Int32.max)+10, allocator: .init()) + XCTAssertEqual(String(buffer: content), "some-dummy-content") + }) + } + + func testAsyncReadingFileSize() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + let size = try await self.fileIO.readFileSize(fileHandle: fileHandle) + XCTAssertEqual(size, 10) + } + } + + func testAsyncChangeFileSizeShrink() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + try await self.fileIO.changeFileSize(fileHandle: fileHandle, + size: 1) + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + XCTAssertEqual("0", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncChangeFileSizeGrow() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + try await self.fileIO.changeFileSize(fileHandle: fileHandle, + size: 100) + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + let zeros = (1...90).map { _ in UInt8(0) } + guard let bytes = buf.readBytes(length: buf.readableBytes)?.suffix(from: 10) else { + XCTFail("readBytes(length:) should not be nil") + return + } + XCTAssertEqual(zeros, Array(bytes)) + } + } + + func testAsyncWriting() async throws { + try await withTemporaryFile(content: "") { (fileHandle, path) in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + buffer: buffer) + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + let readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: 3, + allocator: self.allocator) + XCTAssertEqual(readBuffer.getString(at: 0, length: 3), "123") + } + } + + func testAsyncWriteMultipleTimes() async throws { + try await withTemporaryFile(content: "AAA") { (fileHandle, path) in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("xxx") + + for i in 0 ..< 3 { + buffer.writeString("\(i)") + try await self.fileIO.write(fileHandle: fileHandle, + buffer: buffer) + } + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + let expectedOutput = "xxx0xxx01xxx012" + let readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: expectedOutput.utf8.count, + allocator: self.allocator) + XCTAssertEqual(expectedOutput, String(decoding: readBuffer.readableBytesView, as: Unicode.UTF8.self)) + } + } + + func testAsyncWritingWithOffset() async throws { + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + toOffset: 1, + buffer: buffer) + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + var readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: 5, + allocator: self.allocator) + XCTAssertEqual(5, readBuffer.readableBytes) + XCTAssertEqual("h123o", readBuffer.readString(length: readBuffer.readableBytes)) + } + } + + // This is undefined behavior and may cause different + // results on other platforms. Please add #if:s according + // to platform requirements. + func testAsyncWritingBeyondEOF() async throws { + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + toOffset: 6, + buffer: buffer) + + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + XCTAssertEqual(9, buf.readableBytes) + XCTAssertEqual("hello", buf.readString(length: 5)) + XCTAssertEqual([ UInt8(0) ], buf.readBytes(length: 1)) + XCTAssertEqual("123", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncFileOpenWorks() async throws { + let content = "123" + try await withTemporaryFile(content: content) { (fileHandle, path) -> Void in + try await self.fileIO.withFileRegion(path: path) { fr in + try fr.fileHandle.withUnsafeFileDescriptor { fd in + XCTAssertGreaterThanOrEqual(fd, 0) + } + XCTAssertTrue(fr.fileHandle.isOpen) + XCTAssertEqual(0, fr.readerIndex) + XCTAssertEqual(3, fr.endIndex) + } + } + } + + func testAsyncFileOpenWorksWithEmptyFile() async throws { + let content = "" + try await withTemporaryFile(content: content) { (fileHandle, path) -> Void in + try await self.fileIO.withFileRegion(path: path) { fr in + try fr.fileHandle.withUnsafeFileDescriptor { fd in + XCTAssertGreaterThanOrEqual(fd, 0) + } + XCTAssertTrue(fr.fileHandle.isOpen) + XCTAssertEqual(0, fr.readerIndex) + XCTAssertEqual(0, fr.endIndex) + } + } + } + + func testAsyncFileOpenFails() async throws { + do { + _ = try await self.fileIO.withFileRegion(path: "/dev/null/this/does/not/exist") { _ in} + XCTFail("should've thrown") + } catch let e as IOError where e.errnoCode == ENOTDIR { + // OK + } catch { + XCTFail("wrong error: \(error)") + } + } + + func testAsyncOpeningFilesForWriting() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .allowFileCreation() + ) { _ in } + } + } + + func testAsyncOpeningFilesForWritingFailsIfWeDontAllowItExplicitly() async throws { + do { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .default + ) { _ in } + } + XCTFail("testAsyncOpeningFilesForWritingFailsIfWeDontAllowItExplicitly: openFile should fail") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + + func testAsyncOpeningFilesForWritingDoesNotAllowReading() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .allowFileCreation() + ) { fileHandle in + XCTAssertEqual(-1 /* read must fail */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt8 = 0 + return withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + }) + } + } + } + + func testAsyncOpeningFilesForWritingAndReading() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + XCTAssertEqual(0 /* read should read EOF */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt8 = 0 + return withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + }) + } + } + } + + func testAsyncOpeningFilesForWritingDoesNotImplyTruncation() async throws { + try await withTemporaryDirectory { dir in + // open 1 + write + do { + try await self.fileIO.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + var data = UInt8(ascii: "X") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + } + } + + // open 2 + write again + read + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .default + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_END) + var data = UInt8(ascii: "Y") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + XCTAssertEqual(2 /* both bytes */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt16 = 0 + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_SET) + let readReturn = withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + XCTAssertEqual(UInt16(bigEndian: (UInt16(UInt8(ascii: "X")) << 8) | UInt16(UInt8(ascii: "Y"))), + data) + return readReturn + }) + } + } + } + } + + func testAsyncOpeningFilesForWritingCanUseTruncation() async throws { + try await withTemporaryDirectory { dir in + // open 1 + write + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + var data = UInt8(ascii: "X") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + } + } + // open 2 (with truncation) + write again + read + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .posix(flags: O_TRUNC, mode: 0) + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_END) + var data = UInt8(ascii: "Y") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + XCTAssertEqual(1 /* read should read just one byte because we truncated the file */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt16 = 0 + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_SET) + let readReturn = withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + XCTAssertEqual(UInt16(bigEndian: UInt16(UInt8(ascii: "Y")) << 8), data) + return readReturn + }) + } + } + } + } + + func testAsyncReadFromOffset() async throws { + try await withTemporaryFile(content: "hello world") { (fileHandle, path) in + let buffer = try await self.fileIO.read(fileHandle: fileHandle, + fromOffset: 6, + byteCount: 5, + allocator: ByteBufferAllocator()) + let string = String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self) + XCTAssertEqual("world", string) + } + } + + func testAsyncReadFromOffsetAfterEOFDeliversExactlyOneChunk() async throws { + try await withTemporaryFile(content: "hello world") { (fileHandle, path) in + let readableBytes = try await self.fileIO.read( + fileHandle: fileHandle, + fromOffset: 100, + byteCount: 5, + allocator: .init() + ).readableBytes + XCTAssertEqual(0,readableBytes) + } + } + + func testAsyncReadFromEOFDeliversExactlyOneChunk() async throws { + try await withTemporaryFile(content: "") { (fileHandle, path) in + let readableBytes = try await self.fileIO.read( + fileHandle: fileHandle, + byteCount: 5, + allocator: .init() + ).readableBytes + XCTAssertEqual(0, readableBytes) + } + } + + func testAsyncThrowsErrorOnUnstartedPool() async throws { + await withTemporaryFile(content: "hello, world") { fileHandle, path in + let threadPool = NIOThreadPool(numberOfThreads: 1) + let fileIO = NonBlockingFileIO(threadPool: threadPool) + do { + try await fileIO.withFileRegion(path: path) { _ in } + XCTFail("testAsyncThrowsErrorOnUnstartedPool: openFile should throw an error") + } catch { + } + } + } + + func testAsyncLStat() async throws { + try await withTemporaryFile(content: "hello, world") { _, path in + let stat = try await self.fileIO.lstat(path: path) + XCTAssertEqual(12, stat.st_size) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + } + + try await withTemporaryDirectory { path in + let stat = try await self.fileIO.lstat(path: path) + XCTAssertEqual(S_IFDIR, S_IFMT & stat.st_mode) + } + } + + func testAsyncSymlink() async throws { + try await withTemporaryFile(content: "hello, world") { _, path in + let symlink = "\(path).symlink" + try await self.fileIO.symlink(path: symlink, to: path) + + let link = try await self.fileIO.readlink(path: symlink) + XCTAssertEqual(path, link) + let stat = try await self.fileIO.lstat(path: symlink) + XCTAssertEqual(S_IFLNK, S_IFMT & stat.st_mode) + + try await self.fileIO.unlink(path: symlink) + do { + _ = try await self.fileIO.lstat(path: symlink) + XCTFail("testAsyncSymlink: lstat should throw an error after unlink") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + + func testAsyncCreateDirectory() async throws { + try await withTemporaryDirectory { path in + let dir = "\(path)/f1/f2///f3" + try await self.fileIO.createDirectory(path: dir, withIntermediateDirectories: true, mode: S_IRWXU) + + let stat = try await self.fileIO.lstat(path: dir) + XCTAssertEqual(S_IFDIR, S_IFMT & stat.st_mode) + + try await self.fileIO.createDirectory(path: "\(dir)/f4", withIntermediateDirectories: false, mode: S_IRWXU) + + let stat2 = try await self.fileIO.lstat(path: dir) + XCTAssertEqual(S_IFDIR, S_IFMT & stat2.st_mode) + + let dir3 = "\(path)/f4/." + try await self.fileIO.createDirectory(path: dir3, withIntermediateDirectories: true, mode: S_IRWXU) + } + } + + func testAsyncListDirectory() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let list = try await self.fileIO.listDirectory(path: path) + XCTAssertEqual([".", "..", "file"], list.sorted(by: { $0.name < $1.name }).map(\.name)) + } + } + } + + func testAsyncRename() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let stat = try await self.fileIO.lstat(path: file) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + + let new = "\(path).new" + try await self.fileIO.rename(path: file, newName: new) + + let stat2 = try await self.fileIO.lstat(path: new) + XCTAssertEqual(S_IFREG, S_IFMT & stat2.st_mode) + + do { + _ = try await self.fileIO.lstat(path: file) + XCTFail("testAsyncRename: lstat should throw an error after file renamed") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + } + + func testAsyncRemove() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let stat = try await self.fileIO.lstat(path: file) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + + try await self.fileIO.remove(path: file) + do { + _ = try await self.fileIO.lstat(path: file) + XCTFail("testAsyncRemove: lstat should throw an error after file removed") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + } } diff --git a/Tests/NIOPosixTests/TestUtils.swift b/Tests/NIOPosixTests/TestUtils.swift index 273b6b145c..910744bc64 100644 --- a/Tests/NIOPosixTests/TestUtils.swift +++ b/Tests/NIOPosixTests/TestUtils.swift @@ -60,6 +60,28 @@ func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) throws -> [ } } +func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) async throws -> [NIOCore.NIOFileHandle]) async throws { + var fds: [Int32] = [-1, -1] + fds.withUnsafeMutableBufferPointer { ptr in + XCTAssertEqual(0, pipe(ptr.baseAddress!)) + } + let readFH = NIOFileHandle(descriptor: fds[0]) + let writeFH = NIOFileHandle(descriptor: fds[1]) + var toClose: [NIOFileHandle] = [readFH, writeFH] + var error: Error? = nil + do { + toClose = try await body(readFH, writeFH) + } catch let err { + error = err + } + try toClose.forEach { fh in + XCTAssertNoThrow(try fh.close()) + } + if let error = error { + throw error + } +} + func withTemporaryDirectory(_ body: (String) throws -> T) rethrows -> T { let dir = createTemporaryDirectory() defer { @@ -68,6 +90,14 @@ func withTemporaryDirectory(_ body: (String) throws -> T) rethrows -> T { return try body(dir) } +func withTemporaryDirectory(_ body: (String) async throws -> T) async rethrows -> T { + let dir = createTemporaryDirectory() + defer { + try? FileManager.default.removeItem(atPath: dir) + } + return try await body(dir) +} + /// This function creates a filename that can be used for a temporary UNIX domain socket path. /// /// If the temporary directory is too long to store a UNIX domain socket path, it will `chdir` into the temporary @@ -131,6 +161,34 @@ func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle } return try body(fileHandle, path) } + +func withTemporaryFile(content: String? = nil, _ body: @escaping @Sendable (NIOCore.NIOFileHandle, String) async throws -> T) async rethrows -> T { + let (fd, path) = openTemporaryFile() + let fileHandle = NIOFileHandle(descriptor: fd) + defer { + XCTAssertNoThrow(try fileHandle.close()) + XCTAssertEqual(0, unlink(path)) + } + if let content = content { + try Array(content.utf8).withUnsafeBufferPointer { ptr in + var toWrite = ptr.count + var start = ptr.baseAddress! + while toWrite > 0 { + let res = try Posix.write(descriptor: fd, pointer: start, size: toWrite) + switch res { + case .processed(let written): + toWrite -= written + start = start + written + case .wouldBlock: + XCTFail("unexpectedly got .wouldBlock from a file") + continue + } + } + XCTAssertEqual(0, lseek(fd, 0, SEEK_SET)) + } + } + return try await body(fileHandle, path) +} var temporaryDirectory: String { get { #if targetEnvironment(simulator) diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml index 6c4362987c..a9089f4982 100644 --- a/docker/docker-compose.2204.510.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -62,7 +62,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=343 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 diff --git a/docker/docker-compose.2204.57.yaml b/docker/docker-compose.2204.57.yaml index eccb233c9c..d6d16f9c34 100644 --- a/docker/docker-compose.2204.57.yaml +++ b/docker/docker-compose.2204.57.yaml @@ -63,7 +63,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=341 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml index c7f4059b09..827f76771e 100644 --- a/docker/docker-compose.2204.58.yaml +++ b/docker/docker-compose.2204.58.yaml @@ -63,7 +63,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=341 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index c1d94f76b5..c348433d9e 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -63,7 +63,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=350 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml index ad3a7fb6ad..5d08b84b1f 100644 --- a/docker/docker-compose.2204.main.yaml +++ b/docker/docker-compose.2204.main.yaml @@ -62,7 +62,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=343 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 From 703ee9172b6a28904e110668cdef6054130a7486 Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Fri, 24 Nov 2023 14:57:33 +0000 Subject: [PATCH 51/64] Remove precondition on result of IOCTL_VM_SOCKETS_GET_LOCAL_CID (#2588) --- Sources/NIOPosix/VsockAddress.swift | 3 +-- Tests/NIOPosixTests/TestUtils.swift | 8 ++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Sources/NIOPosix/VsockAddress.swift b/Sources/NIOPosix/VsockAddress.swift index 87e93e91b2..8e15e893e2 100644 --- a/Sources/NIOPosix/VsockAddress.swift +++ b/Sources/NIOPosix/VsockAddress.swift @@ -220,12 +220,11 @@ extension VsockAddress.ContextID { let fd = socketFD #elseif os(Linux) || os(Android) let request = CNIOLinux_IOCTL_VM_SOCKETS_GET_LOCAL_CID - let fd = try! Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) + let fd = try Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) defer { try! Posix.close(descriptor: fd) } #endif var cid = Self.any.rawValue try Posix.ioctl(fd: fd, request: request, ptr: &cid) - precondition(cid != Self.any.rawValue) return Self(rawValue: cid) } } diff --git a/Tests/NIOPosixTests/TestUtils.swift b/Tests/NIOPosixTests/TestUtils.swift index 910744bc64..461273194e 100644 --- a/Tests/NIOPosixTests/TestUtils.swift +++ b/Tests/NIOPosixTests/TestUtils.swift @@ -31,6 +31,14 @@ extension System { #if canImport(Darwin) || os(Linux) || os(Android) guard let socket = try? Socket(protocolFamily: .vsock, type: .stream) else { return false } XCTAssertNoThrow(try socket.close()) +#if !canImport(Darwin) + do { + let fd = try Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) + try Posix.close(descriptor: fd) + } catch { + return false + } +#endif return true #else return false From 80249029d094f61328fc128508a18d117c183c00 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Fri, 24 Nov 2023 15:53:55 +0000 Subject: [PATCH 52/64] Add missing availability guards in tests (#2596) # Motivation We were missing a few availability guards in our tests which causes compile time errors when building for older Darwin SDKs. # Modification This PR adds the missing availability guards. --- .../AsyncChannel/AsyncChannelInboundStreamTests.swift | 1 + .../AsyncChannel/AsyncChannelOutboundWriterTests.swift | 1 + Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift | 2 ++ Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift | 3 +++ .../AsyncSequences/NIOThrowingAsyncSequenceTests.swift | 1 + Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift | 2 ++ Tests/NIOPosixTests/SerialExecutorTests.swift | 1 + 7 files changed, 11 insertions(+) diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift index f12df126a5..94e17ec593 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift @@ -15,6 +15,7 @@ @testable import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelInboundStreamTests: XCTestCase { func testTestingStream() async throws { let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift index 49d070aab6..68446e2a34 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift @@ -15,6 +15,7 @@ @testable import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelOutboundWriterTests: XCTestCase { func testTestingWriter() async throws { let (writer, sink) = NIOAsyncChannelOutboundWriter.makeTestingWriter() diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 7fb92d1a82..e7d320bf01 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -425,6 +425,7 @@ private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHan } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncTestingChannel { fileprivate func closeIgnoringSuppression() async throws { try await self.pipeline.context(handlerType: CloseSuppressor.self).flatMap { @@ -455,6 +456,7 @@ private enum TestError: Error { case bang } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension Array { fileprivate init(_ sequence: AS) async throws where AS.Element == Self.Element { self = [] diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 7f96495a1a..40325ce09b 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -15,6 +15,7 @@ import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategy, @unchecked Sendable { enum Event { case didYield @@ -48,6 +49,7 @@ final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBa } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class MockNIOBackPressuredStreamSourceDelegate: NIOAsyncSequenceProducerDelegate, @unchecked Sendable { enum Event { case produceMore @@ -673,6 +675,7 @@ fileprivate func XCTAssertEqualWithoutAutoclosure( XCTAssertEqual(expression1, expression2, message(), file: file, line: line) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncSequence { /// Collect all elements in the sequence into an array. fileprivate func collect() async rethrows -> [Element] { diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index e63ccb8391..95978bc9cf 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -850,6 +850,7 @@ fileprivate func XCTAssertEqualWithoutAutoclosure( XCTAssertEqual(expression1, expression2, message(), file: file, line: line) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncSequence { /// Collect all elements in the sequence into an array. fileprivate func collect() async rethrows -> [Element] { diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 099074c71b..1fb1bdd810 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -191,6 +191,7 @@ private final class AddressedEnvelopingHandler: ChannelDuplexHandler { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelBootstrapTests: XCTestCase { enum NegotiationResult { case string(NIOAsyncChannel) @@ -1358,6 +1359,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncStream { fileprivate static func makeStream( of elementType: Element.Type = Element.self, diff --git a/Tests/NIOPosixTests/SerialExecutorTests.swift b/Tests/NIOPosixTests/SerialExecutorTests.swift index 84f5249222..56bc882bec 100644 --- a/Tests/NIOPosixTests/SerialExecutorTests.swift +++ b/Tests/NIOPosixTests/SerialExecutorTests.swift @@ -61,6 +61,7 @@ final class SerialExecutorTests: XCTestCase { try await self._testBasicExecutorFitsOnEventLoop(loop1: loops[0], loop2: loops[1]) } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testBasicExecutorFitsOnEventLoop_AsyncTestingEventLoop() async throws { let loop1 = NIOAsyncTestingEventLoop() let loop2 = NIOAsyncTestingEventLoop() From 4223cb37762cb1bdbd0a05f56fa79c9ac0eee98b Mon Sep 17 00:00:00 2001 From: finagolfin Date: Mon, 27 Nov 2023 14:35:30 +0530 Subject: [PATCH 53/64] Build for Android with NDK 26, by accounting for the new nullability annotations (#2600) Motivation: Fix build for the latest LTS NDK 26 Modifications: - Update C declarations - Add force unwraps where needed Result: Everything works on Android with NDK 26b --- Sources/NIOCore/BSDSocketAPI.swift | 5 +++++ Sources/NIOCore/Interfaces.swift | 4 ++-- Sources/NIOCore/SystemCallHelpers.swift | 9 +++++++-- Sources/NIOPosix/System.swift | 23 +++++++++++++---------- Sources/NIOPosix/ThreadPosix.swift | 8 ++++++++ 5 files changed, 35 insertions(+), 14 deletions(-) diff --git a/Sources/NIOCore/BSDSocketAPI.swift b/Sources/NIOCore/BSDSocketAPI.swift index c370644eca..07ba7bebcb 100644 --- a/Sources/NIOCore/BSDSocketAPI.swift +++ b/Sources/NIOCore/BSDSocketAPI.swift @@ -68,8 +68,13 @@ import Musl #endif import CNIOLinux +#if os(Android) +private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer, UnsafeMutablePointer, socklen_t) -> UnsafePointer? = inet_ntop +private let sysInet_pton: @convention(c) (CInt, UnsafePointer, UnsafeMutableRawPointer) -> CInt = inet_pton +#else private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = inet_ntop private let sysInet_pton: @convention(c) (CInt, UnsafePointer?, UnsafeMutableRawPointer?) -> CInt = inet_pton +#endif #elseif canImport(Darwin) import Darwin diff --git a/Sources/NIOCore/Interfaces.swift b/Sources/NIOCore/Interfaces.swift index 1c656db020..5fcbbf83ab 100644 --- a/Sources/NIOCore/Interfaces.swift +++ b/Sources/NIOCore/Interfaces.swift @@ -123,7 +123,7 @@ public final class NIONetworkInterface { } #else internal init?(_ caddr: ifaddrs) { - self.name = String(cString: caddr.ifa_name) + self.name = String(cString: caddr.ifa_name!) guard caddr.ifa_addr != nil else { return nil @@ -414,7 +414,7 @@ extension NIONetworkDevice { } #else internal init?(_ caddr: ifaddrs) { - self.name = String(cString: caddr.ifa_name) + self.name = String(cString: caddr.ifa_name!) self.address = caddr.ifa_addr.flatMap { $0.convert() } self.netmask = caddr.ifa_netmask.flatMap { $0.convert() } diff --git a/Sources/NIOCore/SystemCallHelpers.swift b/Sources/NIOCore/SystemCallHelpers.swift index dc9b457a19..b74092a139 100644 --- a/Sources/NIOCore/SystemCallHelpers.swift +++ b/Sources/NIOCore/SystemCallHelpers.swift @@ -43,11 +43,16 @@ private let sysOpenWithMode: @convention(c) (UnsafePointer, CInt, NIOPOSI private let sysLseek: @convention(c) (CInt, off_t, CInt) -> off_t = lseek private let sysRead: @convention(c) (CInt, UnsafeMutableRawPointer?, size_t) -> size_t = read #endif -private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex +#if os(Android) +private let sysIfNameToIndex: @convention(c) (UnsafePointer) -> CUnsignedInt = if_nametoindex +private let sysGetifaddrs: @convention(c) (UnsafeMutablePointer?>) -> CInt = getifaddrs +#else +private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex #if !os(Windows) private let sysGetifaddrs: @convention(c) (UnsafeMutablePointer?>?) -> CInt = getifaddrs #endif +#endif private func isUnacceptableErrno(_ code: Int32) -> Bool { switch code { @@ -173,7 +178,7 @@ enum SystemCalls { @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { return try syscall(blocking: false) { - sysIfNameToIndex(name) + sysIfNameToIndex(name!) }.result } diff --git a/Sources/NIOPosix/System.swift b/Sources/NIOPosix/System.swift index 025e60f0ed..fc6c0584dc 100644 --- a/Sources/NIOPosix/System.swift +++ b/Sources/NIOPosix/System.swift @@ -90,7 +90,7 @@ func sysRecvFrom_wrapper(sockfd: CInt, buf: UnsafeMutableRawPointer, len: CLong, return recvfrom(sockfd, buf, len, flags, src_addr, addrlen) // src_addr is 'UnsafeMutablePointer', but it need to be 'UnsafePointer' } func sysWritev_wrapper(fd: CInt, iov: UnsafePointer?, iovcnt: CInt) -> CLong { - return CLong(writev(fd, iov, iovcnt)) // cast 'Int32' to 'CLong' + return CLong(writev(fd, iov!, iovcnt)) // cast 'Int32' to 'CLong' } private let sysWritev = sysWritev_wrapper #elseif !os(Windows) @@ -106,12 +106,16 @@ private let sysGetpeername: @convention(c) (CInt, UnsafeMutablePointer private let sysGetsockname: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getsockname #endif +#if os(Android) +private let sysIfNameToIndex: @convention(c) (UnsafePointer) -> CUnsignedInt = if_nametoindex +#else private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex +#endif #if !os(Windows) private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointer?) -> CInt = socketpair #endif -#if os(Linux) && !canImport(Musl) +#if (os(Linux) && !canImport(Musl)) || os(Android) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer) -> CInt = fstat private let sysStat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = stat private let sysLstat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = lstat @@ -122,9 +126,14 @@ private let sysMkdir: @convention(c) (UnsafePointer, mode_t) -> CInt = mk private let sysOpendir: @convention(c) (UnsafePointer) -> OpaquePointer? = opendir private let sysReaddir: @convention(c) (OpaquePointer) -> UnsafeMutablePointer? = readdir private let sysClosedir: @convention(c) (OpaquePointer) -> CInt = closedir +#if os(Android) +private let sysRename: @convention(c) (UnsafePointer, UnsafePointer) -> CInt = rename +private let sysRemove: @convention(c) (UnsafePointer) -> CInt = remove +#else private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove -#elseif canImport(Darwin) || os(Android) +#endif +#elseif canImport(Darwin) private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat private let sysStat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = stat private let sysLstat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = lstat @@ -132,16 +141,10 @@ private let sysSymlink: @convention(c) (UnsafePointer?, UnsafePointer?, UnsafeMutablePointer?, Int) -> CLong = readlink private let sysUnlink: @convention(c) (UnsafePointer?) -> CInt = unlink private let sysMkdir: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkdir -#if os(Android) -private let sysOpendir: @convention(c) (UnsafePointer?) -> OpaquePointer? = opendir -private let sysReaddir: @convention(c) (OpaquePointer?) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (OpaquePointer?) -> CInt = closedir -#else private let sysMkpath: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkpath_np private let sysOpendir: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = opendir private let sysReaddir: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutablePointer? = readdir private let sysClosedir: @convention(c) (UnsafeMutablePointer?) -> CInt = closedir -#endif private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove #endif @@ -732,7 +735,7 @@ internal enum Posix { @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { return try syscall(blocking: false) { - sysIfNameToIndex(name) + sysIfNameToIndex(name!) }.result } diff --git a/Sources/NIOPosix/ThreadPosix.swift b/Sources/NIOPosix/ThreadPosix.swift index b6e0ed4a30..852f08f634 100644 --- a/Sources/NIOPosix/ThreadPosix.swift +++ b/Sources/NIOPosix/ThreadPosix.swift @@ -19,7 +19,11 @@ import CNIOLinux private let sys_pthread_getname_np = CNIOLinux_pthread_getname_np private let sys_pthread_setname_np = CNIOLinux_pthread_setname_np +#if os(Android) +private typealias ThreadDestructor = @convention(c) (UnsafeMutableRawPointer) -> UnsafeMutableRawPointer +#else private typealias ThreadDestructor = @convention(c) (UnsafeMutableRawPointer?) -> UnsafeMutableRawPointer? +#endif #elseif canImport(Darwin) private let sys_pthread_getname_np = pthread_getname_np // Emulate the same method signature as pthread_setname_np on Linux. @@ -111,7 +115,11 @@ enum ThreadOpsPosix: ThreadOps { body(NIOThread(handle: hThread, desiredName: name)) + #if os(Android) + return UnsafeMutableRawPointer(bitPattern: 0xdeadbee)! + #else return nil + #endif }, args: argv0) precondition(res == 0, "Unable to create thread: \(res)") From 505bda969bf69d12b0f8de8a66b62d5c1001f514 Mon Sep 17 00:00:00 2001 From: Ryu <87907656+Ryu0118@users.noreply.github.com> Date: Wed, 29 Nov 2023 20:41:30 +0900 Subject: [PATCH 54/64] Update README.md (#2602) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c256db1ad0..725f5cd30c 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ Protocol | Client | Server | Repository | Module | Comment --- | --- | --- | --- | --- | --- HTTP | ✅| ❌ | [swift-server/async-http-client](https://github.com/swift-server/async-http-client) | `AsyncHTTPClient` | SSWG community project gRPC | ✅| ✅ | [grpc/grpc-swift](https://github.com/grpc/grpc-swift) | `GRPC` | also offers a low-level API; SSWG community project -APNS | ✅ | ❌ | [kylebrowning/APNSwift](https://github.com/kylebrowning/APNSwift) | `APNSwift` | SSWG community project +APNS | ✅ | ❌ | [swift-server-community/APNSwift](https://github.com/swift-server-community/APNSwift) | `APNSwift` | SSWG community project PostgreSQL | ✅ | ❌ | [vapor/postgres-nio](https://github.com/vapor/postgres-nio) | `PostgresNIO` | SSWG community project Redis | ✅ | ❌ | [swift-server/RediStack](https://github.com/swift-server/RediStack) | `RediStack` | SSWG community project From fc6e3c0eefb28adf641531180b81aaf41b02ed20 Mon Sep 17 00:00:00 2001 From: Alastair Houghton Date: Wed, 29 Nov 2023 17:27:13 +0000 Subject: [PATCH 55/64] Changes to support building with Musl (#2595) * Changes to support building with Musl Define `_GNU_SOURCE` in the `Package.swift` rather than in `shim.c` (this is required because `_GNU_SOURCE` affects modular headers). Add an import for Musl to IO.swift. Add code to disable `SIGPIPE` to `SocketProtocols.swift`. Remove types from a pile of functions in `System.swift`; Swift will use the correct type automatically (except in cases where there are multiple versions of a function, e.g. `ioctl()`, in which case we need to be explicit which one we mean). * Fix some test failures caused by failing to define `_GNU_SOURCE`. A couple of the integration tests grab code and build it outside of the normal `Package.swift`, so they needed fixing to define `_GNU_SOURCE` themselves. --- .../tests_02_syscall_wrappers/defines.sh | 5 +- .../tests_05_assertions/defines.sh | 5 +- Package.swift | 5 +- Sources/CNIOLinux/shim.c | 5 +- Sources/NIOCore/ChannelOption.swift | 2 +- Sources/NIOCore/IO.swift | 4 +- Sources/NIOPosix/BSDSocketAPICommon.swift | 4 +- Sources/NIOPosix/SocketProtocols.swift | 6 ++ Sources/NIOPosix/System.swift | 58 ++++++------------- 9 files changed, 46 insertions(+), 48 deletions(-) diff --git a/IntegrationTests/tests_02_syscall_wrappers/defines.sh b/IntegrationTests/tests_02_syscall_wrappers/defines.sh index c289646759..c97d1aca84 100644 --- a/IntegrationTests/tests_02_syscall_wrappers/defines.sh +++ b/IntegrationTests/tests_02_syscall_wrappers/defines.sh @@ -36,7 +36,10 @@ let package = Package( dependencies: ["CNIOLinux", "CNIODarwin", "NIOCore"]), .target( name: "CNIOLinux", - dependencies: []), + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE") + ]), .target( name: "CNIODarwin", dependencies: []), diff --git a/IntegrationTests/tests_05_assertions/defines.sh b/IntegrationTests/tests_05_assertions/defines.sh index 22ffb942d7..0c2ed428d1 100644 --- a/IntegrationTests/tests_05_assertions/defines.sh +++ b/IntegrationTests/tests_05_assertions/defines.sh @@ -31,7 +31,10 @@ let package = Package( dependencies: ["CNIOLinux", "CNIODarwin", "NIOCore"]), .target( name: "CNIOLinux", - dependencies: []), + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE") + ]), .target( name: "CNIODarwin", dependencies: []), diff --git a/Package.swift b/Package.swift index f7c7adb4c6..68d37ba072 100644 --- a/Package.swift +++ b/Package.swift @@ -115,7 +115,10 @@ let package = Package( ), .target( name: "CNIOLinux", - dependencies: [] + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE"), + ] ), .target( name: "CNIODarwin", diff --git a/Sources/CNIOLinux/shim.c b/Sources/CNIOLinux/shim.c index 236da99ff4..d56c3bba69 100644 --- a/Sources/CNIOLinux/shim.c +++ b/Sources/CNIOLinux/shim.c @@ -18,7 +18,10 @@ void CNIOLinux_i_do_nothing_just_working_around_a_darwin_toolchain_bug(void) {} #ifdef __linux__ -#define _GNU_SOURCE +#ifndef _GNU_SOURCE +#error You must define _GNU_SOURCE +#endif + #include #include #include diff --git a/Sources/NIOCore/ChannelOption.swift b/Sources/NIOCore/ChannelOption.swift index 5f1f2f6adc..074561544f 100644 --- a/Sources/NIOCore/ChannelOption.swift +++ b/Sources/NIOCore/ChannelOption.swift @@ -19,7 +19,7 @@ public protocol ChannelOption: Equatable, _NIOPreconcurrencySendable { } public typealias SocketOptionName = Int32 -#if os(Linux) || os(Android) +#if (os(Linux) || os(Android)) && !canImport(Musl) public typealias SocketOptionLevel = Int public typealias SocketOptionValue = Int #else diff --git a/Sources/NIOCore/IO.swift b/Sources/NIOCore/IO.swift index 4377154317..3fdd7a10f9 100644 --- a/Sources/NIOCore/IO.swift +++ b/Sources/NIOCore/IO.swift @@ -28,8 +28,10 @@ import typealias WinSDK.WORD internal func MAKELANGID(_ p: WORD, _ s: WORD) -> DWORD { return DWORD((s << 10) | p) } -#elseif os(Linux) || os(Android) +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #elseif canImport(Darwin) import Darwin #else diff --git a/Sources/NIOPosix/BSDSocketAPICommon.swift b/Sources/NIOPosix/BSDSocketAPICommon.swift index 539136cc26..721143a990 100644 --- a/Sources/NIOPosix/BSDSocketAPICommon.swift +++ b/Sources/NIOPosix/BSDSocketAPICommon.swift @@ -83,8 +83,8 @@ extension NIOBSDSocket.SocketType { internal static let stream: NIOBSDSocket.SocketType = NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) #endif - - #if os(Linux) + + #if os(Linux) && !canImport(Musl) internal static let raw: NIOBSDSocket.SocketType = NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue)) #else diff --git a/Sources/NIOPosix/SocketProtocols.swift b/Sources/NIOPosix/SocketProtocols.swift index 62f1c46569..90cca51fad 100644 --- a/Sources/NIOPosix/SocketProtocols.swift +++ b/Sources/NIOPosix/SocketProtocols.swift @@ -73,7 +73,13 @@ protocol SocketProtocol: BaseSocketProtocol { // This is a lazily initialised global variable that when read for the first time, will ignore SIGPIPE. private let globallyIgnoredSIGPIPE: Bool = { /* no F_SETNOSIGPIPE on Linux :( */ + #if canImport(Glibc) _ = Glibc.signal(SIGPIPE, SIG_IGN) + #elseif canImport(Musl) + _ = Musl.signal(SIGPIPE, SIG_IGN) + #else + #error("Don't know which stdlib to use") + #endif return true }() #endif diff --git a/Sources/NIOPosix/System.swift b/Sources/NIOPosix/System.swift index fc6c0584dc..6fea9a56f8 100644 --- a/Sources/NIOPosix/System.swift +++ b/Sources/NIOPosix/System.swift @@ -115,53 +115,31 @@ private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsigne private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointer?) -> CInt = socketpair #endif -#if (os(Linux) && !canImport(Musl)) || os(Android) -private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer) -> CInt = fstat -private let sysStat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = stat -private let sysLstat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = lstat -private let sysSymlink: @convention(c) (UnsafePointer, UnsafePointer) -> CInt = symlink -private let sysReadlink: @convention(c) (UnsafePointer, UnsafeMutablePointer, Int) -> CLong = readlink -private let sysUnlink: @convention(c) (UnsafePointer) -> CInt = unlink -private let sysMkdir: @convention(c) (UnsafePointer, mode_t) -> CInt = mkdir -private let sysOpendir: @convention(c) (UnsafePointer) -> OpaquePointer? = opendir -private let sysReaddir: @convention(c) (OpaquePointer) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (OpaquePointer) -> CInt = closedir -#if os(Android) -private let sysRename: @convention(c) (UnsafePointer, UnsafePointer) -> CInt = rename -private let sysRemove: @convention(c) (UnsafePointer) -> CInt = remove -#else -private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename -private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove -#endif -#elseif canImport(Darwin) -private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat -private let sysStat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = stat -private let sysLstat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = lstat -private let sysSymlink: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = symlink -private let sysReadlink: @convention(c) (UnsafePointer?, UnsafeMutablePointer?, Int) -> CLong = readlink -private let sysUnlink: @convention(c) (UnsafePointer?) -> CInt = unlink -private let sysMkdir: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkdir -private let sysMkpath: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkpath_np -private let sysOpendir: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = opendir -private let sysReaddir: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (UnsafeMutablePointer?) -> CInt = closedir -private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename -private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove +#if os(Linux) || os(Android) || canImport(Darwin) +private let sysFstat = fstat +private let sysStat = stat +private let sysLstat = lstat +private let sysSymlink = symlink +private let sysReadlink = readlink +private let sysUnlink = unlink +private let sysMkdir = mkdir +private let sysOpendir = opendir +private let sysReaddir = readdir +private let sysClosedir = closedir +private let sysRename = rename +private let sysRemove = remove #endif #if os(Linux) || os(Android) -private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg -private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg +private let sysSendMmsg = CNIOLinux_sendmmsg +private let sysRecvMmsg = CNIOLinux_recvmmsg #elseif canImport(Darwin) private let sysKevent = kevent -private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg -private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg +private let sysMkpath = mkpath_np +private let sysSendMmsg = CNIODarwin_sendmmsg +private let sysRecvMmsg = CNIODarwin_recvmmsg #endif #if !os(Windows) -#if canImport(Musl) -private let sysIoctl: @convention(c) (CInt, CInt, UnsafeMutableRawPointer) -> CInt = ioctl -#else private let sysIoctl: @convention(c) (CInt, CUnsignedLong, UnsafeMutableRawPointer) -> CInt = ioctl -#endif // canImport(Musl) #endif // !os(Windows) private func isUnacceptableErrno(_ code: Int32) -> Bool { From a60f1972d176e9fd15744a88399bb4d970f8183e Mon Sep 17 00:00:00 2001 From: hamzahrmalik Date: Thu, 30 Nov 2023 09:23:25 +0000 Subject: [PATCH 56/64] Add tests to validate the behaviour when requests/response content-length headers are wrong in HTTP1 (#2601) Add tests to validate the behaviour when requests/response content-length headers are wrong --- Tests/NIOHTTP1Tests/ContentLengthTests.swift | 131 +++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 Tests/NIOHTTP1Tests/ContentLengthTests.swift diff --git a/Tests/NIOHTTP1Tests/ContentLengthTests.swift b/Tests/NIOHTTP1Tests/ContentLengthTests.swift new file mode 100644 index 0000000000..b295df6f47 --- /dev/null +++ b/Tests/NIOHTTP1Tests/ContentLengthTests.swift @@ -0,0 +1,131 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +import NIOCore +import NIOEmbedded +import NIOHTTP1 + +final class ContentLengthTests: XCTestCase { + + /// Client receives a response longer than the content-length header + func testResponseContentTooLong() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHTTPClientHandlers() + defer { + _ = try? channel.finish() + } + // Receive a response with a content-length header of 2 but a body of more than 2 bytes + let badResponse = "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 2\r\n\r\ntoo many bytes" + + XCTAssertThrowsError(try channel.sendRequestAndReceiveResponse(response: badResponse)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidConstant) + } + + channel.embeddedEventLoop.run() + } + + /// Client receives a response shorter than the content-length header + func testResponseContentTooShort() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHTTPClientHandlers() + defer { + _ = try? channel.finish() + } + // Receive a response with a content-length header of 100 but a body of less than 100 bytes + let badResponse = "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 100\r\n\r\nnot many bytes" + + // First is successful, it just waits for more bytes + XCTAssertNoThrow(try channel.sendRequestAndReceiveResponse(response: badResponse)) + // It is waiting for 100-14 = 86 more bytes + // We will send the same response again (75 bytes) + // The client will consider this as part of the body of the previous response. No error expected + XCTAssertNoThrow(try channel.sendRequestAndReceiveResponse(response: badResponse)) + // Now the client is expected only 86-75 = 11 bytes. We wil send the same 75 byte request again + // An error is expected because everything from the 12th byte forward will be parsed as a new message, which isn't well formed + XCTAssertThrowsError(try channel.sendRequestAndReceiveResponse(response: badResponse)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidConstant) + } + + channel.embeddedEventLoop.run() + } + + /// Server receives a request longer than the content-length header + func testRequestContentTooLong() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.configureHTTPServerPipeline() + defer { + _ = try? channel.finish() + } + // Receive a request with a content-length header of 2 but a body of more than 2 bytes + let badRequest = "POST / HTTP/1.1\r\nContent-Length: 2\r\n\r\nhello" + // First one is fine, the extra bytes will be treated as the next request + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) + // Which means the next request is now malformed + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidMethod) + } + + channel.embeddedEventLoop.run() + } + + /// Server receives a request shorter than the content-length header + func testRequestContentTooShort() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.configureHTTPServerPipeline() + defer { + _ = try? channel.finish() + } + // Receive a request with a content-length header of 100 but a body of less + let badRequest = "POST / HTTP/1.1\r\nContent-Length: 100\r\n\r\nnot many bytes" + // First one is fine, server will wait for 100-14 (86) further bytes to come + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: false)) + // The full request (60 bytes) will be treated as the body of the original request + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: false)) + // The original request is still 26 bytes short. Sending the request once more will complete it + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) + // The leftover bytes from the previous write (we wrote 100 bytes where it wanted 26) will form a new malformed request + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidMethod) + } + + channel.embeddedEventLoop.run() + } +} + +extension EmbeddedChannel { + /// Do a request-response cycle + /// Asserts that sending the request won't fail + /// Throws if receiving the response fails + fileprivate func sendRequestAndReceiveResponse(response: String) throws { + // Send a request + XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/")))) + XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.end(nil))) + // Receive a response + try self.writeInbound(ByteBuffer(string: response)) + } + + /// Do a response-request cycle + /// Throws if receiving the request fails + /// Asserts that sending the response won't fail + fileprivate func receiveRequestAndSendResponse(request: String, sendResponse: Bool) throws { + // Receive a request + try self.writeInbound(ByteBuffer(string: request)) + // Send a response + if sendResponse { + XCTAssertNoThrow(try self.writeOutbound(HTTPServerResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try self.writeOutbound(HTTPServerResponsePart.end(nil))) + } + } +} From 4a42bc24c6e03f53b0c677164d8ea539f82179e6 Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Thu, 30 Nov 2023 11:08:35 +0000 Subject: [PATCH 57/64] Add NIOAsyncWriterSinkDelegate._didSuspend hook for testing (#2597) * Add internal _didSuspend to NIOAsyncWriter for testing * PR feedback: Rename assert(numSuspends, ...) to assert(suspendCallCount, ...) * PR feedback: Add equivalent hook in NIOThrowingAsyncSequenceProducer * PR feedback: Add equivalent hook in NIOAsyncSequenceProducer --------- Co-authored-by: Franz Busch --- .../AsyncSequences/NIOAsyncWriter.swift | 4 + .../NIOThrowingAsyncSequenceProducer.swift | 4 + .../NIOAsyncSequenceTests.swift | 68 +++-- .../AsyncSequences/NIOAsyncWriterTests.swift | 268 +++++++++++------- .../NIOThrowingAsyncSequenceTests.swift | 167 +++++++---- 5 files changed, 328 insertions(+), 183 deletions(-) diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index 9a8947f231..ed1011368c 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -456,6 +456,9 @@ extension NIOAsyncWriter { /// The state machine. @usableFromInline /* private */ internal var _stateMachine: StateMachine + /// Hook used in testing. + @usableFromInline + internal var _didSuspend: (() -> Void)? @inlinable internal var isWriterFinished: Bool { @@ -540,6 +543,7 @@ extension NIOAsyncWriter { ) self._lock.unlock() + self._didSuspend?() } } } onCancel: { diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index 8f7020adab..7b852dd6b6 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -398,6 +398,9 @@ extension NIOThrowingAsyncSequenceProducer { /// The delegate. @usableFromInline /* private */ internal var _delegate: Delegate? + /// Hook used in testing. + @usableFromInline + internal var _didSuspend: (() -> Void)? @inlinable var isFinished: Bool { @@ -595,6 +598,7 @@ extension NIOThrowingAsyncSequenceProducer { case .none: self._lock.unlock() } + self._didSuspend?() } } } onCancel: { diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 40325ce09b..db2471c43b 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -import NIOCore +@testable import NIOCore import XCTest @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -266,11 +266,14 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testFinish_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + async let element = sequence.first { _ in true } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() @@ -346,15 +349,16 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { var source = newSequence?.source newSequence = nil - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + let suspended = expectation(description: "task suspended") + sequence!._throwingSequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { let element = await sequence!.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) source = nil @@ -432,14 +436,18 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { // MARK: - Task cancel func testTaskCancel_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let value = await task.value @@ -448,41 +456,51 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testTaskCancel_whenStreaming_andNotSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + let resumed = expectation(description: "task resumed") + let cancelled = expectation(description: "task cancelled") + + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() - let value = await iterator.next() - // Sleeping here a bit to make sure we hit the case where - // we are streaming and still retain the iterator. - try? await Task.sleep(nanoseconds: 1_000_000) + let value = await iterator.next() + resumed.fulfill() + await fulfillment(of: [cancelled], timeout: 1) return value } - try await Task.sleep(nanoseconds: 2_000_000) - + await fulfillment(of: [suspended], timeout: 1) _ = self.source.yield(contentsOf: [1]) + await fulfillment(of: [resumed], timeout: 1) task.cancel() + cancelled.fulfill() + let value = await task.value XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) XCTAssertEqual(value, 1) } func testTaskCancel_whenSourceFinished() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() + XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) task.cancel() let value = await task.value @@ -491,15 +509,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { func testTaskCancel_whenStreaming_andTaskIsAlreadyCancelled() async throws { let sequence = try XCTUnwrap(self.sequence) + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return await iterator.next() } task.cancel() + cancelled.fulfill() let value = await task.value diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 091e41743e..52963e9004 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// import DequeModule -import NIOCore +@testable import NIOCore import XCTest import NIOConcurrencyHelpers @@ -34,6 +34,18 @@ private final class MockAsyncWriterDelegate: NIOAsyncWriterSinkDelegate, @unchec } } + var _didSuspendCallCount = NIOLockedValueBox(0) + var didSuspendCallCount: Int { + self._didSuspendCallCount.withLockedValue { $0 } + } + var didSuspendHandler: (() -> Void)? + func didSuspend() { + self._didSuspendCallCount.withLockedValue { $0 += 1 } + if let didSuspendHandler = self.didSuspendHandler { + didSuspendHandler() + } + } + var _didTerminateCallCount = NIOLockedValueBox(0) var didTerminateCallCount: Int { self._didTerminateCallCount.withLockedValue { $0 } @@ -65,6 +77,7 @@ final class NIOAsyncWriterTests: XCTestCase { ) self.writer = newWriter.writer self.sink = newWriter.sink + self.sink._storage._didSuspend = self.delegate.didSuspend } override func tearDown() { @@ -81,6 +94,18 @@ final class NIOAsyncWriterTests: XCTestCase { super.tearDown() } + func assert( + suspendCallCount: Int, + yieldCallCount: Int, + terminateCallCount: Int, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertEqual(self.delegate.didSuspendCallCount, suspendCallCount, "Unexpeced suspends", file: file, line: line) + XCTAssertEqual(self.delegate.didYieldCallCount, yieldCallCount, "Unexpected yields", file: file, line: line) + XCTAssertEqual(self.delegate.didTerminateCallCount, terminateCallCount, "Unexpected terminates", file: file, line: line) + } + func testMultipleConcurrentWrites() async throws { var elements = 0 self.delegate.didYieldHandler = { elements += $0.count } @@ -148,7 +173,7 @@ final class NIOAsyncWriterTests: XCTestCase { writer = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) XCTAssertNil(writer) sink.finish() @@ -168,7 +193,7 @@ final class NIOAsyncWriterTests: XCTestCase { try await writer!.yield("message1") writer = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) XCTAssertNil(writer) sink.finish() @@ -179,18 +204,17 @@ final class NIOAsyncWriterTests: XCTestCase { self.writer.finish() self.writer = nil - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) } func testWriterDeinitialized_whenFinished() async throws { self.sink.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) self.writer = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - ToggleWritability @@ -198,15 +222,18 @@ final class NIOAsyncWriterTests: XCTestCase { func testSetWritability_whenInitial() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testSetWritability_whenStreaming_andBecomingUnwritable() async throws { @@ -215,75 +242,88 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testSetWritability_whenStreaming_andBecomingWritable() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let resumed = expectation(description: "yield completed") + Task { [writer] in try await writer!.yield("message2") + resumed.fulfill() } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.sink.setWritability(to: true) - // Sleep a bit so that the other Task can retry the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [resumed], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testSetWritability_whenStreaming_andSettingSameWritability() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) // Setting the writability to the same state again shouldn't change anything self.sink.setWritability(to: false) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testSetWritability_whenWriterFinished() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let resumed = expectation(description: "yield completed") + Task { [writer] in try await writer!.yield("message1") + resumed.fulfill() } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) self.sink.setWritability(to: true) - // Sleep a bit so that the other Task can retry the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [resumed], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1) } func testSetWritability_whenFinished() async throws { @@ -291,7 +331,7 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: false) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - Yield @@ -299,86 +339,98 @@ final class NIOAsyncWriterTests: XCTestCase { func testYield_whenInitial_andWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenInitial_andNotWritable() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenStreaming_andWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) try await self.writer.yield("message2") - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + self.assert(suspendCallCount: 0, yieldCallCount: 2, terminateCallCount: 0) } func testYield_whenStreaming_andNotWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenStreaming_andYieldCancelled() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) + + let cancelled = expectation(description: "task cancelled") let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message2") } task.cancel() + cancelled.fulfill() await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenWriterFinished() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenFinished() async throws { @@ -387,7 +439,7 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenFinishedError() async throws { @@ -396,75 +448,80 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertTrue(error is SomeError) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - Cancel func testCancel_whenInitial() async throws { + let cancelled = expectation(description: "task cancelled") + let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message1") } task.cancel() + cancelled.fulfill() await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testCancel_whenStreaming_andCancelBeforeYield() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) + + let cancelled = expectation(description: "task cancelled") let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message2") } task.cancel() + cancelled.fulfill() await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testCancel_whenStreaming_andCancelAfterSuspendedYield() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let task = Task { [writer] in try await writer!.yield("message2") } - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) task.cancel() await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) } - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: true) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testCancel_whenFinished() async throws { @@ -472,23 +529,20 @@ final class NIOAsyncWriterTests: XCTestCase { XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + let cancelled = expectation(description: "task cancelled") + let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message1") } - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) - task.cancel() + cancelled.fulfill() - XCTAssertEqual(self.delegate.didYieldCallCount, 0) await XCTAssertThrowsError(try await task.value) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } + XCTAssertEqual(self.delegate.didYieldCallCount, 0) } // MARK: - Writer Finish @@ -496,13 +550,13 @@ final class NIOAsyncWriterTests: XCTestCase { func testWriterFinish_whenInitial() async throws { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } func testWriterFinish_whenInitial_andFailure() async throws { self.writer.finish(error: SomeError()) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } func testWriterFinish_whenStreaming() async throws { @@ -510,38 +564,42 @@ final class NIOAsyncWriterTests: XCTestCase { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) } func testWriterFinish_whenStreaming_AndBufferedElements() async throws { // We are setting up a suspended yield here to check that it gets resumed self.sink.setWritability(to: false) + + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } let task = Task { [writer] in try await writer!.yield("message1") } - - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) // We have to become writable again to unbuffer the yield self.sink.setWritability(to: true) await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1) } func testWriterFinish_whenFinished() { // This tests just checks that finishing again is a no-op self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } // MARK: - Sink Finish @@ -561,7 +619,7 @@ final class NIOAsyncWriterTests: XCTestCase { XCTAssertNil(sink) XCTAssertNotNil(writer) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testSinkFinish_whenStreaming() async throws { @@ -575,25 +633,33 @@ final class NIOAsyncWriterTests: XCTestCase { let writer = newWriter!.writer newWriter = nil - Task { [writer] in - try await writer.yield("message1") - } - - try await Task.sleep(nanoseconds: 1_000_000) + try await writer.yield("message1") sink = nil XCTAssertNil(sink) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testSinkFinish_whenFinished() async throws { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) self.sink = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) + } +} + +#if !canImport(Darwin) && swift(<5.10) +extension XCTestCase { + func fulfillment( + of expectations: [XCTestExpectation], + timeout seconds: TimeInterval, + enforceOrder enforceOrderOfFulfillment: Bool = false + ) async { + wait(for: expectations, timeout: seconds) } } +#endif diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index 95978bc9cf..3e71d71988 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -import NIOCore +@testable import NIOCore import XCTest @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) @@ -103,15 +103,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYield_whenStreaming_andSuspended_andStopProducing() async throws { self.backPressureStrategy.didYieldHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: [1]) @@ -128,15 +131,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYield_whenStreaming_andSuspended_andProduceMore() async throws { self.backPressureStrategy.didYieldHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: [1]) @@ -153,15 +159,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYieldEmptySequence_whenStreaming_andSuspended_andStopProducing() async throws { self.backPressureStrategy.didYieldHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) - try await withThrowingTaskGroup(of: Void.self) { group in + await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: []) @@ -175,15 +184,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYieldEmptySequence_whenStreaming_andSuspended_andProduceMore() async throws { self.backPressureStrategy.didYieldHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) - try await withThrowingTaskGroup(of: Void.self) { group in + await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: []) @@ -231,16 +243,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testFinish_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { let element = try await sequence.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.source.finish() @@ -312,15 +326,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testFinishError_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) await XCTAssertThrowsError(try await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.source.finish(ChannelError.alreadyClosed) @@ -423,15 +439,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { var source = newSequence?.source newSequence = nil - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence!._storage._didSuspend = { suspended.fulfill() } + group.addTask { let element = try await sequence!.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) source = nil @@ -511,14 +529,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { // MARK: - Task cancel func testTaskCancel_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let result = await task.result @@ -531,8 +552,6 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { @available(*, deprecated, message: "tests the deprecated custom generic failure type") func testTaskCancel_whenStreaming_andSuspended_withCustomErrorType() async throws { struct CustomError: Error {} - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let backPressureStrategy = MockNIOElementStreamBackPressureStrategy() let delegate = MockNIOBackPressuredStreamSourceDelegate() let new = NIOThrowingAsyncSequenceProducer.makeSequence( @@ -542,11 +561,16 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { delegate: delegate ) let sequence = new.sequence + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let result = await task.result @@ -558,36 +582,47 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testTaskCancel_whenStreaming_andNotSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + let suspended = expectation(description: "task suspended") + let resumed = expectation(description: "task resumed") + let cancelled = expectation(description: "task cancelled") + + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() let element = try await iterator.next() - - // Sleeping here to give the other Task a chance to cancel this one. - try? await Task.sleep(nanoseconds: 1_000_000) + resumed.fulfill() + await fulfillment(of: [cancelled], timeout: 1) return element } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) _ = self.source.yield(contentsOf: [1]) + await fulfillment(of: [resumed], timeout: 1) + task.cancel() + cancelled.fulfill() + let value = try await task.value XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) XCTAssertEqual(value, 1) } func testTaskCancel_whenSourceFinished() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) @@ -598,15 +633,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testTaskCancel_whenStreaming_andTaskIsAlreadyCancelled() async throws { let sequence = try XCTUnwrap(self.sequence) + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return try await iterator.next() } task.cancel() + cancelled.fulfill() let result = await task.result @@ -627,16 +664,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { delegate: delegate ) let sequence = new.sequence + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return try await iterator.next() } task.cancel() - + cancelled.fulfill() + let result = await task.result try withExtendedLifetime(new.source) { XCTAssertNil(try result.get()) @@ -647,14 +686,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testNext_whenInitial_whenDemand() async throws { self.backPressureStrategy.didNextHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.produceMore]) @@ -662,14 +704,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testNext_whenInitial_whenNoDemand() async throws { self.backPressureStrategy.didNextHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } @@ -679,14 +724,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { _ = self.source.yield(contentsOf: []) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didYield]) - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } @@ -696,14 +744,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { _ = self.source.yield(contentsOf: []) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didYield]) - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } From 890486769b96966a952849fc1bd23eaec75e5b40 Mon Sep 17 00:00:00 2001 From: Alastair Houghton Date: Tue, 5 Dec 2023 15:49:43 +0000 Subject: [PATCH 58/64] Fix warnings caused by not defining the feature macros. (#2606) We should define `_GNU_SOURCE` for all of the C modules, really. Relying on `features.h` doesn't work for modular headers. --- Package.swift | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/Package.swift b/Package.swift index 68d37ba072..ac54050d6d 100644 --- a/Package.swift +++ b/Package.swift @@ -107,7 +107,10 @@ let package = Package( ), .target( name: "CNIOAtomics", - dependencies: [] + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE"), + ] ), .target( name: "CNIOSHA1", @@ -159,7 +162,10 @@ let package = Package( ), .target( name: "CNIOLLHTTP", - cSettings: [.define("LLHTTP_STRICT_MODE")] + cSettings: [ + .define("_GNU_SOURCE"), + .define("LLHTTP_STRICT_MODE") + ] ), .target( name: "NIOTLS", From 50ae57ac9d2817f045896316a8145890fe1b59f8 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Fri, 8 Dec 2023 13:12:29 +0000 Subject: [PATCH 59/64] Fix test availability annotations (#2607) # Motivation We are currently missing a bunch of annotations in our tests which leads to compilation failures if you build against generic iOS/macOS/watchOS/tvOS # Modification This PR adds the missing availability annotations in the tests. --- .../AsyncChannel/AsyncChannelTests.swift | 12 +------ Tests/NIOCoreTests/AsyncSequenceTests.swift | 2 +- .../AsyncTestingChannelTests.swift | 33 +------------------ .../AsyncTestingEventLoopTests.swift | 25 +------------- Tests/NIOPosixTests/NIOThreadPoolTest.swift | 5 +-- Tests/NIOPosixTests/SerialExecutorTests.swift | 5 +-- Tests/NIOPosixTests/TestUtils.swift | 3 ++ .../GlobalSingletonsTests.swift | 2 ++ 8 files changed, 11 insertions(+), 76 deletions(-) diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index e7d320bf01..cafddee308 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -17,9 +17,9 @@ import NIOConcurrencyHelpers import NIOEmbedded import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelTests: XCTestCase { func testAsyncChannelCloseOnWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } final class CloseOnWriteHandler: ChannelOutboundHandler { typealias OutboundIn = String @@ -39,7 +39,6 @@ final class AsyncChannelTests: XCTestCase { } func testAsyncChannelBasicFunctionality() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -65,7 +64,6 @@ final class AsyncChannelTests: XCTestCase { } func testAsyncChannelBasicWrites() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -84,7 +82,6 @@ final class AsyncChannelTests: XCTestCase { } func testFinishingTheWriterClosesTheWriteSideOfTheChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() try await channel.pipeline.addHandler(closeRecorder) @@ -116,7 +113,6 @@ final class AsyncChannelTests: XCTestCase { } func testDroppingEverythingDoesntCloseTheChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() try await channel.pipeline.addHandler(CloseSuppressor()) @@ -148,7 +144,6 @@ final class AsyncChannelTests: XCTestCase { } func testReadsArePropagated() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -167,7 +162,6 @@ final class AsyncChannelTests: XCTestCase { } func testErrorsArePropagatedButAfterReads() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -190,7 +184,6 @@ final class AsyncChannelTests: XCTestCase { } func testChannelBecomingNonWritableDelaysWriters() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) @@ -227,7 +220,6 @@ final class AsyncChannelTests: XCTestCase { } func testBufferDropsReadsIfTheReaderIsGone() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() try await channel.pipeline.addHandler(CloseSuppressor()).get() do { @@ -252,7 +244,6 @@ final class AsyncChannelTests: XCTestCase { } func testManagingBackPressure() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let readCounter = ReadCounter() try await channel.pipeline.addHandler(readCounter) @@ -367,7 +358,6 @@ final class AsyncChannelTests: XCTestCase { } func testCanWrapAChannelSynchronously() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel(wrappingChannelSynchronously: channel) diff --git a/Tests/NIOCoreTests/AsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequenceTests.swift index 07c51b34d3..9a000a79ef 100644 --- a/Tests/NIOCoreTests/AsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequenceTests.swift @@ -25,9 +25,9 @@ fileprivate struct TestCase { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncSequenceCollectTests: XCTestCase { func testAsyncSequenceCollect() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let testCases = [ TestCase([ [], diff --git a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift index c3122e9b84..5a40874674 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift @@ -17,9 +17,9 @@ import Atomics import NIOCore @testable import NIOEmbedded +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) class AsyncTestingChannelTests: XCTestCase { func testSingleHandlerInit() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler { typealias InboundIn = Never } @@ -29,7 +29,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmptyInit() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler { typealias InboundIn = Never @@ -43,7 +42,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testMultipleHandlerInit() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = Never let identifier: String @@ -67,7 +65,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForInboundWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { try await XCTAsyncAssertEqual(try await channel.waitForInboundWrite(), 1) @@ -82,7 +79,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForMultipleInboundWritesInParallel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { let task1 = Task { try await channel.waitForInboundWrite(as: Int.self) } @@ -102,7 +98,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForOutboundWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { try await XCTAsyncAssertEqual(try await channel.waitForOutboundWrite(), 1) @@ -117,7 +112,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForMultipleOutboundWritesInParallel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { let task1 = Task { try await channel.waitForOutboundWrite(as: Int.self) } @@ -137,7 +131,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBuffer() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -158,7 +151,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBufferMultipleTimes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -179,7 +171,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBuffer() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -192,7 +183,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBufferMultipleTimes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -213,7 +203,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBufferReThrow() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(ExceptionThrowingInboundHandler()).wait()) await XCTAsyncAssertThrowsError(try await channel.writeInbound("msg")) { error in @@ -223,7 +212,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBufferReThrow() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(ExceptionThrowingOutboundHandler()).wait()) await XCTAsyncAssertThrowsError(try await channel.writeOutbound("msg")) { error in @@ -233,7 +221,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testReadOutboundWrongTypeThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.writeOutbound("hello").isFull) do { @@ -248,7 +235,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testReadInboundWrongTypeThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.writeInbound("hello").isFull) do { @@ -263,7 +249,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWrongTypesWithFastpathTypes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buffer = channel.allocator.buffer(capacity: 0) @@ -312,7 +297,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testCloseMultipleTimesThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.finish().isClean) @@ -326,7 +310,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testCloseOnInactiveIsOk() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let inactiveHandler = CloseInChannelInactiveHandler() XCTAssertNoThrow(try channel.pipeline.addHandler(inactiveHandler).wait()) @@ -337,7 +320,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedLifecycle() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let handler = ChannelLifecycleHandler() XCTAssertEqual(handler.currentState, .unregistered) @@ -383,7 +365,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedChannelAndPipelineAndChannelCoreShareTheEventLoop() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let pipelineEventLoop = channel.pipeline.eventLoop XCTAssert(pipelineEventLoop === channel.eventLoop) @@ -392,7 +373,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSendingAnythingOnEmbeddedChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buffer = ByteBufferAllocator().buffer(capacity: 5) let socketAddress = try SocketAddress(unixDomainSocketPath: "path") @@ -411,7 +391,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testActiveWhenConnectPromiseFiresAndInactiveWhenClosePromiseFires() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertFalse(channel.isActive) let connectPromise = channel.eventLoop.makePromise(of: Void.self) @@ -431,7 +410,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteWithoutFlushDoesNotWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buf = ByteBuffer(bytes: [1]) @@ -445,7 +423,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetLocalAddressAfterSuccessfulBind() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let bindPromise = channel.eventLoop.makePromise(of: Void.self) @@ -459,7 +436,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetRemoteAddressAfterSuccessfulConnect() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let connectPromise = channel.eventLoop.makePromise(of: Void.self) @@ -473,7 +449,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testUnprocessedOutboundUserEventFailsOnEmbeddedChannel() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertThrowsError(try channel.triggerUserOutboundEvent("event").wait()) { (error: Error) in @@ -487,7 +462,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedChannelWritabilityIsWritable() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let opaqueChannel: Channel = channel @@ -500,7 +474,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testFinishWithRecursivelyScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let invocations = AtomicCounter() @@ -518,7 +491,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSyncOptionsAreSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -530,7 +502,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testGetChannelOptionAutoReadIsSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -541,7 +512,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetGetChannelOptionAllowRemoteHalfClosureIsSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -559,7 +529,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSecondFinishThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() _ = try await channel.finish() await XCTAsyncAssertThrowsError(try await channel.finish()) diff --git a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift index 334d1c50e1..0e67bd99d2 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift @@ -19,9 +19,9 @@ import Atomics private class EmbeddedTestError: Error { } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class NIOAsyncTestingEventLoopTests: XCTestCase { func testExecuteDoesNotImmediatelyRunTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() try await loop.executeInContext { @@ -33,7 +33,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteWillRunAllTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let runCount = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() loop.execute { runCount.wrappingIncrement(ordering: .relaxed) } @@ -50,7 +49,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteWillRunTasksAddedRecursively() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() @@ -80,7 +78,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteRunsImmediately() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.execute { callbackRan.store(true, ordering: .relaxed) } @@ -99,7 +96,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testTasksScheduledAfterRunDontRun() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -120,7 +116,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testSubmitRunsImmediately() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() _ = loop.submit { callbackRan.store(true, ordering: .relaxed) } @@ -139,7 +134,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testSyncShutdownGracefullyRunsTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -154,7 +148,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testShutdownGracefullyRunsTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -169,7 +162,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCanControlTime() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackCount = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -198,7 +190,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCanScheduleMultipleTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() for index in 1...10 { @@ -219,7 +210,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecutedTasksFromScheduledOnesAreRun() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -240,7 +230,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFromScheduledTasksProperlySchedule() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -280,7 +269,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFromExecutedTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() loop.execute { @@ -299,7 +287,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCancellingScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let loop = NIOAsyncTestingEventLoop() let task = loop.scheduleTask(in: .nanoseconds(10), { XCTFail("Cancelled task ran") }) _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -310,7 +297,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFuturesFire() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let fired = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() let task = loop.scheduleTask(in: .nanoseconds(5)) { true } @@ -323,7 +309,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFuturesError() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let err = EmbeddedTestError() let fired = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() @@ -345,7 +330,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testTaskOrdering() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } // This test validates that the ordering of task firing on NIOAsyncTestingEventLoop via // advanceTime(by:) is the same as on MultiThreadedEventLoopGroup: specifically, that tasks run via // schedule that expire "now" all run at the same time, and that any work they schedule is run @@ -426,7 +410,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCancelledScheduledTasksDoNotHoldOnToRunClosure() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() defer { XCTAssertNoThrow(try eventLoop.syncShutdownGracefully()) @@ -466,7 +449,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testDrainScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) let startTime = eventLoop.now @@ -486,7 +468,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testDrainScheduledTasksDoesNotRunNewlyScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) @@ -503,7 +484,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testAdvanceTimeToDeadline() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let deadline = NIODeadline.uptimeNanoseconds(0) + .seconds(42) @@ -517,7 +497,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testWeCantTimeTravelByAdvancingTimeToThePast() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) @@ -539,7 +518,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteInOrder() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let counter = ManagedAtomic(0) @@ -563,7 +541,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksInOrder() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let counter = ManagedAtomic(0) diff --git a/Tests/NIOPosixTests/NIOThreadPoolTest.swift b/Tests/NIOPosixTests/NIOThreadPoolTest.swift index b51a96ad47..3a4d772781 100644 --- a/Tests/NIOPosixTests/NIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/NIOThreadPoolTest.swift @@ -19,6 +19,7 @@ import Dispatch import NIOConcurrencyHelpers import NIOEmbedded +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) class NIOThreadPoolTest: XCTestCase { func testThreadNamesAreSetUp() { let numberOfThreads = 11 @@ -112,7 +113,6 @@ class NIOThreadPoolTest: XCTestCase { } func testAsyncThreadPool() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let numberOfThreads = 1 let pool = NIOThreadPool(numberOfThreads: numberOfThreads) pool.start() @@ -127,7 +127,6 @@ class NIOThreadPoolTest: XCTestCase { } func testAsyncThreadPoolErrorPropagation() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } struct ThreadPoolError: Error {} let numberOfThreads = 1 let pool = NIOThreadPool(numberOfThreads: numberOfThreads) @@ -144,7 +143,6 @@ class NIOThreadPoolTest: XCTestCase { } func testAsyncThreadPoolNotActiveError() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } struct ThreadPoolError: Error {} let numberOfThreads = 1 let pool = NIOThreadPool(numberOfThreads: numberOfThreads) @@ -160,7 +158,6 @@ class NIOThreadPoolTest: XCTestCase { } func testAsyncShutdownWorks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let threadPool = NIOThreadPool(numberOfThreads: 17) let eventLoop = NIOAsyncTestingEventLoop() diff --git a/Tests/NIOPosixTests/SerialExecutorTests.swift b/Tests/NIOPosixTests/SerialExecutorTests.swift index 56bc882bec..85619b0ca3 100644 --- a/Tests/NIOPosixTests/SerialExecutorTests.swift +++ b/Tests/NIOPosixTests/SerialExecutorTests.swift @@ -37,14 +37,12 @@ actor EventLoopBoundActor { } #endif +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) final class SerialExecutorTests: XCTestCase { private func _testBasicExecutorFitsOnEventLoop(loop1: EventLoop, loop2: EventLoop) async throws { #if compiler(<5.9) throw XCTSkip("Custom executors are only supported in 5.9") #else - guard #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) else { - throw XCTSkip("Custom executors not available on this platform") - } let testActor = EventLoopBoundActor(loop: loop1) await testActor.assertInLoop(loop1) @@ -61,7 +59,6 @@ final class SerialExecutorTests: XCTestCase { try await self._testBasicExecutorFitsOnEventLoop(loop1: loops[0], loop2: loops[1]) } - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testBasicExecutorFitsOnEventLoop_AsyncTestingEventLoop() async throws { let loop1 = NIOAsyncTestingEventLoop() let loop2 = NIOAsyncTestingEventLoop() diff --git a/Tests/NIOPosixTests/TestUtils.swift b/Tests/NIOPosixTests/TestUtils.swift index 461273194e..0c38f94d8f 100644 --- a/Tests/NIOPosixTests/TestUtils.swift +++ b/Tests/NIOPosixTests/TestUtils.swift @@ -68,6 +68,7 @@ func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) throws -> [ } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) async throws -> [NIOCore.NIOFileHandle]) async throws { var fds: [Int32] = [-1, -1] fds.withUnsafeMutableBufferPointer { ptr in @@ -98,6 +99,7 @@ func withTemporaryDirectory(_ body: (String) throws -> T) rethrows -> T { return try body(dir) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func withTemporaryDirectory(_ body: (String) async throws -> T) async rethrows -> T { let dir = createTemporaryDirectory() defer { @@ -170,6 +172,7 @@ func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle return try body(fileHandle, path) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func withTemporaryFile(content: String? = nil, _ body: @escaping @Sendable (NIOCore.NIOFileHandle, String) async throws -> T) async rethrows -> T { let (fd, path) = openTemporaryFile() let fileHandle = NIOFileHandle(descriptor: fd) diff --git a/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift b/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift index 0c534568d1..d55fab898e 100644 --- a/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift +++ b/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift @@ -18,11 +18,13 @@ import NIOPosix import Foundation final class NIOSingletonsTests: XCTestCase { + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testSingletonMultiThreadedEventLoopWorks() async throws { let works = try await MultiThreadedEventLoopGroup.singleton.any().submit { "yes" }.get() XCTAssertEqual(works, "yes") } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testSingletonBlockingPoolWorks() async throws { let works = try await NIOThreadPool.singleton.runIfActive( eventLoop: MultiThreadedEventLoopGroup.singleton.any() From 4c77ef01844a6551c3de441fd0a12174566673af Mon Sep 17 00:00:00 2001 From: Si Beaumont Date: Tue, 19 Dec 2023 16:34:36 +0000 Subject: [PATCH 60/64] Fix building tests on Swift 5.9.2 Linux (#2610) --- Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 52963e9004..9b93841252 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -652,7 +652,7 @@ final class NIOAsyncWriterTests: XCTestCase { } } -#if !canImport(Darwin) && swift(<5.10) +#if !canImport(Darwin) && swift(<5.9.2) extension XCTestCase { func fulfillment( of expectations: [XCTestExpectation], From 23e995b3af75c54a192c170cc80dffa446a262ce Mon Sep 17 00:00:00 2001 From: David Nadoba Date: Fri, 22 Dec 2023 10:07:35 +0000 Subject: [PATCH 61/64] Set `SWIFT_VERSION` environment variable to resolve to the correct benchmarks thresholds path (#2613) * Set `SWIFT_VERSION` environment variable to resolve to the correct benchmarks thresholds path * mallocs have increased * update benchmark results manually * update thresholds again * disable flaky benchmark --- .../NIOPosixBenchmarks/Benchmarks.swift | 44 ++++++++++--------- .../5.10/NIOPosixBenchmarks.TCPEcho.p90.json | 4 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 -- .../5.7/NIOPosixBenchmarks.TCPEcho.p90.json | 4 +- .../5.8/NIOPosixBenchmarks.TCPEcho.p90.json | 4 +- .../5.9/NIOPosixBenchmarks.TCPEcho.p90.json | 4 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 -- .../main/NIOPosixBenchmarks.TCPEcho.p90.json | 4 +- ...sixBenchmarks.TCPEchoAsyncChannel.p90.json | 3 -- docker/docker-compose.2204.510.yaml | 1 + docker/docker-compose.2204.57.yaml | 1 + docker/docker-compose.2204.58.yaml | 1 + docker/docker-compose.2204.59.yaml | 1 + docker/docker-compose.2204.main.yaml | 1 + 14 files changed, 38 insertions(+), 40 deletions(-) delete mode 100644 Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json delete mode 100644 Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json delete mode 100644 Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift index 56bb64dd61..521ce7df00 100644 --- a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift @@ -39,26 +39,28 @@ let benchmarks = { // This benchmark is only available above 5.9 since our EL conformance // to serial executor is also gated behind 5.9. #if compiler(>=5.9) - Benchmark( - "TCPEchoAsyncChannel", - configuration: .init( - metrics: defaultMetrics, - timeUnits: .milliseconds, - scalingFactor: .mega, - setup: { - swiftTaskEnqueueGlobalHook = { job, _ in - eventLoop.executor.enqueue(job) - } - }, - teardown: { - swiftTaskEnqueueGlobalHook = nil - } - ) - ) { benchmark in - try await runTCPEchoAsyncChannel( - numberOfWrites: benchmark.scaledIterations.upperBound, - eventLoop: eventLoop - ) - } +// In addition this benchmark currently doesn't produce deterministic results on our CI +// and therefore is currently disabled +// Benchmark( +// "TCPEchoAsyncChannel", +// configuration: .init( +// metrics: defaultMetrics, +// timeUnits: .milliseconds, +// scalingFactor: .mega, +// setup: { +// swiftTaskEnqueueGlobalHook = { job, _ in +// eventLoop.executor.enqueue(job) +// } +// }, +// teardown: { +// swiftTaskEnqueueGlobalHook = nil +// } +// ) +// ) { benchmark in +// try await runTCPEchoAsyncChannel( +// numberOfWrites: benchmark.scaledIterations.upperBound, +// eventLoop: eventLoop +// ) +// } #endif } diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json index fa70aea890..c6a93680d0 100644 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 90 -} \ No newline at end of file + "mallocCountTotal" : 108 +} diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json deleted file mode 100644 index 74498fb02f..0000000000 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "mallocCountTotal" : 164419 -} diff --git a/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json index 1859f424c5..248bd96061 100644 --- a/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 92 -} \ No newline at end of file + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json index 1859f424c5..248bd96061 100644 --- a/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 92 -} \ No newline at end of file + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json index 1859f424c5..248bd96061 100644 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 92 -} \ No newline at end of file + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json deleted file mode 100644 index c38c7cbbfd..0000000000 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "mallocCountTotal" : 164426 -} \ No newline at end of file diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json index fa70aea890..c6a93680d0 100644 --- a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 90 -} \ No newline at end of file + "mallocCountTotal" : 108 +} diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json deleted file mode 100644 index 617e73531c..0000000000 --- a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEchoAsyncChannel.p90.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "mallocCountTotal" : 164419 -} \ No newline at end of file diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml index a9089f4982..757084acc9 100644 --- a/docker/docker-compose.2204.510.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -20,6 +20,7 @@ services: test: image: swift-nio:22.04-5.10 environment: + - SWIFT_VERSION=5.10 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 diff --git a/docker/docker-compose.2204.57.yaml b/docker/docker-compose.2204.57.yaml index d6d16f9c34..415223ba5b 100644 --- a/docker/docker-compose.2204.57.yaml +++ b/docker/docker-compose.2204.57.yaml @@ -21,6 +21,7 @@ services: test: image: swift-nio:22.04-5.7 environment: + - SWIFT_VERSION=5.7 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml index 827f76771e..af365fa86e 100644 --- a/docker/docker-compose.2204.58.yaml +++ b/docker/docker-compose.2204.58.yaml @@ -21,6 +21,7 @@ services: test: image: swift-nio:22.04-5.8 environment: + - SWIFT_VERSION=5.8 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index c348433d9e..47a13adc14 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -21,6 +21,7 @@ services: test: image: swift-nio:22.04-5.9 environment: + - SWIFT_VERSION=5.9 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml index 5d08b84b1f..6198161e61 100644 --- a/docker/docker-compose.2204.main.yaml +++ b/docker/docker-compose.2204.main.yaml @@ -20,6 +20,7 @@ services: test: image: swift-nio:22.04-main environment: + - SWIFT_VERSION=main - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 From 1445dcabeb6ee6c32006899d8977c77088e80dba Mon Sep 17 00:00:00 2001 From: Gustavo Cairo Date: Fri, 22 Dec 2023 07:47:32 -0300 Subject: [PATCH 62/64] Add cxx interop build pipeline (#2614) * Add cxx interop build pipeline * Add imports to source * Update docker-compose for 5.9 and 5.10 * Remove test target --------- Co-authored-by: George Barnett Co-authored-by: David Nadoba --- docker/docker-compose.2204.510.yaml | 3 ++ docker/docker-compose.2204.59.yaml | 3 ++ docker/docker-compose.yaml | 4 ++ scripts/cxx-interop-compatibility.sh | 79 ++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) create mode 100755 scripts/cxx-interop-compatibility.sh diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml index 757084acc9..7016963d72 100644 --- a/docker/docker-compose.2204.510.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -90,3 +90,6 @@ services: http: image: swift-nio:22.04-5.10 + + cxx-interop-build: + image: swift-nio:22.04-5.10 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index 47a13adc14..889c4cb8b8 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -91,3 +91,6 @@ services: http: image: swift-nio:22.04-5.9 + + cxx-interop-build: + image: swift-nio:22.04-5.9 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index c563622f75..0c3c846e81 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -49,6 +49,10 @@ services: update-benchmark-baseline: <<: *common command: /bin/bash -xcl "cd Benchmarks && swift package --disable-sandbox --scratch-path .build/$${SWIFT_VERSION-}/ --allow-writing-to-package-directory benchmark --format metricP90AbsoluteThresholds --path Thresholds/$${SWIFT_VERSION-}/" + + cxx-interop-build: + <<: *common + command: /bin/bash -xcl "./scripts/cxx-interop-compatibility.sh" # util diff --git a/scripts/cxx-interop-compatibility.sh b/scripts/cxx-interop-compatibility.sh new file mode 100755 index 0000000000..4109810786 --- /dev/null +++ b/scripts/cxx-interop-compatibility.sh @@ -0,0 +1,79 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +set -eu + +sourcedir=$(pwd) +workingdir=$(mktemp -d) +projectname=$(basename $workingdir) + +cd $workingdir +swift package init + +cat << EOF > Package.swift +// swift-tools-version: 5.9 + +import PackageDescription + +let package = Package( + name: "interop", + products: [ + .library( + name: "interop", + targets: ["interop"] + ), + ], + dependencies: [ + .package(path: "$sourcedir") + ], + targets: [ + .target( + name: "interop", + // Depend on all products of swift-nio to make sure they're all + // compatible with cxx interop. + dependencies: [ + .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOTLS", package: "swift-nio"), + .product(name: "NIOEmbedded", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOTestUtils", package: "swift-nio"), + .product(name: "_NIOConcurrency", package: "swift-nio") + ], + swiftSettings: [.interoperabilityMode(.Cxx)] + ) + ] +) +EOF + +cat << EOF > Sources/$projectname/$(echo $projectname | tr . _).swift +import NIO +import NIOCore +import NIOConcurrencyHelpers +import NIOTLS +import NIOEmbedded +import NIOPosix +import NIOHTTP1 +import NIOFoundationCompat +import NIOWebSocket +import NIOTestUtils +import _NIOConcurrency +EOF + +swift build From 5c668eb47ec74202c146ce4e679555b86c3a38b9 Mon Sep 17 00:00:00 2001 From: Johannes Weiss Date: Tue, 2 Jan 2024 14:57:31 +0000 Subject: [PATCH 63/64] allow setting MTELG.singleton as Swift Concurrency executor (#2564) --- .../NIOCrashTester/CrashTests+EventLoop.swift | 33 ++++- .../MultiThreadedEventLoopGroup.swift | 6 +- .../PosixSingletons+ConcurrencyTakeOver.swift | 123 ++++++++++++++++++ Sources/NIOPosix/SelectableEventLoop.swift | 4 +- Sources/NIOTCPEchoClient/Client.swift | 3 +- Sources/NIOTCPEchoServer/Server.swift | 3 +- Sources/NIOWebSocketClient/Client.swift | 3 +- Sources/NIOWebSocketServer/Server.swift | 3 +- 8 files changed, 168 insertions(+), 10 deletions(-) create mode 100644 Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift diff --git a/Sources/NIOCrashTester/CrashTests+EventLoop.swift b/Sources/NIOCrashTester/CrashTests+EventLoop.swift index 658d5f36ae..7ce7f4a3ef 100644 --- a/Sources/NIOCrashTester/CrashTests+EventLoop.swift +++ b/Sources/NIOCrashTester/CrashTests+EventLoop.swift @@ -11,7 +11,9 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + #if !canImport(Darwin) || os(macOS) +import Dispatch import NIOCore import NIOPosix @@ -177,5 +179,34 @@ struct EventLoopCrashTests { ) { NIOSingletons.groupLoopCountSuggestion = -1 } + + #if compiler(>=5.9) // We only support Concurrency executor take-over on 5.9+ + let testInstallingSingletonMTELGAsConcurrencyExecutorWorksButOnlyOnce = CrashTest( + regex: #"Fatal error: Must be called only once"# + ) { + guard NIOSingletons.unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() else { + print("Installation failed, that's unexpected -> let's not crash") + return + } + + // Yes, this pattern is bad abuse but this is a crash test, we don't mind. + let semaphoreAbuse = DispatchSemaphore(value: 0) + if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { + Task { + precondition(MultiThreadedEventLoopGroup.currentEventLoop != nil) + try await Task.sleep(nanoseconds: 123) + precondition(MultiThreadedEventLoopGroup.currentEventLoop != nil) + semaphoreAbuse.signal() + } + } else { + semaphoreAbuse.signal() + } + semaphoreAbuse.wait() + print("Okay, worked") + + // This should crash + _ = NIOSingletons.unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() + } + #endif // compiler(>=5.9) } -#endif +#endif // !canImport(Darwin) || os(macOS) diff --git a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift index 0a5d87c6c3..622122dea6 100644 --- a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift +++ b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift @@ -402,7 +402,7 @@ extension MultiThreadedEventLoopGroup: CustomStringConvertible { } } -#if swift(>=5.9) +#if compiler(>=5.9) @usableFromInline struct ErasedUnownedJob { @usableFromInline @@ -427,7 +427,7 @@ internal struct ScheduledTask { @usableFromInline enum UnderlyingTask { case function(() -> Void) - #if swift(>=5.9) + #if compiler(>=5.9) case unownedJob(ErasedUnownedJob) #endif } @@ -452,7 +452,7 @@ internal struct ScheduledTask { self.readyTime = time } - #if swift(>=5.9) + #if compiler(>=5.9) @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) @usableFromInline init(id: UInt64, job: consuming ExecutorJob, readyTime: NIODeadline) { diff --git a/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift new file mode 100644 index 0000000000..ad50430c9f --- /dev/null +++ b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics +import NIOCore + +#if compiler(>=5.9) +private protocol SilenceWarning { + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + func enqueue(_ job: UnownedJob) +} +@available(macOS 14, *) +extension SelectableEventLoop: SilenceWarning {} +#endif + +private let _haveWeTakenOverTheConcurrencyPool = ManagedAtomic(false) +extension NIOSingletons { + /// Install ``MultiThreadedEventLoopGroup/singleton`` as Swift Concurrency's global executor. + /// + /// This allows to use Swift Concurrency and retain high-performance I/O alleviating the otherwise necessary thread switches between + /// Swift Concurrency's own global pool and a place (like an `EventLoop`) that allows to perform I/O + /// + /// This method uses an atomic compare and exchange to install the hook (and makes sure it's not already set). This unilateral atomic memory + /// operation doesn't guarantee anything because another piece of code could have done the same without using atomic operations. But we + /// do our best. + /// + /// - warning: You may only call this method from the main thread. + /// - warning: You may only call this method once. + @discardableResult + public static func unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() -> Bool { + #if /* minimum supported */ compiler(>=5.9) && /* maximum tested */ swift(<5.11) + guard #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) else { + return false + } + + typealias ConcurrencyEnqueueGlobalHook = @convention(thin) ( + UnownedJob, @convention(thin) (UnownedJob) -> Void + ) -> Void + + guard + _haveWeTakenOverTheConcurrencyPool.compareExchange( + expected: false, + desired: true, + ordering: .relaxed + ).exchanged + else { + fatalError("Must be called only once") + } + + #if canImport(Darwin) + guard pthread_main_np() == 1 else { + fatalError("Must be called from the main thread") + } + #endif + + // Unsafe 1: We pretend that the hook's type is actually fully equivalent to `ConcurrencyEnqueueGlobalHook` + // @convention(thin) (UnownedJob, @convention(thin) (UnownedJob) -> Void) -> Void + // which isn't formally guaranteed. + let concurrencyEnqueueGlobalHookPtr = dlsym( + dlopen(nil, RTLD_NOW), + "swift_task_enqueueGlobal_hook" + )?.assumingMemoryBound(to: UnsafeRawPointer?.AtomicRepresentation.self) + guard let concurrencyEnqueueGlobalHookPtr = concurrencyEnqueueGlobalHookPtr else { + return false + } + + // We will use an atomic operation to swap the pointers aiming to protect against other code that attempts + // to swap the pointer. This isn't guaranteed to work as we can't be sure that the other code will actually + // use atomic compares and exchanges to. Nevertheless, we're doing our best. + let concurrencyEnqueueGlobalHookAtomic = UnsafeAtomic(at: concurrencyEnqueueGlobalHookPtr) + // note: We don't need to destroy this atomic as we're borrowing the storage from `dlsym`. + + return withUnsafeTemporaryAllocation( + of: ConcurrencyEnqueueGlobalHook.self, + capacity: 1 + ) { enqueueOnNIOPtr -> Bool in + // Unsafe 2: We mandate that we're actually getting _the_ function pointer to the closure below which + // isn't formally guaranteed by Swift. + enqueueOnNIOPtr.baseAddress!.initialize(to: { job, _ in + // This formally picks a random EventLoop from the singleton group. However, `EventLoopGroup.any()` + // attempts to be sticky. So if we're already in an `EventLoop` that's part of the singleton + // `EventLoopGroup`, we'll get that one and be very fast (avoid a thread switch). + let targetEL = MultiThreadedEventLoopGroup.singleton.any() + + (targetEL.executor as! any SilenceWarning).enqueue(job) + }) + + // Unsafe 3: We mandate that the function pointer can be reinterpreted as `UnsafeRawPointer` which isn't + // formally guaranteed by Swift + return enqueueOnNIOPtr.baseAddress!.withMemoryRebound( + to: UnsafeRawPointer.self, + capacity: 1 + ) { enqueueOnNIOPtr in + // Unsafe 4: We just pretend that we're the only ones in the world pulling this trick (or at least + // that the others also use a `compareExchange`)... + guard concurrencyEnqueueGlobalHookAtomic.compareExchange( + expected: nil, + desired: enqueueOnNIOPtr.pointee, + ordering: .relaxed + ).exchanged else { + return false + } + + // nice, everything worked. + return true + } + } + #else + return false + #endif + } +} diff --git a/Sources/NIOPosix/SelectableEventLoop.swift b/Sources/NIOPosix/SelectableEventLoop.swift index d76dc6005d..3d498a968f 100644 --- a/Sources/NIOPosix/SelectableEventLoop.swift +++ b/Sources/NIOPosix/SelectableEventLoop.swift @@ -299,7 +299,7 @@ Further information: }, .now())) } - #if swift(>=5.9) + #if compiler(>=5.9) @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) @usableFromInline func enqueue(_ job: consuming ExecutorJob) { @@ -533,7 +533,7 @@ Further information: case .function(let function): function() - #if swift(>=5.9) + #if compiler(>=5.9) case .unownedJob(let erasedUnownedJob): if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { erasedUnownedJob.unownedJob.runSynchronously(on: self.asUnownedSerialExecutor()) diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index c3fcb963dd..16f2d5da4e 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -11,7 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) + +#if compiler(>=5.9) import NIOCore import NIOPosix diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index 30b786b79d..1ccfccc33e 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -11,7 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) + +#if compiler(>=5.9) import NIOCore import NIOPosix diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index 5efa89993a..0bf42d3b16 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -11,7 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if (!canImport(Darwin) && swift(>=5.9)) || (canImport(Darwin) && swift(>=5.10)) + +#if (!canImport(Darwin) && compiler(>=5.9)) || (canImport(Darwin) && compiler(>=5.10)) import NIOCore import NIOPosix import NIOHTTP1 diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index 9ef311c57a..560fe16e39 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -11,7 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if (!canImport(Darwin) && swift(>=5.9)) || (canImport(Darwin) && swift(>=5.10)) + +#if (!canImport(Darwin) && compiler(>=5.9)) || (canImport(Darwin) && compiler(>=5.10)) import NIOCore import NIOPosix import NIOHTTP1 From 52908578ed60747714836f3cad5eb51969b6a582 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Wed, 3 Jan 2024 07:16:25 +0000 Subject: [PATCH 64/64] Fix the broken performance test binary (#2619) Motivation: The performance test binary was crashing ever since #2589 added the crash on deinit flow. Crashes here are preventing us from using the performance tester. Modifications: Correctly clean up the async writer. Result: The writer is cleaned up now. --- .../NIOAsyncWriterSingleWritesBenchmark.swift | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift index 8ee2cda9d5..ab14f33182 100644 --- a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift +++ b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift @@ -44,7 +44,9 @@ final class NIOAsyncWriterSingleWritesBenchmark: AsyncBenchmark, @unchecked Send } func setUp() async throws {} - func tearDown() {} + func tearDown() { + self.writer.finish() + } func run() async throws -> Int { for i in 0..