diff --git a/.gitignore b/.gitignore index a70f621b48..ec69488c87 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ Package.resolved DerivedData .swiftpm .*.sw? +.vscode/launch.json diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 0000000000..2517bcdfa8 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc +.benchmarkBaselines/ \ No newline at end of file diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift new file mode 100644 index 0000000000..521ce7df00 --- /dev/null +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/Benchmarks.swift @@ -0,0 +1,66 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Benchmark +import NIOPosix + +private let eventLoop = MultiThreadedEventLoopGroup.singleton.next() + +let benchmarks = { + let defaultMetrics: [BenchmarkMetric] = [ + .mallocCountTotal, + ] + + Benchmark( + "TCPEcho", + configuration: .init( + metrics: defaultMetrics, + timeUnits: .milliseconds, + scalingFactor: .mega + ) + ) { benchmark in + try runTCPEcho( + numberOfWrites: benchmark.scaledIterations.upperBound, + eventLoop: eventLoop + ) + } + + // This benchmark is only available above 5.9 since our EL conformance + // to serial executor is also gated behind 5.9. + #if compiler(>=5.9) +// In addition this benchmark currently doesn't produce deterministic results on our CI +// and therefore is currently disabled +// Benchmark( +// "TCPEchoAsyncChannel", +// configuration: .init( +// metrics: defaultMetrics, +// timeUnits: .milliseconds, +// scalingFactor: .mega, +// setup: { +// swiftTaskEnqueueGlobalHook = { job, _ in +// eventLoop.executor.enqueue(job) +// } +// }, +// teardown: { +// swiftTaskEnqueueGlobalHook = nil +// } +// ) +// ) { benchmark in +// try await runTCPEchoAsyncChannel( +// numberOfWrites: benchmark.scaledIterations.upperBound, +// eventLoop: eventLoop +// ) +// } + #endif +} diff --git a/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift new file mode 100644 index 0000000000..a1ca7a5df4 --- /dev/null +++ b/Benchmarks/Benchmarks/NIOPosixBenchmarks/TCPEcho.swift @@ -0,0 +1,90 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore +import NIOPosix + +private final class EchoChannelHandler: ChannelInboundHandler { + fileprivate typealias InboundIn = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.writeAndFlush(data, promise: nil) + } +} + +private final class EchoRequestChannelHandler: ChannelInboundHandler { + fileprivate typealias InboundIn = ByteBuffer + + private let messageSize = 10000 + private let numberOfWrites: Int + private var batchCount = 0 + private let data: NIOAny + private let readsCompletePromise: EventLoopPromise + private var receivedData = 0 + + init(numberOfWrites: Int, readsCompletePromise: EventLoopPromise) { + self.numberOfWrites = numberOfWrites + self.readsCompletePromise = readsCompletePromise + self.data = NIOAny(ByteBuffer(repeating: 0, count: self.messageSize)) + } + + func channelActive(context: ChannelHandlerContext) { + for _ in 0.. Void) -> Void + +var swiftTaskEnqueueGlobalHook: EnqueueGlobalHook? { + get { _swiftTaskEnqueueGlobalHook.pointee } + set { _swiftTaskEnqueueGlobalHook.pointee = newValue } +} + +private let _swiftTaskEnqueueGlobalHook: UnsafeMutablePointer = + dlsym(dlopen(nil, RTLD_LAZY), "swift_task_enqueueGlobal_hook").assumingMemoryBound(to: EnqueueGlobalHook?.self) diff --git a/Benchmarks/Package.swift b/Benchmarks/Package.swift new file mode 100644 index 0000000000..8797a9249f --- /dev/null +++ b/Benchmarks/Package.swift @@ -0,0 +1,41 @@ +// swift-tools-version: 5.7 +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftCertificates open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftCertificates project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftCertificates project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import PackageDescription + +let package = Package( + name: "benchmarks", + platforms: [ + .macOS("14"), + ], + dependencies: [ + .package(path: "../"), + .package(url: "https://github.com/ordo-one/package-benchmark.git", from: "1.11.1"), + ], + targets: [ + .executableTarget( + name: "NIOPosixBenchmarks", + dependencies: [ + .product(name: "Benchmark", package: "package-benchmark"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + ], + path: "Benchmarks/NIOPosixBenchmarks", + plugins: [ + .plugin(name: "BenchmarkPlugin", package: "package-benchmark") + ] + ), + ] +) diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json new file mode 100644 index 0000000000..c6a93680d0 --- /dev/null +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 108 +} diff --git a/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json new file mode 100644 index 0000000000..248bd96061 --- /dev/null +++ b/Benchmarks/Thresholds/5.7/NIOPosixBenchmarks.TCPEcho.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json new file mode 100644 index 0000000000..248bd96061 --- /dev/null +++ b/Benchmarks/Thresholds/5.8/NIOPosixBenchmarks.TCPEcho.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json new file mode 100644 index 0000000000..248bd96061 --- /dev/null +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 110 +} diff --git a/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json new file mode 100644 index 0000000000..c6a93680d0 --- /dev/null +++ b/Benchmarks/Thresholds/main/NIOPosixBenchmarks.TCPEcho.p90.json @@ -0,0 +1,3 @@ +{ + "mallocCountTotal" : 108 +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eab29e7456..f4953f0d60 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,7 @@ For this reason, whenever you add new tests **you have to run a script** that ge ### Make sure your patch works for all supported versions of swift -The CI will do this for you. You can use the docker-compose files included if you wish to check locally. Currently all versions of swift >= 5.6 are supported. For example usage of docker compose see the main [README](./README.md#an-alternative-using-docker-compose) +The CI will do this for you. You can use the docker-compose files included if you wish to check locally. Currently all versions of swift >= 5.7 are supported. For example usage of docker compose see the main [README](./README.md#an-alternative-using-docker-compose) ### Make sure your code is performant diff --git a/IntegrationTests/tests_02_syscall_wrappers/defines.sh b/IntegrationTests/tests_02_syscall_wrappers/defines.sh index ac7793b046..c97d1aca84 100644 --- a/IntegrationTests/tests_02_syscall_wrappers/defines.sh +++ b/IntegrationTests/tests_02_syscall_wrappers/defines.sh @@ -22,7 +22,7 @@ function make_package() { fi cat > "$tmpdir/syscallwrapper/Package.swift" <<"EOF" -// swift-tools-version:5.6 +// swift-tools-version:5.7 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription @@ -36,7 +36,10 @@ let package = Package( dependencies: ["CNIOLinux", "CNIODarwin", "NIOCore"]), .target( name: "CNIOLinux", - dependencies: []), + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE") + ]), .target( name: "CNIODarwin", dependencies: []), diff --git a/IntegrationTests/tests_05_assertions/defines.sh b/IntegrationTests/tests_05_assertions/defines.sh index 22ffb942d7..0c2ed428d1 100644 --- a/IntegrationTests/tests_05_assertions/defines.sh +++ b/IntegrationTests/tests_05_assertions/defines.sh @@ -31,7 +31,10 @@ let package = Package( dependencies: ["CNIOLinux", "CNIODarwin", "NIOCore"]), .target( name: "CNIOLinux", - dependencies: []), + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE") + ]), .target( name: "CNIODarwin", dependencies: []), diff --git a/NOTICE.txt b/NOTICE.txt index 5090ff4f3e..e977e8e139 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -78,10 +78,21 @@ This product contains a derivation of Fabian Fett's 'Base64.swift'. * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE * HOMEPAGE: * https://github.com/fabianfett/swift-base64-kit - + +--- + This product contains a derivation of "XCTest+AsyncAwait.swift" from AsyncHTTPClient. * LICENSE (Apache License 2.0): * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/swift-server/async-http-client + +--- + +This product contains a derivation of "_TinyArray.swift" from SwiftCertificates. + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/apple/swift-certificates diff --git a/Package.swift b/Package.swift index c23c131429..ac54050d6d 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.6 +// swift-tools-version:5.7 //===----------------------------------------------------------------------===// // // This source file is part of the SwiftNIO open source project @@ -50,6 +50,7 @@ let package = Package( "CNIODarwin", "CNIOLinux", "CNIOWindows", + "_NIODataStructures", swiftCollections, swiftAtomics, ] @@ -106,7 +107,10 @@ let package = Package( ), .target( name: "CNIOAtomics", - dependencies: [] + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE"), + ] ), .target( name: "CNIOSHA1", @@ -114,7 +118,10 @@ let package = Package( ), .target( name: "CNIOLinux", - dependencies: [] + dependencies: [], + cSettings: [ + .define("_GNU_SOURCE"), + ] ), .target( name: "CNIODarwin", @@ -140,6 +147,7 @@ let package = Package( "NIOCore", "NIOConcurrencyHelpers", "CNIOLLHTTP", + swiftCollections ] ), .target( @@ -154,7 +162,10 @@ let package = Package( ), .target( name: "CNIOLLHTTP", - cSettings: [.define("LLHTTP_STRICT_MODE")] + cSettings: [ + .define("_GNU_SOURCE"), + .define("LLHTTP_STRICT_MODE") + ] ), .target( name: "NIOTLS", diff --git a/README.md b/README.md index 55b77050d5..725f5cd30c 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ It's like [Netty](https://netty.io), but written for Swift. The SwiftNIO project is split across multiple repositories: -Repository | NIO 2 (Swift 5.6+) +Repository | NIO 2 (Swift 5.7+) --- | --- [https://github.com/apple/swift-nio][repo-nio]
SwiftNIO core | `from: "2.0.0"` [https://github.com/apple/swift-nio-ssl][repo-nio-ssl]
TLS (SSL) support | `from: "2.0.0"` @@ -60,7 +60,7 @@ Protocol | Client | Server | Repository | Module | Comment --- | --- | --- | --- | --- | --- HTTP | ✅| ❌ | [swift-server/async-http-client](https://github.com/swift-server/async-http-client) | `AsyncHTTPClient` | SSWG community project gRPC | ✅| ✅ | [grpc/grpc-swift](https://github.com/grpc/grpc-swift) | `GRPC` | also offers a low-level API; SSWG community project -APNS | ✅ | ❌ | [kylebrowning/APNSwift](https://github.com/kylebrowning/APNSwift) | `APNSwift` | SSWG community project +APNS | ✅ | ❌ | [swift-server-community/APNSwift](https://github.com/swift-server-community/APNSwift) | `APNSwift` | SSWG community project PostgreSQL | ✅ | ❌ | [vapor/postgres-nio](https://github.com/vapor/postgres-nio) | `PostgresNIO` | SSWG community project Redis | ✅ | ❌ | [swift-server/RediStack](https://github.com/swift-server/RediStack) | `RediStack` | SSWG community project @@ -70,7 +70,7 @@ Redis | ✅ | ❌ | [swift-server/RediStack](https://github.com/swift-server/Red This is the current version of SwiftNIO and will be supported for the foreseeable future. -The most recent versions of SwiftNIO support Swift 5.6 and newer. The minimum Swift version supported by SwiftNIO releases are detailed below: +The most recent versions of SwiftNIO support Swift 5.7 and newer. The minimum Swift version supported by SwiftNIO releases are detailed below: SwiftNIO | Minimum Swift Version --------------------|---------------------- @@ -78,7 +78,8 @@ SwiftNIO | Minimum Swift Version `2.30.0 ..< 2.40.0` | 5.2 `2.40.0 ..< 2.43.0` | 5.4 `2.43.0 ..< 2.51.0` | 5.5.2 -`2.51.0 ...` | 5.6 +`2.51.0 ..< 2.60.0` | 5.6 +`2.60.0 ...` | 5.7 ### SwiftNIO 1 SwiftNIO 1 is considered end of life - it is strongly recommended that you move to a newer version. The Core NIO team does not actively work on this version. No new features will be added to this version but PRs which fix bugs or security vulnerabilities will be accepted until the end of May 2022. @@ -159,7 +160,7 @@ In general, [`ChannelHandler`][ch]s are designed to be highly re-usable componen SwiftNIO ships with many [`ChannelHandler`][ch]s built in that provide useful functionality, such as HTTP parsing. In addition, high-performance applications will want to provide as much of their logic as possible in [`ChannelHandler`][ch]s, as it helps avoid problems with context switching. -Additionally, SwiftNIO ships with a few [`Channel`][c] implementations. In particular, it ships with `ServerSocketChannel`, a [`Channel`][c] for sockets that accept inbound connections; `SocketChannel`, a [`Channel`][c] for TCP connections; and `DatagramChannel`, a [`Channel`][c] for UDP sockets. All of these are provided by the `NIOPosix` module. It also provides[`EmbeddedChannel`][ec], a [`Channel`][c] primarily used for testing, provided by the `NIOEmbedded` module. +Additionally, SwiftNIO ships with a few [`Channel`][c] implementations. In particular, it ships with `ServerSocketChannel`, a [`Channel`][c] for sockets that accept inbound connections; `SocketChannel`, a [`Channel`][c] for TCP connections; and `DatagramChannel`, a [`Channel`][c] for UDP sockets. All of these are provided by the `NIOPosix` module. It also provides [`EmbeddedChannel`][ec], a [`Channel`][c] primarily used for testing, provided by the `NIOEmbedded` module. ##### A Note on Blocking @@ -332,7 +333,7 @@ have a few prerequisites installed on your system. ### Linux -- Swift 5.6 or newer from [swift.org/download](https://swift.org/download/#releases). We always recommend to use the latest released version. +- Swift 5.7 or newer from [swift.org/download](https://swift.org/download/#releases). We always recommend to use the latest released version. - netcat (for integration tests only) - lsof (for integration tests only) - shasum (for integration tests only) @@ -351,6 +352,19 @@ apt-get install -y git curl libatomic1 libxml2 netcat-openbsd lsof perl dnf install swift-lang /usr/bin/nc /usr/bin/lsof /usr/bin/shasum ``` +### Benchmarks + +Benchmarks for `swift-nio` are in a separate Swift Package in the `Benchmarks` subfolder of this repository. +They use the [`package-benchmark`](https://github.com/ordo-one/package-benchmark) plugin. +Benchmarks depends on the [`jemalloc`](https://jemalloc.net) memory allocation library, which is used by `package-benchmark` to capture memory allocation statistics. +An installation guide can be found in the [Getting Started article](https://swiftpackageindex.com/ordo-one/package-benchmark/documentation/benchmark/gettingstarted#Installing-Prerequisites-and-Platform-Support) of `package-benchmark`. +Afterwards you can run the benchmarks from CLI by going to the `Benchmarks` subfolder (e.g. `cd Benchmarks`) and invoking: +``` +swift package benchmark +``` + +For more information please refer to `swift package benchmark --help` or the [documentation of `package-benchmark`](https://swiftpackageindex.com/ordo-one/package-benchmark/documentation/benchmark). + [ch]: https://swiftpackageindex.com/apple/swift-nio/main/documentation/niocore/channelhandler [c]: https://swiftpackageindex.com/apple/swift-nio/main/documentation/niocore/channel [chc]: https://swiftpackageindex.com/apple/swift-nio/main/documentation/niocore/channelhandlercontext diff --git a/SECURITY.md b/SECURITY.md index d1e7b82c60..4f548dcfa9 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -19,8 +19,10 @@ team would create the following patch releases: Swift 5.4 and later * NIO 2.50. + plus next patch release to address the issue for projects that support Swift 5.5.2 and later -* mainline + plus next patch release to address the issue for projects that support +* NIO 2.59. + plus next patch release to address the issue for projects that support Swift 5.6 and later +* mainline + plus next patch release to address the issue for projects that support + Swift 5.7 and later SwiftNIO 1.x is considered end of life and will not receive any security patches. diff --git a/Sources/CNIOLinux/shim.c b/Sources/CNIOLinux/shim.c index 236da99ff4..d56c3bba69 100644 --- a/Sources/CNIOLinux/shim.c +++ b/Sources/CNIOLinux/shim.c @@ -18,7 +18,10 @@ void CNIOLinux_i_do_nothing_just_working_around_a_darwin_toolchain_bug(void) {} #ifdef __linux__ -#define _GNU_SOURCE +#ifndef _GNU_SOURCE +#error You must define _GNU_SOURCE +#endif + #include #include #include diff --git a/Sources/CNIOSHA1/c_nio_sha1.c b/Sources/CNIOSHA1/c_nio_sha1.c index 965185e7b5..30b61c8793 100644 --- a/Sources/CNIOSHA1/c_nio_sha1.c +++ b/Sources/CNIOSHA1/c_nio_sha1.c @@ -4,6 +4,7 @@ - defined the __min_size macro inline - included sys/endian.h on Android - use welcoming language (soundness check) + - ensure BYTE_ORDER is defined */ /* $KAME: sha1.c,v 1.5 2000/11/08 06:13:08 itojun Exp $ */ /*- @@ -53,14 +54,15 @@ #endif #ifdef __ANDROID__ #include -#elif __linux__ +#elif defined(__linux__) || defined(__APPLE__) #include #endif - /* soundness check */ -#if BYTE_ORDER != BIG_ENDIAN +#if !defined(BYTE_ORDER) +#error "BYTE_ORDER not defined" +#elif BYTE_ORDER != BIG_ENDIAN # if BYTE_ORDER != LITTLE_ENDIAN # define unsupported 1 # endif diff --git a/Sources/CNIOSHA1/include/CNIOSHA1.h b/Sources/CNIOSHA1/include/CNIOSHA1.h index 6ed17f79f5..b4d8c9524d 100644 --- a/Sources/CNIOSHA1/include/CNIOSHA1.h +++ b/Sources/CNIOSHA1/include/CNIOSHA1.h @@ -4,6 +4,7 @@ - defined the __min_size macro inline - included sys/endian.h on Android - use welcoming language (soundness check) + - ensure BYTE_ORDER is defined */ /* $FreeBSD$ */ /* $KAME: sha1.h,v 1.5 2000/03/27 04:36:23 sumikawa Exp $ */ @@ -49,6 +50,10 @@ #include #include +#ifdef __cplusplus +extern "C" { +#endif + struct sha1_ctxt { union { uint8_t b8[20]; @@ -68,7 +73,12 @@ typedef struct sha1_ctxt SHA1_CTX; #define SHA1_RESULTLEN (160/8) +#ifdef __cplusplus +#define __min_size(x) (x) +#else #define __min_size(x) static (x) +#endif + extern void c_nio_sha1_init(struct sha1_ctxt *); extern void c_nio_sha1_pad(struct sha1_ctxt *); extern void c_nio_sha1_loop(struct sha1_ctxt *, const uint8_t *, size_t); @@ -79,5 +89,8 @@ extern void c_nio_sha1_result(struct sha1_ctxt *, char[__min_size(SHA1_RESULTLEN #define SHA1Update(x, y, z) c_nio_sha1_loop((x), (y), (z)) #define SHA1Final(x, y) c_nio_sha1_result((y), (x)) +#ifdef __cplusplus +} /* extern "C" */ +#endif #endif /*_CRYPTO_SHA1_H_*/ diff --git a/Sources/CNIOSHA1/update_and_patch_sha1.sh b/Sources/CNIOSHA1/update_and_patch_sha1.sh index c88709ebca..8557928a34 100755 --- a/Sources/CNIOSHA1/update_and_patch_sha1.sh +++ b/Sources/CNIOSHA1/update_and_patch_sha1.sh @@ -33,6 +33,7 @@ for f in sha1.c sha1.h; do echo " - defined the __min_size macro inline" echo " - included sys/endian.h on Android" echo " - use welcoming language (soundness check)" + echo " - ensure BYTE_ORDER is defined" echo "*/" curl -Ls "https://raw.githubusercontent.com/freebsd/freebsd/master/sys/crypto/$f" ) > "$here/c_nio_$f" @@ -52,8 +53,9 @@ $sed -e $'/#define _CRYPTO_SHA1_H_/a #include \\\n#include ' $sed -e 's/u_int\([0-9]\+\)_t/uint\1_t/g' \ -e '/^#include/d' \ - -e $'/__FBSDID/c #include "include/CNIOSHA1.h"\\n#include \\n#if !defined(bzero)\\n#define bzero(b,l) memset((b), \'\\\\0\', (l))\\n#endif\\n#if !defined(bcopy)\\n#define bcopy(s,d,l) memmove((d), (s), (l))\\n#endif\\n#ifdef __ANDROID__\\n#include \\n#elif __linux__\\n#include \\n#endif' \ + -e $'/__FBSDID/c #include "include/CNIOSHA1.h"\\n#include \\n#if !defined(bzero)\\n#define bzero(b,l) memset((b), \'\\\\0\', (l))\\n#endif\\n#if !defined(bcopy)\\n#define bcopy(s,d,l) memmove((d), (s), (l))\\n#endif\\n#ifdef __ANDROID__\\n#include \\n#elif defined(__linux__) || defined(__APPLE__)\\n#include \\n#endif' \ -e 's/sanit[y]/soundness/g' \ + -e 's/#if BYTE_ORDER != BIG_ENDIAN/#if !defined(BYTE_ORDER)\\n#error "BYTE_ORDER not defined"\\n#elif BYTE_ORDER != BIG_ENDIAN/' \ -i "$here/c_nio_sha1.c" mv "$here/c_nio_sha1.h" "$here/include/CNIOSHA1.h" diff --git a/Sources/NIOConcurrencyHelpers/lock.swift b/Sources/NIOConcurrencyHelpers/lock.swift index 2aac90ade1..5df4af7b15 100644 --- a/Sources/NIOConcurrencyHelpers/lock.swift +++ b/Sources/NIOConcurrencyHelpers/lock.swift @@ -295,5 +295,5 @@ internal func debugOnly(_ body: () -> Void) { } @available(*, deprecated) -extension Lock: Sendable {} +extension Lock: @unchecked Sendable {} extension ConditionLock: @unchecked Sendable {} diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index 0fac8cce1d..3bbaaf4ba9 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -19,53 +19,51 @@ /// the following functionality: /// /// - reads are presented as an `AsyncSequence` -/// - writes can be written to with async functions on a writer, providing backpressure +/// - writes can be written to with async functions on a writer, providing back pressure /// - channels can be closed seamlessly /// /// This type does not replace the full complexity of NIO's ``Channel``. In particular, it /// does not expose the following functionality: /// /// - user events -/// - traditional NIO backpressure such as writability signals and the ``Channel/read()`` call +/// - traditional NIO back pressure such as writability signals and the ``Channel/read()`` call /// /// Users are encouraged to separate their ``ChannelHandler``s into those that implement /// protocol-specific logic (such as parsers and encoders) and those that implement business /// logic. Protocol-specific logic should be implemented as a ``ChannelHandler``, while business /// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) -public final class NIOAsyncChannel: Sendable { - @_spi(AsyncChannel) +public struct NIOAsyncChannel: Sendable { public struct Configuration: Sendable { - /// The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. - public var backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark + /// The back pressure strategy of the ``NIOAsyncChannel/inbound``. + public var backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark /// If outbound half closure should be enabled. Outbound half closure is triggered once - /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. public var isOutboundHalfClosureEnabled: Bool - /// The ``NIOAsyncChannel/inboundStream`` message's type. + /// The ``NIOAsyncChannel/inbound`` message's type. public var inboundType: Inbound.Type - /// The ``NIOAsyncChannel/outboundWriter`` message's type. + /// The ``NIOAsyncChannel/outbound`` message's type. public var outboundType: Outbound.Type /// Initializes a new ``NIOAsyncChannel/Configuration``. /// /// - Parameters: - /// - backpressureStrategy: The backpressure strategy of the ``NIOAsyncChannel/inboundStream``. Defaults + /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inbound``. Defaults /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once - /// the ``NIOAsyncChannelWriter`` is either finished or deinitialized. Defaults to `false`. - /// - inboundType: The ``NIOAsyncChannel/inboundStream`` message's type. - /// - outboundType: The ``NIOAsyncChannel/outboundWriter`` message's type. + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. Defaults to `false`. + /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. + /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. public init( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), isOutboundHalfClosureEnabled: Bool = false, inboundType: Inbound.Type = Inbound.self, outboundType: Outbound.Type = Outbound.self ) { - self.backpressureStrategy = backpressureStrategy + self.backPressureStrategy = backPressureStrategy self.isOutboundHalfClosureEnabled = isOutboundHalfClosureEnabled self.inboundType = inboundType self.outboundType = outboundType @@ -73,14 +71,25 @@ public final class NIOAsyncChannel: Senda } /// The underlying channel being wrapped by this ``NIOAsyncChannel``. - @_spi(AsyncChannel) public let channel: Channel + /// The stream of inbound messages. - @_spi(AsyncChannel) - public let inboundStream: NIOAsyncChannelInboundStream + /// + /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. + @available(*, deprecated, message: "Use the executeThenClose scoped method instead.") + public var inbound: NIOAsyncChannelInboundStream { + self._inbound + } /// The writer for writing outbound messages. - @_spi(AsyncChannel) - public let outboundWriter: NIOAsyncChannelOutboundWriter + @available(*, deprecated, message: "Use the executeThenClose scoped method instead.") + public var outbound: NIOAsyncChannelOutboundWriter { + self._outbound + } + + @usableFromInline + let _inbound: NIOAsyncChannelInboundStream + @usableFromInline + let _outbound: NIOAsyncChannelOutboundWriter /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. /// @@ -91,22 +100,71 @@ public final class NIOAsyncChannel: Senda /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable - @_spi(AsyncChannel) + public init( + wrappingChannelSynchronously channel: Channel, + configuration: Configuration = .init() + ) throws { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: false + ) + } + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. + /// + /// This initializer will finish the ``NIOAsyncChannel/outbound`` immediately. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable + public init( + wrappingChannelSynchronously channel: Channel, + configuration: Configuration = .init() + ) throws where Outbound == Never { + channel.eventLoop.preconditionInEventLoop() + self.channel = channel + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: false + ) + + self._outbound.finish() + } + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @inlinable public init( synchronouslyWrapping channel: Channel, configuration: Configuration = .init() ) throws { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( - backpressureStrategy: configuration.backpressureStrategy, - isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: true ) } /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. /// - /// This initializer will finish the ``NIOAsyncChannel/outboundWriter`` immediately. + /// This initializer will finish the ``NIOAsyncChannel/outbound`` immediately. /// /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. @@ -115,46 +173,51 @@ public final class NIOAsyncChannel: Senda /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable - @_spi(AsyncChannel) + @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") public init( synchronouslyWrapping channel: Channel, - configuration: Configuration + configuration: Configuration = .init() ) throws where Outbound == Never { channel.eventLoop.preconditionInEventLoop() self.channel = channel - (self.inboundStream, self.outboundWriter) = try channel._syncAddAsyncHandlers( - backpressureStrategy: configuration.backpressureStrategy, - isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled + (self._inbound, self._outbound) = try channel._syncAddAsyncHandlers( + backPressureStrategy: configuration.backPressureStrategy, + isOutboundHalfClosureEnabled: configuration.isOutboundHalfClosureEnabled, + closeOnDeinit: true ) - self.outboundWriter.finish() + self._outbound.finish() } @inlinable - @_spi(AsyncChannel) - public init( + internal init( channel: Channel, inboundStream: NIOAsyncChannelInboundStream, outboundWriter: NIOAsyncChannelOutboundWriter ) { channel.eventLoop.preconditionInEventLoop() self.channel = channel - self.inboundStream = inboundStream - self.outboundWriter = outboundWriter + self._inbound = inboundStream + self._outbound = outboundWriter } + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. @inlinable - @_spi(AsyncChannel) - public static func wrapAsyncChannelWithTransformations( + @available(*, deprecated, message: "This method has been deprecated since it defaults to deinit based resource teardown") + public static func _wrapAsyncChannelWithTransformations( synchronouslyWrapping channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, isOutboundHalfClosureEnabled: Bool = false, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannel where Outbound == Never { channel.eventLoop.preconditionInEventLoop() let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( - backpressureStrategy: backpressureStrategy, + backPressureStrategy: backPressureStrategy, isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: true, channelReadTransformation: channelReadTransformation ) @@ -166,52 +229,125 @@ public final class NIOAsyncChannel: Senda outboundWriter: outboundWriter ) } + + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. + @inlinable + public static func _wrapAsyncChannelWithTransformations( + wrappingChannelSynchronously channel: Channel, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + isOutboundHalfClosureEnabled: Bool = false, + channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture + ) throws -> NIOAsyncChannel where Outbound == Never { + channel.eventLoop.preconditionInEventLoop() + let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( + backPressureStrategy: backPressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: false, + channelReadTransformation: channelReadTransformation + ) + + outboundWriter.finish() + + return .init( + channel: channel, + inboundStream: inboundStream, + outboundWriter: outboundWriter + ) + } + + /// Provides scoped access to the inbound and outbound side of the underlying ``Channel``. + /// + /// - Important: After this method returned the underlying ``Channel`` will be closed. + /// + /// - Parameter body: A closure that gets scoped access to the inbound and outbound. + public func executeThenClose( + _ body: (_ inbound: NIOAsyncChannelInboundStream, _ outbound: NIOAsyncChannelOutboundWriter) async throws -> Result + ) async throws -> Result { + let result: Result + do { + result = try await body(self._inbound, self._outbound) + } catch let bodyError { + do { + self._outbound.finish() + try await self.channel.close().get() + throw bodyError + } catch { + throw bodyError + } + } + + do { + self._outbound.finish() + try await self.channel.close().get() + } catch { + if let error = error as? ChannelError, error == .alreadyClosed { + return result + } + throw error + } + return result + } + + /// Provides scoped access to the inbound side of the underlying ``Channel``. + /// + /// - Important: After this method returned the underlying ``Channel`` will be closed. + /// + /// - Parameter body: A closure that gets scoped access to the inbound. + public func executeThenClose( + _ body: (_ inbound: NIOAsyncChannelInboundStream) async throws -> Result + ) async throws -> Result where Outbound == Never { + try await self.executeThenClose { inbound, _ in + try await body(inbound) + } + } } extension Channel { - // TODO: We need to remove the public and spi here once we make the AsyncChannel methods public @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable - @_spi(AsyncChannel) - public func _syncAddAsyncHandlers( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - isOutboundHalfClosureEnabled: Bool + func _syncAddAsyncHandlers( + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() - let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeWrappingHandler( channel: self, - backpressureStrategy: backpressureStrategy, - closeRatchet: closeRatchet + backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: closeOnDeinit ) return (inboundStream, writer) } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable - @_spi(AsyncChannel) - public func _syncAddAsyncHandlersWithTransformations( - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + func _syncAddAsyncHandlersWithTransformations( + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) { self.eventLoop.assertInEventLoop() - let closeRatchet = CloseRatchet(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) let inboundStream = try NIOAsyncChannelInboundStream.makeTransformationHandler( channel: self, - backpressureStrategy: backpressureStrategy, - closeRatchet: closeRatchet, + backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, channelReadTransformation: channelReadTransformation ) let writer = try NIOAsyncChannelOutboundWriter( channel: self, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: closeOnDeinit ) return (inboundStream, writer) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index 6826be3d9e..0a672dc3ed 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -16,7 +16,6 @@ /// /// This is a unicast async sequence that allows a single iterator to be created. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) public struct NIOAsyncChannelInboundStream: Sendable { @usableFromInline typealias Producer = NIOThrowingAsyncSequenceProducer @@ -48,20 +47,11 @@ public struct NIOAsyncChannelInboundStream: Sendable { } } - #if swift(>=5.7) @usableFromInline enum _Backing: Sendable { case asyncStream(AsyncThrowingStream) case producer(Producer) } - #else - // AsyncStream wasn't marked as `Sendable` in 5.6 - @usableFromInline - enum _Backing: @unchecked Sendable { - case asyncStream(AsyncThrowingStream) - case producer(Producer) - } - #endif /// The underlying async sequence. @usableFromInline @@ -89,14 +79,14 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable init( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool, handler: NIOAsyncChannelInboundStreamChannelHandler ) throws { channel.eventLoop.preconditionInEventLoop() let strategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark - if let userProvided = backpressureStrategy { + if let userProvided = backPressureStrategy { strategy = userProvided } else { // Default strategy. These numbers are fairly arbitrary, but they line up with the default value of @@ -106,6 +96,7 @@ public struct NIOAsyncChannelInboundStream: Sendable { let sequence = Producer.makeSequence( backPressureStrategy: strategy, + finishOnDeinit: closeOnDeinit, delegate: NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate(handler: handler) ) handler.source = sequence.source @@ -117,18 +108,17 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeWrappingHandler( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandler( - eventLoop: channel.eventLoop, - closeRatchet: closeRatchet + eventLoop: channel.eventLoop ) return try .init( channel: channel, - backpressureStrategy: backpressureStrategy, - closeRatchet: closeRatchet, + backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, handler: handler ) } @@ -137,20 +127,19 @@ public struct NIOAsyncChannelInboundStream: Sendable { @inlinable static func makeTransformationHandler( channel: Channel, - backpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - closeRatchet: CloseRatchet, + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + closeOnDeinit: Bool, channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannelInboundStream { let handler = NIOAsyncChannelInboundStreamChannelHandler.makeHandlerWithTransformations( eventLoop: channel.eventLoop, - closeRatchet: closeRatchet, channelReadTransformation: channelReadTransformation ) return try .init( channel: channel, - backpressureStrategy: backpressureStrategy, - closeRatchet: closeRatchet, + backPressureStrategy: backPressureStrategy, + closeOnDeinit: closeOnDeinit, handler: handler ) } @@ -158,10 +147,8 @@ public struct NIOAsyncChannelInboundStream: Sendable { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelInboundStream: AsyncSequence { - @_spi(AsyncChannel) public typealias Element = Inbound - @_spi(AsyncChannel) public struct AsyncIterator: AsyncIteratorProtocol { @usableFromInline enum _Backing { @@ -181,7 +168,7 @@ extension NIOAsyncChannelInboundStream: AsyncSequence { } } - @inlinable @_spi(AsyncChannel) + @inlinable public mutating func next() async throws -> Element? { switch self._backing { case .asyncStream(var iterator): @@ -198,7 +185,6 @@ extension NIOAsyncChannelInboundStream: AsyncSequence { } @inlinable - @_spi(AsyncChannel) public func makeAsyncIterator() -> AsyncIterator { return AsyncIterator(self._backing) } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift index 31d22f3ca6..c04b203560 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStreamChannelHandler.swift @@ -63,10 +63,6 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler NIOAsyncChannelInboundStreamChannelHandler where InboundIn == ProducerElement { return .init( eventLoop: eventLoop, - closeRatchet: closeRatchet, transformation: .syncWrapping { $0 } ) } @@ -109,12 +101,10 @@ internal final class NIOAsyncChannelInboundStreamChannelHandler EventLoopFuture ) -> NIOAsyncChannelInboundStreamChannelHandler where InboundIn == Channel { return .init( eventLoop: eventLoop, - closeRatchet: closeRatchet, transformation: .transformation( channelReadTransformation: channelReadTransformation ) @@ -277,14 +267,6 @@ extension NIOAsyncChannelInboundStreamChannelHandler { // Wedges the read open forever, we'll never read again. self.producingState = .producingPausedWithOutstandingRead - - switch self.closeRatchet.closeRead() { - case .nothing: - break - - case .close: - self.context?.close(promise: nil) - } } @inlinable @@ -326,15 +308,23 @@ struct NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate: @unchecked Se @inlinable func didTerminate() { - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self._didTerminate() + } else { + self.eventLoop.execute { + self._didTerminate() + } } } @inlinable func produceMore() { - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self._produceMore() + } else { + self.eventLoop.execute { + self._produceMore() + } } } } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 13934d5ec7..d89332e255 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -12,13 +12,12 @@ // //===----------------------------------------------------------------------===// -/// A ``NIOAsyncChannelWriter`` is used to write and flush new outbound messages in a channel. +/// A ``NIOAsyncChannelOutboundWriter`` is used to write and flush new outbound messages in a channel. /// /// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the /// underlying ``Channel``. Furthermore, it respects back-pressure of the channel by suspending the calls to write until /// the channel becomes writable again. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -@_spi(AsyncChannel) public struct NIOAsyncChannelOutboundWriter: Sendable { @usableFromInline typealias _Writer = NIOAsyncChannelOutboundWriterHandler.Writer @@ -85,15 +84,17 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { @inlinable init( channel: Channel, - closeRatchet: CloseRatchet + isOutboundHalfClosureEnabled: Bool, + closeOnDeinit: Bool ) throws { let handler = NIOAsyncChannelOutboundWriterHandler( eventLoop: channel.eventLoop, - closeRatchet: closeRatchet + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled ) let writer = _Writer.makeWriter( elementType: OutboundOut.self, isWritable: true, + finishOnDeinit: closeOnDeinit, delegate: .init(handler: handler) ) handler.sink = writer.sink @@ -112,7 +113,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(_ data: OutboundOut) async throws { switch self._backing { case .asyncStream(let continuation): @@ -126,7 +126,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { switch self._backing { case .asyncStream(let continuation): @@ -138,11 +137,12 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { } } - /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// Send an asynchronous sequence of writes into the ``ChannelPipeline``. + /// + /// This will flush after every write. /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - @_spi(AsyncChannel) public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { for try await data in sequence { try await self.write(data) @@ -152,7 +152,6 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// Finishes the writer. /// /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. - @_spi(AsyncChannel) public func finish() { switch self._backing { case .asyncStream(let continuation): @@ -163,11 +162,5 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { } } -#if swift(>=5.7) @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelOutboundWriter.TestSink: Sendable {} -#else -// AsyncStream wasn't marked as `Sendable` in 5.6 -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -extension NIOAsyncChannelOutboundWriter.TestSink: @unchecked Sendable {} -#endif diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift index b15795b9c5..59fad5e3e1 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriterHandler.swift @@ -45,17 +45,16 @@ internal final class NIOAsyncChannelOutboundWriterHandler @usableFromInline let eventLoop: EventLoop - /// The shared `CloseRatchet` between this handler and the inbound stream handler. @usableFromInline - let closeRatchet: CloseRatchet + let isOutboundHalfClosureEnabled: Bool @inlinable init( eventLoop: EventLoop, - closeRatchet: CloseRatchet + isOutboundHalfClosureEnabled: Bool ) { self.eventLoop = eventLoop - self.closeRatchet = closeRatchet + self.isOutboundHalfClosureEnabled = isOutboundHalfClosureEnabled } @inlinable @@ -76,18 +75,28 @@ internal final class NIOAsyncChannelOutboundWriterHandler } @inlinable - func _didTerminate(error: Error?) { + func _didYield(element: OutboundOut) { + // This is always called from an async context, so we must loop-hop. + // Because we always loop-hop, we're always at the top of a stack frame. As this + // is the only source of writes for us, and as this channel handler doesn't implement + // func write(), we cannot possibly re-entrantly write. That means we can skip many of the + // awkward re-entrancy protections NIO usually requires, and can safely just do an iterative + // write. self.eventLoop.preconditionInEventLoop() + guard let context = self.context else { + // Already removed from the channel by now, we can stop. + return + } - switch self.closeRatchet.closeWrite() { - case .nothing: - break + self._doOutboundWrite(context: context, write: element) + } - case .closeOutput: - self.context?.close(mode: .output, promise: nil) + @inlinable + func _didTerminate(error: Error?) { + self.eventLoop.preconditionInEventLoop() - case .close: - self.context?.close(promise: nil) + if self.isOutboundHalfClosureEnabled { + self.context?.close(mode: .output, promise: nil) } self.sink = nil @@ -102,6 +111,12 @@ internal final class NIOAsyncChannelOutboundWriterHandler context.flush() } + @inlinable + func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) { + context.write(self.wrapOutboundOut(write), promise: nil) + context.flush() + } + @inlinable func handlerAdded(context: ChannelHandlerContext) { self.context = context @@ -110,7 +125,7 @@ internal final class NIOAsyncChannelOutboundWriterHandler @inlinable func handlerRemoved(context: ChannelHandlerContext) { self.context = nil - self.sink = nil + self.sink?.finish() } @inlinable @@ -124,6 +139,18 @@ internal final class NIOAsyncChannelOutboundWriterHandler self.sink?.setWritability(to: context.channel.isWritable) context.fireChannelWritabilityChanged() } + + @inlinable + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case ChannelEvent.outputClosed: + self.sink?.finish() + default: + break + } + + context.fireUserInboundEventTriggered(event) + } } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -147,17 +174,34 @@ extension NIOAsyncChannelOutboundWriterHandler { @inlinable func didYield(contentsOf sequence: Deque) { - // This always called from an async context, so we must loop-hop. - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self.handler._didYield(sequence: sequence) + } else { + self.eventLoop.execute { + self.handler._didYield(sequence: sequence) + } + } + } + + @inlinable + func didYield(_ element: OutboundOut) { + if self.eventLoop.inEventLoop { + self.handler._didYield(element: element) + } else { + self.eventLoop.execute { + self.handler._didYield(element: element) + } } } @inlinable func didTerminate(error: Error?) { - // This always called from an async context, so we must loop-hop. - self.eventLoop.execute { + if self.eventLoop.inEventLoop { self.handler._didTerminate(error: error) + } else { + self.eventLoop.execute { + self.handler._didTerminate(error: error) + } } } } diff --git a/Sources/NIOCore/AsyncChannel/CloseRatchet.swift b/Sources/NIOCore/AsyncChannel/CloseRatchet.swift deleted file mode 100644 index 9d87ea4d07..0000000000 --- a/Sources/NIOCore/AsyncChannel/CloseRatchet.swift +++ /dev/null @@ -1,93 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2022-2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -/// A helper type that lets ``NIOAsyncChannelAdapterHandler`` and ``NIOAsyncChannelWriterHandler`` collude -/// to ensure that the ``Channel`` they share is closed appropriately. -/// -/// The strategy of this type is that it keeps track of which side has closed, so that the handlers can work out -/// which of them was "last", in order to arrange closure. -@usableFromInline -final class CloseRatchet { - @usableFromInline - enum State { - case notClosed(isOutboundHalfClosureEnabled: Bool) - case readClosed - case writeClosed - case bothClosed - - @inlinable - mutating func closeRead() -> CloseReadAction { - switch self { - case .notClosed: - self = .readClosed - return .nothing - case .writeClosed: - self = .bothClosed - return .close - case .readClosed, .bothClosed: - preconditionFailure("Duplicate read closure") - } - } - - @inlinable - mutating func closeWrite() -> CloseWriteAction { - switch self { - case .notClosed(let isOutboundHalfClosureEnabled): - self = .writeClosed - - if isOutboundHalfClosureEnabled { - return .closeOutput - } else { - return .nothing - } - case .readClosed: - self = .bothClosed - return .close - case .writeClosed, .bothClosed: - preconditionFailure("Duplicate write closure") - } - } - } - - @usableFromInline - var _state: State - - @inlinable - init(isOutboundHalfClosureEnabled: Bool) { - self._state = .notClosed(isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled) - } - - @usableFromInline - enum CloseReadAction { - case nothing - case close - } - - @inlinable - func closeRead() -> CloseReadAction { - return self._state.closeRead() - } - - @usableFromInline - enum CloseWriteAction { - case nothing - case close - case closeOutput - } - - @inlinable - func closeWrite() -> CloseWriteAction { - return self._state.closeWrite() - } -} diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift index ba39820021..c724f7798f 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift @@ -139,17 +139,52 @@ public struct NIOAsyncSequenceProducer< /// - Parameters: /// - elementType: The element type of the sequence. /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - finishOnDeinit: Indicates if ``NIOAsyncSequenceProducer/Source/finish()`` should be called on deinit of the the source. + /// We do not recommend to rely on deinit based resource tear down. /// - delegate: The delegate of the sequence /// - Returns: A ``NIOAsyncSequenceProducer/Source`` and a ``NIOAsyncSequenceProducer``. @inlinable public static func makeSequence( elementType: Element.Type = Element.self, backPressureStrategy: Strategy, + finishOnDeinit: Bool, delegate: Delegate ) -> NewSequence { let newSequence = NIOThrowingAsyncSequenceProducer.makeNonThrowingSequence( elementType: Element.self, backPressureStrategy: backPressureStrategy, + finishOnDeinit: finishOnDeinit, + delegate: delegate + ) + + let sequence = self.init(throwingSequence: newSequence.sequence) + + return .init(source: Source(throwingSource: newSequence.source), sequence: sequence) + } + + /// Initializes a new ``NIOAsyncSequenceProducer`` and a ``NIOAsyncSequenceProducer/Source``. + /// + /// - Important: This method returns a struct containing a ``NIOAsyncSequenceProducer/Source`` and + /// a ``NIOAsyncSequenceProducer``. The source MUST be held by the caller and + /// used to signal new elements or finish. The sequence MUST be passed to the actual consumer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying source. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - delegate: The delegate of the sequence + /// - Returns: A ``NIOAsyncSequenceProducer/Source`` and a ``NIOAsyncSequenceProducer``. + @inlinable + @available(*, deprecated, renamed: "makeSequence(elementType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + public static func makeSequence( + elementType: Element.Type = Element.self, + backPressureStrategy: Strategy, + delegate: Delegate + ) -> NewSequence { + let newSequence = NIOThrowingAsyncSequenceProducer.makeNonThrowingSequence( + elementType: Element.self, + backPressureStrategy: backPressureStrategy, + finishOnDeinit: true, delegate: delegate ) @@ -209,45 +244,20 @@ extension NIOAsyncSequenceProducer { /// This type allows the producer to synchronously `yield` new elements to the ``NIOAsyncSequenceProducer`` /// and to `finish` the sequence. public struct Source { - /// This class is needed to hook the deinit to observe once all references to the ``NIOAsyncSequenceProducer/Source`` are dropped. - /// - /// - Important: This is safe to be unchecked ``Sendable`` since the `storage` is ``Sendable`` and `immutable`. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { - @usableFromInline - typealias ThrowingSource = NIOThrowingAsyncSequenceProducer< - Element, - Never, - Strategy, - Delegate - >.Source - - @usableFromInline - /* fileprivate */ internal let _throwingSource: ThrowingSource - - @inlinable - init(throwingSource: ThrowingSource) { - self._throwingSource = throwingSource - } - - @inlinable - deinit { - // We need to call finish here to resume any suspended continuation. - self._throwingSource.finish() - } - } - - @usableFromInline - /* private */ internal let _internalClass: InternalClass + typealias ThrowingSource = NIOThrowingAsyncSequenceProducer< + Element, + Never, + Strategy, + Delegate + >.Source @usableFromInline - /* private */ internal var _throwingSource: InternalClass.ThrowingSource { - self._internalClass._throwingSource - } + /* private */ internal var _throwingSource: ThrowingSource @usableFromInline - /* fileprivate */ internal init(throwingSource: InternalClass.ThrowingSource) { - self._internalClass = .init(throwingSource: throwingSource) + /* fileprivate */ internal init(throwingSource: ThrowingSource) { + self._throwingSource = throwingSource } /// The result of a call to ``NIOAsyncSequenceProducer/Source/yield(_:)``. diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index dd07fec57f..ed1011368c 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -15,13 +15,13 @@ import Atomics import DequeModule import NIOConcurrencyHelpers +import _NIODataStructures /// The delegate of the ``NIOAsyncWriter``. It is the consumer of the yielded writes to the ``NIOAsyncWriter``. /// Furthermore, the delegate gets informed when the ``NIOAsyncWriter`` terminated. /// -/// - Important: The methods on the delegate are called while a lock inside of the ``NIOAsyncWriter`` is held. This is done to -/// guarantee the ordering of the writes. However, this means you **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` -/// from within ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` or ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``. +/// - Important: The methods on the delegate might be called on arbitrary threads and the implementation must ensure +/// that proper synchronization is in place. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public protocol NIOAsyncWriterSinkDelegate: Sendable { /// The `Element` type of the delegate and the writer. @@ -31,22 +31,20 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// /// If the ``NIOAsyncWriter`` was writable when the sequence was yielded, the sequence will be forwarded /// right away to the delegate. If the ``NIOAsyncWriter`` was _NOT_ writable then the sequence will be buffered - /// until the ``NIOAsyncWriter`` becomes writable again. All buffered writes, while the ``NIOAsyncWriter`` is not writable, - /// will be coalesced into a single sequence. + /// until the ``NIOAsyncWriter`` becomes writable again. /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. + /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. func didYield(contentsOf sequence: Deque) /// This method is called once a single element was yielded to the ``NIOAsyncWriter``. /// /// If the ``NIOAsyncWriter`` was writable when the sequence was yielded, the sequence will be forwarded /// right away to the delegate. If the ``NIOAsyncWriter`` was _NOT_ writable then the sequence will be buffered - /// until the ``NIOAsyncWriter`` becomes writable again. All buffered writes, while the ``NIOAsyncWriter`` is not writable, - /// will be coalesced into a single sequence. + /// until the ``NIOAsyncWriter`` becomes writable again. /// /// - Note: This a fast path that you can optionally implement. By default this will just call ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)``. /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. + /// The delegate might reentrantly call ``NIOAsyncWriter/Sink/setWritability(to:)`` while still processing writes. func didYield(_ element: Element) /// This method is called once the ``NIOAsyncWriter`` is terminated. @@ -55,15 +53,11 @@ public protocol NIOAsyncWriterSinkDelegate: Sendable { /// - The ``NIOAsyncWriter`` is deinited and all yielded elements have been delivered to the delegate. /// - ``NIOAsyncWriter/finish()`` is called and all yielded elements have been delivered to the delegate. /// - ``NIOAsyncWriter/finish(error:)`` is called and all yielded elements have been delivered to the delegate. - /// - ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` is called. /// - /// - Note: This is guaranteed to be called _exactly_ once. + /// - Note: This is guaranteed to be called _at most_ once. /// /// - Parameter error: The error that terminated the ``NIOAsyncWriter``. If the writer was terminated without an - /// error this value is `nil`. This can be either the error passed to ``NIOAsyncWriter/finish(error:)`` or - /// to ``NIOAsyncWriter/Sink/finish(error:)``. - /// - /// - Important: You **MUST NOT** call ``NIOAsyncWriter/Sink/setWritability(to:)`` from within this method. + /// error this value is `nil`. This can be either the error passed to ``NIOAsyncWriter/finish(error:)``. func didTerminate(error: Error?) } @@ -167,14 +161,23 @@ public struct NIOAsyncWriter< @usableFromInline internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - _storage.writerDeinitialized() + if !self._finishOnDeinit && !self._storage.isWriterFinished { + preconditionFailure("Deinited NIOAsyncWriter without calling finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.writerFinish(error: nil) + } } } @@ -199,16 +202,49 @@ public struct NIOAsyncWriter< /// - delegate: The delegate of the writer. /// - Returns: A ``NIOAsyncWriter/NewWriter``. @inlinable + @available(*, deprecated, renamed: "makeWriter(elementType:isWritable:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + public static func makeWriter( + elementType: Element.Type = Element.self, + isWritable: Bool, + delegate: Delegate + ) -> NewWriter { + let writer = Self( + isWritable: isWritable, + finishOnDeinit: true, + delegate: delegate + ) + let sink = Sink(storage: writer._storage, finishOnDeinit: true) + + return .init(sink: sink, writer: writer) + } + + /// Initializes a new ``NIOAsyncWriter`` and a ``NIOAsyncWriter/Sink``. + /// + /// - Important: This method returns a struct containing a ``NIOAsyncWriter/Sink`` and + /// a ``NIOAsyncWriter``. The sink MUST be held by the caller and is used to set the writability. + /// The writer MUST be passed to the actual producer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying sink. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - isWritable: The initial writability state of the writer. + /// - finishOnDeinit: Indicates if ``NIOAsyncWriter/finish()`` should be called on deinit. We do not recommend to rely on + /// deinit based resource tear down. + /// - delegate: The delegate of the writer. + /// - Returns: A ``NIOAsyncWriter/NewWriter``. + @inlinable public static func makeWriter( elementType: Element.Type = Element.self, isWritable: Bool, + finishOnDeinit: Bool, delegate: Delegate ) -> NewWriter { let writer = Self( isWritable: isWritable, + finishOnDeinit: finishOnDeinit, delegate: delegate ) - let sink = Sink(storage: writer._storage) + let sink = Sink(storage: writer._storage, finishOnDeinit: finishOnDeinit) return .init(sink: sink, writer: writer) } @@ -216,28 +252,24 @@ public struct NIOAsyncWriter< @inlinable /* private */ internal init( isWritable: Bool, + finishOnDeinit: Bool, delegate: Delegate ) { let storage = Storage( isWritable: isWritable, delegate: delegate ) - self._internalClass = .init(storage: storage) + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// Yields a sequence of new elements to the ``NIOAsyncWriter``. /// /// If the ``NIOAsyncWriter`` is writable the sequence will get forwarded to the ``NIOAsyncWriterSinkDelegate`` immediately. /// Otherwise, the sequence will be buffered and the call to ``NIOAsyncWriter/yield(contentsOf:)`` will get suspended until the ``NIOAsyncWriter`` - /// becomes writable again. If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(contentsOf:)`` - /// will be resumed. + /// becomes writable again. /// - /// If the ``NIOAsyncWriter/finish()`` or ``NIOAsyncWriter/finish(error:)`` method is called while a call to - /// ``NIOAsyncWriter/yield(contentsOf:)`` is suspended then the call will be resumed and the yielded sequence will be kept buffered. - /// - /// If the ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` method is called while - /// a call to ``NIOAsyncWriter/yield(contentsOf:)`` is suspended then the call will be resumed with an error and the - /// yielded sequence is dropped. + /// If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(contentsOf:)`` + /// will be resumed. Consequently, the provided elements will not be yielded. /// /// This can be called more than once and from multiple `Task`s at the same time. /// @@ -251,22 +283,17 @@ public struct NIOAsyncWriter< /// /// If the ``NIOAsyncWriter`` is writable the element will get forwarded to the ``NIOAsyncWriterSinkDelegate`` immediately. /// Otherwise, the element will be buffered and the call to ``NIOAsyncWriter/yield(_:)`` will get suspended until the ``NIOAsyncWriter`` - /// becomes writable again. If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(_:)`` - /// will be resumed. - /// - /// If the ``NIOAsyncWriter/finish()`` or ``NIOAsyncWriter/finish(error:)`` method is called while a call to - /// ``NIOAsyncWriter/yield(_:)`` is suspended then the call will be resumed and the yielded sequence will be kept buffered. + /// becomes writable again. /// - /// If the ``NIOAsyncWriter/Sink/finish()`` or ``NIOAsyncWriter/Sink/finish(error:)`` method is called while - /// a call to ``NIOAsyncWriter/yield(_:)`` is suspended then the call will be resumed with an error and the - /// yielded sequence is dropped. + /// If the calling `Task` gets cancelled at any point the call to ``NIOAsyncWriter/yield(_:)`` + /// will be resumed. Consequently, the provided element will not be yielded. /// /// This can be called more than once and from multiple `Task`s at the same time. /// /// - Parameter element: The element to yield. @inlinable public func yield(_ element: Element) async throws { - try await self._storage.yield(element) + try await self._storage.yield(contentsOf: CollectionOfOne(element)) } /// Finishes the writer. @@ -275,7 +302,7 @@ public struct NIOAsyncWriter< /// or ``NIOAsyncWriter/yield(_:)`` will be resumed. Any subsequent calls to ``NIOAsyncWriter/yield(contentsOf:)`` /// or ``NIOAsyncWriter/yield(_:)`` will throw. /// - /// Any element that have been yielded elements before the writer has been finished which have not been delivered yet are continued + /// Any element that have been yielded before the writer has been finished which have not been delivered yet are continued /// to be buffered and will be delivered once the writer becomes writable again. /// /// - Note: Calling this function more than once has no effect. @@ -290,7 +317,7 @@ public struct NIOAsyncWriter< /// or ``NIOAsyncWriter/yield(_:)`` will be resumed. Any subsequent calls to ``NIOAsyncWriter/yield(contentsOf:)`` /// or ``NIOAsyncWriter/yield(_:)`` will throw. /// - /// Any element that have been yielded elements before the writer has been finished which have not been delivered yet are continued + /// Any element that have been yielded before the writer has been finished which have not been delivered yet are continued /// to be buffered and will be delivered once the writer becomes writable again. /// /// - Note: Calling this function more than once has no effect. @@ -313,15 +340,23 @@ extension NIOAsyncWriter { @usableFromInline /* fileprivate */ internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - // We need to call finish here to resume any suspended continuation. - self._storage.sinkFinish(error: nil) + if !self._finishOnDeinit && !self._storage.isSinkFinished { + preconditionFailure("Deinited NIOAsyncWriter.Sink without calling sink.finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.sinkFinish(error: nil) + } } } @@ -334,8 +369,8 @@ extension NIOAsyncWriter { } @inlinable - init(storage: Storage) { - self._internalClass = .init(storage: storage) + init(storage: Storage, finishOnDeinit: Bool) { + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// Sets the writability of the ``NIOAsyncWriter``. @@ -421,67 +456,67 @@ extension NIOAsyncWriter { /// The state machine. @usableFromInline /* private */ internal var _stateMachine: StateMachine + /// Hook used in testing. + @usableFromInline + internal var _didSuspend: (() -> Void)? + + @inlinable + internal var isWriterFinished: Bool { + self._lock.withLock { self._stateMachine.isWriterFinished } + } + + @inlinable + internal var isSinkFinished: Bool { + self._lock.withLock { self._stateMachine.isSinkFinished } + } @inlinable /* fileprivate */ internal init( isWritable: Bool, delegate: Delegate ) { - self._stateMachine = .init(isWritable: isWritable, delegate: delegate) + self._stateMachine = .init( + isWritable: isWritable, + delegate: delegate + ) } @inlinable - /* fileprivate */ internal func writerDeinitialized() { - self._lock.withLock { - let action = self._stateMachine.writerDeinitialized() + /* fileprivate */ internal func setWritability(to writability: Bool) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.setWritability(to: writability) + } - switch action { - case .callDidTerminate(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: nil) + switch action { + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } - case .none: - break - } + case .none: + return } } @inlinable - /* fileprivate */ internal func setWritability(to writability: Bool) { - self._lock.withLock { - let action = self._stateMachine.setWritability(to: writability) - - switch action { - case .callDidYieldAndResumeContinuations(let delegate, let elements, let suspendedYields): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didYield(contentsOf: elements) - - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } - - case .callDidYieldAndDidTerminate(let delegate, let elements): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didYield(contentsOf: elements) - delegate.didTerminate(error: nil) - - case .none: + /* fileprivate */ internal func yield(contentsOf sequence: S) async throws where S.Element == Element { + let yieldID = self._yieldIDGenerator.generateUniqueYieldID() + while true { + switch try await self._yield(contentsOf: sequence, yieldID: yieldID) { + case .retry: + continue + case .yielded: return } } } @inlinable - /* fileprivate */ internal func yield(contentsOf sequence: S) async throws where S.Element == Element { - let yieldID = self._yieldIDGenerator.generateUniqueYieldID() + /* fileprivate */ internal func _yield(contentsOf sequence: S, yieldID: StateMachine.YieldID?) async throws -> StateMachine.YieldResult where S.Element == Element { + let yieldID = yieldID ?? self._yieldIDGenerator.generateUniqueYieldID() - try await withTaskCancellationHandler { + return try await withTaskCancellationHandler { // We are manually locking here to hold the lock across the withCheckedContinuation call self._lock.lock() @@ -489,24 +524,18 @@ extension NIOAsyncWriter { switch action { case .callDidYield(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - // We are allocating a new Deque for every write here - delegate.didYield(contentsOf: Deque(sequence)) self._lock.unlock() - - case .returnNormally: - self._lock.unlock() - return + delegate.didYield(contentsOf: Deque(sequence)) + self.unbufferQueuedEvents() + return .yielded case .throwError(let error): self._lock.unlock() throw error case .suspendTask: - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self._stateMachine.yield( contentsOf: sequence, continuation: continuation, @@ -514,127 +543,77 @@ extension NIOAsyncWriter { ) self._lock.unlock() + self._didSuspend?() } } } onCancel: { - self._lock.withLock { - let action = self._stateMachine.cancel(yieldID: yieldID) + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.cancel(yieldID: yieldID) + } - switch action { - case .resumeContinuation(let continuation): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - continuation.resume() + switch action { + case .resumeContinuationWithCancellationError(let continuation): + continuation.resume(throwing: CancellationError()) - case .none: - break - } + case .none: + break } } } @inlinable - /* fileprivate */ internal func yield(_ element: Element) async throws { - let yieldID = self._yieldIDGenerator.generateUniqueYieldID() - - try await withTaskCancellationHandler { - // We are manually locking here to hold the lock across the withCheckedContinuation call - self._lock.lock() - - let action = self._stateMachine.yield(contentsOf: CollectionOfOne(element), yieldID: yieldID) - - switch action { - case .callDidYield(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - - delegate.didYield(element) - self._lock.unlock() - - case .returnNormally: - self._lock.unlock() - return - - case .throwError(let error): - self._lock.unlock() - throw error - - case .suspendTask: - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self._stateMachine.yield( - contentsOf: CollectionOfOne(element), - continuation: continuation, - yieldID: yieldID - ) + /* fileprivate */ internal func writerFinish(error: Error?) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.writerFinish(error: error) + } - self._lock.unlock() - } - } - } onCancel: { - self._lock.withLock { - let action = self._stateMachine.cancel(yieldID: yieldID) + switch action { + case .callDidTerminate(let delegate): + delegate.didTerminate(error: error) - switch action { - case .resumeContinuation(let continuation): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - continuation.resume() + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } - case .none: - break - } - } + case .none: + break } } @inlinable - /* fileprivate */ internal func writerFinish(error: Error?) { - self._lock.withLock { - let action = self._stateMachine.writerFinish() + /* fileprivate */ internal func sinkFinish(error: Error?) { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.sinkFinish(error: error) + } - switch action { - case .callDidTerminate(let delegate): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: error) + switch action { + case .resumeContinuationsWithError(let suspendedYields, let error): + suspendedYields.forEach { $0.continuation.resume(throwing: error) } - case .resumeContinuations(let suspendedYields): - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume() } - - case .none: - break - } + case .none: + break } } - @inlinable - /* fileprivate */ internal func sinkFinish(error: Error?) { - self._lock.withLock { - let action = self._stateMachine.sinkFinish(error: error) + @inlinable + /* fileprivate */ internal func unbufferQueuedEvents() { + while let action = self._lock.withLock({ self._stateMachine.unbufferQueuedEvents()}) { switch action { case .callDidTerminate(let delegate, let error): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate - delegate.didTerminate(error: error) - - case .resumeContinuationsWithErrorAndCallDidTerminate(let delegate, let suspendedYields, let error): - // We are calling the delegate while holding lock. This can lead to potential crashes - // if the delegate calls `setWritability` reentrantly. However, we call this - // out in the docs of the delegate delegate.didTerminate(error: error) - // It is safe to resume the continuations while holding the lock since resume - // is immediately returning and just enqueues the Job on the executor - suspendedYields.forEach { $0.continuation.resume(throwing: error) } - - case .none: - break + case .resumeContinuations(let suspendedYields): + suspendedYields.forEach { $0.continuation.resume(returning: .retry) } + return } } } @@ -656,18 +635,26 @@ extension NIOAsyncWriter { /// The yield's produced sequence of elements. /// The yield's continuation. @usableFromInline - var continuation: CheckedContinuation + var continuation: CheckedContinuation @inlinable - init(yieldID: YieldID, continuation: CheckedContinuation) { + init(yieldID: YieldID, continuation: CheckedContinuation) { self.yieldID = yieldID self.continuation = continuation } } + /// The internal result of a yield. + @usableFromInline + /* private */ internal enum YieldResult { + /// Indicates that the elements got yielded to the sink. + case yielded + /// Indicates that the yield should be retried. + case retry + } /// The current state of our ``NIOAsyncWriter``. @usableFromInline - /* private */ internal enum State { + /* private */ internal enum State: CustomStringConvertible { /// The initial state before either a call to ``NIOAsyncWriter/yield(contentsOf:)`` or /// ``NIOAsyncWriter/finish(completion:)`` happened. case initial( @@ -678,18 +665,24 @@ extension NIOAsyncWriter { /// The state after a call to ``NIOAsyncWriter/yield(contentsOf:)``. case streaming( isWritable: Bool, + inDelegateOutcall: Bool, cancelledYields: [YieldID], - suspendedYields: [SuspendedYield], - elements: Deque, + suspendedYields: _TinyArray, delegate: Delegate ) - /// The state once the writer finished and there are still elements that need to be delivered. This can happen if: + /// The state once the writer finished and there are still tasks that need to write. This can happen if: /// 1. The ``NIOAsyncWriter`` was deinited /// 2. ``NIOAsyncWriter/finish(completion:)`` was called. case writerFinished( - elements: Deque, - delegate: Delegate + isWritable: Bool, + inDelegateOutcall: Bool, + suspendedYields: _TinyArray, + cancelledYields: [YieldID], + // These are the yields that have been enqueued before the writer got finished. + bufferedYieldIDs: _TinyArray, + delegate: Delegate, + error: Error? ) /// The state once the sink has been finished or the writer has been finished and all elements @@ -698,6 +691,22 @@ extension NIOAsyncWriter { /// Internal state to avoid CoW. case modifying + + @usableFromInline + var description: String { + switch self { + case .initial(let isWritable, _): + return "initial(isWritable: \(isWritable))" + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, _): + return "streaming(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), cancelledYields: \(cancelledYields.count), suspendedYields: \(suspendedYields.count))" + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, _, _): + return "writerFinished(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), suspendedYields: \(suspendedYields.count), cancelledYields: \(cancelledYields.count), bufferedYieldIDs: \(bufferedYieldIDs.count)" + case .finished: + return "finished" + case .modifying: + return "modifying" + } + } } /// The state machine's current state. @@ -705,70 +714,47 @@ extension NIOAsyncWriter { /* private */ internal var _state: State @inlinable - init( - isWritable: Bool, - delegate: Delegate - ) { - self._state = .initial(isWritable: isWritable, delegate: delegate) - } - - /// Actions returned by `writerDeinitialized()`. - @usableFromInline - enum WriterDeinitializedAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. - case callDidTerminate(Delegate) - /// Indicates that nothing should be done. - case none + internal var isWriterFinished: Bool { + switch self._state { + case .initial, .streaming: + return false + case .writerFinished, .finished: + return true + case .modifying: + preconditionFailure("Invalid state") + } } @inlinable - /* fileprivate */ internal mutating func writerDeinitialized() -> WriterDeinitializedAction { + internal var isSinkFinished: Bool { switch self._state { - case .initial(_, let delegate): - // The writer deinited before writing anything. - // We can transition to finished and inform our delegate - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) - - case .streaming(_, _, let suspendedYields, let elements, let delegate): - // The writer got deinited after we started streaming. - // This is normal and we need to transition to finished - // and call the delegate. However, we should not have - // any suspended yields because they MUST strongly retain - // the writer. - precondition(suspendedYields.isEmpty, "We have outstanding suspended yields") - precondition(elements.isEmpty, "We have buffered elements") - - // We have no elements left and can transition to finished directly - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) - - case .finished, .writerFinished: - // We are already finished nothing to do here - return .none - + case .initial, .streaming, .writerFinished: + return false + case .finished: + return true case .modifying: preconditionFailure("Invalid state") } } + + @inlinable + init( + isWritable: Bool, + delegate: Delegate + ) { + self._state = .initial(isWritable: isWritable, delegate: delegate) + } + /// Actions returned by `setWritability()`. @usableFromInline enum SetWritabilityAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` should be called - /// and all continuations should be resumed. - case callDidYieldAndResumeContinuations(Delegate, Deque, [SuspendedYield]) - /// Indicates that ``NIOAsyncWriterSinkDelegate/didYield(contentsOf:)`` and - /// ``NIOAsyncWriterSinkDelegate/didTerminate(error:)``should be called. - case callDidYieldAndDidTerminate(Delegate, Deque) - /// Indicates that nothing should be done. - case none + /// Indicates that all writer continuations should be resumed. + case resumeContinuations(_TinyArray) } @inlinable - /* fileprivate */ internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction { + /* fileprivate */ internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction? { switch self._state { case .initial(_, let delegate): // We just need to store the new writability state @@ -776,53 +762,91 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let cancelledYields, let suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): if isWritable == newWritability { // The writability didn't change so we can just early exit here return .none } - if newWritability { - // We became writable again. This means we have to resume all the continuations - // and yield the values. - + if newWritability && !inDelegateOutcall { + // We became writable again. This means we have to resume all the continuations. self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, - suspendedYields: [], - elements: .init(), + suspendedYields: .init(), delegate: delegate ) - // We are taking the whole array of suspended yields and the deque of elements - // and allocate a new empty one. - // As a performance optimization we could always keep multiple arrays/deques and - // switch between them but I don't think this is the performance critical part. - return .callDidYieldAndResumeContinuations(delegate, elements, suspendedYields) + return .resumeContinuations(suspendedYields) + } else if newWritability && inDelegateOutcall { + // We became writable but are in a delegate outcall. + // We just have to store the new writability here. + self._state = .streaming( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + delegate: delegate + ) + return .none } else { // We became unwritable nothing really to do here - precondition(suspendedYields.isEmpty, "No yield should be suspended at this point") - precondition(elements.isEmpty, "No element should be buffered at this point") - self._state = .streaming( isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none } - case .writerFinished(let elements, let delegate): + case .writerFinished(_, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): if !newWritability { // We are not writable so we can't deliver the outstanding elements return .none } - self._state = .finished(sinkError: nil) + if newWritability && !inDelegateOutcall { + // We became writable again. This means we have to resume all the continuations. + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) - return .callDidYieldAndDidTerminate(delegate, elements) + return .resumeContinuations(suspendedYields) + } else if newWritability && inDelegateOutcall { + // We became writable but are in a delegate outcall. + // We just have to store the new writability here. + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .none + } else { + // We became unwritable nothing really to do here + self._state = .writerFinished( + isWritable: newWritability, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .none + } case .finished: // We are already finished nothing to do here @@ -840,8 +864,6 @@ extension NIOAsyncWriter { case callDidYield(Delegate) /// Indicates that the calling `Task` should get suspended. case suspendTask - /// Indicates that the method should just return. - case returnNormally /// Indicates the given error should be thrown. case throwError(Error) @@ -866,59 +888,112 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: isWritable, // If we are writable we are going to make an outcall cancelledYields: [], - suspendedYields: [], - elements: .init(), + suspendedYields: .init(), delegate: delegate ) return .init(isWritable: isWritable, delegate: delegate) - case .streaming(let isWritable, var cancelledYields, let suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, let suspendedYields, let delegate): + self._state = .modifying + if let index = cancelledYields.firstIndex(of: yieldID) { // We already marked the yield as cancelled. We have to remove it and - // throw an error. - self._state = .modifying - + // throw a CancellationError. cancelledYields.remove(at: index) - if isWritable { - // We are writable so we can yield the elements right away and then - // return normally. + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + delegate: delegate + ) + + return .throwError(CancellationError()) + } else { + // Yield hasn't been marked as cancelled. + + switch (isWritable, inDelegateOutcall) { + case (true, false): self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) - return .callDidYield(delegate) - } else { - // We are not writable so we are just going to enqueue the writes - // and return normally. We are not suspending the yield since the Task - // is marked as cancelled. - elements.append(contentsOf: sequence) + return .callDidYield(delegate) + case (true, true), (false, _): self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) + return .suspendTask + } + } + + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + if bufferedYieldIDs.contains(yieldID) { + // This yield was buffered before we became finished so we still have to deliver it + self._state = .modifying + + if let index = cancelledYields.firstIndex(of: yieldID) { + // We already marked the yield as cancelled. We have to remove it and + // throw a CancellationError. + cancelledYields.remove(at: index) + + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) - return .returnNormally + return .throwError(CancellationError()) + } else { + // Yield hasn't been marked as cancelled. + + switch (isWritable, inDelegateOutcall) { + case (true, false): + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: true, // We are now making a call to the delegate + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .callDidYield(delegate) + case (true, true), (false, _): + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + return .suspendTask + } } } else { - // Yield hasn't been marked as cancelled. - // This means we can either call the delegate or suspend - return .init(isWritable: isWritable, delegate: delegate) + // We are already finished and still tried to write something + return .throwError(NIOAsyncWriterError.alreadyFinished()) } - case .writerFinished: - // We are already finished and still tried to write something - return .throwError(NIOAsyncWriterError.alreadyFinished()) - case .finished(let sinkError): // We are already finished and still tried to write something return .throwError(sinkError ?? NIOAsyncWriterError.alreadyFinished()) @@ -932,11 +1007,11 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal mutating func yield( contentsOf sequence: S, - continuation: CheckedContinuation, + continuation: CheckedContinuation, yieldID: YieldID ) where S.Element == Element { switch self._state { - case .streaming(let isWritable, let cancelledYields, var suspendedYields, var elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, var suspendedYields, let delegate): // We have a suspended yield at this point that hasn't been cancelled yet. // We need to store the yield now. @@ -947,13 +1022,12 @@ extension NIOAsyncWriter { continuation: continuation ) suspendedYields.append(suspendedYield) - elements.append(contentsOf: sequence) self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) @@ -968,7 +1042,8 @@ extension NIOAsyncWriter { /// Actions returned by `cancel()`. @usableFromInline enum CancelAction { - case resumeContinuation(CheckedContinuation) + /// Indicates that the continuation should be resumed with a `CancellationError`. + case resumeContinuationWithCancellationError(CheckedContinuation) /// Indicates that nothing should be done. case none } @@ -984,15 +1059,15 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: false, cancelledYields: [yieldID], - suspendedYields: [], - elements: .init(), + suspendedYields: .init(), delegate: delegate ) return .none - case .streaming(let isWritable, var cancelledYields, var suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, var suspendedYields, let delegate): if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { self._state = .modifying // We have a suspended yield for the id. We need to resume the continuation now. @@ -1006,13 +1081,13 @@ extension NIOAsyncWriter { // We are keeping the elements that the yield produced. self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) - return .resumeContinuation(suspendedYield.continuation) + return .resumeContinuationWithCancellationError(suspendedYield.continuation) } else { self._state = .modifying @@ -1023,16 +1098,63 @@ extension NIOAsyncWriter { cancelledYields.append(yieldID) self._state = .streaming( isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, cancelledYields: cancelledYields, suspendedYields: suspendedYields, - elements: elements, delegate: delegate ) return .none } - case .writerFinished, .finished: + case .writerFinished(let isWritable, let inDelegateOutcall, var suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + guard bufferedYieldIDs.contains(yieldID) else { + return .none + } + if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { + self._state = .modifying + // We have a suspended yield for the id. We need to resume the continuation now. + + // Removing can be quite expensive if it produces a gap in the array. + // Since we are not expecting a lot of elements in this array it should be fine + // to just remove. If this turns out to be a performance pitfall, we can + // swap the elements before removing. So that we always remove the last element. + let suspendedYield = suspendedYields.remove(at: index) + + // We are keeping the elements that the yield produced. + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .resumeContinuationWithCancellationError(suspendedYield.continuation) + + } else { + self._state = .modifying + // There is no suspended yield. This can mean that we either already yielded + // or that the call to `yield` is coming afterwards. We need to store + // the ID here. However, if the yield already happened we will never remove the + // stored ID. The only way to avoid doing this would be storing every ID + cancelledYields.append(yieldID) + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .none + } + + case .finished: // We are already finished and there is nothing to do return .none @@ -1047,13 +1169,13 @@ extension NIOAsyncWriter { /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. case callDidTerminate(Delegate) /// Indicates that all continuations should be resumed. - case resumeContinuations([SuspendedYield]) + case resumeContinuations(_TinyArray) /// Indicates that nothing should be done. case none } @inlinable - /* fileprivate */ internal mutating func writerFinish() -> WriterFinishAction { + /* fileprivate */ internal mutating func writerFinish(error: Error?) -> WriterFinishAction { switch self._state { case .initial(_, let delegate): // Nothing was ever written so we can transition to finished @@ -1061,23 +1183,41 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(_, _, let suspendedYields, let elements, let delegate): + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): // We are currently streaming and the writer got finished. - if elements.isEmpty { - // We have no elements left and can transition to finished directly - self._state = .finished(sinkError: nil) - - return .callDidTerminate(delegate) + if suspendedYields.isEmpty { + if inDelegateOutcall { + // We are in an outcall already and have to buffer + // the didTerminate call. + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: .init(), + delegate: delegate, + error: error + ) + return .none + } else { + // We have no elements left and are not in an outcall so we + // can transition to finished directly + self._state = .finished(sinkError: nil) + return .callDidTerminate(delegate) + } } else { - // There are still elements left which we need to deliver once we become writable again + // There are still suspended writer tasks which we need to deliver once we become writable again self._state = .writerFinished( - elements: elements, - delegate: delegate + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: suspendedYields, + cancelledYields: cancelledYields, + bufferedYieldIDs: _TinyArray(suspendedYields.map { $0.yieldID }), + delegate: delegate, + error: error ) - // We are not resuming the continuations with the error here since their elements - // are still queued up. If they try to yield again they will run into an alreadyFinished error - return .resumeContinuations(suspendedYields) + return .none } case .writerFinished, .finished: @@ -1092,11 +1232,8 @@ extension NIOAsyncWriter { /// Actions returned by `sinkFinish()`. @usableFromInline enum SinkFinishAction { - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called. - case callDidTerminate(Delegate, Error?) - /// Indicates that ``NIOAsyncWriterSinkDelegate/didTerminate(completion:)`` should be called and all - /// continuations should be resumed with the given error. - case resumeContinuationsWithErrorAndCallDidTerminate(Delegate, [SuspendedYield], Error) + /// Indicates that all continuations should be resumed with the given error. + case resumeContinuationsWithError(_TinyArray, Error) /// Indicates that nothing should be done. case none } @@ -1104,30 +1241,29 @@ extension NIOAsyncWriter { @inlinable /* fileprivate */ internal mutating func sinkFinish(error: Error?) -> SinkFinishAction { switch self._state { - case .initial(_, let delegate): + case .initial(_, _): // Nothing was ever written so we can transition to finished self._state = .finished(sinkError: error) - return .callDidTerminate(delegate, error) + return .none - case .streaming(_, _, let suspendedYields, _, let delegate): - // We are currently streaming and the writer got finished. + case .streaming(_, _, _, let suspendedYields, _): + // We are currently streaming and the sink got finished. // We can transition to finished and need to resume all continuations. self._state = .finished(sinkError: error) - - return .resumeContinuationsWithErrorAndCallDidTerminate( - delegate, + return .resumeContinuationsWithError( suspendedYields, error ?? NIOAsyncWriterError.alreadyFinished() ) - case .writerFinished(_, let delegate): - // The writer already finished and we were waiting to become writable again - // The Sink finished before we became writable so we can drop the elements and - // transition to finished + case .writerFinished(_, _, let suspendedYields, _, _, _, _): + // The writer already got finished and the sink got finished too now. + // We can transition to finished and need to resume all continuations. self._state = .finished(sinkError: error) - - return .callDidTerminate(delegate, error) + return .resumeContinuationsWithError( + suspendedYields, + error ?? NIOAsyncWriterError.alreadyFinished() + ) case .finished: // We are already finished and there is nothing to do @@ -1137,5 +1273,76 @@ extension NIOAsyncWriter { preconditionFailure("Invalid state") } } + + /// Actions returned by `sinkFinish()`. + @usableFromInline + enum UnbufferQueuedEventsAction { + case resumeContinuations(_TinyArray) + case callDidTerminate(Delegate, Error?) + } + + @inlinable + /* fileprivate */ internal mutating func unbufferQueuedEvents() -> UnbufferQueuedEventsAction? { + switch self._state { + case .initial: + preconditionFailure("Invalid state") + + case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): + precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") + // We have to resume the other suspended yields now. + + if suspendedYields.isEmpty { + // There are no other writer suspended writer tasks so we can just return + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: false, + cancelledYields: cancelledYields, + suspendedYields: suspendedYields, + delegate: delegate + ) + return .none + } else { + // We have to resume the other suspended yields now. + self._state = .streaming( + isWritable: isWritable, + inDelegateOutcall: false, + cancelledYields: cancelledYields, + suspendedYields: .init(), + delegate: delegate + ) + return .resumeContinuations(suspendedYields) + } + + case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): + precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") + if suspendedYields.isEmpty { + // We were the last writer task and can now call didTerminate + self._state = .finished(sinkError: nil) + return .callDidTerminate(delegate, error) + } else { + // There are still other writer tasks that need to be resumed + self._state = .modifying + + + self._state = .writerFinished( + isWritable: isWritable, + inDelegateOutcall: inDelegateOutcall, + suspendedYields: .init(), + cancelledYields: cancelledYields, + bufferedYieldIDs: bufferedYieldIDs, + delegate: delegate, + error: error + ) + + return .resumeContinuations(suspendedYields) + } + + case .finished: + return .none + + case .modifying: + preconditionFailure("Invalid state") + } + } } } diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index 2574d42d92..7b852dd6b6 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -85,6 +85,38 @@ public struct NIOThrowingAsyncSequenceProducer< self._internalClass._storage } + /// Initializes a new ``NIOThrowingAsyncSequenceProducer`` and a ``NIOThrowingAsyncSequenceProducer/Source``. + /// + /// - Important: This method returns a struct containing a ``NIOThrowingAsyncSequenceProducer/Source`` and + /// a ``NIOThrowingAsyncSequenceProducer``. The source MUST be held by the caller and + /// used to signal new elements or finish. The sequence MUST be passed to the actual consumer and MUST NOT be held by the + /// caller. This is due to the fact that deiniting the sequence is used as part of a trigger to terminate the underlying source. + /// + /// - Parameters: + /// - elementType: The element type of the sequence. + /// - failureType: The failure type of the sequence. Must be `Swift.Error` + /// - backPressureStrategy: The back-pressure strategy of the sequence. + /// - finishOnDeinit: Indicates if ``NIOThrowingAsyncSequenceProducer/Source/finish()`` should be called on deinit of the. + /// We do not recommend to rely on deinit based resource tear down. + /// - delegate: The delegate of the sequence + /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. + @inlinable + public static func makeSequence( + elementType: Element.Type = Element.self, + failureType: Failure.Type = Error.self, + backPressureStrategy: Strategy, + finishOnDeinit: Bool, + delegate: Delegate + ) -> NewSequence where Failure == Error { + let sequence = Self( + backPressureStrategy: backPressureStrategy, + delegate: delegate + ) + let source = Source(storage: sequence._storage, finishOnDeinit: finishOnDeinit) + + return .init(source: source, sequence: sequence) + } + /// Initializes a new ``NIOThrowingAsyncSequenceProducer`` and a ``NIOThrowingAsyncSequenceProducer/Source``. /// /// - Important: This method returns a struct containing a ``NIOThrowingAsyncSequenceProducer/Source`` and @@ -110,7 +142,7 @@ public struct NIOThrowingAsyncSequenceProducer< backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: true) return .init(source: source, sequence: sequence) } @@ -129,6 +161,7 @@ public struct NIOThrowingAsyncSequenceProducer< /// - delegate: The delegate of the sequence /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. @inlinable + @available(*, deprecated, renamed: "makeSequence(elementType:failureType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") public static func makeSequence( elementType: Element.Type = Element.self, failureType: Failure.Type = Error.self, @@ -139,7 +172,7 @@ public struct NIOThrowingAsyncSequenceProducer< backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: true) return .init(source: source, sequence: sequence) } @@ -149,13 +182,14 @@ public struct NIOThrowingAsyncSequenceProducer< internal static func makeNonThrowingSequence( elementType: Element.Type = Element.self, backPressureStrategy: Strategy, + finishOnDeinit: Bool, delegate: Delegate ) -> NewSequence where Failure == Never { let sequence = Self( backPressureStrategy: backPressureStrategy, delegate: delegate ) - let source = Source(storage: sequence._storage) + let source = Source(storage: sequence._storage, finishOnDeinit: finishOnDeinit) return .init(source: source, sequence: sequence) } @@ -238,15 +272,23 @@ extension NIOThrowingAsyncSequenceProducer { @usableFromInline internal let _storage: Storage + @usableFromInline + internal let _finishOnDeinit: Bool + @inlinable - init(storage: Storage) { + init(storage: Storage, finishOnDeinit: Bool) { self._storage = storage + self._finishOnDeinit = finishOnDeinit } @inlinable deinit { - // We need to call finish here to resume any suspended continuation. - self._storage.finish(nil) + if !self._finishOnDeinit && !self._storage.isFinished { + preconditionFailure("Deinited NIOAsyncSequenceProducer.Source without calling source.finish()") + } else { + // We need to call finish here to resume any suspended continuation. + self._storage.finish(nil) + } } } @@ -259,8 +301,8 @@ extension NIOThrowingAsyncSequenceProducer { } @usableFromInline - /* fileprivate */ internal init(storage: Storage) { - self._internalClass = .init(storage: storage) + /* fileprivate */ internal init(storage: Storage, finishOnDeinit: Bool) { + self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } /// The result of a call to ``NIOThrowingAsyncSequenceProducer/Source/yield(_:)``. @@ -356,6 +398,14 @@ extension NIOThrowingAsyncSequenceProducer { /// The delegate. @usableFromInline /* private */ internal var _delegate: Delegate? + /// Hook used in testing. + @usableFromInline + internal var _didSuspend: (() -> Void)? + + @inlinable + var isFinished: Bool { + self._lock.withLock { self._stateMachine.isFinished } + } @usableFromInline /* fileprivate */ internal init( @@ -414,63 +464,65 @@ extension NIOThrowingAsyncSequenceProducer { @inlinable /* fileprivate */ internal func yield(_ sequence: S) -> Source.YieldResult where S.Element == Element { - self._lock.withLock { - let action = self._stateMachine.yield(sequence) + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let action = self._lock.withLock { + self._stateMachine.yield(sequence) + } - switch action { - case .returnProduceMore: - return .produceMore + switch action { + case .returnProduceMore: + return .produceMore - case .returnStopProducing: - return .stopProducing + case .returnStopProducing: + return .stopProducing - case .returnDropped: - return .dropped + case .returnDropped: + return .dropped - case .resumeContinuationAndReturnProduceMore(let continuation, let element): - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - continuation.resume(returning: element) + case .resumeContinuationAndReturnProduceMore(let continuation, let element): + continuation.resume(returning: element) - return .produceMore + return .produceMore - case .resumeContinuationAndReturnStopProducing(let continuation, let element): - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - continuation.resume(returning: element) + case .resumeContinuationAndReturnStopProducing(let continuation, let element): + continuation.resume(returning: element) - return .stopProducing - } + return .stopProducing } } @inlinable /* fileprivate */ internal func finish(_ failure: Failure?) { - let delegate: Delegate? = self._lock.withLock { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.FinishAction) = self._lock.withLock { let action = self._stateMachine.finish(failure) - + switch action { - case .resumeContinuationWithFailureAndCallDidTerminate(let continuation, let failure): + case .resumeContinuationWithFailureAndCallDidTerminate: let delegate = self._delegate self._delegate = nil + return (delegate, action) - // It is safe to resume the continuation while holding the lock - // since the task will get enqueued on its executor and the resume method - // is returning immediately - switch failure { - case .some(let error): - continuation.resume(throwing: error) - case .none: - continuation.resume(returning: nil) - } - - return delegate + case .none: + return (nil, action) + } + } + switch action { + case .resumeContinuationWithFailureAndCallDidTerminate(let continuation, let failure): + switch failure { + case .some(let error): + continuation.resume(throwing: error) case .none: - return nil + continuation.resume(returning: nil) } + + case .none: + break } delegate?.didTerminate() @@ -546,10 +598,14 @@ extension NIOThrowingAsyncSequenceProducer { case .none: self._lock.unlock() } + self._didSuspend?() } } } onCancel: { - let delegate: Delegate? = self._lock.withLock { + // We must not resume the continuation while holding the lock + // because it can deadlock in combination with the underlying ulock + // in cases where we race with a cancellation handler + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.CancelledAction) = self._lock.withLock { let action = self._stateMachine.cancelled() switch action { @@ -557,33 +613,42 @@ extension NIOThrowingAsyncSequenceProducer { let delegate = self._delegate self._delegate = nil - return delegate - - case .resumeContinuationWithCancellationErrorAndCallDidTerminate(let continuation): - // We have deprecated the generic Failure type in the public API and Failure should - // now be `Swift.Error`. However, if users have not migrated to the new API they could - // still use a custom generic Error type and this cast might fail. - // In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the - // non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and - // this cast will fail as well. - // Everything is marked @inlinable and the Failure type is known at compile time, - // therefore this cast should be optimised away in release build. - if let failure = CancellationError() as? Failure { - continuation.resume(throwing: failure) - } else { - continuation.resume(returning: nil) - } + return (delegate, action) + case .resumeContinuationWithCancellationErrorAndCallDidTerminate: let delegate = self._delegate self._delegate = nil - return delegate + return (delegate, action) case .none: - return nil + return (nil, action) } } + switch action { + case .callDidTerminate: + break + + case .resumeContinuationWithCancellationErrorAndCallDidTerminate(let continuation): + // We have deprecated the generic Failure type in the public API and Failure should + // now be `Swift.Error`. However, if users have not migrated to the new API they could + // still use a custom generic Error type and this cast might fail. + // In addition, we use `NIOThrowingAsyncSequenceProducer` in the implementation of the + // non-throwing variant `NIOAsyncSequenceProducer` where `Failure` will be `Never` and + // this cast will fail as well. + // Everything is marked @inlinable and the Failure type is known at compile time, + // therefore this cast should be optimised away in release build. + if let failure = CancellationError() as? Failure { + continuation.resume(throwing: failure) + } else { + continuation.resume(returning: nil) + } + + case .none: + break + } + delegate?.didTerminate() } } @@ -634,6 +699,19 @@ extension NIOThrowingAsyncSequenceProducer { @usableFromInline /* private */ internal var _state: State + @inlinable + var isFinished: Bool { + switch self._state { + case .initial, .streaming: + return false + case .cancelled, .sourceFinished, .finished: + return true + case .modifying: + preconditionFailure("Invalid state") + } + } + + /// Initializes a new `StateMachine`. /// /// We are passing and holding the back-pressure strategy here because diff --git a/Sources/NIOCore/BSDSocketAPI.swift b/Sources/NIOCore/BSDSocketAPI.swift index c370644eca..07ba7bebcb 100644 --- a/Sources/NIOCore/BSDSocketAPI.swift +++ b/Sources/NIOCore/BSDSocketAPI.swift @@ -68,8 +68,13 @@ import Musl #endif import CNIOLinux +#if os(Android) +private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer, UnsafeMutablePointer, socklen_t) -> UnsafePointer? = inet_ntop +private let sysInet_pton: @convention(c) (CInt, UnsafePointer, UnsafeMutableRawPointer) -> CInt = inet_pton +#else private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = inet_ntop private let sysInet_pton: @convention(c) (CInt, UnsafePointer?, UnsafeMutableRawPointer?) -> CInt = inet_pton +#endif #elseif canImport(Darwin) import Darwin diff --git a/Sources/NIOCore/ChannelHandler.swift b/Sources/NIOCore/ChannelHandler.swift index 1e630cdbd4..93f11f537b 100644 --- a/Sources/NIOCore/ChannelHandler.swift +++ b/Sources/NIOCore/ChannelHandler.swift @@ -343,89 +343,3 @@ extension RemovableChannelHandler { context.leavePipeline(removalToken: removalToken) } } - -/// The result of protocol negotiation. -@_spi(AsyncChannel) -public struct NIOProtocolNegotiationResult { - fileprivate enum Result { - /// Indicates that the protocol negotiation finished. - case finished(NegotiationResult) - /// Indicates that protocol negotiation has been deferred to the next handler. - case deferredResult(EventLoopFuture>) - } - - private let _result: Result - - /// The final result of protocol negotiation. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - var result: NegotiationResult { - get async throws { - switch self._result { - case .finished(let negotiationResult): - return negotiationResult - case .deferredResult(let eventLoopFuture): - return try await eventLoopFuture.flatMap { $0._result.resolve(on: eventLoopFuture.eventLoop) }.get() - } - } - } - - /// Intializes a new ``NIOProtocolNegotiationResult`` with a final result. - /// - /// - Parameter result: The final result of protocol negotiation. - public init(result: NegotiationResult) { - self._result = .finished(result) - } - - /// Intializes a new ``NIOProtocolNegotiationResult`` with a deferred result. - /// - /// - Parameter deferredResult: The deferred result. - public init(deferredResult: EventLoopFuture>) { - self._result = .deferredResult(deferredResult) - } - - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @available(*, deprecated, renamed: "getResult") - public func waitForFinalResult() async throws -> NegotiationResult { - try await self.result - } -} - -extension EventLoopFuture { - /// Get the result/error from the protocol negotiation. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func getResult() async throws -> NegotiationResult where Value == NIOProtocolNegotiationResult { - try await self.get().result - } -} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result { - fileprivate func resolve(on eventLoop: EventLoop) -> EventLoopFuture { - Self.resolve(on: eventLoop, result: self) - } - - fileprivate static func resolve(on eventLoop: EventLoop, result: Self) -> EventLoopFuture { - switch result { - case .finished(let negotiationResult): - return eventLoop.makeSucceededFuture(negotiationResult) - - case .deferredResult(let future): - return future.flatMap { result in - return resolve(on: eventLoop, result: result._result) - } - } - } -} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult: Equatable where NegotiationResult: Equatable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult: Sendable where NegotiationResult: Sendable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result: Equatable where NegotiationResult: Equatable {} - -@_spi(AsyncChannel) -extension NIOProtocolNegotiationResult.Result: Sendable where NegotiationResult: Sendable {} diff --git a/Sources/NIOCore/ChannelHandlers.swift b/Sources/NIOCore/ChannelHandlers.swift index 89dbc594d6..18a9318f4b 100644 --- a/Sources/NIOCore/ChannelHandlers.swift +++ b/Sources/NIOCore/ChannelHandlers.swift @@ -105,10 +105,8 @@ public final class AcceptBackoffHandler: ChannelDuplexHandler, RemovableChannelH } } -#if swift(>=5.7) @available(*, unavailable) extension AcceptBackoffHandler: Sendable {} -#endif /** ChannelHandler implementation which enforces back-pressure by stopping to read from the remote peer when it cannot write back fast enough. @@ -157,10 +155,8 @@ public final class BackPressureHandler: ChannelDuplexHandler, RemovableChannelHa } } -#if swift(>=5.7) @available(*, unavailable) extension BackPressureHandler: Sendable {} -#endif /// Triggers an IdleStateEvent when a Channel has not performed read, write, or both operation for a while. public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandler { @@ -347,7 +343,5 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl } } -#if swift(>=5.7) @available(*, unavailable) extension IdleStateHandler: Sendable {} -#endif diff --git a/Sources/NIOCore/ChannelOption.swift b/Sources/NIOCore/ChannelOption.swift index 5f1f2f6adc..074561544f 100644 --- a/Sources/NIOCore/ChannelOption.swift +++ b/Sources/NIOCore/ChannelOption.swift @@ -19,7 +19,7 @@ public protocol ChannelOption: Equatable, _NIOPreconcurrencySendable { } public typealias SocketOptionName = Int32 -#if os(Linux) || os(Android) +#if (os(Linux) || os(Android)) && !canImport(Musl) public typealias SocketOptionLevel = Int public typealias SocketOptionValue = Int #else diff --git a/Sources/NIOCore/ChannelPipeline.swift b/Sources/NIOCore/ChannelPipeline.swift index d9d2c72ecf..246d86c0f0 100644 --- a/Sources/NIOCore/ChannelPipeline.swift +++ b/Sources/NIOCore/ChannelPipeline.swift @@ -946,9 +946,7 @@ public final class ChannelPipeline: ChannelInvoker { } } -#if swift(>=5.7) extension ChannelPipeline: @unchecked Sendable {} -#endif extension ChannelPipeline { /// Adds the provided channel handlers to the pipeline in the order given, taking account @@ -1274,10 +1272,8 @@ extension ChannelPipeline { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelPipeline.SynchronousOperations: Sendable {} -#endif extension ChannelPipeline { /// A `Position` within the `ChannelPipeline` used to insert handlers into the `ChannelPipeline`. @@ -1296,10 +1292,8 @@ extension ChannelPipeline { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelPipeline.Position: Sendable {} -#endif /// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation. /* private but tests */ final class HeadChannelHandler: _ChannelOutboundHandler { @@ -1853,10 +1847,8 @@ public final class ChannelHandlerContext: ChannelInvoker { } } -#if swift(>=5.7) @available(*, unavailable) extension ChannelHandlerContext: Sendable {} -#endif extension ChannelHandlerContext { /// A `RemovalToken` is handed to a `RemovableChannelHandler` when its `removeHandler` function is invoked. A diff --git a/Sources/NIOCore/Codec.swift b/Sources/NIOCore/Codec.swift index 5d79ab39b7..a46beac870 100644 --- a/Sources/NIOCore/Codec.swift +++ b/Sources/NIOCore/Codec.swift @@ -485,10 +485,8 @@ public final class ByteToMessageHandler { } } -#if swift(>=5.7) @available(*, unavailable) extension ByteToMessageHandler: Sendable {} -#endif // MARK: ByteToMessageHandler: Test Helpers extension ByteToMessageHandler { @@ -776,10 +774,8 @@ public final class MessageToByteHandler: ChannelO } } -#if swift(>=5.7) @available(*, unavailable) extension MessageToByteHandler: Sendable {} -#endif extension MessageToByteHandler { public func handlerAdded(context: ChannelHandlerContext) { diff --git a/Sources/NIOCore/Docs.docc/swift-concurrency.md b/Sources/NIOCore/Docs.docc/swift-concurrency.md index 90a260f282..7198e18260 100644 --- a/Sources/NIOCore/Docs.docc/swift-concurrency.md +++ b/Sources/NIOCore/Docs.docc/swift-concurrency.md @@ -2,17 +2,21 @@ This article explains how to interface between NIO and Swift Concurrency. -NIO was created before native Concurrency support in Swift existed, hence, NIO had to solve -a few problems that have solutions in the language today. Since the introduction of Swift Concurrency, -NIO has added numerous features to make the interop between NIO's ``Channel`` eventing system and Swift's -Concurrency primitives as easy as possible. +NIO was created before native Concurrency support in Swift existed, hence, NIO +had to solve a few problems that have solutions in the language today. Since the +introduction of Swift Concurrency, NIO has added numerous features to make the +interop between NIO's ``Channel`` eventing system and Swift's Concurrency +primitives as easy as possible. -### EventLoopFuture bridges +### EventLoopFuture/Promise bridges -The first bridges that NIO introduced added methods on ``EventLoopFuture`` and ``EventLoopPromise`` -to enable communication between Concurrency and NIO. These methods are ``EventLoopFuture/get()`` and ``EventLoopPromise/completeWithTask(_:)``. +The first bridges that NIO introduced added methods on ``EventLoopFuture`` and +``EventLoopPromise`` to enable communication between Concurrency and NIO. These +methods are ``EventLoopFuture/get()`` and +``EventLoopPromise/completeWithTask(_:)``. -> Warning: The future ``EventLoopFuture/get()`` method does not support task cancellation. +> Warning: The future ``EventLoopFuture/get()`` method does not support task +> cancellation. Here is a small example of how these work: @@ -29,259 +33,282 @@ promise.completeWithTask { let result = try await promise.futureResult.get() ``` -> Note: The `completeWithTask` method creates an unstructured task under the hood. +> Note: The `completeWithTask` method creates an unstructured task under the +> hood. ### Channel bridges -The ``EventLoopFuture`` and ``EventLoopPromise`` bridges already allow async code to interact with -some of NIO's types. However, they only work where we have request-response-like interfaces. -On the other hand, NIO's ``Channel`` type contains a ``ChannelPipeline`` which can be roughly -described as a bi-directional streaming pipeline. To bridge such a pipeline into Concurrency required -new types. Importantly, these types need to uphold the channel's back-pressure and writability guarantees. -NIO introduced the ``NIOThrowingAsyncSequenceProducer``, ``NIOAsyncSequenceProducer`` and the ``NIOAsyncWriter`` -which form the foundation to bridge a ``Channel``. -On top of these foundational types, NIO provides the `NIOAsyncChannel` which is used to wrap a -``Channel`` to produce an interface that can be consumed directly from Swift Concurrency. The following -sections cover the details of the foundational types and how the `NIOAsyncChannel` works. +The ``EventLoopFuture`` and ``EventLoopPromise`` bridges already allow async +code to interact with some of NIO's types. However, they only work where we have +request-response-like interfaces. On the other hand, NIO's ``Channel`` type +contains a ``ChannelPipeline`` which can be roughly described as a +bi-directional streaming pipeline. To bridge such a pipeline into Concurrency +required new types. Importantly, these types need to uphold the channel's +back pressure and writability guarantees. NIO introduced the +``NIOThrowingAsyncSequenceProducer``, ``NIOAsyncSequenceProducer`` and the +``NIOAsyncChannelOutboundWriter`` which form the foundation to bridge a ``Channel``. On top of +these foundational types, NIO provides the `NIOAsyncChannel` which is used to +wrap a ``Channel`` to produce an interface that can be consumed directly from +Swift Concurrency. The following sections cover the details of the foundational +types and how the `NIOAsyncChannel` works. #### NIOThrowingAsyncSequenceProducer and NIOAsyncSequenceProducer -The ``NIOThrowingAsyncSequenceProducer`` and ``NIOAsyncSequenceProducer`` are asynchronous sequences -similar to Swift's `AsyncStream`. Their purpose is to provide a back-pressured bridge between a -synchronous producer and an asynchronous consumer. These types are highly configurable and generic which -makes them usable in a lot of places with very good performance; however, at the same time they are -not the easiest types to hold. We recommend that you **never** expose them in public API but rather -wrap them in your own async sequence. +The ``NIOThrowingAsyncSequenceProducer`` and ``NIOAsyncSequenceProducer`` are +asynchronous sequences similar to Swift's `AsyncStream`. Their purpose is to +provide a back pressured bridge between a synchronous producer and an +asynchronous consumer. These types are highly configurable and generic which +makes them usable in a lot of places with very good performance; however, at the +same time they are not the easiest types to hold. We recommend that you +**never** expose them in public API but rather wrap them in your own async +sequence. #### NIOAsyncWriter -The ``NIOAsyncWriter`` is used for bridging from an asynchronous producer to a synchronous consumer. -It also has back-pressure support which allows the consumer to stop the producer by suspending the +The ``NIOAsyncChannelOutboundWriter`` is used for bridging from an asynchronous producer to a +synchronous consumer. It also has back pressure support which allows the +consumer to stop the producer by suspending the ``NIOAsyncWriter/yield(contentsOf:)`` method. - -> Important: Everything below this is currently not public API but can be tested it by using `@_spi(AsyncChannel) import`. -The APIs might change until they become publicly available. - #### NIOAsyncChannel -The above types are used to bridge both the read and write side of a ``Channel`` into Swift Concurrency. -This can be done by wrapping a ``Channel`` via the `NIOAsyncChannel/init(synchronouslyWrapping:backpressureStrategy:isOutboundHalfClosureEnabled:inboundType:outboundType:)` -initializer. Under the hood, this initializer adds two channel handlers to the end of the channel's pipeline. -These handlers bridge the read and write side of the channel. Additionally, the handlers work together -to close the channel once both the reading and the writing have finished. - +The above types are used to bridge both the read and write side of a ``Channel`` +into Swift Concurrency. This can be done by wrapping a ``Channel`` via the +`NIOAsyncChannel/init(synchronouslyWrapping:configuration:)` +initializer. Under the hood, this initializer adds two channel handlers to the +end of the channel's pipeline. These handlers bridge the read and write side of +the channel. Additionally, the handlers work together to close the channel once +both the reading and the writing have finished. -This is how you can wrap an existing channel into a `NIOAsyncChannel`, consume the inbound data and -echo it back outbound. +This is how you can wrap an existing channel into a `NIOAsyncChannel`, consume +the inbound data and echo it back outbound. ```swift let channel = ... -let asyncChannel = try NIOAsyncChannel(synchronouslyWrapping: channel, inboundType: ByteBuffer.self, outboundType: ByteBuffer.self) +let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) -for try await inboundData in asyncChannel.inboundStream { - try await asyncChannel.outboundWriter.write(inboundData) +try await asyncChannel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + try await outbound.write(inboundData) + } } ``` -The above code works nicely; however, you must be very careful at what point you wrap your channel -otherwise you might lose some reads. For example your channel might be created by a `ServerBootstrap` -for a new inbound connection. The channel might start to produce reads as soon as it registered its -IO which happens after your channel initializer ran. To avoid potentially losing reads the channel -must be wrapped before it registered its IO. -Another example is when the channel contains a handler that does protocol negotiation. Protocol negotiation handlers -are usually waiting for some data to be exchanged before deciding what protocol to chose. Afterwards, they -often modify the channel's pipeline and add the protocol appropriate handlers to it. This is another -case where wrapping of the `Channel` into a `NIOAsyncChannel` needs to happen at the right time to avoid -losing reads. +The above code works nicely; however, you must be very careful at what point you +wrap your channel otherwise you might lose some reads. For example your channel +might be created by a `ServerBootstrap` for a new inbound connection. The +channel might start to produce reads as soon as it registered its IO which +happens after your channel initializer ran. To avoid potentially losing reads +the channel must be wrapped before it registered its IO. Another example is when +the channel contains a handler that does protocol negotiation. Protocol +negotiation handlers are usually waiting for some data to be exchanged before +deciding what protocol to chose. Afterwards, they often modify the channel's +pipeline and add the protocol appropriate handlers to it. This is another case +where wrapping of the `Channel` into a `NIOAsyncChannel` needs to happen at the +right time to avoid losing reads. -### Async bootstrap +### Asynchronous bootstrap methods -NIO offers three different kind of bootstraps `ServerBootstrap`, `ClientBootstrap` and `DatagramBootstrap`. -The next section is going to focus on how to use the methods of these three types to bootstrap connections -using `NIOAsyncChannel`. +NIO offers a multitude of bootstraps. To avoid the above problems +and enable a seamless experience when using NIO from Swift Concurrency, +the bootstraps gained new generic asynchronous methods. +The next section is going to focus on how to use the methods to boostrap a TCP +server and client. #### ServerBootstrap -The server bootstrap is used to create a new TCP based server. Once any of the bind methods on the `ServerBootstrap` -is called, a new listening socket is created to handle new inbound TCP connections. Let's take a look -at the new `NIOAsyncChannel` based bind methods. +The server bootstrap is used to create a new TCP based server. Once any of the +bind methods on the `ServerBootstrap` is called, a new listening socket is +created to handle new inbound TCP connections. Let's use the new methods +to setup a TCP server and configure a `NIOAsyncChannel` for each inbound +connection. ```swift let serverChannel = try await ServerBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", - port: 0, - childChannelInboundType: ByteBuffer.self, - childChannelOutboundType: ByteBuffer.self - ) + port: 1234 + ) { childChannel in + // This closure is called for every inbound connection + childChannel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + synchronouslyWrapping: childChannel + ) + } + } try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in serverChannel.inboundStream { - group.addTask { - do { - for try await inboundData in connectionChannel.inboundStream { - try await connectionChannel.outboundWriter.write(inboundData) + try await serverChannel.executeThenClose { serverChannelInbound in + for try await connectionChannel in serverChannelInbound { + group.addTask { + do { + try await connectionChannel.executeThenClose { connectionChannelInbound, connectionChannelOutbound in + for try await inboundData in connectionChannelInbound { + // Let's echo back all inbound data + try await connectionChannelOutbound.write(inboundData) + } + } + } catch { + // Handle errors } - } catch { - // Handle errors } } } } ``` -In the above code, we are bootstrapping a new TCP server which we assign to `serverChannel`. -The `serverChannel` is a `NIOAsyncChannel` whose inbound type is a `NIOAsyncChannel` and whose -outbound type is `Never`. This is due to the fact that each inbound connection gets its own separate child channel. -The inbound and outbound types of each inbound connection is `ByteBuffer` as specified in the bootstrap. -Afterwards, we handle each inbound connection in separate child tasks and echo the data back. - -> Important: Make sure to use discarding task groups which automatically reap finished child tasks. -Normal task groups will result in a memory leak since they do not reap their child tasks automatically. +In the above code, we are bootstrapping a new TCP server which we assign to +`serverChannel`. Furthermore, in the trailing closure of `bind` we are +configuring the pipeline of each inbound connection. In our example, we are +wrapping each child channel in a `NIOAsyncChannel`. The resulting +`serverChannel` is a `NIOAsyncChannel` whose inbound type is a `NIOAsyncChannel` +and whose outbound type is `Never`. This is due to the fact that each inbound +connection gets its own separate child channel. The inbound and outbound types +of each inbound connection is `ByteBuffer` as specified in the bootstrap. +Afterwards, we handle each inbound connection in separate child tasks and echo +the data back. + +> Important: Make sure to use discarding task groups which automatically reap +finished child tasks. Normal task groups will result in a memory leak since they +do not reap their child tasks automatically. #### ClientBootstrap -The client bootstrap is used to create a new TCP based client. Let's take a look at the new -`NIOAsyncChannel` based connect methods. +The client bootstrap is used to create a new TCP based client. Let's take a look +how to bootstrap a TCP connection and send some data to the server. ```swift let clientChannel = try await ClientBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", - port: 0, - channelInboundType: ByteBuffer.self, - channelOutboundType: ByteBuffer.self - ) + port: 1234 + ) { channel in + channel.eventLoop.makeCompletedFuture { + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel + ) + } + } -clientChannel.outboundWriter.write(ByteBuffer(string: "hello")) +try await clientChannel.executeThenClose { inbound, outbound in + try await outbound.write(ByteBuffer(string: "hello")) -for try await inboundData in clientChannel.inboundStream { - print(inboundData) + for try await inboundData in inbound { + print(inboundData) + } } ``` -#### DatagramBootstrap -> Important: Support for `DatagramBootstrap` with `NIOAsyncChannel` hasn't landed yet. - -#### Protocol negotiation - -The above bootstrap methods work great in the case where we know the types of the resulting channels -at compile time. However, as mentioned previously protocol negotiation is another case where the timing -of wrapping the ``Channel`` is important that we haven't covered with the `bind` methods that take -an inbound and outbound type yet. -To solve the problem of protocol negotiation, NIO introduced a new ``ChannelHandler`` protocol called -`NIOProtocolNegotiationHandler`. This protocol requires a single future property `NIOProtocolNegotiationHandler/protocolNegotiationResult` -that is completed once the handler is finished with protocol negotiation. In the successful case, -the future can either indicate that protocol negotiation is fully done by returning `NIOProtocolNegotiationResult/finished(_:)` or -indicate that further protocol negotiation needs to be done by returning `NIOProtocolNegotiationResult/deferredResult(_:)`. -Additionally, the various bootstraps provide another set of `bind()`/`connect()` methods that handle protocol negotiation. -Let's walk through how to setup a `ServerBootstrap` with protocol negotiation. - -First, we have to define our negotiation result. For this example, we are negotiating between a -`String` based and `UInt8` based channel. Additionally, we also need an error that we can throw -if protocol negotiation failed. +#### Dynamic pipeline modifications + +The above bootstrap methods work great in the case where we know the types of +the resulting channels at compile time. However, there are some scenarios where +the type is only determined at runtime. Such cases include +[Application-Layer-Protocol-Negotiation](https://en.wikipedia.org/wiki/Application-Layer_Protocol_Negotiation) +or [HTTP protocol +upgrades](https://en.wikipedia.org/wiki/HTTP/1.1_Upgrade_header). To support +those scenarios it is essential that channel handlers that dynamically configure +the pipeline carry type information which allows us to runtime to determine how +the pipeline was configured at runtime. To support this NIO introduced multiple +new `ChannelHandler` and corresponding pipeline configuration methods. Those +types are: + +1. `NIOTypedApplicationProtocolNegotiationHandler` for TLS based ALPN +2. `NIOTypedHTTPServerUpgradeHandler` and + `configureUpgradableHTTPServerPipeline` for server-side HTTP protocol + upgrades +2. `NIOTypedHTTPClientUpgradeHandler` and + `configureUpgradableHTTPClientPipeline` for client-side HTTP protocol + upgrades + +All of those types have one thing in common - they are generic over the result +of the dynamic pipeline configuration. This allows users to exhaustively switch +over the result and correctly handle each case. The following example +demonstrates how this works for a client-side websocket upgrade. + ```swift -enum NegotiationResult { - case string(NIOAsyncChannel) - case byte(NIOAsyncChannel) +enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded } -struct ProtocolNegotiationError: Error {} -``` - -Next, we have to setup our bootstrap. We are adding a `NIOTypedApplicationProtocolNegotiationHandler` -to each child channel's pipeline. This handler listens for user inbound events of the type `TLSUserEvent` -and then calls the provided closure with the result. In our example, we are handling either `string` -or `byte` application protocols. Importantly, we now have to wrap the channel into a `NIOAsyncChannel` ourselves and -return the finished `NIOProtocolNegotiationResult`. -```swift -let serverBoostrap = try await ServerBootstrap(group: eventLoopGroup) - .childChannelInitializer { channel in +let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: eventLoopGroup) + .connect( + host: "127.0.0.1", + port: 1234 + ) { channel in channel.eventLoop.makeCompletedFuture { - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: channel.eventLoop) { alpnResult, channel in - switch alpnResult { - case .negotiated(let alpn): - switch alpn { - case "string": - return channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel, - isOutboundHalfClosureEnabled: true, - inboundType: String.self, - outboundType: String.self - ) - - return NIOProtocolNegotiationResult.finished(NegotiationResult.string(asyncChannel)) - } - case "byte": - return channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel, - isOutboundHalfClosureEnabled: true, - inboundType: UInt8.self, - outboundType: UInt8.self - ) - - return NIOProtocolNegotiationResult.finished(NegotiationResult.byte(asyncChannel)) - } - default: - return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError()) + // Configure the websocket upgrader + let upgrader = NIOTypedWebSocketClientUpgrader( + upgradePipelineHandler: { channel, _ in + // This configures the pipeline after the websocket upgrade was successful. + // We are wrapping the pipeline in a NIOAsyncChannel. + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + return UpgradeResult.websocket(asyncChannel) } - case .fallback: - return channel.eventLoop.makeFailedFuture(ProtocolNegotiationError()) } - } + ) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "0") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + return UpgradeResult.notUpgraded + } + } + ) - try channel.pipeline.syncOperations.addHandler(negotiationHandler) + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) + + return upgradeResult } } ``` -Lastly, we can now bind the `serverChannel` and handle the incoming connections. In the code below, -you can see that our server channel is now a `NIOAsyncChannel` of `NegotiationResult`s instead of -child channels. -```swift -let serverChannel = serverBootstrap.bind( - host: "127.0.0.1", - port: 1995, - protocolNegotiationHandlerType: NIOTypedApplicationProtocolNegotiationHandler.self -) +After having configured the pipeline to negotiate a websocket upgrade. We can +switch over the the `upgradeResult`. Importantly, we have to `await` the +`upgradeResult` first since it has to be negotiated on the connection. -try await withThrowingDiscardingTaskGroup { group in - for try await negotiationResult in serverChannel.inboundStream { - group.addTask { - do { - switch negotiationResult { - case .string(let channel): - for try await inboundData in channel.inboundStream { - try await channel.outboundWriter.write(inboundData) - } - case .byte(let channel): - for try await value in channel.inboundStream { - try await channel.outboundWriter.write(inboundData) - } - } - } catch { - // Handle errors - } - } - } +``` +switch try await upgradeResult.get() { +case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") +case .notUpgraded: + // The upgrade to websocket did not succeed. + print("Upgrade declined") } ``` - ### General guidance #### Where should your code live? -Before the introduction of Swift Concurrency both implementations of network protocols and business logic -were often written inside ``ChannelHandler``s. This made it easier to get started; however, it came with -some downsides. First, implementing business logic inside channel handlers requires the business logic to -also handle all of the invariants that the ``ChannelHandler`` protocol brings with it. This often requires -writing complex state machines. Additionally, the business logic becomes very tied to NIO and hard to -port between different systems. -Because of the above reasons we recommend to implement your business logic using Swift Concurrency primitives and the -`NIOAsyncChannel` based bootstraps. Network protocol implementation should still be implemented as +Before the introduction of Swift Concurrency both implementations of network +protocols and business logic were often written inside ``ChannelHandler``s. This +made it easier to get started; however, it came with some downsides. First, +implementing business logic inside channel handlers requires the business logic +to also handle all of the invariants that the ``ChannelHandler`` protocol brings +with it. This often requires writing complex state machines. Additionally, the +business logic becomes very tied to NIO and hard to port between different +systems. Because of the above reasons we recommend to implement your business +logic using Swift Concurrency primitives and the `NIOAsyncChannel` based +bootstraps. Network protocol implementation should still be implemented as ``ChannelHandler``s. diff --git a/Sources/NIOCore/ActorExecutor.swift b/Sources/NIOCore/EventLoop+SerialExecutor.swift similarity index 92% rename from Sources/NIOCore/ActorExecutor.swift rename to Sources/NIOCore/EventLoop+SerialExecutor.swift index 2842181e4a..f157701778 100644 --- a/Sources/NIOCore/ActorExecutor.swift +++ b/Sources/NIOCore/EventLoop+SerialExecutor.swift @@ -25,6 +25,9 @@ public protocol NIOSerialEventLoopExecutor: EventLoop, SerialExecutor { } extension NIOSerialEventLoopExecutor { @inlinable public func enqueue(_ job: consuming ExecutorJob) { + // By default we are just going to use execute to run the job + // this is quite heavy since it allocates the closure for + // every single job. let unownedJob = UnownedJob(job) self.execute { unownedJob.runSynchronously(on: self.asUnownedSerialExecutor()) @@ -62,10 +65,7 @@ final class NIODefaultSerialEventLoopExecutor { extension NIODefaultSerialEventLoopExecutor: SerialExecutor { @inlinable public func enqueue(_ job: consuming ExecutorJob) { - let unownedJob = UnownedJob(job) - self.loop.execute { - unownedJob.runSynchronously(on: self.asUnownedSerialExecutor()) - } + self.loop.enqueue(job) } @inlinable diff --git a/Sources/NIOCore/EventLoop.swift b/Sources/NIOCore/EventLoop.swift index 6675f118f9..098e6f1652 100644 --- a/Sources/NIOCore/EventLoop.swift +++ b/Sources/NIOCore/EventLoop.swift @@ -23,28 +23,16 @@ import CNIOLinux /// A `Scheduled` allows the user to either `cancel()` the execution of the scheduled task (if possible) or obtain a reference to the `EventLoopFuture` that /// will be notified once the execution is complete. public struct Scheduled { - #if swift(>=5.7) @usableFromInline typealias CancelationCallback = @Sendable () -> Void - #else - @usableFromInline typealias CancelationCallback = () -> Void - #endif /* private but usableFromInline */ @usableFromInline let _promise: EventLoopPromise /* private but usableFromInline */ @usableFromInline let _cancellationTask: CancelationCallback - #if swift(>=5.7) @inlinable @preconcurrency public init(promise: EventLoopPromise, cancellationTask: @escaping @Sendable () -> Void) { self._promise = promise self._cancellationTask = cancellationTask } - #else - @inlinable - public init(promise: EventLoopPromise, cancellationTask: @escaping () -> Void) { - self._promise = promise - self._cancellationTask = cancellationTask - } - #endif /// Try to cancel the execution of the scheduled task. /// @@ -63,19 +51,13 @@ public struct Scheduled { } } -#if swift(>=5.7) extension Scheduled: Sendable where T: Sendable {} -#endif /// Returned once a task was scheduled to be repeatedly executed on the `EventLoop`. /// /// A `RepeatedTask` allows the user to `cancel()` the repeated scheduling of further tasks. public final class RepeatedTask { - #if swift(>=5.7) typealias RepeatedTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture - #else - typealias RepeatedTaskCallback = (RepeatedTask) -> EventLoopFuture - #endif private let delay: TimeAmount private let eventLoop: EventLoop private let cancellationPromise: EventLoopPromise? @@ -196,9 +178,7 @@ public final class RepeatedTask { } } -#if swift(>=5.7) extension RepeatedTask: @unchecked Sendable {} -#endif /// An iterator over the `EventLoop`s forming an `EventLoopGroup`. /// @@ -226,9 +206,7 @@ public struct EventLoopIterator: Sequence, IteratorProtocol { } } -#if swift(>=5.7) extension EventLoopIterator: Sendable {} -#endif /// An EventLoop processes IO / tasks in an endless loop for `Channel`s until it's closed. /// @@ -270,16 +248,10 @@ public protocol EventLoop: EventLoopGroup { /// If it is necessary for correctness to confirm that you're on an event loop, prefer ``preconditionInEventLoop(file:line:)-7ukrq``. var inEventLoop: Bool { get } - #if swift(>=5.7) /// Submit a given task to be executed by the `EventLoop` @preconcurrency func execute(_ task: @escaping @Sendable () -> Void) - #else - /// Submit a given task to be executed by the `EventLoop` - func execute(_ task: @escaping () -> Void) - #endif - - #if swift(>=5.7) + /// Submit a given task to be executed by the `EventLoop`. Once the execution is complete the returned `EventLoopFuture` is notified. /// /// - parameters: @@ -287,16 +259,6 @@ public protocol EventLoop: EventLoopGroup { /// - returns: `EventLoopFuture` that is notified once the task was executed. @preconcurrency func submit(_ task: @escaping @Sendable () throws -> T) -> EventLoopFuture - #else - /// Submit a given task to be executed by the `EventLoop`. Once the execution is complete the returned `EventLoopFuture` is notified. - /// - /// - parameters: - /// - task: The closure that will be submitted to the `EventLoop` for execution. - /// - returns: `EventLoopFuture` that is notified once the task was executed. - func submit(_ task: @escaping () throws -> T) -> EventLoopFuture - #endif - - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` at the given time. /// @@ -309,20 +271,6 @@ public protocol EventLoop: EventLoopGroup { @discardableResult @preconcurrency func scheduleTask(deadline: NIODeadline, _ task: @escaping @Sendable () throws -> T) -> Scheduled - #else - /// Schedule a `task` that is executed by this `EventLoop` at the given time. - /// - /// - parameters: - /// - task: The synchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the completion of the task. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - func scheduleTask(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled - #endif - - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. /// @@ -336,19 +284,6 @@ public protocol EventLoop: EventLoopGroup { @discardableResult @preconcurrency func scheduleTask(in: TimeAmount, _ task: @escaping @Sendable () throws -> T) -> Scheduled - #else - /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. - /// - /// - parameters: - /// - task: The synchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the completion of the task. - /// - /// - note: You can only cancel a task before it has started executing. - /// - note: The `in` value is clamped to a maximum when running on a Darwin-kernel. - @discardableResult - func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled - #endif /// Asserts that the current thread is the one tied to this `EventLoop`. /// Otherwise, the process will be abnormally terminated as per the semantics of `preconditionFailure(_:file:line:)`. @@ -382,6 +317,10 @@ public protocol EventLoop: EventLoopGroup { /// implementation returns a ``NIODefaultSerialEventLoopExecutor`` instead, which provides suboptimal performance. @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) var executor: any SerialExecutor { get } + + /// Submit a job to be executed by the `EventLoop` + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + func enqueue(_ job: consuming ExecutorJob) #endif /// Must crash if it is not safe to call `wait()` on an `EventLoopFuture`. @@ -446,6 +385,18 @@ extension EventLoop { public var executor: any SerialExecutor { NIODefaultSerialEventLoopExecutor(self) } + + @inlinable + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + public func enqueue(_ job: consuming ExecutorJob) { + // By default we are just going to use execute to run the job + // this is quite heavy since it allocates the closure for + // every single job. + let unownedJob = UnownedJob(job) + self.execute { + unownedJob.runSynchronously(on: self.executor.asUnownedSerialExecutor()) + } + } #endif } @@ -485,9 +436,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of microseconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func microseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * 1000) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000)) } /// Creates a new `TimeAmount` for the given amount of milliseconds. @@ -495,9 +448,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of milliseconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func milliseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of seconds. @@ -505,9 +460,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of seconds this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func seconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of minutes. @@ -515,9 +472,11 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of minutes this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func minutes(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000 * 60)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60)) } /// Creates a new `TimeAmount` for the given amount of hours. @@ -525,9 +484,27 @@ public struct TimeAmount: Hashable, Sendable { /// - parameters: /// - amount: the amount of hours this `TimeAmount` represents. /// - returns: the `TimeAmount` for the given amount. + /// + /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func hours(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount * (1000 * 1000 * 1000 * 60 * 60)) + return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60 * 60)) + } + + /// Converts `amount` to nanoseconds multiplying it by `multiplier`. The return value is capped to `Int64.max` if the multiplication overflows and `Int64.min` if it underflows. + /// + /// - parameters: + /// - amount: the amount to be converted to nanoseconds. + /// - multiplier: the multiplier that converts the given amount to nanoseconds. + /// - returns: the amount converted to nanoseconds within [Int64.min, Int64.max]. + @inlinable + static func _cappedNanoseconds(amount: Int64, multiplier: Int64) -> Int64 { + let nanosecondsMultiplication = amount.multipliedReportingOverflow(by: multiplier) + if nanosecondsMultiplication.overflow { + return amount >= 0 ? .max : .min + } else { + return nanosecondsMultiplication.partialValue + } } } @@ -725,7 +702,6 @@ extension NIODeadline { } extension EventLoop { - #if swift(>=5.7) /// Submit `task` to be run on this `EventLoop`. /// /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be succeeded with @@ -740,22 +716,7 @@ extension EventLoop { _submit(task) } @usableFromInline typealias SubmitCallback = @Sendable () throws -> T - #else - /// Submit `task` to be run on this `EventLoop`. - /// - /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be succeeded with - /// `task`'s return value or failed if the execution of `task` threw an error. - /// - /// - parameters: - /// - task: The synchronous task to run. As everything that runs on the `EventLoop`, it must not block. - /// - returns: An `EventLoopFuture` containing the result of `task`'s execution. - @inlinable - public func submit(_ task: @escaping () throws -> T) -> EventLoopFuture { - _submit(task) - } - @usableFromInline typealias SubmitCallback = () throws -> T - #endif - + @inlinable func _submit(_ task: @escaping SubmitCallback) -> EventLoopFuture { let promise: EventLoopPromise = makePromise(file: #fileID, line: #line) @@ -771,7 +732,6 @@ extension EventLoop { return promise.futureResult } - #if swift(>=5.7) /// Submit `task` to be run on this `EventLoop`. /// /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be identical to @@ -786,28 +746,12 @@ extension EventLoop { self._flatSubmit(task) } @usableFromInline typealias FlatSubmitCallback = @Sendable () -> EventLoopFuture - #else - /// Submit `task` to be run on this `EventLoop`. - /// - /// The returned `EventLoopFuture` will be completed when `task` has finished running. It will be identical to - /// the `EventLoopFuture` returned by `task`. - /// - /// - parameters: - /// - task: The asynchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: An `EventLoopFuture` identical to the `EventLoopFuture` returned from `task`. - @inlinable - public func flatSubmit(_ task: @escaping () -> EventLoopFuture) -> EventLoopFuture { - self._flatSubmit(task) - } - @usableFromInline typealias FlatSubmitCallback = () -> EventLoopFuture - #endif - + @inlinable func _flatSubmit(_ task: @escaping FlatSubmitCallback) -> EventLoopFuture { self.submit(task).flatMap { $0 } } - - #if swift(>=5.7) + /// Schedule a `task` that is executed by this `EventLoop` at the given time. /// /// - parameters: @@ -828,28 +772,7 @@ extension EventLoop { self._flatScheduleTask(deadline: deadline, file: file, line: line, task) } @usableFromInline typealias FlatScheduleTaskDeadlineCallback = () throws -> EventLoopFuture - #else - /// Schedule a `task` that is executed by this `EventLoop` at the given time. - /// - /// - parameters: - /// - task: The asynchronous task to run. As with everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the full execution of the task, including its returned `EventLoopFuture`. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - @inlinable - public func flatScheduleTask( - deadline: NIODeadline, - file: StaticString = #fileID, - line: UInt = #line, - _ task: @escaping () throws -> EventLoopFuture - ) -> Scheduled { - self._flatScheduleTask(deadline: deadline, file: file, line: line, task) - } - @usableFromInline typealias FlatScheduleTaskDeadlineCallback = () throws -> EventLoopFuture - #endif - + @discardableResult @inlinable func _flatScheduleTask( @@ -865,7 +788,6 @@ extension EventLoop { return .init(promise: promise, cancellationTask: { scheduled.cancel() }) } - #if swift(>=5.7) /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. /// /// - parameters: @@ -885,29 +807,9 @@ extension EventLoop { ) -> Scheduled { self._flatScheduleTask(in: delay, file: file, line: line, task) } - @usableFromInline typealias FlatScheduleTaskDelayCallback = @Sendable () throws -> EventLoopFuture - #else - /// Schedule a `task` that is executed by this `EventLoop` after the given amount of time. - /// - /// - parameters: - /// - task: The asynchronous task to run. As everything that runs on the `EventLoop`, it must not block. - /// - returns: A `Scheduled` object which may be used to cancel the task if it has not yet run, or to wait - /// on the full execution of the task, including its returned `EventLoopFuture`. - /// - /// - note: You can only cancel a task before it has started executing. - @discardableResult - @inlinable - public func flatScheduleTask( - in delay: TimeAmount, - file: StaticString = #fileID, - line: UInt = #line, - _ task: @escaping () throws -> EventLoopFuture - ) -> Scheduled { - self._flatScheduleTask(in: delay, file: file, line: line, task) - } - @usableFromInline typealias FlatScheduleTaskDelayCallback = () throws -> EventLoopFuture - #endif + @usableFromInline typealias FlatScheduleTaskDelayCallback = @Sendable () throws -> EventLoopFuture + @inlinable func _flatScheduleTask( in delay: TimeAmount, @@ -998,7 +900,6 @@ extension EventLoop { // Do nothing } - #if swift(>=5.7) /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each /// task. @@ -1019,14 +920,14 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } - typealias ScheduleRepeatedTaskCallback = @Sendable (RepeatedTask) throws -> Void - #else + /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each /// task. /// /// - parameters: /// - initialDelay: The delay after which the first task is executed. /// - delay: The delay between the end of one task and the start of the next. + /// - maximumAllowableJitter: Exclusive upper bound of jitter range added to the `delay` parameter. /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. /// - task: The closure that will be executed. /// - return: `RepeatedTask` @@ -1034,14 +935,16 @@ extension EventLoop { public func scheduleRepeatedTask( initialDelay: TimeAmount, delay: TimeAmount, + maximumAllowableJitter: TimeAmount, notifying promise: EventLoopPromise? = nil, - _ task: @escaping (RepeatedTask) throws -> Void + _ task: @escaping @Sendable (RepeatedTask) throws -> Void ) -> RepeatedTask { - self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) + let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) + return self.scheduleRepeatedTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) } - typealias ScheduleRepeatedTaskCallback = (RepeatedTask) throws -> Void - #endif - + typealias ScheduleRepeatedTaskCallback = @Sendable (RepeatedTask) throws -> Void + func _scheduleRepeatedTask( initialDelay: TimeAmount, delay: TimeAmount, @@ -1059,7 +962,6 @@ extension EventLoop { return self.scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, futureTask) } - #if swift(>=5.7) /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and /// start of each task. /// @@ -1086,8 +988,7 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } - typealias ScheduleRepeatedAsyncTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture - #else + /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and /// start of each task. /// @@ -1099,6 +1000,7 @@ extension EventLoop { /// - parameters: /// - initialDelay: The delay after which the first task is executed. /// - delay: The delay between the end of one task and the start of the next. + /// - maximumAllowableJitter: Exclusive upper bound of jitter range added to the `delay` parameter. /// - promise: If non-nil, a promise to fulfill when the task is cancelled and all execution is complete. /// - task: The closure that will be executed. Task will keep repeating regardless of whether the future /// gets fulfilled with success or error. @@ -1108,14 +1010,16 @@ extension EventLoop { public func scheduleRepeatedAsyncTask( initialDelay: TimeAmount, delay: TimeAmount, + maximumAllowableJitter: TimeAmount, notifying promise: EventLoopPromise? = nil, - _ task: @escaping (RepeatedTask) -> EventLoopFuture + _ task: @escaping @Sendable (RepeatedTask) -> EventLoopFuture ) -> RepeatedTask { - self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) + let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) + return self._scheduleRepeatedAsyncTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) } - typealias ScheduleRepeatedAsyncTaskCallback = (RepeatedTask) -> EventLoopFuture - #endif - + typealias ScheduleRepeatedAsyncTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture + func _scheduleRepeatedAsyncTask( initialDelay: TimeAmount, delay: TimeAmount, @@ -1126,7 +1030,21 @@ extension EventLoop { repeated.begin(in: initialDelay) return repeated } - + + /// Adds a random amount of `.nanoseconds` (within `.zero.. TimeAmount { + let jitter = TimeAmount.nanoseconds(Int64.random(in: .zero.. EventLoop - #if swift(>=5.7) /// Shuts down the eventloop gracefully. This function is clearly an outlier in that it uses a completion /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue /// instead. @preconcurrency func shutdownGracefully(queue: DispatchQueue, _ callback: @Sendable @escaping (Error?) -> Void) - #else - /// Shuts down the eventloop gracefully. This function is clearly an outlier in that it uses a completion - /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. - /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue - /// instead. - func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) - #endif - + /// Returns an `EventLoopIterator` over the `EventLoop`s in this `EventLoopGroup`. /// /// - returns: `EventLoopIterator` @@ -1245,27 +1155,14 @@ extension EventLoopGroup { } extension EventLoopGroup { - #if swift(>=5.7) @preconcurrency public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void) { self.shutdownGracefully(queue: .global(), callback) } - #else - public func shutdownGracefully(_ callback: @escaping (Error?) -> Void) { - self.shutdownGracefully(queue: .global(), callback) - } - #endif - - #if swift(>=5.7) @available(*, noasync, message: "this can end up blocking the calling thread", renamed: "shutdownGracefully()") public func syncShutdownGracefully() throws { try self._syncShutdownGracefully() } - #else - public func syncShutdownGracefully() throws { - try self._syncShutdownGracefully() - } - #endif private func _syncShutdownGracefully() throws { self._preconditionSafeToSyncShutdown(file: #fileID, line: #line) @@ -1306,9 +1203,7 @@ public enum NIOEventLoopGroupProvider { case createNew } -#if swift(>=5.7) extension NIOEventLoopGroupProvider: Sendable {} -#endif /// Different `Error`s that are specific to `EventLoop` operations / implementations. public enum EventLoopError: Error { diff --git a/Sources/NIOCore/EventLoopFuture.swift b/Sources/NIOCore/EventLoopFuture.swift index 6c79a08ef9..185e00fecc 100644 --- a/Sources/NIOCore/EventLoopFuture.swift +++ b/Sources/NIOCore/EventLoopFuture.swift @@ -24,13 +24,8 @@ import Dispatch /// This eliminates recursion when processing `flatMap()` chains. @usableFromInline internal struct CallbackList { - #if swift(>=5.7) @usableFromInline internal typealias Element = @Sendable () -> CallbackList - #else - @usableFromInline - internal typealias Element = () -> CallbackList - #endif @usableFromInline internal var firstCallback: Optional @usableFromInline @@ -196,7 +191,7 @@ public struct EventLoopPromise { public func fail(_ error: Error) { self._resolve(value: .failure(error)) } - + /// Complete the promise with the passed in `EventLoopFuture`. /// /// This method is equivalent to invoking `future.cascade(to: promise)`, @@ -259,7 +254,6 @@ public struct EventLoopPromise { } } - /// Holder for a result that will be provided later. /// /// Functions that promise to do work asynchronously can return an `EventLoopFuture`. @@ -447,7 +441,6 @@ extension EventLoopFuture: Equatable { // 'flatMap' and 'map' implementations. This is really the key of the entire system. extension EventLoopFuture { - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, /// which will provide a new `EventLoopFuture`. /// @@ -481,41 +474,7 @@ extension EventLoopFuture { self._flatMap(callback) } @usableFromInline typealias FlatMapCallback = @Sendable (Value) -> EventLoopFuture - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, - /// which will provide a new `EventLoopFuture`. - /// - /// This allows you to dynamically dispatch new asynchronous tasks as phases in a - /// longer series of processing steps. Note that you can use the results of the - /// current `EventLoopFuture` when determining how to dispatch the next operation. - /// - /// This works well when you have APIs that already know how to return `EventLoopFuture`s. - /// You can do something with the result of one and just return the next future: - /// - /// ``` - /// let d1 = networkRequest(args).future() - /// let d2 = d1.flatMap { t -> EventLoopFuture in - /// . . . something with t . . . - /// return netWorkRequest(args) - /// } - /// d2.whenSuccess { u in - /// NSLog("Result of second request: \(u)") - /// } - /// ``` - /// - /// Note: In a sense, the `EventLoopFuture` is returned before it's created. - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMap(_ callback: @escaping (Value) -> EventLoopFuture) -> EventLoopFuture { - self._flatMap(callback) - } - @usableFromInline typealias FlatMapCallback = (Value) -> EventLoopFuture - #endif - + @inlinable func _flatMap(_ callback: @escaping FlatMapCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -537,8 +496,7 @@ extension EventLoopFuture { } return next.futureResult } - - #if swift(>=5.7) + /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns a new value of type `NewValue`. The provided /// callback may optionally `throw`. @@ -559,28 +517,7 @@ extension EventLoopFuture { self._flatMapThrowing(callback) } @usableFromInline typealias FlatMapThrowingCallback = @Sendable (Value) throws -> NewValue - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns a new value of type `NewValue`. The provided - /// callback may optionally `throw`. - /// - /// Operations performed in `flatMapThrowing` should not block, or they will block the entire - /// event loop. `flatMapThrowing` is intended for use when you have a data-driven function that - /// performs a simple data transformation that can potentially error. - /// - /// If your callback function throws, the returned `EventLoopFuture` will error. - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMapThrowing(_ callback: @escaping (Value) throws -> NewValue) -> EventLoopFuture { - self._flatMapThrowing(callback) - } - @usableFromInline typealias FlatMapThrowingCallback = (Value) throws -> NewValue - #endif - + @inlinable func _flatMapThrowing(_ callback: @escaping FlatMapThrowingCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -599,8 +536,7 @@ extension EventLoopFuture { } return next.futureResult } - - #if swift(>=5.7) + /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// may recover from the error and returns a new value of type `Value`. The provided callback may optionally `throw`, /// in which case the `EventLoopFuture` will be in a failed state with the new thrown error. @@ -621,28 +557,7 @@ extension EventLoopFuture { self._flatMapErrorThrowing(callback) } @usableFromInline typealias FlatMapErrorThrowingCallback = @Sendable (Error) throws -> Value - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// may recover from the error and returns a new value of type `Value`. The provided callback may optionally `throw`, - /// in which case the `EventLoopFuture` will be in a failed state with the new thrown error. - /// - /// Operations performed in `flatMapErrorThrowing` should not block, or they will block the entire - /// event loop. `flatMapErrorThrowing` is intended for use when you have the ability to synchronously - /// recover from errors. - /// - /// If your callback function throws, the returned `EventLoopFuture` will error. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value or a rethrown error. - @inlinable - public func flatMapErrorThrowing(_ callback: @escaping (Error) throws -> Value) -> EventLoopFuture { - self._flatMapErrorThrowing(callback) - } - @usableFromInline typealias FlatMapErrorThrowingCallback = (Error) throws -> Value - #endif - + @inlinable func _flatMapErrorThrowing(_ callback: @escaping FlatMapErrorThrowingCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -662,7 +577,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns a new value of type `NewValue`. /// @@ -695,48 +609,11 @@ extension EventLoopFuture { self._map(callback) } @usableFromInline typealias MapCallback = @Sendable (Value) -> (NewValue) - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns a new value of type `NewValue`. - /// - /// Operations performed in `map` should not block, or they will block the entire event - /// loop. `map` is intended for use when you have a data-driven function that performs - /// a simple data transformation that cannot error. - /// - /// If you have a data-driven function that can throw, you should use `flatMapThrowing` - /// instead. - /// - /// ``` - /// let future1 = eventually() - /// let future2 = future1.map { T -> U in - /// ... stuff ... - /// return u - /// } - /// let future3 = future2.map { U -> V in - /// ... stuff ... - /// return v - /// } - /// ``` - /// - /// - parameters: - /// - callback: Function that will receive the value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func map(_ callback: @escaping (Value) -> (NewValue)) -> EventLoopFuture { - self._map(callback) - } - @usableFromInline typealias MapCallback = (Value) -> (NewValue) - #endif - + @inlinable func _map(_ callback: @escaping MapCallback) -> EventLoopFuture { if NewValue.self == Value.self && NewValue.self == Void.self { - #if swift(>=5.7) self.whenSuccess(callback as! @Sendable (Value) -> Void) - #else - self.whenSuccess(callback as! (Value) -> Void) - #endif return self as! EventLoopFuture } else { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -747,7 +624,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// may recover from the error by returning an `EventLoopFuture`. The callback is intended to potentially /// recover from the error by returning a new `EventLoopFuture` that will eventually contain the recovered @@ -765,25 +641,7 @@ extension EventLoopFuture { self._flatMapError(callback) } @usableFromInline typealias FlatMapErrorCallback = @Sendable (Error) -> EventLoopFuture - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// may recover from the error by returning an `EventLoopFuture`. The callback is intended to potentially - /// recover from the error by returning a new `EventLoopFuture` that will eventually contain the recovered - /// result. - /// - /// If the callback cannot recover it should return a failed `EventLoopFuture`. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the recovered value. - @inlinable - public func flatMapError(_ callback: @escaping (Error) -> EventLoopFuture) -> EventLoopFuture { - self._flatMapError(callback) - } - @usableFromInline typealias FlatMapErrorCallback = (Error) -> EventLoopFuture - #endif - + @inlinable func _flatMapError(_ callback: @escaping FlatMapErrorCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -806,7 +664,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which /// performs a synchronous computation and returns either a new value (of type `NewValue`) or /// an error depending on the `Result` returned by the closure. @@ -826,27 +683,7 @@ extension EventLoopFuture { self._flatMapResult(body) } @usableFromInline typealias FlatMapResultCallback = @Sendable (Value) -> Result - #else - /// When the current `EventLoopFuture` is fulfilled, run the provided callback, which - /// performs a synchronous computation and returns either a new value (of type `NewValue`) or - /// an error depending on the `Result` returned by the closure. - /// - /// Operations performed in `flatMapResult` should not block, or they will block the entire - /// event loop. `flatMapResult` is intended for use when you have a data-driven function that - /// performs a simple data transformation that can potentially error. - /// - /// - /// - parameters: - /// - body: Function that will receive the value of this `EventLoopFuture` and return - /// a new value or error lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the eventual value. - @inlinable - public func flatMapResult(_ body: @escaping (Value) -> Result) -> EventLoopFuture { - self._flatMapResult(body) - } - @usableFromInline typealias FlatMapResultCallback = (Value) -> Result - #endif - + @inlinable func _flatMapResult(_ body: @escaping FlatMapResultCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -866,7 +703,6 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// can recover from the error and return a new value of type `Value`. The provided callback may not `throw`, /// so this function should be used when the error is always recoverable. @@ -885,26 +721,7 @@ extension EventLoopFuture { self._recover(callback) } @usableFromInline typealias RecoverCallback = @Sendable (Error) -> Value - #else - /// When the current `EventLoopFuture` is in an error state, run the provided callback, which - /// can recover from the error and return a new value of type `Value`. The provided callback may not `throw`, - /// so this function should be used when the error is always recoverable. - /// - /// Operations performed in `recover` should not block, or they will block the entire - /// event loop. `recover` is intended for use when you have the ability to synchronously - /// recover from errors. - /// - /// - parameters: - /// - callback: Function that will receive the error value of this `EventLoopFuture` and return - /// a new value lifted into a new `EventLoopFuture`. - /// - returns: A future that will receive the recovered value. - @inlinable - public func recover(_ callback: @escaping (Error) -> Value) -> EventLoopFuture { - self._recover(callback) - } - @usableFromInline typealias RecoverCallback = (Error) -> Value - #endif - + @inlinable func _recover(_ callback: @escaping RecoverCallback) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) @@ -919,11 +736,7 @@ extension EventLoopFuture { return next.futureResult } - #if swift(>=5.7) @usableFromInline typealias AddCallbackCallback = @Sendable () -> CallbackList - #else - @usableFromInline typealias AddCallbackCallback = () -> CallbackList - #endif /// Add a callback. If there's already a value, invoke it and return the resulting list of new callback functions. @inlinable internal func _addCallback(_ callback: @escaping AddCallbackCallback) -> CallbackList { @@ -934,8 +747,7 @@ extension EventLoopFuture { } return callback() } - - #if swift(>=5.7) + /// Add a callback. If there's already a value, run as much of the chain as we can. @inlinable @preconcurrency // TODO: We want to remove @preconcurrency but it results in more allocations in 1000_udpconnections @@ -943,15 +755,7 @@ extension EventLoopFuture { self._internalWhenComplete(callback) } @usableFromInline typealias InternalWhenCompleteCallback = @Sendable () -> CallbackList - #else - /// Add a callback. If there's already a value, run as much of the chain as we can. - @inlinable - internal func _whenComplete(_ callback: @escaping () -> CallbackList) { - self._internalWhenComplete(callback) - } - @usableFromInline typealias InternalWhenCompleteCallback = () -> CallbackList - #endif - + /// Add a callback. If there's already a value, run as much of the chain as we can. @inlinable internal func _internalWhenComplete(_ callback: @escaping InternalWhenCompleteCallback) { @@ -964,7 +768,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a success result. /// @@ -981,24 +784,7 @@ extension EventLoopFuture { self._whenSuccess(callback) } @usableFromInline typealias WhenSuccessCallback = @Sendable (Value) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a success result. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `map` or `flatMap`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - callback: The callback that is called with the successful result of the `EventLoopFuture`. - @inlinable - public func whenSuccess(_ callback: @escaping (Value) -> Void) { - self._whenSuccess(callback) - } - @usableFromInline typealias WhenSuccessCallback = (Value) -> Void - #endif - + @inlinable func _whenSuccess(_ callback: @escaping WhenSuccessCallback) { self._whenComplete { @@ -1008,8 +794,7 @@ extension EventLoopFuture { return CallbackList() } } - - #if swift(>=5.7) + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a failure result. /// @@ -1026,24 +811,7 @@ extension EventLoopFuture { self._whenFailure(callback) } @usableFromInline typealias WhenFailureCallback = @Sendable (Error) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a failure result. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `recover` or `flatMapError`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - callback: The callback that is called with the failed result of the `EventLoopFuture`. - @inlinable - public func whenFailure(_ callback: @escaping (Error) -> Void) { - self._whenFailure(callback) - } - @usableFromInline typealias WhenFailureCallback = (Error) -> Void - #endif - + @inlinable func _whenFailure(_ callback: @escaping WhenFailureCallback) { self._whenComplete { @@ -1054,7 +822,6 @@ extension EventLoopFuture { } } - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. /// @@ -1066,18 +833,6 @@ extension EventLoopFuture { self._publicWhenComplete(callback) } @usableFromInline typealias WhenCompleteCallback = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. - /// - /// - parameters: - /// - callback: The callback that is called when the `EventLoopFuture` is fulfilled. - @inlinable - public func whenComplete(_ callback: @escaping (Result) -> Void) { - self._publicWhenComplete(callback) - } - @usableFromInline typealias WhenCompleteCallback = (Result) -> Void - #endif @inlinable func _publicWhenComplete(_ callback: @escaping WhenCompleteCallback) { self._whenComplete { @@ -1223,7 +978,6 @@ extension EventLoopFuture { // MARK: wait extension EventLoopFuture { - #if swift(>=5.7) /// Wait for the resolution of this `EventLoopFuture` by blocking the current thread until it /// resolves. /// @@ -1242,25 +996,6 @@ extension EventLoopFuture { public func wait(file: StaticString = #file, line: UInt = #line) throws -> Value { return try self._wait(file: file, line: line) } - #else - /// Wait for the resolution of this `EventLoopFuture` by blocking the current thread until it - /// resolves. - /// - /// If the `EventLoopFuture` resolves with a value, that value is returned from `wait()`. If - /// the `EventLoopFuture` resolves with an error, that error will be thrown instead. - /// `wait()` will block whatever thread it is called on, so it must not be called on event loop - /// threads: it is primarily useful for testing, or for building interfaces between blocking - /// and non-blocking code. - /// - /// This is also forbidden in async contexts: prefer ``EventLoopFuture/get``. - /// - /// - returns: The value of the `EventLoopFuture` when it completes. - /// - throws: The error value of the `EventLoopFuture` if it errors. - @inlinable - public func wait(file: StaticString = #file, line: UInt = #line) throws -> Value { - return try self._wait(file: file, line: line) - } - #endif @inlinable func _wait(file: StaticString, line: UInt) throws -> Value { @@ -1289,7 +1024,6 @@ extension EventLoopFuture { // MARK: fold extension EventLoopFuture { - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when this `EventLoopFuture` and /// all the provided `futures` complete. It then provides the result of folding the value of this /// `EventLoopFuture` with the values of all the provided `futures`. @@ -1315,33 +1049,7 @@ extension EventLoopFuture { self._fold(futures, with: combiningFunction) } @usableFromInline typealias FoldCallback = @Sendable (Value, OtherValue) -> EventLoopFuture - #else - /// Returns a new `EventLoopFuture` that fires only when this `EventLoopFuture` and - /// all the provided `futures` complete. It then provides the result of folding the value of this - /// `EventLoopFuture` with the values of all the provided `futures`. - /// - /// This function is suited when you have APIs that already know how to return `EventLoopFuture`s. - /// - /// The returned `EventLoopFuture` will fail as soon as the a failure is encountered in any of the - /// `futures` (or in this one). However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall EventLoopFuture. - /// - /// - parameters: - /// - futures: An array of `EventLoopFuture` to wait for. - /// - with: A function that will be used to fold the values of two `EventLoopFuture`s and return a new value wrapped in an `EventLoopFuture`. - /// - returns: A new `EventLoopFuture` with the folded value whose callbacks run on `self.eventLoop`. - @inlinable - public func fold( - _ futures: [EventLoopFuture], - with combiningFunction: @escaping (Value, OtherValue) -> EventLoopFuture - ) -> EventLoopFuture { - self._fold(futures, with: combiningFunction) - } - @usableFromInline typealias FoldCallback = (Value, OtherValue) -> EventLoopFuture - #endif - + @inlinable func _fold( _ futures: [EventLoopFuture], @@ -1375,7 +1083,6 @@ extension EventLoopFuture { // MARK: reduce extension EventLoopFuture { - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. /// The new `EventLoopFuture` contains the result of reducing the `initialResult` with the /// values of the `[EventLoopFuture]`. @@ -1406,38 +1113,7 @@ extension EventLoopFuture { Self._reduce(initialResult, futures, on: eventLoop, nextPartialResult) } @usableFromInline typealias ReduceCallback = @Sendable (Value, InputValue) -> Value - #else - /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. - /// The new `EventLoopFuture` contains the result of reducing the `initialResult` with the - /// values of the `[EventLoopFuture]`. - /// - /// This function makes copies of the result for each EventLoopFuture, for a version which avoids - /// making copies, check out `reduce(into:)`. - /// - /// The returned `EventLoopFuture` will fail as soon as a failure is encountered in any of the - /// `futures`. However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall `EventLoopFuture`. - /// - /// - parameters: - /// - initialResult: An initial result to begin the reduction. - /// - futures: An array of `EventLoopFuture` to wait for. - /// - eventLoop: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. - /// - nextPartialResult: The bifunction used to produce partial results. - /// - returns: A new `EventLoopFuture` with the reduced value. - @inlinable - public static func reduce( - _ initialResult: Value, - _ futures: [EventLoopFuture], - on eventLoop: EventLoop, - _ nextPartialResult: @escaping (Value, InputValue) -> Value - ) -> EventLoopFuture { - Self._reduce(initialResult, futures, on: eventLoop, nextPartialResult) - } - @usableFromInline typealias ReduceCallback = (Value, InputValue) -> Value - #endif - + @inlinable static func _reduce( _ initialResult: Value, @@ -1454,7 +1130,6 @@ extension EventLoopFuture { return body } - #if swift(>=5.7) /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. /// The new `EventLoopFuture` contains the result of combining the `initialResult` with the /// values of the `[EventLoopFuture]`. This function is analogous to the standard library's @@ -1483,36 +1158,7 @@ extension EventLoopFuture { Self._reduce(into: initialResult, futures, on: eventLoop, updateAccumulatingResult) } @usableFromInline typealias ReduceIntoCallback = @Sendable (inout Value, InputValue) -> Void - #else - /// Returns a new `EventLoopFuture` that fires only when all the provided futures complete. - /// The new `EventLoopFuture` contains the result of combining the `initialResult` with the - /// values of the `[EventLoopFuture]`. This function is analogous to the standard library's - /// `reduce(into:)`, which does not make copies of the result type for each `EventLoopFuture`. - /// - /// The returned `EventLoopFuture` will fail as soon as a failure is encountered in any of the - /// `futures`. However, the failure will not occur until all preceding - /// `EventLoopFutures` have completed. At the point the failure is encountered, all subsequent - /// `EventLoopFuture` objects will no longer be waited for. This function therefore fails fast: once - /// a failure is encountered, it will immediately fail the overall `EventLoopFuture`. - /// - /// - parameters: - /// - initialResult: An initial result to begin the reduction. - /// - futures: An array of `EventLoopFuture` to wait for. - /// - eventLoop: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. - /// - updateAccumulatingResult: The bifunction used to combine partialResults with new elements. - /// - returns: A new `EventLoopFuture` with the combined value. - @inlinable - public static func reduce( - into initialResult: Value, - _ futures: [EventLoopFuture], - on eventLoop: EventLoop, - _ updateAccumulatingResult: @escaping (inout Value, InputValue) -> Void - ) -> EventLoopFuture { - Self._reduce(into: initialResult, futures, on: eventLoop, updateAccumulatingResult) - } - @usableFromInline typealias ReduceIntoCallback = (inout Value, InputValue) -> Void - #endif - + @inlinable static func _reduce( into initialResult: Value, @@ -1611,15 +1257,9 @@ extension EventLoopFuture { let reduced = eventLoop.makePromise(of: Void.self) let results: UnsafeMutableTransferBox<[Value?]> = .init(.init(repeating: nil, count: futures.count)) - #if swift(>=5.7) let callback = { @Sendable (index: Int, result: Value) in results.wrappedValue[index] = result } - #else - let callback = { (index: Int, result: Value) in - results.wrappedValue[index] = result - } - #endif if eventLoop.inEventLoop { self._reduceSuccesses0(reduced, futures, eventLoop, onValue: callback) @@ -1640,12 +1280,8 @@ extension EventLoopFuture { } } } - - #if swift(>=5.7) + @usableFromInline typealias ReduceSuccessCallback = @Sendable (Int, InputValue) -> Void - #else - @usableFromInline typealias ReduceSuccessCallback = (Int, InputValue) -> Void - #endif /// Loops through the futures array and attaches callbacks to execute `onValue` on the provided `EventLoop` when /// they succeed. The `onValue` will receive the index of the future that fulfilled the provided `Result`. /// @@ -1773,17 +1409,11 @@ extension EventLoopFuture { promise: EventLoopPromise<[Result]>) { let eventLoop = promise.futureResult.eventLoop let reduced = eventLoop.makePromise(of: Void.self) - + let results: UnsafeMutableTransferBox<[Result]> = .init(.init(repeating: .failure(OperationPlaceholderError()), count: futures.count)) - #if swift(>=5.7) let callback = { @Sendable (index: Int, result: Result) in results.wrappedValue[index] = result } - #else - let callback = { (index: Int, result: Result) in - results.wrappedValue[index] = result - } - #endif if eventLoop.inEventLoop { self._reduceCompletions0(reduced, futures, eventLoop, onResult: callback) @@ -1808,13 +1438,9 @@ extension EventLoopFuture { } } } - - #if swift(>=5.7) + @usableFromInline typealias ReduceCompletions = @Sendable (Int, Result) -> Void - #else - @usableFromInline typealias ReduceCompletions = (Int, Result) -> Void - #endif - + /// Loops through the futures array and attaches callbacks to execute `onResult` on the provided `EventLoop` when /// they complete. The `onResult` will receive the index of the future that fulfilled the provided `Result`. /// @@ -1889,7 +1515,6 @@ extension EventLoopFuture { // MARK: always extension EventLoopFuture { - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. /// @@ -1902,20 +1527,7 @@ extension EventLoopFuture { self._always(callback) } @usableFromInline typealias AlwaysCallback = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. - /// - /// - parameters: - /// - callback: the callback that is called when the `EventLoopFuture` is fulfilled. - /// - returns: the current `EventLoopFuture` - @inlinable - public func always(_ callback: @escaping (Result) -> Void) -> EventLoopFuture { - self._always(callback) - } - @usableFromInline typealias AlwaysCallback = (Result) -> Void - #endif - + @inlinable func _always(_ callback: @escaping AlwaysCallback) -> EventLoopFuture { self.whenComplete { result in callback(result) } @@ -1974,8 +1586,7 @@ extension EventLoopFuture { return value } } - - #if swift(>=5.7) + /// Unwrap an `EventLoopFuture` where its type parameter is an `Optional`. /// /// Unwraps a future returning a new `EventLoopFuture` with either: the value returned by the closure passed in @@ -1998,30 +1609,7 @@ extension EventLoopFuture { self._unwrap(orElse: callback) } @usableFromInline typealias UnwrapCallback = @Sendable () -> NewValue - #else - /// Unwrap an `EventLoopFuture` where its type parameter is an `Optional`. - /// - /// Unwraps a future returning a new `EventLoopFuture` with either: the value returned by the closure passed in - /// the `orElse` parameter when the future resolved with value Optional.none, or the same value otherwise. For example: - /// ``` - /// var x = 2 - /// promise.futureResult.unwrap(orElse: { x * 2 }).wait() - /// ``` - /// - /// - parameters: - /// - orElse: a closure that returns the value of the returned `EventLoopFuture` when then resolved future's value - /// is `Optional.some()`. - /// - returns: an new `EventLoopFuture` with new type parameter `NewValue` and with the value returned by the closure - /// passed in the `orElse` parameter. - @inlinable - public func unwrap( - orElse callback: @escaping () -> NewValue - ) -> EventLoopFuture where Value == Optional { - self._unwrap(orElse: callback) - } - @usableFromInline typealias UnwrapCallback = () -> NewValue - #endif - + @inlinable func _unwrap( orElse callback: @escaping UnwrapCallback @@ -2038,7 +1626,6 @@ extension EventLoopFuture { // MARK: may block extension EventLoopFuture { - #if swift(>=5.7) /// Chain an `EventLoopFuture` providing the result of a IO / task that may block. For example: /// /// promise.futureResult.flatMapBlocking(onto: DispatchQueue.global()) { value in Int @@ -2058,27 +1645,7 @@ extension EventLoopFuture { self._flatMapBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias FlatMapBlockingCallback = @Sendable (Value) throws -> NewValue - #else - /// Chain an `EventLoopFuture` providing the result of a IO / task that may block. For example: - /// - /// promise.futureResult.flatMapBlocking(onto: DispatchQueue.global()) { value in Int - /// blockingTask(value) - /// } - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: Function that will receive the value of this `EventLoopFuture` and return - /// a new `EventLoopFuture`. - @inlinable - public func flatMapBlocking( - onto queue: DispatchQueue, - _ callbackMayBlock: @escaping (Value) throws -> NewValue - ) -> EventLoopFuture { - self._flatMapBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias FlatMapBlockingCallback = (Value) throws -> NewValue - #endif - + @inlinable func _flatMapBlocking( onto queue: DispatchQueue, @@ -2088,7 +1655,7 @@ extension EventLoopFuture { queue.asyncWithFuture(eventLoop: self.eventLoop) { try callbackMayBlock(result) } } } - + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a success result. The observer callback is permitted to block. /// @@ -2106,8 +1673,7 @@ extension EventLoopFuture { queue.async { callbackMayBlock(value) } } } - - #if swift(>=5.7) + /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has a failure result. The observer callback is permitted to block. /// @@ -2125,34 +1691,14 @@ extension EventLoopFuture { self._whenFailureBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenFailureBlockingCallback = @Sendable (Error) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has a failure result. The observer callback is permitted to block. - /// - /// An observer callback cannot return a value, meaning that this function cannot be chained - /// from. If you are attempting to create a computation pipeline, consider `recover` or `flatMapError`. - /// If you find yourself passing the results from this `EventLoopFuture` to a new `EventLoopPromise` - /// in the body of this function, consider using `cascade` instead. - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: The callback that is called with the failed result of the `EventLoopFuture`. - @inlinable - public func whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping (Error) -> Void) { - self._whenFailureBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias WhenFailureBlockingCallback = (Error) -> Void - #endif - + @inlinable func _whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping WhenFailureBlockingCallback) { self.whenFailure { err in queue.async { callbackMayBlock(err) } } } - - #if swift(>=5.7) /// Adds an observer callback to this `EventLoopFuture` that is called when the /// `EventLoopFuture` has any result. The observer callback is permitted to block. /// @@ -2165,20 +1711,7 @@ extension EventLoopFuture { self._whenCompleteBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenCompleteBlocking = @Sendable (Result) -> Void - #else - /// Adds an observer callback to this `EventLoopFuture` that is called when the - /// `EventLoopFuture` has any result. The observer callback is permitted to block. - /// - /// - parameters: - /// - onto: the `DispatchQueue` on which the blocking IO / task specified by `callbackMayBlock` is scheduled. - /// - callbackMayBlock: The callback that is called when the `EventLoopFuture` is fulfilled. - @inlinable - public func whenCompleteBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping (Result) -> Void) { - self._whenCompleteBlocking(onto: queue, callbackMayBlock) - } - @usableFromInline typealias WhenCompleteBlocking = (Result) -> Void - #endif - + @inlinable func _whenCompleteBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping WhenCompleteBlocking) { self.whenComplete { value in @@ -2283,7 +1816,6 @@ public struct _NIOEventLoopFutureIdentifier: Hashable, Sendable { self.opaqueID = _NIOEventLoopFutureIdentifier.obfuscatePointerValue(future: future) } - private static func obfuscatePointerValue(future: EventLoopFuture) -> UInt { // Note: // 1. 0xbf15ca5d is randomly picked such that it fits into both 32 and 64 bit address spaces diff --git a/Sources/NIOCore/IO.swift b/Sources/NIOCore/IO.swift index 4377154317..3fdd7a10f9 100644 --- a/Sources/NIOCore/IO.swift +++ b/Sources/NIOCore/IO.swift @@ -28,8 +28,10 @@ import typealias WinSDK.WORD internal func MAKELANGID(_ p: WORD, _ s: WORD) -> DWORD { return DWORD((s << 10) | p) } -#elseif os(Linux) || os(Android) +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #elseif canImport(Darwin) import Darwin #else diff --git a/Sources/NIOCore/Interfaces.swift b/Sources/NIOCore/Interfaces.swift index 1c656db020..5fcbbf83ab 100644 --- a/Sources/NIOCore/Interfaces.swift +++ b/Sources/NIOCore/Interfaces.swift @@ -123,7 +123,7 @@ public final class NIONetworkInterface { } #else internal init?(_ caddr: ifaddrs) { - self.name = String(cString: caddr.ifa_name) + self.name = String(cString: caddr.ifa_name!) guard caddr.ifa_addr != nil else { return nil @@ -414,7 +414,7 @@ extension NIONetworkDevice { } #else internal init?(_ caddr: ifaddrs) { - self.name = String(cString: caddr.ifa_name) + self.name = String(cString: caddr.ifa_name!) self.address = caddr.ifa_addr.flatMap { $0.convert() } self.netmask = caddr.ifa_netmask.flatMap { $0.convert() } diff --git a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift index 5cda5a5794..2341a0550e 100644 --- a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift +++ b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift @@ -274,10 +274,8 @@ public final class NIOSingleStepByteToMessageProcessor=5.7) @available(*, unavailable) extension NIOSingleStepByteToMessageProcessor: Sendable {} -#endif // MARK: NIOSingleStepByteToMessageProcessor Public API extension NIOSingleStepByteToMessageProcessor { diff --git a/Sources/NIOCore/SystemCallHelpers.swift b/Sources/NIOCore/SystemCallHelpers.swift index dc9b457a19..b74092a139 100644 --- a/Sources/NIOCore/SystemCallHelpers.swift +++ b/Sources/NIOCore/SystemCallHelpers.swift @@ -43,11 +43,16 @@ private let sysOpenWithMode: @convention(c) (UnsafePointer, CInt, NIOPOSI private let sysLseek: @convention(c) (CInt, off_t, CInt) -> off_t = lseek private let sysRead: @convention(c) (CInt, UnsafeMutableRawPointer?, size_t) -> size_t = read #endif -private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex +#if os(Android) +private let sysIfNameToIndex: @convention(c) (UnsafePointer) -> CUnsignedInt = if_nametoindex +private let sysGetifaddrs: @convention(c) (UnsafeMutablePointer?>) -> CInt = getifaddrs +#else +private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex #if !os(Windows) private let sysGetifaddrs: @convention(c) (UnsafeMutablePointer?>?) -> CInt = getifaddrs #endif +#endif private func isUnacceptableErrno(_ code: Int32) -> Bool { switch code { @@ -173,7 +178,7 @@ enum SystemCalls { @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { return try syscall(blocking: false) { - sysIfNameToIndex(name) + sysIfNameToIndex(name!) }.result } diff --git a/Sources/NIOCore/UniversalBootstrapSupport.swift b/Sources/NIOCore/UniversalBootstrapSupport.swift index 01ab959e1a..b89594a201 100644 --- a/Sources/NIOCore/UniversalBootstrapSupport.swift +++ b/Sources/NIOCore/UniversalBootstrapSupport.swift @@ -15,7 +15,6 @@ /// `NIOClientTCPBootstrapProtocol` is implemented by various underlying transport mechanisms. Typically, /// this will be the BSD Sockets API implemented by `ClientBootstrap`. public protocol NIOClientTCPBootstrapProtocol { - #if swift(>=5.7) /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -35,28 +34,7 @@ public protocol NIOClientTCPBootstrapProtocol { /// - handler: A closure that initializes the provided `Channel`. @preconcurrency func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self - #else - /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - warning: The `handler` closure may be invoked _multiple times_ so it's usually the right choice to instantiate - /// `ChannelHandler`s within `handler`. The reason `handler` may be invoked multiple times is that to - /// successfully set up a connection multiple connections might be setup in the process. Assuming a - /// hostname that resolves to both IPv4 and IPv6 addresses, NIO will follow - /// [_Happy Eyeballs_](https://en.wikipedia.org/wiki/Happy_Eyeballs) and race both an IPv4 and an IPv6 - /// connection. It is possible that both connections get fully established before the IPv4 connection - /// will be closed again because the IPv6 connection 'won the race'. Therefore the `channelInitializer` - /// might be called multiple times and it's important not to share stateful `ChannelHandler`s in more - /// than one `Channel`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self - #endif - - #if swift(>=5.7) + /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the /// `channelInitializer` has been called. /// @@ -65,15 +43,6 @@ public protocol NIOClientTCPBootstrapProtocol { /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. @preconcurrency func protocolHandlers(_ handlers: @escaping @Sendable () -> [ChannelHandler]) -> Self - #else - /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the - /// `channelInitializer` has been called. - /// - /// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS - /// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses - /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. - func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self - #endif /// Specifies a `ChannelOption` to be applied to the `SocketChannel`. /// @@ -81,7 +50,7 @@ public protocol NIOClientTCPBootstrapProtocol { /// - option: The option to be applied. /// - value: The value for the option. func channelOption(_ option: Option, value: Option.Value) -> Self - + /// Apply any understood convenience options to the bootstrap, removing them from the set of options if they are consumed. /// Method is optional to implement and should never be directly called by users. /// - parameters: @@ -198,7 +167,7 @@ public struct NIOClientTCPBootstrap { self.underlyingBootstrap = bootstrap self.tlsEnablerTypeErased = tlsEnabler } - + internal init(_ original : NIOClientTCPBootstrap, updating underlying : NIOClientTCPBootstrapProtocol) { self.underlyingBootstrap = underlying self.tlsEnablerTypeErased = original.tlsEnablerTypeErased diff --git a/Sources/NIOCrashTester/CrashTests+EventLoop.swift b/Sources/NIOCrashTester/CrashTests+EventLoop.swift index 658d5f36ae..7ce7f4a3ef 100644 --- a/Sources/NIOCrashTester/CrashTests+EventLoop.swift +++ b/Sources/NIOCrashTester/CrashTests+EventLoop.swift @@ -11,7 +11,9 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + #if !canImport(Darwin) || os(macOS) +import Dispatch import NIOCore import NIOPosix @@ -177,5 +179,34 @@ struct EventLoopCrashTests { ) { NIOSingletons.groupLoopCountSuggestion = -1 } + + #if compiler(>=5.9) // We only support Concurrency executor take-over on 5.9+ + let testInstallingSingletonMTELGAsConcurrencyExecutorWorksButOnlyOnce = CrashTest( + regex: #"Fatal error: Must be called only once"# + ) { + guard NIOSingletons.unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() else { + print("Installation failed, that's unexpected -> let's not crash") + return + } + + // Yes, this pattern is bad abuse but this is a crash test, we don't mind. + let semaphoreAbuse = DispatchSemaphore(value: 0) + if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { + Task { + precondition(MultiThreadedEventLoopGroup.currentEventLoop != nil) + try await Task.sleep(nanoseconds: 123) + precondition(MultiThreadedEventLoopGroup.currentEventLoop != nil) + semaphoreAbuse.signal() + } + } else { + semaphoreAbuse.signal() + } + semaphoreAbuse.wait() + print("Okay, worked") + + // This should crash + _ = NIOSingletons.unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() + } + #endif // compiler(>=5.9) } -#endif +#endif // !canImport(Darwin) || os(macOS) diff --git a/Sources/NIOCrashTester/OutputGrepper.swift b/Sources/NIOCrashTester/OutputGrepper.swift index ee99921bd1..a226a38bf7 100644 --- a/Sources/NIOCrashTester/OutputGrepper.swift +++ b/Sources/NIOCrashTester/OutputGrepper.swift @@ -22,7 +22,6 @@ internal struct OutputGrepper { internal static func make(group: EventLoopGroup) -> OutputGrepper { let processToChannel = Pipe() - let deadPipe = Pipe() // just so we have an output... let eventLoop = group.next() let outputPromise = eventLoop.makePromise(of: ProgramOutput.self) @@ -34,13 +33,10 @@ internal struct OutputGrepper { channel.pipeline.addHandlers([ByteToMessageHandler(NewlineFramer()), GrepHandler(promise: outputPromise)]) } - .takingOwnershipOfDescriptors(input: dup(processToChannel.fileHandleForReading.fileDescriptor), - output: dup(deadPipe.fileHandleForWriting.fileDescriptor)) + .takingOwnershipOfDescriptor(input: dup(processToChannel.fileHandleForReading.fileDescriptor)) let processOutputPipe = NIOFileHandle(descriptor: dup(processToChannel.fileHandleForWriting.fileDescriptor)) processToChannel.fileHandleForReading.closeFile() processToChannel.fileHandleForWriting.closeFile() - deadPipe.fileHandleForReading.closeFile() - deadPipe.fileHandleForWriting.closeFile() channelFuture.cascadeFailure(to: outputPromise) return OutputGrepper(result: outputPromise.futureResult, processOutputPipe: processOutputPipe) diff --git a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift index 271707ad83..0ee8a685dc 100644 --- a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift +++ b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift @@ -375,9 +375,23 @@ extension ByteBufferAllocator { } // MARK: - Conformances +#if swift(>=5.8) +#if $RetroactiveAttribute +extension ByteBufferView: @retroactive ContiguousBytes {} +extension ByteBufferView: @retroactive DataProtocol {} +extension ByteBufferView: @retroactive MutableDataProtocol {} +#else extension ByteBufferView: ContiguousBytes {} +extension ByteBufferView: DataProtocol {} +extension ByteBufferView: MutableDataProtocol {} +#endif +#else +extension ByteBufferView: ContiguousBytes {} +extension ByteBufferView: DataProtocol {} +extension ByteBufferView: MutableDataProtocol {} +#endif -extension ByteBufferView: DataProtocol { +extension ByteBufferView { public typealias Regions = CollectionOfOne public var regions: CollectionOfOne { @@ -385,8 +399,6 @@ extension ByteBufferView: DataProtocol { } } -extension ByteBufferView: MutableDataProtocol {} - // MARK: - Data extension Data { diff --git a/Sources/NIOHTTP1/HTTPEncoder.swift b/Sources/NIOHTTP1/HTTPEncoder.swift index da3fa33a69..c79283ef87 100644 --- a/Sources/NIOHTTP1/HTTPEncoder.swift +++ b/Sources/NIOHTTP1/HTTPEncoder.swift @@ -15,26 +15,24 @@ import NIOCore private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHandlerContext, isChunked: Bool, chunk: IOData, promise: EventLoopPromise?) { - let (mW1, mW2, mW3): (EventLoopPromise?, EventLoopPromise?, EventLoopPromise?) - - switch (isChunked, promise) { - case (true, .some(let p)): - /* chunked encoding and the user's interested: we need three promises and need to cascade into the users promise */ - let (w1, w2, w3) = (context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise) - w1.futureResult.and(w2.futureResult).and(w3.futureResult).map { (_: ((((), ()), ()))) in }.cascade(to: p) - (mW1, mW2, mW3) = (w1, w2, w3) - case (false, .some(let p)): - /* not chunked, so just use the user's promise for the actual data */ - (mW1, mW2, mW3) = (nil, p, nil) - case (_, .none): - /* user isn't interested, let's not bother even allocating promises */ - (mW1, mW2, mW3) = (nil, nil, nil) - } - let readableBytes = chunk.readableBytes - /* we don't want to copy the chunk unnecessarily and therefore call write an annoyingly large number of times */ - if isChunked { + // we don't want to copy the chunk unnecessarily and therefore call write an annoyingly large number of times + // we also don't frame empty chunks as they would otherwise end the response stream + // we still need to write the empty IODate to complete the promise in the right order but is otherwise a no-op. + if isChunked && readableBytes > 0 { + let (mW1, mW2, mW3): (EventLoopPromise?, EventLoopPromise?, EventLoopPromise?) + + if let p = promise { + /* chunked encoding and the user's interested: we need three promises and need to cascade into the users promise */ + let (w1, w2, w3) = (context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise) + w1.futureResult.and(w2.futureResult).and(w3.futureResult).map { (_: ((((), ()), ()))) in }.cascade(to: p) + (mW1, mW2, mW3) = (w1, w2, w3) + } else { + /* user isn't interested, let's not bother even allocating promises */ + (mW1, mW2, mW3) = (nil, nil, nil) + } + var buffer = context.channel.allocator.buffer(capacity: 32) let len = String(readableBytes, radix: 16) buffer.writeString(len) @@ -47,7 +45,7 @@ private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHan buffer.moveReaderIndex(forwardBy: buffer.readableBytes - 2) context.write(wrapOutboundOut(.byteBuffer(buffer)), promise: mW3) } else { - context.write(wrapOutboundOut(chunk), promise: mW2) + context.write(wrapOutboundOut(chunk), promise: promise) } } @@ -195,12 +193,6 @@ public final class HTTPRequestEncoder: ChannelOutboundHandler, RemovableChannelH buffer.write(request: request) }, context: context, headers: request.headers, promise: promise) case .body(let bodyPart): - guard bodyPart.readableBytes > 0 else { - // Empty writes shouldn't send any bytes in chunked or identity encoding. - context.write(self.wrapOutboundOut(bodyPart), promise: promise) - return - } - writeChunk(wrapOutboundOut: self.wrapOutboundOut, context: context, isChunked: self.isChunked, chunk: bodyPart, promise: promise) case .end(let trailers): writeTrailers(wrapOutboundOut: self.wrapOutboundOut, context: context, isChunked: self.isChunked, trailers: trailers, promise: promise) diff --git a/Sources/NIOHTTP1/HTTPPipelineSetup.swift b/Sources/NIOHTTP1/HTTPPipelineSetup.swift index 251143f956..c433665afc 100644 --- a/Sources/NIOHTTP1/HTTPPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPPipelineSetup.swift @@ -14,19 +14,11 @@ import NIOCore -#if swift(>=5.7) /// Configuration required to configure a HTTP client pipeline for upgrade. /// /// See the documentation for `HTTPClientUpgradeHandler` for details on these /// properties. public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) -#else -/// Configuration required to configure a HTTP client pipeline for upgrade. -/// -/// See the documentation for `HTTPClientUpgradeHandler` for details on these -/// properties. -public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: (ChannelHandlerContext) -> Void) -#endif /// Configuration required to configure a HTTP server pipeline for upgrade. /// @@ -35,11 +27,7 @@ public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientP @available(*, deprecated, renamed: "NIOHTTPServerUpgradeConfiguration") public typealias HTTPUpgradeConfiguration = NIOHTTPServerUpgradeConfiguration -#if swift(>=5.7) public typealias NIOHTTPServerUpgradeConfiguration = (upgraders: [HTTPServerProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) -#else -public typealias NIOHTTPServerUpgradeConfiguration = (upgraders: [HTTPServerProtocolUpgrader], completionHandler: (ChannelHandlerContext) -> Void) -#endif extension ChannelPipeline { /// Configure a `ChannelPipeline` for use as a HTTP client. @@ -56,7 +44,6 @@ extension ChannelPipeline { withClientUpgrade: nil) } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. /// /// - parameters: @@ -78,29 +65,7 @@ extension ChannelPipeline { withClientUpgrade: upgrade ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. - /// - /// - parameters: - /// - position: The position in the `ChannelPipeline` where to add the HTTP client handlers. Defaults to `.last`. - /// - leftOverBytesStrategy: The strategy to use when dealing with leftover bytes after removing the `HTTPDecoder` - /// from the pipeline. - /// - upgrade: Add a `HTTPClientUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Should be a tuple of an array of `HTTPClientProtocolUpgrader` and - /// the upgrade completion handler. See the documentation on `HTTPClientUpgradeHandler` - /// for more details. - /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { - self._addHTTPClientHandlers( - position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade - ) - } - #endif - + private func _addHTTPClientHandlers(position: Position = .last, leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { @@ -206,7 +171,6 @@ extension ChannelPipeline { return future } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP server. /// /// This function knows how to set up all first-party HTTP channel handlers appropriately @@ -244,44 +208,6 @@ extension ChannelPipeline { withErrorHandling: errorHandling ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP server. - /// - /// This function knows how to set up all first-party HTTP channel handlers appropriately - /// for server use. It supports the following features: - /// - /// 1. Providing assistance handling clients that pipeline HTTP requests, using the - /// `HTTPServerPipelineHandler`. - /// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`. - /// - /// This method will likely be extended in future with more support for other first-party - /// features. - /// - /// - parameters: - /// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`. - /// - pipelining: Whether to provide assistance handling HTTP clients that pipeline - /// their requests. Defaults to `true`. If `false`, users will need to handle - /// clients that pipeline themselves. - /// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If - /// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade - /// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more - /// details. - /// - errorHandling: Whether to provide assistance handling protocol errors (e.g. - /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. - /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) -> EventLoopFuture { - self._configureHTTPServerPipeline( - position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP server. /// @@ -410,7 +336,6 @@ extension ChannelPipeline { } extension ChannelPipeline.SynchronousOperations { - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. /// /// - important: This **must** be called on the Channel's event loop. @@ -433,29 +358,6 @@ extension ChannelPipeline.SynchronousOperations { withClientUpgrade: upgrade ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. - /// - /// - important: This **must** be called on the Channel's event loop. - /// - parameters: - /// - position: The position in the `ChannelPipeline` where to add the HTTP client handlers. Defaults to `.last`. - /// - leftOverBytesStrategy: The strategy to use when dealing with leftover bytes after removing the `HTTPDecoder` - /// from the pipeline. - /// - upgrade: Add a `HTTPClientUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Should be a tuple of an array of `HTTPClientProtocolUpgrader` and - /// the upgrade completion handler. See the documentation on `HTTPClientUpgradeHandler` - /// for more details. - /// - throws: If the pipeline could not be configured. - public func addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { - try self._addHTTPClientHandlers( - position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP client. /// @@ -558,7 +460,6 @@ extension ChannelPipeline.SynchronousOperations { try self.addHandlers(handlers, position: position) } - #if swift(>=5.7) /// Configure a `ChannelPipeline` for use as a HTTP server. /// /// This function knows how to set up all first-party HTTP channel handlers appropriately @@ -597,45 +498,6 @@ extension ChannelPipeline.SynchronousOperations { withErrorHandling: errorHandling ) } - #else - /// Configure a `ChannelPipeline` for use as a HTTP server. - /// - /// This function knows how to set up all first-party HTTP channel handlers appropriately - /// for server use. It supports the following features: - /// - /// 1. Providing assistance handling clients that pipeline HTTP requests, using the - /// `HTTPServerPipelineHandler`. - /// 2. Supporting HTTP upgrade, using the `HTTPServerUpgradeHandler`. - /// - /// This method will likely be extended in future with more support for other first-party - /// features. - /// - /// - important: This **must** be called on the Channel's event loop. - /// - parameters: - /// - position: Where in the pipeline to add the HTTP server handlers, defaults to `.last`. - /// - pipelining: Whether to provide assistance handling HTTP clients that pipeline - /// their requests. Defaults to `true`. If `false`, users will need to handle - /// clients that pipeline themselves. - /// - upgrade: Whether to add a `HTTPServerUpgradeHandler` to the pipeline, configured for - /// HTTP upgrade. Defaults to `nil`, which will not add the handler to the pipeline. If - /// provided should be a tuple of an array of `HTTPServerProtocolUpgrader` and the upgrade - /// completion handler. See the documentation on `HTTPServerUpgradeHandler` for more - /// details. - /// - errorHandling: Whether to provide assistance handling protocol errors (e.g. - /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. - /// - throws: If the pipeline could not be configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) throws { - try self._configureHTTPServerPipeline( - position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling - ) - } - #endif /// Configure a `ChannelPipeline` for use as a HTTP server. /// diff --git a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift index 489eef8954..1e68af848b 100644 --- a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift @@ -13,15 +13,6 @@ //===----------------------------------------------------------------------===// import NIOCore -/// A utility function that runs the body code only in debug builds, without -/// emitting compiler warnings. -/// -/// This is currently the only way to do this in Swift: see -/// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. -internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) -} - /// A `ChannelHandler` that handles HTTP pipelining by buffering inbound data until a /// response has been sent. /// @@ -56,6 +47,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha public typealias OutboundIn = HTTPServerResponsePart public typealias OutboundOut = HTTPServerResponsePart + // If this is true AND we're in a debug build, crash the program when an invariant in violated + // Otherwise, we will try to handle the situation as cleanly as possible + internal var failOnPreconditions: Bool = true + public init() { self.nextExpectedInboundMessage = nil self.nextExpectedOutboundMessage = nil @@ -66,6 +61,59 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } } + private enum ConnectionStateAction { + /// A precondition has been violated. Should send an error down the pipeline + case warnPreconditionViolated(message: String) + + /// A further state change was attempted when a precondition has already been violated. + /// Should force close this connection + case forceCloseConnection + + /// Nothing to do + case none + } + + public struct ConnectionStateError: Error, CustomStringConvertible, Hashable { + enum Base: Hashable, CustomStringConvertible { + /// A precondition was violated + case preconditionViolated(message: String) + + var description: String { + switch self { + case .preconditionViolated(let message): + return "Precondition violated \(message)" + } + } + } + + private var base: Base + private var file: String + private var line: Int + + private init(base: Base, file: String, line: Int) { + self.base = base + self.file = file + self.line = line + } + + public static func ==(lhs: ConnectionStateError, rhs: ConnectionStateError) -> Bool { + lhs.base == rhs.base + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.base) + } + + /// A precondition was violated + public static func preconditionViolated(message: String, file: String = #fileID, line: Int = #line) -> Self { + .init(base: .preconditionViolated(message: message), file: file, line: line) + } + + public var description: String { + "\(self.base) file \(self.file) line \(self.line)" + } + } + /// The state of the HTTP connection. private enum ConnectionState { /// We are waiting for a HTTP response to complete before we @@ -92,53 +140,76 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha /// never suppress reads again. case sentCloseOutput - mutating func requestHeadReceived() { + /// The user has violated an invariant. We should refuse further IO now + case preconditionFailed + + mutating func requestHeadReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .idle: self = .requestAndResponseEndPending + return .none case .requestAndResponseEndPending, .responseEndPending, .requestEndPending, .sentCloseOutputRequestEndPending, .sentCloseOutput: - preconditionFailure("received request head in state \(self)") + let message = "received request head in state \(self)" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } - mutating func responseEndReceived() { + mutating func responseEndReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .responseEndPending: // Got the response we were waiting for. self = .idle + return .none case .requestAndResponseEndPending: // We got a response while still receiving a request, which we have to // wait for. self = .requestEndPending + return .none case .sentCloseOutput, .sentCloseOutputRequestEndPending: // This is a user error: they have sent close(mode: .output), but are continuing to write. // The write will fail, so we can allow it to pass. - () + return .none case .requestEndPending, .idle: - preconditionFailure("Unexpectedly received a response in state \(self)") + let message = "Unexpectedly received a response in state \(self)" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } - mutating func requestEndReceived() { + mutating func requestEndReceived() -> ConnectionStateAction { switch self { + case .preconditionFailed: + return .forceCloseConnection case .requestEndPending: // Got the request end we were waiting for. self = .idle + return .none case .requestAndResponseEndPending: // We got a request and the response isn't done, wait for the // response. self = .responseEndPending + return .none case .sentCloseOutputRequestEndPending: // Got the request end we were waiting for. self = .sentCloseOutput + return .none case .responseEndPending, .idle, .sentCloseOutput: - preconditionFailure("Received second request") + let message = "Received second request" + self = .preconditionFailed + return .warnPreconditionViolated(message: message) } } mutating func closeOutputSent() { switch self { + case .preconditionFailed: + break case .idle, .responseEndPending: self = .sentCloseOutput case .requestEndPending, .requestAndResponseEndPending: @@ -221,35 +292,37 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self.eventBuffer.append(.channelRead(data)) return } else { - self.deliverOneMessage(context: context, data: data) + let connectionStateAction = self.deliverOneMessage(context: context, data: data) + _ = self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: nil) } } - private func deliverOneMessage(context: ChannelHandlerContext, data: NIOAny) { - assert(self.lifecycleState != .quiescingLastRequestEndReceived && - self.lifecycleState != .quiescingCompleted, - "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)") + private func deliverOneMessage(context: ChannelHandlerContext, data: NIOAny) -> ConnectionStateAction { + self.checkAssertion(self.lifecycleState != .quiescingLastRequestEndReceived && + self.lifecycleState != .quiescingCompleted, + "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)") let msg = self.unwrapInboundIn(data) debugOnly { switch msg { case .head: - assert(self.nextExpectedInboundMessage == .head) + self.checkAssertion(self.nextExpectedInboundMessage == .head) self.nextExpectedInboundMessage = .bodyOrEnd case .body: - assert(self.nextExpectedInboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedInboundMessage == .bodyOrEnd) case .end: - assert(self.nextExpectedInboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedInboundMessage == .bodyOrEnd) self.nextExpectedInboundMessage = .head } } + let action: ConnectionStateAction switch msg { case .head: - self.state.requestHeadReceived() + action = self.state.requestHeadReceived() case .end: // New request is complete. We don't want any more data from now on. - self.state.requestEndReceived() + action = self.state.requestEndReceived() if self.lifecycleState == .quiescingWaitingForRequestEnd { self.lifecycleState = .quiescingLastRequestEndReceived @@ -260,10 +333,11 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha context.close(promise: nil) } case .body: - () + action = .none } context.fireChannelRead(data) + return action } private func deliverOneError(context: ChannelHandlerContext, error: Error) { @@ -280,14 +354,16 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is ChannelShouldQuiesceEvent: - assert(self.lifecycleState == .acceptingEvents, - "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)") + self.checkAssertion(self.lifecycleState == .acceptingEvents, + "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)") switch self.state { case .responseEndPending: // we're not in the middle of a request, let's just shut the door self.lifecycleState = .quiescingLastRequestEndReceived self.eventBuffer.removeAll() - case .idle, .sentCloseOutput: + case .preconditionFailed, + // An invariant has been violated already, this time we close the connection + .idle, .sentCloseOutput: // we're completely idle, let's just close self.lifecycleState = .quiescingCompleted self.eventBuffer.removeAll() @@ -324,20 +400,20 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - assert(self.state != .requestEndPending, - "Received second response while waiting for first one to complete") + self.checkAssertion(self.state != .requestEndPending, + "Received second response while waiting for first one to complete") debugOnly { let res = self.unwrapOutboundIn(data) switch res { case .head(let head) where head.isInformational: - assert(self.nextExpectedOutboundMessage == .head) + self.checkAssertion(self.nextExpectedOutboundMessage == .head) case .head: - assert(self.nextExpectedOutboundMessage == .head) + self.checkAssertion(self.nextExpectedOutboundMessage == .head) self.nextExpectedOutboundMessage = .bodyOrEnd case .body: - assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedOutboundMessage == .bodyOrEnd) case .end: - assert(self.nextExpectedOutboundMessage == .bodyOrEnd) + self.checkAssertion(self.nextExpectedOutboundMessage == .bodyOrEnd) self.nextExpectedOutboundMessage = .head } } @@ -367,7 +443,7 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha case .quiescingCompleted: // Uh, why are we writing more data here? We'll write it, but it should be guaranteed // to fail. - assertionFailure("Wrote in quiescing completed state") + self.assertionFailed("Wrote in quiescing completed state") context.write(data, promise: promise) } case .body, .head: @@ -375,7 +451,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } if startReadingAgain { - self.state.responseEndReceived() + let connectionStateAction = self.state.responseEndReceived() + if self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: promise) { + return + } self.deliverPendingRequests(context: context) self.startReading(context: context) } @@ -477,6 +556,30 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } } + /// - Returns: True if an error was sent, ie the caller should not continue + private func handleConnectionStateAction( + context: ChannelHandlerContext, + action: ConnectionStateAction, + promise: EventLoopPromise? + ) -> Bool { + switch action { + case .warnPreconditionViolated(let message): + self.assertionFailed(message) + let error = ConnectionStateError.preconditionViolated(message: message) + self.deliverOneError(context: context, error: error) + promise?.fail(error) + return true + case .forceCloseConnection: + let message = "The connection has been forcefully closed because further IO was attempted after a precondition was violated" + let error = ConnectionStateError.preconditionViolated(message: message) + promise?.fail(error) + self.close(context: context, mode: .all, promise: nil) + return true + case .none: + return false + } + } + /// A response has been sent: we can now start passing reads through /// again if there are no further pending requests, and send any read() /// call we may have swallowed. @@ -503,8 +606,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha switch event { case .channelRead(let read): - self.deliverOneMessage(context: context, data: read) - deliveredRead = true + let connectionStateAction = self.deliverOneMessage(context: context, data: read) + if !self.handleConnectionStateAction(context: context, action: connectionStateAction, promise: nil) { + deliveredRead = true + } case .error(let error): self.deliverOneError(context: context, error: error) case .halfClose: @@ -557,6 +662,34 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self.eventBuffer.removeSubrange(firstHead...) } + + /// A utility function that runs the body code only in debug builds, without + /// emitting compiler warnings. + /// + /// This is currently the only way to do this in Swift: see + /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. + private func debugOnly(_ body: () -> Void) { + self.checkAssertion({ body(); return true }()) + } + + /// Calls assertionFailure if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests + private func assertionFailed(_ message: @autoclosure () -> String, file: StaticString = #file, line: UInt = #line) { + if self.failOnPreconditions { + assertionFailure(message(), file: file, line: line) + } + } + + /// Calls assert if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests + private func checkAssertion( + _ closure: @autoclosure () -> Bool, + _ message: @autoclosure () -> String = String(), + file: StaticString = #file, + line: UInt = #line + ) { + if self.failOnPreconditions { + assert(closure(), message(), file: file, line: line) + } + } } @available(*, unavailable) diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift new file mode 100644 index 0000000000..9021062488 --- /dev/null +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -0,0 +1,250 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || swift(>=5.10) +import NIOCore + +// MARK: - Server pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPServerPipelineConfiguration { + /// Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. + public var enablePipelining = true + + /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP + /// request) by sending 400 errors. Defaults to `true`. + public var enableErrorHandling = true + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableResponseHeaderValidation = true + + /// The configuration for the ``HTTPResponseEncoder``. + public var encoderConfiguration = HTTPResponseEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Assistance handling clients that pipeline HTTP requests. + /// 2. Assistance handling protocol errors. + /// 3. Outbound header fields validation to protect against response splitting attacks. + public init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline( + configuration: NIOUpgradableHTTPServerPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let responseEncoder = HTTPResponseEncoder(configuration: configuration.encoderConfiguration) + let requestDecoder = ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)) + + var extraHTTPHandlers = [RemovableChannelHandler]() + extraHTTPHandlers.reserveCapacity(4) + extraHTTPHandlers.append(requestDecoder) + + try self.addHandler(responseEncoder) + try self.addHandler(requestDecoder) + + if configuration.enablePipelining { + let pipeliningHandler = HTTPServerPipelineHandler() + try self.addHandler(pipeliningHandler) + extraHTTPHandlers.append(pipeliningHandler) + } + + if configuration.enableResponseHeaderValidation { + let headerValidationHandler = NIOHTTPResponseHeadersValidator() + try self.addHandler(headerValidationHandler) + extraHTTPHandlers.append(headerValidationHandler) + } + + if configuration.enableErrorHandling { + let errorHandler = HTTPServerProtocolErrorHandler() + try self.addHandler(errorHandler) + extraHTTPHandlers.append(errorHandler) + } + + let upgrader = NIOTypedHTTPServerUpgradeHandler( + httpEncoder: responseEncoder, + extraHTTPHandlers: extraHTTPHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} + +// MARK: - Client pipeline configuration + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPClientPipelineConfiguration { + /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. + public var leftOverBytesStrategy = RemoveAfterUpgradeStrategy.dropBytes + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableOutboundHeaderValidation = true + + /// The configuration for the ``HTTPRequestEncoder``. + public var encoderConfiguration = HTTPRequestEncoder.Configuration() + + /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. + public var upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Outbound header fields validation to protect against response splitting attacks. + public init( + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.upgradeConfiguration = upgradeConfiguration + } +} + +extension ChannelPipeline { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + self._configureUpgradableHTTPClientPipeline(configuration: configuration) + } + + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + private func _configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) -> EventLoopFuture> { + let future: EventLoopFuture> + + if self.eventLoop.inEventLoop { + let result = Result, Error> { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + future = self.eventLoop.makeCompletedFuture(result) + } else { + future = self.eventLoop.submit { + try self.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + } + } + + return future + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline( + configuration: NIOUpgradableHTTPClientPipelineConfiguration + ) throws -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + + let requestEncoder = HTTPRequestEncoder(configuration: configuration.encoderConfiguration) + let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy)) + var httpHandlers = [RemovableChannelHandler]() + httpHandlers.reserveCapacity(3) + httpHandlers.append(requestEncoder) + httpHandlers.append(responseDecoder) + + try self.addHandler(requestEncoder) + try self.addHandler(responseDecoder) + + if configuration.enableOutboundHeaderValidation { + let headerValidationHandler = NIOHTTPRequestHeadersValidator() + try self.addHandler(headerValidationHandler) + httpHandlers.append(headerValidationHandler) + } + + let upgrader = NIOTypedHTTPClientUpgradeHandler( + httpHandlers: httpHandlers, + upgradeConfiguration: configuration.upgradeConfiguration + ) + try self.addHandler(upgrader) + + return upgrader.upgradeResultFuture + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift new file mode 100644 index 0000000000..c683b61b3e --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -0,0 +1,285 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2013 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || swift(>=5.10) +import NIOCore + +/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a client-side channel. +/// It has the option of denying this upgrade based upon the server response. +public protocol NIOTypedHTTPClientProtocolUpgrader { + associatedtype UpgradeResult: Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol requires in the request to successfully upgrade. + /// These header fields will be added to the outbound request's "Connection" header field. + /// It is the responsibility of the custom headers call to actually add these required headers. + var requiredUpgradeHeaders: [String] { get } + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + func addCustom(upgradeRequestHeaders: inout HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPClientUpgradeConfiguration { + /// The initial request head that is sent out once the channel becomes active. + public var upgradeRequestHead: HTTPRequestHead + + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgradeRequestHead: HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + precondition(upgraders.count > 0, "A minimum of one protocol upgrader must be specified.") + self.upgradeRequestHead = upgradeRequestHead + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. +/// This handler will add all appropriate headers to perform an upgrade to +/// the a protocol. It may add headers for a set of protocols in preference order. +/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply +/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. +/// +/// The request sends an order of preference to request which protocol it would like to use for the upgrade. +/// It will only upgrade to the protocol that is returned first in the list and does not currently +/// have the capability to upgrade to multiple simultaneous layered protocols. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { + public typealias OutboundIn = HTTPClientRequestPart + public typealias OutboundOut = HTTPClientRequestPart + public typealias InboundIn = HTTPClientResponsePart + public typealias InboundOut = HTTPClientResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + private let upgradeRequestHead: HTTPRequestHead + private let httpHandlers: [RemovableChannelHandler] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private var stateMachine: NIOTypedHTTPClientUpgraderStateMachine + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// Create a ``NIOTypedHTTPClientUpgradeHandler``. + /// + /// - Parameters: + /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline + /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state + /// after the upgrade. It should include any handlers that are directly related to handling HTTP. + /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include + /// any other handler that cannot tolerate receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPClientUpgradeConfiguration + ) { + self.httpHandlers = httpHandlers + var upgradeRequestHead = upgradeConfiguration.upgradeRequestHead + Self.addHeaders( + to: &upgradeRequestHead, + upgraders: upgradeConfiguration.upgraders + ) + self.upgradeRequestHead = upgradeRequestHead + self.stateMachine = .init(upgraders: upgradeConfiguration.upgraders) + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelActive(context: ChannelHandlerContext) { + switch self.stateMachine.channelActive() { + case .writeUpgradeRequest: + context.write(self.wrapOutboundOut(.head(self.upgradeRequestHead)), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(.init()))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + + case .none: + break + } + } + + private static func addHeaders( + to requestHead: inout HTTPRequestHead, + upgraders: [any NIOTypedHTTPClientProtocolUpgrader] + ) { + let requiredHeaders = ["upgrade"] + upgraders.flatMap { $0.requiredUpgradeHeaders } + requestHead.headers.add(name: "Connection", value: requiredHeaders.joined(separator: ",")) + + let allProtocols = upgraders.map { $0.supportedProtocol.lowercased() } + requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) + + // Allow each upgrader the chance to add custom headers. + for upgrader in upgraders { + upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) + } + } + + public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + switch self.stateMachine.write() { + case .failWrite(let error): + promise?.fail(error) + + case .forwardWrite: + context.write(data, promise: promise) + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let responsePart = self.unwrapInboundIn(data) + self.channelRead(context: context, responsePart: responsePart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + private func channelRead(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { + switch self.stateMachine.channelReadResponsePart(responsePart) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + + case .startUpgrading(let upgrader, let responseHead): + self.startUpgrading( + context: context, + upgrader: upgrader, + responseHead: responseHead + ) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHead: HTTPResponseHead + ) { + // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the + // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. + self.removeHTTPHandlers(context: context) + .flatMap { + upgrader.upgrade(channel: context.channel, upgradeResponse: responseHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result) + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeHTTPHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.httpHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + let removeFutures = self.httpHandlers.map { context.pipeline.removeHandler($0) } + return .andAllSucceed(removeFutures, on: context.eventLoop) + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift new file mode 100644 index 0000000000..6e9c696811 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift @@ -0,0 +1,335 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || swift(>=5.10) +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPClientUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + /// The request has been sent. We are waiting for the upgrade response. + case awaitingUpgradeResponseHead(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) + + @usableFromInline + struct AwaitingUpgradeResponseEnd { + var upgrader: any NIOTypedHTTPClientProtocolUpgrader + var responseHead: HTTPResponseHead + } + /// We received the response head and are just waiting for the response end. + case awaitingUpgradeResponseEnd(AwaitingUpgradeResponseEnd) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state: State + + init(upgraders: [any NIOTypedHTTPClientProtocolUpgrader]) { + self.state = .initial(upgraders: upgraders) + } + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelActiveAction { + case writeUpgradeRequest + } + + @inlinable + mutating func channelActive() -> ChannelActiveAction? { + switch self.state { + case .initial(let upgraders): + self.state = .awaitingUpgradeResponseHead(upgraders: upgraders) + return .writeUpgradeRequest + + case .finished: + return nil + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering, .upgrading: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum WriteAction { + case failWrite(Error) + case forwardWrite + } + + @usableFromInline + func write() -> WriteAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading: + return .failWrite(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) + + case .unbuffering, .finished: + return .forwardWrite + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd: + return .unwrapData + + case .upgrading(var upgrading): + // We got a read while running upgrading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + + @usableFromInline + enum ChannelReadResponsePartAction { + case fireErrorCaughtAndRemoveHandler(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPClientProtocolUpgrader, + responseHeaders: HTTPResponseHead + ) + } + + @inlinable + mutating func channelReadResponsePart(_ responsePart: HTTPClientResponsePart) -> ChannelReadResponsePartAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .awaitingUpgradeResponseHead(let upgraders): + // We should decide if we can upgrade based on the first response header: if we aren't upgrading, + // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're + // upgrading, the only thing we should see is a response head. Anything else in an error. + guard case .head(let response) = responsePart else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + } + + // Assess whether the server has accepted our upgrade request. + guard case .switchingProtocols = response.status else { + var buffer = Deque() + buffer.append(.init(responsePart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // Ok, we have a HTTP response. Check if it's an upgrade confirmation. + // If it's not, we want to pass it on and remove ourselves from the channel pipeline. + let acceptedProtocols = response.headers[canonicalForm: "upgrade"] + + // At the moment we only upgrade to the first protocol returned from the server. + guard let protocolName = acceptedProtocols.first?.lowercased() else { + // There are no upgrade protocols returned. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + let matchingUpgrader = upgraders + .first(where: { $0.supportedProtocol.lowercased() == protocolName }) + + guard let upgrader = matchingUpgrader else { + // There is no upgrader for this protocol. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { + // The upgrader says no. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // We received the response head and decided that we can upgrade. + // We now need to wait for the response end and then we can perform the upgrade + self.state = .awaitingUpgradeResponseEnd(.init( + upgrader: upgrader, + responseHead: response + )) + return .none + + case .awaitingUpgradeResponseEnd(let awaitingUpgradeResponseEnd): + switch responsePart { + case .head: + // We got two HTTP response heads. + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.invalidHTTPOrdering) + + case .body: + // We tolerate body parts to be send but just ignore them + return .none + + case .end: + // We got the response end and can now run the upgrader. + self.state = .upgrading(.init(buffer: .init())) + return .startUpgrading( + upgrader: awaitingUpgradeResponseEnd.upgrader, + responseHeaders: awaitingUpgradeResponseEnd.responseHead + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") + + } + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift new file mode 100644 index 0000000000..b6a90b1294 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -0,0 +1,371 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || swift(>=5.10) +import NIOCore + +/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a server-side channel. +public protocol NIOTypedHTTPServerProtocolUpgrader { + associatedtype UpgradeResult: Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol needs in the request to successfully upgrade. These header fields + /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated + /// against the inbound request's `Connection` header field. + var requiredUpgradeHeaders: [String] { get } + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + func upgrade( + channel: Channel, + upgradeRequest: HTTPRequestHead + ) -> EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPServerUpgradeConfiguration { + /// The array of potential upgraders. + public var upgraders: [any NIOTypedHTTPServerProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + + public init( + upgraders: [any NIOTypedHTTPServerProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) { + self.upgraders = upgraders + self.notUpgradingCompletionHandler = notUpgradingCompletionHandler + } +} + +/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. +/// +/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of +/// whether the upgrade succeeded or not. +/// +/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade +/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's +/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined +/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: +/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias InboundOut = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private let upgraders: [String: any NIOTypedHTTPServerProtocolUpgrader] + private let notUpgradingCompletionHandler: @Sendable (Channel) -> EventLoopFuture + private let httpEncoder: HTTPResponseEncoder + private let extraHTTPHandlers: [RemovableChannelHandler] + private var stateMachine = NIOTypedHTTPServerUpgraderStateMachine() + + private var _upgradeResultPromise: EventLoopPromise? + private var upgradeResultPromise: EventLoopPromise { + precondition( + self._upgradeResultPromise != nil, + "Tried to access the upgrade result before the handler was added to a pipeline" + ) + return self._upgradeResultPromise! + } + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: EventLoopFuture { + self.upgradeResultPromise.futureResult + } + + /// Create a ``NIOTypedHTTPServerUpgradeHandler``. + /// + /// - Parameters: + /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will + /// be removed from the pipeline once the upgrade response is sent. This is used to ensure + /// that the pipeline will be in a clean state after upgrade. + /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least + /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate + /// receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init( + httpEncoder: HTTPResponseEncoder, + extraHTTPHandlers: [RemovableChannelHandler], + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration + ) { + var upgraderMap = [String: any NIOTypedHTTPServerProtocolUpgrader]() + for upgrader in upgradeConfiguration.upgraders { + upgraderMap[upgrader.supportedProtocol.lowercased()] = upgrader + } + self.upgraders = upgraderMap + self.notUpgradingCompletionHandler = upgradeConfiguration.notUpgradingCompletionHandler + self.httpEncoder = httpEncoder + self.extraHTTPHandlers = extraHTTPHandlers + } + + public func handlerAdded(context: ChannelHandlerContext) { + self._upgradeResultPromise = context.eventLoop.makePromise(of: UpgradeResult.self) + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { + case .failUpgradePromise: + self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState) + case .none: + break + } + } + + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { + switch self.stateMachine.channelReadData(data) { + case .unwrapData: + let requestPart = self.unwrapInboundIn(data) + self.channelRead(context: context, requestPart: requestPart) + + case .fireChannelRead: + context.fireChannelRead(data) + + case .none: + break + } + } + + private func channelRead(context: ChannelHandlerContext, requestPart: HTTPServerRequestPart) { + switch self.stateMachine.channelReadRequestPart(requestPart) { + case .failUpgradePromise(let error): + self.upgradeResultPromise.fail(error) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .findUpgrader(let head, let requestedProtocols, let allHeaderNames, let connectionHeader): + let protocolIterator = requestedProtocols.makeIterator() + self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: head, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ).whenComplete { result in + context.eventLoop.assertInEventLoop() + self.findingUpgradeCompleted(context: context, requestHead: head, result) + } + + case .startUpgrading(let upgrader, let requestHead, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .none: + break + } + } + + private func upgradingHandlerCompleted( + context: ChannelHandlerContext, + _ result: Result, + requestHeadAndProtocol: (HTTPRequestHead, String)? + ) { + switch self.stateMachine.upgradingHandlerCompleted(result) { + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .startUnbuffering(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + self.unbuffer(context: context) + + case .removeHandler(let value): + if let requestHeadAndProtocol = requestHeadAndProtocol { + context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + } + self.upgradeResultPromise.succeed(value) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + /// Attempt to upgrade a single protocol. + /// + /// Will recurse through `protocolIterator` if upgrade fails. + private func handleUpgradeForProtocol( + context: ChannelHandlerContext, + protocolIterator: Array.Iterator, + request: HTTPRequestHead, + allHeaderNames: Set, + connectionHeader: Set + ) -> EventLoopFuture<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?> { + // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. + var protocolIterator = protocolIterator + guard let proto = protocolIterator.next() else { + // We're done! No suitable protocol for upgrade. + return context.eventLoop.makeSucceededFuture(nil) + } + + guard let upgrader = self.upgraders[proto.lowercased()] else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) + guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + + let responseHeaders = self.buildUpgradeHeaders(protocol: proto) + return upgrader.buildUpgradeResponse( + channel: context.channel, + upgradeRequest: request, + initialResponseHeaders: responseHeaders + ) + .hop(to: context.eventLoop) + .map { (upgrader, $0, proto) } + .flatMapError { error in + // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. + context.fireErrorCaught(error) + return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + } + } + + private func findingUpgradeCompleted( + context: ChannelHandlerContext, + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) { + switch self.stateMachine.findingUpgraderCompleted(requestHead: requestHead, result) { + case .startUpgrading(let upgrader, let responseHeaders, let proto): + self.startUpgrading( + context: context, + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto + ) + + case .runNotUpgradingInitializer: + self.notUpgradingCompletionHandler(context.channel) + .hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: nil) + } + + case .fireErrorCaughtAndStartUnbuffering(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + self.unbuffer(context: context) + + case .fireErrorCaughtAndRemoveHandler(let error): + self.upgradeResultPromise.fail(error) + context.fireErrorCaught(error) + context.pipeline.removeHandler(self, promise: nil) + + case .none: + break + } + } + + private func startUpgrading( + context: ChannelHandlerContext, + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) { + // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP + // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until + // that completes. + // While there are a lot of Futures involved here it's quite possible that all of this code will + // actually complete synchronously: we just want to program for the possibility that it won't. + // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the + // internal handler, then call the user code, and then finally when the user code is done we do + // our final cleanup steps, namely we replay the received data we buffered in the meantime and + // then remove ourselves from the pipeline. + self.removeExtraHandlers(context: context).flatMap { + self.sendUpgradeResponse(context: context, responseHeaders: responseHeaders) + }.flatMap { + context.pipeline.removeHandler(self.httpEncoder) + }.flatMap { () -> EventLoopFuture in + return upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) + }.hop(to: context.eventLoop) + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) + } + } + + /// Sends the 101 Switching Protocols response for the pipeline. + private func sendUpgradeResponse(context: ChannelHandlerContext, responseHeaders: HTTPHeaders) -> EventLoopFuture { + var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) + response.headers = responseHeaders + return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) + } + + /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. + private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { + return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) + } + + /// Removes any extra HTTP-related handlers from the channel pipeline. + private func removeExtraHandlers(context: ChannelHandlerContext) -> EventLoopFuture { + guard self.extraHTTPHandlers.count > 0 else { + return context.eventLoop.makeSucceededFuture(()) + } + + return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, + on: context.eventLoop) + } + + private func unbuffer(context: ChannelHandlerContext) { + while true { + switch self.stateMachine.unbuffer() { + case .fireChannelRead(let data): + context.fireChannelRead(data) + + case .fireChannelReadCompleteAndRemoveHandler: + context.fireChannelReadComplete() + context.pipeline.removeHandler(self, promise: nil) + return + } + } + } +} +#endif diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift new file mode 100644 index 0000000000..bc2536f7c8 --- /dev/null +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -0,0 +1,386 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +#if !canImport(Darwin) || swift(>=5.10) +import DequeModule +import NIOCore + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +struct NIOTypedHTTPServerUpgraderStateMachine { + @usableFromInline + enum State { + /// The state before we received a TLSUserEvent. We are just forwarding any read at this point. + case initial + + @usableFromInline + struct AwaitingUpgrader { + var seenFirstRequest: Bool + var buffer: Deque + } + + /// The request head has been received. We're currently running the future chain awaiting an upgrader. + case awaitingUpgrader(AwaitingUpgrader) + + @usableFromInline + struct UpgraderReady { + var upgrader: any NIOTypedHTTPServerProtocolUpgrader + var requestHead: HTTPRequestHead + var responseHeaders: HTTPHeaders + var proto: String + var buffer: Deque + } + + /// We have an upgrader, which means we can begin upgrade we are just waiting for the request end. + case upgraderReady(UpgraderReady) + + @usableFromInline + struct Upgrading { + var buffer: Deque + } + /// We are either running the upgrading handler. + case upgrading(Upgrading) + + @usableFromInline + struct Unbuffering { + var buffer: Deque + } + case unbuffering(Unbuffering) + + case finished + + case modifying + } + + private var state = State.initial + + @usableFromInline + enum HandlerRemovedAction { + case failUpgradePromise + } + + @inlinable + mutating func handlerRemoved() -> HandlerRemovedAction? { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .unbuffering: + self.state = .finished + return .failUpgradePromise + + case .finished: + return .none + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadDataAction { + case unwrapData + case fireChannelRead + } + + @inlinable + mutating func channelReadData(_ data: NIOAny) -> ChannelReadDataAction? { + switch self.state { + case .initial: + return .unwrapData + + case .awaitingUpgrader(var awaitingUpgrader): + if awaitingUpgrader.seenFirstRequest { + // We should buffer the data since we have seen the full request. + self.state = .modifying + awaitingUpgrader.buffer.append(data) + self.state = .awaitingUpgrader(awaitingUpgrader) + return nil + } else { + // We shouldn't buffer. This means we are still expecting HTTP parts. + return .unwrapData + } + + case .upgraderReady: + // We have not seen the end of the HTTP request so this + // data is probably an HTTP request part. + return .unwrapData + + case .unbuffering(var unbuffering): + self.state = .modifying + unbuffering.buffer.append(data) + self.state = .unbuffering(unbuffering) + return nil + + case .finished: + return .fireChannelRead + + case .upgrading(var upgrading): + // We got a read while running ugprading. + // We have to buffer the read to unbuffer it afterwards + self.state = .modifying + upgrading.buffer.append(data) + self.state = .upgrading(upgrading) + return nil + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum ChannelReadRequestPartAction { + case failUpgradePromise(Error) + case runNotUpgradingInitializer + case startUpgrading( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + requestHead: HTTPRequestHead, + responseHeaders: HTTPHeaders, + proto: String + ) + case findUpgrader( + head: HTTPRequestHead, + requestedProtocols: [String], + allHeaderNames: Set, + connectionHeader: Set + ) + } + + @inlinable + mutating func channelReadRequestPart(_ requestPart: HTTPServerRequestPart) -> ChannelReadRequestPartAction? { + switch self.state { + case .initial: + guard case .head(let head) = requestPart else { + // The first data that we saw was not a head. This is a protocol error and we are just going to + // fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + } + + // Ok, we have a HTTP head. Check if it's an upgrade. + let requestedProtocols = head.headers[canonicalForm: "upgrade"].map(String.init) + guard requestedProtocols.count > 0 else { + // We have to buffer now since we got the request head but are not upgrading. + // The user is configuring the HTTP pipeline now. + var buffer = Deque() + buffer.append(NIOAny(requestPart)) + self.state = .upgrading(.init(buffer: buffer)) + return .runNotUpgradingInitializer + } + + // We can now transition to awaiting the upgrader. This means that we are trying to + // find an upgrade that can handle requested protocols. We are not buffering because + // we are waiting for the request end. + self.state = .awaitingUpgrader(.init(seenFirstRequest: false, buffer: .init())) + + let connectionHeader = Set(head.headers[canonicalForm: "connection"].map { $0.lowercased() }) + let allHeaderNames = Set(head.headers.map { $0.name.lowercased() }) + + return .findUpgrader( + head: head, + requestedProtocols: requestedProtocols, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) + + case .awaitingUpgrader(let awaitingUpgrader): + switch (awaitingUpgrader.seenFirstRequest, requestPart) { + case (true, _): + // This is weird we are seeing more requests parts after we have seen an end + // Let's fail upgrading + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .head): + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .body): + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case (false, .end): + // Okay we got the end as expected. Just gotta store this in our state. + self.state = .awaitingUpgrader(.init(seenFirstRequest: true, buffer: awaitingUpgrader.buffer)) + return nil + } + + case .upgraderReady(let upgraderReady): + switch requestPart { + case .head: + // This is weird we are seeing another head but haven't seen the end for the request before + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .body: + // This is weird we are seeing body parts for a request that indicated that it wanted + // to upgrade. + return .failUpgradePromise(HTTPServerUpgradeErrors.invalidHTTPOrdering) + + case .end: + // Okay we got the end as expected and our upgrader is ready so let's start upgrading + self.state = .upgrading(.init(buffer: upgraderReady.buffer)) + return .startUpgrading( + upgrader: upgraderReady.upgrader, + requestHead: upgraderReady.requestHead, + responseHeaders: upgraderReady.responseHeaders, + proto: upgraderReady.proto + ) + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UpgradingHandlerCompletedAction { + case fireErrorCaughtAndStartUnbuffering(Error) + case removeHandler(UpgradeResult) + case fireErrorCaughtAndRemoveHandler(Error) + case startUnbuffering(UpgradeResult) + } + + @inlinable + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + switch self.state { + case .initial: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .upgrading(let upgrading): + switch result { + case .success(let value): + if !upgrading.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .startUnbuffering(value) + } else { + self.state = .finished + return .removeHandler(value) + } + + case .failure(let error): + if !upgrading.buffer.isEmpty { + // So we failed to upgrade. There is nothing really that we can do here. + // We are unbuffering the reads but there shouldn't be any handler in the pipeline + // that expects a specific type of reads anyhow. + self.state = .unbuffering(.init(buffer: upgrading.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .finished: + // We have to tolerate this + return nil + + case .awaitingUpgrader, .upgraderReady, .unbuffering: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum FindingUpgraderCompletedAction { + case startUpgrading(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String) + case runNotUpgradingInitializer + case fireErrorCaughtAndStartUnbuffering(Error) + case fireErrorCaughtAndRemoveHandler(Error) + } + + @inlinable + mutating func findingUpgraderCompleted( + requestHead: HTTPRequestHead, + _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + ) -> FindingUpgraderCompletedAction? { + switch self.state { + case .initial, .upgraderReady: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .awaitingUpgrader(let awaitingUpgrader): + switch result { + case .success(.some((let upgrader, let responseHeaders, let proto))): + if awaitingUpgrader.seenFirstRequest { + // We have seen the end of the request. So we can upgrade now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .startUpgrading(upgrader: upgrader, responseHeaders: responseHeaders, proto: proto) + } else { + // We have not yet seen the end so we have to wait until that happens + self.state = .upgraderReady(.init( + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto, + buffer: awaitingUpgrader.buffer + )) + return nil + } + + case .success(.none): + // There was no upgrader to handle the request. We just run the not upgrading + // initializer now. + self.state = .upgrading(.init(buffer: awaitingUpgrader.buffer)) + return .runNotUpgradingInitializer + + case .failure(let error): + if !awaitingUpgrader.buffer.isEmpty { + self.state = .unbuffering(.init(buffer: awaitingUpgrader.buffer)) + return .fireErrorCaughtAndStartUnbuffering(error) + } else { + self.state = .finished + return .fireErrorCaughtAndRemoveHandler(error) + } + } + + case .upgrading, .unbuffering, .finished: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + } + } + + @usableFromInline + enum UnbufferAction { + case fireChannelRead(NIOAny) + case fireChannelReadCompleteAndRemoveHandler + } + + @inlinable + mutating func unbuffer() -> UnbufferAction { + switch self.state { + case .initial, .awaitingUpgrader, .upgraderReady, .upgrading, .finished: + preconditionFailure("Invalid state \(self.state)") + + case .unbuffering(var unbuffering): + self.state = .modifying + + if let element = unbuffering.buffer.popFirst() { + self.state = .unbuffering(unbuffering) + + return .fireChannelRead(element) + } else { + self.state = .finished + + return .fireChannelReadCompleteAndRemoveHandler + } + + case .modifying: + fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") + + } + } + +} +#endif diff --git a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift index e4e7387a5d..489af7b4a1 100644 --- a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift +++ b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift @@ -34,6 +34,7 @@ final class ByteBufferReadWriteMultipleIntegersBenchmark: func run() throws -> Int { var result: I = 0 for _ in 0..: func run() throws -> Int { var result: I = 0 for _ in 0...makeWriter(isWritable: true, delegate: self.delegate) + let newWriter = NIOAsyncWriter.makeWriter(isWritable: true, finishOnDeinit: false, delegate: self.delegate) self.writer = newWriter.writer self.sink = newWriter.sink } func setUp() async throws {} - func tearDown() {} + func tearDown() { + self.writer.finish() + } func run() async throws -> Int { for i in 0..! + private var serverEventLoop: EventLoop! final class ServerHandler: ChannelInboundHandler { public typealias InboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer - private var channel: Channel! + private var context: ChannelHandlerContext! public func channelActive(context: ChannelHandlerContext) { - self.channel = context.channel + self.context = context } public func send(_ message: ByteBuffer, times count: Int) { for _ in 0..? - init(_ benchmark: TCPThroughputBenchmark) { - self.benchmark = benchmark + init() { self.messagesReceived = 0 } + func prepareRun(expectedMessages: Int, promise: EventLoopPromise) { + self.expectedMessages = expectedMessages + self.completionPromise = promise + } + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.messagesReceived += 1 - if (self.benchmark.messages == self.messagesReceived) { - self.benchmark.isDonePromise.succeed() + + if (self.expectedMessages == self.messagesReceived) { + let promise = self.completionPromise + self.messagesReceived = 0 + self.expectedMessages = nil + self.completionPromise = nil + + promise!.succeed() } } } @@ -95,12 +106,12 @@ final class TCPThroughputBenchmark: Benchmark { func setUp() throws { self.group = MultiThreadedEventLoopGroup(numberOfThreads: 4) - let connectionEstablished: EventLoopPromise = self.group.next().makePromise() + let connectionEstablished: EventLoopPromise = self.group.next().makePromise() self.serverChannel = try ServerBootstrap(group: self.group) .childChannelInitializer { channel in self.serverHandler = ServerHandler() - connectionEstablished.succeed() + connectionEstablished.succeed(channel.eventLoop) return channel.pipeline.addHandler(self.serverHandler) } .bind(host: "127.0.0.1", port: 0) @@ -109,7 +120,7 @@ final class TCPThroughputBenchmark: Benchmark { self.clientChannel = try ClientBootstrap(group: group) .channelInitializer { channel in channel.pipeline.addHandler(ByteToMessageHandler(StreamDecoder())).flatMap { _ in - channel.pipeline.addHandler(ClientHandler(self)) + channel.pipeline.addHandler(ClientHandler()) } } .connect(to: serverChannel.localAddress!) @@ -122,7 +133,7 @@ final class TCPThroughputBenchmark: Benchmark { } self.message = message - try connectionEstablished.futureResult.wait() + self.serverEventLoop = try connectionEstablished.futureResult.wait() } func tearDown() { @@ -132,12 +143,22 @@ final class TCPThroughputBenchmark: Benchmark { } func run() throws -> Int { - self.isDonePromise = self.group.next().makePromise() - Task { - self.serverHandler.send(self.message, times: self.messages) + let isDonePromise = self.clientChannel.eventLoop.makePromise(of: Void.self) + let clientChannel = self.clientChannel! + let expectedMessages = self.messages + + try clientChannel.eventLoop.submit { + try clientChannel.pipeline.syncOperations.handler(type: ClientHandler.self).prepareRun(expectedMessages: expectedMessages, promise: isDonePromise) + }.wait() + + let serverHandler = self.serverHandler! + let message = self.message! + let messages = self.messages + + self.serverEventLoop.execute { + serverHandler.send(message, times: messages) } - try self.isDonePromise.futureResult.wait() - self.isDonePromise = nil + try isDonePromise.futureResult.wait() return 0 } } diff --git a/Sources/NIOPosix/BSDSocketAPICommon.swift b/Sources/NIOPosix/BSDSocketAPICommon.swift index 539136cc26..721143a990 100644 --- a/Sources/NIOPosix/BSDSocketAPICommon.swift +++ b/Sources/NIOPosix/BSDSocketAPICommon.swift @@ -83,8 +83,8 @@ extension NIOBSDSocket.SocketType { internal static let stream: NIOBSDSocket.SocketType = NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) #endif - - #if os(Linux) + + #if os(Linux) && !canImport(Musl) internal static let raw: NIOBSDSocket.SocketType = NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue)) #else diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 6e92f1c157..64d679c503 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore #if os(Windows) import ucrt @@ -25,13 +25,8 @@ import struct WinSDK.DWORD import struct WinSDK.HANDLE #endif -#if swift(>=5.7) /// The type of all `channelInitializer` callbacks. internal typealias ChannelInitializerCallback = @Sendable (Channel) -> EventLoopFuture -#else -/// The type of all `channelInitializer` callbacks. -internal typealias ChannelInitializerCallback = (Channel) -> EventLoopFuture -#endif /// Common functionality for all NIO on sockets bootstraps. internal enum NIOOnSocketsBootstraps { @@ -151,7 +146,6 @@ public final class ServerBootstrap { self.enableMPTCP = false } - #if swift(>=5.7) /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -166,23 +160,7 @@ public final class ServerBootstrap { self.serverChannelInit = initializer return self } - #else - /// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages. - /// - /// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`. - /// - /// - parameters: - /// - initializer: A closure that initializes the provided `Channel`. - public func serverChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture) -> Self { - self.serverChannelInit = initializer - return self - } - #endif - #if swift(>=5.7) /// Initialize the accepted `SocketChannel`s with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. Note that if the `initializer` fails then the error will be /// fired in the *parent* channel. @@ -202,26 +180,6 @@ public final class ServerBootstrap { self.childChannelInit = initializer return self } - #else - /// Initialize the accepted `SocketChannel`s with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. Note that if the `initializer` fails then the error will be - /// fired in the *parent* channel. - /// - /// - warning: The `initializer` will be invoked once for every accepted connection. Therefore it's usually the - /// right choice to instantiate stateful `ChannelHandler`s within the closure to make sure they are not - /// accidentally shared across `Channel`s. There are expert use-cases where stateful handler need to be - /// shared across `Channel`s in which case the user is responsible to synchronise the state access - /// appropriately. - /// - /// The accepted `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - parameters: - /// - initializer: A closure that initializes the provided `Channel`. - public func childChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture) -> Self { - self.childChannelInit = initializer - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `ServerSocketChannel`. /// @@ -511,23 +469,22 @@ extension ServerBootstrap { /// - Parameters: /// - host: The host to bind on. /// - port: The port to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, port: Int, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { let address = try SocketAddress.makeAddressResolvingHost(host, port: port) return try await bind( to: address, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer ) } @@ -536,15 +493,14 @@ extension ServerBootstrap { /// /// - Parameters: /// - address: The `SocketAddress` to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( to address: SocketAddress, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { return try await bind0( @@ -556,7 +512,7 @@ extension ServerBootstrap { enableMPTCP: enableMPTCP ) }, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer, registration: { serverChannel in serverChannel.registerAndDoSynchronously { serverChannel in @@ -572,16 +528,15 @@ extension ServerBootstrap { /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, /// unless `cleanupExistingSocketFile`is set to `true`. /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { if cleanupExistingSocketFile { @@ -592,25 +547,58 @@ extension ServerBootstrap { return try await self.bind( to: address, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer ) } + /// Bind the `ServerSocketChannel` to a VSOCK socket. + /// + /// - Parameters: + /// - vsockAddress: The VSOCK socket address to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind( + to vsockAddress: VsockAddress, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { + try ServerSocketChannel(eventLoop: eventLoop, group: childEventLoopGroup, protocolFamily: .vsock, enableMPTCP: enableMPTCP) + } + + return try await self.bind0( + makeServerChannel: makeChannel, + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer + ) { channel in + channel.register().flatMap { + let promise = channel.eventLoop.makePromise(of: Void.self) + channel.triggerUserOutboundEvent0( + VsockChannelEvents.BindToAddress(vsockAddress), + promise: promise + ) + return promise.futureResult + } + }.get() + } + /// Use the existing bound socket file descriptor. /// /// - Parameters: /// - socket: The _Unix file descriptor_ representing the bound stream socket. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( _ socket: NIOBSDSocket.Handle, cleanupExistingSocketFile: Bool = false, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { return try await bind0( @@ -624,7 +612,7 @@ extension ServerBootstrap { group: childEventLoopGroup ) }, - serverBackpressureStrategy: serverBackpressureStrategy, + serverBackPressureStrategy: serverBackPressureStrategy, childChannelInitializer: childChannelInitializer, registration: { serverChannel in let promise = serverChannel.eventLoop.makePromise(of: Void.self) @@ -637,9 +625,9 @@ extension ServerBootstrap { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) private func bind0( makeServerChannel: @escaping (SelectableEventLoop, EventLoopGroup, Bool) throws -> ServerSocketChannel, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - registration: @escaping @Sendable (Channel) -> EventLoopFuture + registration: @escaping @Sendable (ServerSocketChannel) -> EventLoopFuture ) -> EventLoopFuture> { let eventLoop = self.group.next() let childEventLoopGroup = self.childGroup @@ -665,9 +653,9 @@ extension ServerBootstrap { name: "AcceptHandler" ) let asyncChannel = try NIOAsyncChannel - .wrapAsyncChannelWithTransformations( - synchronouslyWrapping: serverChannel, - backpressureStrategy: serverBackpressureStrategy, + ._wrapAsyncChannelWithTransformations( + wrappingChannelSynchronously: serverChannel, + backPressureStrategy: serverBackPressureStrategy, channelReadTransformation: { channel -> EventLoopFuture in // The channelReadTransformation is run on the EL of the server channel // We have to make sure that we execute child channel initializer on the @@ -747,11 +735,7 @@ private extension Channel { /// The connected `SocketChannel` will operate on `ByteBuffer` as inbound and on `IOData` as outbound messages. public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { private let group: EventLoopGroup - #if swift(>=5.7) private var protocolHandlers: Optional<@Sendable () -> [ChannelHandler]> - #else - private var protocolHandlers: Optional<() -> [ChannelHandler]> - #endif private var _channelInitializer: ChannelInitializerCallback private var channelInitializer: ChannelInitializerCallback { if let protocolHandlers = self.protocolHandlers { @@ -807,7 +791,6 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self.enableMPTCP = false } - #if swift(>=5.7) /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -830,31 +813,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self._channelInitializer = handler return self } - #else - /// Initialize the connected `SocketChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and `IOData` as outbound messages. - /// - /// - warning: The `handler` closure may be invoked _multiple times_ so it's usually the right choice to instantiate - /// `ChannelHandler`s within `handler`. The reason `handler` may be invoked multiple times is that to - /// successfully set up a connection multiple connections might be setup in the process. Assuming a - /// hostname that resolves to both IPv4 and IPv6 addresses, NIO will follow - /// [_Happy Eyeballs_](https://en.wikipedia.org/wiki/Happy_Eyeballs) and race both an IPv4 and an IPv6 - /// connection. It is possible that both connections get fully established before the IPv4 connection - /// will be closed again because the IPv6 connection 'won the race'. Therefore the `channelInitializer` - /// might be called multiple times and it's important not to share stateful `ChannelHandler`s in more - /// than one `Channel`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self._channelInitializer = handler - return self - } - #endif - #if swift(>=5.7) /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the /// `channelInitializer` has been called. /// @@ -867,19 +826,6 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { self.protocolHandlers = handlers return self } - #else - /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the - /// `channelInitializer` has been called. - /// - /// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS - /// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses - /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. - public func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self { - precondition(self.protocolHandlers == nil, "protocol handlers can only be set once") - self.protocolHandlers = handlers - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `SocketChannel`. /// @@ -1028,6 +974,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - address: The VSOCK address to connect to. /// - returns: An `EventLoopFuture` for when the `Channel` is connected. public func connect(to address: VsockAddress) -> EventLoopFuture { + let connectTimeout = self.connectTimeout return self.initializeAndRegisterNewChannel( eventLoop: self.group.next(), protocolFamily: .vsock @@ -1035,8 +982,8 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { let connectPromise = channel.eventLoop.makePromise(of: Void.self) channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) - let cancelTask = channel.eventLoop.scheduleTask(in: self.connectTimeout) { - connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout)) + let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { + connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) channel.close(promise: nil) } connectPromise.futureResult.whenComplete { (_: Result) in @@ -1160,7 +1107,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, port: Int, @@ -1186,7 +1132,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1211,7 +1156,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1223,6 +1167,42 @@ extension ClientBootstrap { ) } + /// Specify the VSOCK address to connect to for the `Channel`. + /// + /// - Parameters: + /// - address: The VSOCK address to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + to address: VsockAddress, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + let connectTimeout = self.connectTimeout + return try await self.initializeAndRegisterNewChannel( + eventLoop: self.group.next(), + protocolFamily: NIOBSDSocket.ProtocolFamily.vsock, + channelInitializer: channelInitializer, + postRegisterTransformation: { result, eventLoop in + return eventLoop.makeSucceededFuture(result) + } + ) { channel in + let connectPromise = channel.eventLoop.makePromise(of: Void.self) + channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) + + let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { + connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) + channel.close(promise: nil) + } + connectPromise.futureResult.whenComplete { (_: Result) in + cancelTask.cancel() + } + + return connectPromise.futureResult + }.get().1 + } + /// Use the existing connected socket file descriptor. /// /// - Parameters: @@ -1231,7 +1211,6 @@ extension ClientBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func withConnectedSocket( _ socket: NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1444,7 +1423,6 @@ public final class DatagramBootstrap { self.channelInitializer = nil } - #if swift(>=5.7) /// Initialize the bound `DatagramChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -1455,17 +1433,6 @@ public final class DatagramBootstrap { self.channelInitializer = handler return self } - #else - /// Initialize the bound `DatagramChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self.channelInitializer = handler - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `DatagramChannel`. /// @@ -1672,7 +1639,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func withBoundSocket( _ socket: NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1703,7 +1669,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, port: Int, @@ -1728,7 +1693,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1754,7 +1718,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, @@ -1784,7 +1747,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, port: Int, @@ -1809,7 +1771,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -1833,7 +1794,6 @@ extension DatagramBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -2007,7 +1967,6 @@ public final class NIOPipeBootstrap { self.channelInitializer = nil } - #if swift(>=5.7) /// Initialize the connected `PipeChannel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -2021,20 +1980,6 @@ public final class NIOPipeBootstrap { self.channelInitializer = handler return self } - #else - /// Initialize the connected `PipeChannel` with `initializer`. The most common task in initializer is to add - /// `ChannelHandler`s to the `ChannelPipeline`. - /// - /// The connected `Channel` will operate on `ByteBuffer` as inbound and outbound messages. Please note that - /// `IOData.fileRegion` is _not_ supported for `PipeChannel`s because `sendfile` only works on sockets. - /// - /// - parameters: - /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture) -> Self { - self.channelInitializer = handler - return self - } - #endif /// Specifies a `ChannelOption` to be applied to the `PipeChannel`. /// @@ -2119,53 +2064,63 @@ public final class NIOPipeBootstrap { /// - output: The _Unix file descriptor_ for the output (ie. the write side). /// - Returns: an `EventLoopFuture` to deliver the `Channel`. public func takingOwnershipOfDescriptors(input: CInt, output: CInt) -> EventLoopFuture { - precondition(input >= 0 && output >= 0 && input != output, - "illegal file descriptor pair. The file descriptors \(input), \(output) " + - "must be distinct and both positive integers.") - let eventLoop = group.next() - do { - try self.validateFileDescriptorIsNotAFile(input) - try self.validateFileDescriptorIsNotAFile(output) - } catch { - return eventLoop.makeFailedFuture(error) - } + self._takingOwnershipOfDescriptors(input: input, output: output) + } - let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } - let channel: PipeChannel - do { - let inputFH = NIOFileHandle(descriptor: input) - let outputFH = NIOFileHandle(descriptor: output) - channel = try PipeChannel(eventLoop: eventLoop as! SelectableEventLoop, - inputPipe: inputFH, - outputPipe: outputFH) - } catch { - return eventLoop.makeFailedFuture(error) - } + /// Create the `PipeChannel` with the provided input file descriptor. + /// + /// The input file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `input`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - Returns: an `EventLoopFuture` to deliver the `Channel`. + public func takingOwnershipOfDescriptor( + input: CInt + ) -> EventLoopFuture { + self._takingOwnershipOfDescriptors(input: input, output: nil) + } - func setupChannel() -> EventLoopFuture { - eventLoop.assertInEventLoop() - return self._channelOptions.applyAllChannelOptions(to: channel).flatMap { - channelInitializer(channel) - }.flatMap { - eventLoop.assertInEventLoop() - let promise = eventLoop.makePromise(of: Void.self) - channel.registerAlreadyConfigured0(promise: promise) - return promise.futureResult - }.map { - channel - }.flatMapError { error in - channel.close0(error: error, mode: .all, promise: nil) - return channel.eventLoop.makeFailedFuture(error) - } - } + /// Create the `PipeChannel` with the provided output file descriptor. + /// + /// The output file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `output` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - Returns: an `EventLoopFuture` to deliver the `Channel`. + public func takingOwnershipOfDescriptor( + output: CInt + ) -> EventLoopFuture { + self._takingOwnershipOfDescriptors(input: nil, output: output) + } - if eventLoop.inEventLoop { - return setupChannel() - } else { - return eventLoop.flatSubmit { - setupChannel() + private func _takingOwnershipOfDescriptors(input: CInt?, output: CInt?) -> EventLoopFuture { + let channelInitializer: @Sendable (Channel) -> EventLoopFuture = { + let eventLoop = self.group.next() + let channelInitializer = self.channelInitializer + return { channel in + if let channelInitializer = channelInitializer { + return channelInitializer(channel).map { channel } + } else { + return eventLoop.makeSucceededFuture(channel) + } } - } + + }() + return self._takingOwnershipOfDescriptors( + input: input, + output: output, + channelInitializer: channelInitializer + ) } @available(*, deprecated, renamed: "takingOwnershipOfDescriptor(inputOutput:)") @@ -2197,7 +2152,6 @@ extension NIOPipeBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func takingOwnershipOfDescriptor( inputOutput: CInt, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture @@ -2216,12 +2170,10 @@ extension NIOPipeBootstrap { throw error } } - + /// Create the `PipeChannel` with the provided input and output file descriptors. /// - /// The input and output file descriptors must be distinct. If you have a single file descriptor, consider using - /// `ClientBootstrap.withConnectedSocket(descriptor:)` if it's a socket or - /// `NIOPipeBootstrap.takingOwnershipOfDescriptor` if it is not a socket. + /// The input and output file descriptors must be distinct. /// /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` and `output` /// when the `Channel` becomes inactive. You _must not_ do any further operations `input` or @@ -2236,7 +2188,6 @@ extension NIOPipeBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func takingOwnershipOfDescriptors( input: CInt, output: CInt, @@ -2245,49 +2196,128 @@ extension NIOPipeBootstrap { try await self._takingOwnershipOfDescriptors( input: input, output: output, - channelInitializer: channelInitializer, - postRegisterTransformation: { $0.makeSucceededFuture($1) } + channelInitializer: channelInitializer ) } - + + /// Create the `PipeChannel` with the provided input file descriptor. + /// + /// The input file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `input`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) // Should become private - public func _takingOwnershipOfDescriptors( + public func takingOwnershipOfDescriptor( input: CInt, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self._takingOwnershipOfDescriptors( + input: input, + output: nil, + channelInitializer: channelInitializer + ) + } + + /// Create the `PipeChannel` with the provided output file descriptor. + /// + /// The output file descriptor must be distinct. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `output` when the `Channel` + /// becomes inactive. You _must not_ do any further operations to `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing them. + /// + /// - Parameters: + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptor( output: CInt, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (EventLoop, ChannelInitializerResult) -> EventLoopFuture - ) async throws -> PostRegistrationTransformationResult { - precondition(input >= 0 && output >= 0 && input != output, - "illegal file descriptor pair. The file descriptors \(input), \(output) " + + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self._takingOwnershipOfDescriptors( + input: nil, + output: output, + channelInitializer: channelInitializer + ) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + func _takingOwnershipOfDescriptors( + input: CInt?, + output: CInt?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> ChannelInitializerResult { + try await self._takingOwnershipOfDescriptors( + input: input, + output: output, + channelInitializer: channelInitializer + ).get() + } + + func _takingOwnershipOfDescriptors( + input: CInt?, + output: CInt?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> EventLoopFuture { + precondition(input ?? 0 >= 0 && output ?? 0 >= 0 && input != output, + "illegal file descriptor pair. The file descriptors \(String(describing: input)), \(String(describing: output)) " + "must be distinct and both positive integers.") + precondition(!(input == nil && output == nil), "Either input or output has to be set") let eventLoop = group.next() - try self.validateFileDescriptorIsNotAFile(input) - try self.validateFileDescriptorIsNotAFile(output) + let channelOptions = self._channelOptions - let channelInitializer = { (channel: Channel) -> EventLoopFuture in - let initializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } - return initializer(channel).flatMap { channelInitializer(channel) } + let channel: PipeChannel + let inputFileHandle: NIOFileHandle? + let outputFileHandle: NIOFileHandle? + do { + if let input = input { + try self.validateFileDescriptorIsNotAFile(input) + } + if let output = output { + try self.validateFileDescriptorIsNotAFile(output) + } + + inputFileHandle = input.flatMap { NIOFileHandle(descriptor: $0) } + outputFileHandle = output.flatMap { NIOFileHandle(descriptor: $0) } + channel = try PipeChannel( + eventLoop: eventLoop as! SelectableEventLoop, + inputPipe: inputFileHandle, + outputPipe: outputFileHandle + ) + } catch { + return eventLoop.makeFailedFuture(error) } - let inputFileHandle = NIOFileHandle(descriptor: input) - let outputFileHandle = NIOFileHandle(descriptor: output) - let channel = try PipeChannel( - eventLoop: eventLoop as! SelectableEventLoop, - inputPipe: inputFileHandle, - outputPipe: outputFileHandle - ) @Sendable - func setupChannel() -> EventLoopFuture { + func setupChannel() -> EventLoopFuture { eventLoop.assertInEventLoop() - return self._channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in + return channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in channelInitializer(channel) }.flatMap { result in eventLoop.assertInEventLoop() let promise = eventLoop.makePromise(of: Void.self) channel.registerAlreadyConfigured0(promise: promise) - return promise.futureResult.flatMap { postRegisterTransformation(eventLoop, result) } + return promise.futureResult.map { result } + }.flatMap { result -> EventLoopFuture in + if inputFileHandle == nil { + return channel.close(mode: .input).map { result } + } + if outputFileHandle == nil { + return channel.close(mode: .output).map { result } + } + return channel.selectableEventLoop.makeSucceededFuture(result) }.flatMapError { error in channel.close0(error: error, mode: .all, promise: nil) return channel.eventLoop.makeFailedFuture(error) @@ -2295,11 +2325,11 @@ extension NIOPipeBootstrap { } if eventLoop.inEventLoop { - return try await setupChannel().get() + return setupChannel() } else { - return try await eventLoop.flatSubmit { + return eventLoop.flatSubmit { setupChannel() - }.get() + } } } } diff --git a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift index f9214a5f91..622122dea6 100644 --- a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift +++ b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift @@ -53,12 +53,8 @@ typealias ThreadInitializer = (NIOThread) -> Void /// test. A good place to start a `MultiThreadedEventLoopGroup` is the `setUp` method of your `XCTestCase` /// subclass, a good place to shut it down is the `tearDown` method. public final class MultiThreadedEventLoopGroup: EventLoopGroup { - #if swift(>=5.7) private typealias ShutdownGracefullyCallback = @Sendable (Error?) -> Void - #else - private typealias ShutdownGracefullyCallback = (Error?) -> Void - #endif - + private enum RunState { case running case closing([(DispatchQueue, ShutdownGracefullyCallback)]) @@ -262,7 +258,6 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { } } - #if swift(>=5.7) /// Shut this `MultiThreadedEventLoopGroup` down which causes the `EventLoop`s and their associated threads to be /// shut down and release their resources. /// @@ -277,22 +272,7 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { public func shutdownGracefully(queue: DispatchQueue, _ handler: @escaping @Sendable (Error?) -> Void) { self._shutdownGracefully(queue: queue, handler) } - #else - /// Shut this `MultiThreadedEventLoopGroup` down which causes the `EventLoop`s and their associated threads to be - /// shut down and release their resources. - /// - /// Even though calling `shutdownGracefully` more than once should be avoided, it is safe to do so and execution - /// of the `handler` is guaranteed. - /// - /// - parameters: - /// - queue: The `DispatchQueue` to run `handler` on when the shutdown operation completes. - /// - handler: The handler which is called after the shutdown operation completes. The parameter will be `nil` - /// on success and contain the `Error` otherwise. - public func shutdownGracefully(queue: DispatchQueue, _ handler: @escaping (Error?) -> Void) { - self._shutdownGracefully(queue: queue, handler) - } - #endif - + private func _shutdownGracefully(queue: DispatchQueue, _ handler: @escaping ShutdownGracefullyCallback) { guard self.canBeShutDown else { queue.async { @@ -422,8 +402,36 @@ extension MultiThreadedEventLoopGroup: CustomStringConvertible { } } +#if compiler(>=5.9) +@usableFromInline +struct ErasedUnownedJob { + @usableFromInline + let erasedJob: Any + + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + init(job: UnownedJob) { + self.erasedJob = job + } + + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @inlinable + var unownedJob: UnownedJob { + // This force-cast is safe since we only store an UnownedJob + self.erasedJob as! UnownedJob + } +} +#endif + @usableFromInline internal struct ScheduledTask { + @usableFromInline + enum UnderlyingTask { + case function(() -> Void) + #if compiler(>=5.9) + case unownedJob(ErasedUnownedJob) + #endif + } + /// The id of the scheduled task. /// /// - Important: This id has two purposes. First, it is used to give this struct an identity so that we can implement ``Equatable`` @@ -431,21 +439,32 @@ internal struct ScheduledTask { /// This means, the ids need to be unique for a given ``SelectableEventLoop`` and they need to be in ascending order. @usableFromInline let id: UInt64 - let task: () -> Void - private let failFn: (Error) ->() + let task: UnderlyingTask + private let failFn: ((Error) ->())? @usableFromInline internal let readyTime: NIODeadline @usableFromInline init(id: UInt64, _ task: @escaping () -> Void, _ failFn: @escaping (Error) -> Void, _ time: NIODeadline) { self.id = id - self.task = task + self.task = .function(task) self.failFn = failFn self.readyTime = time } + #if compiler(>=5.9) + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @usableFromInline + init(id: UInt64, job: consuming ExecutorJob, readyTime: NIODeadline) { + self.id = id + self.task = .unownedJob(.init(job: UnownedJob(job))) + self.readyTime = readyTime + self.failFn = nil + } + #endif + func fail(_ error: Error) { - failFn(error) + failFn?(error) } } diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index b57c699bec..39db189196 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -18,7 +18,7 @@ import NIOConcurrencyHelpers /// Errors that may be thrown when executing work on a `NIOThreadPool` public enum NIOThreadPoolError { - + /// The `NIOThreadPool` was not active. public struct ThreadPoolInactive: Error { public init() {} @@ -55,14 +55,9 @@ public final class NIOThreadPool { /// The `WorkItem` was cancelled and will not be processed by the `NIOThreadPool`. case cancelled } - - #if swift(>=5.7) + /// The work that should be done by the `NIOThreadPool`. public typealias WorkItem = @Sendable (WorkItemState) -> Void - #else - /// The work that should be done by the `NIOThreadPool`. - public typealias WorkItem = (WorkItemState) -> Void - #endif private enum State { /// The `NIOThreadPool` is already stopped. case stopped @@ -78,7 +73,6 @@ public final class NIOThreadPool { private let numberOfThreads: Int private let canBeStopped: Bool - #if swift(>=5.7) /// Gracefully shutdown this `NIOThreadPool`. All tasks will be run before shutdown will take place. /// /// - parameters: @@ -88,17 +82,7 @@ public final class NIOThreadPool { public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping @Sendable (Error?) -> Void) { self._shutdownGracefully(queue: queue, callback) } - #else - /// Gracefully shutdown this `NIOThreadPool`. All tasks will be run before shutdown will take place. - /// - /// - parameters: - /// - queue: The `DispatchQueue` used to executed the callback - /// - callback: The function to be executed once the shutdown is complete. - public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { - self._shutdownGracefully(queue: queue, callback) - } - #endif - + private func _shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { guard self.canBeStopped else { queue.async { @@ -135,10 +119,9 @@ public final class NIOThreadPool { callback(nil) } } - - - #if swift(>=5.7) + + /// Submit a `WorkItem` to process. /// /// - note: This is a low-level method, in most cases the `runIfActive` method should be used. @@ -149,17 +132,6 @@ public final class NIOThreadPool { public func submit(_ body: @escaping WorkItem) { self._submit(body) } - #else - /// Submit a `WorkItem` to process. - /// - /// - note: This is a low-level method, in most cases the `runIfActive` method should be used. - /// - /// - parameters: - /// - body: The `WorkItem` to process by the `NIOThreadPool`. - public func submit(_ body: @escaping WorkItem) { - self._submit(body) - } - #endif private func _submit(_ body: @escaping WorkItem) { let item = self.lock.withLock { () -> WorkItem? in @@ -176,7 +148,7 @@ public final class NIOThreadPool { /* if item couldn't be added run it immediately indicating that it couldn't be run */ item.map { $0(.cancelled) } } - + /// Initialize a `NIOThreadPool` thread pool with `numberOfThreads` threads. /// /// - parameters: @@ -290,8 +262,7 @@ public final class NIOThreadPool { extension NIOThreadPool: @unchecked Sendable {} extension NIOThreadPool { - - #if swift(>=5.7) + /// Runs the submitted closure if the thread pool is still active, otherwise fails the promise. /// The closure will be run on the thread pool so can do blocking work. /// @@ -303,19 +274,7 @@ extension NIOThreadPool { public func runIfActive(eventLoop: EventLoop, _ body: @escaping @Sendable () throws -> T) -> EventLoopFuture { self._runIfActive(eventLoop: eventLoop, body) } - #else - /// Runs the submitted closure if the thread pool is still active, otherwise fails the promise. - /// The closure will be run on the thread pool so can do blocking work. - /// - /// - parameters: - /// - eventLoop: The `EventLoop` the returned `EventLoopFuture` will fire on. - /// - body: The closure which performs some blocking work to be done on the thread pool. - /// - returns: The `EventLoopFuture` of `promise` fulfilled with the result (or error) of the passed closure. - public func runIfActive(eventLoop: EventLoop, _ body: @escaping () throws -> T) -> EventLoopFuture { - self._runIfActive(eventLoop: eventLoop, body) - } - #endif - + private func _runIfActive(eventLoop: EventLoop, _ body: @escaping () throws -> T) -> EventLoopFuture { let promise = eventLoop.makePromise(of: T.self) self.submit { shouldRun in @@ -331,19 +290,36 @@ extension NIOThreadPool { } return promise.futureResult } + + /// Runs the submitted closure if the thread pool is still active, otherwise throw an error. + /// The closure will be run on the thread pool so can do blocking work. + /// + /// - parameters: + /// - body: The closure which performs some blocking work to be done on the thread pool. + /// - returns: result of the passed closure. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func runIfActive(_ body: @escaping @Sendable () throws -> T) async throws -> T { + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + self.submit { shouldRun in + guard case shouldRun = NIOThreadPool.WorkItemState.active else { + cont.resume(throwing: NIOThreadPoolError.ThreadPoolInactive()) + return + } + do { + try cont.resume(returning: body()) + } catch { + cont.resume(throwing: error) + } + } + } + } } extension NIOThreadPool { - #if swift(>=5.7) @preconcurrency public func shutdownGracefully(_ callback: @escaping @Sendable (Error?) -> Void) { self.shutdownGracefully(queue: .global(), callback) } - #else - public func shutdownGracefully(_ callback: @escaping (Error?) -> Void) { - self.shutdownGracefully(queue: .global(), callback) - } - #endif /// Shuts down the thread pool gracefully. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -360,16 +336,10 @@ extension NIOThreadPool { } } - #if swift(>=5.7) @available(*, noasync, message: "this can end up blocking the calling thread", renamed: "shutdownGracefully()") public func syncShutdownGracefully() throws { try self._syncShutdownGracefully() } - #else - public func syncShutdownGracefully() throws { - try self._syncShutdownGracefully() - } - #endif private func _syncShutdownGracefully() throws { let errorStorageLock = NIOLock() @@ -390,4 +360,4 @@ extension NIOThreadPool { } } } -} +} \ No newline at end of file diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index 7fe20193e3..63830a8e0c 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -15,7 +15,7 @@ import NIOCore import NIOConcurrencyHelpers -/// `NonBlockingFileIO` is a helper that allows you to read files without blocking the calling thread. +/// ``NonBlockingFileIO`` is a helper that allows you to read files without blocking the calling thread. /// /// It is worth noting that `kqueue`, `epoll` or `poll` returning claiming a file is readable does not mean that the /// data is already available in the kernel's memory. In other words, a `read` from a file can still block even if @@ -25,19 +25,19 @@ import NIOConcurrencyHelpers /// - [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html): "epoll is simply a faster poll(2), and can be used wherever the latter is used since it shares the same semantics." /// - [`kqueue`](https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2): "Returns when the file pointer is not at the end of file." /// -/// `NonBlockingFileIO` helps to work around this issue by maintaining its own thread pool that is used to read the data +/// ``NonBlockingFileIO`` helps to work around this issue by maintaining its own thread pool that is used to read the data /// from the files into memory. It will then hand the (in-memory) data back which makes it available without the possibility /// of blocking. public struct NonBlockingFileIO: Sendable { - /// The default and recommended size for `NonBlockingFileIO`'s thread pool. + /// The default and recommended size for ``NonBlockingFileIO``'s thread pool. public static let defaultThreadPoolSize = 2 /// The default and recommended chunk size. public static let defaultChunkSize = 128*1024 - /// `NonBlockingFileIO` errors. + /// ``NonBlockingFileIO`` errors. public enum Error: Swift.Error { - /// `NonBlockingFileIO` is meant to be used with file descriptors that are set to the default (blocking) mode. + /// ``NonBlockingFileIO`` is meant to be used with file descriptors that are set to the default (blocking) mode. /// It doesn't make sense to use it with a file descriptor where `O_NONBLOCK` is set therefore this error is /// raised when that was requested. case descriptorSetToNonBlocking @@ -45,16 +45,15 @@ public struct NonBlockingFileIO: Sendable { private let threadPool: NIOThreadPool - /// Initialize a `NonBlockingFileIO` which uses the `NIOThreadPool`. + /// Initialize a ``NonBlockingFileIO`` which uses the `NIOThreadPool`. /// /// - parameters: /// - threadPool: The `NIOThreadPool` that will be used for all the IO. public init(threadPool: NIOThreadPool) { self.threadPool = threadPool } - - #if swift(>=5.7) - /// Read a `FileRegion` in chunks of `chunkSize` bytes on `NonBlockingFileIO`'s private thread + + /// Read a `FileRegion` in chunks of `chunkSize` bytes on ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `fileRegion.readableBytes` is greater than @@ -89,45 +88,8 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read a `FileRegion` in chunks of `chunkSize` bytes on `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `fileRegion.readableBytes` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ fileRegion.readableBytes / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `fileRegion.readableBytes` bytes can be read from the file, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the - /// same `FileRegion` in multiple threads. - /// - /// - parameters: - /// - fileRegion: The file region to read. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileRegion: FileRegion, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - let readableBytes = fileRegion.readableBytes - return self.readChunked(fileHandle: fileRegion.fileHandle, - fromOffset: Int64(fileRegion.readerIndex), - byteCount: readableBytes, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - #if swift(>=5.7) - /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread + /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than @@ -165,48 +127,8 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ byteCount / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `byteCount` bytes can be read from `descriptor`, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// - note: `readChunked(fileRegion:chunkSize:allocator:eventLoop:chunkHandler:)` should be preferred as it uses - /// `FileRegion` object instead of raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, - /// please consider using `readChunked(fileHandle:fromOffset:chunkSize:allocator:eventLoop:chunkHandler:)` - /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file - /// descriptors and even across processes.) - /// - /// - parameters: - /// - fileHandle: The `NIOFileHandle` to read from. - /// - byteCount: The number of bytes to read from `fileHandle`. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileHandle: NIOFileHandle, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: nil, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - #if swift(>=5.7) - /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread + /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread /// pool which is separate from any `EventLoop` thread. /// /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than @@ -246,53 +168,8 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop, chunkHandler: chunkHandler) } - #else - /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread - /// pool which is separate from any `EventLoop` thread. - /// - /// `chunkHandler` will be called on `eventLoop` for every chunk that was read. Assuming `byteCount` is greater than - /// zero and there are enough bytes available `chunkHandler` will be called `1 + |_ byteCount / chunkSize _|` - /// times, delivering `chunkSize` bytes each time. If less than `byteCount` bytes can be read from `descriptor`, - /// `chunkHandler` will be called less often with the last invocation possibly being of less than `chunkSize` bytes. - /// - /// The allocation and reading of a subsequent chunk will only be attempted when `chunkHandler` succeeds. - /// - /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the - /// same `NIOFileHandle` in multiple threads. - /// - /// - note: `readChunked(fileRegion:chunkSize:allocator:eventLoop:chunkHandler:)` should be preferred as it uses - /// `FileRegion` object instead of raw `NIOFileHandle`s. - /// - /// - parameters: - /// - fileHandle: The `NIOFileHandle` to read from. - /// - byteCount: The number of bytes to read from `fileHandle`. - /// - chunkSize: The size of the individual chunks to deliver. - /// - allocator: A `ByteBufferAllocator` used to allocate space for the chunks. - /// - eventLoop: The `EventLoop` to call `chunkHandler` on. - /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. - /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. - public func readChunked(fileHandle: NIOFileHandle, - fromOffset fileOffset: Int64, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: fileOffset, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) - } - #endif - - #if swift(>=5.7) + private typealias ReadChunkHandler = @Sendable (ByteBuffer) -> EventLoopFuture - #else - private typealias ReadChunkHandler = (ByteBuffer) -> EventLoopFuture - #endif private func readChunked0(fileHandle: NIOFileHandle, fromOffset: Int64?, @@ -303,7 +180,7 @@ public struct NonBlockingFileIO: Sendable { precondition(chunkSize > 0, "chunkSize must be > 0 (is \(chunkSize))") let remainingReads = 1 + (byteCount / chunkSize) let lastReadSize = byteCount % chunkSize - + let promise = eventLoop.makePromise(of: Void.self) func _read(remainingReads: Int, bytesReadSoFar: Int64) { @@ -325,7 +202,7 @@ public struct NonBlockingFileIO: Sendable { return } let bytesRead = Int64(buffer.readableBytes) - chunkHandler(buffer).whenComplete { result in + chunkHandler(buffer).hop(to: eventLoop).whenComplete { result in switch result { case .success(_): eventLoop.assertInEventLoop() @@ -348,7 +225,7 @@ public struct NonBlockingFileIO: Sendable { return promise.futureResult } - /// Read a `FileRegion` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Read a `FileRegion` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `fileRegion.readableBytes` unless we hit end-of-file in which /// case the `ByteBuffer` will contain the bytes available to read. @@ -373,15 +250,15 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop) } - /// Read `byteCount` bytes from `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Read `byteCount` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which /// case the `ByteBuffer` will contain the bytes available to read. /// /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. - /// - note: `read(fileRegion:allocator:eventLoop:)` should be preferred as it uses `FileRegion` object instead of + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of /// raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, - /// please consider using `read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)` + /// please consider using ``read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)`` /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file /// descriptors and even across processes.) /// @@ -402,7 +279,7 @@ public struct NonBlockingFileIO: Sendable { eventLoop: eventLoop) } - /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in `NonBlockingFileIO`'s private thread pool + /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in ``NonBlockingFileIO``'s private thread pool /// which is separate from any `EventLoop` thread. /// /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which @@ -412,7 +289,7 @@ public struct NonBlockingFileIO: Sendable { /// same `fileHandle` in multiple threads. /// /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. - /// - note: `read(fileRegion:allocator:eventLoop:)` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to read. @@ -443,40 +320,49 @@ public struct NonBlockingFileIO: Sendable { } let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) - var buf = allocator.buffer(capacity: byteCount) return self.threadPool.runIfActive(eventLoop: eventLoop) { () -> ByteBuffer in - var bytesRead = 0 - while bytesRead < byteCount { - let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in - let res = try fileHandle.withUnsafeFileDescriptor { descriptor -> IOResult in - if let offset = fromOffset { - return try Posix.pread(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead, - offset: off_t(offset) + off_t(bytesRead)) - } else { - return try Posix.read(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead) - } - } - switch res { - case .processed(let n): - assert(n >= 0, "read claims to have read a negative number of bytes \(n)") - return n - case .wouldBlock: - throw Error.descriptorSetToNonBlocking + try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + } + } + + private func readSync( + fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount: Int, + allocator: ByteBufferAllocator + ) throws -> ByteBuffer { + var bytesRead = 0 + var buf = allocator.buffer(capacity: byteCount) + while bytesRead < byteCount { + let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in + let res = try fileHandle.withUnsafeFileDescriptor { descriptor -> IOResult in + if let offset = fromOffset { + return try Posix.pread(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead, + offset: off_t(offset) + off_t(bytesRead)) + } else { + return try Posix.read(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead) } } - if n == 0 { - // EOF - break - } else { - bytesRead += n + switch res { + case .processed(let n): + assert(n >= 0, "read claims to have read a negative number of bytes \(n)") + return n + case .wouldBlock: + throw Error.descriptorSetToNonBlocking } } - return buf + if n == 0 { + // EOF + break + } else { + bytesRead += n + } } + return buf } /// Changes the file size of `fileHandle` to `size`. @@ -499,12 +385,12 @@ public struct NonBlockingFileIO: Sendable { } } - /// Returns the length of the file associated with `fileHandle`. + /// Returns the length of the file in bytes associated with `fileHandle`. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to read from. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. - /// - returns: An `EventLoopFuture` which is fulfilled if the write was successful or fails on error. + /// - returns: An `EventLoopFuture` which is fulfilled with the length of the file in bytes if the write was successful or fails on error. public func readFileSize(fileHandle: NIOFileHandle, eventLoop: EventLoop) -> EventLoopFuture { return self.threadPool.runIfActive(eventLoop: eventLoop) { @@ -517,7 +403,7 @@ public struct NonBlockingFileIO: Sendable { } } - /// Write `buffer` to `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Write `buffer` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to write to. @@ -530,7 +416,7 @@ public struct NonBlockingFileIO: Sendable { return self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer, eventLoop: eventLoop) } - /// Write `buffer` starting from `toOffset` to `fileHandle` in `NonBlockingFileIO`'s private thread pool which is separate from any `EventLoop` thread. + /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. /// /// - parameters: /// - fileHandle: The `NIOFileHandle` to write to. @@ -556,35 +442,44 @@ public struct NonBlockingFileIO: Sendable { } return self.threadPool.runIfActive(eventLoop: eventLoop) { - var buf = buffer - - var offsetAccumulator: Int = 0 - repeat { - let n = try buf.readWithUnsafeReadableBytes { ptr in - precondition(ptr.count == byteCount - offsetAccumulator) - let res: IOResult = try fileHandle.withUnsafeFileDescriptor { descriptor in - if let toOffset = toOffset { - return try Posix.pwrite(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator, - offset: off_t(toOffset + Int64(offsetAccumulator))) - } else { - return try Posix.write(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator) - } - } - switch res { - case .processed(let n): - assert(n >= 0, "write claims to have written a negative number of bytes \(n)") - return n - case .wouldBlock: - throw Error.descriptorSetToNonBlocking + try self.writeSync(fileHandle: fileHandle, byteCount: byteCount, toOffset: toOffset, buffer: buffer) + } + } + + private func writeSync( + fileHandle: NIOFileHandle, + byteCount: Int, + toOffset: Int64?, + buffer: ByteBuffer + ) throws { + var buf = buffer + + var offsetAccumulator: Int = 0 + repeat { + let n = try buf.readWithUnsafeReadableBytes { ptr in + precondition(ptr.count == byteCount - offsetAccumulator) + let res: IOResult = try fileHandle.withUnsafeFileDescriptor { descriptor in + if let toOffset = toOffset { + return try Posix.pwrite(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator, + offset: off_t(toOffset + Int64(offsetAccumulator))) + } else { + return try Posix.write(descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator) } } - offsetAccumulator += n - } while offsetAccumulator < byteCount - } + switch res { + case .processed(let n): + assert(n >= 0, "write claims to have written a negative number of bytes \(n)") + return n + case .wouldBlock: + throw Error.descriptorSetToNonBlocking + } + } + offsetAccumulator += n + } while offsetAccumulator < byteCount } /// Open the file at `path` for reading on a private thread pool which is separate from any `EventLoop` thread. @@ -835,3 +730,377 @@ public struct NIODirectoryEntry: Hashable { } } #endif + +extension NonBlockingFileIO { + /// Read a `FileRegion` in ``NonBlockingFileIO``'s private thread pool. + /// + /// The returned `ByteBuffer` will not have less than the minimum of `fileRegion.readableBytes` and `UInt32.max` unless we hit + /// end-of-file in which case the `ByteBuffer` will contain the bytes available to read. + /// + /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the + /// same `FileRegion` in multiple threads. + /// + /// - note: Only use this function for small enough `FileRegion`s as it will need to allocate enough memory to hold `fileRegion.readableBytes` bytes. + /// - note: In most cases you should prefer one of the `readChunked` functions. + /// + /// - parameters: + /// - fileRegion: The file region to read. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read(fileRegion: FileRegion, allocator: ByteBufferAllocator) async throws -> ByteBuffer { + let readableBytes = fileRegion.readableBytes + return try await self.read( + fileHandle: fileRegion.fileHandle, + fromOffset: Int64(fileRegion.readerIndex), + byteCount: readableBytes, + allocator: allocator + ) + } + + /// Read `byteCount` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which + /// case the `ByteBuffer` will contain the bytes available to read. + /// + /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of + /// raw `NIOFileHandle`s. In case you do want to use raw `NIOFileHandle`s, + /// please consider using ``read(fileHandle:fromOffset:byteCount:allocator:eventLoop:)`` + /// because it doesn't use the file descriptor's seek pointer (which may be shared with other file + /// descriptors and even across processes.) + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read. + /// - byteCount: The number of bytes to read from `fileHandle`. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read( + fileHandle: NIOFileHandle, + byteCount: Int, + allocator: ByteBufferAllocator + ) async throws-> ByteBuffer { + return try await self.read0( + fileHandle: fileHandle, + fromOffset: nil, + byteCount: byteCount, + allocator: allocator + ) + } + + /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in ``NonBlockingFileIO``'s private thread pool + ///. + /// + /// The returned `ByteBuffer` will not have less than `byteCount` bytes unless we hit end-of-file in which + /// case the `ByteBuffer` will contain the bytes available to read. + /// + /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the + /// same `fileHandle` in multiple threads. + /// + /// - note: Only use this function for small enough `byteCount`s as it will need to allocate enough memory to hold `byteCount` bytes. + /// - note: ``read(fileRegion:allocator:eventLoop:)`` should be preferred as it uses `FileRegion` object instead of raw `NIOFileHandle`s. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read. + /// - fileOffset: The offset to read from. + /// - byteCount: The number of bytes to read from `fileHandle`. + /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. + /// - returns: ByteBuffer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func read(fileHandle: NIOFileHandle, + fromOffset fileOffset: Int64, + byteCount: Int, + allocator: ByteBufferAllocator) async throws -> ByteBuffer { + return try await self.read0(fileHandle: fileHandle, + fromOffset: fileOffset, + byteCount: byteCount, + allocator: allocator) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func read0(fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount rawByteCount: Int, + allocator: ByteBufferAllocator) async throws -> ByteBuffer { + guard rawByteCount > 0 else { + return allocator.buffer(capacity: 0) + } + let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) + + return try await self.threadPool.runIfActive { () -> ByteBuffer in + try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + } + } + + /// Changes the file size of `fileHandle` to `size`. + /// + /// If `size` is smaller than the current file size, the remaining bytes will be truncated and are lost. If `size` + /// is larger than the current file size, the gap will be filled with zero bytes. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - size: The new file size in bytes to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func changeFileSize(fileHandle: NIOFileHandle, + size: Int64) async throws { + return try await self.threadPool.runIfActive { + try fileHandle.withUnsafeFileDescriptor { descriptor -> Void in + try Posix.ftruncate(descriptor: descriptor, size: off_t(size)) + } + } + } + + /// Returns the length of the file associated with `fileHandle`. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to read from. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func readFileSize(fileHandle: NIOFileHandle) async throws -> Int64 { + return try await self.threadPool.runIfActive { + return try fileHandle.withUnsafeFileDescriptor { descriptor in + let curr = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_CUR) + let eof = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_END) + try Posix.lseek(descriptor: descriptor, offset: curr, whence: SEEK_SET) + return Int64(eof) + } + } + } + + /// Write `buffer` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - buffer: The `ByteBuffer` to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func write(fileHandle: NIOFileHandle, + buffer: ByteBuffer) async throws { + return try await self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer) + } + + /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool. + /// + /// - parameters: + /// - fileHandle: The `NIOFileHandle` to write to. + /// - toOffset: The file offset to write to. + /// - buffer: The `ByteBuffer` to write. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func write(fileHandle: NIOFileHandle, + toOffset: Int64, + buffer: ByteBuffer) async throws { + return try await self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func write0(fileHandle: NIOFileHandle, + toOffset: Int64?, + buffer: ByteBuffer) async throws { + let byteCount = buffer.readableBytes + + guard byteCount > 0 else { + return + } + + return try await self.threadPool.runIfActive { + try self.writeSync(fileHandle: fileHandle, byteCount: byteCount, toOffset: toOffset, buffer: buffer) + } + } + + /// Open file at `path` and query its size on a private thread pool, run an operation given + /// the resulting file region and then close the file handle. + /// + /// The will return the result of the operation. + /// + /// - note: This function opens a file and queries it size which are both blocking operations + /// + /// - parameters: + /// - path: The path of the file to be opened for reading. + /// - body: operation to run with file handle and region + /// - returns: return value of operation + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withFileRegion( + path: String, + _ body: (_ fileRegion: FileRegion) async throws -> Result + ) async throws -> Result { + let fileRegion = try await self.threadPool.runIfActive { + let fh = try NIOFileHandle(path: path) + do { + let fr = try FileRegion(fileHandle: fh) + return UnsafeTransfer(fr) + } catch { + _ = try? fh.close() + throw error + } + } + let result: Result + do { + result = try await body(fileRegion.wrappedValue) + } catch { + try fileRegion.wrappedValue.fileHandle.close() + throw error + } + try fileRegion.wrappedValue.fileHandle.close() + return result + } + + /// Open file at `path` on a private thread pool, run an operation given the file handle and then close the file handle. + /// + /// This function will return the result of the operation. + /// + /// - parameters: + /// - path: The path of the file to be opened for writing. + /// - mode: File access mode. + /// - flags: Additional POSIX flags. + /// - returns: return value of operation + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withFileHandle( + path: String, + mode: NIOFileHandle.Mode, + flags: NIOFileHandle.Flags = .default, + _ body: (NIOFileHandle) async throws -> Result + ) async throws -> Result { + let fileHandle = try await self.threadPool.runIfActive { + return try UnsafeTransfer(NIOFileHandle(path: path, mode: mode, flags: flags)) + } + let result: Result + do { + result = try await body(fileHandle.wrappedValue) + } catch { + try fileHandle.wrappedValue.close() + throw error + } + try fileHandle.wrappedValue.close() + return result + } + +#if !os(Windows) + + /// Returns information about a file at `path` on a private thread pool. + /// + /// - note: If `path` is a symlink, information about the link, not the file it points to. + /// + /// - parameters: + /// - path: The path of the file to get information about. + /// - returns: file information. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func lstat(path: String) async throws -> stat { + return try await self.threadPool.runIfActive { + var s = stat() + try Posix.lstat(pathname: path, outStat: &s) + return s + } + } + + /// Creates a symbolic link to a `destination` file at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the link. + /// - destination: Target path where this link will point to. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func symlink(path: String, to destination: String) async throws { + return try await self.threadPool.runIfActive { + try Posix.symlink(pathname: path, destination: destination) + } + } + + /// Returns target of the symbolic link at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the link to read. + /// - returns: link target. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func readlink(path: String) async throws -> String { + return try await self.threadPool.runIfActive { + let maxLength = Int(PATH_MAX) + let pointer = UnsafeMutableBufferPointer.allocate(capacity: maxLength) + defer { + pointer.deallocate() + } + let length = try Posix.readlink(pathname: path, outPath: pointer.baseAddress!, outPathSize: maxLength) + return String(decoding: UnsafeRawBufferPointer(pointer).prefix(length), as: UTF8.self) + } + } + + /// Removes symbolic link at `path` on a private thread pool which is separate from any `EventLoop` thread. + /// + /// - parameters: + /// - path: The path of the link to remove. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func unlink(path: String) async throws { + return try await self.threadPool.runIfActive { + try Posix.unlink(pathname: path) + } + } + + + /// Creates directory at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the directory to be created. + /// - withIntermediateDirectories: Whether intermediate directories should be created. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func createDirectory(path: String, withIntermediateDirectories createIntermediates: Bool = false, mode: NIOPOSIXFileMode) async throws { + return try await self.threadPool.runIfActive { + if createIntermediates { + #if canImport(Darwin) + try Posix.mkpath_np(pathname: path, mode: mode) + #else + try self.createDirectory0(path, mode: mode) + #endif + } else { + try Posix.mkdir(pathname: path, mode: mode) + } + } + } + + /// List contents of the directory at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the directory to list the content of. + /// - returns: The directory entries. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func listDirectory(path: String) async throws -> [NIODirectoryEntry] { + return try await self.threadPool.runIfActive { + let dir = try Posix.opendir(pathname: path) + var entries: [NIODirectoryEntry] = [] + do { + while let entry = try Posix.readdir(dir: dir) { + let name = withUnsafeBytes(of: entry.pointee.d_name) { pointer -> String in + let ptr = pointer.baseAddress!.assumingMemoryBound(to: CChar.self) + return String(cString: ptr) + } + entries.append(NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name)) + } + try? Posix.closedir(dir: dir) + } catch { + try? Posix.closedir(dir: dir) + throw error + } + return entries + } + } + + /// Renames the file at `path` to `newName` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the file to be renamed. + /// - newName: New file name. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func rename(path: String, newName: String) async throws { + return try await self.threadPool.runIfActive() { + try Posix.rename(pathname: path, newName: newName) + } + } + + /// Removes the file at `path` on a private thread pool. + /// + /// - parameters: + /// - path: The path of the file to be removed. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func remove(path: String) async throws { + return try await self.threadPool.runIfActive() { + try Posix.remove(pathname: path) + } + } +#endif +} \ No newline at end of file diff --git a/Sources/NIOPosix/PipeChannel.swift b/Sources/NIOPosix/PipeChannel.swift index 053f5f14a4..069fdfcc40 100644 --- a/Sources/NIOPosix/PipeChannel.swift +++ b/Sources/NIOPosix/PipeChannel.swift @@ -21,14 +21,18 @@ final class PipeChannel: BaseStreamSocketChannel { case output } - init(eventLoop: SelectableEventLoop, - inputPipe: NIOFileHandle, - outputPipe: NIOFileHandle) throws { + init( + eventLoop: SelectableEventLoop, + inputPipe: NIOFileHandle?, + outputPipe: NIOFileHandle? + ) throws { self.pipePair = try PipePair(inputFD: inputPipe, outputFD: outputPipe) - try super.init(socket: self.pipePair, - parent: nil, - eventLoop: eventLoop, - recvAllocator: AdaptiveRecvByteBufferAllocator()) + try super.init( + socket: self.pipePair, + parent: nil, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator() + ) } func registrationForInput(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { @@ -56,50 +60,62 @@ final class PipeChannel: BaseStreamSocketChannel { } override func register(selector: Selector, interested: SelectorEventSet) throws { - try selector.register(selectable: self.pipePair.inputFD, - interested: interested.intersection([.read, .reset]), - makeRegistration: self.registrationForInput) - try selector.register(selectable: self.pipePair.outputFD, - interested: interested.intersection([.write, .reset]), - makeRegistration: self.registrationForOutput) + if let inputFD = self.pipePair.inputFD { + try selector.register( + selectable: inputFD, + interested: interested.intersection([.read, .reset]), + makeRegistration: self.registrationForInput + ) + } + if let outputFD = self.pipePair.outputFD { + try selector.register( + selectable: outputFD, + interested: interested.intersection([.write, .reset]), + makeRegistration: self.registrationForOutput + ) + } } override func deregister(selector: Selector, mode: CloseMode) throws { - if (mode == .all || mode == .input) && self.pipePair.inputFD.isOpen { - try selector.deregister(selectable: self.pipePair.inputFD) + if let inputFD = self.pipePair.inputFD, (mode == .all || mode == .input) && inputFD.isOpen { + try selector.deregister(selectable: inputFD) } - if (mode == .all || mode == .output) && self.pipePair.outputFD.isOpen { - try selector.deregister(selectable: self.pipePair.outputFD) + if let outputFD = self.pipePair.outputFD, (mode == .all || mode == .output) && outputFD.isOpen { + try selector.deregister(selectable: outputFD) } } override func reregister(selector: Selector, interested: SelectorEventSet) throws { - if self.pipePair.inputFD.isOpen { - try selector.reregister(selectable: self.pipePair.inputFD, - interested: interested.intersection([.read, .reset])) + if let inputFD = self.pipePair.inputFD, inputFD.isOpen { + try selector.reregister( + selectable: inputFD, + interested: interested.intersection([.read, .reset]) + ) } - if self.pipePair.outputFD.isOpen { - try selector.reregister(selectable: self.pipePair.outputFD, - interested: interested.intersection([.write, .reset])) + if let outputFD = self.pipePair.outputFD, outputFD.isOpen { + try selector.reregister( + selectable: outputFD, + interested: interested.intersection([.write, .reset]) + ) } } override func readEOF() { super.readEOF() - guard self.pipePair.inputFD.isOpen else { + guard let inputFD = self.pipePair.inputFD, inputFD.isOpen else { return } try! self.selectableEventLoop.deregister(channel: self, mode: .input) - try! self.pipePair.inputFD.close() + try! inputFD.close() } override func writeEOF() { - guard self.pipePair.outputFD.isOpen else { + guard let outputFD = self.pipePair.outputFD, outputFD.isOpen else { return } try! self.selectableEventLoop.deregister(channel: self, mode: .output) - try! self.pipePair.outputFD.close() + try! outputFD.close() } override func shutdownSocket(mode: CloseMode) throws { diff --git a/Sources/NIOPosix/PipePair.swift b/Sources/NIOPosix/PipePair.swift index 921fdab3ce..cb9a92d4b3 100644 --- a/Sources/NIOPosix/PipePair.swift +++ b/Sources/NIOPosix/PipePair.swift @@ -38,14 +38,14 @@ extension SelectableFileHandle: Selectable { final class PipePair: SocketProtocol { typealias SelectableType = SelectableFileHandle - let inputFD: SelectableFileHandle - let outputFD: SelectableFileHandle + let inputFD: SelectableFileHandle? + let outputFD: SelectableFileHandle? - init(inputFD: NIOFileHandle, outputFD: NIOFileHandle) throws { - self.inputFD = SelectableFileHandle(inputFD) - self.outputFD = SelectableFileHandle(outputFD) + init(inputFD: NIOFileHandle?, outputFD: NIOFileHandle?) throws { + self.inputFD = inputFD.flatMap { SelectableFileHandle($0) } + self.outputFD = outputFD.flatMap { SelectableFileHandle($0) } try self.ignoreSIGPIPE() - for fileHandle in [inputFD, outputFD] { + for fileHandle in [inputFD, outputFD].compactMap({ $0 }) { try fileHandle.withUnsafeFileDescriptor { try NIOFileHandle.setNonBlocking(fileDescriptor: $0) } @@ -53,7 +53,7 @@ final class PipePair: SocketProtocol { } func ignoreSIGPIPE() throws { - for fileHandle in [self.inputFD, self.outputFD] { + for fileHandle in [self.inputFD, self.outputFD].compactMap({ $0 }) { try fileHandle.withUnsafeHandle { try PipePair.ignoreSIGPIPE(descriptor: $0) } @@ -61,7 +61,7 @@ final class PipePair: SocketProtocol { } var description: String { - return "PipePair { in=\(inputFD), out=\(outputFD) }" + return "PipePair { in=\(String(describing: inputFD)), out=\(String(describing: inputFD)) }" } func connect(to address: SocketAddress) throws -> Bool { @@ -73,19 +73,28 @@ final class PipePair: SocketProtocol { } func write(pointer: UnsafeRawBufferPointer) throws -> IOResult { - return try self.outputFD.withUnsafeHandle { + guard let outputFD = self.outputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try outputFD.withUnsafeHandle { try Posix.write(descriptor: $0, pointer: pointer.baseAddress!, size: pointer.count) } } func writev(iovecs: UnsafeBufferPointer) throws -> IOResult { - return try self.outputFD.withUnsafeHandle { + guard let outputFD = self.outputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try outputFD.withUnsafeHandle { try Posix.writev(descriptor: $0, iovecs: iovecs) } } func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult { - return try self.inputFD.withUnsafeHandle { + guard let inputFD = self.inputFD else { + fatalError("Internal inconsistency inside NIO. Please file a bug") + } + return try inputFD.withUnsafeHandle { try Posix.read(descriptor: $0, pointer: pointer.baseAddress!, size: pointer.count) } } @@ -119,30 +128,30 @@ final class PipePair: SocketProtocol { func shutdown(how: Shutdown) throws { switch how { case .RD: - try self.inputFD.close() + try self.inputFD?.close() case .WR: - try self.outputFD.close() + try self.outputFD?.close() case .RDWR: try self.close() } } var isOpen: Bool { - return self.inputFD.isOpen || self.outputFD.isOpen + return self.inputFD?.isOpen ?? false || self.outputFD?.isOpen ?? false } func close() throws { - guard self.inputFD.isOpen || self.outputFD.isOpen else { + guard self.isOpen else { throw ChannelError.alreadyClosed } let r1 = Result { - if self.inputFD.isOpen { - try self.inputFD.close() + if let inputFD = self.inputFD, inputFD.isOpen { + try inputFD.close() } } let r2 = Result { - if self.outputFD.isOpen { - try self.outputFD.close() + if let outputFD = self.outputFD, outputFD.isOpen { + try outputFD.close() } } try r1.get() diff --git a/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift new file mode 100644 index 0000000000..ad50430c9f --- /dev/null +++ b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift @@ -0,0 +1,123 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Atomics +import NIOCore + +#if compiler(>=5.9) +private protocol SilenceWarning { + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + func enqueue(_ job: UnownedJob) +} +@available(macOS 14, *) +extension SelectableEventLoop: SilenceWarning {} +#endif + +private let _haveWeTakenOverTheConcurrencyPool = ManagedAtomic(false) +extension NIOSingletons { + /// Install ``MultiThreadedEventLoopGroup/singleton`` as Swift Concurrency's global executor. + /// + /// This allows to use Swift Concurrency and retain high-performance I/O alleviating the otherwise necessary thread switches between + /// Swift Concurrency's own global pool and a place (like an `EventLoop`) that allows to perform I/O + /// + /// This method uses an atomic compare and exchange to install the hook (and makes sure it's not already set). This unilateral atomic memory + /// operation doesn't guarantee anything because another piece of code could have done the same without using atomic operations. But we + /// do our best. + /// + /// - warning: You may only call this method from the main thread. + /// - warning: You may only call this method once. + @discardableResult + public static func unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() -> Bool { + #if /* minimum supported */ compiler(>=5.9) && /* maximum tested */ swift(<5.11) + guard #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) else { + return false + } + + typealias ConcurrencyEnqueueGlobalHook = @convention(thin) ( + UnownedJob, @convention(thin) (UnownedJob) -> Void + ) -> Void + + guard + _haveWeTakenOverTheConcurrencyPool.compareExchange( + expected: false, + desired: true, + ordering: .relaxed + ).exchanged + else { + fatalError("Must be called only once") + } + + #if canImport(Darwin) + guard pthread_main_np() == 1 else { + fatalError("Must be called from the main thread") + } + #endif + + // Unsafe 1: We pretend that the hook's type is actually fully equivalent to `ConcurrencyEnqueueGlobalHook` + // @convention(thin) (UnownedJob, @convention(thin) (UnownedJob) -> Void) -> Void + // which isn't formally guaranteed. + let concurrencyEnqueueGlobalHookPtr = dlsym( + dlopen(nil, RTLD_NOW), + "swift_task_enqueueGlobal_hook" + )?.assumingMemoryBound(to: UnsafeRawPointer?.AtomicRepresentation.self) + guard let concurrencyEnqueueGlobalHookPtr = concurrencyEnqueueGlobalHookPtr else { + return false + } + + // We will use an atomic operation to swap the pointers aiming to protect against other code that attempts + // to swap the pointer. This isn't guaranteed to work as we can't be sure that the other code will actually + // use atomic compares and exchanges to. Nevertheless, we're doing our best. + let concurrencyEnqueueGlobalHookAtomic = UnsafeAtomic(at: concurrencyEnqueueGlobalHookPtr) + // note: We don't need to destroy this atomic as we're borrowing the storage from `dlsym`. + + return withUnsafeTemporaryAllocation( + of: ConcurrencyEnqueueGlobalHook.self, + capacity: 1 + ) { enqueueOnNIOPtr -> Bool in + // Unsafe 2: We mandate that we're actually getting _the_ function pointer to the closure below which + // isn't formally guaranteed by Swift. + enqueueOnNIOPtr.baseAddress!.initialize(to: { job, _ in + // This formally picks a random EventLoop from the singleton group. However, `EventLoopGroup.any()` + // attempts to be sticky. So if we're already in an `EventLoop` that's part of the singleton + // `EventLoopGroup`, we'll get that one and be very fast (avoid a thread switch). + let targetEL = MultiThreadedEventLoopGroup.singleton.any() + + (targetEL.executor as! any SilenceWarning).enqueue(job) + }) + + // Unsafe 3: We mandate that the function pointer can be reinterpreted as `UnsafeRawPointer` which isn't + // formally guaranteed by Swift + return enqueueOnNIOPtr.baseAddress!.withMemoryRebound( + to: UnsafeRawPointer.self, + capacity: 1 + ) { enqueueOnNIOPtr in + // Unsafe 4: We just pretend that we're the only ones in the world pulling this trick (or at least + // that the others also use a `compareExchange`)... + guard concurrencyEnqueueGlobalHookAtomic.compareExchange( + expected: nil, + desired: enqueueOnNIOPtr.pointee, + ordering: .relaxed + ).exchanged else { + return false + } + + // nice, everything worked. + return true + } + } + #else + return false + #endif + } +} diff --git a/Sources/NIOPosix/RawSocketBootstrap.swift b/Sources/NIOPosix/RawSocketBootstrap.swift index 9712bd3a9b..9847f17e1f 100644 --- a/Sources/NIOPosix/RawSocketBootstrap.swift +++ b/Sources/NIOPosix/RawSocketBootstrap.swift @@ -11,7 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore /// A `RawSocketBootstrap` is an easy way to interact with IP based protocols other then TCP and UDP. /// @@ -204,7 +204,6 @@ extension NIORawSocketBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func bind( host: String, ipProtocol: NIOIPProtocol, @@ -227,7 +226,6 @@ extension NIORawSocketBootstrap { /// method. /// - Returns: The result of the channel initializer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) public func connect( host: String, ipProtocol: NIOIPProtocol, diff --git a/Sources/NIOPosix/SelectableEventLoop.swift b/Sources/NIOPosix/SelectableEventLoop.swift index 52339eccee..3d498a968f 100644 --- a/Sources/NIOPosix/SelectableEventLoop.swift +++ b/Sources/NIOPosix/SelectableEventLoop.swift @@ -299,6 +299,20 @@ Further information: }, .now())) } + #if compiler(>=5.9) + @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) + @usableFromInline + func enqueue(_ job: consuming ExecutorJob) { + let scheduledTask = ScheduledTask( + id: self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed), + job: job, + readyTime: .now() + ) + // nothing we can do if we fail enqueuing here. + try? self._schedule0(scheduledTask) + } + #endif + /// Add the `ScheduledTask` to be executed. @usableFromInline internal func _schedule0(_ task: ScheduledTask) throws { @@ -515,7 +529,19 @@ Further information: for task in self.tasksCopy { /* for macOS: in case any calls we make to Foundation put objects into an autoreleasepool */ withAutoReleasePool { - task.task() + switch task.task { + case .function(let function): + function() + + #if compiler(>=5.9) + case .unownedJob(let erasedUnownedJob): + if #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) { + erasedUnownedJob.unownedJob.runSynchronously(on: self.asUnownedSerialExecutor()) + } else { + fatalError("Tried to run an UnownedJob without runtime support") + } + #endif + } } } // Drop everything (but keep the capacity) so we can fill it again on the next iteration. diff --git a/Sources/NIOPosix/SocketProtocols.swift b/Sources/NIOPosix/SocketProtocols.swift index 62f1c46569..90cca51fad 100644 --- a/Sources/NIOPosix/SocketProtocols.swift +++ b/Sources/NIOPosix/SocketProtocols.swift @@ -73,7 +73,13 @@ protocol SocketProtocol: BaseSocketProtocol { // This is a lazily initialised global variable that when read for the first time, will ignore SIGPIPE. private let globallyIgnoredSIGPIPE: Bool = { /* no F_SETNOSIGPIPE on Linux :( */ + #if canImport(Glibc) _ = Glibc.signal(SIGPIPE, SIG_IGN) + #elseif canImport(Musl) + _ = Musl.signal(SIGPIPE, SIG_IGN) + #else + #error("Don't know which stdlib to use") + #endif return true }() #endif diff --git a/Sources/NIOPosix/System.swift b/Sources/NIOPosix/System.swift index 025e60f0ed..6fea9a56f8 100644 --- a/Sources/NIOPosix/System.swift +++ b/Sources/NIOPosix/System.swift @@ -90,7 +90,7 @@ func sysRecvFrom_wrapper(sockfd: CInt, buf: UnsafeMutableRawPointer, len: CLong, return recvfrom(sockfd, buf, len, flags, src_addr, addrlen) // src_addr is 'UnsafeMutablePointer', but it need to be 'UnsafePointer' } func sysWritev_wrapper(fd: CInt, iov: UnsafePointer?, iovcnt: CInt) -> CLong { - return CLong(writev(fd, iov, iovcnt)) // cast 'Int32' to 'CLong' + return CLong(writev(fd, iov!, iovcnt)) // cast 'Int32' to 'CLong' } private let sysWritev = sysWritev_wrapper #elseif !os(Windows) @@ -106,59 +106,40 @@ private let sysGetpeername: @convention(c) (CInt, UnsafeMutablePointer private let sysGetsockname: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getsockname #endif +#if os(Android) +private let sysIfNameToIndex: @convention(c) (UnsafePointer) -> CUnsignedInt = if_nametoindex +#else private let sysIfNameToIndex: @convention(c) (UnsafePointer?) -> CUnsignedInt = if_nametoindex +#endif #if !os(Windows) private let sysSocketpair: @convention(c) (CInt, CInt, CInt, UnsafeMutablePointer?) -> CInt = socketpair #endif -#if os(Linux) && !canImport(Musl) -private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer) -> CInt = fstat -private let sysStat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = stat -private let sysLstat: @convention(c) (UnsafePointer, UnsafeMutablePointer) -> CInt = lstat -private let sysSymlink: @convention(c) (UnsafePointer, UnsafePointer) -> CInt = symlink -private let sysReadlink: @convention(c) (UnsafePointer, UnsafeMutablePointer, Int) -> CLong = readlink -private let sysUnlink: @convention(c) (UnsafePointer) -> CInt = unlink -private let sysMkdir: @convention(c) (UnsafePointer, mode_t) -> CInt = mkdir -private let sysOpendir: @convention(c) (UnsafePointer) -> OpaquePointer? = opendir -private let sysReaddir: @convention(c) (OpaquePointer) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (OpaquePointer) -> CInt = closedir -private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename -private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove -#elseif canImport(Darwin) || os(Android) -private let sysFstat: @convention(c) (CInt, UnsafeMutablePointer?) -> CInt = fstat -private let sysStat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = stat -private let sysLstat: @convention(c) (UnsafePointer?, UnsafeMutablePointer?) -> CInt = lstat -private let sysSymlink: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = symlink -private let sysReadlink: @convention(c) (UnsafePointer?, UnsafeMutablePointer?, Int) -> CLong = readlink -private let sysUnlink: @convention(c) (UnsafePointer?) -> CInt = unlink -private let sysMkdir: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkdir -#if os(Android) -private let sysOpendir: @convention(c) (UnsafePointer?) -> OpaquePointer? = opendir -private let sysReaddir: @convention(c) (OpaquePointer?) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (OpaquePointer?) -> CInt = closedir -#else -private let sysMkpath: @convention(c) (UnsafePointer?, mode_t) -> CInt = mkpath_np -private let sysOpendir: @convention(c) (UnsafePointer?) -> UnsafeMutablePointer? = opendir -private let sysReaddir: @convention(c) (UnsafeMutablePointer?) -> UnsafeMutablePointer? = readdir -private let sysClosedir: @convention(c) (UnsafeMutablePointer?) -> CInt = closedir -#endif -private let sysRename: @convention(c) (UnsafePointer?, UnsafePointer?) -> CInt = rename -private let sysRemove: @convention(c) (UnsafePointer?) -> CInt = remove +#if os(Linux) || os(Android) || canImport(Darwin) +private let sysFstat = fstat +private let sysStat = stat +private let sysLstat = lstat +private let sysSymlink = symlink +private let sysReadlink = readlink +private let sysUnlink = unlink +private let sysMkdir = mkdir +private let sysOpendir = opendir +private let sysReaddir = readdir +private let sysClosedir = closedir +private let sysRename = rename +private let sysRemove = remove #endif #if os(Linux) || os(Android) -private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIOLinux_sendmmsg -private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIOLinux_recvmmsg +private let sysSendMmsg = CNIOLinux_sendmmsg +private let sysRecvMmsg = CNIOLinux_recvmmsg #elseif canImport(Darwin) private let sysKevent = kevent -private let sysSendMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt) -> CInt = CNIODarwin_sendmmsg -private let sysRecvMmsg: @convention(c) (CInt, UnsafeMutablePointer?, CUnsignedInt, CInt, UnsafeMutablePointer?) -> CInt = CNIODarwin_recvmmsg +private let sysMkpath = mkpath_np +private let sysSendMmsg = CNIODarwin_sendmmsg +private let sysRecvMmsg = CNIODarwin_recvmmsg #endif #if !os(Windows) -#if canImport(Musl) -private let sysIoctl: @convention(c) (CInt, CInt, UnsafeMutableRawPointer) -> CInt = ioctl -#else private let sysIoctl: @convention(c) (CInt, CUnsignedLong, UnsafeMutableRawPointer) -> CInt = ioctl -#endif // canImport(Musl) #endif // !os(Windows) private func isUnacceptableErrno(_ code: Int32) -> Bool { @@ -732,7 +713,7 @@ internal enum Posix { @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { return try syscall(blocking: false) { - sysIfNameToIndex(name) + sysIfNameToIndex(name!) }.result } diff --git a/Sources/NIOPosix/Thread.swift b/Sources/NIOPosix/Thread.swift index be392dec4d..68b486c72a 100644 --- a/Sources/NIOPosix/Thread.swift +++ b/Sources/NIOPosix/Thread.swift @@ -17,7 +17,7 @@ import CNIOLinux #endif enum LowLevelThreadOperations { - + } protocol ThreadOps { @@ -185,8 +185,7 @@ public final class ThreadSpecificVariable { self.currentValue = value } - - #if swift(>=5.7) + /// The value for the current thread. @available(*, noasync, message: "threads can change between suspension points and therefore the thread specific value too") public var currentValue: Value? { @@ -197,18 +196,7 @@ public final class ThreadSpecificVariable { self.set(newValue) } } - #else - /// The value for the current thread. - public var currentValue: Value? { - get { - self.get() - } - set { - self.set(newValue) - } - } - #endif - + /// Get the current value for the calling thread. func get() -> Value? { guard let raw = self.key.get() else { return nil } @@ -218,7 +206,7 @@ public final class ThreadSpecificVariable { .takeUnretainedValue() .value.1 as! Value) } - + /// Set the current value for the calling threads. The `currentValue` for all other threads remains unchanged. func set(_ newValue: Value?) { if let raw = self.key.get() { diff --git a/Sources/NIOPosix/ThreadPosix.swift b/Sources/NIOPosix/ThreadPosix.swift index b6e0ed4a30..852f08f634 100644 --- a/Sources/NIOPosix/ThreadPosix.swift +++ b/Sources/NIOPosix/ThreadPosix.swift @@ -19,7 +19,11 @@ import CNIOLinux private let sys_pthread_getname_np = CNIOLinux_pthread_getname_np private let sys_pthread_setname_np = CNIOLinux_pthread_setname_np +#if os(Android) +private typealias ThreadDestructor = @convention(c) (UnsafeMutableRawPointer) -> UnsafeMutableRawPointer +#else private typealias ThreadDestructor = @convention(c) (UnsafeMutableRawPointer?) -> UnsafeMutableRawPointer? +#endif #elseif canImport(Darwin) private let sys_pthread_getname_np = pthread_getname_np // Emulate the same method signature as pthread_setname_np on Linux. @@ -111,7 +115,11 @@ enum ThreadOpsPosix: ThreadOps { body(NIOThread(handle: hThread, desiredName: name)) + #if os(Android) + return UnsafeMutableRawPointer(bitPattern: 0xdeadbee)! + #else return nil + #endif }, args: argv0) precondition(res == 0, "Unable to create thread: \(res)") diff --git a/Sources/NIOPosix/UnsafeTransfer.swift b/Sources/NIOPosix/UnsafeTransfer.swift new file mode 100644 index 0000000000..daef8dacc0 --- /dev/null +++ b/Sources/NIOPosix/UnsafeTransfer.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2021-2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// ``UnsafeTransfer`` can be used to make non-`Sendable` values `Sendable`. +/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler. +/// It can be used similar to `@unsafe Sendable` but for values instead of types. +@usableFromInline +struct UnsafeTransfer { + @usableFromInline + var wrappedValue: Wrapped + + @inlinable + init(_ wrappedValue: Wrapped) { + self.wrappedValue = wrappedValue + } +} + +extension UnsafeTransfer: @unchecked Sendable {} + +extension UnsafeTransfer: Equatable where Wrapped: Equatable {} +extension UnsafeTransfer: Hashable where Wrapped: Hashable {} + diff --git a/Sources/NIOPosix/VsockAddress.swift b/Sources/NIOPosix/VsockAddress.swift index 87e93e91b2..8e15e893e2 100644 --- a/Sources/NIOPosix/VsockAddress.swift +++ b/Sources/NIOPosix/VsockAddress.swift @@ -220,12 +220,11 @@ extension VsockAddress.ContextID { let fd = socketFD #elseif os(Linux) || os(Android) let request = CNIOLinux_IOCTL_VM_SOCKETS_GET_LOCAL_CID - let fd = try! Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) + let fd = try Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) defer { try! Posix.close(descriptor: fd) } #endif var cid = Self.any.rawValue try Posix.ioctl(fd: fd, request: request, ptr: &cid) - precondition(cid != Self.any.rawValue) return Self(rawValue: cid) } } diff --git a/Sources/NIOTCPEchoClient/Client.swift b/Sources/NIOTCPEchoClient/Client.swift index 0bbb1dba3e..16f2d5da4e 100644 --- a/Sources/NIOTCPEchoClient/Client.swift +++ b/Sources/NIOTCPEchoClient/Client.swift @@ -11,11 +11,12 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix -@available(macOS 14, *) +#if compiler(>=5.9) +import NIOCore +import NIOPosix + +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Client { /// The host to connect to. @@ -59,7 +60,7 @@ struct Client { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(NewlineDelimiterCoder())) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: NIOAsyncChannel.Configuration( inboundType: String.self, outboundType: String.self @@ -68,16 +69,18 @@ struct Client { } } - print("Connection(\(number)): Writing request") - try await channel.outboundWriter.write("Hello on connection \(number)") + try await channel.executeThenClose { inbound, outbound in + print("Connection(\(number)): Writing request") + try await outbound.write("Hello on connection \(number)") - for try await inboundData in channel.inboundStream { - print("Connection(\(number)): Received response (\(inboundData))") + for try await inboundData in inbound { + print("Connection(\(number)): Received response (\(inboundData))") - // We only expect a single response so we can exit here. - // Once, we exit out of this loop and the references to the `NIOAsyncChannel` are dropped - // the connection is going to close itself. - break + // We only expect a single response so we can exit here. + // Once, we exit out of this loop and the references to the `NIOAsyncChannel` are dropped + // the connection is going to close itself. + break + } } } } diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index 1d38049058..1ccfccc33e 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -11,11 +11,12 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -#if swift(>=5.9) -@_spi(AsyncChannel) import NIOCore -@_spi(AsyncChannel) import NIOPosix -@available(macOS 14, *) +#if compiler(>=5.9) +import NIOCore +import NIOPosix + +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main struct Server { /// The server's host. @@ -48,7 +49,7 @@ struct Server { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(NewlineDelimiterCoder())) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: NIOAsyncChannel.Configuration( inboundType: String.self, outboundType: String.self @@ -64,11 +65,13 @@ struct Server { // the results of the group we need the group to automatically discard them; otherwise, this // would result in a memory leak over time. try await withThrowingDiscardingTaskGroup { group in - for try await connectionChannel in channel.inboundStream { - group.addTask { - print("Handling new connection") - await self.handleConnection(channel: connectionChannel) - print("Done handling connection") + try await channel.executeThenClose { inbound in + for try await connectionChannel in inbound { + group.addTask { + print("Handling new connection") + await self.handleConnection(channel: connectionChannel) + print("Done handling connection") + } } } } @@ -80,9 +83,11 @@ struct Server { // We do this since we don't want to tear down the whole server when a single connection // encounters an error. do { - for try await inboundData in channel.inboundStream { - print("Received request (\(inboundData))") - try await channel.outboundWriter.write(inboundData) + try await channel.executeThenClose { inbound, outbound in + for try await inboundData in inbound { + print("Received request (\(inboundData))") + try await outbound.write(inboundData) + } } } catch { print("Hit error: \(error)") diff --git a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift index 903f665974..d3e63fcf47 100644 --- a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift +++ b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOCore +import NIOCore -/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines +/// A helper `ChannelInboundHandler` that makes it easy to swap channel pipelines /// based on the result of an ALPN negotiation. /// /// The standard pattern used by applications that want to use ALPN is to select @@ -26,48 +26,43 @@ /// /// The user of this channel handler provides a single closure that is called with /// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result -/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should -/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured. +/// the user is free to reconfigure the `ChannelPipeline` as required, and should +/// return an `EventLoopFuture` that will complete when the pipeline is reconfigured. /// -/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound -/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed +/// Until the `EventLoopFuture` completes, this channel handler will buffer inbound +/// data. When the `EventLoopFuture` completes, the buffered data will be replayed /// down the channel. Then, finally, this channel handler will automatically remove /// itself from the channel pipeline, leaving the pipeline in its final /// configuration. /// /// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to /// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult`` -/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel`` +/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into `NIOAsyncChannel` /// based bootstraps. -@_spi(AsyncChannel) public final class NIOTypedApplicationProtocolNegotiationHandler: ChannelInboundHandler, RemovableChannelHandler { - @_spi(AsyncChannel) public typealias InboundIn = Any - @_spi(AsyncChannel) public typealias InboundOut = Any - @_spi(AsyncChannel) - public var protocolNegotiationResult: EventLoopFuture> { + public var protocolNegotiationResult: EventLoopFuture { return self.negotiatedPromise.futureResult } - private var negotiatedPromise: EventLoopPromise> { + private var negotiatedPromise: EventLoopPromise { precondition(self._negotiatedPromise != nil, "Tried to access the protocol negotiation result before the handler was added to a pipeline") return self._negotiatedPromise! } - private var _negotiatedPromise: EventLoopPromise>? + private var _negotiatedPromise: EventLoopPromise? - private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture> - private var stateMachine = ProtocolNegotiationHandlerStateMachine>() + private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture + private var stateMachine = ProtocolNegotiationHandlerStateMachine() /// Create an `ApplicationProtocolNegotiationHandler` with the given completion /// callback. /// /// - Parameter alpnCompleteHandler: The closure that will fire when ALPN /// negotiation has completed. - @_spi(AsyncChannel) - public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture>) { + public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture) { self.completionHandler = alpnCompleteHandler } @@ -76,8 +71,7 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture>) { + public convenience init(alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture) { self.init { result, _ in alpnCompleteHandler(result) } @@ -97,7 +91,6 @@ public final class NIOTypedApplicationProtocolNegotiationHandler, Error>) { + private func userFutureCompleted(context: ChannelHandlerContext, result: Result) { switch self.stateMachine.userFutureCompleted(with: result) { case .fireErrorCaughtAndRemoveHandler(let error): self.negotiatedPromise.fail(error) diff --git a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift index 1bddd9057f..c787c3e7d5 100644 --- a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift +++ b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift @@ -167,6 +167,7 @@ private final class AggregateBodyHandler: ChannelInboundHandler { /// XCTAssertNoThrow(XCTAssertEqual(responseBody, try requestComplete.wait())) public final class NIOHTTP1TestServer { private let eventLoop: EventLoop + private let aggregateBody: Bool // all protected by eventLoop private let inboundBuffer: BlockingQueue = .init() private var currentClientChannel: Channel? = nil @@ -213,7 +214,11 @@ public final class NIOHTTP1TestServer { return } channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(AggregateBodyHandler()) + if self.aggregateBody { + return channel.pipeline.addHandler(AggregateBodyHandler()) + } else { + return self.eventLoop.makeSucceededVoidFuture() + } }.flatMap { channel.pipeline.addHandler(WebServerHandler(webServer: self)) }.whenSuccess { @@ -221,8 +226,13 @@ public final class NIOHTTP1TestServer { } } - public init(group: EventLoopGroup) { + public convenience init(group: EventLoopGroup) { + self.init(group: group, aggregateBody: true) + } + + public init(group: EventLoopGroup, aggregateBody: Bool) { self.eventLoop = group.next() + self.aggregateBody = aggregateBody self.serverChannel = try! ServerBootstrap(group: self.eventLoop) .childChannelOption(ChannelOptions.autoRead, value: false) diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index 5e7df19c4f..a9e456f857 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -25,7 +25,6 @@ public typealias NIOWebClientSocketUpgrader = NIOWebSocketClientUpgrader /// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the /// pipeline to remove the HTTP `ChannelHandler`s. public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { - /// RFC 6455 specs this as the required entry in the Upgrade header. public let supportedProtocol: String = "websocket" /// None of the websocket headers are actually defined as 'required'. @@ -34,7 +33,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { private let requestKey: String private let maxFrameSize: Int private let automaticErrorHandling: Bool - private let upgradePipelineHandler: (Channel, HTTPResponseHead) -> EventLoopFuture + private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture /// - Parameters: /// - requestKey: sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. @@ -45,7 +44,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { requestKey: String = randomRequestKey(), maxFrameSize: Int = 1 << 14, automaticErrorHandling: Bool = true, - upgradePipelineHandler: @escaping (Channel, HTTPResponseHead) -> EventLoopFuture + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture ) { precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") @@ -57,48 +56,80 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { /// Add additional headers that are needed for a WebSocket upgrade request. public func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { - upgradeRequestHeaders.add(name: "Sec-WebSocket-Key", value: self.requestKey) - upgradeRequestHeaders.add(name: "Sec-WebSocket-Version", value: "13") + _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) } - /// Allow or deny the upgrade based on the upgrade HTTP response - /// headers containing the correct accept key. public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - - let acceptValueHeader = upgradeResponse.headers["Sec-WebSocket-Accept"] + _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) + } - guard acceptValueHeader.count == 1 else { - return false - } + public func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + _upgrade( + channel: context.channel, + upgradeResponse: upgradeResponse, + maxFrameSize: self.maxFrameSize, + enableAutomaticErrorHandling: self.automaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} + +#if !canImport(Darwin) || swift(>=5.10) +/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the +/// pipeline to remove the HTTP `ChannelHandler`s. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public final class NIOTypedWebSocketClientUpgrader: NIOTypedHTTPClientProtocolUpgrader { + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + /// None of the websocket headers are actually defined as 'required'. + public let requiredUpgradeHeaders: [String] = [] + + private let requestKey: String + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + private let upgradePipelineHandler: @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture - // Validate the response key in 'Sec-WebSocket-Accept'. - var hasher = SHA1() - hasher.update(string: self.requestKey) - hasher.update(string: magicWebSocketGUID) - let expectedAcceptValue = String(base64Encoding: hasher.finish()) + /// - Parameters: + /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. + /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. + /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. + /// - upgradePipelineHandler: Called once the upgrade was successful. + public init( + requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture + ) { + precondition(requestKey != "", "The request key must contain a valid Sec-WebSocket-Key") + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.requestKey = requestKey + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling + } - return expectedAcceptValue == acceptValueHeader[0] + public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) { + _addCustom(upgradeRequestHeaders: &upgradeRequestHeaders, requestKey: self.requestKey) } - /// Called when the upgrade response has been flushed and it is safe to mutate the channel - /// pipeline. Adds channel handlers for websocket frame encoding, decoding and errors. - public func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + public func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { + _shouldAllowUpgrade(upgradeResponse: upgradeResponse, requestKey: self.requestKey) + } - var upgradeFuture = context.pipeline.addHandler(WebSocketFrameEncoder()).flatMap { - context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize))) - } - - if self.automaticErrorHandling { - upgradeFuture = upgradeFuture.flatMap { - context.pipeline.addHandler(WebSocketProtocolErrorHandler()) - } - } - - return upgradeFuture.flatMap { - self.upgradePipelineHandler(context.channel, upgradeResponse) - } + public func upgrade(channel: Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeResponse: upgradeResponse, + maxFrameSize: self.maxFrameSize, + enableAutomaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) } } +#endif @available(*, unavailable) extension NIOWebSocketClientUpgrader: Sendable {} @@ -128,3 +159,48 @@ extension NIOWebSocketClientUpgrader { return NIOWebSocketClientUpgrader.randomRequestKey(using: &generator) } } + +/// Add additional headers that are needed for a WebSocket upgrade request. +private func _addCustom(upgradeRequestHeaders: inout HTTPHeaders, requestKey: String) { + upgradeRequestHeaders.add(name: "Sec-WebSocket-Key", value: requestKey) + upgradeRequestHeaders.add(name: "Sec-WebSocket-Version", value: "13") +} + +/// Allow or deny the upgrade based on the upgrade HTTP response +/// headers containing the correct accept key. +private func _shouldAllowUpgrade(upgradeResponse: HTTPResponseHead, requestKey: String) -> Bool { + let acceptValueHeader = upgradeResponse.headers["Sec-WebSocket-Accept"] + + guard acceptValueHeader.count == 1 else { + return false + } + + // Validate the response key in 'Sec-WebSocket-Accept'. + var hasher = SHA1() + hasher.update(string: requestKey) + hasher.update(string: magicWebSocketGUID) + let expectedAcceptValue = String(base64Encoding: hasher.finish()) + + return expectedAcceptValue == acceptValueHeader[0] +} + +/// Called when the upgrade response has been flushed and it is safe to mutate the channel +/// pipeline. Adds channel handlers for websocket frame encoding, decoding and errors. +private func _upgrade( + channel: Channel, + upgradeResponse: HTTPResponseHead, + maxFrameSize: Int, + enableAutomaticErrorHandling: Bool, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture +) -> EventLoopFuture { + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()) + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))) + if enableAutomaticErrorHandling { + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) + } + } + .flatMap { + upgradePipelineHandler(channel, upgradeResponse) + } +} diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 4d1f77f6a9..0672bc4a06 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -64,15 +64,10 @@ fileprivate extension HTTPHeaders { public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable { // This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime // the conformance is `@unchecked`. - - #if swift(>=5.7) + // FIXME: remove @unchecked when 5.7 is the minimum supported version. private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture - #else - private typealias ShouldUpgrade = (Channel, HTTPRequestHead) -> EventLoopFuture - private typealias UpgradePipelineHandler = (Channel, HTTPRequestHead) -> EventLoopFuture - #endif /// RFC 6455 specs this as the required entry in the Upgrade header. public let supportedProtocol: String = "websocket" @@ -86,7 +81,6 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch private let maxFrameSize: Int private let automaticErrorHandling: Bool - #if swift(>=5.7) /// Create a new `NIOWebSocketServerUpgrader`. /// /// - parameters: @@ -112,34 +106,7 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch self.init(maxFrameSize: 1 << 14, automaticErrorHandling: automaticErrorHandling, shouldUpgrade: shouldUpgrade, upgradePipelineHandler: upgradePipelineHandler) } - #else - /// Create a new `NIOWebSocketServerUpgrader`. - /// - /// - parameters: - /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol - /// errors by sending error responses and closing the connection. Defaults to `true`, - /// may be set to `false` if the user wishes to handle their own errors. - /// - shouldUpgrade: A callback that determines whether the websocket request should be - /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with - /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, - /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return - /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. - /// - upgradePipelineHandler: A function that will be called once the upgrade response is - /// flushed, and that is expected to mutate the `Channel` appropriately to handle the - /// websocket protocol. This only needs to add the user handlers: the - /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the - /// pipeline automatically. - public convenience init( - automaticErrorHandling: Bool = true, - shouldUpgrade: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture, - upgradePipelineHandler: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture - ) { - self.init(maxFrameSize: 1 << 14, automaticErrorHandling: automaticErrorHandling, - shouldUpgrade: shouldUpgrade, upgradePipelineHandler: upgradePipelineHandler) - } - #endif - #if swift(>=5.7) /// Create a new `NIOWebSocketServerUpgrader`. /// /// - parameters: @@ -174,78 +141,152 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch upgradePipelineHandler: upgradePipelineHandler ) } - #else - /// Create a new `NIOWebSocketServerUpgrader`. + + private init( + _maxFrameSize maxFrameSize: Int, + automaticErrorHandling: Bool, + shouldUpgrade: @escaping ShouldUpgrade, + upgradePipelineHandler: @escaping UpgradePipelineHandler + ) { + precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") + self.shouldUpgrade = shouldUpgrade + self.upgradePipelineHandler = upgradePipelineHandler + self.maxFrameSize = maxFrameSize + self.automaticErrorHandling = automaticErrorHandling + } + + public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { + return _buildUpgradeResponse( + channel: channel, + upgradeRequest: upgradeRequest, + initialResponseHeaders: initialResponseHeaders, + shouldUpgrade: self.shouldUpgrade + ) + } + + public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + _upgrade( + channel: context.channel, + upgradeRequest: upgradeRequest, + maxFrameSize: self.maxFrameSize, + automaticErrorHandling: self.automaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} + +#if !canImport(Darwin) || swift(>=5.10) +/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// Users may frequently want to offer multiple websocket endpoints on the same port. For this +/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to +/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers +/// (including the path) and determines whether this is a websocket upgrade request that is acceptable +/// to them. +/// +/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to +/// remove the HTTP `ChannelHandler`s. +public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { + private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String = "websocket" + + /// We deliberately do not actually set any required headers here, because the websocket + /// spec annoyingly does not actually force the client to send these in the Upgrade header, + /// which NIO requires. We check for these manually. + public let requiredUpgradeHeaders: [String] = [] + + private let shouldUpgrade: ShouldUpgrade + private let upgradePipelineHandler: UpgradePipelineHandler + private let maxFrameSize: Int + private let enableAutomaticErrorHandling: Bool + + /// Create a new ``NIOTypedWebSocketServerUpgrader``. /// - /// - parameters: - /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the + /// - Parameters: + /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but /// this is an objectively unreasonable maximum value (on AMD64 systems it is not /// possible to even. Users may set this to any value up to `UInt32.max`. - /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol + /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol /// errors by sending error responses and closing the connection. Defaults to `true`, /// may be set to `false` if the user wishes to handle their own errors. - /// - shouldUpgrade: A callback that determines whether the websocket request should be + /// - shouldUpgrade: A callback that determines whether the websocket request should be /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. - /// - upgradePipelineHandler: A function that will be called once the upgrade response is + /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is /// flushed, and that is expected to mutate the `Channel` appropriately to handle the /// websocket protocol. This only needs to add the user handlers: the /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the /// pipeline automatically. - public convenience init( - maxFrameSize: Int, - automaticErrorHandling: Bool = true, - shouldUpgrade: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture, - upgradePipelineHandler: @escaping (Channel, HTTPRequestHead) -> EventLoopFuture - ) { - self.init( - _maxFrameSize: maxFrameSize, - automaticErrorHandling: automaticErrorHandling, - shouldUpgrade: shouldUpgrade, - upgradePipelineHandler: upgradePipelineHandler - ) - } - #endif - - private init( - _maxFrameSize maxFrameSize: Int, - automaticErrorHandling: Bool, - shouldUpgrade: @escaping ShouldUpgrade, - upgradePipelineHandler: @escaping UpgradePipelineHandler + public init( + maxFrameSize: Int = 1 << 14, + enableAutomaticErrorHandling: Bool = true, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture ) { precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") self.shouldUpgrade = shouldUpgrade self.upgradePipelineHandler = upgradePipelineHandler self.maxFrameSize = maxFrameSize - self.automaticErrorHandling = automaticErrorHandling + self.enableAutomaticErrorHandling = enableAutomaticErrorHandling } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - let key: String - let version: String - - do { - key = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Key") - version = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Version") - } catch { - return channel.eventLoop.makeFailedFuture(error) - } + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + _buildUpgradeResponse( + channel: channel, + upgradeRequest: upgradeRequest, + initialResponseHeaders: initialResponseHeaders, + shouldUpgrade: self.shouldUpgrade + ) + } - // The version must be 13. - guard version == "13" else { - return channel.eventLoop.makeFailedFuture(NIOWebSocketUpgradeError.invalidUpgradeHeader) - } + public func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + _upgrade( + channel: channel, + upgradeRequest: upgradeRequest, + maxFrameSize: self.maxFrameSize, + automaticErrorHandling: self.enableAutomaticErrorHandling, + upgradePipelineHandler: self.upgradePipelineHandler + ) + } +} +#endif - return self.shouldUpgrade(channel, upgradeRequest).flatMapThrowing { extraHeaders in - guard let extraHeaders = extraHeaders else { +private func _buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders, + shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture +) -> EventLoopFuture { + let key: String + let version: String + + do { + key = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Key") + version = try upgradeRequest.headers.nonListHeader("Sec-WebSocket-Version") + } catch { + return channel.eventLoop.makeFailedFuture(error) + } + + // The version must be 13. + guard version == "13" else { + return channel.eventLoop.makeFailedFuture(NIOWebSocketUpgradeError.invalidUpgradeHeader) + } + + return shouldUpgrade(channel, upgradeRequest) + .flatMapThrowing { extraHeaders in + guard var extraHeaders = extraHeaders else { throw NIOWebSocketUpgradeError.unsupportedWebSocketTarget } - return extraHeaders - }.map { (extraHeaders: HTTPHeaders) in - var extraHeaders = extraHeaders // Cool, we're good to go! Let's do our upgrade. We do this by concatenating the magic // GUID to the base64-encoded key and taking a SHA1 hash of the result. @@ -263,23 +304,27 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch return extraHeaders } - } - - public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { - /// We never use the automatic error handling feature of the WebSocketFrameDecoder: we always use the separate channel - /// handler. - var upgradeFuture = context.pipeline.addHandler(WebSocketFrameEncoder()).flatMap { - context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize))) - } +} - if self.automaticErrorHandling { - upgradeFuture = upgradeFuture.flatMap { - context.pipeline.addHandler(WebSocketProtocolErrorHandler()) - } - } +private func _upgrade( + channel: Channel, + upgradeRequest: HTTPRequestHead, + maxFrameSize: Int, + automaticErrorHandling: Bool, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture +) -> EventLoopFuture { + /// We never use the automatic error handling feature of the WebSocketFrameDecoder: we always use the separate channel + /// handler. + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()) + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) + ) - return upgradeFuture.flatMap { - self.upgradePipelineHandler(context.channel, upgradeRequest) + if automaticErrorHandling { + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) } + }.flatMap { + upgradePipelineHandler(channel, upgradeRequest) } } diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift new file mode 100644 index 0000000000..0bf42d3b16 --- /dev/null +++ b/Sources/NIOWebSocketClient/Client.swift @@ -0,0 +1,147 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if (!canImport(Darwin) && compiler(>=5.9)) || (canImport(Darwin) && compiler(>=5.10)) +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket + +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) +@main +struct Client { + /// The host to connect to. + private let host: String + /// The port to connect to. + private let port: Int + /// The client's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded + } + + static func main() async throws { + let client = Client( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await client.run() + } + + /// This method starts the client and tries to setup a WebSocket connection. + func run() async throws { + let upgradeResult: EventLoopFuture = try await ClientBootstrap(group: self.eventLoopGroup) + .connect( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketClientUpgrader( + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "0") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let clientUpgradeConfiguration = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + return UpgradeResult.notUpgraded + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are awaiting and handling the upgrade result now. + try await self.handleUpgradeResult(upgradeResult) + } + + /// This method handles the upgrade result. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async throws { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded: + // The upgrade to websocket did not succeed. We are just exiting in this case. + print("Upgrade declined") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + // We are sending a ping frame and then + // start to handle all inbound frames. + + let pingFrame = WebSocketFrame(fin: true, opcode: .ping, data: ByteBuffer(string: "Hello!")) + try await channel.executeThenClose { inbound, outbound in + try await outbound.write(pingFrame) + + for try await frame in inbound { + switch frame.opcode { + case .pong: + print("Received pong: \(String(buffer: frame.data))") + + case .text: + print("Received: \(String(buffer: frame.data))") + + case .connectionClose: + // Handle a received close frame. We're just going to close by returning from this method. + print("Received Close instruction from server") + return + case .binary, .continuation, .ping: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } + } +} + +#else +@main +struct Server { + static func main() { + fatalError("Requires at least Swift 5.9") + } +} +#endif diff --git a/Sources/NIOWebSocketClient/main.swift b/Sources/NIOWebSocketClient/main.swift deleted file mode 100644 index 1f3adecc4f..0000000000 --- a/Sources/NIOWebSocketClient/main.swift +++ /dev/null @@ -1,233 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore -import NIOPosix -import NIOHTTP1 -import NIOWebSocket - -print("Establishing connection.") - -enum ConnectTo { - case ip(host: String, port: Int) - case unixDomainSocket(path: String) -} - -// The HTTP handler to be used to initiate the request. -// This initial request will be adapted by the WebSocket upgrader to contain the upgrade header parameters. -// Channel read will only be called if the upgrade fails. - -private final class HTTPInitialRequestHandler: ChannelInboundHandler, RemovableChannelHandler { - public typealias InboundIn = HTTPClientResponsePart - public typealias OutboundOut = HTTPClientRequestPart - - public let target: ConnectTo - - public init(target: ConnectTo) { - self.target = target - } - - public func channelActive(context: ChannelHandlerContext) { - print("Client connected to \(context.remoteAddress!)") - - // We are connected. It's time to send the message to the server to initialize the upgrade dance. - var headers = HTTPHeaders() - if case let .ip(host: host, port: port) = target { - headers.add(name: "Host", value: "\(host):\(port)") - } - headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") - headers.add(name: "Content-Length", value: "\(0)") - - let requestHead = HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/", - headers: headers) - - context.write(self.wrapOutboundOut(.head(requestHead)), promise: nil) - - let body = HTTPClientRequestPart.body(.byteBuffer(ByteBuffer())) - context.write(self.wrapOutboundOut(body), promise: nil) - - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - - let clientResponse = self.unwrapInboundIn(data) - - print("Upgrade failed") - - switch clientResponse { - case .head(let responseHead): - print("Received status: \(responseHead.status)") - case .body(let byteBuffer): - let string = String(buffer: byteBuffer) - print("Received: '\(string)' back from the server.") - case .end: - print("Closing channel.") - context.close(promise: nil) - } - } - - public func handlerRemoved(context: ChannelHandlerContext) { - print("HTTP handler removed.") - } - - public func errorCaught(context: ChannelHandlerContext, error: Error) { - print("error: ", error) - - // As we are not really interested getting notified on success or failure - // we just pass nil as promise to reduce allocations. - context.close(promise: nil) - } -} - -// The web socket handler to be used once the upgrade has occurred. -// One added, it sends a ping-pong round trip with "Hello World" data. -// It also listens for any text frames from the server and prints them. - -private final class WebSocketPingPongHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - let testFrameData: String = "Hello World" - - // This is being hit, channel active won't be called as it is already added. - public func handlerAdded(context: ChannelHandlerContext) { - print("WebSocket handler added.") - self.pingTestFrameData(context: context) - } - - public func handlerRemoved(context: ChannelHandlerContext) { - print("WebSocket handler removed.") - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = self.unwrapInboundIn(data) - - switch frame.opcode { - case .pong: - self.pong(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - print("Websocket: Received \(text)") - case .connectionClose: - self.receivedClose(context: context, frame: frame) - case .binary, .continuation, .ping: - // We ignore these frames. - break - default: - // Unknown frames are errors. - self.closeOnError(context: context) - } - } - - public func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - // Handle a received close frame. We're just going to close. - print("Received Close instruction from server") - context.close(promise: nil) - } - - private func pingTestFrameData(context: ChannelHandlerContext) { - let buffer = context.channel.allocator.buffer(string: self.testFrameData) - let frame = WebSocketFrame(fin: true, opcode: .ping, data: buffer) - context.write(self.wrapOutboundOut(frame), promise: nil) - } - - private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - if let frameDataString = frameData.readString(length: self.testFrameData.count) { - print("Websocket: Received: \(frameDataString)") - } - } - - private func closeOnError(context: ChannelHandlerContext) { - // We have hit an error, we want to close. We do that by sending a close frame and then - // shutting down the write side of the connection. The server will respond with a close of its own. - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - } -} - -// First argument is the program path -let arguments = CommandLine.arguments -let arg1 = arguments.dropFirst().first -let arg2 = arguments.dropFirst(2).first - -let defaultHost = "::1" -let defaultPort: Int = 8888 - -let connectTarget: ConnectTo -switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ - connectTarget = .ip(host: h, port: p) -case (.some(let portString), .none, _): - /* couldn't parse as number, expecting unix domain socket path */ - connectTarget = .unixDomainSocket(path: portString) -case (_, .some(let p), _): - /* only one argument --> port */ - connectTarget = .ip(host: defaultHost, port: p) -default: - connectTarget = .ip(host: defaultHost, port: defaultPort) -} - -let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) -let bootstrap = ClientBootstrap(group: group) - // Enable SO_REUSEADDR. - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .channelInitializer { channel in - - let httpHandler = HTTPInitialRequestHandler(target: connectTarget) - - let websocketUpgrader = NIOWebSocketClientUpgrader(requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", - upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in - channel.pipeline.addHandler(WebSocketPingPongHandler()) - }) - - let config: NIOHTTPClientUpgradeConfiguration = ( - upgraders: [ websocketUpgrader ], - completionHandler: { _ in - channel.pipeline.removeHandler(httpHandler, promise: nil) - }) - - return channel.pipeline.addHTTPClientHandlers(withClientUpgrade: config).flatMap { - channel.pipeline.addHandler(httpHandler) - } -} -defer { - try! group.syncShutdownGracefully() -} - -let channel = try { () -> Channel in - switch connectTarget { - case .ip(let host, let port): - return try bootstrap.connect(host: host, port: port).wait() - case .unixDomainSocket(let path): - return try bootstrap.connect(unixDomainSocketPath: path).wait() - } -}() - -// Will be closed after we echo-ed back to the server. -try channel.closeFuture.wait() - -print("Client closed") diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift new file mode 100644 index 0000000000..560fe16e39 --- /dev/null +++ b/Sources/NIOWebSocketServer/Server.swift @@ -0,0 +1,291 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if (!canImport(Darwin) && compiler(>=5.9)) || (canImport(Darwin) && compiler(>=5.10)) +import NIOCore +import NIOPosix +import NIOHTTP1 +import NIOWebSocket + +let websocketResponse = """ + + + + + Swift NIO WebSocket Test Page + + + +

WebSocket Stream

+
+ + +""" + +@available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) +@main +struct Server { + /// The server's host. + private let host: String + /// The server's port. + private let port: Int + /// The server's event loop group. + private let eventLoopGroup: MultiThreadedEventLoopGroup + + private static let responseBody = ByteBuffer(string: websocketResponse) + + enum UpgradeResult { + case websocket(NIOAsyncChannel) + case notUpgraded(NIOAsyncChannel>) + } + + static func main() async throws { + let server = Server( + host: "localhost", + port: 8888, + eventLoopGroup: .singleton + ) + try await server.run() + } + + /// This method starts the server and handles incoming connections. + func run() async throws { + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketServerUpgrader( + shouldUpgrade: { (channel, head) in + channel.eventLoop.makeSucceededFuture(HTTPHeaders()) + }, + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + return UpgradeResult.websocket(asyncChannel) + } + } + ) + + let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) + let asyncChannel = try NIOAsyncChannel>(wrappingChannelSynchronously: channel) + return UpgradeResult.notUpgraded(asyncChannel) + } + } + ) + + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) + ) + + return negotiationResultFuture + } + } + + // We are handling each incoming connection in a separate child task. It is important + // to use a discarding task group here which automatically discards finished child tasks. + // A normal task group retains all child tasks and their outputs in memory until they are + // consumed by iterating the group or by exiting the group. Since, we are never consuming + // the results of the group we need the group to automatically discard them; otherwise, this + // would result in a memory leak over time. + try await withThrowingDiscardingTaskGroup { group in + try await channel.executeThenClose { inbound in + for try await upgradeResult in inbound { + group.addTask { + await self.handleUpgradeResult(upgradeResult) + } + } + } + } + } + + /// This method handles a single connection by echoing back all inbound data. + private func handleUpgradeResult(_ upgradeResult: EventLoopFuture) async { + // Note that this method is non-throwing and we are catching any error. + // We do this since we don't want to tear down the whole server when a single connection + // encounters an error. + do { + switch try await upgradeResult.get() { + case .websocket(let websocketChannel): + print("Handling websocket connection") + try await self.handleWebsocketChannel(websocketChannel) + print("Done handling websocket connection") + case .notUpgraded(let httpChannel): + print("Handling HTTP connection") + try await self.handleHTTPChannel(httpChannel) + print("Done handling HTTP connection") + } + } catch { + print("Hit error: \(error)") + } + } + + private func handleWebsocketChannel(_ channel: NIOAsyncChannel) async throws { + try await channel.executeThenClose { inbound, outbound in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for try await frame in inbound { + switch frame.opcode { + case .ping: + print("Received ping") + var frameData = frame.data + let maskingKey = frame.maskKey + + if let maskingKey = maskingKey { + frameData.webSocketUnmask(maskingKey) + } + + let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) + try await outbound.write(responseFrame) + + case .connectionClose: + // This is an unsolicited close. We're going to send a response frame and + // then, when we've sent it, close up shop. We should send back the close code the remote + // peer sent us, unless they didn't send one at all. + print("Received close") + var data = frame.unmaskedData + let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() + let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) + try await outbound.write(closeFrame) + return + case .binary, .continuation, .pong: + // We ignore these frames. + break + default: + // Unknown frames are errors. + return + } + } + } + + group.addTask { + // This is our main business logic where we are just sending the current time + // every second. + while true { + // We can't really check for error here, but it's also not the purpose of the + // example so let's not worry about it. + let theTime = ContinuousClock().now + var buffer = channel.channel.allocator.buffer(capacity: 12) + buffer.writeString("\(theTime)") + + let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) + + print("Sending time") + try await outbound.write(frame) + try await Task.sleep(for: .seconds(1)) + } + } + + try await group.next() + group.cancelAll() + } + } + } + + + private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { + try await channel.executeThenClose { inbound, outbound in + for try await requestPart in inbound { + // We're not interested in request bodies here: we're just serving up GET responses + // to get the client to initiate a websocket request. + guard case .head(let head) = requestPart else { + return + } + + // GETs only. + guard case .GET = head.method else { + try await self.respond405(writer: outbound) + return + } + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/html") + headers.add(name: "Content-Length", value: String(Self.responseBody.readableBytes)) + headers.add(name: "Connection", value: "close") + let responseHead = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: headers + ) + + try await outbound.write( + contentsOf: [ + .head(responseHead), + .body(Self.responseBody), + .end(nil) + ] + ) + } + } + } + + private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { + var headers = HTTPHeaders() + headers.add(name: "Connection", value: "close") + headers.add(name: "Content-Length", value: "0") + let head = HTTPResponseHead( + version: .http1_1, + status: .methodNotAllowed, + headers: headers + ) + + try await writer.write( + contentsOf: [ + .head(head), + .end(nil) + ] + ) + } +} + +final class HTTPByteBufferResponsePartHandler: ChannelOutboundHandler { + typealias OutboundIn = HTTPPart + typealias OutboundOut = HTTPServerResponsePart + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let part = self.unwrapOutboundIn(data) + switch part { + case .head(let head): + context.write(self.wrapOutboundOut(.head(head)), promise: promise) + case .body(let buffer): + context.write(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) + case .end(let trailers): + context.write(self.wrapOutboundOut(.end(trailers)), promise: promise) + } + } +} + +#else +@main +struct Server { + static func main() { + fatalError("Requires at least Swift 5.9") + } +} +#endif diff --git a/Sources/NIOWebSocketServer/main.swift b/Sources/NIOWebSocketServer/main.swift deleted file mode 100644 index d8bb6592e6..0000000000 --- a/Sources/NIOWebSocketServer/main.swift +++ /dev/null @@ -1,282 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -import NIOCore -import NIOPosix -import NIOHTTP1 -import NIOWebSocket - -let websocketResponse = """ - - - - - Swift NIO WebSocket Test Page - - - -

WebSocket Stream

-
- - -""" - -private final class HTTPHandler: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = HTTPServerRequestPart - typealias OutboundOut = HTTPServerResponsePart - - private var responseBody: ByteBuffer! - - func handlerAdded(context: ChannelHandlerContext) { - self.responseBody = context.channel.allocator.buffer(string: websocketResponse) - } - - func handlerRemoved(context: ChannelHandlerContext) { - self.responseBody = nil - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let reqPart = self.unwrapInboundIn(data) - - // We're not interested in request bodies here: we're just serving up GET responses - // to get the client to initiate a websocket request. - guard case .head(let head) = reqPart else { - return - } - - // GETs only. - guard case .GET = head.method else { - self.respond405(context: context) - return - } - - var headers = HTTPHeaders() - headers.add(name: "Content-Type", value: "text/html") - headers.add(name: "Content-Length", value: String(self.responseBody.readableBytes)) - headers.add(name: "Connection", value: "close") - let responseHead = HTTPResponseHead(version: .init(major: 1, minor: 1), - status: .ok, - headers: headers) - context.write(self.wrapOutboundOut(.head(responseHead)), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(self.responseBody))), promise: nil) - context.write(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result) in - context.close(promise: nil) - } - context.flush() - } - - private func respond405(context: ChannelHandlerContext) { - var headers = HTTPHeaders() - headers.add(name: "Connection", value: "close") - headers.add(name: "Content-Length", value: "0") - let head = HTTPResponseHead(version: .http1_1, - status: .methodNotAllowed, - headers: headers) - context.write(self.wrapOutboundOut(.head(head)), promise: nil) - context.write(self.wrapOutboundOut(.end(nil))).whenComplete { (_: Result) in - context.close(promise: nil) - } - context.flush() - } -} - -private final class WebSocketTimeHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - private var awaitingClose: Bool = false - - public func handlerAdded(context: ChannelHandlerContext) { - self.sendTime(context: context) - } - - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = self.unwrapInboundIn(data) - - switch frame.opcode { - case .connectionClose: - self.receivedClose(context: context, frame: frame) - case .ping: - self.pong(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - print(text) - case .binary, .continuation, .pong: - // We ignore these frames. - break - default: - // Unknown frames are errors. - self.closeOnError(context: context) - } - } - - public func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - private func sendTime(context: ChannelHandlerContext) { - guard context.channel.isActive else { return } - - // We can't send if we sent a close message. - guard !self.awaitingClose else { return } - - // We can't really check for error here, but it's also not the purpose of the - // example so let's not worry about it. - let theTime = NIODeadline.now().uptimeNanoseconds - var buffer = context.channel.allocator.buffer(capacity: 12) - buffer.writeString("\(theTime)") - - let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) - context.writeAndFlush(self.wrapOutboundOut(frame)).map { - context.eventLoop.scheduleTask(in: .seconds(1), { self.sendTime(context: context) }) - }.whenFailure { (_: Error) in - context.close(promise: nil) - } - } - - private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - // Handle a received close frame. In websockets, we're just going to send the close - // frame and then close, unless we already sent our own close frame. - if awaitingClose { - // Cool, we started the close and were waiting for the user. We're done. - context.close(promise: nil) - } else { - // This is an unsolicited close. We're going to send a response frame and - // then, when we've sent it, close up shop. We should send back the close code the remote - // peer sent us, unless they didn't send one at all. - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() - let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) - _ = context.write(self.wrapOutboundOut(closeFrame)).map { () in - context.close(promise: nil) - } - } - } - - private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - let maskingKey = frame.maskKey - - if let maskingKey = maskingKey { - frameData.webSocketUnmask(maskingKey) - } - - let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - context.write(self.wrapOutboundOut(responseFrame), promise: nil) - } - - private func closeOnError(context: ChannelHandlerContext) { - // We have hit an error, we want to close. We do that by sending a close frame and then - // shutting down the write side of the connection. - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - awaitingClose = true - } -} - -let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - -let upgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel: Channel, head: HTTPRequestHead) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, - upgradePipelineHandler: { (channel: Channel, _: HTTPRequestHead) in - channel.pipeline.addHandler(WebSocketTimeHandler()) - }) - -let bootstrap = ServerBootstrap(group: group) - // Specify backlog and enable SO_REUSEADDR for the server itself - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are applied to the accepted Channels - .childChannelInitializer { channel in - let httpHandler = HTTPHandler() - let config: NIOHTTPServerUpgradeConfiguration = ( - upgraders: [ upgrader ], - completionHandler: { _ in - channel.pipeline.removeHandler(httpHandler, promise: nil) - } - ) - return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config).flatMap { - channel.pipeline.addHandler(httpHandler) - } - } - - // Enable SO_REUSEADDR for the accepted Channels - .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - -defer { - try! group.syncShutdownGracefully() -} - -// First argument is the program path -let arguments = CommandLine.arguments -let arg1 = arguments.dropFirst().first -let arg2 = arguments.dropFirst(2).first - -let defaultHost = "localhost" -let defaultPort = 8888 - -enum BindTo { - case ip(host: String, port: Int) - case unixDomainSocket(path: String) -} - -let bindTarget: BindTo -switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ - bindTarget = .ip(host: h, port: p) - -case (let portString?, .none, _): - // Couldn't parse as number, expecting unix domain socket path. - bindTarget = .unixDomainSocket(path: portString) - -case (_, let p?, _): - // Only one argument --> port. - bindTarget = .ip(host: defaultHost, port: p) - -default: - bindTarget = .ip(host: defaultHost, port: defaultPort) -} - -let channel = try { () -> Channel in - switch bindTarget { - case .ip(let host, let port): - return try bootstrap.bind(host: host, port: port).wait() - case .unixDomainSocket(let path): - return try bootstrap.bind(unixDomainSocketPath: path).wait() - } -}() - -guard let localAddress = channel.localAddress else { - fatalError("Address was unable to bind. Please check that the socket was not closed or that the address family was understood.") -} -print("Server started and listening on \(localAddress)") - -// This will never unblock as we don't close the ServerChannel -try channel.closeFuture.wait() - -print("Server closed") diff --git a/Sources/_NIODataStructures/_TinyArray.swift b/Sources/_NIODataStructures/_TinyArray.swift new file mode 100644 index 0000000000..bc1c154b6a --- /dev/null +++ b/Sources/_NIODataStructures/_TinyArray.swift @@ -0,0 +1,356 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftCertificates open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftCertificates project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftCertificates project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/// ``TinyArray`` is a ``RandomAccessCollection`` optimised to store zero or one ``Element``. +/// It supports arbitrary many elements but if only up to one ``Element`` is stored it does **not** allocate separate storage on the heap +/// and instead stores the ``Element`` inline. +public struct _TinyArray { + @usableFromInline + enum Storage { + case one(Element) + case arbitrary([Element]) + } + + @usableFromInline + var storage: Storage +} + +// MARK: - TinyArray "public" interface + +extension _TinyArray: Equatable where Element: Equatable {} +extension _TinyArray: Hashable where Element: Hashable {} +extension _TinyArray: Sendable where Element: Sendable {} + +extension _TinyArray: RandomAccessCollection { + public typealias Element = Element + + public typealias Index = Int + + @inlinable + public func makeIterator() -> Iterator { + return Iterator(storage: self.storage) + } + + public struct Iterator: IteratorProtocol { + @usableFromInline + let _storage: Storage + @usableFromInline + var _index: Index + + @usableFromInline + init(storage: Storage) { + self._storage = storage + self._index = storage.startIndex + } + + @inlinable + public mutating func next() -> Element? { + if self._index == self._storage.endIndex { + return nil + } + + defer { + self._index &+= 1 + } + + return self._storage[self._index] + } + } + + @inlinable + public subscript(position: Int) -> Element { + get { + self.storage[position] + } + } + + @inlinable + public var startIndex: Int { + self.storage.startIndex + } + + @inlinable + public var endIndex: Int { + self.storage.endIndex + } +} + +extension _TinyArray { + @inlinable + public init(_ elements: some Sequence) { + self.storage = .init(elements) + } + + @inlinable + public init(_ elements: some Sequence>) throws { + self.storage = try .init(elements) + } + + @inlinable + public init() { + self.storage = .init() + } + + @inlinable + public mutating func append(_ newElement: Element) { + self.storage.append(newElement) + } + + @inlinable + public mutating func append(contentsOf newElements: some Sequence) { + self.storage.append(contentsOf: newElements) + } + + @discardableResult + @inlinable + public mutating func remove(at index: Int) -> Element { + self.storage.remove(at: index) + } + + @inlinable + public mutating func removeAll(where shouldBeRemoved: (Element) throws -> Bool) rethrows { + try self.storage.removeAll(where: shouldBeRemoved) + } + + @inlinable + public mutating func sort(by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows { + try self.storage.sort(by: areInIncreasingOrder) + } +} + +// MARK: - TinyArray.Storage "private" implementation + +extension _TinyArray.Storage: Equatable where Element: Equatable { + @inlinable + static func == (lhs: Self, rhs: Self) -> Bool { + switch (lhs, rhs) { + case (.one(let lhs), .one(let rhs)): + return lhs == rhs + case (.arbitrary(let lhs), .arbitrary(let rhs)): + // we don't use lhs.elementsEqual(rhs) so we can hit the fast path from Array + // if both arrays share the same underlying storage: https://github.com/apple/swift/blob/b42019005988b2d13398025883e285a81d323efa/stdlib/public/core/Array.swift#L1775 + return lhs == rhs + + case (.one(let element), .arbitrary(let array)), + (.arbitrary(let array), .one(let element)): + guard array.count == 1 else { + return false + } + return element == array[0] + + } + } +} +extension _TinyArray.Storage: Hashable where Element: Hashable { + @inlinable + func hash(into hasher: inout Hasher) { + // same strategy as Array: https://github.com/apple/swift/blob/b42019005988b2d13398025883e285a81d323efa/stdlib/public/core/Array.swift#L1801 + hasher.combine(count) + for element in self { + hasher.combine(element) + } + } +} +extension _TinyArray.Storage: Sendable where Element: Sendable {} + +extension _TinyArray.Storage: RandomAccessCollection { + @inlinable + subscript(position: Int) -> Element { + get { + switch self { + case .one(let element): + guard position == 0 else { + fatalError("index \(position) out of bounds") + } + return element + case .arbitrary(let elements): + return elements[position] + } + } + } + + @inlinable + var startIndex: Int { + 0 + } + + @inlinable + var endIndex: Int { + switch self { + case .one: return 1 + case .arbitrary(let elements): return elements.endIndex + } + } +} + +extension _TinyArray.Storage { + @inlinable + init(_ elements: some Sequence) { + self = .arbitrary([]) + self.append(contentsOf: elements) + } + + @inlinable + init(_ newElements: some Sequence>) throws { + var iterator = newElements.makeIterator() + guard let firstElement = try iterator.next()?.get() else { + self = .arbitrary([]) + return + } + guard let secondElement = try iterator.next()?.get() else { + // newElements just contains a single element + // and we hit the fast path + self = .one(firstElement) + return + } + + var elements: [Element] = [] + elements.reserveCapacity(newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + while let nextElement = try iterator.next()?.get() { + elements.append(nextElement) + } + self = .arbitrary(elements) + } + + @inlinable + init() { + self = .arbitrary([]) + } + + @inlinable + mutating func append(_ newElement: Element) { + self.append(contentsOf: CollectionOfOne(newElement)) + } + + @inlinable + mutating func append(contentsOf newElements: some Sequence) { + switch self { + case .one(let firstElement): + var iterator = newElements.makeIterator() + guard let secondElement = iterator.next() else { + // newElements is empty, nothing to do + return + } + var elements: [Element] = [] + elements.reserveCapacity(1 + newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + elements.appendRemainingElements(from: &iterator) + self = .arbitrary(elements) + + case .arbitrary(var elements): + if elements.isEmpty { + // if `self` is currently empty and `newElements` just contains a single + // element, we skip allocating an array and set `self` to `.one(firstElement)` + var iterator = newElements.makeIterator() + guard let firstElement = iterator.next() else { + // newElements is empty, nothing to do + return + } + guard let secondElement = iterator.next() else { + // newElements just contains a single element + // and we hit the fast path + self = .one(firstElement) + return + } + elements.reserveCapacity(elements.count + newElements.underestimatedCount) + elements.append(firstElement) + elements.append(secondElement) + elements.appendRemainingElements(from: &iterator) + self = .arbitrary(elements) + + } else { + elements.append(contentsOf: newElements) + self = .arbitrary(elements) + } + + } + } + + @discardableResult + @inlinable + mutating func remove(at index: Int) -> Element { + switch self { + case .one(let oldElement): + guard index == 0 else { + fatalError("index \(index) out of bounds") + } + self = .arbitrary([]) + return oldElement + + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + return elements.remove(at: index) + + } + } + + @inlinable + mutating func removeAll(where shouldBeRemoved: (Element) throws -> Bool) rethrows { + switch self { + case .one(let oldElement): + if try shouldBeRemoved(oldElement) { + self = .arbitrary([]) + } + + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + return try elements.removeAll(where: shouldBeRemoved) + + } + } + + @inlinable + mutating func sort(by areInIncreasingOrder: (Element, Element) throws -> Bool) rethrows { + switch self { + case .one: + // a collection of just one element is always sorted, nothing to do + break + case .arbitrary(var elements): + defer { + self = .arbitrary(elements) + } + + try elements.sort(by: areInIncreasingOrder) + } + } +} + +extension Array { + @inlinable + mutating func appendRemainingElements(from iterator: inout some IteratorProtocol) { + while let nextElement = iterator.next() { + append(nextElement) + } + } +} diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift index f3d60baee6..94e17ec593 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelInboundStreamTests: XCTestCase { func testTestingStream() async throws { let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift index 1a58a54358..68446e2a34 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelOutboundWriterTests: XCTestCase { func testTestingWriter() async throws { let (writer, sink) = NIOAsyncChannelOutboundWriter.makeTestingWriter() diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 170add83fe..cafddee308 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -13,186 +13,115 @@ //===----------------------------------------------------------------------===// import Atomics import NIOConcurrencyHelpers -@_spi(AsyncChannel) @testable import NIOCore +@testable import NIOCore import NIOEmbedded import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelTests: XCTestCase { - func testAsyncChannelBasicFunctionality() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testAsyncChannelCloseOnWrite() async throws { + final class CloseOnWriteHandler: ChannelOutboundHandler { + typealias OutboundIn = String + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.close(promise: promise) + } + } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try channel.pipeline.syncOperations.addHandler(CloseOnWriteHandler()) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } - var iterator = wrapped.inboundStream.makeAsyncIterator() - try await channel.writeInbound("hello") - let firstRead = try await iterator.next() - XCTAssertEqual(firstRead, "hello") - - try await channel.writeInbound("world") - let secondRead = try await iterator.next() - XCTAssertEqual(secondRead, "world") - - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("Test") } - - let thirdRead = try await iterator.next() - XCTAssertNil(thirdRead) - - try await channel.close() } - func testAsyncChannelBasicWrites() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testAsyncChannelBasicFunctionality() async throws { let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } - try await wrapped.outboundWriter.write("hello") - try await wrapped.outboundWriter.write("world") - - let firstRead = try await channel.waitForOutboundWrite(as: String.self) - let secondRead = try await channel.waitForOutboundWrite(as: String.self) - - XCTAssertEqual(firstRead, "hello") - XCTAssertEqual(secondRead, "world") - - try await channel.close() - } - - func testDroppingTheWriterClosesTheWriteSideOfTheChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } - let channel = NIOAsyncTestingChannel() - let closeRecorder = CloseRecorder() - try await channel.pipeline.addHandler(closeRecorder) + try await wrapped.executeThenClose { inbound, _ in + var iterator = inbound.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") - let inboundReader: NIOAsyncChannelInboundStream - - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: true, - inboundType: Never.self, - outboundType: Never.self - ) - ) - } - inboundReader = wrapped.inboundStream + try await channel.writeInbound("world") + let secondRead = try await iterator.next() + XCTAssertEqual(secondRead, "world") try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) } - } - await channel.testingEventLoop.run() - - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.outboundCloses) + let thirdRead = try await iterator.next() + XCTAssertNil(thirdRead) } - - // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - channel.close(promise: nil) } - func testDroppingTheWriterDoesntCloseTheWriteSideOfTheChannelIfHalfClosureIsDisabled() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testAsyncChannelBasicWrites() async throws { let channel = NIOAsyncTestingChannel() - let closeRecorder = CloseRecorder() - try await channel.pipeline.addHandler(closeRecorder) - - let inboundReader: NIOAsyncChannelInboundStream + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: false, - inboundType: Never.self, - outboundType: Never.self - ) - ) - } - inboundReader = wrapped.inboundStream + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("hello") + try await outbound.write("world") - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) - } - } + let firstRead = try await channel.waitForOutboundWrite(as: String.self) + let secondRead = try await channel.waitForOutboundWrite(as: String.self) - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.outboundCloses) + XCTAssertEqual(firstRead, "hello") + XCTAssertEqual(secondRead, "world") } - - // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - channel.close(promise: nil) } - func testDroppingTheWriterFirstLeadsToChannelClosureWhenReaderIsAlsoDropped() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testFinishingTheWriterClosesTheWriteSideOfTheChannel() async throws { let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() - try await channel.pipeline.addHandler(CloseSuppressor()) try await channel.pipeline.addHandler(closeRecorder) - do { - let inboundReader: NIOAsyncChannelInboundStream - - do { - let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - isOutboundHalfClosureEnabled: true, - inboundType: Never.self, - outboundType: Never.self - ) - ) - } - inboundReader = wrapped.inboundStream + let wrapped = try await channel.testingEventLoop.executeInContext { + try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init( + isOutboundHalfClosureEnabled: true, + inboundType: Never.self, + outboundType: Never.self + ) + ) + } - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(0, closeRecorder.allCloses) - } - } + try await wrapped.executeThenClose { inbound, outbound in + outbound.finish() await channel.testingEventLoop.run() - // First we see half-closure. try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.allCloses) + XCTAssertEqual(1, closeRecorder.outboundCloses) } // Just use this to keep the inbound reader alive. - withExtendedLifetime(inboundReader) {} - } + withExtendedLifetime(inbound) {} - // Now the inbound reader is dead, we see full closure. - try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(2, closeRecorder.allCloses) } - - try await channel.closeIgnoringSuppression() } - func testDroppingEverythingClosesTheChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testDroppingEverythingDoesntCloseTheChannel() async throws { let channel = NIOAsyncTestingChannel() let closeRecorder = CloseRecorder() try await channel.pipeline.addHandler(CloseSuppressor()) try await channel.pipeline.addHandler(closeRecorder) do { - let wrapped = try await channel.testingEventLoop.executeInContext { + _ = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( isOutboundHalfClosureEnabled: false, inboundType: Never.self, @@ -204,24 +133,20 @@ final class AsyncChannelTests: XCTestCase { try await channel.testingEventLoop.executeInContext { XCTAssertEqual(0, closeRecorder.allCloses) } - - // Just use this to keep the wrapper alive until here. - withExtendedLifetime(wrapped) {} } // Now that everything is dead, we see full closure. try await channel.testingEventLoop.executeInContext { - XCTAssertEqual(1, closeRecorder.allCloses) + XCTAssertEqual(0, closeRecorder.allCloses) } try await channel.closeIgnoringSuppression() } func testReadsArePropagated() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.writeInbound("hello") @@ -230,15 +155,16 @@ final class AsyncChannelTests: XCTestCase { try await channel.close().get() - let reads = try await Array(wrapped.inboundStream) - XCTAssertEqual(reads, ["hello"]) + try await wrapped.executeThenClose { inbound, _ in + let reads = try await Array(inbound) + XCTAssertEqual(reads, ["hello"]) + } } func testErrorsArePropagatedButAfterReads() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.writeInbound("hello") @@ -246,20 +172,21 @@ final class AsyncChannelTests: XCTestCase { channel.pipeline.fireErrorCaught(TestError.bang) } - var iterator = wrapped.inboundStream.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first, "hello") + try await wrapped.executeThenClose { inbound, _ in + var iterator = inbound.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, "hello") - try await XCTAssertThrowsError(await iterator.next()) { error in - XCTAssertEqual(error as? TestError, .bang) + try await XCTAssertThrowsError(await iterator.next()) { error in + XCTAssertEqual(error as? TestError, .bang) + } } } func testChannelBecomingNonWritableDelaysWriters() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.testingEventLoop.executeInContext { @@ -271,9 +198,11 @@ final class AsyncChannelTests: XCTestCase { await withThrowingTaskGroup(of: Void.self) { group in group.addTask { - try await wrapped.outboundWriter.write("hello") - lock.withLockedValue { - XCTAssertTrue($0) + try await wrapped.executeThenClose { _, outbound in + try await outbound.write("hello") + lock.withLockedValue { + XCTAssertTrue($0) + } } } @@ -288,18 +217,15 @@ final class AsyncChannelTests: XCTestCase { } } } - - try await channel.close().get() } func testBufferDropsReadsIfTheReaderIsGone() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() try await channel.pipeline.addHandler(CloseSuppressor()).get() do { // Create the NIOAsyncChannel, then drop it. The handler will still be in the pipeline. _ = try await channel.testingEventLoop.executeInContext { - _ = try NIOAsyncChannel(synchronouslyWrapping: channel) + _ = try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } @@ -317,16 +243,15 @@ final class AsyncChannelTests: XCTestCase { try await channel.closeIgnoringSuppression() } - func testManagingBackpressure() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + func testManagingBackPressure() async throws { let channel = NIOAsyncTestingChannel() let readCounter = ReadCounter() try await channel.pipeline.addHandler(readCounter) let wrapped = try await channel.testingEventLoop.executeInContext { try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( - backpressureStrategy: .init(lowWatermark: 2, highWatermark: 4), + backPressureStrategy: .init(lowWatermark: 2, highWatermark: 4), inboundType: Void.self, outboundType: Never.self ) @@ -377,83 +302,84 @@ final class AsyncChannelTests: XCTestCase { } XCTAssertEqual(readCounter.readCount, 6) - // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. - var reader = wrapped.inboundStream.makeAsyncIterator() - for _ in 0..<3 { - try await XCTAsyncAssertNotNil(await reader.next()) - } - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 6) + try await wrapped.executeThenClose { inbound, outbound in + // Now consume three elements from the pipeline. This should not unbuffer the read, as 3 elements remain. + var reader = inbound.makeAsyncIterator() + for _ in 0..<3 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 6) - // Removing the next element should trigger an automatic read. - try await XCTAsyncAssertNotNil(await reader.next()) - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 7) + // Removing the next element should trigger an automatic read. + try await XCTAsyncAssertNotNil(await reader.next()) + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 7) - // Reads now work again, even if more data arrives. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() + // Reads now work again, even if more data arrives. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelReadComplete() + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() - } - XCTAssertEqual(readCounter.readCount, 13) + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 13) - // The next reads arriving pushes us past the limit again. - // This time we won't read. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelRead(NIOAny(())) - channel.pipeline.fireChannelReadComplete() - } - XCTAssertEqual(readCounter.readCount, 13) + // The next reads arriving pushes us past the limit again. + // This time we won't read. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelRead(NIOAny(())) + channel.pipeline.fireChannelReadComplete() + } + XCTAssertEqual(readCounter.readCount, 13) - // This time we'll consume 4 more elements, and we won't find a read at all. - for _ in 0..<4 { - try await XCTAsyncAssertNotNil(await reader.next()) - } - await channel.testingEventLoop.run() - XCTAssertEqual(readCounter.readCount, 13) + // This time we'll consume 4 more elements, and we won't find a read at all. + for _ in 0..<4 { + try await XCTAsyncAssertNotNil(await reader.next()) + } + await channel.testingEventLoop.run() + XCTAssertEqual(readCounter.readCount, 13) - // But the next reads work fine. - try await channel.testingEventLoop.executeInContext { - channel.pipeline.read() - channel.pipeline.read() - channel.pipeline.read() + // But the next reads work fine. + try await channel.testingEventLoop.executeInContext { + channel.pipeline.read() + channel.pipeline.read() + channel.pipeline.read() + } + XCTAssertEqual(readCounter.readCount, 16) } - XCTAssertEqual(readCounter.readCount, 16) } func testCanWrapAChannelSynchronously() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } - var iterator = wrapped.inboundStream.makeAsyncIterator() - try await channel.writeInbound("hello") - let firstRead = try await iterator.next() - XCTAssertEqual(firstRead, "hello") - - try await wrapped.outboundWriter.write("world") - let write = try await channel.waitForOutboundWrite(as: String.self) - XCTAssertEqual(write, "world") + try await wrapped.executeThenClose { inbound, outbound in + var iterator = inbound.makeAsyncIterator() + try await channel.writeInbound("hello") + let firstRead = try await iterator.next() + XCTAssertEqual(firstRead, "hello") - try await channel.testingEventLoop.executeInContext { - channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) - } + try await outbound.write("world") + let write = try await channel.waitForOutboundWrite(as: String.self) + XCTAssertEqual(write, "world") - let secondRead = try await iterator.next() - XCTAssertNil(secondRead) + try await channel.testingEventLoop.executeInContext { + channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) + } - try await channel.close() + let secondRead = try await iterator.next() + XCTAssertNil(secondRead) + } } } @@ -489,6 +415,7 @@ private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHan } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncTestingChannel { fileprivate func closeIgnoringSuppression() async throws { try await self.pipeline.context(handlerType: CloseSuppressor.self).flatMap { @@ -519,6 +446,7 @@ private enum TestError: Error { case bang } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension Array { fileprivate init(_ sequence: AS) async throws where AS.Element == Self.Element { self = [] diff --git a/Tests/NIOCoreTests/AsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequenceTests.swift index 07c51b34d3..9a000a79ef 100644 --- a/Tests/NIOCoreTests/AsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequenceTests.swift @@ -25,9 +25,9 @@ fileprivate struct TestCase { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncSequenceCollectTests: XCTestCase { func testAsyncSequenceCollect() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } let testCases = [ TestCase([ [], diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 114b06c576..db2471c43b 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import NIOCore +@testable import NIOCore import XCTest +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategy, @unchecked Sendable { enum Event { case didYield @@ -48,6 +49,7 @@ final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBa } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class MockNIOBackPressuredStreamSourceDelegate: NIOAsyncSequenceProducerDelegate, @unchecked Sendable { enum Event { case produceMore @@ -102,6 +104,7 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let result = NIOAsyncSequenceProducer.makeSequence( elementType: Int.self, backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: false, delegate: self.delegate ) self.source = result.source @@ -112,6 +115,8 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { self.backPressureStrategy = nil self.delegate = nil self.sequence = nil + self.source.finish() + self.source = nil super.tearDown() } @@ -261,11 +266,14 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testFinish_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + async let element = sequence.first { _ in true } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() @@ -307,39 +315,83 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - self.source = nil + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + + source = nil + XCTAssertNil(source) + XCTAssertNotNil(sequence) } func testSourceDeinited_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered - let sequence = try XCTUnwrap(self.sequence) + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + let suspended = expectation(description: "task suspended") + sequence!._throwingSequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { - let element = await sequence.first { _ in true } + let element = await sequence!.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - self.source = nil + source = nil return try await group.next() ?? nil } XCTAssertEqual(element, nil) + XCTAssertNil(source) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - _ = self.source.yield(contentsOf: []) + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: []) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence.first { _ in true } + return await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -350,14 +402,27 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - _ = self.source.yield(contentsOf: [1]) + var newSequence: NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: [1]) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence.first { _ in true } + return await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -371,14 +436,18 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { // MARK: - Task cancel func testTaskCancel_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let value = await task.value @@ -387,41 +456,51 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testTaskCancel_whenStreaming_andNotSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + let resumed = expectation(description: "task resumed") + let cancelled = expectation(description: "task cancelled") + + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() - let value = await iterator.next() - // Sleeping here a bit to make sure we hit the case where - // we are streaming and still retain the iterator. - try? await Task.sleep(nanoseconds: 1_000_000) + let value = await iterator.next() + resumed.fulfill() + await fulfillment(of: [cancelled], timeout: 1) return value } - try await Task.sleep(nanoseconds: 2_000_000) - + await fulfillment(of: [suspended], timeout: 1) _ = self.source.yield(contentsOf: [1]) + await fulfillment(of: [resumed], timeout: 1) task.cancel() + cancelled.fulfill() + let value = await task.value XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) XCTAssertEqual(value, 1) } func testTaskCancel_whenSourceFinished() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._throwingSequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() + XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) task.cancel() let value = await task.value @@ -430,15 +509,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { func testTaskCancel_whenStreaming_andTaskIsAlreadyCancelled() async throws { let sequence = try XCTUnwrap(self.sequence) + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return await iterator.next() } task.cancel() + cancelled.fulfill() let value = await task.value @@ -614,6 +695,7 @@ fileprivate func XCTAssertEqualWithoutAutoclosure( XCTAssertEqual(expression1, expression2, message(), file: file, line: line) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncSequence { /// Collect all elements in the sequence into an array. fileprivate func collect() async rethrows -> [Element] { diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 3acb2c0ffc..9b93841252 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -13,27 +13,46 @@ //===----------------------------------------------------------------------===// import DequeModule -import NIOCore +@testable import NIOCore import XCTest +import NIOConcurrencyHelpers private struct SomeError: Error, Hashable {} private final class MockAsyncWriterDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { typealias Element = String - var didYieldCallCount = 0 + var _didYieldCallCount = NIOLockedValueBox(0) + var didYieldCallCount: Int { + self._didYieldCallCount.withLockedValue { $0 } + } var didYieldHandler: ((Deque) -> Void)? func didYield(contentsOf sequence: Deque) { - self.didYieldCallCount += 1 + self._didYieldCallCount.withLockedValue { $0 += 1 } if let didYieldHandler = self.didYieldHandler { didYieldHandler(sequence) } } - var didTerminateCallCount = 0 + var _didSuspendCallCount = NIOLockedValueBox(0) + var didSuspendCallCount: Int { + self._didSuspendCallCount.withLockedValue { $0 } + } + var didSuspendHandler: (() -> Void)? + func didSuspend() { + self._didSuspendCallCount.withLockedValue { $0 += 1 } + if let didSuspendHandler = self.didSuspendHandler { + didSuspendHandler() + } + } + + var _didTerminateCallCount = NIOLockedValueBox(0) + var didTerminateCallCount: Int { + self._didTerminateCallCount.withLockedValue { $0 } + } var didTerminateHandler: ((Error?) -> Void)? func didTerminate(error: Error?) { - self.didTerminateCallCount += 1 + self._didTerminateCallCount.withLockedValue { $0 += 1 } if let didTerminateHandler = self.didTerminateHandler { didTerminateHandler(error) } @@ -53,13 +72,21 @@ final class NIOAsyncWriterTests: XCTestCase { let newWriter = NIOAsyncWriter.makeWriter( elementType: String.self, isWritable: true, + finishOnDeinit: false, delegate: self.delegate ) self.writer = newWriter.writer self.sink = newWriter.sink + self.sink._storage._didSuspend = self.delegate.didSuspend } override func tearDown() { + if let writer = self.writer { + writer.finish() + } + if let sink = self.sink { + sink.finish() + } self.delegate = nil self.writer = nil self.sink = nil @@ -67,7 +94,21 @@ final class NIOAsyncWriterTests: XCTestCase { super.tearDown() } + func assert( + suspendCallCount: Int, + yieldCallCount: Int, + terminateCallCount: Int, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertEqual(self.delegate.didSuspendCallCount, suspendCallCount, "Unexpeced suspends", file: file, line: line) + XCTAssertEqual(self.delegate.didYieldCallCount, yieldCallCount, "Unexpected yields", file: file, line: line) + XCTAssertEqual(self.delegate.didTerminateCallCount, terminateCallCount, "Unexpected terminates", file: file, line: line) + } + func testMultipleConcurrentWrites() async throws { + var elements = 0 + self.delegate.didYieldHandler = { elements += $0.count } let task1 = Task { [writer] in for i in 0...9 { try await writer!.yield("message\(i)") @@ -88,52 +129,74 @@ final class NIOAsyncWriterTests: XCTestCase { try await task2.value try await task3.value - XCTAssertEqual(self.delegate.didYieldCallCount, 30) + XCTAssertEqual(elements, 30) } - func testWriterCoalescesWrites() async throws { - var writes = [Deque]() - self.delegate.didYieldHandler = { - writes.append($0) - } - self.sink.setWritability(to: false) - + func testMultipleConcurrentBatchWrites() async throws { + var elements = 0 + self.delegate.didYieldHandler = { elements += $0.count } let task1 = Task { [writer] in - try await writer!.yield("message1") + for i in 0...9 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } } - task1.cancel() - try await task1.value - let task2 = Task { [writer] in - try await writer!.yield("message2") + for i in 10...19 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } } - task2.cancel() - try await task2.value - let task3 = Task { [writer] in - try await writer!.yield("message3") + for i in 20...29 { + try await writer!.yield(contentsOf: ["message\(i).1", "message\(i).2"]) + } } - task3.cancel() - try await task3.value - self.sink.setWritability(to: true) + try await task1.value + try await task2.value + try await task3.value - XCTAssertEqual(writes, [Deque(["message1", "message2", "message3"])]) + XCTAssertEqual(elements, 60) } // MARK: - WriterDeinitialized func testWriterDeinitialized_whenInitial() async throws { - self.writer = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + let sink = newWriter!.sink + var writer: NIOAsyncWriter? = newWriter!.writer + newWriter = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + writer = nil + + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) + XCTAssertNil(writer) + + sink.finish() } func testWriterDeinitialized_whenStreaming() async throws { - try await writer.yield("message1") - self.writer = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + let sink = newWriter!.sink + var writer: NIOAsyncWriter? = newWriter!.writer + newWriter = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + try await writer!.yield("message1") + writer = nil + + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) + XCTAssertNil(writer) + + sink.finish() } func testWriterDeinitialized_whenWriterFinished() async throws { @@ -141,18 +204,17 @@ final class NIOAsyncWriterTests: XCTestCase { self.writer.finish() self.writer = nil - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) } func testWriterDeinitialized_whenFinished() async throws { self.sink.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) self.writer = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - ToggleWritability @@ -160,15 +222,18 @@ final class NIOAsyncWriterTests: XCTestCase { func testSetWritability_whenInitial() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testSetWritability_whenStreaming_andBecomingUnwritable() async throws { @@ -177,69 +242,88 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testSetWritability_whenStreaming_andBecomingWritable() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let resumed = expectation(description: "yield completed") + Task { [writer] in try await writer!.yield("message2") + resumed.fulfill() } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.sink.setWritability(to: true) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + await fulfillment(of: [resumed], timeout: 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testSetWritability_whenStreaming_andSettingSameWritability() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) // Setting the writability to the same state again shouldn't change anything self.sink.setWritability(to: false) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testSetWritability_whenWriterFinished() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let resumed = expectation(description: "yield completed") + Task { [writer] in try await writer!.yield("message1") + resumed.fulfill() } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didYieldCallCount, 0) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) self.sink.setWritability(to: true) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + await fulfillment(of: [resumed], timeout: 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1) } func testSetWritability_whenFinished() async throws { @@ -247,7 +331,7 @@ final class NIOAsyncWriterTests: XCTestCase { self.sink.setWritability(to: false) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - Yield @@ -255,84 +339,98 @@ final class NIOAsyncWriterTests: XCTestCase { func testYield_whenInitial_andWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenInitial_andNotWritable() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenStreaming_andWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) try await self.writer.yield("message2") - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + self.assert(suspendCallCount: 0, yieldCallCount: 2, terminateCallCount: 0) } func testYield_whenStreaming_andNotWritable() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message2") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenStreaming_andYieldCancelled() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) + + let cancelled = expectation(description: "task cancelled") let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message2") } task.cancel() + cancelled.fulfill() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testYield_whenWriterFinished() async throws { self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + Task { [writer] in try await writer!.yield("message1") } - // Sleep a bit so that the other Task suspends on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenFinished() async throws { @@ -341,7 +439,7 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testYield_whenFinishedError() async throws { @@ -350,69 +448,80 @@ final class NIOAsyncWriterTests: XCTestCase { await XCTAssertThrowsError(try await self.writer.yield("message1")) { error in XCTAssertTrue(error is SomeError) } - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } // MARK: - Cancel func testCancel_whenInitial() async throws { + let cancelled = expectation(description: "task cancelled") + let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message1") } task.cancel() + cancelled.fulfill() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testCancel_whenStreaming_andCancelBeforeYield() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) + + let cancelled = expectation(description: "task cancelled") let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message2") } task.cancel() + cancelled.fulfill() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testCancel_whenStreaming_andCancelAfterSuspendedYield() async throws { try await self.writer.yield("message1") - XCTAssertEqual(self.delegate.didYieldCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: false) + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } + let task = Task { [writer] in try await writer!.yield("message2") } - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) task.cancel() - await XCTAssertNoThrow(try await task.value) - XCTAssertEqual(self.delegate.didYieldCallCount, 1) - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + await XCTAssertThrowsError(try await task.value) { error in + XCTAssertTrue(error is CancellationError) + } + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) self.sink.setWritability(to: true) - XCTAssertEqual(self.delegate.didYieldCallCount, 2) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 0) } func testCancel_whenFinished() async throws { @@ -420,23 +529,20 @@ final class NIOAsyncWriterTests: XCTestCase { XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + let cancelled = expectation(description: "task cancelled") + let task = Task { [writer] in - // Sleeping here a bit to delay the call to yield - // The idea is that we call yield once the Task is - // already cancelled - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) try await writer!.yield("message1") } - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) - task.cancel() + cancelled.fulfill() - XCTAssertEqual(self.delegate.didYieldCallCount, 0) await XCTAssertThrowsError(try await task.value) { error in XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } + XCTAssertEqual(self.delegate.didYieldCallCount, 0) } // MARK: - Writer Finish @@ -444,13 +550,13 @@ final class NIOAsyncWriterTests: XCTestCase { func testWriterFinish_whenInitial() async throws { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } func testWriterFinish_whenInitial_andFailure() async throws { self.writer.finish(error: SomeError()) - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } func testWriterFinish_whenStreaming() async throws { @@ -458,62 +564,102 @@ final class NIOAsyncWriterTests: XCTestCase { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 1) } func testWriterFinish_whenStreaming_AndBufferedElements() async throws { // We are setting up a suspended yield here to check that it gets resumed self.sink.setWritability(to: false) + + let suspended = expectation(description: "suspended on yield") + self.delegate.didSuspendHandler = { + suspended.fulfill() + } let task = Task { [writer] in try await writer!.yield("message1") } - - // Sleeping here to give the task enough time to suspend on the yield - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 0) + self.assert(suspendCallCount: 1, yieldCallCount: 0, terminateCallCount: 0) + + // We have to become writable again to unbuffer the yield + self.sink.setWritability(to: true) + await XCTAssertNoThrow(try await task.value) + + self.assert(suspendCallCount: 1, yieldCallCount: 1, terminateCallCount: 1) } func testWriterFinish_whenFinished() { // This tests just checks that finishing again is a no-op self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) } // MARK: - Sink Finish func testSinkFinish_whenInitial() async throws { - self.sink = nil + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + var sink: NIOAsyncWriter.Sink? = newWriter!.sink + let writer = newWriter!.writer + newWriter = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + sink = nil + + XCTAssertNil(sink) + XCTAssertNotNil(writer) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 0) } func testSinkFinish_whenStreaming() async throws { - Task { [writer] in - try await writer!.yield("message1") - } + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + elementType: String.self, + isWritable: true, + finishOnDeinit: true, + delegate: self.delegate + ) + var sink: NIOAsyncWriter.Sink? = newWriter!.sink + let writer = newWriter!.writer + newWriter = nil - try await Task.sleep(nanoseconds: 1_000_000) + try await writer.yield("message1") - self.sink = nil + sink = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + XCTAssertNil(sink) + self.assert(suspendCallCount: 0, yieldCallCount: 1, terminateCallCount: 0) } func testSinkFinish_whenFinished() async throws { self.writer.finish() - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) self.sink = nil - XCTAssertEqual(self.delegate.didTerminateCallCount, 1) + self.assert(suspendCallCount: 0, yieldCallCount: 0, terminateCallCount: 1) + } +} + +#if !canImport(Darwin) && swift(<5.9.2) +extension XCTestCase { + func fulfillment( + of expectations: [XCTestExpectation], + timeout seconds: TimeInterval, + enforceOrder enforceOrderOfFulfillment: Bool = false + ) async { + wait(for: expectations, timeout: seconds) } } +#endif diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index ebd38f87db..3e71d71988 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -import NIOCore +@testable import NIOCore import XCTest @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) @@ -41,6 +41,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { elementType: Int.self, failureType: Error.self, backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: false, delegate: self.delegate ) self.source = result.source @@ -51,6 +52,8 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { self.backPressureStrategy = nil self.delegate = nil self.sequence = nil + self.source.finish() + self.source = nil super.tearDown() } @@ -100,15 +103,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYield_whenStreaming_andSuspended_andStopProducing() async throws { self.backPressureStrategy.didYieldHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: [1]) @@ -125,15 +131,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYield_whenStreaming_andSuspended_andProduceMore() async throws { self.backPressureStrategy.didYieldHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: [1]) @@ -150,15 +159,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYieldEmptySequence_whenStreaming_andSuspended_andStopProducing() async throws { self.backPressureStrategy.didYieldHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) - try await withThrowingTaskGroup(of: Void.self) { group in + await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: []) @@ -172,15 +184,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testYieldEmptySequence_whenStreaming_andSuspended_andProduceMore() async throws { self.backPressureStrategy.didYieldHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) - try await withThrowingTaskGroup(of: Void.self) { group in + await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) + XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) let result = self.source.yield(contentsOf: []) @@ -228,16 +243,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testFinish_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { let element = try await sequence.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.source.finish() @@ -309,15 +326,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testFinishError_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) await XCTAssertThrowsError(try await withThrowingTaskGroup(of: Void.self) { group in + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + group.addTask { _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) self.source.finish(ChannelError.alreadyClosed) @@ -383,39 +402,88 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - self.source = nil + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + + source = nil + + XCTAssertNil(source) + XCTAssertNotNil(sequence) } func testSourceDeinited_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered - let sequence = try XCTUnwrap(self.sequence) + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil + let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in + + let suspended = expectation(description: "task suspended") + sequence!._storage._didSuspend = { suspended.fulfill() } + group.addTask { - let element = try await sequence.first { _ in true } + let element = try await sequence!.first { _ in true } return element } - try await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [suspended], timeout: 1) - self.source = nil + source = nil return try await group.next() ?? nil } XCTAssertEqual(element, nil) + XCTAssertNil(source) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - _ = self.source.yield(contentsOf: []) + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: []) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + return try await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -426,14 +494,28 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - _ = self.source.yield(contentsOf: [1]) + var newSequence: NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) + let sequence = newSequence?.sequence + var source = newSequence?.source + newSequence = nil - self.source = nil + _ = source!.yield(contentsOf: [1]) + + source = nil - let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + return try await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -447,14 +529,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { // MARK: - Task cancel func testTaskCancel_whenStreaming_andSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let result = await task.result @@ -467,8 +552,6 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { @available(*, deprecated, message: "tests the deprecated custom generic failure type") func testTaskCancel_whenStreaming_andSuspended_withCustomErrorType() async throws { struct CustomError: Error {} - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let backPressureStrategy = MockNIOElementStreamBackPressureStrategy() let delegate = MockNIOBackPressuredStreamSourceDelegate() let new = NIOThrowingAsyncSequenceProducer.makeSequence( @@ -478,11 +561,16 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { delegate: delegate ) let sequence = new.sequence + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) task.cancel() let result = await task.result @@ -494,36 +582,47 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testTaskCancel_whenStreaming_andNotSuspended() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + let suspended = expectation(description: "task suspended") + let resumed = expectation(description: "task resumed") + let cancelled = expectation(description: "task cancelled") + + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() let element = try await iterator.next() - - // Sleeping here to give the other Task a chance to cancel this one. - try? await Task.sleep(nanoseconds: 1_000_000) + resumed.fulfill() + await fulfillment(of: [cancelled], timeout: 1) return element } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) _ = self.source.yield(contentsOf: [1]) + await fulfillment(of: [resumed], timeout: 1) + task.cancel() + cancelled.fulfill() + let value = try await task.value XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) XCTAssertEqual(value, 1) } func testTaskCancel_whenSourceFinished() async throws { - // We are registering our demand and sleeping a bit to make - // sure our task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + let task: Task = Task { let iterator = sequence.makeAsyncIterator() return try await iterator.next() } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) self.source.finish() XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) @@ -534,15 +633,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testTaskCancel_whenStreaming_andTaskIsAlreadyCancelled() async throws { let sequence = try XCTUnwrap(self.sequence) + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return try await iterator.next() } task.cancel() + cancelled.fulfill() let result = await task.result @@ -563,16 +664,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { delegate: delegate ) let sequence = new.sequence + + let cancelled = expectation(description: "task cancelled") + let task: Task = Task { - // We are sleeping here to allow some time for us to cancel the task. - // Once the Task is cancelled we will call `next()` - try? await Task.sleep(nanoseconds: 1_000_000) + await fulfillment(of: [cancelled], timeout: 1) let iterator = sequence.makeAsyncIterator() return try await iterator.next() } task.cancel() - + cancelled.fulfill() + let result = await task.result try withExtendedLifetime(new.source) { XCTAssertNil(try result.get()) @@ -583,14 +686,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testNext_whenInitial_whenDemand() async throws { self.backPressureStrategy.didNextHandler = { _ in true } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.produceMore]) @@ -598,14 +704,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testNext_whenInitial_whenNoDemand() async throws { self.backPressureStrategy.didNextHandler = { _ in false } - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } @@ -615,14 +724,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { _ = self.source.yield(contentsOf: []) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didYield]) - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } @@ -632,14 +744,17 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { _ = self.source.yield(contentsOf: []) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didYield]) - // We are registering our demand and sleeping a bit to make - // sure the other child task runs when the demand is registered let sequence = try XCTUnwrap(self.sequence) + + let suspended = expectation(description: "task suspended") + sequence._storage._didSuspend = { suspended.fulfill() } + Task { // Would prefer to use async let _ here but that is not allowed yet _ = try await sequence.first { _ in true } } - try await Task.sleep(nanoseconds: 1_000_000) + + await fulfillment(of: [suspended], timeout: 1) XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(1).collect(), [.didNext]) } @@ -786,6 +901,7 @@ fileprivate func XCTAssertEqualWithoutAutoclosure( XCTAssertEqual(expression1, expression2, message(), file: file, line: line) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncSequence { /// Collect all elements in the sequence into an array. fileprivate func collect() async rethrows -> [Element] { diff --git a/Tests/NIOCoreTests/TimeAmountTests.swift b/Tests/NIOCoreTests/TimeAmountTests.swift index dbee1d4831..7fb6a4efe8 100644 --- a/Tests/NIOCoreTests/TimeAmountTests.swift +++ b/Tests/NIOCoreTests/TimeAmountTests.swift @@ -43,4 +43,22 @@ class TimeAmountTests: XCTestCase { lhs -= rhs XCTAssertEqual(lhs, .nanoseconds(0)) } + + func testTimeAmountCappedOverflow() { + let overflowCap = TimeAmount.nanoseconds(Int64.max) + XCTAssertEqual(TimeAmount.microseconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.milliseconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.seconds(.max), overflowCap) + XCTAssertEqual(TimeAmount.minutes(.max), overflowCap) + XCTAssertEqual(TimeAmount.hours(.max), overflowCap) + } + + func testTimeAmountCappedUnderflow() { + let underflowCap = TimeAmount.nanoseconds(.min) + XCTAssertEqual(TimeAmount.microseconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.milliseconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.seconds(.min), underflowCap) + XCTAssertEqual(TimeAmount.minutes(.min), underflowCap) + XCTAssertEqual(TimeAmount.hours(.min), underflowCap) + } } diff --git a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift index c3122e9b84..5a40874674 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift @@ -17,9 +17,9 @@ import Atomics import NIOCore @testable import NIOEmbedded +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) class AsyncTestingChannelTests: XCTestCase { func testSingleHandlerInit() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler { typealias InboundIn = Never } @@ -29,7 +29,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmptyInit() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler { typealias InboundIn = Never @@ -43,7 +42,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testMultipleHandlerInit() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } class Handler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = Never let identifier: String @@ -67,7 +65,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForInboundWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { try await XCTAsyncAssertEqual(try await channel.waitForInboundWrite(), 1) @@ -82,7 +79,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForMultipleInboundWritesInParallel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { let task1 = Task { try await channel.waitForInboundWrite(as: Int.self) } @@ -102,7 +98,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForOutboundWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { try await XCTAsyncAssertEqual(try await channel.waitForOutboundWrite(), 1) @@ -117,7 +112,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWaitForMultipleOutboundWritesInParallel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let task = Task { let task1 = Task { try await channel.waitForOutboundWrite(as: Int.self) } @@ -137,7 +131,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBuffer() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -158,7 +151,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBufferMultipleTimes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -179,7 +171,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBuffer() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -192,7 +183,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBufferMultipleTimes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() var buf = channel.allocator.buffer(capacity: 1024) buf.writeString("hello") @@ -213,7 +203,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteInboundByteBufferReThrow() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(ExceptionThrowingInboundHandler()).wait()) await XCTAsyncAssertThrowsError(try await channel.writeInbound("msg")) { error in @@ -223,7 +212,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteOutboundByteBufferReThrow() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(ExceptionThrowingOutboundHandler()).wait()) await XCTAsyncAssertThrowsError(try await channel.writeOutbound("msg")) { error in @@ -233,7 +221,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testReadOutboundWrongTypeThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.writeOutbound("hello").isFull) do { @@ -248,7 +235,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testReadInboundWrongTypeThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.writeInbound("hello").isFull) do { @@ -263,7 +249,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWrongTypesWithFastpathTypes() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buffer = channel.allocator.buffer(capacity: 0) @@ -312,7 +297,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testCloseMultipleTimesThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try await XCTAsyncAssertTrue(await channel.finish().isClean) @@ -326,7 +310,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testCloseOnInactiveIsOk() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let inactiveHandler = CloseInChannelInactiveHandler() XCTAssertNoThrow(try channel.pipeline.addHandler(inactiveHandler).wait()) @@ -337,7 +320,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedLifecycle() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let handler = ChannelLifecycleHandler() XCTAssertEqual(handler.currentState, .unregistered) @@ -383,7 +365,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedChannelAndPipelineAndChannelCoreShareTheEventLoop() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let pipelineEventLoop = channel.pipeline.eventLoop XCTAssert(pipelineEventLoop === channel.eventLoop) @@ -392,7 +373,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSendingAnythingOnEmbeddedChannel() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buffer = ByteBufferAllocator().buffer(capacity: 5) let socketAddress = try SocketAddress(unixDomainSocketPath: "path") @@ -411,7 +391,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testActiveWhenConnectPromiseFiresAndInactiveWhenClosePromiseFires() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertFalse(channel.isActive) let connectPromise = channel.eventLoop.makePromise(of: Void.self) @@ -431,7 +410,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testWriteWithoutFlushDoesNotWrite() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let buf = ByteBuffer(bytes: [1]) @@ -445,7 +423,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetLocalAddressAfterSuccessfulBind() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let bindPromise = channel.eventLoop.makePromise(of: Void.self) @@ -459,7 +436,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetRemoteAddressAfterSuccessfulConnect() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let connectPromise = channel.eventLoop.makePromise(of: Void.self) @@ -473,7 +449,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testUnprocessedOutboundUserEventFailsOnEmbeddedChannel() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() XCTAssertThrowsError(try channel.triggerUserOutboundEvent("event").wait()) { (error: Error) in @@ -487,7 +462,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testEmbeddedChannelWritabilityIsWritable() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let opaqueChannel: Channel = channel @@ -500,7 +474,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testFinishWithRecursivelyScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() let invocations = AtomicCounter() @@ -518,7 +491,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSyncOptionsAreSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -530,7 +502,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testGetChannelOptionAutoReadIsSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -541,7 +512,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSetGetChannelOptionAllowRemoteHalfClosureIsSupported() throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() try channel.testingEventLoop.submit { let options = channel.syncOptions @@ -559,7 +529,6 @@ class AsyncTestingChannelTests: XCTestCase { } func testSecondFinishThrows() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let channel = NIOAsyncTestingChannel() _ = try await channel.finish() await XCTAsyncAssertThrowsError(try await channel.finish()) diff --git a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift index 334d1c50e1..0e67bd99d2 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift @@ -19,9 +19,9 @@ import Atomics private class EmbeddedTestError: Error { } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class NIOAsyncTestingEventLoopTests: XCTestCase { func testExecuteDoesNotImmediatelyRunTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() try await loop.executeInContext { @@ -33,7 +33,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteWillRunAllTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let runCount = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() loop.execute { runCount.wrappingIncrement(ordering: .relaxed) } @@ -50,7 +49,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteWillRunTasksAddedRecursively() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() @@ -80,7 +78,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteRunsImmediately() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.execute { callbackRan.store(true, ordering: .relaxed) } @@ -99,7 +96,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testTasksScheduledAfterRunDontRun() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -120,7 +116,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testSubmitRunsImmediately() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() _ = loop.submit { callbackRan.store(true, ordering: .relaxed) } @@ -139,7 +134,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testSyncShutdownGracefullyRunsTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -154,7 +148,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testShutdownGracefullyRunsTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackRan = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() loop.scheduleTask(deadline: loop.now) { callbackRan.store(true, ordering: .relaxed) } @@ -169,7 +162,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCanControlTime() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let callbackCount = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -198,7 +190,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCanScheduleMultipleTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() for index in 1...10 { @@ -219,7 +210,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecutedTasksFromScheduledOnesAreRun() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -240,7 +230,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFromScheduledTasksProperlySchedule() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -280,7 +269,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFromExecutedTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let sentinel = ManagedAtomic(0) let loop = NIOAsyncTestingEventLoop() loop.execute { @@ -299,7 +287,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCancellingScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let loop = NIOAsyncTestingEventLoop() let task = loop.scheduleTask(in: .nanoseconds(10), { XCTFail("Cancelled task ran") }) _ = loop.scheduleTask(in: .nanoseconds(5)) { @@ -310,7 +297,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFuturesFire() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let fired = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() let task = loop.scheduleTask(in: .nanoseconds(5)) { true } @@ -323,7 +309,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksFuturesError() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let err = EmbeddedTestError() let fired = ManagedAtomic(false) let loop = NIOAsyncTestingEventLoop() @@ -345,7 +330,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testTaskOrdering() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } // This test validates that the ordering of task firing on NIOAsyncTestingEventLoop via // advanceTime(by:) is the same as on MultiThreadedEventLoopGroup: specifically, that tasks run via // schedule that expire "now" all run at the same time, and that any work they schedule is run @@ -426,7 +410,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testCancelledScheduledTasksDoNotHoldOnToRunClosure() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() defer { XCTAssertNoThrow(try eventLoop.syncShutdownGracefully()) @@ -466,7 +449,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testDrainScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) let startTime = eventLoop.now @@ -486,7 +468,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testDrainScheduledTasksDoesNotRunNewlyScheduledTasks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) @@ -503,7 +484,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testAdvanceTimeToDeadline() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let deadline = NIODeadline.uptimeNanoseconds(0) + .seconds(42) @@ -517,7 +497,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testWeCantTimeTravelByAdvancingTimeToThePast() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let tasksRun = ManagedAtomic(0) @@ -539,7 +518,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testExecuteInOrder() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let counter = ManagedAtomic(0) @@ -563,7 +541,6 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { } func testScheduledTasksInOrder() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let eventLoop = NIOAsyncTestingEventLoop() let counter = ManagedAtomic(0) diff --git a/Tests/NIOHTTP1Tests/ContentLengthTests.swift b/Tests/NIOHTTP1Tests/ContentLengthTests.swift new file mode 100644 index 0000000000..b295df6f47 --- /dev/null +++ b/Tests/NIOHTTP1Tests/ContentLengthTests.swift @@ -0,0 +1,131 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +import NIOCore +import NIOEmbedded +import NIOHTTP1 + +final class ContentLengthTests: XCTestCase { + + /// Client receives a response longer than the content-length header + func testResponseContentTooLong() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHTTPClientHandlers() + defer { + _ = try? channel.finish() + } + // Receive a response with a content-length header of 2 but a body of more than 2 bytes + let badResponse = "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 2\r\n\r\ntoo many bytes" + + XCTAssertThrowsError(try channel.sendRequestAndReceiveResponse(response: badResponse)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidConstant) + } + + channel.embeddedEventLoop.run() + } + + /// Client receives a response shorter than the content-length header + func testResponseContentTooShort() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHTTPClientHandlers() + defer { + _ = try? channel.finish() + } + // Receive a response with a content-length header of 100 but a body of less than 100 bytes + let badResponse = "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 100\r\n\r\nnot many bytes" + + // First is successful, it just waits for more bytes + XCTAssertNoThrow(try channel.sendRequestAndReceiveResponse(response: badResponse)) + // It is waiting for 100-14 = 86 more bytes + // We will send the same response again (75 bytes) + // The client will consider this as part of the body of the previous response. No error expected + XCTAssertNoThrow(try channel.sendRequestAndReceiveResponse(response: badResponse)) + // Now the client is expected only 86-75 = 11 bytes. We wil send the same 75 byte request again + // An error is expected because everything from the 12th byte forward will be parsed as a new message, which isn't well formed + XCTAssertThrowsError(try channel.sendRequestAndReceiveResponse(response: badResponse)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidConstant) + } + + channel.embeddedEventLoop.run() + } + + /// Server receives a request longer than the content-length header + func testRequestContentTooLong() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.configureHTTPServerPipeline() + defer { + _ = try? channel.finish() + } + // Receive a request with a content-length header of 2 but a body of more than 2 bytes + let badRequest = "POST / HTTP/1.1\r\nContent-Length: 2\r\n\r\nhello" + // First one is fine, the extra bytes will be treated as the next request + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) + // Which means the next request is now malformed + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidMethod) + } + + channel.embeddedEventLoop.run() + } + + /// Server receives a request shorter than the content-length header + func testRequestContentTooShort() throws { + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.configureHTTPServerPipeline() + defer { + _ = try? channel.finish() + } + // Receive a request with a content-length header of 100 but a body of less + let badRequest = "POST / HTTP/1.1\r\nContent-Length: 100\r\n\r\nnot many bytes" + // First one is fine, server will wait for 100-14 (86) further bytes to come + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: false)) + // The full request (60 bytes) will be treated as the body of the original request + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: false)) + // The original request is still 26 bytes short. Sending the request once more will complete it + XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) + // The leftover bytes from the previous write (we wrote 100 bytes where it wanted 26) will form a new malformed request + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertEqual(error as? HTTPParserError, .invalidMethod) + } + + channel.embeddedEventLoop.run() + } +} + +extension EmbeddedChannel { + /// Do a request-response cycle + /// Asserts that sending the request won't fail + /// Throws if receiving the response fails + fileprivate func sendRequestAndReceiveResponse(response: String) throws { + // Send a request + XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/")))) + XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.end(nil))) + // Receive a response + try self.writeInbound(ByteBuffer(string: response)) + } + + /// Do a response-request cycle + /// Throws if receiving the request fails + /// Asserts that sending the response won't fail + fileprivate func receiveRequestAndSendResponse(request: String, sendResponse: Bool) throws { + // Receive a request + try self.writeInbound(ByteBuffer(string: request)) + // Send a response + if sendResponse { + XCTAssertNoThrow(try self.writeOutbound(HTTPServerResponsePart.head(.init(version: .http1_1, status: .ok)))) + XCTAssertNoThrow(try self.writeOutbound(HTTPServerResponsePart.end(nil))) + } + } +} diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 42de8b87d5..a0cda42f73 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -32,31 +32,15 @@ extension EmbeddedChannel { } } -private func setUpClientChannel(clientHTTPHandler: RemovableChannelHandler, - clientUpgraders: [NIOHTTPClientProtocolUpgrader], - _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void) throws -> EmbeddedChannel { - - let channel = EmbeddedChannel() - - let config: NIOHTTPClientUpgradeConfiguration = ( - upgraders: clientUpgraders, - completionHandler: { context in - channel.pipeline.removeHandler(clientHTTPHandler, promise: nil) - upgradeCompletionHandler(context) - }) - - try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config).flatMap({ - channel.pipeline.addHandler(clientHTTPHandler) - }).wait() - - try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) - .wait() - - return channel -} +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} +#else +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader {} +#endif -private final class SuccessfulClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class SuccessfulClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -87,10 +71,15 @@ private final class SuccessfulClientUpgrader: NIOHTTPClientProtocolUpgrader { self.upgradeContextResponseCallCount += 1 return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + self.upgradeContextResponseCallCount += 1 + return channel.eventLoop.makeSucceededFuture(true) + } } -private final class ExplodingClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class ExplodingClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { + fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -118,10 +107,15 @@ private final class ExplodingClientUpgrader: NIOHTTPClientProtocolUpgrader { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + XCTFail("Upgrade should not be called.") + return channel.eventLoop.makeSucceededFuture(false) + } } -private final class DenyingClientUpgrader: NIOHTTPClientProtocolUpgrader { - +private final class DenyingClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { + fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] fileprivate let upgradeHeaders: [(String,String)] @@ -152,9 +146,14 @@ private final class DenyingClientUpgrader: NIOHTTPClientProtocolUpgrader { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + XCTFail("Upgrade should not be called.") + return channel.eventLoop.makeSucceededFuture(false) + } } -private final class UpgradeDelayClientUpgrader: NIOHTTPClientProtocolUpgrader { +private final class UpgradeDelayClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] @@ -188,7 +187,14 @@ private final class UpgradeDelayClientUpgrader: NIOHTTPClientProtocolUpgrader { context.pipeline.addHandler(self.upgradedHandler) } } - + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise() + return self.upgradePromise!.futureResult.flatMap { + channel.pipeline.addHandler(self.upgradedHandler) + }.map { _ in true} + } + fileprivate func unblockUpgrade() { self.upgradePromise!.succeed(()) } @@ -278,8 +284,45 @@ private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChanne } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +private func assertPipelineContainsUpgradeHandler(channel: Channel) { + let handler = try? channel.pipeline.syncOperations.handler(type: NIOHTTPClientUpgradeHandler.self) + + #if !canImport(Darwin) || swift(>=5.10) + let typedHandler = try? channel.pipeline.syncOperations.handler(type: NIOTypedHTTPClientUpgradeHandler.self) + XCTAssertTrue(handler != nil || typedHandler != nil) + #else + XCTAssertTrue(handler != nil) + #endif +} + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPClientUpgradeTestCase: XCTestCase { - + func setUpClientChannel( + clientHTTPHandler: RemovableChannelHandler, + clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], + _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void + ) throws -> EmbeddedChannel { + + let channel = EmbeddedChannel() + + let config: NIOHTTPClientUpgradeConfiguration = ( + upgraders: clientUpgraders, + completionHandler: { context in + channel.pipeline.removeHandler(clientHTTPHandler, promise: nil) + upgradeCompletionHandler(context) + }) + + try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config).flatMap({ + channel.pipeline.addHandler(clientHTTPHandler) + }).wait() + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + + return channel + } + // MARK: Test basic happy path requests and responses. func testSimpleUpgradeSucceeds() throws { @@ -315,13 +358,10 @@ class HTTPClientUpgradeTestCase: XCTestCase { } // Validate the pipeline still has http handlers. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: NIOHTTPClientUpgradeHandler.self)) - + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + assertPipelineContainsUpgradeHandler(channel: clientChannel) + // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" @@ -400,8 +440,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol, upgradeHeaders: clientHeaders) - let clientUpgraders: [NIOHTTPClientProtocolUpgrader] = [unusedClientUpgrader, clientUpgrader] - + let clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader] = [unusedClientUpgrader, clientUpgrader] + // The process should kick-off independently by sending the upgrade request to the server. let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), clientUpgraders: clientUpgraders) { (context) in @@ -456,7 +496,7 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } - final class AddHandlerClientUpgrader: NIOHTTPClientProtocolUpgrader { + final class AddHandlerClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let requiredUpgradeHeaders: [String] = [] fileprivate let supportedProtocol: String fileprivate let handler: T @@ -475,6 +515,10 @@ class HTTPClientUpgradeTestCase: XCTestCase { func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { return context.pipeline.addHandler(handler) } + + func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + return channel.pipeline.addHandler(handler).map { _ in true } + } } var upgradeHandlerCallbackFired = false @@ -540,10 +584,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { clientChannel.embeddedEventLoop.run() // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -582,10 +624,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -625,10 +665,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -670,10 +708,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is malformed) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) @@ -717,10 +753,8 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Should fail with error (response is denied) and remove upgrader from pipeline. // Check that the http elements are not removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertContains(handlerType: ByteToMessageHandler.self)) + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) @@ -922,3 +956,235 @@ class HTTPClientUpgradeTestCase: XCTestCase { } } } + +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { + override func setUpClientChannel( + clientHTTPHandler: RemovableChannelHandler, + clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader], + _ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void + ) throws -> EmbeddedChannel { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array(clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader }) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: upgraders + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(clientHTTPHandler) + }.map { _ in + false + } + } + var configuration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: config) + configuration.leftOverBytesStrategy = .forwardBytes + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: configuration) + let context = try channel.pipeline.syncOperations.context(handlerType: NIOTypedHTTPClientUpgradeHandler.self) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + upgradeResult.whenSuccess { result in + if result { + upgradeCompletionHandler(context) + } + } + + return channel + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testUpgradeOnlyHandlesKnownProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { + let upgradeProtocol = "myProto" + + var upgradeHandlerCallbackFired = false + + let clientUpgrader = DenyingClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .upgraderDeniedUpgrade) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is denied) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } + + override func testFiresOutboundErrorDuringAddingHandlers() throws { + let upgradeProtocol = "myProto" + var errorOnAdditionalChannelWrite: Error? + var upgradeHandlerCallbackFired = false + + let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) + let clientHandler = RecordingHTTPHandler() + + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + let promise = clientChannel.eventLoop.makePromise(of: Void.self) + + promise.futureResult.whenFailure() { error in + errorOnAdditionalChannelWrite = error + } + + // Send another outbound request during the upgrade. + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let secondRequest: HTTPClientRequestPart = .head(requestHead) + clientChannel.writeAndFlush(secondRequest, promise: promise) + + clientChannel.embeddedEventLoop.run() + + let promiseError = errorOnAdditionalChannelWrite as! NIOHTTPClientUpgradeError + XCTAssertEqual(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade, promiseError) + + // Soundness check that the upgrade was delayed. + XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + + // Upgrade now. + clientUpgrader.unblockUpgrade() + clientChannel.embeddedEventLoop.run() + + // Check that the upgrade was still successful, despite the interruption. + XCTAssert(upgradeHandlerCallbackFired) + XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) + } + + override func testUpgradeResponseMissingAllProtocols() throws { + var upgradeHandlerCallbackFired = false + + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") + let clientHandler = RecordingHTTPHandler() + + // The process should kick-off independently by sending the upgrade request to the server. + let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader]) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true + } + defer { + XCTAssertNoThrow(try clientChannel.finish()) + } + + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) + } + + clientChannel.embeddedEventLoop.run() + + // Should fail with error (response is malformed) and remove upgrader from pipeline. + + // Check that the http elements are not removed from the pipeline. + clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) + clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + + // Check that the HTTP handler received its response. + XCTAssertLessThanOrEqual(0, clientHandler.channelReadChannelHandlerContextDataCallCount) + // Check an error is reported + XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) + + XCTAssertFalse(upgradeHandlerCallbackFired) + + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + } +} +#endif diff --git a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift index c3ab6fbcee..dc6b82fdfb 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift @@ -251,6 +251,25 @@ class HTTPServerClientTest : XCTestCase { self.sentEnd = true self.maybeClose(context: context) } + case "/zero-length-body-part": + + let r = HTTPServerResponsePart.head(.init(version: req.version, status: .ok)) + context.write(self.wrapOutboundOut(r)).whenFailure { error in + XCTFail("unexpected error \(error)") + } + + context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer())))).whenFailure { error in + XCTFail("unexpected error \(error)") + } + context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer(string: "Hello World"))))).whenFailure { error in + XCTFail("unexpected error \(error)") + } + context.write(self.wrapOutboundOut(.end(nil))).recover { error in + XCTFail("unexpected error \(error)") + }.whenComplete { (_: Result) in + self.sentEnd = true + self.maybeClose(context: context) + } default: XCTFail("received request to unknown URI \(req.uri)") } @@ -437,6 +456,53 @@ class HTTPServerClientTest : XCTestCase { func testSimpleGetTrailersFileRegion() throws { try testSimpleGetTrailers(.fileRegion) } + + func testSimpleGetChunkedEncodingWithZeroLengthBodyPart() throws { + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + var expectedHeaders = HTTPHeaders() + expectedHeaders.add(name: "transfer-encoding", value: "chunked") + + let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "Hello World") + + let httpHandler = SimpleHTTPServer(.byteBuffer) + let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait()) + + defer { + XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) + } + + let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } + } + .connect(to: serverChannel.localAddress!) + .wait()) + + defer { + XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) + } + + var head = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/zero-length-body-part") + head.headers.add(name: "Host", value: "apple.com") + clientChannel.write(NIOAny(HTTPClientRequestPart.head(head)), promise: nil) + try clientChannel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil))).wait() + accumulation.syncWaitForCompletion() + } private func testSimpleGetTrailers(_ mode: SendMode) throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) diff --git a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift index d03d119713..572a1a5ce1 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift @@ -1010,7 +1010,7 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) } - func testServerCanRespondProcessingMultipleTimes() throws { + func testServerCanRespondProcessingMultipleTimes() throws { // Send in a request. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -1181,4 +1181,52 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeOutbound(HTTPServerResponsePart.body(.byteBuffer(ByteBuffer())))) XCTAssertNoThrow(try self.channel.writeOutbound(HTTPServerResponsePart.end(nil))) } + + func testSendingHeadTwiceGivesError() throws { + self.pipelineHandler.failOnPreconditions = false + // Sending a head once is normal + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) + // Sending a head twice is an error + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "received request head in state requestAndResponseEndPending")) + } + } + + func testServerRespondToNothing() { + self.pipelineHandler.failOnPreconditions = false + + // Writing an end whilst in state idle is an error + XCTAssertThrowsError(try self.channel.writeOutbound(HTTPServerResponsePart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + } + + // Calling finish surfaces the error again + XCTAssertThrowsError(try self.channel.finish()) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + } + } + + func testServerRequestEndFirstIsError() { + self.pipelineHandler.failOnPreconditions = false + // End sending a request which was never started + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + } + } + + func testForcefulShutdownWhenViolatedPrecondition() { + self.pipelineHandler.failOnPreconditions = false + + // End sending a request which was never started + XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in + XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + } + // The handler should now refuse further io, and forcefully shutdown + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) + + self.channel.embeddedEventLoop.run() + + // Ensure the channel is closed + XCTAssertNoThrow(try self.channel.closeFuture.wait()) + } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index d9821d6a5d..5b48485751 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// import XCTest -import Dispatch import NIOCore import NIOEmbedded @testable import NIOPosix @@ -35,11 +34,20 @@ extension ChannelPipeline { } } - fileprivate func assertContainsUpgrader() throws { - try self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + fileprivate func assertContainsUpgrader() { + #if !canImport(Darwin) || swift(>=5.10) + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + } catch { + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + } + #else + self.assertContains(handlerType: HTTPServerUpgradeHandler.self) + #endif } - func assertContains(handlerType: Handler.Type) throws { + func assertContains(handlerType: Handler.Type) { XCTAssertNoThrow(try self.context(handlerType: handlerType).wait(), "did not find handler") } @@ -51,6 +59,7 @@ extension ChannelPipeline { // Waits up to 1 second for the upgrader to be removed by polling the pipeline // every 50ms checking for the handler. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) fileprivate func waitForUpgraderToBeRemoved() throws { for _ in 0..<20 { do { @@ -58,8 +67,19 @@ extension ChannelPipeline { // handler present, keep waiting usleep(50) } catch ChannelPipelineError.notFound { - // No upgrader, we're good. + #if !canImport(Darwin) || swift(>=5.10) + // Checking if the typed variant is present + do { + _ = try self.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self).wait() + // handler present, keep waiting + usleep(50) + } catch ChannelPipelineError.notFound { + // No upgrader, we're good. + return + } + #else return + #endif } } @@ -83,15 +103,12 @@ extension EmbeddedChannel { } } -#if swift(>=5.7) private typealias UpgradeCompletionHandler = @Sendable (ChannelHandlerContext) -> Void -#else -private typealias UpgradeCompletionHandler = (ChannelHandlerContext) -> Void -#endif +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) private func serverHTTPChannelWithAutoremoval(group: EventLoopGroup, pipelining: Bool, - upgraders: [HTTPServerProtocolUpgrader], + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], extraHandlers: [ChannelHandler], _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, EventLoopFuture) { let p = group.next().makePromise(of: Channel.self) @@ -141,20 +158,6 @@ private func connectedClientChannel(group: EventLoopGroup, serverAddress: Socket .wait() } -private func setUpTestWithAutoremoval(pipelining: Bool = false, - upgraders: [HTTPServerProtocolUpgrader], - extraHandlers: [ChannelHandler], - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (EventLoopGroup, Channel, Channel, Channel) { - let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: group, - pipelining: pipelining, - upgraders: upgraders, - extraHandlers: extraHandlers, - upgradeCompletionHandler) - let clientChannel = try connectedClientChannel(group: group, serverAddress: serverChannel.localAddress!) - return (group, serverChannel, clientChannel, try connectedServerChannelFuture.wait()) -} - internal func assertResponseIs(response: String, expectedResponseLine: String, expectedResponseHeaders: [String]) { var lines = response.split(separator: "\r\n", omittingEmptySubsequences: false).map { String($0) } @@ -179,7 +182,15 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e XCTAssertEqual(lines.count, 0) } -private class ExplodingUpgrader: HTTPServerProtocolUpgrader { +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} +#else +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader {} +#endif + +private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] @@ -201,9 +212,14 @@ private class ExplodingUpgrader: HTTPServerProtocolUpgrader { XCTFail("upgrade called") return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + XCTFail("upgrade called") + return channel.eventLoop.makeSucceededFuture(true) + } } -private class UpgraderSaysNo: HTTPServerProtocolUpgrader { +private class UpgraderSaysNo: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] = [] @@ -223,9 +239,14 @@ private class UpgraderSaysNo: HTTPServerProtocolUpgrader { XCTFail("upgrade called") return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + XCTFail("upgrade called") + return channel.eventLoop.makeSucceededFuture(true) + } } -private class SuccessfulUpgrader: HTTPServerProtocolUpgrader { +private class SuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] private let onUpgradeComplete: (HTTPRequestHead) -> () @@ -260,13 +281,18 @@ private class SuccessfulUpgrader: HTTPServerProtocolUpgrader { self.onUpgradeComplete(upgradeRequest) return context.eventLoop.makeSucceededFuture(()) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.onUpgradeComplete(upgradeRequest) + return channel.eventLoop.makeSucceededFuture(true) + } } -private class DelayedUnsuccessfulUpgrader: HTTPServerProtocolUpgrader { +private class DelayedUnsuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] - private var upgradePromise: EventLoopPromise? + private var upgradePromise: EventLoopPromise? init(forProtocol `protocol`: String) { self.supportedProtocol = `protocol` @@ -281,22 +307,32 @@ private class DelayedUnsuccessfulUpgrader: HTTPServerProtocolUpgrader { func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { self.upgradePromise = context.eventLoop.makePromise() - return self.upgradePromise!.futureResult + return self.upgradePromise!.futureResult.map { _ in } } func unblockUpgrade(withError error: Error) { self.upgradePromise!.fail(error) } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise(of: Bool.self) + return self.upgradePromise!.futureResult + } } -private class UpgradeDelayer: HTTPServerProtocolUpgrader { +private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] = [] - private var upgradePromise: EventLoopPromise? + private var upgradePromise: EventLoopPromise? + private let upgradeRequestedPromise: EventLoopPromise - public init(forProtocol `protocol`: String) { + /// - Parameters: + /// - protocol: The protocol this upgrader knows how to support. + /// - upgradeRequestedPromise: Will be fulfilled when upgrade() is called + init(forProtocol `protocol`: String, upgradeRequestedPromise: EventLoopPromise) { self.supportedProtocol = `protocol` + self.upgradeRequestedPromise = upgradeRequestedPromise } public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { @@ -307,11 +343,18 @@ private class UpgradeDelayer: HTTPServerProtocolUpgrader { public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { self.upgradePromise = context.eventLoop.makePromise() - return self.upgradePromise!.futureResult + upgradeRequestedPromise.succeed() + return self.upgradePromise!.futureResult.map { _ in } } public func unblockUpgrade() { - self.upgradePromise!.succeed(()) + self.upgradePromise!.succeed(true) + } + + func upgrade(channel: Channel, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { + self.upgradePromise = channel.eventLoop.makePromise() + self.upgradeRequestedPromise.succeed() + return self.upgradePromise!.futureResult } } @@ -395,16 +438,33 @@ private class ReentrantReadOnChannelReadCompleteHandler: ChannelInboundHandler { } } +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class HTTPServerUpgradeTestCase: XCTestCase { + + static let eventLoop = MultiThreadedEventLoopGroup.singleton.next() + + fileprivate func setUpTestWithAutoremoval(pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, Channel, Channel) { + let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: Self.eventLoop, + pipelining: pipelining, + upgraders: upgraders, + extraHandlers: extraHandlers, + upgradeCompletionHandler) + let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannel.localAddress!) + return (serverChannel, clientChannel, try connectedServerChannelFuture.wait()) + } + func testUpgradeWithoutUpgrade() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\n" @@ -415,14 +475,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeAfterInitialRequest() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } // This request fires a subsequent upgrade in immediately. It should also be ignored. @@ -470,7 +529,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -479,11 +538,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -510,14 +567,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresCorrectHeaders() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\n" @@ -528,14 +584,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresHeadersInConnection() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } // This request is missing a 'Kafkaesque' connection header. @@ -547,14 +602,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeOnlyHandlesKnownProtocols() throws { - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], extraHandlers: []) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { XCTAssertNoThrow(try client.close().wait()) XCTAssertNoThrow(try server.close().wait()) - XCTAssertNoThrow(try group.syncShutdownGracefully()) } let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: something-else\r\n\r\n" @@ -576,7 +630,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: []) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -585,11 +639,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -624,16 +676,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: [eventSaver.wrappedValue]) { context in XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -681,7 +731,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { } let errorCatcher = ErrorSaver() - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], extraHandlers: [errorCatcher]) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -690,11 +740,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -732,15 +780,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testUpgradeIsCaseInsensitive() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["WeIrDcAsE"]) { req in } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { context in context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -762,19 +808,12 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testDelayedUpgradeBehaviour() throws { - let g = DispatchGroup() - g.enter() + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) + let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { context in } - let upgrader = UpgradeDelayer(forProtocol: "myproto") - let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { context in - g.leave() - } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - - let completePromise = group.next().makePromise(of: Void.self) + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = SingleHTTPResponseAccumulator { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -788,12 +827,12 @@ class HTTPServerUpgradeTestCase: XCTestCase { let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - g.wait() - // Ok, we don't think this upgrade should have succeeded yet, but neither should it have failed. We want to // dispatch onto the server event loop and check that the channel still contains the upgrade handler. - try connectedServer.pipeline.assertContainsUpgrader() + connectedServer.pipeline.assertContainsUpgrader() + // Wait for the upgrade function to be called + try upgradeRequestPromise.futureResult.wait() // Ok, let's unblock the upgrade now. The machinery should do its thing. try server.eventLoop.submit { upgrader.unblockUpgrade() @@ -804,21 +843,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testBuffersInboundDataDuringDelayedUpgrade() throws { - let g = DispatchGroup() - g.enter() - - let upgrader = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) let dataRecorder = DataRecorder() - let (group, server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [dataRecorder]) { context in - g.leave() - } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } + let (server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: [dataRecorder]) { context in } - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -828,14 +861,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - // This request is safe to upgrade, but is immediately followed by non-HTTP data that will probably - // blow up the HTTP parser. + // This request is safe to upgrade, but is immediately followed by non-HTTP data. let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) - // Wait for the upgrade machinery to run. - g.wait() - // Ok, send the application data in. let appData = "supersecretawesome data definitely not http\r\nawesome\r\ndata\ryeah" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: appData))).wait()) @@ -880,7 +909,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) // Ok, now we can upgrade. Upgrader should be out of the pipeline, and we should have seen the 101 response. @@ -928,14 +957,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded for the failing protocol. XCTAssertEqual(upgradingProtocol, "failingProtocol") - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) // Ok, now we'll fail the promise. This will catch an error, but the upgrade won't happen: instead, the second handler will be fired. failingProtocolPromise.fail(No.no) XCTAssertEqual(upgradingProtocol, "myproto") - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in @@ -978,7 +1007,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) @@ -1039,7 +1068,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Upgrade has been requested but not proceeded. XCTAssertTrue(upgradeRequested) - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) XCTAssertNoThrow(try channel.throwIfErrorCaught()) @@ -1095,17 +1124,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testRemovesAllHTTPRelatedHandlersAfterUpgrade() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(pipelining: true, + let (_, client, connectedServer) = try setUpTestWithAutoremoval(pipelining: true, upgraders: [upgrader], extraHandlers: []) { context in } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } // First, validate the pipeline is right. - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: HTTPResponseEncoder.self)) - XCTAssertNoThrow(try connectedServer.pipeline.assertContains(handlerType: HTTPServerPipelineHandler.self)) + connectedServer.pipeline.assertContains(handlerType: ByteToMessageHandler.self) + connectedServer.pipeline.assertContains(handlerType: HTTPResponseEncoder.self) + connectedServer.pipeline.assertContains(handlerType: HTTPServerPipelineHandler.self) // This request is safe to upgrade. let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" @@ -1206,7 +1232,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -1216,11 +1242,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { secondByteDonePromise: secondByteDonePromise, allDonePromise: allDonePromise)) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - - let completePromise = group.next().makePromise(of: Void.self) + + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1263,7 +1286,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - let delayer = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let delayer = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) defer { delayer.unblockUpgrade() } @@ -1275,7 +1299,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.embeddedEventLoop.run() // Upgrade has been requested but not proceeded. - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readInbound(as: ByteBuffer.self))) // The 101 has been sent. @@ -1311,7 +1335,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - let delayer = UpgradeDelayer(forProtocol: "myproto") + let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) + let delayer = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) defer { delayer.unblockUpgrade() } @@ -1324,7 +1349,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { channel.embeddedEventLoop.run() // Upgrade has been requested but not proceeded. - XCTAssertNoThrow(try channel.pipeline.assertContainsUpgrader()) + channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(try XCTAssertNil(channel.readInbound(as: ByteBuffer.self))) // The 101 has been sent. @@ -1376,7 +1401,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (group, _, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], extraHandlers: []) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) @@ -1385,11 +1410,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { // We're closing the connection now. context.close(promise: nil) } - defer { - XCTAssertNoThrow(try group.syncShutdownGracefully()) - } - let completePromise = group.next().makePromise(of: Void.self) + let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") assertResponseIs(response: resultString, @@ -1451,7 +1473,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() } func testFailedUpgradeResponseWriteThrowsError() throws { @@ -1496,7 +1518,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() } func testFailedUpgraderThrowsError() throws { @@ -1539,6 +1561,508 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Nothing should have been forwarded. XCTAssertTrue(dataRecorder.receivedData().isEmpty) // The upgrade handler should still be in the pipeline. - try channel.pipeline.assertContainsUpgrader() + channel.pipeline.assertContainsUpgrader() + } +} + +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { + fileprivate override func setUpTestWithAutoremoval( + pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler + ) throws -> (Channel, Channel, Channel) { + let connectionChannelPromise = Self.eventLoop.makePromise(of: Channel.self) + let serverChannelFuture = ServerBootstrap(group: Self.eventLoop) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + connectionChannelPromise.succeed(channel) + var configuration = NIOUpgradableHTTPServerPipelineConfiguration( + upgradeConfiguration: .init( + upgraders: upgraders.map { $0 as! any NIOTypedHTTPServerProtocolUpgrader }, + notUpgradingCompletionHandler: { notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) } + ) + ) + configuration.enablePipelining = pipelining + return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(configuration: configuration) + .flatMap { result in + if result { + return channel.pipeline.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self) + .map { + upgradeCompletionHandler($0) + } + } else { + return channel.eventLoop.makeSucceededVoidFuture() + } + } + } + .flatMap { _ in + let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) + } + }.bind(host: "127.0.0.1", port: 0) + let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannelFuture.wait().localAddress!) + return (try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) + } + + func testNotUpgrading() throws { + let notUpgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { _ in } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [], + notUpgradingHandler: { channel in + notUpgraderCbFired.wrappedValue = true + // We're closing the connection now. + channel.close(promise: nil) + return channel.eventLoop.makeSucceededFuture(true) + } + ) { _ in } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + XCTAssertEqual(resultString, "") + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that the not upgrader got called. + XCTAssert(notUpgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + // - MARK: The following tests are all overridden from the base class since they slightly differ in behaviour + + override func testSimpleUpgradeSucceeds() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + // This is called before completion block. + upgradeRequest.wrappedValue = req + upgradeHandlerCbFired.wrappedValue = true + + XCTAssert(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeRespectsClientPreference() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = ExplodingUpgrader(forProtocol: "exploder") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: []) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + } + + override func testUpgraderCanRejectUpgradeForPersonalReasons() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + let explodingUpgrader = UpgraderSaysNo(forProtocol: "noproto") + let successfulUpgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + let errorCatcher = ErrorSaver() + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [errorCatcher]) { context in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() + + // And we want to confirm we saved the error. + XCTAssertEqual(errorCatcher.errors.count, 1) + + switch(errorCatcher.errors[0]) { + case UpgraderSaysNo.No.no: + break + default: + XCTFail("Unexpected error: \(errorCatcher.errors[0])") + } + } + + override func testUpgradeWithUpgradePayloadInlineWithRequestWorks() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + enum ReceivedTheWrongThingError: Error { case error } + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + + class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias OutboundIn = Never + typealias OutboundOut = Never + + enum State { + case fresh + case added + case inlineDataRead + case extraDataRead + case closed + } + + private let firstByteDonePromise: EventLoopPromise + private let secondByteDonePromise: EventLoopPromise + private let allDonePromise: EventLoopPromise + private var state = State.fresh + + init(firstByteDonePromise: EventLoopPromise, + secondByteDonePromise: EventLoopPromise, + allDonePromise: EventLoopPromise) { + self.firstByteDonePromise = firstByteDonePromise + self.secondByteDonePromise = secondByteDonePromise + self.allDonePromise = allDonePromise + } + + func handlerAdded(context: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .added + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + XCTAssertEqual(1, buf.readableBytes) + let stringRead = buf.readString(length: buf.readableBytes) + switch self.state { + case .added: + XCTAssertEqual("A", stringRead) + self.state = .inlineDataRead + if stringRead == .some("A") { + self.firstByteDonePromise.succeed(()) + } else { + self.firstByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + case .inlineDataRead: + XCTAssertEqual("B", stringRead) + self.state = .extraDataRead + context.channel.close(promise: nil) + if stringRead == .some("B") { + self.secondByteDonePromise.succeed(()) + } else { + self.secondByteDonePromise.fail(ReceivedTheWrongThingError.error) + } + default: + XCTFail("channel read in wrong state \(self.state)") + } + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + XCTAssertEqual(.extraDataRead, self.state) + self.state = .closed + context.close(mode: mode, promise: promise) + + self.allDonePromise.succeed(()) + } + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let promiseGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try promiseGroup.syncShutdownGracefully()) + } + let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) + let allDonePromise = promiseGroup.next().makePromise(of: Void.self) + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, + secondByteDonePromise: secondByteDonePromise, + allDonePromise: allDonePromise)) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + request += "A" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + XCTAssertNoThrow(try firstByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: "B"))).wait()) + + XCTAssertNoThrow(try secondByteDonePromise.futureResult.wait() as Void) + + XCTAssertNoThrow(try allDonePromise.futureResult.wait() as Void) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + + XCTAssertNoThrow(try allDonePromise.futureResult.wait()) + } + + override func testWeTolerateUpgradeFuturesFromWrongEventLoops() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let upgradeRequest = UnsafeMutableTransferBox(nil) + let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) + let upgraderCbFired = UnsafeMutableTransferBox(false) + let otherELG = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", + requiringHeaders: ["kafkaesque"], + buildUpgradeResponseFuture: { + // this is the wrong EL + otherELG.next().makeSucceededFuture($1) + }) { req in + upgradeRequest.wrappedValue = req + XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) + upgraderCbFired.wrappedValue = true + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (context) in + // This is called before the upgrader gets called. + XCTAssertNotNil(upgradeRequest.wrappedValue) + upgradeHandlerCbFired.wrappedValue = true + + // We're closing the connection now. + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired.wrappedValue) + XCTAssert(upgraderCbFired.wrappedValue) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + } + + override func testUpgradeFiresUserEvent() throws { + // This test is different since we call the completionHandler after the upgrader + // modified the pipeline in the typed version. + let eventSaver = UnsafeTransfer(UserEventSaver()) + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) + } + + let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: [eventSaver.wrappedValue]) { context in + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + context.close(promise: nil) + } + + + let completePromise = Self.eventLoop.makePromise(of: Void.self) + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(()) + } + XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) + + // This request is safe to upgrade. + let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we should have received one user event. We schedule this onto the + // event loop to guarantee thread safety. + XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { + XCTAssertEqual(proto, "myproto") + XCTAssertEqual(req.method, .OPTIONS) + XCTAssertEqual(req.uri, "*") + XCTAssertEqual(req.version, .http1_1) + } else { + XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") + } + }.futureResult.wait()) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.waitForUpgraderToBeRemoved() } } +#endif diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 7e9061b88d..1fb1bdd810 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -13,10 +13,10 @@ //===----------------------------------------------------------------------===// import NIOConcurrencyHelpers -@_spi(AsyncChannel) @testable import NIOCore -@_spi(AsyncChannel) @testable import NIOPosix +@testable import NIOCore +@testable import NIOPosix import XCTest -@_spi(AsyncChannel) import NIOTLS +import NIOTLS private final class IPHeaderRemoverHandler: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope @@ -191,6 +191,7 @@ private final class AddressedEnvelopingHandler: ChannelDuplexHandler { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelBootstrapTests: XCTestCase { enum NegotiationResult { case string(NIOAsyncChannel) @@ -224,7 +225,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) return try NIOAsyncChannel( - synchronouslyWrapping: channel, + wrappingChannelSynchronously: channel, configuration: .init( inboundType: String.self, outboundType: String.self @@ -239,16 +240,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { _ in - for try await childChannel in channel.inboundStream { - for try await value in childChannel.inboundStream { - continuation.yield(.string(value)) + try await channel.executeThenClose { inbound in + for try await childChannel in inbound { + try await childChannel.executeThenClose { childChannelInbound, _ in + for try await value in childChannelInbound { + continuation.yield(.string(value)) + } + } } } } } let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.executeThenClose { _, outbound in + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) @@ -262,7 +269,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelOption(ChannelOptions.autoRead, value: true) .bind( @@ -280,16 +287,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -302,11 +315,13 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .string ) - let stringNegotiationResult = try await stringNegotiationResultFuture.getResult() + let stringNegotiationResult = try await stringNegotiationResultFuture.get() switch stringNegotiationResult { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -317,13 +332,15 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .byte ) - let byteNegotiationResult = try await byteNegotiationResultFuture.getResult() + let byteNegotiationResult = try await byteNegotiationResultFuture.get() switch byteNegotiationResult { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -337,7 +354,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .bind( host: "127.0.0.1", @@ -354,16 +371,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get().get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -377,10 +400,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .string, proposedInnerALPN: .string ) - switch try await stringStringNegotiationResult.getResult() { + switch try await stringStringNegotiationResult.get().get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -392,10 +417,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .byte, proposedInnerALPN: .string ) - switch try await byteStringNegotiationResult.getResult() { + switch try await byteStringNegotiationResult.get().get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -407,12 +434,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .byte, proposedInnerALPN: .byte ) - switch try await byteByteNegotiationResult.getResult() { + switch try await byteByteNegotiationResult.get().get() { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -422,12 +451,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: .string, proposedInnerALPN: .byte ) - switch try await stringByteNegotiationResult.getResult() { + switch try await stringByteNegotiationResult.get().get() { case .string: preconditionFailure() case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) + try await byteChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write(UInt8(8)) + } await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) } @@ -460,7 +491,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } let channels = NIOLockedValueBox<[Channel]>([Channel]()) - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .serverChannelInitializer { channel in channel.eventLoop.makeCompletedFuture { @@ -483,16 +514,22 @@ final class AsyncChannelBootstrapTests: XCTestCase { group.addTask { try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) + try await channel.executeThenClose { inbound in + for try await negotiationResult in inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.string(value)) + } + } + case .byte(let channel): + try await channel.executeThenClose { inbound, _ in + for try await value in inbound { + continuation.yield(.byte(value)) + } + } } } } @@ -506,7 +543,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedALPN: .unknown ) await XCTAssertThrowsError( - try await failedProtocolNegotiation.getResult() + try await failedProtocolNegotiation.get() ) // Let's check that we can still open a new connection @@ -515,10 +552,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: channel.channel.localAddress!.port!, proposedALPN: .string ) - switch try await stringNegotiationResult.getResult() { + switch try await stringNegotiationResult.get() { case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") + try await stringChannel.executeThenClose { _, outbound in + // This is the actual content + try await outbound.write("hello") + } await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) case .byte: preconditionFailure() @@ -549,14 +588,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: eventLoopGroup, port: serverChannel.channel.localAddress!.port! ) - var serverInboundIterator = serverChannel.inboundStream.makeAsyncIterator() - var clientInboundIterator = clientChannel.inboundStream.makeAsyncIterator() + try await serverChannel.executeThenClose { serverChannelInbound, serverChannelOutbound in + try await clientChannel.executeThenClose { clientChannelInbound, clientChannelOutbound in + var serverInboundIterator = serverChannelInbound.makeAsyncIterator() + var clientInboundIterator = clientChannelInbound.makeAsyncIterator() - try await clientChannel.outboundWriter.write("request") - try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") + try await clientChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outboundWriter.write("response") - try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + try await serverChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + } + } } func testDatagramBootstrap_withProtocolNegotiation_andHostPort() async throws { @@ -575,7 +618,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { let port = channel.localAddress!.port! try await channel.close() - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in + try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in group.addTask { // We have to use a fixed port here since we only get the channel once protocol negotiation is done try await self.makeUDPServerChannelWithProtocolNegotiation( @@ -599,16 +642,20 @@ final class AsyncChannelBootstrapTests: XCTestCase { let firstNegotiationResult = try await group.next() let secondNegotiationResult = try await group.next() - switch (try await firstNegotiationResult?.getResult(), try await secondNegotiationResult?.getResult()) { + switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() - var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() + try await firstChannel.executeThenClose { firstChannelInbound, firstChannelOutbound in + try await secondChannel.executeThenClose { secondChannelInbound, secondChannelOutbound in + var firstInboundIterator = firstChannelInbound.makeAsyncIterator() + var secondInboundIterator = secondChannelInbound.makeAsyncIterator() - try await firstChannel.outboundWriter.write("request") - try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") + try await firstChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outboundWriter.write("response") - try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + try await secondChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + } + } default: preconditionFailure() @@ -620,85 +667,256 @@ final class AsyncChannelBootstrapTests: XCTestCase { func testPipeBootstrap() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH) = self.makePipeFileDescriptors() - let toChannel = FileHandle(fileDescriptor: pipe1WriteFH, closeOnDealloc: false) - let fromChannel = FileHandle(fileDescriptor: pipe2ReadFH, closeOnDealloc: false) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD) = self.makePipeFileDescriptors() let channel: NIOAsyncChannel + let toChannel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel do { channel = try await NIOPipeBootstrap(group: eventLoopGroup) .takingOwnershipOfDescriptors( - input: pipe1ReadFH, - output: pipe2WriteFH + input: pipe1ReadFD, + output: pipe2WriteFD ) { channel in channel.eventLoop.makeCompletedFuture { - try NIOAsyncChannel(synchronouslyWrapping: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } catch { - [pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH].forEach { try? SystemCalls.close(descriptor: $0) } + try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } throw error } - var inboundIterator = channel.inboundStream.makeAsyncIterator() + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe2ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + try await channel.executeThenClose { channelInbound, channelOutbound in + try await fromChannel.executeThenClose { fromChannelInbound, _ in + try await toChannel.executeThenClose { _, toChannelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() + + try await toChannelOutbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + + let response = ByteBuffer(string: "Response") + try await channelOutbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + } + } + } + + func testPipeBootstrap_whenInputNil() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD) = self.makePipeFileDescriptors() + let channel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel do { - try toChannel.writeBytes(.init(string: "Request")) - try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + channel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } - let response = ByteBuffer(string: "Response") - try await channel.outboundWriter.write(response) - XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe1ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } } catch { - // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFH, pipe2ReadFH].forEach { try? SystemCalls.close(descriptor: $0) } + try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } throw error } + + try await channel.executeThenClose { channelInbound, channelOutbound in + try await fromChannel.executeThenClose { fromChannelInbound, _ in + var inboundIterator = channelInbound.makeAsyncIterator() + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() + + try await XCTAsyncAssertEqual(try await inboundIterator.next(), nil) + + let response = ByteBuffer(string: "Response") + try await channelOutbound.write(response) + try await XCTAsyncAssertEqual(try await fromChannelInboundIterator.next(), response) + } + } } - func testPipeBootstrap_withProtocolNegotiation() async throws { + func testPipeBootstrap_whenOutputNil() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - let (pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH) = self.makePipeFileDescriptors() - let toChannel = FileHandle(fileDescriptor: pipe1WriteFH, closeOnDealloc: false) - let fromChannel = FileHandle(fileDescriptor: pipe2ReadFH, closeOnDealloc: false) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD) = self.makePipeFileDescriptors() + let channel: NIOAsyncChannel + let toChannel: NIOAsyncChannel - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in - group.addTask { - do { - return try await NIOPipeBootstrap(group: eventLoopGroup) - .takingOwnershipOfDescriptors( - input: pipe1ReadFH, - output: pipe2WriteFH - ) { channel in - return channel.eventLoop.makeCompletedFuture { - return try self.configureProtocolNegotiationHandlers(channel: channel) - } - } - } catch { - [pipe1ReadFH, pipe1WriteFH, pipe2ReadFH, pipe2WriteFH].forEach { try? SystemCalls.close(descriptor: $0) } - throw error + do { + channel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe1ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + try await channel.executeThenClose { channelInbound, channelOutbound in + try await toChannel.executeThenClose { _, toChannelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() + + try await toChannelOutbound.write(.init(string: "Request")) + try await XCTAsyncAssertEqual(try await inboundIterator.next(), ByteBuffer(string: "Request")) + + let response = ByteBuffer(string: "Response") + await XCTAsyncAssertThrowsError(try await channelOutbound.write(response)) { error in + XCTAssertEqual(error as? NIOAsyncWriterError, .alreadyFinished()) } } + } + } - try toChannel.writeBytes(.init(string: "alpn:string\nHello\n")) - let negotiationResult = try await group.next() - switch try await negotiationResult?.getResult() { - case .string(let channel): - var inboundIterator = channel.inboundStream.makeAsyncIterator() - do { - try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") + func testPipeBootstrap_withProtocolNegotiation() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let (pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD) = self.makePipeFileDescriptors() + let negotiationResult: EventLoopFuture + let toChannel: NIOAsyncChannel + let fromChannel: NIOAsyncChannel - let response = ByteBuffer(string: "Response") - try await channel.outboundWriter.write("Response") - XCTAssertEqual(try fromChannel.readBytes(ofExactLength: response.readableBytes), Array(buffer: response)) - } catch { - // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFH, pipe2ReadFH].forEach { try? SystemCalls.close(descriptor: $0) } - throw error + do { + negotiationResult = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptors( + input: pipe1ReadFD, + output: pipe2WriteFD + ) { channel in + return channel.eventLoop.makeCompletedFuture { + return try self.configureProtocolNegotiationHandlers(channel: channel) + } } + } catch { + try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } - case .byte, nil: - fatalError() + do { + toChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + output: pipe1WriteFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + do { + fromChannel = try await NIOPipeBootstrap(group: eventLoopGroup) + .takingOwnershipOfDescriptor( + input: pipe2ReadFD + ) { channel in + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + } catch { + try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + throw error + } + + try await fromChannel.executeThenClose { fromChannelInbound, _ in + try await toChannel.executeThenClose { _, toChannelOutbound in + var fromChannelInboundIterator = fromChannelInbound.makeAsyncIterator() + + try await toChannelOutbound.write(.init(string: "alpn:string\nHello\n")) + switch try await negotiationResult.get() { + case .string(let channel): + try await channel.executeThenClose { channelInbound, channelOutbound in + var inboundIterator = channelInbound.makeAsyncIterator() + do { + try await XCTAsyncAssertEqual(try await inboundIterator.next(), "Hello") + + let expectedResponse = ByteBuffer(string: "Response\n") + try await channelOutbound.write("Response") + let response = try await fromChannelInboundIterator.next() + XCTAssertEqual(response, expectedResponse) + } catch { + // We only got to close the FDs that are not owned by the PipeChannel + [pipe1WriteFD, pipe2ReadFD].forEach { try? SystemCalls.close(descriptor: $0) } + throw error + } + } + + case .byte: + fatalError() + } } } } @@ -715,14 +933,18 @@ final class AsyncChannelBootstrapTests: XCTestCase { let serverChannel = try await self.makeRawSocketServerChannel(eventLoopGroup: eventLoopGroup) let clientChannel = try await self.makeRawSocketClientChannel(eventLoopGroup: eventLoopGroup) - var serverInboundIterator = serverChannel.inboundStream.makeAsyncIterator() - var clientInboundIterator = clientChannel.inboundStream.makeAsyncIterator() + try await serverChannel.executeThenClose { serverChannelInbound, serverChannelOutbound in + try await clientChannel.executeThenClose { clientChannelInbound, clientChannelOutbound in + var serverInboundIterator = serverChannelInbound.makeAsyncIterator() + var clientInboundIterator = clientChannelInbound.makeAsyncIterator() - try await clientChannel.outboundWriter.write("request") - try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") + try await clientChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await serverInboundIterator.next(), "request") - try await serverChannel.outboundWriter.write("response") - try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + try await serverChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await clientInboundIterator.next(), "response") + } + } } func testRawSocketBootstrap_withProtocolNegotiation() async throws { @@ -732,7 +954,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - try await withThrowingTaskGroup(of: EventLoopFuture>.self) { group in + try await withThrowingTaskGroup(of: EventLoopFuture.self) { group in group.addTask { // We have to use a fixed port here since we only get the channel once protocol negotiation is done try await self.makeRawSocketServerChannelWithProtocolNegotiation( @@ -753,16 +975,20 @@ final class AsyncChannelBootstrapTests: XCTestCase { let firstNegotiationResult = try await group.next() let secondNegotiationResult = try await group.next() - switch (try await firstNegotiationResult?.getResult(), try await secondNegotiationResult?.getResult()) { + switch (try await firstNegotiationResult?.get(), try await secondNegotiationResult?.get()) { case (.string(let firstChannel), .string(let secondChannel)): - var firstInboundIterator = firstChannel.inboundStream.makeAsyncIterator() - var secondInboundIterator = secondChannel.inboundStream.makeAsyncIterator() + try await firstChannel.executeThenClose { firstChannelInbound, firstChannelOutbound in + try await secondChannel.executeThenClose { secondChannelInbound, secondChannelOutbound in + var firstInboundIterator = firstChannelInbound.makeAsyncIterator() + var secondInboundIterator = secondChannelInbound.makeAsyncIterator() - try await firstChannel.outboundWriter.write("request") - try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") + try await firstChannelOutbound.write("request") + try await XCTAsyncAssertEqual(try await secondInboundIterator.next(), "request") - try await secondChannel.outboundWriter.write("response") - try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + try await secondChannelOutbound.write("response") + try await XCTAsyncAssertEqual(try await firstInboundIterator.next(), "response") + } + } default: preconditionFailure() @@ -770,20 +996,97 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } + // MARK: VSock + + func testVSock() async throws { + try XCTSkipUnless(System.supportsVsock, "No vsock transport available") + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let port = VsockAddress.Port(1234) + + let serverChannel = try await ServerBootstrap(group: eventLoopGroup) + .bind( + to: VsockAddress(cid: .any, port: port) + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + + #if canImport(Darwin) + let connectAddress = VsockAddress(cid: .any, port: port) + #elseif os(Linux) || os(Android) + let connectAddress = VsockAddress(cid: .local, port: port) + #endif + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var iterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { _ in + try await serverChannel.executeThenClose { inbound in + for try await childChannel in inbound { + try await childChannel.executeThenClose { childChannelInbound, _ in + for try await value in childChannelInbound { + continuation.yield(.string(value)) + } + } + } + } + } + } + + let stringChannel = try await ClientBootstrap(group: eventLoopGroup) + .connect(to: connectAddress) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + } + } + try await stringChannel.executeThenClose { _, outbound in + try await outbound.write("hello") + } + + await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) + + group.cancelAll() + } + } + + // MARK: - Test Helpers - private func makePipeFileDescriptors() -> (pipe1ReadFH: Int32, pipe1WriteFH: Int32, pipe2ReadFH: Int32, pipe2WriteFH: Int32) { - var pipe1FDs: [Int32] = [-1, -1] + private func makePipeFileDescriptors() -> (pipe1ReadFD: CInt, pipe1WriteFD: CInt, pipe2ReadFD: CInt, pipe2WriteFD: CInt) { + var pipe1FDs: [CInt] = [-1, -1] pipe1FDs.withUnsafeMutableBufferPointer { ptr in XCTAssertEqual(0, pipe(ptr.baseAddress!)) } - var pipe2FDs: [Int32] = [-1, -1] + var pipe2FDs: [CInt] = [-1, -1] pipe2FDs.withUnsafeMutableBufferPointer { ptr in XCTAssertEqual(0, pipe(ptr.baseAddress!)) } return (pipe1FDs[0], pipe1FDs[1], pipe2FDs[0], pipe2FDs[1]) } + private func makePipeFileDescriptors() -> (pipeReadFD: CInt, pipeWriteFD: CInt) { + var pipeFDs: [CInt] = [-1, -1] + pipeFDs.withUnsafeMutableBufferPointer { ptr in + XCTAssertEqual(0, pipe(ptr.baseAddress!)) + } + return (pipeFDs[0], pipeFDs[1]) + } + + + private func makeRawSocketServerChannel(eventLoopGroup: EventLoopGroup) async throws -> NIOAsyncChannel { try await NIORawSocketBootstrap(group: eventLoopGroup) .bind( @@ -796,7 +1099,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 1))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 2))) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -813,14 +1116,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 2))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 1))) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } private func makeRawSocketServerChannelWithProtocolNegotiation( eventLoopGroup: EventLoopGroup - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await NIORawSocketBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", @@ -837,7 +1140,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { private func makeRawSocketClientChannelWithProtocolNegotiation( eventLoopGroup: EventLoopGroup, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await NIORawSocketBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -861,7 +1164,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -870,7 +1173,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { return try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) @@ -886,7 +1189,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: Int, proposedOuterALPN: TLSUserEventHandler.ALPN, proposedInnerALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture> { return try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) @@ -912,7 +1215,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -921,7 +1224,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN? = nil - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await DatagramBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", @@ -945,7 +1248,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } @@ -954,7 +1257,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { eventLoopGroup: EventLoopGroup, port: Int, proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { + ) async throws -> EventLoopFuture { try await DatagramBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -973,7 +1276,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedALPN: TLSUserEventHandler.ALPN? = nil, inboundID: UInt8? = nil, outboundID: UInt8? = nil - ) throws -> EventLoopFuture> { + ) throws -> EventLoopFuture { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: inboundID))) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: outboundID))) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN)) @@ -985,11 +1288,11 @@ final class AsyncChannelBootstrapTests: XCTestCase { channel: Channel, proposedOuterALPN: TLSUserEventHandler.ALPN? = nil, proposedInnerALPN: TLSUserEventHandler.ALPN? = nil - ) throws -> EventLoopFuture> { + ) throws -> EventLoopFuture> { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler> { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): switch alpn { @@ -998,14 +1301,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - return NIOProtocolNegotiationResult(deferredResult: negotiationFuture) + return negotiationFuture } case "byte": return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - return NIOProtocolNegotiationResult(deferredResult: negotiationHandler) + return negotiationHandler } default: return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } @@ -1019,7 +1322,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } @discardableResult - private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture> { + private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture { let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): @@ -1028,20 +1331,20 @@ final class AsyncChannelBootstrapTests: XCTestCase { return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel + wrappingChannelSynchronously: channel ) - return NIOProtocolNegotiationResult(result: .string(asyncChannel)) + return .string(asyncChannel) } case "byte": return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler()) let asyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel + wrappingChannelSynchronously: channel ) - return NIOProtocolNegotiationResult(result: .byte(asyncChannel)) + return .byte(asyncChannel) } default: return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } @@ -1056,6 +1359,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncStream { fileprivate static func makeStream( of elementType: Element.Type = Element.self, diff --git a/Tests/NIOPosixTests/EventLoopTest.swift b/Tests/NIOPosixTests/EventLoopTest.swift index 52403b0bf5..75956c55b3 100644 --- a/Tests/NIOPosixTests/EventLoopTest.swift +++ b/Tests/NIOPosixTests/EventLoopTest.swift @@ -382,7 +382,38 @@ public final class EventLoopTest : XCTestCase { eventLoop.advanceTime(by: .hours(10)) XCTAssertEqual(5, counter) } + + func testScheduledRepeatedAsyncTaskIsJittered() throws { + let initialDelay = TimeAmount.minutes(5) + let delay = TimeAmount.minutes(2) + let maximumAllowableJitter = TimeAmount.minutes(1) + let counter = ManagedAtomic(0) + let loop = EmbeddedEventLoop() + + _ = loop.scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, maximumAllowableJitter: maximumAllowableJitter, { RepeatedTask in + counter.wrappingIncrement(ordering: .relaxed) + let p = loop.makePromise(of: Void.self) + loop.scheduleTask(in: .milliseconds(10)) { + p.succeed(()) + } + return p.futureResult + }) + + for _ in 0..<10 { + // just running shouldn't do anything + loop.run() + } + let timeRange = TimeAmount.hours(1) + // Due to jittered delays is not possible to exactly know how many tasks will be executed in a given time range, + // instead calculate a range representing an estimate of the number of tasks executed during that given time range. + let minNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / (delay.nanoseconds + maximumAllowableJitter.nanoseconds) + let maxNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / delay.nanoseconds + 1 + + loop.advanceTime(by: timeRange) + XCTAssertTrue((minNumberOfExecutedTasks...maxNumberOfExecutedTasks).contains(counter.load(ordering: .relaxed))) + } + public func testEventLoopGroupMakeIterator() throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) defer { @@ -803,6 +834,27 @@ public final class EventLoopTest : XCTestCase { semaphore.signal() XCTAssertEqual(XCTWaiter.wait(for: [expect1, expect2], timeout: 0.5), .completed) } + + func testRepeatedTaskIsJittered() throws { + let initialDelay = TimeAmount.minutes(5) + let delay = TimeAmount.minutes(2) + let maximumAllowableJitter = TimeAmount.minutes(1) + let counter = ManagedAtomic(0) + let loop = EmbeddedEventLoop() + + _ = loop.scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, maximumAllowableJitter: maximumAllowableJitter, { RepeatedTask in + counter.wrappingIncrement(ordering: .relaxed) + }) + + let timeRange = TimeAmount.hours(1) + // Due to jittered delays is not possible to exactly know how many tasks will be executed in a given time range, + // instead calculate a range representing an estimate of the number of tasks executed during that given time range. + let minNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / (delay.nanoseconds + maximumAllowableJitter.nanoseconds) + let maxNumberOfExecutedTasks = (timeRange.nanoseconds - initialDelay.nanoseconds) / delay.nanoseconds + 1 + + loop.advanceTime(by: timeRange) + XCTAssertTrue((minNumberOfExecutedTasks...maxNumberOfExecutedTasks).contains(counter.load(ordering: .relaxed))) + } func testCancelledScheduledTasksDoNotHoldOnToRunClosure() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) diff --git a/Tests/NIOPosixTests/NIOThreadPoolTest.swift b/Tests/NIOPosixTests/NIOThreadPoolTest.swift index a36e4794ac..3a4d772781 100644 --- a/Tests/NIOPosixTests/NIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/NIOThreadPoolTest.swift @@ -14,10 +14,12 @@ import XCTest @testable import NIOPosix +import Atomics import Dispatch import NIOConcurrencyHelpers import NIOEmbedded +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) class NIOThreadPoolTest: XCTestCase { func testThreadNamesAreSetUp() { let numberOfThreads = 11 @@ -110,8 +112,52 @@ class NIOThreadPoolTest: XCTestCase { } } + func testAsyncThreadPool() async throws { + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + pool.start() + do { + let hitCount = ManagedAtomic(false) + try await pool.runIfActive { + hitCount.store(true, ordering: .relaxed) + } + XCTAssertEqual(hitCount.load(ordering: .relaxed), true) + } catch {} + try await pool.shutdownGracefully() + } + + func testAsyncThreadPoolErrorPropagation() async throws { + struct ThreadPoolError: Error {} + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + pool.start() + do { + try await pool.runIfActive { + throw ThreadPoolError() + } + XCTFail("Should not get here as closure sent to runIfActive threw an error") + } catch { + XCTAssertNotNil(error as? ThreadPoolError, "Error thrown should be of type ThreadPoolError") + } + try await pool.shutdownGracefully() + } + + func testAsyncThreadPoolNotActiveError() async throws { + struct ThreadPoolError: Error {} + let numberOfThreads = 1 + let pool = NIOThreadPool(numberOfThreads: numberOfThreads) + do { + try await pool.runIfActive { + throw ThreadPoolError() + } + XCTFail("Should not get here as thread pool isn't active") + } catch { + XCTAssertNotNil(error as? NIOThreadPoolError.ThreadPoolInactive, "Error thrown should be of type ThreadPoolError") + } + try await pool.shutdownGracefully() + } + func testAsyncShutdownWorks() async throws { - guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { throw XCTSkip() } let threadPool = NIOThreadPool(numberOfThreads: 17) let eventLoop = NIOAsyncTestingEventLoop() diff --git a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift index 6c4ff1b1d9..cd5d4d5f34 100644 --- a/Tests/NIOPosixTests/NonBlockingFileIOTest.swift +++ b/Tests/NIOPosixTests/NonBlockingFileIOTest.swift @@ -1016,4 +1016,634 @@ class NonBlockingFileIOTest: XCTestCase { } }) } + + func testChunkedReadingToleratesChunkHandlersWithForeignEventLoops() throws { + let content = "hello" + let contentBytes = Array(content.utf8) + var numCalls = 0 + let otherGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + try! otherGroup.syncShutdownGracefully() + } + try withTemporaryFile(content: content) { (fileHandle, path) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 5) + try self.fileIO.readChunked(fileRegion: fr, + chunkSize: 1, + allocator: self.allocator, + eventLoop: self.eventLoop) { buf in + var buf = buf + XCTAssertTrue(self.eventLoop.inEventLoop) + XCTAssertEqual(1, buf.readableBytes) + XCTAssertEqual(contentBytes[numCalls], buf.readBytes(length: 1)?.first!) + numCalls += 1 + return otherGroup.next().makeSucceededFuture(()) + }.wait() + } + XCTAssertEqual(content.utf8.count, numCalls) + } + +} + + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NonBlockingFileIOTest { + func testAsyncBasicFileIOWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 5) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(content.utf8.count, buf.readableBytes) + XCTAssertEqual(content, buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncOffsetWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 3, endIndex: 5) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(2, buf.readableBytes) + XCTAssertEqual("lo", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncOffsetBeyondEOF() async throws { + let content = "hello" + try await withTemporaryFile(content: content) { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 3000, endIndex: 3001) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(0, buf.readableBytes) + XCTAssertEqual("", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncEmptyReadWorks() async throws { + try await withTemporaryFile { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 0) + let buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(0, buf.readableBytes) + } + } + + func testAsyncReadingShortWorks() async throws { + let content = "hello" + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + let fr = FileRegion(fileHandle: fileHandle, readerIndex: 0, endIndex: 10) + var buf = try await self.fileIO.read(fileRegion: fr, + allocator: self.allocator) + XCTAssertEqual(content.utf8.count, buf.readableBytes) + XCTAssertEqual(content, buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncDoesNotBlockTheThreadOrEventLoop() async throws { + try await withPipe { readFH, writeFH in + async let byteBufferTask = try await self.fileIO.read( + fileHandle: readFH, + byteCount: 10, + allocator: self.allocator) + do { + try await self.threadPool.runIfActive { + try writeFH.withUnsafeFileDescriptor { writeFD in + _ = try Posix.write(descriptor: writeFD, pointer: "X", size: 1) + } + try writeFH.close() + } + var buf = try await byteBufferTask + XCTAssertEqual(1, buf.readableBytes) + XCTAssertEqual("X", buf.readString(length: buf.readableBytes)) + } + return [readFH] + } + } + + func testAsyncGettingErrorWhenEventLoopGroupIsShutdown() async throws { + try await self.threadPool.shutdownGracefully() + + try await withPipe { readFH, writeFH in + do { + _ = try await self.fileIO.read( + fileHandle: readFH, + byteCount: 1, + allocator: self.allocator) + XCTFail("testAsyncGettingErrorWhenEventLoopGroupIsShutdown: fileIO.read should throw an error") + } catch { + XCTAssertTrue(error is NIOThreadPoolError.ThreadPoolInactive) + } + return [readFH, writeFH] + } + } + + func testAsyncReadDoesNotReadShort() async throws { + try await withPipe { readFH, writeFH in + async let bufferTask = try await self.fileIO.read(fileHandle: readFH, + byteCount: 10, + allocator: self.allocator) + for i in 0..<10 { + try await Task.sleep(nanoseconds: 5_000_000) + try await self.threadPool.runIfActive { + try writeFH.withUnsafeFileDescriptor { writeFD in + _ = try Posix.write(descriptor: writeFD, pointer: "\(i)", size: 1) + } + } + } + try writeFH.close() + + var buf = try await bufferTask + XCTAssertEqual(10, buf.readableBytes) + XCTAssertEqual("0123456789", buf.readString(length: buf.readableBytes)) + return [readFH] + } + } + + func testAsyncReadMoreThanIntMaxBytesDoesntThrow() async throws { + try XCTSkipIf(MemoryLayout.size == MemoryLayout.size) + // here we try to read way more data back from the file than it contains but it serves the purpose + // even on a small file the OS will return EINVAL if you try to read > INT_MAX bytes + try await withTemporaryFile(content: "some-dummy-content", { (filehandle, path) -> Void in + let content = try await self.fileIO.read(fileHandle: filehandle, byteCount:Int(Int32.max)+10, allocator: .init()) + XCTAssertEqual(String(buffer: content), "some-dummy-content") + }) + } + + func testAsyncReadingFileSize() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + let size = try await self.fileIO.readFileSize(fileHandle: fileHandle) + XCTAssertEqual(size, 10) + } + } + + func testAsyncChangeFileSizeShrink() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + try await self.fileIO.changeFileSize(fileHandle: fileHandle, + size: 1) + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + XCTAssertEqual("0", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncChangeFileSizeGrow() async throws { + try await withTemporaryFile(content: "0123456789") { (fileHandle, _) -> Void in + try await self.fileIO.changeFileSize(fileHandle: fileHandle, + size: 100) + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + let zeros = (1...90).map { _ in UInt8(0) } + guard let bytes = buf.readBytes(length: buf.readableBytes)?.suffix(from: 10) else { + XCTFail("readBytes(length:) should not be nil") + return + } + XCTAssertEqual(zeros, Array(bytes)) + } + } + + func testAsyncWriting() async throws { + try await withTemporaryFile(content: "") { (fileHandle, path) in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + buffer: buffer) + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + let readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: 3, + allocator: self.allocator) + XCTAssertEqual(readBuffer.getString(at: 0, length: 3), "123") + } + } + + func testAsyncWriteMultipleTimes() async throws { + try await withTemporaryFile(content: "AAA") { (fileHandle, path) in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("xxx") + + for i in 0 ..< 3 { + buffer.writeString("\(i)") + try await self.fileIO.write(fileHandle: fileHandle, + buffer: buffer) + } + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + let expectedOutput = "xxx0xxx01xxx012" + let readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: expectedOutput.utf8.count, + allocator: self.allocator) + XCTAssertEqual(expectedOutput, String(decoding: readBuffer.readableBytesView, as: Unicode.UTF8.self)) + } + } + + func testAsyncWritingWithOffset() async throws { + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + toOffset: 1, + buffer: buffer) + let offset = try fileHandle.withUnsafeFileDescriptor { + try Posix.lseek(descriptor: $0, offset: 0, whence: SEEK_SET) + } + XCTAssertEqual(offset, 0) + + var readBuffer = try await self.fileIO.read(fileHandle: fileHandle, + byteCount: 5, + allocator: self.allocator) + XCTAssertEqual(5, readBuffer.readableBytes) + XCTAssertEqual("h123o", readBuffer.readString(length: readBuffer.readableBytes)) + } + } + + // This is undefined behavior and may cause different + // results on other platforms. Please add #if:s according + // to platform requirements. + func testAsyncWritingBeyondEOF() async throws { + try await withTemporaryFile(content: "hello") { (fileHandle, _) -> Void in + var buffer = self.allocator.buffer(capacity: 3) + buffer.writeStaticString("123") + + try await self.fileIO.write(fileHandle: fileHandle, + toOffset: 6, + buffer: buffer) + + let fileRegion = try FileRegion(fileHandle: fileHandle) + var buf = try await self.fileIO.read(fileRegion: fileRegion, + allocator: self.allocator) + XCTAssertEqual(9, buf.readableBytes) + XCTAssertEqual("hello", buf.readString(length: 5)) + XCTAssertEqual([ UInt8(0) ], buf.readBytes(length: 1)) + XCTAssertEqual("123", buf.readString(length: buf.readableBytes)) + } + } + + func testAsyncFileOpenWorks() async throws { + let content = "123" + try await withTemporaryFile(content: content) { (fileHandle, path) -> Void in + try await self.fileIO.withFileRegion(path: path) { fr in + try fr.fileHandle.withUnsafeFileDescriptor { fd in + XCTAssertGreaterThanOrEqual(fd, 0) + } + XCTAssertTrue(fr.fileHandle.isOpen) + XCTAssertEqual(0, fr.readerIndex) + XCTAssertEqual(3, fr.endIndex) + } + } + } + + func testAsyncFileOpenWorksWithEmptyFile() async throws { + let content = "" + try await withTemporaryFile(content: content) { (fileHandle, path) -> Void in + try await self.fileIO.withFileRegion(path: path) { fr in + try fr.fileHandle.withUnsafeFileDescriptor { fd in + XCTAssertGreaterThanOrEqual(fd, 0) + } + XCTAssertTrue(fr.fileHandle.isOpen) + XCTAssertEqual(0, fr.readerIndex) + XCTAssertEqual(0, fr.endIndex) + } + } + } + + func testAsyncFileOpenFails() async throws { + do { + _ = try await self.fileIO.withFileRegion(path: "/dev/null/this/does/not/exist") { _ in} + XCTFail("should've thrown") + } catch let e as IOError where e.errnoCode == ENOTDIR { + // OK + } catch { + XCTFail("wrong error: \(error)") + } + } + + func testAsyncOpeningFilesForWriting() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .allowFileCreation() + ) { _ in } + } + } + + func testAsyncOpeningFilesForWritingFailsIfWeDontAllowItExplicitly() async throws { + do { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .default + ) { _ in } + } + XCTFail("testAsyncOpeningFilesForWritingFailsIfWeDontAllowItExplicitly: openFile should fail") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + + func testAsyncOpeningFilesForWritingDoesNotAllowReading() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: .write, + flags: .allowFileCreation() + ) { fileHandle in + XCTAssertEqual(-1 /* read must fail */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt8 = 0 + return withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + }) + } + } + } + + func testAsyncOpeningFilesForWritingAndReading() async throws { + try await withTemporaryDirectory { dir in + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + XCTAssertEqual(0 /* read should read EOF */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt8 = 0 + return withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + }) + } + } + } + + func testAsyncOpeningFilesForWritingDoesNotImplyTruncation() async throws { + try await withTemporaryDirectory { dir in + // open 1 + write + do { + try await self.fileIO.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + var data = UInt8(ascii: "X") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + } + } + + // open 2 + write again + read + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .default + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_END) + var data = UInt8(ascii: "Y") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + XCTAssertEqual(2 /* both bytes */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt16 = 0 + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_SET) + let readReturn = withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + XCTAssertEqual(UInt16(bigEndian: (UInt16(UInt8(ascii: "X")) << 8) | UInt16(UInt8(ascii: "Y"))), + data) + return readReturn + }) + } + } + } + } + + func testAsyncOpeningFilesForWritingCanUseTruncation() async throws { + try await withTemporaryDirectory { dir in + // open 1 + write + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .allowFileCreation() + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + var data = UInt8(ascii: "X") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + } + } + // open 2 (with truncation) + write again + read + do { + try await self.fileIO!.withFileHandle( + path: "\(dir)/file", + mode: [.write, .read], + flags: .posix(flags: O_TRUNC, mode: 0) + ) { fileHandle in + try fileHandle.withUnsafeFileDescriptor { fd in + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_END) + var data = UInt8(ascii: "Y") + XCTAssertEqual(IOResult.processed(1), + try withUnsafeBytes(of: &data) { ptr in + try Posix.write(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + }) + } + XCTAssertEqual(1 /* read should read just one byte because we truncated the file */, + try fileHandle.withUnsafeFileDescriptor { fd -> ssize_t in + var data: UInt16 = 0 + try Posix.lseek(descriptor: fd, offset: 0, whence: SEEK_SET) + let readReturn = withUnsafeMutableBytes(of: &data) { ptr in + read(fd, ptr.baseAddress, ptr.count) + } + XCTAssertEqual(UInt16(bigEndian: UInt16(UInt8(ascii: "Y")) << 8), data) + return readReturn + }) + } + } + } + } + + func testAsyncReadFromOffset() async throws { + try await withTemporaryFile(content: "hello world") { (fileHandle, path) in + let buffer = try await self.fileIO.read(fileHandle: fileHandle, + fromOffset: 6, + byteCount: 5, + allocator: ByteBufferAllocator()) + let string = String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self) + XCTAssertEqual("world", string) + } + } + + func testAsyncReadFromOffsetAfterEOFDeliversExactlyOneChunk() async throws { + try await withTemporaryFile(content: "hello world") { (fileHandle, path) in + let readableBytes = try await self.fileIO.read( + fileHandle: fileHandle, + fromOffset: 100, + byteCount: 5, + allocator: .init() + ).readableBytes + XCTAssertEqual(0,readableBytes) + } + } + + func testAsyncReadFromEOFDeliversExactlyOneChunk() async throws { + try await withTemporaryFile(content: "") { (fileHandle, path) in + let readableBytes = try await self.fileIO.read( + fileHandle: fileHandle, + byteCount: 5, + allocator: .init() + ).readableBytes + XCTAssertEqual(0, readableBytes) + } + } + + func testAsyncThrowsErrorOnUnstartedPool() async throws { + await withTemporaryFile(content: "hello, world") { fileHandle, path in + let threadPool = NIOThreadPool(numberOfThreads: 1) + let fileIO = NonBlockingFileIO(threadPool: threadPool) + do { + try await fileIO.withFileRegion(path: path) { _ in } + XCTFail("testAsyncThrowsErrorOnUnstartedPool: openFile should throw an error") + } catch { + } + } + } + + func testAsyncLStat() async throws { + try await withTemporaryFile(content: "hello, world") { _, path in + let stat = try await self.fileIO.lstat(path: path) + XCTAssertEqual(12, stat.st_size) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + } + + try await withTemporaryDirectory { path in + let stat = try await self.fileIO.lstat(path: path) + XCTAssertEqual(S_IFDIR, S_IFMT & stat.st_mode) + } + } + + func testAsyncSymlink() async throws { + try await withTemporaryFile(content: "hello, world") { _, path in + let symlink = "\(path).symlink" + try await self.fileIO.symlink(path: symlink, to: path) + + let link = try await self.fileIO.readlink(path: symlink) + XCTAssertEqual(path, link) + let stat = try await self.fileIO.lstat(path: symlink) + XCTAssertEqual(S_IFLNK, S_IFMT & stat.st_mode) + + try await self.fileIO.unlink(path: symlink) + do { + _ = try await self.fileIO.lstat(path: symlink) + XCTFail("testAsyncSymlink: lstat should throw an error after unlink") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + + func testAsyncCreateDirectory() async throws { + try await withTemporaryDirectory { path in + let dir = "\(path)/f1/f2///f3" + try await self.fileIO.createDirectory(path: dir, withIntermediateDirectories: true, mode: S_IRWXU) + + let stat = try await self.fileIO.lstat(path: dir) + XCTAssertEqual(S_IFDIR, S_IFMT & stat.st_mode) + + try await self.fileIO.createDirectory(path: "\(dir)/f4", withIntermediateDirectories: false, mode: S_IRWXU) + + let stat2 = try await self.fileIO.lstat(path: dir) + XCTAssertEqual(S_IFDIR, S_IFMT & stat2.st_mode) + + let dir3 = "\(path)/f4/." + try await self.fileIO.createDirectory(path: dir3, withIntermediateDirectories: true, mode: S_IRWXU) + } + } + + func testAsyncListDirectory() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let list = try await self.fileIO.listDirectory(path: path) + XCTAssertEqual([".", "..", "file"], list.sorted(by: { $0.name < $1.name }).map(\.name)) + } + } + } + + func testAsyncRename() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let stat = try await self.fileIO.lstat(path: file) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + + let new = "\(path).new" + try await self.fileIO.rename(path: file, newName: new) + + let stat2 = try await self.fileIO.lstat(path: new) + XCTAssertEqual(S_IFREG, S_IFMT & stat2.st_mode) + + do { + _ = try await self.fileIO.lstat(path: file) + XCTFail("testAsyncRename: lstat should throw an error after file renamed") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + } + + func testAsyncRemove() async throws { + try await withTemporaryDirectory { path in + let file = "\(path)/file" + try await self.fileIO.withFileHandle( + path: file, + mode: .write, + flags: .allowFileCreation() + ) { handle in + let stat = try await self.fileIO.lstat(path: file) + XCTAssertEqual(S_IFREG, S_IFMT & stat.st_mode) + + try await self.fileIO.remove(path: file) + do { + _ = try await self.fileIO.lstat(path: file) + XCTFail("testAsyncRemove: lstat should throw an error after file removed") + } catch { + XCTAssertEqual(ENOENT, (error as? IOError)?.errnoCode) + } + } + } + } } diff --git a/Tests/NIOPosixTests/SerialExecutorTests.swift b/Tests/NIOPosixTests/SerialExecutorTests.swift index 84f5249222..85619b0ca3 100644 --- a/Tests/NIOPosixTests/SerialExecutorTests.swift +++ b/Tests/NIOPosixTests/SerialExecutorTests.swift @@ -37,14 +37,12 @@ actor EventLoopBoundActor { } #endif +@available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) final class SerialExecutorTests: XCTestCase { private func _testBasicExecutorFitsOnEventLoop(loop1: EventLoop, loop2: EventLoop) async throws { #if compiler(<5.9) throw XCTSkip("Custom executors are only supported in 5.9") #else - guard #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) else { - throw XCTSkip("Custom executors not available on this platform") - } let testActor = EventLoopBoundActor(loop: loop1) await testActor.assertInLoop(loop1) diff --git a/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift b/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift index ae1ed7ac15..24b2086587 100644 --- a/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift +++ b/Tests/NIOPosixTests/SystemCallWrapperHelpers.swift @@ -19,10 +19,10 @@ import Foundation public func measureRunTime(_ body: () throws -> Int) rethrows -> TimeInterval { func measureOne(_ body: () throws -> Int) rethrows -> TimeInterval { - let start = Date() + let start = DispatchTime.now().uptimeNanoseconds _ = try body() - let end = Date() - return end.timeIntervalSince(start) + let end = DispatchTime.now().uptimeNanoseconds + return Double(end - start)/1_000_000 } _ = try measureOne(body) diff --git a/Tests/NIOPosixTests/TestUtils.swift b/Tests/NIOPosixTests/TestUtils.swift index 273b6b145c..0c38f94d8f 100644 --- a/Tests/NIOPosixTests/TestUtils.swift +++ b/Tests/NIOPosixTests/TestUtils.swift @@ -31,6 +31,14 @@ extension System { #if canImport(Darwin) || os(Linux) || os(Android) guard let socket = try? Socket(protocolFamily: .vsock, type: .stream) else { return false } XCTAssertNoThrow(try socket.close()) +#if !canImport(Darwin) + do { + let fd = try Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) + try Posix.close(descriptor: fd) + } catch { + return false + } +#endif return true #else return false @@ -60,6 +68,29 @@ func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) throws -> [ } } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +func withPipe(_ body: (NIOCore.NIOFileHandle, NIOCore.NIOFileHandle) async throws -> [NIOCore.NIOFileHandle]) async throws { + var fds: [Int32] = [-1, -1] + fds.withUnsafeMutableBufferPointer { ptr in + XCTAssertEqual(0, pipe(ptr.baseAddress!)) + } + let readFH = NIOFileHandle(descriptor: fds[0]) + let writeFH = NIOFileHandle(descriptor: fds[1]) + var toClose: [NIOFileHandle] = [readFH, writeFH] + var error: Error? = nil + do { + toClose = try await body(readFH, writeFH) + } catch let err { + error = err + } + try toClose.forEach { fh in + XCTAssertNoThrow(try fh.close()) + } + if let error = error { + throw error + } +} + func withTemporaryDirectory(_ body: (String) throws -> T) rethrows -> T { let dir = createTemporaryDirectory() defer { @@ -68,6 +99,15 @@ func withTemporaryDirectory(_ body: (String) throws -> T) rethrows -> T { return try body(dir) } +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +func withTemporaryDirectory(_ body: (String) async throws -> T) async rethrows -> T { + let dir = createTemporaryDirectory() + defer { + try? FileManager.default.removeItem(atPath: dir) + } + return try await body(dir) +} + /// This function creates a filename that can be used for a temporary UNIX domain socket path. /// /// If the temporary directory is too long to store a UNIX domain socket path, it will `chdir` into the temporary @@ -131,6 +171,35 @@ func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle } return try body(fileHandle, path) } + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +func withTemporaryFile(content: String? = nil, _ body: @escaping @Sendable (NIOCore.NIOFileHandle, String) async throws -> T) async rethrows -> T { + let (fd, path) = openTemporaryFile() + let fileHandle = NIOFileHandle(descriptor: fd) + defer { + XCTAssertNoThrow(try fileHandle.close()) + XCTAssertEqual(0, unlink(path)) + } + if let content = content { + try Array(content.utf8).withUnsafeBufferPointer { ptr in + var toWrite = ptr.count + var start = ptr.baseAddress! + while toWrite > 0 { + let res = try Posix.write(descriptor: fd, pointer: start, size: toWrite) + switch res { + case .processed(let written): + toWrite -= written + start = start + written + case .wouldBlock: + XCTFail("unexpectedly got .wouldBlock from a file") + continue + } + } + XCTAssertEqual(0, lseek(fd, 0, SEEK_SET)) + } + } + return try await body(fileHandle, path) +} var temporaryDirectory: String { get { #if targetEnvironment(simulator) diff --git a/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift b/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift index 0c534568d1..d55fab898e 100644 --- a/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift +++ b/Tests/NIOSingletonsTests/GlobalSingletonsTests.swift @@ -18,11 +18,13 @@ import NIOPosix import Foundation final class NIOSingletonsTests: XCTestCase { + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testSingletonMultiThreadedEventLoopWorks() async throws { let works = try await MultiThreadedEventLoopGroup.singleton.any().submit { "yes" }.get() XCTAssertEqual(works, "yes") } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) func testSingletonBlockingPoolWorks() async throws { let works = try await NIOThreadPool.singleton.runIfActive( eventLoop: MultiThreadedEventLoopGroup.singleton.any() diff --git a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift index 99bc3e795d..36fc7bd04a 100644 --- a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift +++ b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) import NIOTLS -@_spi(AsyncChannel) import NIOCore +import NIOTLS +import NIOCore import NIOEmbedded import XCTest import NIOTestUtils @@ -31,7 +31,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let channel = EmbeddedChannel() let handler = NIOTypedApplicationProtocolNegotiationHandler { result, channel in - return channel.eventLoop.makeSucceededFuture(.init(result: (.negotiated(result)))) + return channel.eventLoop.makeSucceededFuture(.negotiated(result)) } try channel.pipeline.addHandler(handler).wait() try channel.pipeline.removeHandler(handler).wait() @@ -49,14 +49,14 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { called = true XCTAssertEqual(result, self.negotiatedResult) XCTAssertTrue(emChannel === channel) - return loop.makeSucceededFuture(.init(result: (.negotiated(result)))) + return loop.makeSucceededFuture(.negotiated(result)) } try emChannel.pipeline.addHandler(handler).wait() emChannel.pipeline.fireUserInboundEventTriggered(negotiatedEvent) XCTAssertTrue(called) - XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .init(result: (.negotiated(negotiatedResult)))) + XCTAssertEqual(try handler.protocolNegotiationResult.wait(), .negotiated(negotiatedResult)) } func testIgnoresUnknownUserEvents() throws { @@ -65,7 +65,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Negotiation fired") - return loop.makeSucceededFuture(.init(result: (.failed))) + return loop.makeSucceededFuture(.failed) } try channel.pipeline.addHandler(handler).wait() @@ -86,7 +86,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Should not be called") - return loop.makeSucceededFuture(.init(result: (.failed))) + return loop.makeSucceededFuture(.failed) } try channel.pipeline.addHandler(handler).wait() @@ -101,7 +101,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testBufferingWhileWaitingForFuture() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in return continuePromise.futureResult @@ -119,7 +119,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertNoThrow(XCTAssertNil(try channel.readInbound())) // Complete the pipeline swap. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) // Now everything should have been unbuffered. XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "writes")) @@ -132,7 +132,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testNothingBufferedDoesNotFireReadCompleted() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -150,7 +150,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { // Now satisfy the future, which forces data unbuffering. As we haven't buffered any data, // readComplete should not be fired. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 0) XCTAssertTrue(try channel.finish().isClean) @@ -159,7 +159,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testUnbufferingFiresReadCompleted() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -179,7 +179,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1) // Now satisfy the future, which forces data unbuffering. This should fire readComplete. - continuePromise.succeed(.init(result: (.failed))) + continuePromise.succeed(.failed) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 2) @@ -190,7 +190,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { func testUnbufferingHandlesReentrantReads() throws { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) + let continuePromise = loop.makePromise(of: NegotiationResult.self) let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult @@ -211,7 +211,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { XCTAssertEqual(eventCounterHandler.channelReadCompleteCalls, 1) // Now satisfy the future, which forces data unbuffering. This should fire readComplete. - continuePromise.succeed(.init(result: .failed)) + continuePromise.succeed(.failed) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) XCTAssertNoThrow(XCTAssertEqual(try channel.readInbound()!, "a write")) diff --git a/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift b/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift index 9d453a243d..10741197b1 100644 --- a/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift +++ b/Tests/NIOTestUtilsTests/NIOHTTP1TestServerTest.swift @@ -313,6 +313,43 @@ class NIOHTTP1TestServerTest: XCTestCase { XCTAssertNotNil(channel) XCTAssertNoThrow(try channel.closeFuture.wait()) } + + func testReceiveBodyWithoutAggregation() { + let testServer = NIOHTTP1TestServer(group: self.group, aggregateBody: false) + + let responsePromise = self.group.next().makePromise(of: String.self) + var channel: Channel! + XCTAssertNoThrow(channel = try self.connect(serverPort: testServer.serverPort, + responsePromise: responsePromise).wait()) + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/uri", headers: headers) + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.head(requestHead)), promise: nil) + XCTAssertNoThrow(try testServer.receiveHeadAndVerify { head in + XCTAssertEqual(head.uri, "/uri") + XCTAssertEqual(head.headers["Content-Type"], ["text/plain; charset=utf-8"]) + }) + XCTAssertNoThrow(try testServer.writeOutbound(.head(.init(version: .http1_1, status: .ok)))) + + for _ in 0..<10 { + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.body(.byteBuffer(ByteBuffer(string: "ping")))), promise: nil) + XCTAssertNoThrow(try testServer.receiveBodyAndVerify { buffer in + XCTAssertEqual(String(buffer: buffer), "ping") + }) + XCTAssertNoThrow(try testServer.writeOutbound(.body(.byteBuffer(ByteBuffer(string: "pong"))))) + } + + channel.writeAndFlush(NIOAny(HTTPClientRequestPart.end(nil)), promise: nil) + XCTAssertNoThrow(try testServer.receiveEndAndVerify { trailers in + XCTAssertNil(trailers) + }) + XCTAssertNoThrow(try testServer.writeOutbound(.end(nil))) + + XCTAssertNoThrow(try testServer.stop()) + XCTAssertNotNil(channel) + XCTAssertNoThrow(try channel.closeFuture.wait()) + } } private final class TestHTTPHandler: ChannelInboundHandler { diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index fd974e253b..1e64a27544 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -146,12 +146,12 @@ private class WebSocketRecorderHandler: ChannelInboundHandler, ChannelOutboundHa } } +private func basicRequest(path: String = "/") -> String { + return "GET \(path) HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0" +} + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) class WebSocketClientEndToEndTests: XCTestCase { - - private func basicRequest(path: String = "/") -> String { - return "GET \(path) HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0" - } - func testSimpleUpgradeSucceeds() throws { var upgradeHandlerCallbackFired = false @@ -173,7 +173,7 @@ class WebSocketClientEndToEndTests: XCTestCase { // Read the server request. if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, self.basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") + XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") } else { XCTFail() } @@ -284,7 +284,7 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } - private func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { + fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { let handler = WebSocketRecorderHandler() @@ -404,3 +404,216 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) } } + +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests { + func setUpClientChannel( + clientUpgraders: [any NIOTypedHTTPClientProtocolUpgrader], + notUpgradingCompletionHandler: @Sendable @escaping (Channel) -> EventLoopFuture + ) throws -> (EmbeddedChannel, EventLoopFuture) { + + let channel = EmbeddedChannel() + + var headers = HTTPHeaders() + headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") + headers.add(name: "Content-Length", value: "\(0)") + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + + let config = NIOTypedHTTPClientUpgradeConfiguration( + upgradeRequestHead: requestHead, + upgraders: clientUpgraders, + notUpgradingCompletionHandler: notUpgradingCompletionHandler + ) + + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: .init(upgradeConfiguration: config)) + + try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + .wait() + + return (channel, upgradeResult) + } + + override func testSimpleUpgradeSucceeds() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Read the server request. + if let requestString = try clientChannel.readByteBufferOutputAsString() { + XCTAssertEqual(requestString, basicRequest() + "\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Key: \(requestKey)\r\nSec-WebSocket-Version: 13\r\n\r\n") + } else { + XCTFail() + } + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // Once upgraded, validate the http pipeline has been removed. + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + // Check that the pipeline now has the correct websocket handlers added. + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketFrameEncoder.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow(try clientChannel.pipeline + .assertContains(handlerType: WebSocketRecorderHandler.self)) + + try upgradeResult.wait() + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + override func testRejectUpgradeIfMissingAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with a missing accept key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfIncorrectAcceptKey() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "notACorrectKeyL1am=F1y=nn=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response but with an incorrect response key. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.upgraderDeniedUpgrade) + } + } + + override func testRejectUpgradeIfNotWebsocket() throws { + let requestKey = "OfS0wDaT5NoxF2gqm7Zj2YtetzM=" + let responseKey = "yKEqitDFPE81FyIhKTm+ojBqigk=" + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: requestKey, + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(WebSocketRecorderHandler()) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response with an incorrect protocol. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myProtocol\r\nSec-WebSocket-Accept:\(responseKey)\r\n\r\n" + + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + + // Close the pipeline. + XCTAssertNoThrow(try clientChannel.close().wait()) + + XCTAssertThrowsError(try upgradeResult.wait()) { error in + XCTAssertEqual(error as? NIOHTTPClientUpgradeError, NIOHTTPClientUpgradeError.responseProtocolNotFound) + } + } + + override fileprivate func runSuccessfulUpgrade() throws -> (EmbeddedChannel, WebSocketRecorderHandler) { + let handler = WebSocketRecorderHandler() + + let basicUpgrader = NIOTypedWebSocketClientUpgrader( + requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=", + upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in + channel.pipeline.addHandler(handler) + }) + + // The process should kick-off independently by sending the upgrade request to the server. + let (clientChannel, upgradeResult) = try setUpClientChannel( + clientUpgraders: [basicUpgrader], + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + + // Push the successful server response. + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept:yKEqitDFPE81FyIhKTm+ojBqigk=\r\n\r\n" + + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) + + clientChannel.embeddedEventLoop.run() + + // We now have a successful upgrade, clear the output channels read to test the frames. + XCTAssertNoThrow(try clientChannel.readOutbound(as: ByteBuffer.self)) + + clientChannel.embeddedEventLoop.run() + + try upgradeResult.wait() + + return (clientChannel, handler) + } +} +#endif diff --git a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift index a5ec8e6dc2..d73d1f21dc 100644 --- a/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketServerEndToEndTests.swift @@ -112,10 +112,38 @@ private class WebSocketRecorderHandler: ChannelInboundHandler { } } +struct WebSocketServerUpgraderConfiguration { + let maxFrameSize: Int + let automaticErrorHandling: Bool + let shouldUpgrade: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + let upgradePipelineHandler: @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + + @preconcurrency + init( + maxFrameSize: Int = 1 << 14, + automaticErrorHandling: Bool = true, + shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, + upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture + ) { + self.maxFrameSize = maxFrameSize + self.automaticErrorHandling = automaticErrorHandling + self.shouldUpgrade = shouldUpgrade + self.upgradePipelineHandler = upgradePipelineHandler + } +} + class WebSocketServerEndToEndTests: XCTestCase { - private func createTestFixtures(upgraders: [NIOWebSocketServerUpgrader]) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { + func createTestFixtures( + upgraders: [WebSocketServerUpgraderConfiguration] + ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { let loop = EmbeddedEventLoop() let serverChannel = EmbeddedChannel(loop: loop) + let upgraders = upgraders.map { NIOWebSocketServerUpgrader( + maxFrameSize: $0.maxFrameSize, + automaticErrorHandling: $0.automaticErrorHandling, + shouldUpgrade: $0.shouldUpgrade, + upgradePipelineHandler: $0.upgradePipelineHandler + )} XCTAssertNoThrow(try serverChannel.pipeline.configureHTTPServerPipeline( withServerUpgrade: (upgraders: upgraders as [HTTPServerProtocolUpgrader], completionHandler: { (context: ChannelHandlerContext) in } ) ).wait()) @@ -129,7 +157,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testBasicUpgradeDance() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -151,7 +179,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testUpgradeWithProtocolName() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -170,7 +198,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testCanRejectUpgrade() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(nil) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(nil) }, upgradePipelineHandler: { (channel, req) in XCTFail("Should not have called") return channel.eventLoop.makeSucceededFuture(()) @@ -201,7 +229,7 @@ class WebSocketServerEndToEndTests: XCTestCase { var acceptPromise: EventLoopPromise? = nil var upgradeComplete = false - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in acceptPromise = channel.eventLoop.makePromise() return acceptPromise!.futureResult }, @@ -240,7 +268,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresVersion13() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -262,7 +290,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresVersionHeader() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -284,7 +312,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testRequiresKeyHeader() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.eventLoop.makeSucceededFuture(()) }) let (loop, server, client) = self.createTestFixtures(upgraders: [basicUpgrader]) defer { @@ -306,7 +334,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testUpgradeMayAddCustomHeaders() throws { - let upgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + let upgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in var hdrs = HTTPHeaders() hdrs.add(name: "TestHeader", value: "TestValue") return channel.eventLoop.makeSucceededFuture(hdrs) @@ -329,8 +357,8 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testMayRegisterMultipleWebSocketEndpoints() throws { - func buildHandler(path: String) -> NIOWebSocketServerUpgrader { - return NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in + func buildHandler(path: String) -> WebSocketServerUpgraderConfiguration { + return WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in guard head.uri == "/\(path)" else { return channel.eventLoop.makeSucceededFuture(nil) } var hdrs = HTTPHeaders() hdrs.add(name: "Target", value: path) @@ -360,7 +388,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testSendAFewFrames() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -398,7 +426,7 @@ class WebSocketServerEndToEndTests: XCTestCase { } func testMaxFrameSize() throws { - let basicUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: 16, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(maxFrameSize: 16, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in return channel.eventLoop.makeSucceededFuture(()) }) @@ -423,7 +451,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testAutomaticErrorHandling() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, + let basicUpgrader = WebSocketServerUpgraderConfiguration(shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -461,7 +489,7 @@ class WebSocketServerEndToEndTests: XCTestCase { func testNoAutomaticErrorHandling() throws { let recorder = WebSocketRecorderHandler() - let basicUpgrader = NIOWebSocketServerUpgrader(automaticErrorHandling: false, + let basicUpgrader = WebSocketServerUpgraderConfiguration(automaticErrorHandling: false, shouldUpgrade: { (channel, head) in channel.eventLoop.makeSucceededFuture(HTTPHeaders()) }, upgradePipelineHandler: { (channel, req) in channel.pipeline.addHandler(recorder) @@ -498,3 +526,32 @@ class WebSocketServerEndToEndTests: XCTestCase { XCTAssertNoThrow(XCTAssertEqual([], try server.readAllOutboundBytes())) } } + +#if !canImport(Darwin) || swift(>=5.10) +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final class TypedWebSocketServerEndToEndTests: WebSocketServerEndToEndTests { + override func createTestFixtures( + upgraders: [WebSocketServerUpgraderConfiguration] + ) -> (loop: EmbeddedEventLoop, serverChannel: EmbeddedChannel, clientChannel: EmbeddedChannel) { + let loop = EmbeddedEventLoop() + let serverChannel = EmbeddedChannel(loop: loop) + let upgraders = upgraders.map { NIOTypedWebSocketServerUpgrader( + maxFrameSize: $0.maxFrameSize, + enableAutomaticErrorHandling: $0.automaticErrorHandling, + shouldUpgrade: $0.shouldUpgrade, + upgradePipelineHandler: $0.upgradePipelineHandler + )} + + XCTAssertNoThrow(try serverChannel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init( + upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration( + upgraders: upgraders, + notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() } + ) + ) + )) + let clientChannel = EmbeddedChannel(loop: loop) + return (loop: loop, serverChannel: serverChannel, clientChannel: clientChannel) + } +} +#endif diff --git a/dev/update-benchmark-thresholds.sh b/dev/update-benchmark-thresholds.sh new file mode 100755 index 0000000000..9a37298f1a --- /dev/null +++ b/dev/update-benchmark-thresholds.sh @@ -0,0 +1,41 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftCertificates open source project +## +## Copyright (c) 2023 Apple Inc. and the SwiftCertificates project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftCertificates project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +set -eu +set -o pipefail + +here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +target_repo=${2-"$here/.."} + +for f in 57 58 59 510 -nightly; do + echo "swift$f" + + docker_file=$(if [[ "$f" == "-nightly" ]]; then f=main; fi && ls "$target_repo/docker/docker-compose."*"$f"*".yaml") + + docker-compose -f docker/docker-compose.yaml -f $docker_file run update-benchmark-baseline +done diff --git a/docker/Dockerfile b/docker/Dockerfile index 86bb2bfbd6..3de64437b4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ ARG swift_version=5.7 -ARG ubuntu_version=bionic +ARG ubuntu_version=jammy ARG base_image=swift:$swift_version-$ubuntu_version FROM $base_image # needed to do again after FROM due to docker limitation @@ -19,6 +19,9 @@ RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools # ruby RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev build-essential +# install jemalloc for running allocation benchmarks +RUN apt-get update & apt-get install -y libjemalloc-dev + # tools RUN mkdir -p $HOME/.tools RUN echo 'export PATH="$HOME/.tools:$PATH"' >> $HOME/.profile diff --git a/docker/docker-compose.2004.56.yaml b/docker/docker-compose.2204.510.yaml similarity index 66% rename from docker/docker-compose.2004.56.yaml rename to docker/docker-compose.2204.510.yaml index f26cac2fec..7016963d72 100644 --- a/docker/docker-compose.2004.56.yaml +++ b/docker/docker-compose.2204.510.yaml @@ -3,28 +3,28 @@ version: "3" services: runtime-setup: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 build: args: - ubuntu_version: "focal" - swift_version: "5.6" + base_image: "swiftlang/swift:nightly-5.10-jammy" unit-tests: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 integration-tests: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 documentation-check: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 test: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 environment: - - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 + - SWIFT_VERSION=5.10 + - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 - - MAX_ALLOCS_ALLOWED_1000_addHandlers_sync=39050 + - MAX_ALLOCS_ALLOWED_1000_addHandlers_sync=38050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlercontext=8050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlername=8050 - MAX_ALLOCS_ALLOWED_1000_addRemoveHandlers_handlertype=8050 @@ -35,13 +35,13 @@ services: - MAX_ALLOCS_ALLOWED_1000_getHandlers=8050 - MAX_ALLOCS_ALLOWED_1000_getHandlers_sync=36 - MAX_ALLOCS_ALLOWED_1000_reqs_1_conn=26400 - - MAX_ALLOCS_ALLOWED_1000_rst_connections=151050 - - MAX_ALLOCS_ALLOWED_1000_tcpbootstraps=4050 - - MAX_ALLOCS_ALLOWED_1000_tcpconnections=156050 + - MAX_ALLOCS_ALLOWED_1000_rst_connections=149050 + - MAX_ALLOCS_ALLOWED_1000_tcpbootstraps=3050 + - MAX_ALLOCS_ALLOWED_1000_tcpconnections=157050 - MAX_ALLOCS_ALLOWED_1000_udp_reqs=6050 - MAX_ALLOCS_ALLOWED_1000_udpbootstraps=2050 - MAX_ALLOCS_ALLOWED_1000_udpconnections=77050 - - MAX_ALLOCS_ALLOWED_1_reqs_1000_conn=404000 + - MAX_ALLOCS_ALLOWED_1_reqs_1000_conn=396000 - MAX_ALLOCS_ALLOWED_bytebuffer_lots_of_rw=2050 - MAX_ALLOCS_ALLOWED_creating_10000_headers=0 - MAX_ALLOCS_ALLOWED_decode_1000_ws_frames=2050 @@ -61,26 +61,35 @@ services: - MAX_ALLOCS_ALLOWED_get_100000_headers_canonical_form_trimming_whitespace_from_long_string=700050 - MAX_ALLOCS_ALLOWED_get_100000_headers_canonical_form_trimming_whitespace_from_short_string=700050 - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=2050 - - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=339 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 - - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=60100 - - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=60050 - - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=86 + - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 + - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=343 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 + - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 + - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 + - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 - MAX_ALLOCS_ALLOWED_udp_1000_reqs_1_conn=6200 - - MAX_ALLOCS_ALLOWED_udp_1_reqs_1000_conn=161050 + - MAX_ALLOCS_ALLOWED_udp_1_reqs_1000_conn=167050 - FORCE_TEST_DISCOVERY=--enable-test-discovery - WARN_AS_ERROR_ARG=-Xswiftc -warnings-as-errors + - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error # - SANITIZER_ARG=--sanitize=thread # TSan broken still performance-test: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 + + update-benchmark-baseline: + image: swift-nio:22.04-5.10 + environment: + - SWIFT_VERSION=5.10 shell: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 echo: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 http: - image: swift-nio:20.04-5.6 + image: swift-nio:22.04-5.10 + + cxx-interop-build: + image: swift-nio:22.04-5.10 diff --git a/docker/docker-compose.2204.57.yaml b/docker/docker-compose.2204.57.yaml index 31df38fc98..415223ba5b 100644 --- a/docker/docker-compose.2204.57.yaml +++ b/docker/docker-compose.2204.57.yaml @@ -21,6 +21,7 @@ services: test: image: swift-nio:22.04-5.7 environment: + - SWIFT_VERSION=5.7 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 @@ -63,7 +64,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=341 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 @@ -76,6 +77,11 @@ services: performance-test: image: swift-nio:22.04-5.7 + update-benchmark-baseline: + image: swift-nio:22.04-5.7 + environment: + - SWIFT_VERSION=5.7 + shell: image: swift-nio:22.04-5.7 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml index f4910ba24b..af365fa86e 100644 --- a/docker/docker-compose.2204.58.yaml +++ b/docker/docker-compose.2204.58.yaml @@ -21,6 +21,7 @@ services: test: image: swift-nio:22.04-5.8 environment: + - SWIFT_VERSION=5.8 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=22 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 @@ -63,7 +64,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=341 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 @@ -77,6 +78,11 @@ services: performance-test: image: swift-nio:22.04-5.8 + update-benchmark-baseline: + image: swift-nio:22.04-5.8 + environment: + - SWIFT_VERSION=5.8 + shell: image: swift-nio:22.04-5.8 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml index baf70a9674..889c4cb8b8 100644 --- a/docker/docker-compose.2204.59.yaml +++ b/docker/docker-compose.2204.59.yaml @@ -6,7 +6,8 @@ services: image: swift-nio:22.04-5.9 build: args: - base_image: "swiftlang/swift:nightly-5.9-jammy" + ubuntu_version: "jammy" + swift_version: "5.9" unit-tests: image: swift-nio:22.04-5.9 @@ -20,6 +21,7 @@ services: test: image: swift-nio:22.04-5.9 environment: + - SWIFT_VERSION=5.9 - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 @@ -62,7 +64,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=350 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 @@ -76,6 +78,11 @@ services: performance-test: image: swift-nio:22.04-5.9 + update-benchmark-baseline: + image: swift-nio:22.04-5.9 + environment: + - SWIFT_VERSION=5.9 + shell: image: swift-nio:22.04-5.9 @@ -84,3 +91,6 @@ services: http: image: swift-nio:22.04-5.9 + + cxx-interop-build: + image: swift-nio:22.04-5.9 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml index e4da57992a..6198161e61 100644 --- a/docker/docker-compose.2204.main.yaml +++ b/docker/docker-compose.2204.main.yaml @@ -20,6 +20,7 @@ services: test: image: swift-nio:22.04-main environment: + - SWIFT_VERSION=main - MAX_ALLOCS_ALLOWED_10000000_asyncsequenceproducer=21 - MAX_ALLOCS_ALLOWED_1000000_asyncwriter=1000050 - MAX_ALLOCS_ALLOWED_1000_addHandlers=45050 @@ -62,7 +63,7 @@ services: - MAX_ALLOCS_ALLOWED_modifying_1000_circular_buffer_elements=0 - MAX_ALLOCS_ALLOWED_modifying_byte_buffer_view=6050 - MAX_ALLOCS_ALLOWED_ping_pong_1000_reqs_1_conn=343 - - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=140050 + - MAX_ALLOCS_ALLOWED_read_10000_chunks_from_file=130050 - MAX_ALLOCS_ALLOWED_schedule_10000_tasks=50100 - MAX_ALLOCS_ALLOWED_schedule_and_run_10000_tasks=50050 - MAX_ALLOCS_ALLOWED_scheduling_10000_executions=85 @@ -76,6 +77,11 @@ services: performance-test: image: swift-nio:22.04-main + update-benchmark-baseline: + image: swift-nio:22.04-main + environment: + - SWIFT_VERSION=main + shell: image: swift-nio:22.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 8bf9b4adb6..0c3c846e81 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -16,8 +16,8 @@ services: depends_on: [runtime-setup] volumes: - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code + - ..:/swift-nio:z + working_dir: /swift-nio cap_drop: - CAP_NET_RAW - CAP_NET_BIND_SERVICE @@ -40,12 +40,20 @@ services: test: <<: *common - command: /bin/bash -xcl "uname -a && swift -version && swift $${SWIFT_TEST_VERB-test} $${FORCE_TEST_DISCOVERY-} $${WARN_AS_ERROR_ARG-} $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-} && ./scripts/integration_tests.sh $${INTEGRATION_TESTS_ARG-}" + command: /bin/bash -xcl "uname -a && swift -version && swift $${SWIFT_TEST_VERB-test} $${FORCE_TEST_DISCOVERY-} $${WARN_AS_ERROR_ARG-} $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-} && ./scripts/integration_tests.sh $${INTEGRATION_TESTS_ARG-} && cd Benchmarks && swift package benchmark baseline check --check-absolute-path Thresholds/$${SWIFT_VERSION-}/" performance-test: <<: *common command: /bin/bash -xcl "swift build -c release -Xswiftc -Xllvm -Xswiftc -align-all-functions=5 -Xswiftc -Xllvm -Xswiftc -align-all-blocks=5 && ./.build/release/NIOPerformanceTester" + update-benchmark-baseline: + <<: *common + command: /bin/bash -xcl "cd Benchmarks && swift package --disable-sandbox --scratch-path .build/$${SWIFT_VERSION-}/ --allow-writing-to-package-directory benchmark --format metricP90AbsoluteThresholds --path Thresholds/$${SWIFT_VERSION-}/" + + cxx-interop-build: + <<: *common + command: /bin/bash -xcl "./scripts/cxx-interop-compatibility.sh" + # util shell: diff --git a/docs/public-async-nio-apis.md b/docs/public-async-nio-apis.md new file mode 100644 index 0000000000..2a3ff118c1 --- /dev/null +++ b/docs/public-async-nio-apis.md @@ -0,0 +1,1121 @@ +# Async NIO bridges + +This is a summary of all the new APIs we introduced to make NIO and the various +network protocols work with new async interfaces. The intention of this document +is to quickly outline the incompatibilities in the original API and then show a +holistic view over all the new APIs. + +## What APIs are needed? + +The first problem that we had to tackle was bridging NIO's `Channel` to Swift +Concurrency. To do so we introduced two new foundational types - the +`NIOAsyncSequenceProducer` and the `NIOAsyncWriter`. Those allow us to bridge +the read and the write side of the `Channel` while propagating the back-pressure +across the bridge. + +On top of those two types, we built out the `NIOAsyncChannel` which allows users +to bridge a `Channel` into Swift Concurrency. To do this it inserts two channel +handlers which bridge the read and write side using the +`NIOAsyncSequenceProducer` and the `NIOAsyncWriter`. + +Next up we had to look at the bootstraps. Here the import part is that the +`Channel`s **must** be wrapped at the correct timing otherwise there is the +potential that reads might be dropped. This is not problematic for most of the +bootstraps since they call their various `channelInitializer`s and +`childChannelInitializer`s at the right time. However, there was one tricky +bootstrap - `ServerBootstrap`. The `ServerBootstrap` multiplexes the incoming +connections and we have to make sure that the wrapping of the child channels +happens at the correct time. Additionally, the new bootstrap APIs **must** be +able to relay the type information of the configured channels to the +`bind`/`connect` methods. + +The next thing we had to tackle was networking protocols that dynamically +re-configure the `ChannelPipeline`. The two examples that we provide +implementations for are HTTP/1 protocol upgrades and Application Protocol +Negotiation (ALPN) via TLS. Similar to the bootstraps we have to ensure that the +type information is retained so that users can correctly identify which +reconfiguration path has been taken. + +Lastly, we had to look at how to handle protocols that multiplex, like HTTP/2. +Multiplexing protocols need to expose a typed async interface to consume new +inbound connections/streams and to open new outbound connections/streams where +applicable. + +## Proposed APIs + +This section contains all the new APIs that we are adding and gives us a holistic +overview to review them. + +### `NIOAsyncChannel` + +```swift +/// Wraps a NIO ``Channel`` object into a form suitable for use in Swift Concurrency. +/// +/// ``NIOAsyncChannel`` abstracts the notion of a NIO ``Channel`` into something that +/// can safely be used in a structured concurrency context. In particular, this exposes +/// the following functionality: +/// +/// - reads are presented as an `AsyncSequence` +/// - writes can be written to with async functions on a writer, providing back pressure +/// - channels can be closed seamlessly +/// +/// This type does not replace the full complexity of NIO's ``Channel``. In particular, it +/// does not expose the following functionality: +/// +/// - user events +/// - traditional NIO back pressure such as writability signals and the ``Channel/read()`` call +/// +/// Users are encouraged to separate their ``ChannelHandler``s into those that implement +/// protocol-specific logic (such as parsers and encoders) and those that implement business +/// logic. Protocol-specific logic should be implemented as a ``ChannelHandler``, while business +/// logic should use ``NIOAsyncChannel`` to consume and produce data to the network. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannel : Sendable where Inbound : Sendable, Outbound : Sendable { + public struct Configuration : Sendable { + /// The back pressure strategy of the ``NIOAsyncChannel/inbound``. + public var backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark + + /// If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. + public var isOutboundHalfClosureEnabled: Bool + + /// The ``NIOAsyncChannel/inbound`` message's type. + public var inboundType: Inbound.Type + + /// The ``NIOAsyncChannel/outbound`` message's type. + public var outboundType: Outbound.Type + + /// Initializes a new ``NIOAsyncChannel/Configuration``. + /// + /// - Parameters: + /// - backPressureStrategy: The back pressure strategy of the ``NIOAsyncChannel/inbound``. Defaults + /// to a watermarked strategy (lowWatermark: 2, highWatermark: 10). + /// - isOutboundHalfClosureEnabled: If outbound half closure should be enabled. Outbound half closure is triggered once + /// the ``NIOAsyncChannelOutboundWriter`` is either finished or deinitialized. Defaults to `false`. + /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. + /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. + public init(backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), isOutboundHalfClosureEnabled: Bool = false, inboundType: Inbound.Type = Inbound.self, outboundType: Outbound.Type = Outbound.self) + } + + /// The underlying channel being wrapped by this ``NIOAsyncChannel``. + public let channel: NIOCore.Channel + + /// The stream of inbound messages. + /// + /// - Important: The `inbound` stream is a unicast `AsyncSequence` and only one iterator can be created. + public let inbound: NIOCore.NIOAsyncChannelInboundStream + + /// The writer for writing outbound messages. + public let outbound: NIOCore.NIOAsyncChannelOutboundWriter + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel``. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable public init(synchronouslyWrapping channel: NIOCore.Channel, configuration: NIOCore.NIOAsyncChannel.Configuration = .init()) throws + + /// Initializes a new ``NIOAsyncChannel`` wrapping a ``Channel`` where the outbound type is `Never`. + /// + /// This initializer will finish the ``NIOAsyncChannel/outboundWriter`` immediately. + /// + /// - Important: This **must** be called on the channel's event loop otherwise this init will crash. This is necessary because + /// we must install the handlers before any other event in the pipeline happens otherwise we might drop reads. + /// + /// - Parameters: + /// - channel: The ``Channel`` to wrap. + /// - configuration: The ``NIOAsyncChannel``s configuration. + @inlinable public init(synchronouslyWrapping channel: NIOCore.Channel, configuration: NIOCore.NIOAsyncChannel.Configuration = .init()) throws where Outbound == Never + + /// This method is only used from our server bootstrap to allow us to run the child channel initializer + /// at the right moment. + /// + /// - Important: This is not considered stable API and should not be used. + @inlinable public static func _wrapAsyncChannelWithTransformations(synchronouslyWrapping channel: NIOCore.Channel, backPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, isOutboundHalfClosureEnabled: Bool = false, channelReadTransformation: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) throws -> NIOCore.NIOAsyncChannel where Outbound == Never +} + +/// The inbound message asynchronous sequence of a ``NIOAsyncChannel``. +/// +/// This is a unicast async sequence that allows a single iterator to be created. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannelInboundStream : Sendable where Inbound : Sendable { + /// A source used for driving a ``NIOAsyncChannelInboundStream`` during tests. + public struct TestSource { + /// Yields the element to the inbound stream. + /// + /// - Parameter element: The element to yield to the inbound stream. + @inlinable public func yield(_ element: Inbound) + + /// Finished the inbound stream. + /// + /// - Parameter error: The error to throw, or nil, to finish normally. + @inlinable public func finish(throwing error: Error? = nil) + } + + /// Creates a new stream with a source for testing. + /// + /// This is useful for writing unit tests where you want to drive a ``NIOAsyncChannelInboundStream``. + /// + /// - Returns: A tuple containing the input stream and a test source to drive it. + @inlinable public static func makeTestingStream() -> (NIOCore.NIOAsyncChannelInboundStream, NIOCore.NIOAsyncChannelInboundStream.TestSource) +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelInboundStream : AsyncSequence { + /// The type of element produced by this asynchronous sequence. + public typealias Element = Inbound + + /// The type of asynchronous iterator that produces elements of this + /// asynchronous sequence. + public struct AsyncIterator : AsyncIteratorProtocol { + /// Asynchronously advances to the next element and returns it, or ends the + /// sequence if there is no next element. + /// + /// - Returns: The next element, if it exists, or `nil` to signal the end of + /// the sequence. + @inlinable public mutating func next() async throws -> NIOCore.NIOAsyncChannelInboundStream.Element? + } + + /// Creates the asynchronous iterator that produces elements of this + /// asynchronous sequence. + /// + /// - Returns: An instance of the `AsyncIterator` type used to produce + /// elements of the asynchronous sequence. + @inlinable public func makeAsyncIterator() -> NIOCore.NIOAsyncChannelInboundStream.AsyncIterator +} + +/// A ``NIOAsyncChannelOutboundWriter`` is used to write and flush new outbound messages in a channel. +/// +/// The writer acts as a bridge between the Concurrency and NIO world. It allows to write and flush messages into the +/// underlying ``Channel``. Furthermore, it respects back-pressure of the channel by suspending the calls to write until +/// the channel becomes writable again. +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +public struct NIOAsyncChannelOutboundWriter : Sendable where OutboundOut : Sendable { + /// An `AsyncSequence` backing a ``NIOAsyncChannelOutboundWriter`` for testing purposes. + public struct TestSink : AsyncSequence { + /// The type of element produced by this asynchronous sequence. + public typealias Element = OutboundOut + + /// Creates the asynchronous iterator that produces elements of this + /// asynchronous sequence. + /// + /// - Returns: An instance of the `AsyncIterator` type used to produce + /// elements of the asynchronous sequence. + public func makeAsyncIterator() -> NIOCore.NIOAsyncChannelOutboundWriter.TestSink.AsyncIterator + + /// The type of asynchronous iterator that produces elements of this + /// asynchronous sequence. + public struct AsyncIterator : AsyncIteratorProtocol { + /// Asynchronously advances to the next element and returns it, or ends the + /// sequence if there is no next element. + /// + /// - Returns: The next element, if it exists, or `nil` to signal the end of + /// the sequence. + public mutating func next() async -> NIOCore.NIOAsyncChannelOutboundWriter.TestSink.Element? + } + } + + /// Creates a new ``NIOAsyncChannelOutboundWriter`` backed by a ``NIOAsyncChannelOutboundWriter/TestSink``. + /// This is mostly useful for testing purposes where one wants to observe the written data. + @inlinable public static func makeTestingWriter() -> (NIOCore.NIOAsyncChannelOutboundWriter, NIOCore.NIOAsyncChannelOutboundWriter.TestSink) + + /// Send a write into the ``ChannelPipeline`` and flush it right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(_ data: OutboundOut) async throws + + /// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(contentsOf sequence: Writes) async throws where OutboundOut == Writes.Element, Writes : Sequence + + /// Send an asynchronous sequence of writes into the ``ChannelPipeline``. + /// + /// This will flush after every write. + /// + /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. + @inlinable public func write(contentsOf sequence: Writes) async throws where OutboundOut == Writes.Element, Writes : AsyncSequence + + /// Finishes the writer. + /// + /// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it. + public func finish() +} + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannelOutboundWriter.TestSink : Sendable {} +``` + +### Bootstraps + +```swift +extension ClientBootstrap { + /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Specify the `address` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - address: The address to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. + /// + /// - Parameters: + /// - unixDomainSocketPath: The _Unix domain socket_ path to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Use the existing connected socket file descriptor. + /// + /// - Parameters: + /// - descriptor: The _Unix file descriptor_ representing the connected stream socket. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withConnectedSocket(_ socket: NIOCore.NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension ServerBootstrap { + /// Bind the `ServerSocketChannel` to the `host` and `port` parameters. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - port: The port to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, port: Int, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Bind the `ServerSocketChannel` to the `address` parameter. + /// + /// - Parameters: + /// - address: The `SocketAddress` to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(to address: NIOCore.SocketAddress, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Bind the `ServerSocketChannel` to a UNIX Domain Socket. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, + /// unless `cleanupExistingSocketFile`is set to `true`. + /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable + + /// Use the existing bound socket file descriptor. + /// + /// - Parameters: + /// - socket: The _Unix file descriptor_ representing the bound stream socket. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(_ socket: NIOCore.NIOBSDSocket.Handle, cleanupExistingSocketFile: Bool = false, serverBackPressureStrategy: NIOCore.NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> NIOCore.NIOAsyncChannel where Output : Sendable +} + +extension DatagramBootstrap { + /// Use the existing bound socket file descriptor. + /// + /// - Parameters: + /// - socket: The _Unix file descriptor_ representing the bound stream socket. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withBoundSocket(_ socket: NIOCore.NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to `host` and `port`. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - port: The port to bind on. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to the `address`. + /// + /// - Parameters: + /// - address: The `SocketAddress` to bind on. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Bind the `DatagramChannel` to the `unixDomainSocketPath`. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. The`unixDomainSocketPath` must not exist, + /// unless `cleanupExistingSocketFile`is set to `true`. + /// - cleanupExistingSocketFile: Whether to cleanup an existing socket file at `unixDomainSocketPath`. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(unixDomainSocketPath: String, cleanupExistingSocketFile: Bool = false, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to `host` and `port`. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, port: Int, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to the `address`. + /// + /// - Parameters: + /// - address: The `SocketAddress` to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(to address: NIOCore.SocketAddress, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `DatagramChannel` to the `unixDomainSocketPath`. + /// + /// - Parameters: + /// - unixDomainSocketPath: The path of the UNIX Domain Socket to connect to. `path` must not exist, it will be created by the system. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension NIOPipeBootstrap { + + /// Create the `PipeChannel` with the provided file descriptor which is used for both input & output. + /// + /// This method is useful for specialilsed use-cases where you want to use `NIOPipeBootstrap` for say a serial line. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `inputOutput` when the `Channel` + /// becomes inactive. You _must not_ do any further operations with `inputOutput`, including `close`. + /// If this method returns a failed future, you still own the file descriptor and are responsible for + /// closing it. + /// + /// - Parameters: + /// - inputOutput: The _Unix file descriptor_ for the input & output. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptor(inputOutput: CInt, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Create the `PipeChannel` with the provided input and output file descriptors. + /// + /// The input and output file descriptors must be distinct. If you have a single file descriptor, consider using + /// `ClientBootstrap.withConnectedSocket(descriptor:)` if it's a socket or + /// `NIOPipeBootstrap.takingOwnershipOfDescriptor` if it is not a socket. + /// + /// - Note: If this method returns a succeeded future, SwiftNIO will close `input` and `output` + /// when the `Channel` becomes inactive. You _must not_ do any further operations `input` or + /// `output`, including `close`. + /// If this method returns a failed future, you still own the file descriptors and are responsible for + /// closing them. + /// + /// - Parameters: + /// - input: The _Unix file descriptor_ for the input (ie. the read side). + /// - output: The _Unix file descriptor_ for the output (ie. the write side). + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func takingOwnershipOfDescriptors(input: CInt, output: CInt, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} + +extension NIORawSocketBootstrap { + + /// Bind the `Channel` to `host`. + /// All packets or errors matching the `ipProtocol` specified are passed to the resulting `Channel`. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind(host: String, ipProtocol: NIOCore.NIOIPProtocol, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable + + /// Connect the `Channel` to `host`. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect(host: String, ipProtocol: NIOCore.NIOIPProtocol, channelInitializer: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) async throws -> Output where Output : Sendable +} +``` + +### HTTPUpgrade + +```swift +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPClientPipelineConfiguration where UpgradeResult : Sendable { + + /// The strategy to use when dealing with leftover bytes after removing the ``HTTPDecoder`` from the pipeline. + public var leftOverBytesStrategy: NIOHTTP1.RemoveAfterUpgradeStrategy + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableOutboundHeaderValidation: Bool + + /// The configuration for the ``HTTPRequestEncoder``. + public var encoderConfiguration: NIOHTTP1.HTTPRequestEncoder.Configuration + + /// The configuration for the ``NIOTypedHTTPClientUpgradeHandler``. + public var upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPClientPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Outbound header fields validation to protect against response splitting attacks. + public init(upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration) +} + +/// Configuration for an upgradable HTTP pipeline. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOUpgradableHTTPServerPipelineConfiguration where UpgradeResult : Sendable { + + /// Whether to provide assistance handling HTTP clients that pipeline + /// their requests. Defaults to `true`. If `false`, users will need to handle clients that pipeline themselves. + public var enablePipelining: Bool + + /// Whether to provide assistance handling protocol errors (e.g. failure to parse the HTTP + /// request) by sending 400 errors. Defaults to `true`. + public var enableErrorHandling: Bool + + /// Whether to validate outbound response headers to confirm that they are + /// spec compliant. Defaults to `true`. + public var enableResponseHeaderValidation: Bool + + /// The configuration for the ``HTTPResponseEncoder``. + public var encoderConfiguration: NIOHTTP1.HTTPResponseEncoder.Configuration + + /// The configuration for the ``NIOTypedHTTPServerUpgradeHandler``. + public var upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration + + /// Initializes a new ``NIOUpgradableHTTPServerPipelineConfiguration`` with default values. + /// + /// The current defaults provide the following features: + /// 1. Assistance handling clients that pipeline HTTP requests. + /// 2. Assistance handling protocol errors. + /// 3. Outbound header fields validation to protect against response splitting attacks. + public init(upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration) +} + +extension ChannelPipeline { + + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPServerPipelineConfiguration) -> NIOCore.EventLoopFuture> where UpgradeResult : Sendable +} + +extension ChannelPipeline.SynchronousOperations { + + /// Configure a `ChannelPipeline` for use as an HTTP server. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPServerPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPServerPipelineConfiguration) throws -> NIOCore.EventLoopFuture where UpgradeResult : Sendable +} + +extension ChannelPipeline { + + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that will fire when the pipeline is configured. The future contains an `EventLoopFuture` + /// that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPClientPipelineConfiguration) -> NIOCore.EventLoopFuture> where UpgradeResult : Sendable +} + +extension ChannelPipeline.SynchronousOperations { + + /// Configure a `ChannelPipeline` for use as an HTTP client. + /// + /// - Parameters: + /// - configuration: The HTTP pipeline's configuration. + /// - Returns: An `EventLoopFuture` that is fired once the pipeline has been upgraded or not and contains the `UpgradeResult`. + @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) + public func configureUpgradableHTTPClientPipeline(configuration: NIOHTTP1.NIOUpgradableHTTPClientPipelineConfiguration) throws -> NIOCore.EventLoopFuture where UpgradeResult : Sendable +} + +/// An object that implements `NIOTypedHTTPClientProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a client-side channel. +/// It has the option of denying this upgrade based upon the server response. +public protocol NIOTypedHTTPClientProtocolUpgrader { + + associatedtype UpgradeResult : Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol requires in the request to successfully upgrade. + /// These header fields will be added to the outbound request's "Connection" header field. + /// It is the responsibility of the custom headers call to actually add these required headers. + var requiredUpgradeHeaders: [String] { get } + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + func shouldAllowUpgrade(upgradeResponse: NIOHTTP1.HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + func upgrade(channel: NIOCore.Channel, upgradeResponse: NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPClientUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPClientUpgradeConfiguration where UpgradeResult : Sendable { + + /// The initial request head that is sent out once the channel becomes active. + public var upgradeRequestHead: NIOHTTP1.HTTPRequestHead + + /// The array of potential upgraders. + public var upgraders: [NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture + + public init(upgradeRequestHead: NIOHTTP1.HTTPRequestHead, upgraders: [NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader], notUpgradingCompletionHandler: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) +} + +/// A client-side channel handler that sends a HTTP upgrade handshake request to perform a HTTP-upgrade. +/// This handler will add all appropriate headers to perform an upgrade to +/// the a protocol. It may add headers for a set of protocols in preference order. +/// If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler simply +/// removes itself from the pipeline. If the upgrade is successful, it upgrades the pipeline to the new protocol. +/// +/// The request sends an order of preference to request which protocol it would like to use for the upgrade. +/// It will only upgrade to the protocol that is returned first in the list and does not currently +/// have the capability to upgrade to multiple simultaneous layered protocols. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedHTTPClientUpgradeHandler : NIOCore.ChannelDuplexHandler, NIOCore.RemovableChannelHandler where UpgradeResult : Sendable { + + /// The type of the outbound data which is wrapped in `NIOAny`. + public typealias OutboundIn = NIOHTTP1.HTTPClientRequestPart + + /// The type of the outbound data which will be forwarded to the next `ChannelOutboundHandler` in the `ChannelPipeline`. + public typealias OutboundOut = NIOHTTP1.HTTPClientRequestPart + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = NIOHTTP1.HTTPClientResponsePart + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = NIOHTTP1.HTTPClientResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: NIOCore.EventLoopFuture { get } + + /// Create a ``NIOTypedHTTPClientUpgradeHandler``. + /// + /// - Parameters: + /// - httpHandlers: All `RemovableChannelHandler` objects which will be removed from the pipeline + /// once the upgrade response is sent. This is used to ensure that the pipeline will be in a clean state + /// after the upgrade. It should include any handlers that are directly related to handling HTTP. + /// At the very least this should include the `HTTPEncoder` and `HTTPDecoder`, but should also include + /// any other handler that cannot tolerate receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init(httpHandlers: [NIOCore.RemovableChannelHandler], upgradeConfiguration: NIOHTTP1.NIOTypedHTTPClientUpgradeConfiguration) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when the `Channel` has become active, and is able to send and receive data. + /// + /// This should call `context.fireChannelActive` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func channelActive(context: NIOCore.ChannelHandlerContext) + + /// Called to request a write operation. The write operation will write the messages through the + /// `ChannelPipeline`. Those are then ready to be flushed to the actual `Channel` when + /// `Channel.flush` or `ChannelHandlerContext.flush` is called. + /// + /// This should call `context.write` to forward the operation to the next `_ChannelOutboundHandler` in the `ChannelPipeline` or + /// complete the `EventLoopPromise` to let the caller know that the operation completed. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data to write through the `Channel`, wrapped in a `NIOAny`. + /// - promise: The `EventLoopPromise` which should be notified once the operation completes, or nil if no notification should take place. + public func write(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny, promise: NIOCore.EventLoopPromise?) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) +} + +/// An object that implements `NIOTypedHTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to +/// a protocol on a server-side channel. +public protocol NIOTypedHTTPServerProtocolUpgrader { + + associatedtype UpgradeResult : Sendable + + /// The protocol this upgrader knows how to support. + var supportedProtocol: String { get } + + /// All the header fields the protocol needs in the request to successfully upgrade. These header fields + /// will be provided to the handler when it is asked to handle the upgrade. They will also be validated + /// against the inbound request's `Connection` header field. + var requiredUpgradeHeaders: [String] { get } + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + func buildUpgradeResponse(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead, initialResponseHeaders: NIOHTTP1.HTTPHeaders) -> NIOCore.EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + func upgrade(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture +} + +/// The upgrade configuration for the ``NIOTypedHTTPServerUpgradeHandler``. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +public struct NIOTypedHTTPServerUpgradeConfiguration where UpgradeResult : Sendable { + + /// The array of potential upgraders. + public var upgraders: [NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader] + + /// A closure that is run once it is determined that no protocol upgrade is happening. This can be used + /// to configure handlers that expect HTTP. + public var notUpgradingCompletionHandler: @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture + + public init(upgraders: [NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader], notUpgradingCompletionHandler: @escaping @Sendable (NIOCore.Channel) -> NIOCore.EventLoopFuture) +} + +/// A server-side channel handler that receives HTTP requests and optionally performs an HTTP-upgrade. +/// +/// Removes itself from the channel pipeline after the first inbound request on the connection, regardless of +/// whether the upgrade succeeded or not. +/// +/// This handler behaves a bit differently from its Netty counterpart because it does not allow upgrade +/// on any request but the first on a connection. This is primarily to handle clients that pipeline: it's +/// sufficiently difficult to ensure that the upgrade happens at a safe time while dealing with pipelined +/// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: +/// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedHTTPServerUpgradeHandler : NIOCore.ChannelInboundHandler, NIOCore.RemovableChannelHandler where UpgradeResult : Sendable { + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = NIOHTTP1.HTTPServerRequestPart + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = NIOHTTP1.HTTPServerRequestPart + + /// The type of the outbound data which will be forwarded to the next `ChannelOutboundHandler` in the `ChannelPipeline`. + public typealias OutboundOut = NIOHTTP1.HTTPServerResponsePart + + /// The upgrade future which will be completed once protocol upgrading has been done. + public var upgradeResultFuture: NIOCore.EventLoopFuture { get } + + /// Create a ``NIOTypedHTTPServerUpgradeHandler``. + /// + /// - Parameters: + /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will + /// be removed from the pipeline once the upgrade response is sent. This is used to ensure + /// that the pipeline will be in a clean state after upgrade. + /// - extraHTTPHandlers: Any other handlers that are directly related to handling HTTP. At the very least + /// this should include the `HTTPDecoder`, but should also include any other handler that cannot tolerate + /// receiving non-HTTP data. + /// - upgradeConfiguration: The upgrade configuration. + public init(httpEncoder: NIOHTTP1.HTTPResponseEncoder, extraHTTPHandlers: [NIOCore.RemovableChannelHandler], upgradeConfiguration: NIOHTTP1.NIOTypedHTTPServerUpgradeConfiguration) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) +} +``` + +### Websocket + +```swift +/// A `NIOTypedHTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the +/// pipeline to remove the HTTP `ChannelHandler`s. +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +final public class NIOTypedWebSocketClientUpgrader : NIOHTTP1.NIOTypedHTTPClientProtocolUpgrader where UpgradeResult : Sendable { + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String + + /// None of the websocket headers are actually defined as 'required'. + public let requiredUpgradeHeaders: [String] + + /// - Parameters: + /// - requestKey: Sent to the server in the `Sec-WebSocket-Key` HTTP header. Default is random request key. + /// - maxFrameSize: Largest incoming `WebSocketFrame` size in bytes. Default is 16,384 bytes. + /// - enableAutomaticErrorHandling: If true, adds `WebSocketProtocolErrorHandler` to the channel pipeline to catch and respond to WebSocket protocol errors. Default is true. + /// - upgradePipelineHandler: Called once the upgrade was successful. + public init(requestKey: String = NIOWebSocketClientUpgrader.randomRequestKey(), maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, upgradePipelineHandler: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture) + + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. + public func addCustom(upgradeRequestHeaders: inout NIOHTTP1.HTTPHeaders) + + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. + public func shouldAllowUpgrade(upgradeResponse: NIOHTTP1.HTTPResponseHead) -> Bool + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel + /// pipeline to add whatever channel handlers are required. + /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. + public func upgrade(channel: NIOCore.Channel, upgradeResponse: NIOHTTP1.HTTPResponseHead) -> NIOCore.EventLoopFuture +} + +/// A `NIOTypedHTTPServerProtocolUpgrader` that knows how to do the WebSocket upgrade dance. +/// +/// Users may frequently want to offer multiple websocket endpoints on the same port. For this +/// reason, this `WebServerSocketUpgrader` only knows how to do the required parts of the upgrade and to +/// complete the handshake. Users are expected to provide a callback that examines the HTTP headers +/// (including the path) and determines whether this is a websocket upgrade request that is acceptable +/// to them. +/// +/// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to +/// remove the HTTP `ChannelHandler`s. +final public class NIOTypedWebSocketServerUpgrader : NIOHTTP1.NIOTypedHTTPServerProtocolUpgrader, Sendable where UpgradeResult : Sendable { + + /// RFC 6455 specs this as the required entry in the Upgrade header. + public let supportedProtocol: String + + /// We deliberately do not actually set any required headers here, because the websocket + /// spec annoyingly does not actually force the client to send these in the Upgrade header, + /// which NIO requires. We check for these manually. + public let requiredUpgradeHeaders: [String] + + /// Create a new ``NIOTypedWebSocketServerUpgrader``. + /// + /// - Parameters: + /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the + /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but + /// this is an objectively unreasonable maximum value (on AMD64 systems it is not + /// possible to even. Users may set this to any value up to `UInt32.max`. + /// - automaticErrorHandling: Whether the pipeline should automatically handle protocol + /// errors by sending error responses and closing the connection. Defaults to `true`, + /// may be set to `false` if the user wishes to handle their own errors. + /// - shouldUpgrade: A callback that determines whether the websocket request should be + /// upgraded. This callback is responsible for creating a `HTTPHeaders` object with + /// any headers that it needs on the response *except for* the `Upgrade`, `Connection`, + /// and `Sec-WebSocket-Accept` headers, which this upgrader will handle. Should return + /// an `EventLoopFuture` containing `nil` if the upgrade should be refused. + /// - enableAutomaticErrorHandling: A function that will be called once the upgrade response is + /// flushed, and that is expected to mutate the `Channel` appropriately to handle the + /// websocket protocol. This only needs to add the user handlers: the + /// `WebSocketFrameEncoder` and `WebSocketFrameDecoder` will have been added to the + /// pipeline automatically. + public init(maxFrameSize: Int = 1 << 14, enableAutomaticErrorHandling: Bool = true, shouldUpgrade: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture, upgradePipelineHandler: @escaping @Sendable (NIOCore.Channel, NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture) + + /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client + /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should + /// return a failed future. + public func buildUpgradeResponse(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead, initialResponseHeaders: NIOHTTP1.HTTPHeaders) -> NIOCore.EventLoopFuture + + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline + /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received + /// data will be buffered. + public func upgrade(channel: NIOCore.Channel, upgradeRequest: NIOHTTP1.HTTPRequestHead) -> NIOCore.EventLoopFuture +} +``` + +### ALPN + +```swift +/// A helper ``ChannelInboundHandler`` that makes it easy to swap channel pipelines +/// based on the result of an ALPN negotiation. +/// +/// The standard pattern used by applications that want to use ALPN is to select +/// an application protocol based on the result, optionally falling back to some +/// default protocol. To do this in SwiftNIO requires that the channel pipeline be +/// reconfigured based on the result of the ALPN negotiation. This channel handler +/// encapsulates that logic in a generic form that doesn't depend on the specific +/// TLS implementation in use by using ``TLSUserEvent`` +/// +/// The user of this channel handler provides a single closure that is called with +/// an ``ALPNResult`` when the ALPN negotiation is complete. Based on that result +/// the user is free to reconfigure the ``ChannelPipeline`` as required, and should +/// return an ``EventLoopFuture`` that will complete when the pipeline is reconfigured. +/// +/// Until the ``EventLoopFuture`` completes, this channel handler will buffer inbound +/// data. When the ``EventLoopFuture`` completes, the buffered data will be replayed +/// down the channel. Then, finally, this channel handler will automatically remove +/// itself from the channel pipeline, leaving the pipeline in its final +/// configuration. +/// +/// Importantly, this is a typed variant of the ``ApplicationProtocolNegotiationHandler`` and allows the user to +/// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult`` +/// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into ``NIOAsyncChannel`` +/// based bootstraps. +final public class NIOTypedApplicationProtocolNegotiationHandler : NIOCore.ChannelInboundHandler, NIOCore.RemovableChannelHandler { + + /// The type of the inbound data which is wrapped in `NIOAny`. + public typealias InboundIn = Any + + /// The type of the inbound data which will be forwarded to the next `ChannelInboundHandler` in the `ChannelPipeline`. + public typealias InboundOut = Any + + public var protocolNegotiationResult: NIOCore.EventLoopFuture { get } + + /// Create an `ApplicationProtocolNegotiationHandler` with the given completion + /// callback. + /// + /// - Parameter alpnCompleteHandler: The closure that will fire when ALPN + /// negotiation has completed. + public init(alpnCompleteHandler: @escaping (NIOTLS.ALPNResult, NIOCore.Channel) -> NIOCore.EventLoopFuture) + + /// Create an `ApplicationProtocolNegotiationHandler` with the given completion + /// callback. + /// + /// - Parameter alpnCompleteHandler: The closure that will fire when ALPN + /// negotiation has completed. + public convenience init(alpnCompleteHandler: @escaping (NIOTLS.ALPNResult) -> NIOCore.EventLoopFutureNegotiationResult>) + + /// Called when this `ChannelHandler` is added to the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerAdded(context: NIOCore.ChannelHandlerContext) + + /// Called when this `ChannelHandler` is removed from the `ChannelPipeline`. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func handlerRemoved(context: NIOCore.ChannelHandlerContext) + + /// Called when a user inbound event has been triggered. + /// + /// This should call `context.fireUserInboundEventTriggered` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - event: The event. + public func userInboundEventTriggered(context: NIOCore.ChannelHandlerContext, event: Any) + + /// Called when some data has been read from the remote peer. + /// + /// This should call `context.fireChannelRead` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + /// - data: The data read from the remote peer, wrapped in a `NIOAny`. + public func channelRead(context: NIOCore.ChannelHandlerContext, data: NIOCore.NIOAny) + + /// Called when the `Channel` has become inactive and is no longer able to send and receive data. + /// + /// This should call `context.fireChannelInactive` to forward the operation to the next `_ChannelInboundHandler` in the `ChannelPipeline` if you want to allow the next handler to also handle the event. + /// + /// - parameters: + /// - context: The `ChannelHandlerContext` which this `ChannelHandler` belongs to. + public func channelInactive(context: NIOCore.ChannelHandlerContext) +} +``` + +### HTTP/2.0 + +```swift +extension NIOHTTP2Handler { + /// A variant of `NIOHTTP2Handler.StreamMultiplexer` which creates a child channel for each HTTP/2 stream and + /// provides access to inbound HTTP/2 streams. + /// + /// In general in NIO applications it is helpful to consider each HTTP/2 stream as an + /// independent stream of HTTP/2 frames. This multiplexer achieves this by creating a + /// number of in-memory `HTTP2StreamChannel` objects, one for each stream. These operate + /// on ``HTTP2Frame/FramePayload`` objects as their base communication + /// atom, as opposed to the regular NIO `SelectableChannel` objects which use `ByteBuffer` + /// and `IOData`. + /// + /// Outbound stream channel objects are initialized upon creation using the supplied `streamStateInitializer` which returns a type + /// `Output`. This type may be `HTTP2Frame` or changed to any other type. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public struct AsyncStreamMultiplexer { + /// Create a stream channel initialized with the provided closure + public func createStreamChannel(_ initializer: @escaping NIOChannelInitializerWithOutput) async throws -> Output + } +} + +/// `NIOHTTP2InboundStreamChannels` provides access to inbound stream channels as a generic `AsyncSequence`. +/// They make use of generics to allow for wrapping the stream `Channel`s, for example as `NIOAsyncChannel`s or protocol negotiation objects. +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +public struct NIOHTTP2InboundStreamChannels: AsyncSequence { + public struct AsyncIterator: AsyncIteratorProtocol { + public typealias Element = Output + + public mutating func next() async throws -> Output? + } + + public typealias Element = Output + + public func makeAsyncIterator() -> AsyncIterator +} + +extension Channel { + /// Configures a `ChannelPipeline` to speak HTTP/2 and sets up mapping functions so that it may be interacted with from concurrent code. + /// + /// In general this is not entirely useful by itself, as HTTP/2 is a negotiated protocol. This helper does not handle negotiation. + /// Instead, this simply adds the handler required to speak HTTP/2 after negotiation has completed, or when agreed by prior knowledge. + /// Use this function to setup a HTTP/2 pipeline if you wish to use async sequence abstractions over inbound and outbound streams. + /// Using this rather than implementing a similar function yourself allows that pipeline to evolve without breaking your code. + /// + /// - Parameters: + /// - mode: The mode this pipeline will operate in, server or client. + /// - configuration: The settings that will be used when establishing the connection and new streams. + /// - inboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing the `AsyncStreamMultiplexer` inserted into this pipeline, which can + /// be used to initiate new streams and iterate over inbound HTTP/2 stream channels. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTP2Pipeline( + mode: NIOHTTP2Handler.ParserMode, + configuration: NIOHTTP2Handler.Configuration = .init(), + inboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) -> EventLoopFuture> + + /// Configures a `ChannelPipeline` to speak either HTTP/1.1 or HTTP/2 according to what can be negotiated with the client. + /// + /// This helper takes care of configuring the server pipeline such that it negotiates whether to + /// use HTTP/1.1 or HTTP/2. + /// + /// This function doesn't configure the TLS handler. Callers of this function need to add a TLS + /// handler appropriately configured to perform protocol negotiation. + /// + /// - Parameters: + /// - http2Configuration: The settings that will be used when establishing the HTTP/2 connections and new HTTP/2 streams. + /// - http1ConnectionInitializer: An optional callback that will be invoked only when the negotiated protocol + /// is HTTP/1.1 to configure the connection channel. + /// - http2ConnectionInitializer: An optional callback that will be invoked only when the negotiated protocol + /// is HTTP/2 to configure the connection channel. + /// - http2InboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing a ``NIOTypedApplicationProtocolNegotiationHandler`` that completes when the channel + /// is ready to negotiate. This can then be used to access the protocol negotiation result which may itself + /// be waited on to retrieve the result of the negotiation. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTPServerPipeline( + http2Configuration: NIOHTTP2Handler.Configuration = .init(), + http1ConnectionInitializer: @escaping NIOChannelInitializerWithOutput, + http2ConnectionInitializer: @escaping NIOChannelInitializerWithOutput, + http2InboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) -> EventLoopFuture) + >>> + +extension ChannelPipeline.SynchronousOperations { + /// Configures a `ChannelPipeline` to speak HTTP/2 and sets up mapping functions so that it may be interacted with from concurrent code. + /// + /// This operation **must** be called on the event loop. + /// + /// In general this is not entirely useful by itself, as HTTP/2 is a negotiated protocol. This helper does not handle negotiation. + /// Instead, this simply adds the handler required to speak HTTP/2 after negotiation has completed, or when agreed by prior knowledge. + /// Use this function to setup a HTTP/2 pipeline if you wish to use async sequence abstractions over inbound and outbound streams, + /// as it allows that pipeline to evolve without breaking your code. + /// + /// - Parameters: + /// - mode: The mode this pipeline will operate in, server or client. + /// - configuration: The settings that will be used when establishing the connection and new streams. + /// - inboundStreamInitializer: A closure that will be called whenever the remote peer initiates a new stream. + /// The output of this closure is the element type of the returned multiplexer + /// - Returns: An `EventLoopFuture` containing the `AsyncStreamMultiplexer` inserted into this pipeline, which can + /// be used to initiate new streams and iterate over inbound HTTP/2 stream channels. + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + public func configureAsyncHTTP2Pipeline( + mode: NIOHTTP2Handler.ParserMode, + configuration: NIOHTTP2Handler.Configuration = .init(), + inboundStreamInitializer: @escaping NIOChannelInitializerWithOutput + ) throws -> NIOHTTP2Handler.AsyncStreamMultiplexer +} + +/// `NIONegotiatedHTTPVersion` is a generic negotiation result holder for HTTP/1.1 and HTTP/2 +public enum NIONegotiatedHTTPVersion { + case http1_1(HTTP1Output) + case http2(HTTP2Output) +} +``` diff --git a/scripts/cxx-interop-compatibility.sh b/scripts/cxx-interop-compatibility.sh new file mode 100755 index 0000000000..4109810786 --- /dev/null +++ b/scripts/cxx-interop-compatibility.sh @@ -0,0 +1,79 @@ +#!/bin/bash +##===----------------------------------------------------------------------===## +## +## This source file is part of the SwiftNIO open source project +## +## Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +## Licensed under Apache License v2.0 +## +## See LICENSE.txt for license information +## See CONTRIBUTORS.txt for the list of SwiftNIO project authors +## +## SPDX-License-Identifier: Apache-2.0 +## +##===----------------------------------------------------------------------===## + +set -eu + +sourcedir=$(pwd) +workingdir=$(mktemp -d) +projectname=$(basename $workingdir) + +cd $workingdir +swift package init + +cat << EOF > Package.swift +// swift-tools-version: 5.9 + +import PackageDescription + +let package = Package( + name: "interop", + products: [ + .library( + name: "interop", + targets: ["interop"] + ), + ], + dependencies: [ + .package(path: "$sourcedir") + ], + targets: [ + .target( + name: "interop", + // Depend on all products of swift-nio to make sure they're all + // compatible with cxx interop. + dependencies: [ + .product(name: "NIO", package: "swift-nio"), + .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + .product(name: "NIOTLS", package: "swift-nio"), + .product(name: "NIOEmbedded", package: "swift-nio"), + .product(name: "NIOPosix", package: "swift-nio"), + .product(name: "NIOHTTP1", package: "swift-nio"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOTestUtils", package: "swift-nio"), + .product(name: "_NIOConcurrency", package: "swift-nio") + ], + swiftSettings: [.interoperabilityMode(.Cxx)] + ) + ] +) +EOF + +cat << EOF > Sources/$projectname/$(echo $projectname | tr . _).swift +import NIO +import NIOCore +import NIOConcurrencyHelpers +import NIOTLS +import NIOEmbedded +import NIOPosix +import NIOHTTP1 +import NIOFoundationCompat +import NIOWebSocket +import NIOTestUtils +import _NIOConcurrency +EOF + +swift build