diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 568f12a6db..1e95efced1 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -10,7 +10,6 @@ jobs: uses: ./.github/workflows/pull_request_soundness.yml with: license_header_check_project_name: "SwiftNIO" - format_check_enabled: false call-pull-request-unit-tests-workflow: name: Unit tests diff --git a/.github/workflows/pull_request_soundness.yml b/.github/workflows/pull_request_soundness.yml index d537ac763f..26bc7fbb7a 100644 --- a/.github/workflows/pull_request_soundness.yml +++ b/.github/workflows/pull_request_soundness.yml @@ -123,5 +123,12 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Run format check - run: swift format lint --parallel --recursive --strict + - name: Mark the workspace as safe + # https://github.com/actions/checkout/issues/766 + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: Run format lint check + run: swift format lint --strict --recursive --parallel . + - name: Run format and check for modified files + run: | + swift format format --parallel --recursive --in-place . + git diff-index --quiet HEAD diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000000..7fa06fb305 --- /dev/null +++ b/.swift-format @@ -0,0 +1,62 @@ +{ + "version" : 1, + "indentation" : { + "spaces" : 4 + }, + "tabWidth" : 4, + "fileScopedDeclarationPrivacy" : { + "accessLevel" : "private" + }, + "spacesAroundRangeFormationOperators" : false, + "indentConditionalCompilationBlocks" : false, + "indentSwitchCaseLabels" : false, + "lineBreakAroundMultilineExpressionChainComponents" : false, + "lineBreakBeforeControlFlowKeywords" : false, + "lineBreakBeforeEachArgument" : true, + "lineBreakBeforeEachGenericRequirement" : true, + "lineLength" : 120, + "maximumBlankLines" : 1, + "respectsExistingLineBreaks" : true, + "prioritizeKeepingFunctionOutputTogether" : true, + "rules" : { + "AllPublicDeclarationsHaveDocumentation" : false, + "AlwaysUseLiteralForEmptyCollectionInit" : false, + "AlwaysUseLowerCamelCase" : false, + "AmbiguousTrailingClosureOverload" : true, + "BeginDocumentationCommentWithOneLineSummary" : false, + "DoNotUseSemicolons" : true, + "DontRepeatTypeInStaticProperties" : true, + "FileScopedDeclarationPrivacy" : true, + "FullyIndirectEnum" : true, + "GroupNumericLiterals" : true, + "IdentifiersMustBeASCII" : true, + "NeverForceUnwrap" : false, + "NeverUseForceTry" : false, + "NeverUseImplicitlyUnwrappedOptionals" : false, + "NoAccessLevelOnExtensionDeclaration" : true, + "NoAssignmentInExpressions" : true, + "NoBlockComments" : true, + "NoCasesWithOnlyFallthrough" : true, + "NoEmptyTrailingClosureParentheses" : true, + "NoLabelsInCasePatterns" : true, + "NoLeadingUnderscores" : false, + "NoParensAroundConditions" : true, + "NoVoidReturnOnFunctionSignature" : true, + "OmitExplicitReturns" : true, + "OneCasePerLine" : true, + "OneVariableDeclarationPerLine" : true, + "OnlyOneTrailingClosureArgument" : true, + "OrderedImports" : true, + "ReplaceForEachWithForLoop" : true, + "ReturnVoidInsteadOfEmptyTuple" : true, + "UseEarlyExits" : false, + "UseExplicitNilCheckInConditions" : false, + "UseLetInEveryBoundCaseVariable" : false, + "UseShorthandTypeNames" : true, + "UseSingleLinePropertyGetter" : false, + "UseSynthesizedInitializer" : false, + "UseTripleSlashForDocumentationComments" : true, + "UseWhereClausesInForLoops" : false, + "ValidateDocumentationComments" : false + } +} diff --git a/Benchmarks/Benchmarks/NIOCoreBenchmarks/Benchmarks.swift b/Benchmarks/Benchmarks/NIOCoreBenchmarks/Benchmarks.swift index 39fa33c7df..562de24eba 100644 --- a/Benchmarks/Benchmarks/NIOCoreBenchmarks/Benchmarks.swift +++ b/Benchmarks/Benchmarks/NIOCoreBenchmarks/Benchmarks.swift @@ -18,7 +18,7 @@ import NIOEmbedded let benchmarks = { let defaultMetrics: [BenchmarkMetric] = [ - .mallocCountTotal, + .mallocCountTotal ] Benchmark( @@ -28,7 +28,7 @@ let benchmarks = { // Elide the cost of the 'EmbeddedChannel'. It's only used for its pipeline. var channels: [EmbeddedChannel] = [] channels.reserveCapacity(benchmark.scaledIterations.count) - for _ in 0 ..< benchmark.scaledIterations.count { + for _ in 0.. Void) -> Void var swiftTaskEnqueueGlobalHook: EnqueueGlobalHook? { - get { _swiftTaskEnqueueGlobalHook.pointee } - set { _swiftTaskEnqueueGlobalHook.pointee = newValue } + get { _swiftTaskEnqueueGlobalHook.pointee } + set { _swiftTaskEnqueueGlobalHook.pointee = newValue } } private let _swiftTaskEnqueueGlobalHook: UnsafeMutablePointer = - dlsym(dlopen(nil, RTLD_LAZY), "swift_task_enqueueGlobal_hook").assumingMemoryBound(to: EnqueueGlobalHook?.self) + dlsym(dlopen(nil, RTLD_LAZY), "swift_task_enqueueGlobal_hook").assumingMemoryBound(to: EnqueueGlobalHook?.self) diff --git a/Benchmarks/Package.swift b/Benchmarks/Package.swift index 6748b30a41..13bef8d521 100644 --- a/Benchmarks/Package.swift +++ b/Benchmarks/Package.swift @@ -5,7 +5,7 @@ import PackageDescription let package = Package( name: "benchmarks", platforms: [ - .macOS("14"), + .macOS("14") ], dependencies: [ .package(path: "../"), diff --git a/IntegrationTests/allocation-counter-tests-framework/template/AtomicCounter/Package.swift b/IntegrationTests/allocation-counter-tests-framework/template/AtomicCounter/Package.swift index c1872d9e69..c37cdd56ff 100644 --- a/IntegrationTests/allocation-counter-tests-framework/template/AtomicCounter/Package.swift +++ b/IntegrationTests/allocation-counter-tests-framework/template/AtomicCounter/Package.swift @@ -19,12 +19,13 @@ import PackageDescription let package = Package( name: "AtomicCounter", products: [ - .library(name: "AtomicCounter", type: .dynamic, targets: ["AtomicCounter"]), + .library(name: "AtomicCounter", type: .dynamic, targets: ["AtomicCounter"]) ], - dependencies: [ ], + dependencies: [], targets: [ .target( name: "AtomicCounter", - dependencies: []), + dependencies: [] + ) ] ) diff --git a/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoHook/Package.swift b/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoHook/Package.swift index a00d52ab52..c05d6b9e8c 100644 --- a/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoHook/Package.swift +++ b/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoHook/Package.swift @@ -19,12 +19,12 @@ import PackageDescription let package = Package( name: "HookedFunctions", products: [ - .library(name: "HookedFunctions", type: .dynamic, targets: ["HookedFunctions"]), + .library(name: "HookedFunctions", type: .dynamic, targets: ["HookedFunctions"]) ], dependencies: [ - .package(url: "../AtomicCounter/", branch: "main"), + .package(url: "../AtomicCounter/", branch: "main") ], targets: [ - .target(name: "HookedFunctions", dependencies: ["AtomicCounter"]), + .target(name: "HookedFunctions", dependencies: ["AtomicCounter"]) ] ) diff --git a/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoNotHook/Package.swift b/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoNotHook/Package.swift index a00d52ab52..c05d6b9e8c 100644 --- a/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoNotHook/Package.swift +++ b/IntegrationTests/allocation-counter-tests-framework/template/HookedFunctionsDoNotHook/Package.swift @@ -19,12 +19,12 @@ import PackageDescription let package = Package( name: "HookedFunctions", products: [ - .library(name: "HookedFunctions", type: .dynamic, targets: ["HookedFunctions"]), + .library(name: "HookedFunctions", type: .dynamic, targets: ["HookedFunctions"]) ], dependencies: [ - .package(url: "../AtomicCounter/", branch: "main"), + .package(url: "../AtomicCounter/", branch: "main") ], targets: [ - .target(name: "HookedFunctions", dependencies: ["AtomicCounter"]), + .target(name: "HookedFunctions", dependencies: ["AtomicCounter"]) ] ) diff --git a/IntegrationTests/allocation-counter-tests-framework/template/scaffolding.swift b/IntegrationTests/allocation-counter-tests-framework/template/scaffolding.swift index 39519c2ac7..ca102aa1e8 100644 --- a/IntegrationTests/allocation-counter-tests-framework/template/scaffolding.swift +++ b/IntegrationTests/allocation-counter-tests-framework/template/scaffolding.swift @@ -12,8 +12,9 @@ // //===----------------------------------------------------------------------===// -import Foundation import AtomicCounter +import Foundation + #if canImport(Darwin) import Darwin #elseif canImport(Glibc) @@ -24,7 +25,7 @@ import Glibc func waitForThreadsToQuiesce(shouldReachZero: Bool) { func getUnfreed() -> Int { - return AtomicCounter.read_malloc_counter() - AtomicCounter.read_free_counter() + AtomicCounter.read_malloc_counter() - AtomicCounter.read_free_counter() } var oldNumberOfUnfreed = getUnfreed() @@ -35,7 +36,8 @@ func waitForThreadsToQuiesce(shouldReachZero: Bool) { return } count += 1 - usleep(shouldReachZero ? 50_000 : 200_000) // allocs/frees happen on multiple threads, allow some cool down time + // allocs/frees happen on multiple threads, allow some cool down time + usleep(shouldReachZero ? 50_000 : 200_000) let newNumberOfUnfreed = getUnfreed() if oldNumberOfUnfreed == newNumberOfUnfreed && (!shouldReachZero || newNumberOfUnfreed <= 0) { // nothing happened in the last 100ms, let's assume everything's @@ -57,7 +59,11 @@ struct Measurement { } extension Array where Element == Measurement { - private func printIntegerMetric(_ keyPath: KeyPath, description desc: String, metricName k: String) { + private func printIntegerMetric( + _ keyPath: KeyPath, + description desc: String, + metricName k: String + ) { let vs = self.map { $0[keyPath: keyPath] } print("\(desc).\(k): \(vs.min() ?? -1)") } @@ -90,13 +96,13 @@ func measureAll(trackFDs: Bool, _ fn: () -> Int) -> [Measurement] { AtomicCounter.begin_tracking_fds() } -#if canImport(Darwin) + #if canImport(Darwin) autoreleasepool { _ = fn() } -#else + #else _ = fn() -#endif + #endif waitForThreadsToQuiesce(shouldReachZero: !throwAway) let frees = AtomicCounter.read_free_counter() let mallocs = AtomicCounter.read_malloc_counter() @@ -121,7 +127,7 @@ func measureAll(trackFDs: Bool, _ fn: () -> Int) -> [Measurement] { ) } - _ = measureOne(throwAway: true, trackFDs: trackFDs, fn) /* pre-heat and throw away */ + _ = measureOne(throwAway: true, trackFDs: trackFDs, fn) // pre-heat and throw away var measurements: [Measurement] = [] for _ in 0..<10 { @@ -132,19 +138,19 @@ func measureAll(trackFDs: Bool, _ fn: () -> Int) -> [Measurement] { return measurements } -func measureAndPrint(desc: String, trackFDs: Bool, fn: () -> Int) -> Void { +func measureAndPrint(desc: String, trackFDs: Bool, fn: () -> Int) { let measurements = measureAll(trackFDs: trackFDs, fn) measurements.printTotalAllocations(description: desc) measurements.printRemainingAllocations(description: desc) measurements.printTotalAllocatedBytes(description: desc) measurements.printLeakedFDs(description: desc) - + print("DEBUG: \(measurements)") } public func measure(identifier: String, trackFDs: Bool = false, _ body: () -> Int) { measureAndPrint(desc: identifier, trackFDs: trackFDs) { - return body() + body() } } @@ -160,7 +166,7 @@ func measureAll(trackFDs: Bool, _ fn: @escaping () async -> Int) -> [Measurement } group.wait() } - + if trackFDs { AtomicCounter.begin_tracking_fds() } @@ -169,13 +175,13 @@ func measureAll(trackFDs: Bool, _ fn: @escaping () async -> Int) -> [Measurement AtomicCounter.reset_malloc_counter() AtomicCounter.reset_malloc_bytes_counter() -#if canImport(Darwin) + #if canImport(Darwin) autoreleasepool { run(fn) } -#else + #else run(fn) -#endif + #endif waitForThreadsToQuiesce(shouldReachZero: !throwAway) let frees = AtomicCounter.read_free_counter() let mallocs = AtomicCounter.read_malloc_counter() @@ -200,7 +206,7 @@ func measureAll(trackFDs: Bool, _ fn: @escaping () async -> Int) -> [Measurement ) } - _ = measureOne(throwAway: true, trackFDs: trackFDs, fn) /* pre-heat and throw away */ + _ = measureOne(throwAway: true, trackFDs: trackFDs, fn) // pre-heat and throw away var measurements: [Measurement] = [] for _ in 0..<10 { @@ -212,7 +218,7 @@ func measureAll(trackFDs: Bool, _ fn: @escaping () async -> Int) -> [Measurement } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -func measureAndPrint(desc: String, trackFDs: Bool, fn: @escaping () async -> Int) -> Void { +func measureAndPrint(desc: String, trackFDs: Bool, fn: @escaping () async -> Int) { let measurements = measureAll(trackFDs: trackFDs, fn) measurements.printTotalAllocations(description: desc) measurements.printRemainingAllocations(description: desc) diff --git a/IntegrationTests/tests_04_performance/test_01_resources/shared.swift b/IntegrationTests/tests_04_performance/test_01_resources/shared.swift index 239cd0bad7..8f540d8651 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/shared.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/shared.swift @@ -14,54 +14,56 @@ import Foundation import NIOCore -import NIOPosix import NIOHTTP1 +import NIOPosix let localhostPickPort = try! SocketAddress.makeAddressResolvingHost("127.0.0.1", port: 0) let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) final class RepeatedRequests: ChannelInboundHandler { - typealias InboundIn = HTTPClientResponsePart - typealias OutboundOut = HTTPClientRequestPart + typealias InboundIn = HTTPClientResponsePart + typealias OutboundOut = HTTPClientRequestPart - private let numberOfRequests: Int - private var remainingNumberOfRequests: Int - private let isDonePromise: EventLoopPromise - static var requestHead: HTTPRequestHead { + private let numberOfRequests: Int + private var remainingNumberOfRequests: Int + private let isDonePromise: EventLoopPromise + static var requestHead: HTTPRequestHead { var head = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/allocation-test-1") - head.headers.add(name: "Host", value: "foo-\(ObjectIdentifier(self)).com") - return head - } - - init(numberOfRequests: Int, eventLoop: EventLoop) { - self.remainingNumberOfRequests = numberOfRequests - self.numberOfRequests = numberOfRequests - self.isDonePromise = eventLoop.makePromise() - } - - func wait() throws -> Int { - let reqs = try self.isDonePromise.futureResult.wait() - precondition(reqs == self.numberOfRequests) - return reqs - } - - func errorCaught(context: ChannelHandlerContext, error: Error) { - context.channel.close(promise: nil) - self.isDonePromise.fail(error) - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let respPart = Self.unwrapInboundIn(data) - if case .end(nil) = respPart { - if self.remainingNumberOfRequests <= 0 { - context.channel.close().map { self.numberOfRequests - self.remainingNumberOfRequests }.cascade(to: self.isDonePromise) - } else { - self.remainingNumberOfRequests -= 1 - context.write(Self.wrapOutboundOut(.head(RepeatedRequests.requestHead)), promise: nil) - context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) - } - } - } + head.headers.add(name: "Host", value: "foo-\(ObjectIdentifier(self)).com") + return head + } + + init(numberOfRequests: Int, eventLoop: EventLoop) { + self.remainingNumberOfRequests = numberOfRequests + self.numberOfRequests = numberOfRequests + self.isDonePromise = eventLoop.makePromise() + } + + func wait() throws -> Int { + let reqs = try self.isDonePromise.futureResult.wait() + precondition(reqs == self.numberOfRequests) + return reqs + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + context.channel.close(promise: nil) + self.isDonePromise.fail(error) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let respPart = self.unwrapInboundIn(data) + if case .end(nil) = respPart { + if self.remainingNumberOfRequests <= 0 { + context.channel.close().map { self.numberOfRequests - self.remainingNumberOfRequests }.cascade( + to: self.isDonePromise + ) + } else { + self.remainingNumberOfRequests -= 1 + context.write(self.wrapOutboundOut(.head(RepeatedRequests.requestHead)), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + } } private final class SimpleHTTPServer: ChannelInboundHandler { @@ -94,10 +96,13 @@ private final class SimpleHTTPServer: ChannelInboundHandler { } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - if case .head(let req) = Self.unwrapInboundIn(data), req.uri == "/allocation-test-1" { - context.write(Self.wrapOutboundOut(.head(self.responseHead)), promise: nil) - context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.responseBody(allocator: context.channel.allocator)))), promise: nil) - context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) + if case .head(let req) = self.unwrapInboundIn(data), req.uri == "/allocation-test-1" { + context.write(self.wrapOutboundOut(.head(self.responseHead)), promise: nil) + context.write( + self.wrapOutboundOut(.body(.byteBuffer(self.responseBody(allocator: context.channel.allocator)))), + promise: nil + ) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } } @@ -106,8 +111,10 @@ func doRequests(group: EventLoopGroup, number numberOfRequests: Int) throws -> I let serverChannel = try ServerBootstrap(group: group) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, - withErrorHandling: false).flatMap { + channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withErrorHandling: false + ).flatMap { channel.pipeline.addHandler(SimpleHTTPServer()) } }.bind(to: localhostPickPort).wait() @@ -116,7 +123,6 @@ func doRequests(group: EventLoopGroup, number numberOfRequests: Int) throws -> I try! serverChannel.close().wait() } - let repeatedRequestsHandler = RepeatedRequests(numberOfRequests: numberOfRequests, eventLoop: group.next()) let clientChannel = try ClientBootstrap(group: group) @@ -176,42 +182,42 @@ enum UDPShared { public typealias InboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope private var repetitionsRemaining: Int - + private let remoteAddress: SocketAddress - + init(remoteAddress: SocketAddress, numberOfRepetitions: Int) { self.remoteAddress = remoteAddress self.repetitionsRemaining = numberOfRepetitions } - + public func channelActive(context: ChannelHandlerContext) { // Channel is available. It's time to send the message to the server to initialize the ping-pong sequence. self.sendSomeDataIfDesiredOrClose(context: context) } - + private func sendSomeDataIfDesiredOrClose(context: ChannelHandlerContext) { if repetitionsRemaining > 0 { repetitionsRemaining -= 1 - + // Set the transmission data. let line = "Something to send there and back again." let buffer = context.channel.allocator.buffer(string: line) - + // Forward the data. let envolope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) - - context.writeAndFlush(Self.wrapOutboundOut(envolope), promise: nil) + + context.writeAndFlush(self.wrapOutboundOut(envolope), promise: nil) } else { // We're all done - hurrah! context.close(promise: nil) } } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { // Got back a response - maybe send some more. self.sendSomeDataIfDesiredOrClose(context: context) } - + public func errorCaught(context: ChannelHandlerContext, error: Error) { // Errors should never happen. fatalError("EchoHandlerClient received errorCaught") @@ -222,20 +228,24 @@ enum UDPShared { let serverChannel = try DatagramBootstrap(group: group) // Set the handlers that are applied to the bound channel .channelInitializer { channel in - return channel.pipeline.addHandler(EchoHandler()) + channel.pipeline.addHandler(EchoHandler()) } .bind(to: localhostPickPort).wait() defer { try! serverChannel.close().wait() } - + let remoteAddress = serverChannel.localAddress! let clientChannel = try DatagramBootstrap(group: group) .channelInitializer { channel in - channel.pipeline.addHandler(EchoHandlerClient(remoteAddress: remoteAddress, - numberOfRepetitions: numberOfRequests)) + channel.pipeline.addHandler( + EchoHandlerClient( + remoteAddress: remoteAddress, + numberOfRepetitions: numberOfRequests + ) + ) } .bind(to: localhostPickPort).wait() diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_10000000_asyncsequenceproducer.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_10000000_asyncsequenceproducer.swift index bdaeefb5c3..30af0df7cf 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_10000000_asyncsequenceproducer.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_10000000_asyncsequenceproducer.swift @@ -15,10 +15,12 @@ import NIOCore @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -fileprivate typealias SequenceProducer = NIOAsyncSequenceProducer +private typealias SequenceProducer = NIOAsyncSequenceProducer< + Int, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, Delegate +> @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -fileprivate final class Delegate: NIOAsyncSequenceProducerDelegate, @unchecked Sendable { +private final class Delegate: NIOAsyncSequenceProducerDelegate, @unchecked Sendable { private let elements = Array(repeating: 1, count: 1000) var source: SequenceProducer.Source! @@ -36,7 +38,10 @@ func run(identifier: String) { } measure(identifier: identifier) { let delegate = Delegate() - let producer = SequenceProducer.makeSequence(backPressureStrategy: .init(lowWatermark: 100, highWatermark: 500), delegate: delegate) + let producer = SequenceProducer.makeSequence( + backPressureStrategy: .init(lowWatermark: 100, highWatermark: 500), + delegate: delegate + ) let sequence = producer.sequence delegate.source = producer.source @@ -44,7 +49,7 @@ func run(identifier: String) { for await i in sequence { counter += i - if counter == 10000000 { + if counter == 10_000_000 { return counter } } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000000_asyncwriter.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000000_asyncwriter.swift index 62229e123a..bb87284d32 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000000_asyncwriter.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000000_asyncwriter.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import NIOCore import DequeModule +import NIOCore @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -fileprivate struct Delegate: NIOAsyncWriterSinkDelegate, Sendable { +private struct Delegate: NIOAsyncWriterSinkDelegate, Sendable { typealias Element = Int func didYield(contentsOf sequence: Deque) {} @@ -33,10 +33,10 @@ func run(identifier: String) { let newWriter = NIOAsyncWriter.makeWriter(isWritable: true, delegate: delegate) let writer = newWriter.writer - for i in 0..<1000000 { + for i in 0..<1_000_000 { try! await writer.yield(i) } - return 1000000 + return 1_000_000 } } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers.swift index c2df1293bc..a8e789fa8d 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers.swift @@ -15,7 +15,7 @@ import NIOCore import NIOEmbedded -fileprivate final class SimpleHandler: ChannelInboundHandler { +private final class SimpleHandler: ChannelInboundHandler { typealias InboundIn = NIOAny } @@ -29,7 +29,7 @@ func run(identifier: String) { } try! channel.pipeline.addHandlers([ SimpleHandler(), - SimpleHandler() + SimpleHandler(), ]).wait() } return iterations diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers_sync.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers_sync.swift index 127aedce9c..b97304956e 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers_sync.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addHandlers_sync.swift @@ -15,7 +15,7 @@ import NIOCore import NIOEmbedded -fileprivate final class SimpleHandler: ChannelInboundHandler { +private final class SimpleHandler: ChannelInboundHandler { typealias InboundIn = NIOAny } @@ -28,8 +28,8 @@ func run(identifier: String) { _ = try! channel.finish() } try! channel.pipeline.syncOperations.addHandlers([ - SimpleHandler(), - SimpleHandler() + SimpleHandler(), + SimpleHandler(), ]) } return iterations diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addRemoveHandlers.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addRemoveHandlers.swift index 4dc6eab941..8aef13916a 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addRemoveHandlers.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_addRemoveHandlers.swift @@ -15,7 +15,7 @@ import NIOCore import NIOEmbedded -fileprivate final class RemovableHandler: ChannelInboundHandler, RemovableChannelHandler { +private final class RemovableHandler: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = NIOAny static let name: String = "RemovableHandler" @@ -32,7 +32,10 @@ fileprivate final class RemovableHandler: ChannelInboundHandler, RemovableChanne } @inline(__always) -private func addRemoveBench(iterations: Int, _ removalOperation: (Channel, RemovableHandler) -> EventLoopFuture) -> Int { +private func addRemoveBench( + iterations: Int, + _ removalOperation: (Channel, RemovableHandler) -> EventLoopFuture +) -> Int { let channel = EmbeddedChannel() defer { _ = try! channel.finish() diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_autoReadGetAndSet.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_autoReadGetAndSet.swift index a508856411..8f143e1a87 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_autoReadGetAndSet.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_autoReadGetAndSet.swift @@ -22,8 +22,8 @@ func run(identifier: String) { } let server = try! ServerBootstrap(group: group) - .bind(host: "127.0.0.1", port: 0) - .wait() + .bind(host: "127.0.0.1", port: 0) + .wait() defer { try! server.close().wait() } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers.swift index ecf41f6694..5b3d5925c8 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers.swift @@ -15,7 +15,7 @@ import NIOCore import NIOEmbedded -fileprivate final class SimpleHandler: ChannelInboundHandler { +private final class SimpleHandler: ChannelInboundHandler { typealias InboundIn = NIOAny } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers_sync.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers_sync.swift index d05dc2084e..7414bba32f 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers_sync.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_getHandlers_sync.swift @@ -15,7 +15,7 @@ import NIOCore import NIOEmbedded -fileprivate final class SimpleHandler: ChannelInboundHandler { +private final class SimpleHandler: ChannelInboundHandler { typealias InboundIn = NIOAny } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_rst_connections.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_rst_connections.swift index 9b620718ad..3e8bf65af8 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_rst_connections.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_rst_connections.swift @@ -35,7 +35,7 @@ func run(identifier: String) { let serverConnection = try! ServerBootstrap(group: group) .bind(host: "localhost", port: 0) .wait() - + let serverAddress = serverConnection.localAddress! let clientBootstrap = ClientBootstrap(group: group) .channelInitializer { channel in @@ -48,7 +48,7 @@ func run(identifier: String) { let iterations = 1000 for _ in 0.. EventLoopFuture in writeWaitAndClose(clientChannel: clientChannel, buffer: buffer) @@ -82,9 +82,9 @@ func run(identifier: String) { } } -fileprivate func writeWaitAndClose(clientChannel: Channel, buffer: ByteBuffer) -> EventLoopFuture { +private func writeWaitAndClose(clientChannel: Channel, buffer: ByteBuffer) -> EventLoopFuture { // Send a byte to make sure everything is really open. - return clientChannel.writeAndFlush(buffer).flatMap { + clientChannel.writeAndFlush(buffer).flatMap { clientChannel.closeFuture } } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udp_reqs.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udp_reqs.swift index 619b746823..f87d529add 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udp_reqs.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udp_reqs.swift @@ -15,7 +15,7 @@ import NIOCore import NIOPosix -fileprivate final class ServerEchoHandler: ChannelInboundHandler { +private final class ServerEchoHandler: ChannelInboundHandler { public typealias InboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope @@ -35,19 +35,19 @@ fileprivate final class ServerEchoHandler: ChannelInboundHandler { } } -fileprivate final class ClientHandler: ChannelInboundHandler { +private final class ClientHandler: ChannelInboundHandler { public typealias InboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope - + private let remoteAddress: SocketAddress - + init(remoteAddress: SocketAddress) { self.remoteAddress = remoteAddress } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { // If we still have iterations to do send some more data. - if (self.iterationsOutstanding > 0) { + if self.iterationsOutstanding > 0 { self.iterationsOutstanding -= 1 sendBytes(clientChannel: context.channel) } else { @@ -63,10 +63,10 @@ fileprivate final class ClientHandler: ChannelInboundHandler { public func errorCaught(context: ChannelHandlerContext, error: Error) { fatalError() } - + var iterationsOutstanding = 0 var whenDone: EventLoopPromise? = nil - + private func sendBytes(clientChannel: Channel) { var buffer = clientChannel.allocator.buffer(capacity: 1) buffer.writeInteger(3, as: UInt8.self) @@ -75,10 +75,10 @@ fileprivate final class ClientHandler: ChannelInboundHandler { let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer, metadata: metadata) clientChannel.writeAndFlush(Self.wrapOutboundOut(envelope), promise: nil) } - + func sendBytesAndWaitForReply(clientChannel: Channel) -> Int { let numberOfIterations = 1000 - + // Setup for iteration. self.iterationsOutstanding = numberOfIterations self.whenDone = clientChannel.eventLoop.makePromise() @@ -95,7 +95,7 @@ func run(identifier: String) { .channelOption(ChannelOptions.explicitCongestionNotification, value: true) // Set the handlers that are applied to the bound channel .channelInitializer { channel in - return channel.pipeline.addHandler(ServerEchoHandler()) + channel.pipeline.addHandler(ServerEchoHandler()) } .bind(to: localhostPickPort).wait() defer { @@ -114,9 +114,8 @@ func run(identifier: String) { defer { try! clientChannel.close().wait() } - + measure(identifier: identifier) { clientHandler.sendBytesAndWaitForReply(clientChannel: clientChannel) } } - diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udpbootstraps.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udpbootstraps.swift index 333f3ee99f..ed5f074912 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udpbootstraps.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_1000_udpbootstraps.swift @@ -15,7 +15,7 @@ import NIOCore import NIOPosix -fileprivate final class DoNothingHandler: ChannelInboundHandler { +private final class DoNothingHandler: ChannelInboundHandler { public typealias InboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer } @@ -23,7 +23,7 @@ fileprivate final class DoNothingHandler: ChannelInboundHandler { func run(identifier: String) { measure(identifier: identifier) { let numberOfIterations = 1000 - for _ in 0 ..< numberOfIterations { + for _ in 0.. - + var completionFuture: EventLoopFuture { - return self.completed.futureResult + self.completed.futureResult } - + init(numberOfReadsExpected: Int, completionPromise: EventLoopPromise) { self.readsRemaining = numberOfReadsExpected self.completed = completionPromise } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.readsRemaining -= 1 if self.readsRemaining <= 0 { @@ -41,13 +41,15 @@ fileprivate final class CountReadsHandler: ChannelInboundHandler { func run(identifier: String) { let numberOfIterations = 1000 - - let serverHandler = CountReadsHandler(numberOfReadsExpected: numberOfIterations, - completionPromise: group.next().makePromise()) + + let serverHandler = CountReadsHandler( + numberOfReadsExpected: numberOfIterations, + completionPromise: group.next().makePromise() + ) let serverChannel = try! DatagramBootstrap(group: group) // Set the handlers that are applied to the bound channel .channelInitializer { channel in - return channel.pipeline.addHandler(serverHandler) + channel.pipeline.addHandler(serverHandler) } .bind(to: localhostPickPort).wait() defer { @@ -55,13 +57,13 @@ func run(identifier: String) { } let remoteAddress = serverChannel.localAddress! - + let clientBootstrap = DatagramBootstrap(group: group) measure(identifier: identifier) { let buffer = ByteBuffer(integer: 1, as: UInt8.self) - for _ in 0 ..< numberOfIterations { - try! clientBootstrap.bind(to: localhostPickPort).flatMap { clientChannel -> EventLoopFuture in + for _ in 0.. EventLoopFuture in // Send a byte to make sure everything is really open. let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) return clientChannel.writeAndFlush(envelope).flatMap { @@ -73,4 +75,3 @@ func run(identifier: String) { return numberOfIterations } } - diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_bytebuffer_lots_of_rw.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_bytebuffer_lots_of_rw.swift index 1954f46880..f92e87977f 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_bytebuffer_lots_of_rw.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_bytebuffer_lots_of_rw.swift @@ -25,7 +25,7 @@ func run(identifier: String) { let substring = Substring("A") @inline(never) func doWrites(buffer: inout ByteBuffer, dispatchData: DispatchData, substring: Substring) { - /* these ones are zero allocations */ + // these ones are zero allocations // buffer.writeBytes(foundationData) // see SR-7542 buffer.writeBytes([0x41]) buffer.writeBytes("A".utf8) @@ -33,15 +33,15 @@ func run(identifier: String) { buffer.writeStaticString("A") buffer.writeInteger(0x41, as: UInt8.self) - /* those down here should be one allocation each (on Linux) */ - buffer.writeBytes(dispatchData) // see https://bugs.swift.org/browse/SR-9597 + // those down here should be one allocation each (on Linux) + buffer.writeBytes(dispatchData) // see https://bugs.swift.org/browse/SR-9597 - /* these here are one allocation on all platforms */ + // these here are one allocation on all platforms buffer.writeSubstring(substring) } @inline(never) func doReads(buffer: inout ByteBuffer) { - /* these ones are zero allocations */ + // these ones are zero allocations let val = buffer.readInteger(as: UInt8.self) precondition(0x41 == val, "\(val!)") var slice = buffer.readSlice(length: 1) @@ -51,13 +51,13 @@ func run(identifier: String) { precondition(ptr[0] == 0x41) } - /* those down here should be one allocation each */ + // those down here should be one allocation each let arr = buffer.readBytes(length: 1) precondition([0x41] == arr!, "\(arr!)") let str = buffer.readString(length: 1) precondition("A" == str, "\(str!)") } - for _ in 0..<1000 { + for _ in 0..<1000 { doWrites(buffer: &buffer, dispatchData: dispatchData, substring: substring) doReads(buffer: &buffer) } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_decode_1000_ws_frames.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_decode_1000_ws_frames.swift index 7c9d6ab239..b9ee98c997 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_decode_1000_ws_frames.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_decode_1000_ws_frames.swift @@ -30,7 +30,7 @@ func run(identifier: String) { let channel = EmbeddedChannel() try! channel.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder())).wait() try! channel.pipeline.addHandler(UnboxingChannelHandler()).wait() - let data = ByteBuffer(bytes: [0x81, 0x00]) // empty websocket + let data = ByteBuffer(bytes: [0x81, 0x00]) // empty websocket measure(identifier: identifier) { for _ in 0..<1000 { diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_encode_1000_ws_frames.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_encode_1000_ws_frames.swift index b33d5ac5a2..1f8ece92b2 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_encode_1000_ws_frames.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_encode_1000_ws_frames.swift @@ -16,7 +16,13 @@ import NIOCore import NIOEmbedded import NIOWebSocket -func doSendFramesHoldingBuffer(channel: EmbeddedChannel, number numberOfFrameSends: Int, data originalData: [UInt8], spareBytesAtFront: Int, mask: WebSocketMaskingKey? = nil) throws -> Int { +func doSendFramesHoldingBuffer( + channel: EmbeddedChannel, + number numberOfFrameSends: Int, + data originalData: [UInt8], + spareBytesAtFront: Int, + mask: WebSocketMaskingKey? = nil +) throws -> Int { var data = channel.allocator.buffer(capacity: originalData.count + spareBytesAtFront) data.moveWriterIndex(forwardBy: spareBytesAtFront) data.moveReaderIndex(forwardBy: spareBytesAtFront) @@ -35,8 +41,13 @@ func doSendFramesHoldingBuffer(channel: EmbeddedChannel, number numberOfFrameSen return numberOfFrameSends } - -func doSendFramesNewBuffer(channel: EmbeddedChannel, number numberOfFrameSends: Int, data originalData: [UInt8], spareBytesAtFront: Int, mask: WebSocketMaskingKey? = nil) throws -> Int { +func doSendFramesNewBuffer( + channel: EmbeddedChannel, + number numberOfFrameSends: Int, + data originalData: [UInt8], + spareBytesAtFront: Int, + mask: WebSocketMaskingKey? = nil +) throws -> Int { for _ in 0.. EventLoopFuture in // This call allocates a new Future, and // so does flatMap(), so this is two Futures. - return loop.makeSucceededFuture(r + 1) + loop.makeSucceededFuture(r + 1) }.flatMapThrowing { (r: Int) -> Int in // flatMapThrowing allocates a new Future, and calls `flatMap` // which also allocates, so this is two. - return r + 2 + r + 2 }.map { (r: Int) -> Int in // map allocates a new future, and calls `flatMap` which // also allocates, so this is two. - return r + 2 + r + 2 }.flatMapThrowing { (r: Int) -> Int in // flatMapThrowing allocates a future on the error path and // calls `flatMap`, which also allocates, so this is two. @@ -40,7 +40,7 @@ func run(identifier: String) { }.flatMapError { (err: Error) -> EventLoopFuture in // This call allocates a new Future, and so does flatMapError, // so this is two Futures. - return loop.makeFailedFuture(err) + loop.makeFailedFuture(err) }.flatMapErrorThrowing { (err: Error) -> Int in // flatMapError allocates a new Future, and calls flatMapError, // so this is two Futures @@ -48,7 +48,7 @@ func run(identifier: String) { }.recover { (err: Error) -> Int in // recover allocates a future, and calls flatMapError, so // this is two Futures. - return 1 + 1 } p.succeed(0) @@ -65,10 +65,10 @@ func run(identifier: String) { // and(result:) allocate two. let f = p1.futureResult - .and(p2.futureResult) - .and(p3.futureResult) - .and(value: 1) - .and(value: 1) + .and(p2.futureResult) + .and(p3.futureResult) + .and(value: 1) + .and(value: 1) p1.succeed(1) p2.succeed(1) @@ -76,7 +76,7 @@ func run(identifier: String) { _ = try! f.wait() } let el = EmbeddedEventLoop() - for _ in 0..<1000 { + for _ in 0..<1000 { doThenAndFriends(loop: el) doAnd(loop: el) } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_modifying_1000_circular_buffer_elements.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_modifying_1000_circular_buffer_elements.swift index f6231248ff..7eecba9762 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_modifying_1000_circular_buffer_elements.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_modifying_1000_circular_buffer_elements.swift @@ -15,7 +15,7 @@ import NIOCore func run(identifier: String) { - var buffer = CircularBuffer>(initialCapacity: 100) + var buffer = CircularBuffer<[Int]>(initialCapacity: 100) for _ in 0..<100 { buffer.append([]) } diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_ping_pong_1000_reqs_1_conn.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_ping_pong_1000_reqs_1_conn.swift index 99fb379001..ebcad8f0b5 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_ping_pong_1000_reqs_1_conn.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_ping_pong_1000_reqs_1_conn.swift @@ -35,8 +35,12 @@ private final class PongDecoder: ByteToMessageDecoder { } } - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { - return .needMoreData + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { + .needMoreData } } @@ -69,9 +73,8 @@ private final class PingHandler: ChannelInboundHandler { } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - var buf = Self.unwrapInboundIn(data) - if buf.readableBytes == 1 && - buf.readInteger(as: UInt8.self) == PongHandler.pongCode { + var buf = self.unwrapInboundIn(data) + if buf.readableBytes == 1 && buf.readInteger(as: UInt8.self) == PongHandler.pongCode { if self.remainingNumberOfRequests > 0 { self.remainingNumberOfRequests -= 1 context.writeAndFlush(Self.wrapOutboundOut(self.pingBuffer), promise: nil) diff --git a/IntegrationTests/tests_04_performance/test_01_resources/test_read_10000_chunks_from_file.swift b/IntegrationTests/tests_04_performance/test_01_resources/test_read_10000_chunks_from_file.swift index e4550f15e8..06ae6deed0 100644 --- a/IntegrationTests/tests_04_performance/test_01_resources/test_read_10000_chunks_from_file.swift +++ b/IntegrationTests/tests_04_performance/test_01_resources/test_read_10000_chunks_from_file.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import Foundation import Dispatch +import Foundation +import NIOConcurrencyHelpers import NIOCore import NIOPosix -import NIOConcurrencyHelpers func run(identifier: String) { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) @@ -37,27 +37,33 @@ func run(identifier: String) { var fileBuffer = allocator.buffer(capacity: numberOfChunks) fileBuffer.writeString(String(repeating: "X", count: numberOfChunks)) let path = NSTemporaryDirectory() + "/\(UUID())" - let fileHandle = try! NIOFileHandle(path: path, - mode: [.write, .read], - flags: .allowFileCreation(posixMode: 0o600)) + let fileHandle = try! NIOFileHandle( + path: path, + mode: [.write, .read], + flags: .allowFileCreation(posixMode: 0o600) + ) defer { unlink(path) } - try! fileIO.write(fileHandle: fileHandle, - buffer: fileBuffer, - eventLoop: loop).wait() + try! fileIO.write( + fileHandle: fileHandle, + buffer: fileBuffer, + eventLoop: loop + ).wait() let numberOfBytes = NIOAtomic.makeAtomic(value: 0) measure(identifier: identifier) { numberOfBytes.store(0) - try! fileIO.readChunked(fileHandle: fileHandle, - fromOffset: 0, - byteCount: numberOfChunks, - chunkSize: 1, - allocator: allocator, - eventLoop: loop) { buffer in - numberOfBytes.add(buffer.readableBytes) - return loop.makeSucceededFuture(()) + try! fileIO.readChunked( + fileHandle: fileHandle, + fromOffset: 0, + byteCount: numberOfChunks, + chunkSize: 1, + allocator: allocator, + eventLoop: loop + ) { buffer in + numberOfBytes.add(buffer.readableBytes) + return loop.makeSucceededFuture(()) }.wait() precondition(numberOfBytes.load() == numberOfChunks, "\(numberOfBytes.load()), \(numberOfChunks)") return numberOfBytes.load() diff --git a/Package.swift b/Package.swift index 1e683691b8..89277a203b 100644 --- a/Package.swift +++ b/Package.swift @@ -18,9 +18,9 @@ import PackageDescription let swiftAtomics: PackageDescription.Target.Dependency = .product(name: "Atomics", package: "swift-atomics") let swiftCollections: PackageDescription.Target.Dependency = .product(name: "DequeModule", package: "swift-collections") let swiftSystem: PackageDescription.Target.Dependency = .product( - name: "SystemPackage", - package: "swift-system", - condition: .when(platforms: [.macOS, .iOS, .tvOS, .watchOS, .linux, .android]) + name: "SystemPackage", + package: "swift-system", + condition: .when(platforms: [.macOS, .iOS, .tvOS, .watchOS, .linux, .android]) ) // This doesn't work when cross-compiling: the privacy manifest will be included in the Bundle and @@ -121,7 +121,7 @@ let package = Package( name: "CNIOAtomics", dependencies: [], cSettings: [ - .define("_GNU_SOURCE"), + .define("_GNU_SOURCE") ] ), .target( @@ -132,14 +132,14 @@ let package = Package( name: "CNIOLinux", dependencies: [], cSettings: [ - .define("_GNU_SOURCE"), + .define("_GNU_SOURCE") ] ), .target( name: "CNIODarwin", dependencies: [], cSettings: [ - .define("__APPLE_USE_RFC_3542"), + .define("__APPLE_USE_RFC_3542") ] ), .target( @@ -149,7 +149,7 @@ let package = Package( .target( name: "NIOConcurrencyHelpers", dependencies: [ - "CNIOAtomics", + "CNIOAtomics" ] ), .target( @@ -159,7 +159,7 @@ let package = Package( "NIOCore", "NIOConcurrencyHelpers", "CNIOLLHTTP", - swiftCollections + swiftCollections, ] ), .target( @@ -169,14 +169,14 @@ let package = Package( "NIOCore", "NIOHTTP1", "CNIOSHA1", - "_NIOBase64" + "_NIOBase64", ] ), .target( name: "CNIOLLHTTP", cSettings: [ - .define("_GNU_SOURCE"), - .define("LLHTTP_STRICT_MODE") + .define("_GNU_SOURCE"), + .define("LLHTTP_STRICT_MODE"), ] ), .target( @@ -218,14 +218,14 @@ let package = Package( .target( name: "NIOFileSystem", dependencies: [ - "_NIOFileSystem", + "_NIOFileSystem" ], path: "Sources/_NIOFileSystemExported" ), .target( name: "_NIOFileSystemFoundationCompat", dependencies: [ - "_NIOFileSystem", + "_NIOFileSystem" ], path: "Sources/NIOFileSystemFoundationCompat" ), @@ -503,7 +503,7 @@ let package = Package( // Contains known files and directory structures used // for the integration tests. Exclude the whole tree from // the build. - "Test Data", + "Test Data" ] ), .testTarget( @@ -512,7 +512,7 @@ let package = Package( "_NIOFileSystem", "_NIOFileSystemFoundationCompat", ] - ) + ), ] ) diff --git a/Snippets/NIOFileSystemTour.swift b/Snippets/NIOFileSystemTour.swift index f79135b45b..e394c272cc 100644 --- a/Snippets/NIOFileSystemTour.swift +++ b/Snippets/NIOFileSystemTour.swift @@ -1,103 +1,103 @@ // snippet.hide -import _NIOFileSystem + import NIOCore +import _NIOFileSystem @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -func main() async throws -{ - // snippet.show +func main() async throws { + // snippet.show - // NIOFileSystem provides access to the local file system via the FileSystem - // type which is available as a global shared instance. - let fileSystem = FileSystem.shared + // NIOFileSystem provides access to the local file system via the FileSystem + // type which is available as a global shared instance. + let fileSystem = FileSystem.shared - // Files can be inspected by using 'info': - if let info = try await fileSystem.info(forFileAt: "/Users/hal9000/demise-of-dave.txt") { - print("demise-of-dave.txt has type '\(info.type)'") - } else { - print("demise-of-dave.txt doesn't exist") - } + // Files can be inspected by using 'info': + if let info = try await fileSystem.info(forFileAt: "/Users/hal9000/demise-of-dave.txt") { + print("demise-of-dave.txt has type '\(info.type)'") + } else { + print("demise-of-dave.txt doesn't exist") + } - // Let's find out what's in that file. - do { - // Reading a whole file requires a limit. If the file is larger than the limit - // then an error is thrown. This avoids accidentally consuming too much memory - // if the file is larger than expected. - let plan = try await ByteBuffer( - contentsOf: "/Users/hal9000/demise-of-dave.txt", - maximumSizeAllowed: .mebibytes(1) - ) - print("Plan for Dave's demise:", String(decoding: plan.readableBytesView, as: UTF8.self)) - } catch let error as FileSystemError where error.code == .notFound { - // All errors thrown by the module have type FileSystemError (or - // Swift.CancellationError). It looks like the file doesn't exist. Let's - // create it now. - // - // The code above for reading the file is shorthand for opening the file in - // read-only mode and then reading its contents. The FileSystemProtocol - // has a few different 'withFileHandle' methods for opening a file in different - // modes. Let's open a file for writing, creating it at the same time. - try await fileSystem.withFileHandle( - forWritingAt: "/Users/hal9000/demise-of-dave.txt", - options: .newFile(replaceExisting: false) - ) { file in - let plan = ByteBuffer(string: "TODO...") - try await file.write(contentsOf: plan.readableBytesView, toAbsoluteOffset: 0) + // Let's find out what's in that file. + do { + // Reading a whole file requires a limit. If the file is larger than the limit + // then an error is thrown. This avoids accidentally consuming too much memory + // if the file is larger than expected. + let plan = try await ByteBuffer( + contentsOf: "/Users/hal9000/demise-of-dave.txt", + maximumSizeAllowed: .mebibytes(1) + ) + print("Plan for Dave's demise:", String(decoding: plan.readableBytesView, as: UTF8.self)) + } catch let error as FileSystemError where error.code == .notFound { + // All errors thrown by the module have type FileSystemError (or + // Swift.CancellationError). It looks like the file doesn't exist. Let's + // create it now. + // + // The code above for reading the file is shorthand for opening the file in + // read-only mode and then reading its contents. The FileSystemProtocol + // has a few different 'withFileHandle' methods for opening a file in different + // modes. Let's open a file for writing, creating it at the same time. + try await fileSystem.withFileHandle( + forWritingAt: "/Users/hal9000/demise-of-dave.txt", + options: .newFile(replaceExisting: false) + ) { file in + let plan = ByteBuffer(string: "TODO...") + try await file.write(contentsOf: plan.readableBytesView, toAbsoluteOffset: 0) + } } - } - // Directories can be opened like regular files but they cannot be read from or - // written to. However, their contents can be listed: - let path: FilePath? = try await fileSystem.withDirectoryHandle(atPath: "/Users/hal9000/Music") { directory in - for try await entry in directory.listContents() { - if entry.name.extension == "mp3", entry.name.stem.contains("daisy") { - // Found it! - return entry.path - } + // Directories can be opened like regular files but they cannot be read from or + // written to. However, their contents can be listed: + let path: FilePath? = try await fileSystem.withDirectoryHandle(atPath: "/Users/hal9000/Music") { directory in + for try await entry in directory.listContents() { + if entry.name.extension == "mp3", entry.name.stem.contains("daisy") { + // Found it! + return entry.path + } + } + // No luck. + return nil } - // No luck. - return nil - } - if let path = path { - print("Found file at '\(path)'") - } + if let path = path { + print("Found file at '\(path)'") + } - // The file system can also be used to perform the following operations on files - // and directories: - // - copy, - // - remove, - // - rename, and - // - replace. - // - // Here's an example of copying a directory: - try await fileSystem.copyItem(at: "/Users/hal9000/Music", to: "/Volumes/Tardis/Music") + // The file system can also be used to perform the following operations on files + // and directories: + // - copy, + // - remove, + // - rename, and + // - replace. + // + // Here's an example of copying a directory: + try await fileSystem.copyItem(at: "/Users/hal9000/Music", to: "/Volumes/Tardis/Music") - // Symbolic links can also be created (and read with 'destinationOfSymbolicLink(at:)'). - try await fileSystem.createSymbolicLink(at: "/Users/hal9000/Backup", withDestination: "/Volumes/Tardis") + // Symbolic links can also be created (and read with 'destinationOfSymbolicLink(at:)'). + try await fileSystem.createSymbolicLink(at: "/Users/hal9000/Backup", withDestination: "/Volumes/Tardis") - // Opening a symbolic link opens its destination so in most cases there's no - // need to read the destination of a symbolic link: - try await fileSystem.withDirectoryHandle(atPath: "/Users/hal9000/Backup") { directory in - // Beyond listing the contents of a directory, the directory handle provides a - // number of other functions, many of which are also available on regular file - // handles. - // - // This includes getting information about a file, such as its permissions, last access time, - // and last modification time: - let info = try await directory.info() - print("The directory has permissions '\(info.permissions)'") + // Opening a symbolic link opens its destination so in most cases there's no + // need to read the destination of a symbolic link: + try await fileSystem.withDirectoryHandle(atPath: "/Users/hal9000/Backup") { directory in + // Beyond listing the contents of a directory, the directory handle provides a + // number of other functions, many of which are also available on regular file + // handles. + // + // This includes getting information about a file, such as its permissions, last access time, + // and last modification time: + let info = try await directory.info() + print("The directory has permissions '\(info.permissions)'") - // Where supported, the extended attributes of a file can also be accessed, read, and modified: - for attribute in try await directory.attributeNames() { - let value = try await directory.valueForAttribute(attribute) - print("Extended attribute '\(attribute)' has value '\(value)'") - } + // Where supported, the extended attributes of a file can also be accessed, read, and modified: + for attribute in try await directory.attributeNames() { + let value = try await directory.valueForAttribute(attribute) + print("Extended attribute '\(attribute)' has value '\(value)'") + } - // Once this closure returns the file system will close the directory handle freeing - // any resources required to access it such as file descriptors. Handles can also be opened - // with the 'openFile' and 'openDirectory' APIs but that places the onus you to close the - // handle at an appropriate time to avoid leaking resources. - } - // snippet.end + // Once this closure returns the file system will close the directory handle freeing + // any resources required to access it such as file descriptors. Handles can also be opened + // with the 'openFile' and 'openDirectory' APIs but that places the onus you to close the + // handle at an appropriate time to avoid leaking resources. + } + // snippet.end } diff --git a/Sources/NIOAsyncAwaitDemo/AsyncChannelIO.swift b/Sources/NIOAsyncAwaitDemo/AsyncChannelIO.swift index c4e6ca69d7..e5612e1f69 100644 --- a/Sources/NIOAsyncAwaitDemo/AsyncChannelIO.swift +++ b/Sources/NIOAsyncAwaitDemo/AsyncChannelIO.swift @@ -24,7 +24,8 @@ struct AsyncChannelIO { } func start() async throws -> AsyncChannelIO { - try await channel.pipeline.addHandler(RequestResponseHandler()).get() + try await channel.pipeline.addHandler(RequestResponseHandler()) + .get() return self } diff --git a/Sources/NIOAsyncAwaitDemo/FullRequestResponse.swift b/Sources/NIOAsyncAwaitDemo/FullRequestResponse.swift index 944e132b49..f0d03ac61d 100644 --- a/Sources/NIOAsyncAwaitDemo/FullRequestResponse.swift +++ b/Sources/NIOAsyncAwaitDemo/FullRequestResponse.swift @@ -63,7 +63,6 @@ public final class RequestResponseHandler: ChannelDuplexHandl private var state: State = .operational private var promiseBuffer: CircularBuffer> - /// Create a new `RequestResponseHandler`. /// /// - parameters: @@ -83,7 +82,7 @@ public final class RequestResponseHandler: ChannelDuplexHandl case .operational: let promiseBuffer = self.promiseBuffer self.promiseBuffer.removeAll() - promiseBuffer.forEach { promise in + for promise in promiseBuffer { promise.fail(ChannelError.eof) } } @@ -112,8 +111,8 @@ public final class RequestResponseHandler: ChannelDuplexHandl let promiseBuffer = self.promiseBuffer self.promiseBuffer.removeAll() context.close(promise: nil) - promiseBuffer.forEach { - $0.fail(error) + for promise in promiseBuffer { + promise.fail(error) } } diff --git a/Sources/NIOAsyncAwaitDemo/main.swift b/Sources/NIOAsyncAwaitDemo/main.swift index a93dd98305..ce7bcd126e 100644 --- a/Sources/NIOAsyncAwaitDemo/main.swift +++ b/Sources/NIOAsyncAwaitDemo/main.swift @@ -11,18 +11,25 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import Dispatch import NIOCore -import NIOPosix import NIOHTTP1 -import Dispatch +import NIOPosix @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -func makeHTTPChannel(host: String, port: Int, group: EventLoopGroup) async throws -> AsyncChannelIO { +func makeHTTPChannel( + host: String, + port: Int, + group: EventLoopGroup +) async throws -> AsyncChannelIO { let channel = try await ClientBootstrap(group: group) .channelInitializer { channel in channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHTTPClientHandlers() - try channel.pipeline.syncOperations.addHandler(NIOHTTPClientResponseAggregator(maxContentLength: 1_000_000)) + try channel.pipeline.syncOperations.addHandler( + NIOHTTPClientResponseAggregator(maxContentLength: 1_000_000) + ) try channel.pipeline.syncOperations.addHandler(MakeFullRequestHandler()) } } @@ -39,17 +46,25 @@ func main() async { print("OK, connected to \(channel)") print("Sending request 1", terminator: "") - let response1 = try await channel.sendRequest(HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/base64/SGVsbG8gV29ybGQsIGZyb20gSFRUUEJpbiEgCg==", - headers: ["host": "httpbin.org"])) + let response1 = try await channel.sendRequest( + HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/base64/SGVsbG8gV29ybGQsIGZyb20gSFRUUEJpbiEgCg==", + headers: ["host": "httpbin.org"] + ) + ) print(", response:", String(buffer: response1.body ?? ByteBuffer())) print("Sending request 2", terminator: "") - let response2 = try await channel.sendRequest(HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/get", - headers: ["host": "httpbin.org"])) + let response2 = try await channel.sendRequest( + HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/get", + headers: ["host": "httpbin.org"] + ) + ) print(", response:", String(buffer: response2.body ?? ByteBuffer())) try await channel.close() diff --git a/Sources/NIOChatClient/main.swift b/Sources/NIOChatClient/main.swift index 1251b11fa9..399d671088 100644 --- a/Sources/NIOChatClient/main.swift +++ b/Sources/NIOChatClient/main.swift @@ -20,7 +20,7 @@ private final class ChatHandler: ChannelInboundHandler { private func printByte(_ byte: UInt8) { #if os(Android) - print(Character(UnicodeScalar(byte)), terminator:"") + print(Character(UnicodeScalar(byte)), terminator: "") #else fputc(Int32(byte), stdout) #endif @@ -68,14 +68,14 @@ enum ConnectTo { let connectTarget: ConnectTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ +case (.some(let h), _, .some(let p)): + // we got two arguments, let's interpret that as host and port connectTarget = .ip(host: h, port: p) case (.some(let portString), .none, _): - /* couldn't parse as number, expecting unix domain socket path */ + // couldn't parse as number, expecting unix domain socket path connectTarget = .unixDomainSocket(path: portString) case (_, .some(let p), _): - /* only one argument --> port */ + // only one argument --> port connectTarget = .ip(host: defaultHost, port: p) default: connectTarget = .ip(host: defaultHost, port: defaultPort) diff --git a/Sources/NIOChatServer/main.swift b/Sources/NIOChatServer/main.swift index e45af00e7d..df9f6e9655 100644 --- a/Sources/NIOChatServer/main.swift +++ b/Sources/NIOChatServer/main.swift @@ -11,9 +11,10 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import Dispatch import NIOCore import NIOPosix -import Dispatch private let newLine = "\n".utf8.first! @@ -51,28 +52,36 @@ final class ChatHandler: ChannelInboundHandler { // All access to channels is guarded by channelsSyncQueue. private let channelsSyncQueue = DispatchQueue(label: "channelsQueue") private var channels: [ObjectIdentifier: Channel] = [:] - + public func channelActive(context: ChannelHandlerContext) { let remoteAddress = context.remoteAddress! let channel = context.channel self.channelsSyncQueue.async { // broadcast the message to all the connected clients except the one that just became active. - self.writeToAll(channels: self.channels, allocator: channel.allocator, message: "(ChatServer) - New client connected with address: \(remoteAddress)\n") - + self.writeToAll( + channels: self.channels, + allocator: channel.allocator, + message: "(ChatServer) - New client connected with address: \(remoteAddress)\n" + ) + self.channels[ObjectIdentifier(channel)] = channel } - + var buffer = channel.allocator.buffer(capacity: 64) buffer.writeString("(ChatServer) - Welcome to: \(context.localAddress!)\n") context.writeAndFlush(Self.wrapOutboundOut(buffer), promise: nil) } - + public func channelInactive(context: ChannelHandlerContext) { let channel = context.channel self.channelsSyncQueue.async { if self.channels.removeValue(forKey: ObjectIdentifier(channel)) != nil { // Broadcast the message to all the connected clients except the one that just was disconnected. - self.writeToAll(channels: self.channels, allocator: channel.allocator, message: "(ChatServer) - Client disconnected\n") + self.writeToAll( + channels: self.channels, + allocator: channel.allocator, + message: "(ChatServer) - Client disconnected\n" + ) } } } @@ -100,12 +109,14 @@ final class ChatHandler: ChannelInboundHandler { } private func writeToAll(channels: [ObjectIdentifier: Channel], allocator: ByteBufferAllocator, message: String) { - let buffer = allocator.buffer(string: message) + let buffer = allocator.buffer(string: message) self.writeToAll(channels: channels, buffer: buffer) } private func writeToAll(channels: [ObjectIdentifier: Channel], buffer: ByteBuffer) { - channels.forEach { $0.value.writeAndFlush(buffer, promise: nil) } + for channel in channels { + channel.value.writeAndFlush(buffer, promise: nil) + } } } @@ -154,8 +165,8 @@ enum BindTo { let bindTarget: BindTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ +case (.some(let h), _, .some(let p)): + // we got two arguments, let's interpret that as host and port bindTarget = .ip(host: h, port: p) case (let portString?, .none, _): @@ -180,7 +191,9 @@ let channel = try { () -> Channel in }() guard let localAddress = channel.localAddress else { - fatalError("Address was unable to bind. Please check that the socket was not closed or that the address family was understood.") + fatalError( + "Address was unable to bind. Please check that the socket was not closed or that the address family was understood." + ) } print("Server started and listening on \(localAddress)") diff --git a/Sources/NIOConcurrencyHelpers/NIOAtomic.swift b/Sources/NIOConcurrencyHelpers/NIOAtomic.swift index 345b5931b2..c718b5105e 100644 --- a/Sources/NIOConcurrencyHelpers/NIOAtomic.swift +++ b/Sources/NIOConcurrencyHelpers/NIOAtomic.swift @@ -33,145 +33,153 @@ public protocol NIOAtomicPrimitive { extension Bool: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic__Bool public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic__Bool_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic__Bool_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic__Bool_add - public static let nio_atomic_sub = catmc_nio_atomic__Bool_sub - public static let nio_atomic_exchange = catmc_nio_atomic__Bool_exchange - public static let nio_atomic_load = catmc_nio_atomic__Bool_load - public static let nio_atomic_store = catmc_nio_atomic__Bool_store + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic__Bool_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic__Bool_add + public static let nio_atomic_sub = catmc_nio_atomic__Bool_sub + public static let nio_atomic_exchange = catmc_nio_atomic__Bool_exchange + public static let nio_atomic_load = catmc_nio_atomic__Bool_load + public static let nio_atomic_store = catmc_nio_atomic__Bool_store } extension Int8: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_int_least8_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_int_least8_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least8_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_int_least8_t_add - public static let nio_atomic_sub = catmc_nio_atomic_int_least8_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_int_least8_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_int_least8_t_load - public static let nio_atomic_store = catmc_nio_atomic_int_least8_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_int_least8_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least8_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_int_least8_t_add + public static let nio_atomic_sub = catmc_nio_atomic_int_least8_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_int_least8_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_int_least8_t_load + public static let nio_atomic_store = catmc_nio_atomic_int_least8_t_store } extension UInt8: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_uint_least8_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_uint_least8_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least8_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_uint_least8_t_add - public static let nio_atomic_sub = catmc_nio_atomic_uint_least8_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_uint_least8_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_uint_least8_t_load - public static let nio_atomic_store = catmc_nio_atomic_uint_least8_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_uint_least8_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least8_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_uint_least8_t_add + public static let nio_atomic_sub = catmc_nio_atomic_uint_least8_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_uint_least8_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_uint_least8_t_load + public static let nio_atomic_store = catmc_nio_atomic_uint_least8_t_store } extension Int16: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_int_least16_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_int_least16_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least16_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_int_least16_t_add - public static let nio_atomic_sub = catmc_nio_atomic_int_least16_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_int_least16_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_int_least16_t_load - public static let nio_atomic_store = catmc_nio_atomic_int_least16_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_int_least16_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least16_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_int_least16_t_add + public static let nio_atomic_sub = catmc_nio_atomic_int_least16_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_int_least16_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_int_least16_t_load + public static let nio_atomic_store = catmc_nio_atomic_int_least16_t_store } extension UInt16: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_uint_least16_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_uint_least16_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least16_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_uint_least16_t_add - public static let nio_atomic_sub = catmc_nio_atomic_uint_least16_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_uint_least16_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_uint_least16_t_load - public static let nio_atomic_store = catmc_nio_atomic_uint_least16_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_uint_least16_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least16_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_uint_least16_t_add + public static let nio_atomic_sub = catmc_nio_atomic_uint_least16_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_uint_least16_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_uint_least16_t_load + public static let nio_atomic_store = catmc_nio_atomic_uint_least16_t_store } extension Int32: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_int_least32_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_int_least32_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least32_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_int_least32_t_add - public static let nio_atomic_sub = catmc_nio_atomic_int_least32_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_int_least32_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_int_least32_t_load - public static let nio_atomic_store = catmc_nio_atomic_int_least32_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_int_least32_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_int_least32_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_int_least32_t_add + public static let nio_atomic_sub = catmc_nio_atomic_int_least32_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_int_least32_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_int_least32_t_load + public static let nio_atomic_store = catmc_nio_atomic_int_least32_t_store } extension UInt32: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_uint_least32_t - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_uint_least32_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least32_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_uint_least32_t_add - public static let nio_atomic_sub = catmc_nio_atomic_uint_least32_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_uint_least32_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_uint_least32_t_load - public static let nio_atomic_store = catmc_nio_atomic_uint_least32_t_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_uint_least32_t_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uint_least32_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_uint_least32_t_add + public static let nio_atomic_sub = catmc_nio_atomic_uint_least32_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_uint_least32_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_uint_least32_t_load + public static let nio_atomic_store = catmc_nio_atomic_uint_least32_t_store } extension Int64: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_long_long public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_long_long_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_long_long_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_long_long_add - public static let nio_atomic_sub = catmc_nio_atomic_long_long_sub - public static let nio_atomic_exchange = catmc_nio_atomic_long_long_exchange - public static let nio_atomic_load = catmc_nio_atomic_long_long_load - public static let nio_atomic_store = catmc_nio_atomic_long_long_store + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_long_long_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_long_long_add + public static let nio_atomic_sub = catmc_nio_atomic_long_long_sub + public static let nio_atomic_exchange = catmc_nio_atomic_long_long_exchange + public static let nio_atomic_load = catmc_nio_atomic_long_long_load + public static let nio_atomic_store = catmc_nio_atomic_long_long_store } extension UInt64: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_unsigned_long_long - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_unsigned_long_long_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_unsigned_long_long_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_unsigned_long_long_add - public static let nio_atomic_sub = catmc_nio_atomic_unsigned_long_long_sub - public static let nio_atomic_exchange = catmc_nio_atomic_unsigned_long_long_exchange - public static let nio_atomic_load = catmc_nio_atomic_unsigned_long_long_load - public static let nio_atomic_store = catmc_nio_atomic_unsigned_long_long_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_unsigned_long_long_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_unsigned_long_long_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_unsigned_long_long_add + public static let nio_atomic_sub = catmc_nio_atomic_unsigned_long_long_sub + public static let nio_atomic_exchange = catmc_nio_atomic_unsigned_long_long_exchange + public static let nio_atomic_load = catmc_nio_atomic_unsigned_long_long_load + public static let nio_atomic_store = catmc_nio_atomic_unsigned_long_long_store } #if os(Windows) extension Int: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_intptr_t public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_intptr_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_intptr_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_intptr_t_add - public static let nio_atomic_sub = catmc_nio_atomic_intptr_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_intptr_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_intptr_t_load - public static let nio_atomic_store = catmc_nio_atomic_intptr_t_store + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_intptr_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_intptr_t_add + public static let nio_atomic_sub = catmc_nio_atomic_intptr_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_intptr_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_intptr_t_load + public static let nio_atomic_store = catmc_nio_atomic_intptr_t_store } extension UInt: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_uintptr_t public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_uintptr_t_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uintptr_t_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_uintptr_t_add - public static let nio_atomic_sub = catmc_nio_atomic_uintptr_t_sub - public static let nio_atomic_exchange = catmc_nio_atomic_uintptr_t_exchange - public static let nio_atomic_load = catmc_nio_atomic_uintptr_t_load - public static let nio_atomic_store = catmc_nio_atomic_uintptr_t_store + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_uintptr_t_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_uintptr_t_add + public static let nio_atomic_sub = catmc_nio_atomic_uintptr_t_sub + public static let nio_atomic_exchange = catmc_nio_atomic_uintptr_t_exchange + public static let nio_atomic_load = catmc_nio_atomic_uintptr_t_load + public static let nio_atomic_store = catmc_nio_atomic_uintptr_t_store } #else extension Int: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_long public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_long_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_long_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_long_add - public static let nio_atomic_sub = catmc_nio_atomic_long_sub - public static let nio_atomic_exchange = catmc_nio_atomic_long_exchange - public static let nio_atomic_load = catmc_nio_atomic_long_load - public static let nio_atomic_store = catmc_nio_atomic_long_store + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_long_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_long_add + public static let nio_atomic_sub = catmc_nio_atomic_long_sub + public static let nio_atomic_exchange = catmc_nio_atomic_long_exchange + public static let nio_atomic_load = catmc_nio_atomic_long_load + public static let nio_atomic_store = catmc_nio_atomic_long_store } extension UInt: NIOAtomicPrimitive { public typealias AtomicWrapper = catmc_nio_atomic_unsigned_long - public static let nio_atomic_create_with_existing_storage = catmc_nio_atomic_unsigned_long_create_with_existing_storage - public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_unsigned_long_compare_and_exchange - public static let nio_atomic_add = catmc_nio_atomic_unsigned_long_add - public static let nio_atomic_sub = catmc_nio_atomic_unsigned_long_sub - public static let nio_atomic_exchange = catmc_nio_atomic_unsigned_long_exchange - public static let nio_atomic_load = catmc_nio_atomic_unsigned_long_load - public static let nio_atomic_store = catmc_nio_atomic_unsigned_long_store + public static let nio_atomic_create_with_existing_storage = + catmc_nio_atomic_unsigned_long_create_with_existing_storage + public static let nio_atomic_compare_and_exchange = catmc_nio_atomic_unsigned_long_compare_and_exchange + public static let nio_atomic_add = catmc_nio_atomic_unsigned_long_add + public static let nio_atomic_sub = catmc_nio_atomic_unsigned_long_sub + public static let nio_atomic_exchange = catmc_nio_atomic_unsigned_long_exchange + public static let nio_atomic_load = catmc_nio_atomic_unsigned_long_load + public static let nio_atomic_store = catmc_nio_atomic_unsigned_long_store } #endif @@ -193,7 +201,7 @@ extension UInt: NIOAtomicPrimitive { /// By necessity, all atomic values are references: after all, it makes no /// sense to talk about managing an atomic value when each time it's modified /// the thread that modified it gets a local copy! -@available(*, deprecated, message:"please use ManagedAtomic from https://github.com/apple/swift-atomics instead") +@available(*, deprecated, message: "please use ManagedAtomic from https://github.com/apple/swift-atomics instead") public final class NIOAtomic { @usableFromInline typealias Manager = ManagedBufferPointer @@ -225,8 +233,8 @@ public final class NIOAtomic { /// match the current value and so no exchange occurred. @inlinable public func compareAndExchange(expected: T, desired: T) -> Bool { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_compare_and_exchange($0, expected, desired) + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_compare_and_exchange($0, expected, desired) } } @@ -241,8 +249,8 @@ public final class NIOAtomic { @inlinable @discardableResult public func add(_ rhs: T) -> T { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_add($0, rhs) + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_add($0, rhs) } } @@ -257,8 +265,8 @@ public final class NIOAtomic { @inlinable @discardableResult public func sub(_ rhs: T) -> T { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_sub($0, rhs) + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_sub($0, rhs) } } @@ -272,8 +280,8 @@ public final class NIOAtomic { /// - Returns: The value previously held by this object. @inlinable public func exchange(with value: T) -> T { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_exchange($0, value) + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_exchange($0, value) } } @@ -286,8 +294,8 @@ public final class NIOAtomic { /// - Returns: The value of this object @inlinable public func load() -> T { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_load($0) + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_load($0) } } @@ -299,9 +307,9 @@ public final class NIOAtomic { /// /// - Parameter value: The new value to set the object to. @inlinable - public func store(_ value: T) -> Void { - return Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { - return T.nio_atomic_store($0, value) + public func store(_ value: T) { + Manager(unsafeBufferObject: self).withUnsafeMutablePointerToElements { + T.nio_atomic_store($0, value) } } diff --git a/Sources/NIOConcurrencyHelpers/NIOLock.swift b/Sources/NIOConcurrencyHelpers/NIOLock.swift index db1a8f9811..0c3d06319d 100644 --- a/Sources/NIOConcurrencyHelpers/NIOLock.swift +++ b/Sources/NIOConcurrencyHelpers/NIOLock.swift @@ -34,61 +34,61 @@ typealias LockPrimitive = pthread_mutex_t #endif @usableFromInline -enum LockOperations { } +enum LockOperations {} extension LockOperations { @inlinable static func create(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) InitializeSRWLock(mutex) -#else + #else var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) } - + let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #else let err = pthread_mutex_destroy(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(mutex) -#else + #else let err = pthread_mutex_lock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(mutex) -#else + #else let err = pthread_mutex_unlock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } } @@ -122,49 +122,49 @@ extension LockOperations { // See also: https://github.com/apple/swift/pull/40000 @usableFromInline final class LockStorage: ManagedBuffer { - + @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in - return value + value } let storage = unsafeDowncast(buffer, to: Self.self) - + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } - + return storage } - + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } - + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } - + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } - + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in - return try body(lockPtr) + try body(lockPtr) } } - + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { try self.withUnsafeMutablePointers { valuePtr, lockPtr in @@ -175,7 +175,7 @@ final class LockStorage: ManagedBuffer { } } -extension LockStorage: @unchecked Sendable { } +extension LockStorage: @unchecked Sendable {} /// A threading lock based on `libpthread` instead of `libdispatch`. /// @@ -188,7 +188,7 @@ extension LockStorage: @unchecked Sendable { } public struct NIOLock { @usableFromInline internal let _storage: LockStorage - + /// Create a new lock. @inlinable public init() { @@ -215,7 +215,7 @@ public struct NIOLock { @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { - return try self._storage.withLockPrimitive(body) + try self._storage.withLockPrimitive(body) } } @@ -238,7 +238,7 @@ extension NIOLock { } @inlinable - public func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + public func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } diff --git a/Sources/NIOConcurrencyHelpers/NIOLockedValueBox.swift b/Sources/NIOConcurrencyHelpers/NIOLockedValueBox.swift index 06cf88529f..dc5916b6dd 100644 --- a/Sources/NIOConcurrencyHelpers/NIOLockedValueBox.swift +++ b/Sources/NIOConcurrencyHelpers/NIOLockedValueBox.swift @@ -22,7 +22,7 @@ /// acquire/release the lock in the correct place. ``NIOLockedValueBox`` makes /// that much easier. public struct NIOLockedValueBox { - + @usableFromInline internal let _storage: LockStorage @@ -35,7 +35,7 @@ public struct NIOLockedValueBox { /// Access the `Value`, allowing mutation of it. @inlinable public func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { - return try self._storage.withLockedValue(mutate) + try self._storage.withLockedValue(mutate) } /// Provides an unsafe view over the lock and its value. @@ -72,7 +72,7 @@ public struct NIOLockedValueBox { public func withValueAssumingLockIsAcquired( _ mutate: (_ value: inout Value) throws -> Result ) rethrows -> Result { - return try self._storage.withUnsafeMutablePointerToHeader { value in + try self._storage.withUnsafeMutablePointerToHeader { value in try mutate(&value.pointee) } } diff --git a/Sources/NIOConcurrencyHelpers/atomics.swift b/Sources/NIOConcurrencyHelpers/atomics.swift index 034747cff5..1eb276c8c9 100644 --- a/Sources/NIOConcurrencyHelpers/atomics.swift +++ b/Sources/NIOConcurrencyHelpers/atomics.swift @@ -16,14 +16,14 @@ import CNIOAtomics #if canImport(Darwin) import Darwin -fileprivate func sys_sched_yield() { +private func sys_sched_yield() { pthread_yield_np() } #elseif os(Windows) import ucrt import WinSDK -fileprivate func sys_sched_yield() { - Sleep(0) +private func sys_sched_yield() { + Sleep(0) } #else #if canImport(Glibc) @@ -34,7 +34,7 @@ import Musl #error("The concurrency atomics module was unable to identify your C library.") #endif -fileprivate func sys_sched_yield() { +private func sys_sched_yield() { _ = sched_yield() } #endif @@ -86,7 +86,7 @@ public struct UnsafeEmbeddedAtomic { /// match the current value and so no exchange occurred. @inlinable public func compareAndExchange(expected: T, desired: T) -> Bool { - return T.atomic_compare_and_exchange(self.value, expected, desired) + T.atomic_compare_and_exchange(self.value, expected, desired) } /// Atomically adds `rhs` to this object. @@ -100,7 +100,7 @@ public struct UnsafeEmbeddedAtomic { @discardableResult @inlinable public func add(_ rhs: T) -> T { - return T.atomic_add(self.value, rhs) + T.atomic_add(self.value, rhs) } /// Atomically subtracts `rhs` from this object. @@ -114,7 +114,7 @@ public struct UnsafeEmbeddedAtomic { @discardableResult @inlinable public func sub(_ rhs: T) -> T { - return T.atomic_sub(self.value, rhs) + T.atomic_sub(self.value, rhs) } /// Atomically exchanges `value` for the current value of this object. @@ -127,7 +127,7 @@ public struct UnsafeEmbeddedAtomic { /// - Returns: The value previously held by this object. @inlinable public func exchange(with value: T) -> T { - return T.atomic_exchange(self.value, value) + T.atomic_exchange(self.value, value) } /// Atomically loads and returns the value of this object. @@ -139,7 +139,7 @@ public struct UnsafeEmbeddedAtomic { /// - Returns: The value of this object @inlinable public func load() -> T { - return T.atomic_load(self.value) + T.atomic_load(self.value) } /// Atomically replaces the value of this object with `value`. @@ -150,7 +150,7 @@ public struct UnsafeEmbeddedAtomic { /// /// - Parameter value: The new value to set the object to. @inlinable - public func store(_ value: T) -> Void { + public func store(_ value: T) { T.atomic_store(self.value, value) } @@ -181,7 +181,7 @@ public struct UnsafeEmbeddedAtomic { /// By necessity, all atomic values are references: after all, it makes no /// sense to talk about managing an atomic value when each time it's modified /// the thread that modified it gets a local copy! -@available(*, deprecated, message:"please use ManagedAtomic from https://github.com/apple/swift-atomics instead") +@available(*, deprecated, message: "please use ManagedAtomic from https://github.com/apple/swift-atomics instead") public final class Atomic { @usableFromInline internal let embedded: UnsafeEmbeddedAtomic @@ -209,7 +209,7 @@ public final class Atomic { /// match the current value and so no exchange occurred. @inlinable public func compareAndExchange(expected: T, desired: T) -> Bool { - return self.embedded.compareAndExchange(expected: expected, desired: desired) + self.embedded.compareAndExchange(expected: expected, desired: desired) } /// Atomically adds `rhs` to this object. @@ -223,7 +223,7 @@ public final class Atomic { @discardableResult @inlinable public func add(_ rhs: T) -> T { - return self.embedded.add(rhs) + self.embedded.add(rhs) } /// Atomically subtracts `rhs` from this object. @@ -237,7 +237,7 @@ public final class Atomic { @discardableResult @inlinable public func sub(_ rhs: T) -> T { - return self.embedded.sub(rhs) + self.embedded.sub(rhs) } /// Atomically exchanges `value` for the current value of this object. @@ -250,7 +250,7 @@ public final class Atomic { /// - Returns: The value previously held by this object. @inlinable public func exchange(with value: T) -> T { - return self.embedded.exchange(with: value) + self.embedded.exchange(with: value) } /// Atomically loads and returns the value of this object. @@ -262,7 +262,7 @@ public final class Atomic { /// - Returns: The value of this object @inlinable public func load() -> T { - return self.embedded.load() + self.embedded.load() } /// Atomically replaces the value of this object with `value`. @@ -273,7 +273,7 @@ public final class Atomic { /// /// - Parameter value: The new value to set the object to. @inlinable - public func store(_ value: T) -> Void { + public func store(_ value: T) { self.embedded.store(value) } @@ -299,147 +299,147 @@ public protocol AtomicPrimitive { } extension Bool: AtomicPrimitive { - public static let atomic_create = catmc_atomic__Bool_create - public static let atomic_destroy = catmc_atomic__Bool_destroy + public static let atomic_create = catmc_atomic__Bool_create + public static let atomic_destroy = catmc_atomic__Bool_destroy public static let atomic_compare_and_exchange = catmc_atomic__Bool_compare_and_exchange - public static let atomic_add = catmc_atomic__Bool_add - public static let atomic_sub = catmc_atomic__Bool_sub - public static let atomic_exchange = catmc_atomic__Bool_exchange - public static let atomic_load = catmc_atomic__Bool_load - public static let atomic_store = catmc_atomic__Bool_store + public static let atomic_add = catmc_atomic__Bool_add + public static let atomic_sub = catmc_atomic__Bool_sub + public static let atomic_exchange = catmc_atomic__Bool_exchange + public static let atomic_load = catmc_atomic__Bool_load + public static let atomic_store = catmc_atomic__Bool_store } extension Int8: AtomicPrimitive { - public static let atomic_create = catmc_atomic_int_least8_t_create - public static let atomic_destroy = catmc_atomic_int_least8_t_destroy + public static let atomic_create = catmc_atomic_int_least8_t_create + public static let atomic_destroy = catmc_atomic_int_least8_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_int_least8_t_compare_and_exchange - public static let atomic_add = catmc_atomic_int_least8_t_add - public static let atomic_sub = catmc_atomic_int_least8_t_sub - public static let atomic_exchange = catmc_atomic_int_least8_t_exchange - public static let atomic_load = catmc_atomic_int_least8_t_load - public static let atomic_store = catmc_atomic_int_least8_t_store + public static let atomic_add = catmc_atomic_int_least8_t_add + public static let atomic_sub = catmc_atomic_int_least8_t_sub + public static let atomic_exchange = catmc_atomic_int_least8_t_exchange + public static let atomic_load = catmc_atomic_int_least8_t_load + public static let atomic_store = catmc_atomic_int_least8_t_store } extension UInt8: AtomicPrimitive { - public static let atomic_create = catmc_atomic_uint_least8_t_create - public static let atomic_destroy = catmc_atomic_uint_least8_t_destroy + public static let atomic_create = catmc_atomic_uint_least8_t_create + public static let atomic_destroy = catmc_atomic_uint_least8_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_uint_least8_t_compare_and_exchange - public static let atomic_add = catmc_atomic_uint_least8_t_add - public static let atomic_sub = catmc_atomic_uint_least8_t_sub - public static let atomic_exchange = catmc_atomic_uint_least8_t_exchange - public static let atomic_load = catmc_atomic_uint_least8_t_load - public static let atomic_store = catmc_atomic_uint_least8_t_store + public static let atomic_add = catmc_atomic_uint_least8_t_add + public static let atomic_sub = catmc_atomic_uint_least8_t_sub + public static let atomic_exchange = catmc_atomic_uint_least8_t_exchange + public static let atomic_load = catmc_atomic_uint_least8_t_load + public static let atomic_store = catmc_atomic_uint_least8_t_store } extension Int16: AtomicPrimitive { - public static let atomic_create = catmc_atomic_int_least16_t_create - public static let atomic_destroy = catmc_atomic_int_least16_t_destroy + public static let atomic_create = catmc_atomic_int_least16_t_create + public static let atomic_destroy = catmc_atomic_int_least16_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_int_least16_t_compare_and_exchange - public static let atomic_add = catmc_atomic_int_least16_t_add - public static let atomic_sub = catmc_atomic_int_least16_t_sub - public static let atomic_exchange = catmc_atomic_int_least16_t_exchange - public static let atomic_load = catmc_atomic_int_least16_t_load - public static let atomic_store = catmc_atomic_int_least16_t_store + public static let atomic_add = catmc_atomic_int_least16_t_add + public static let atomic_sub = catmc_atomic_int_least16_t_sub + public static let atomic_exchange = catmc_atomic_int_least16_t_exchange + public static let atomic_load = catmc_atomic_int_least16_t_load + public static let atomic_store = catmc_atomic_int_least16_t_store } extension UInt16: AtomicPrimitive { - public static let atomic_create = catmc_atomic_uint_least16_t_create - public static let atomic_destroy = catmc_atomic_uint_least16_t_destroy + public static let atomic_create = catmc_atomic_uint_least16_t_create + public static let atomic_destroy = catmc_atomic_uint_least16_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_uint_least16_t_compare_and_exchange - public static let atomic_add = catmc_atomic_uint_least16_t_add - public static let atomic_sub = catmc_atomic_uint_least16_t_sub - public static let atomic_exchange = catmc_atomic_uint_least16_t_exchange - public static let atomic_load = catmc_atomic_uint_least16_t_load - public static let atomic_store = catmc_atomic_uint_least16_t_store + public static let atomic_add = catmc_atomic_uint_least16_t_add + public static let atomic_sub = catmc_atomic_uint_least16_t_sub + public static let atomic_exchange = catmc_atomic_uint_least16_t_exchange + public static let atomic_load = catmc_atomic_uint_least16_t_load + public static let atomic_store = catmc_atomic_uint_least16_t_store } extension Int32: AtomicPrimitive { - public static let atomic_create = catmc_atomic_int_least32_t_create - public static let atomic_destroy = catmc_atomic_int_least32_t_destroy + public static let atomic_create = catmc_atomic_int_least32_t_create + public static let atomic_destroy = catmc_atomic_int_least32_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_int_least32_t_compare_and_exchange - public static let atomic_add = catmc_atomic_int_least32_t_add - public static let atomic_sub = catmc_atomic_int_least32_t_sub - public static let atomic_exchange = catmc_atomic_int_least32_t_exchange - public static let atomic_load = catmc_atomic_int_least32_t_load - public static let atomic_store = catmc_atomic_int_least32_t_store + public static let atomic_add = catmc_atomic_int_least32_t_add + public static let atomic_sub = catmc_atomic_int_least32_t_sub + public static let atomic_exchange = catmc_atomic_int_least32_t_exchange + public static let atomic_load = catmc_atomic_int_least32_t_load + public static let atomic_store = catmc_atomic_int_least32_t_store } extension UInt32: AtomicPrimitive { - public static let atomic_create = catmc_atomic_uint_least32_t_create - public static let atomic_destroy = catmc_atomic_uint_least32_t_destroy + public static let atomic_create = catmc_atomic_uint_least32_t_create + public static let atomic_destroy = catmc_atomic_uint_least32_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_uint_least32_t_compare_and_exchange - public static let atomic_add = catmc_atomic_uint_least32_t_add - public static let atomic_sub = catmc_atomic_uint_least32_t_sub - public static let atomic_exchange = catmc_atomic_uint_least32_t_exchange - public static let atomic_load = catmc_atomic_uint_least32_t_load - public static let atomic_store = catmc_atomic_uint_least32_t_store + public static let atomic_add = catmc_atomic_uint_least32_t_add + public static let atomic_sub = catmc_atomic_uint_least32_t_sub + public static let atomic_exchange = catmc_atomic_uint_least32_t_exchange + public static let atomic_load = catmc_atomic_uint_least32_t_load + public static let atomic_store = catmc_atomic_uint_least32_t_store } extension Int64: AtomicPrimitive { - public static let atomic_create = catmc_atomic_long_long_create - public static let atomic_destroy = catmc_atomic_long_long_destroy + public static let atomic_create = catmc_atomic_long_long_create + public static let atomic_destroy = catmc_atomic_long_long_destroy public static let atomic_compare_and_exchange = catmc_atomic_long_long_compare_and_exchange - public static let atomic_add = catmc_atomic_long_long_add - public static let atomic_sub = catmc_atomic_long_long_sub - public static let atomic_exchange = catmc_atomic_long_long_exchange - public static let atomic_load = catmc_atomic_long_long_load - public static let atomic_store = catmc_atomic_long_long_store + public static let atomic_add = catmc_atomic_long_long_add + public static let atomic_sub = catmc_atomic_long_long_sub + public static let atomic_exchange = catmc_atomic_long_long_exchange + public static let atomic_load = catmc_atomic_long_long_load + public static let atomic_store = catmc_atomic_long_long_store } extension UInt64: AtomicPrimitive { - public static let atomic_create = catmc_atomic_unsigned_long_long_create - public static let atomic_destroy = catmc_atomic_unsigned_long_long_destroy + public static let atomic_create = catmc_atomic_unsigned_long_long_create + public static let atomic_destroy = catmc_atomic_unsigned_long_long_destroy public static let atomic_compare_and_exchange = catmc_atomic_unsigned_long_long_compare_and_exchange - public static let atomic_add = catmc_atomic_unsigned_long_long_add - public static let atomic_sub = catmc_atomic_unsigned_long_long_sub - public static let atomic_exchange = catmc_atomic_unsigned_long_long_exchange - public static let atomic_load = catmc_atomic_unsigned_long_long_load - public static let atomic_store = catmc_atomic_unsigned_long_long_store + public static let atomic_add = catmc_atomic_unsigned_long_long_add + public static let atomic_sub = catmc_atomic_unsigned_long_long_sub + public static let atomic_exchange = catmc_atomic_unsigned_long_long_exchange + public static let atomic_load = catmc_atomic_unsigned_long_long_load + public static let atomic_store = catmc_atomic_unsigned_long_long_store } #if os(Windows) extension Int: AtomicPrimitive { - public static let atomic_create = catmc_atomic_intptr_t_create - public static let atomic_destroy = catmc_atomic_intptr_t_destroy + public static let atomic_create = catmc_atomic_intptr_t_create + public static let atomic_destroy = catmc_atomic_intptr_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_intptr_t_compare_and_exchange - public static let atomic_add = catmc_atomic_intptr_t_add - public static let atomic_sub = catmc_atomic_intptr_t_sub - public static let atomic_exchange = catmc_atomic_intptr_t_exchange - public static let atomic_load = catmc_atomic_intptr_t_load - public static let atomic_store = catmc_atomic_intptr_t_store + public static let atomic_add = catmc_atomic_intptr_t_add + public static let atomic_sub = catmc_atomic_intptr_t_sub + public static let atomic_exchange = catmc_atomic_intptr_t_exchange + public static let atomic_load = catmc_atomic_intptr_t_load + public static let atomic_store = catmc_atomic_intptr_t_store } extension UInt: AtomicPrimitive { - public static let atomic_create = catmc_atomic_uintptr_t_create - public static let atomic_destroy = catmc_atomic_uintptr_t_destroy + public static let atomic_create = catmc_atomic_uintptr_t_create + public static let atomic_destroy = catmc_atomic_uintptr_t_destroy public static let atomic_compare_and_exchange = catmc_atomic_uintptr_t_compare_and_exchange - public static let atomic_add = catmc_atomic_uintptr_t_add - public static let atomic_sub = catmc_atomic_uintptr_t_sub - public static let atomic_exchange = catmc_atomic_uintptr_t_exchange - public static let atomic_load = catmc_atomic_uintptr_t_load - public static let atomic_store = catmc_atomic_uintptr_t_store + public static let atomic_add = catmc_atomic_uintptr_t_add + public static let atomic_sub = catmc_atomic_uintptr_t_sub + public static let atomic_exchange = catmc_atomic_uintptr_t_exchange + public static let atomic_load = catmc_atomic_uintptr_t_load + public static let atomic_store = catmc_atomic_uintptr_t_store } #else extension Int: AtomicPrimitive { - public static let atomic_create = catmc_atomic_long_create - public static let atomic_destroy = catmc_atomic_long_destroy + public static let atomic_create = catmc_atomic_long_create + public static let atomic_destroy = catmc_atomic_long_destroy public static let atomic_compare_and_exchange = catmc_atomic_long_compare_and_exchange - public static let atomic_add = catmc_atomic_long_add - public static let atomic_sub = catmc_atomic_long_sub - public static let atomic_exchange = catmc_atomic_long_exchange - public static let atomic_load = catmc_atomic_long_load - public static let atomic_store = catmc_atomic_long_store + public static let atomic_add = catmc_atomic_long_add + public static let atomic_sub = catmc_atomic_long_sub + public static let atomic_exchange = catmc_atomic_long_exchange + public static let atomic_load = catmc_atomic_long_load + public static let atomic_store = catmc_atomic_long_store } extension UInt: AtomicPrimitive { - public static let atomic_create = catmc_atomic_unsigned_long_create - public static let atomic_destroy = catmc_atomic_unsigned_long_destroy + public static let atomic_create = catmc_atomic_unsigned_long_create + public static let atomic_destroy = catmc_atomic_unsigned_long_destroy public static let atomic_compare_and_exchange = catmc_atomic_unsigned_long_compare_and_exchange - public static let atomic_add = catmc_atomic_unsigned_long_add - public static let atomic_sub = catmc_atomic_unsigned_long_sub - public static let atomic_exchange = catmc_atomic_unsigned_long_exchange - public static let atomic_load = catmc_atomic_unsigned_long_load - public static let atomic_store = catmc_atomic_unsigned_long_store + public static let atomic_add = catmc_atomic_unsigned_long_add + public static let atomic_sub = catmc_atomic_unsigned_long_sub + public static let atomic_exchange = catmc_atomic_unsigned_long_exchange + public static let atomic_load = catmc_atomic_unsigned_long_load + public static let atomic_store = catmc_atomic_unsigned_long_store } #endif @@ -447,7 +447,11 @@ extension UInt: AtomicPrimitive { /// /// - warning: The use of `AtomicBox` should be avoided because it requires an implementation of a spin-lock /// (more precisely a CAS loop) to operate correctly. -@available(*, deprecated, message: "AtomicBox is deprecated without replacement because the original implementation doesn't work.") +@available( + *, + deprecated, + message: "AtomicBox is deprecated without replacement because the original implementation doesn't work." +) public final class AtomicBox { private let storage: NIOAtomic @@ -482,7 +486,7 @@ public final class AtomicBox { /// - Returns: `True` if the exchange occurred, or `False` if `expected` did not /// match the current value and so no exchange occurred. public func compareAndExchange(expected: T, desired: T) -> Bool { - return withExtendedLifetime(desired) { + withExtendedLifetime(desired) { let expectedPtr = Unmanaged.passUnretained(expected) let desiredPtr = Unmanaged.passUnretained(desired) let expectedPtrBits = UInt(bitPattern: expectedPtr.toOpaque()) @@ -573,7 +577,7 @@ public final class AtomicBox { // step 3: Now, let's exchange it back into the store let casWorked = self.storage.compareAndExchange(expected: 0, desired: ptrBits) - precondition(casWorked) // this _has_ to work because `0` means we own it exclusively. + precondition(casWorked) // this _has_ to work because `0` means we own it exclusively. return value } @@ -587,7 +591,7 @@ public final class AtomicBox { /// 100% CPU load. /// /// - Parameter value: The new value to set the object to. - public func store(_ value: T) -> Void { + public func store(_ value: T) { _ = self.exchange(with: value) } } diff --git a/Sources/NIOConcurrencyHelpers/lock.swift b/Sources/NIOConcurrencyHelpers/lock.swift index 5df4af7b15..326603775f 100644 --- a/Sources/NIOConcurrencyHelpers/lock.swift +++ b/Sources/NIOConcurrencyHelpers/lock.swift @@ -33,19 +33,19 @@ import Musl /// `SRWLOCK` type. @available(*, deprecated, renamed: "NIOLock") public final class Lock { -#if os(Windows) + #if os(Windows) fileprivate let mutex: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: 1) -#else + #else fileprivate let mutex: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: 1) -#endif + #endif /// Create a new lock. public init() { -#if os(Windows) + #if os(Windows) InitializeSRWLock(self.mutex) -#else + #else var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { @@ -54,16 +54,16 @@ public final class Lock { let err = pthread_mutex_init(self.mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } deinit { -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #else let err = pthread_mutex_destroy(self.mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif mutex.deallocate() } @@ -72,12 +72,12 @@ public final class Lock { /// Whenever possible, consider using `withLock` instead of this method and /// `unlock`, to simplify lock handling. public func lock() { -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(self.mutex) -#else + #else let err = pthread_mutex_lock(self.mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } /// Release the lock. @@ -85,12 +85,12 @@ public final class Lock { /// Whenever possible, consider using `withLock` instead of this method and /// `lock`, to simplify lock handling. public func unlock() { -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(self.mutex) -#else + #else let err = pthread_mutex_unlock(self.mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } /// Acquire the lock for the duration of the given block. @@ -112,7 +112,7 @@ public final class Lock { // specialise Void return (for performance) @inlinable - public func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + public func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } @@ -124,13 +124,13 @@ public final class Lock { public final class ConditionLock { private var _value: T private let mutex: NIOLock -#if os(Windows) + #if os(Windows) private let cond: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: 1) -#else + #else private let cond: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: 1) -#endif + #endif /// Create the lock, and initialize the state variable to `value`. /// @@ -138,21 +138,21 @@ public final class ConditionLock { public init(value: T) { self._value = value self.mutex = NIOLock() -#if os(Windows) + #if os(Windows) InitializeConditionVariable(self.cond) -#else + #else let err = pthread_cond_init(self.cond, nil) precondition(err == 0, "\(#function) failed in pthread_cond with error \(err)") -#endif + #endif } deinit { -#if os(Windows) + #if os(Windows) // condition variables do not need to be explicitly destroyed -#else + #else let err = pthread_cond_destroy(self.cond) precondition(err == 0, "\(#function) failed in pthread_cond with error \(err)") -#endif + #endif self.cond.deallocate() } @@ -190,13 +190,13 @@ public final class ConditionLock { break } self.mutex.withLockPrimitive { mutex in -#if os(Windows) + #if os(Windows) let result = SleepConditionVariableSRW(self.cond, mutex, INFINITE, 0) precondition(result, "\(#function) failed in SleepConditionVariableSRW with error \(GetLastError())") -#else + #else let err = pthread_cond_wait(self.cond, mutex) precondition(err == 0, "\(#function) failed in pthread_cond with error \(err)") -#endif + #endif } } } @@ -212,7 +212,7 @@ public final class ConditionLock { public func lock(whenValue wantedValue: T, timeoutSeconds: Double) -> Bool { precondition(timeoutSeconds >= 0) -#if os(Windows) + #if os(Windows) var dwMilliseconds: DWORD = DWORD(timeoutSeconds * 1000) self.lock() @@ -222,10 +222,14 @@ public final class ConditionLock { } let dwWaitStart = timeGetTime() - if !SleepConditionVariableSRW(self.cond, self.mutex._storage.mutex, - dwMilliseconds, 0) { + if !SleepConditionVariableSRW( + self.cond, + self.mutex._storage.mutex, + dwMilliseconds, + 0 + ) { let dwError = GetLastError() - if (dwError == ERROR_TIMEOUT) { + if dwError == ERROR_TIMEOUT { self.unlock() return false } @@ -235,18 +239,20 @@ public final class ConditionLock { // NOTE: this may be a spurious wakeup, adjust the timeout accordingly dwMilliseconds = dwMilliseconds - (timeGetTime() - dwWaitStart) } -#else - let nsecPerSec: Int64 = 1000000000 + #else + let nsecPerSec: Int64 = 1_000_000_000 self.lock() - /* the timeout as a (seconds, nano seconds) pair */ + // the timeout as a (seconds, nano seconds) pair let timeoutNS = Int64(timeoutSeconds * Double(nsecPerSec)) var curTime = timeval() gettimeofday(&curTime, nil) let allNSecs: Int64 = timeoutNS + Int64(curTime.tv_usec) * 1000 - var timeoutAbs = timespec(tv_sec: curTime.tv_sec + Int((allNSecs / nsecPerSec)), - tv_nsec: Int(allNSecs % nsecPerSec)) + var timeoutAbs = timespec( + tv_sec: curTime.tv_sec + Int((allNSecs / nsecPerSec)), + tv_nsec: Int(allNSecs % nsecPerSec) + ) assert(timeoutAbs.tv_nsec >= 0 && timeoutAbs.tv_nsec < Int(nsecPerSec)) assert(timeoutAbs.tv_sec >= curTime.tv_sec) return self.mutex.withLockPrimitive { mutex -> Bool in @@ -265,7 +271,7 @@ public final class ConditionLock { } } } -#endif + #endif } /// Release the lock, setting the state variable to `newValue`. @@ -275,12 +281,12 @@ public final class ConditionLock { public func unlock(withValue newValue: T) { self._value = newValue self.unlock() -#if os(Windows) + #if os(Windows) WakeAllConditionVariable(self.cond) -#else + #else let err = pthread_cond_broadcast(self.cond) precondition(err == 0, "\(#function) failed in pthread_cond with error \(err)") -#endif + #endif } } @@ -291,7 +297,12 @@ public final class ConditionLock { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } @available(*, deprecated) diff --git a/Sources/NIOCore/AddressedEnvelope.swift b/Sources/NIOCore/AddressedEnvelope.swift index d705009a0e..47b167a766 100644 --- a/Sources/NIOCore/AddressedEnvelope.swift +++ b/Sources/NIOCore/AddressedEnvelope.swift @@ -27,19 +27,19 @@ public struct AddressedEnvelope { self.remoteAddress = remoteAddress self.data = data } - + public init(remoteAddress: SocketAddress, data: DataType, metadata: Metadata?) { self.remoteAddress = remoteAddress self.data = data self.metadata = metadata } - + /// Any metadata associated with an `AddressedEnvelope` public struct Metadata: Hashable, Sendable { /// Details of any congestion state. public var ecnState: NIOExplicitCongestionNotificationState public var packetInfo: NIOPacketInfo? - + public init(ecnState: NIOExplicitCongestionNotificationState) { self.ecnState = ecnState self.packetInfo = nil @@ -54,7 +54,7 @@ public struct AddressedEnvelope { extension AddressedEnvelope: CustomStringConvertible { public var description: String { - return "AddressedEnvelope { remoteAddress: \(self.remoteAddress), data: \(self.data) }" + "AddressedEnvelope { remoteAddress: \(self.remoteAddress), data: \(self.data) }" } } diff --git a/Sources/NIOCore/AsyncAwaitSupport.swift b/Sources/NIOCore/AsyncAwaitSupport.swift index f235d8cfd7..abdc20c2d6 100644 --- a/Sources/NIOCore/AsyncAwaitSupport.swift +++ b/Sources/NIOCore/AsyncAwaitSupport.swift @@ -20,7 +20,7 @@ extension EventLoopFuture { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable public func get() async throws -> Value { - return try await withUnsafeThrowingContinuation { (cont: UnsafeContinuation, Error>) in + try await withUnsafeThrowingContinuation { (cont: UnsafeContinuation, Error>) in self.whenComplete { result in switch result { case .success(let value): @@ -38,7 +38,7 @@ extension EventLoopGroup { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable public func shutdownGracefully() async throws { - return try await withCheckedThrowingContinuation { cont in + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in self.shutdownGracefully { error in if let error = error { cont.resume(throwing: error) @@ -97,7 +97,7 @@ extension Channel { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @inlinable public func getOption(_ option: Option) async throws -> Option.Value { - return try await self.getOption(option).get() + try await self.getOption(option).get() } } @@ -162,9 +162,11 @@ extension ChannelOutboundInvoker { extension ChannelPipeline { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @preconcurrency - public func addHandler(_ handler: ChannelHandler & Sendable, - name: String? = nil, - position: ChannelPipeline.Position = .last) async throws { + public func addHandler( + _ handler: ChannelHandler & Sendable, + name: String? = nil, + position: ChannelPipeline.Position = .last + ) async throws { try await self.addHandler(handler, name: name, position: position).get() } @@ -185,43 +187,62 @@ extension ChannelPipeline { } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @available(*, deprecated, message: "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop") + @available( + *, + deprecated, + message: + "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop" + ) @preconcurrency public func context(handler: ChannelHandler & Sendable) async throws -> ChannelHandlerContext { - return try await self.context(handler: handler).map { UnsafeTransfer($0) }.get().wrappedValue + try await self.context(handler: handler).map { UnsafeTransfer($0) }.get().wrappedValue } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @available(*, deprecated, message: "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop") + @available( + *, + deprecated, + message: + "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop" + ) public func context(name: String) async throws -> ChannelHandlerContext { - return try await self.context(name: name).map { UnsafeTransfer($0) }.get().wrappedValue + try await self.context(name: name).map { UnsafeTransfer($0) }.get().wrappedValue } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @available(*, deprecated, message: "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop") + @available( + *, + deprecated, + message: + "ChannelHandlerContext is not Sendable and it is therefore not safe to be used outside of its EventLoop" + ) @inlinable public func context(handlerType: Handler.Type) async throws -> ChannelHandlerContext { - return try await self.context(handlerType: handlerType).map { UnsafeTransfer($0) }.get().wrappedValue + try await self.context(handlerType: handlerType).map { UnsafeTransfer($0) }.get().wrappedValue } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @preconcurrency - public func addHandlers(_ handlers: [ChannelHandler & Sendable], - position: ChannelPipeline.Position = .last) async throws { + public func addHandlers( + _ handlers: [ChannelHandler & Sendable], + position: ChannelPipeline.Position = .last + ) async throws { try await self.addHandlers(handlers, position: position).map { UnsafeTransfer($0) }.get().wrappedValue } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @preconcurrency - public func addHandlers(_ handlers: (ChannelHandler & Sendable)..., - position: ChannelPipeline.Position = .last) async throws { + public func addHandlers( + _ handlers: (ChannelHandler & Sendable)..., + position: ChannelPipeline.Position = .last + ) async throws { try await self.addHandlers(handlers, position: position) } } public struct NIOTooManyBytesError: Error, Hashable { - public init() {} - } + public init() {} +} @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension AsyncSequence where Element: RandomAccessCollection, Element.Element == UInt8 { @@ -246,7 +267,7 @@ extension AsyncSequence where Element: RandomAccessCollection, Element.Element = accumulationBuffer.writeBytes(fragment) } } - + /// Accumulates an `AsyncSequence` of `RandomAccessCollection`s into a single ``ByteBuffer``. /// - Parameters: /// - maxBytes: The maximum number of bytes this method is allowed to accumulate @@ -289,7 +310,7 @@ extension AsyncSequence where Element == ByteBuffer { accumulationBuffer.writeImmutableBuffer(fragment) } } - + /// Accumulates an `AsyncSequence` of ``ByteBuffer``s into a single ``ByteBuffer``. /// - Parameters: /// - maxBytes: The maximum number of bytes this method is allowed to accumulate @@ -309,7 +330,7 @@ extension AsyncSequence where Element == ByteBuffer { guard head.readableBytes <= maxBytes else { throw NIOTooManyBytesError() } - + let tail = AsyncSequenceFromIterator(iterator) // it is guaranteed that // `maxBytes >= 0 && head.readableBytes >= 0 && head.readableBytes <= maxBytes` @@ -324,13 +345,13 @@ extension AsyncSequence where Element == ByteBuffer { @usableFromInline struct AsyncSequenceFromIterator: AsyncSequence { @usableFromInline typealias Element = AsyncIterator.Element - + @usableFromInline var iterator: AsyncIterator - + @inlinable init(_ iterator: AsyncIterator) { self.iterator = iterator } - + @inlinable func makeAsyncIterator() -> AsyncIterator { self.iterator } @@ -339,7 +360,9 @@ struct AsyncSequenceFromIterator: AsyncSeq @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension EventLoop { @inlinable - public func makeFutureWithTask(_ body: @Sendable @escaping () async throws -> Return) -> EventLoopFuture { + public func makeFutureWithTask( + _ body: @Sendable @escaping () async throws -> Return + ) -> EventLoopFuture { let promise = self.makePromise(of: Return.self) promise.completeWithTask(body) return promise.futureResult diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift index eda10858ab..10354f0760 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannel.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannel.swift @@ -58,7 +58,10 @@ public struct NIOAsyncChannel: Sendable { /// - inboundType: The ``NIOAsyncChannel/inbound`` message's type. /// - outboundType: The ``NIOAsyncChannel/outbound`` message's type. public init( - backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init(lowWatermark: 2, highWatermark: 10), + backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark = .init( + lowWatermark: 2, + highWatermark: 10 + ), isOutboundHalfClosureEnabled: Bool = false, inboundType: Inbound.Type = Inbound.self, outboundType: Outbound.Type = Outbound.self @@ -147,7 +150,12 @@ public struct NIOAsyncChannel: Sendable { /// - Parameters: /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. - @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + renamed: "init(wrappingChannelSynchronously:configuration:)", + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) @inlinable public init( synchronouslyWrapping channel: Channel, @@ -173,7 +181,12 @@ public struct NIOAsyncChannel: Sendable { /// - channel: The ``Channel`` to wrap. /// - configuration: The ``NIOAsyncChannel``s configuration. @inlinable - @available(*, deprecated, renamed: "init(wrappingChannelSynchronously:configuration:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + renamed: "init(wrappingChannelSynchronously:configuration:)", + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) public init( synchronouslyWrapping channel: Channel, configuration: Configuration = .init() @@ -206,7 +219,11 @@ public struct NIOAsyncChannel: Sendable { /// /// - Important: This is not considered stable API and should not be used. @inlinable - @available(*, deprecated, message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) public static func _wrapAsyncChannelWithTransformations( synchronouslyWrapping channel: Channel, backPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, @@ -214,12 +231,14 @@ public struct NIOAsyncChannel: Sendable { channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannel where Outbound == Never { channel.eventLoop.preconditionInEventLoop() - let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( - backPressureStrategy: backPressureStrategy, - isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, - closeOnDeinit: true, - channelReadTransformation: channelReadTransformation - ) + let (inboundStream, outboundWriter): + (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = + try channel._syncAddAsyncHandlersWithTransformations( + backPressureStrategy: backPressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: true, + channelReadTransformation: channelReadTransformation + ) outboundWriter.finish() @@ -242,12 +261,14 @@ public struct NIOAsyncChannel: Sendable { channelReadTransformation: @Sendable @escaping (Channel) -> EventLoopFuture ) throws -> NIOAsyncChannel where Outbound == Never { channel.eventLoop.preconditionInEventLoop() - let (inboundStream, outboundWriter): (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = try channel._syncAddAsyncHandlersWithTransformations( - backPressureStrategy: backPressureStrategy, - isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, - closeOnDeinit: false, - channelReadTransformation: channelReadTransformation - ) + let (inboundStream, outboundWriter): + (NIOAsyncChannelInboundStream, NIOAsyncChannelOutboundWriter) = + try channel._syncAddAsyncHandlersWithTransformations( + backPressureStrategy: backPressureStrategy, + isOutboundHalfClosureEnabled: isOutboundHalfClosureEnabled, + closeOnDeinit: false, + channelReadTransformation: channelReadTransformation + ) outboundWriter.finish() @@ -264,7 +285,8 @@ public struct NIOAsyncChannel: Sendable { /// /// - Parameter body: A closure that gets scoped access to the inbound and outbound. public func executeThenClose( - _ body: (_ inbound: NIOAsyncChannelInboundStream, _ outbound: NIOAsyncChannelOutboundWriter) async throws -> Result + _ body: (_ inbound: NIOAsyncChannelInboundStream, _ outbound: NIOAsyncChannelOutboundWriter) + async throws -> Result ) async throws -> Result { let result: Result do { @@ -290,7 +312,11 @@ public struct NIOAsyncChannel: Sendable { } return result } +} +// swift-format-ignore: AmbiguousTrailingClosureOverload +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOAsyncChannel { /// Provides scoped access to the inbound side of the underlying ``Channel``. /// /// - Important: After this method returned the underlying ``Channel`` will be closed. diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift index 5f76313a73..63d243b74e 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift @@ -147,7 +147,6 @@ extension NIOAsyncChannelHandler: ChannelInboundHandler { context.fireChannelInactive() } - @inlinable func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { @@ -298,7 +297,6 @@ extension NIOAsyncChannelHandler { } } - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncChannelHandler { @inlinable @@ -341,7 +339,9 @@ struct NIOAsyncChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequ let _produceMore: () -> Void @inlinable - init(handler: NIOAsyncChannelHandler) { + init( + handler: NIOAsyncChannelHandler + ) { self.eventLoop = handler.eventLoop self._didTerminate = handler._didTerminate self._produceMore = handler._produceMore diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift index 7134359a2c..7bda21cfdd 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift @@ -151,7 +151,7 @@ extension NIOAsyncChannelInboundStream: AsyncSequence { @inlinable public func makeAsyncIterator() -> AsyncIterator { - return AsyncIterator(self._backing) + AsyncIterator(self._backing) } } diff --git a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift index 8ef5f12cf8..dfdeeb0fda 100644 --- a/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift +++ b/Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift @@ -144,7 +144,8 @@ public struct NIOAsyncChannelOutboundWriter: Sendable { /// /// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again. @inlinable - public func write(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut { + public func write(contentsOf sequence: Writes) async throws + where Writes.Element == OutboundOut { for try await data in sequence { try await self.write(data) } diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift index c724f7798f..d8a41e434b 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncSequenceProducer.swift @@ -112,7 +112,7 @@ public struct NIOAsyncSequenceProducer< public let sequence: NIOAsyncSequenceProducer @usableFromInline - /* fileprivate */ internal init( + internal init( source: Source, sequence: NIOAsyncSequenceProducer ) { @@ -122,12 +122,13 @@ public struct NIOAsyncSequenceProducer< } @usableFromInline - /* private */ internal let _throwingSequence: NIOThrowingAsyncSequenceProducer< - Element, - Never, - Strategy, - Delegate - > + internal let _throwingSequence: + NIOThrowingAsyncSequenceProducer< + Element, + Never, + Strategy, + Delegate + > /// Initializes a new ``NIOAsyncSequenceProducer`` and a ``NIOAsyncSequenceProducer/Source``. /// @@ -175,7 +176,12 @@ public struct NIOAsyncSequenceProducer< /// - delegate: The delegate of the sequence /// - Returns: A ``NIOAsyncSequenceProducer/Source`` and a ``NIOAsyncSequenceProducer``. @inlinable - @available(*, deprecated, renamed: "makeSequence(elementType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + renamed: "makeSequence(elementType:backPressureStrategy:finishOnDeinit:delegate:)", + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) public static func makeSequence( elementType: Element.Type = Element.self, backPressureStrategy: Strategy, @@ -194,7 +200,7 @@ public struct NIOAsyncSequenceProducer< } @inlinable - /* private */ internal init( + internal init( throwingSequence: NIOThrowingAsyncSequenceProducer ) { self._throwingSequence = throwingSequence @@ -212,12 +218,13 @@ extension NIOAsyncSequenceProducer: AsyncSequence { extension NIOAsyncSequenceProducer { public struct AsyncIterator: AsyncIteratorProtocol { @usableFromInline - /* private */ internal let _throwingIterator: NIOThrowingAsyncSequenceProducer< - Element, - Never, - Strategy, - Delegate - >.AsyncIterator + internal let _throwingIterator: + NIOThrowingAsyncSequenceProducer< + Element, + Never, + Strategy, + Delegate + >.AsyncIterator fileprivate init( throwingIterator: NIOThrowingAsyncSequenceProducer< @@ -233,7 +240,7 @@ extension NIOAsyncSequenceProducer { @inlinable public func next() async -> Element? { // this call will only throw if cancelled and we want to just return nil in that case - return try? await self._throwingIterator.next() + try? await self._throwingIterator.next() } } } @@ -253,10 +260,10 @@ extension NIOAsyncSequenceProducer { >.Source @usableFromInline - /* private */ internal var _throwingSource: ThrowingSource + internal var _throwingSource: ThrowingSource @usableFromInline - /* fileprivate */ internal init(throwingSource: ThrowingSource) { + internal init(throwingSource: ThrowingSource) { self._throwingSource = throwingSource } diff --git a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift index 163dcaf076..64b74b1818 100644 --- a/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift +++ b/Sources/NIOCore/AsyncSequences/NIOAsyncWriter.swift @@ -96,7 +96,7 @@ public struct NIOAsyncWriterError: Error, Hashable, CustomStringConvertible { @inlinable public static func == (lhs: NIOAsyncWriterError, rhs: NIOAsyncWriterError) -> Bool { - return lhs._code == rhs._code + lhs._code == rhs._code } @inlinable @@ -147,7 +147,7 @@ public struct NIOAsyncWriter< public let writer: NIOAsyncWriter @inlinable - /* fileprivate */ internal init( + internal init( sink: Sink, writer: NIOAsyncWriter ) { @@ -158,7 +158,7 @@ public struct NIOAsyncWriter< /// This class is needed to hook the deinit to observe once all references to the ``NIOAsyncWriter`` are dropped. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { + internal final class InternalClass: Sendable { @usableFromInline internal let _storage: Storage @@ -183,10 +183,10 @@ public struct NIOAsyncWriter< } @usableFromInline - /* private */ internal let _internalClass: InternalClass + internal let _internalClass: InternalClass @inlinable - /* private */ internal var _storage: Storage { + internal var _storage: Storage { self._internalClass._storage } @@ -203,7 +203,12 @@ public struct NIOAsyncWriter< /// - delegate: The delegate of the writer. /// - Returns: A ``NIOAsyncWriter/NewWriter``. @inlinable - @available(*, deprecated, renamed: "makeWriter(elementType:isWritable:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + renamed: "makeWriter(elementType:isWritable:finishOnDeinit:delegate:)", + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) public static func makeWriter( elementType: Element.Type = Element.self, isWritable: Bool, @@ -251,7 +256,7 @@ public struct NIOAsyncWriter< } @inlinable - /* private */ internal init( + internal init( isWritable: Bool, finishOnDeinit: Bool, delegate: Delegate @@ -337,9 +342,9 @@ extension NIOAsyncWriter { public struct Sink { /// This class is needed to hook the deinit to observe once all references to the ``NIOAsyncWriter/Sink`` are dropped. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { + internal final class InternalClass: Sendable { @usableFromInline - /* fileprivate */ internal let _storage: Storage + internal let _storage: Storage @usableFromInline internal let _finishOnDeinit: Bool @@ -362,10 +367,10 @@ extension NIOAsyncWriter { } @usableFromInline - /* private */ internal let _internalClass: InternalClass + internal let _internalClass: InternalClass @inlinable - /* private */ internal var _storage: Storage { + internal var _storage: Storage { self._internalClass._storage } @@ -414,7 +419,7 @@ extension NIOAsyncWriter { extension NIOAsyncWriter { /// This is the underlying storage of the writer. The goal of this is to synchronize the access to all state. @usableFromInline - /* fileprivate */ internal struct Storage: Sendable { + internal struct Storage: Sendable { /// Internal type to generate unique yield IDs. /// /// This type has reference semantics. @@ -424,7 +429,7 @@ extension NIOAsyncWriter { @usableFromInline struct YieldID: Equatable, Sendable { @usableFromInline - /* private */ internal var value: UInt64 + internal var value: UInt64 @inlinable init(value: UInt64) { @@ -447,10 +452,10 @@ extension NIOAsyncWriter { /// The counter used to assign an ID to all our yields. @usableFromInline - /* private */ internal let _yieldIDGenerator = YieldIDGenerator() + internal let _yieldIDGenerator = YieldIDGenerator() /// The state machine. @usableFromInline - /* private */ internal let _state: NIOLockedValueBox + internal let _state: NIOLockedValueBox @usableFromInline struct State: Sendable { @@ -485,7 +490,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal init( + internal init( isWritable: Bool, delegate: Delegate ) { @@ -494,7 +499,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func setWritability(to writability: Bool) { + internal func setWritability(to writability: Bool) { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler @@ -504,7 +509,9 @@ extension NIOAsyncWriter { switch action { case .resumeContinuations(let suspendedYields): - suspendedYields.forEach { $0.continuation.resume(returning: .retry) } + for yield in suspendedYields { + yield .continuation.resume(returning: .retry) + } case .none: return @@ -512,7 +519,8 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func yield(contentsOf sequence: S) async throws where S.Element == Element { + internal func yield(contentsOf sequence: S) async throws + where S.Element == Element { let yieldID = self._yieldIDGenerator.generateUniqueYieldID() while true { switch try await self._yield(contentsOf: sequence, yieldID: yieldID) { @@ -525,7 +533,10 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func _yield(contentsOf sequence: S, yieldID: StateMachine.YieldID?) async throws -> StateMachine.YieldResult where S.Element == Element { + internal func _yield( + contentsOf sequence: S, + yieldID: StateMachine.YieldID? + ) async throws -> StateMachine.YieldResult where S.Element == Element { let yieldID = yieldID ?? self._yieldIDGenerator.generateUniqueYieldID() return try await withTaskCancellationHandler { @@ -550,7 +561,8 @@ extension NIOAsyncWriter { throw error case .suspendTask: - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in let didSuspend = unsafe.withValueAssumingLockIsAcquired { $0.stateMachine.yield(continuation: continuation, yieldID: yieldID) return $0.didSuspend @@ -579,7 +591,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func yield(element: Element) async throws { + internal func yield(element: Element) async throws { let yieldID = self._yieldIDGenerator.generateUniqueYieldID() while true { switch try await self._yield(element: element, yieldID: yieldID) { @@ -592,7 +604,10 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func _yield(element: Element, yieldID: StateMachine.YieldID?) async throws -> StateMachine.YieldResult { + internal func _yield( + element: Element, + yieldID: StateMachine.YieldID? + ) async throws -> StateMachine.YieldResult { let yieldID = yieldID ?? self._yieldIDGenerator.generateUniqueYieldID() return try await withTaskCancellationHandler { @@ -617,7 +632,8 @@ extension NIOAsyncWriter { throw error case .suspendTask: - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in let didSuspend = unsafe.withValueAssumingLockIsAcquired { $0.stateMachine.yield(continuation: continuation, yieldID: yieldID) return $0.didSuspend @@ -645,7 +661,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func writerFinish(error: Error?) { + internal func writerFinish(error: Error?) { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler @@ -658,7 +674,9 @@ extension NIOAsyncWriter { delegate.didTerminate(error: error) case .resumeContinuations(let suspendedYields): - suspendedYields.forEach { $0.continuation.resume(returning: .retry) } + for yield in suspendedYields { + yield .continuation.resume(returning: .retry) + } case .none: break @@ -666,7 +684,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal func sinkFinish(error: Error?) { + internal func sinkFinish(error: Error?) { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler @@ -676,23 +694,26 @@ extension NIOAsyncWriter { switch action { case .resumeContinuationsWithError(let suspendedYields, let error): - suspendedYields.forEach { $0.continuation.resume(throwing: error) } + for yield in suspendedYields { + yield .continuation.resume(throwing: error) + } case .none: break } } - @inlinable - /* fileprivate */ internal func unbufferQueuedEvents() { - while let action = self._state.withLockedValue({ $0.stateMachine.unbufferQueuedEvents()}) { + internal func unbufferQueuedEvents() { + while let action = self._state.withLockedValue({ $0.stateMachine.unbufferQueuedEvents() }) { switch action { case .callDidTerminate(let delegate, let error): delegate.didTerminate(error: error) case .resumeContinuations(let suspendedYields): - suspendedYields.forEach { $0.continuation.resume(returning: .retry) } + for yield in suspendedYields { + yield .continuation.resume(returning: .retry) + } return } } @@ -703,12 +724,12 @@ extension NIOAsyncWriter { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOAsyncWriter { @usableFromInline - /* private */ internal struct StateMachine: Sendable { + internal struct StateMachine: Sendable { @usableFromInline typealias YieldID = Storage.YieldIDGenerator.YieldID /// This is a small helper struct to encapsulate the two different values for a suspended yield. @usableFromInline - /* private */ internal struct SuspendedYield: Sendable { + internal struct SuspendedYield: Sendable { /// The yield's ID. @usableFromInline var yieldID: YieldID @@ -725,7 +746,7 @@ extension NIOAsyncWriter { } /// The internal result of a yield. @usableFromInline - /* private */ internal enum YieldResult { + internal enum YieldResult { /// Indicates that the elements got yielded to the sink. case yielded /// Indicates that the yield should be retried. @@ -734,7 +755,7 @@ extension NIOAsyncWriter { /// The current state of our ``NIOAsyncWriter``. @usableFromInline - /* private */ internal enum State: Sendable, CustomStringConvertible { + internal enum State: Sendable, CustomStringConvertible { /// The initial state before either a call to ``NIOAsyncWriter/yield(contentsOf:)`` or /// ``NIOAsyncWriter/finish(completion:)`` happened. case initial( @@ -778,9 +799,19 @@ extension NIOAsyncWriter { case .initial(let isWritable, _): return "initial(isWritable: \(isWritable))" case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, _): - return "streaming(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), cancelledYields: \(cancelledYields.count), suspendedYields: \(suspendedYields.count))" - case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, _, _): - return "writerFinished(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), suspendedYields: \(suspendedYields.count), cancelledYields: \(cancelledYields.count), bufferedYieldIDs: \(bufferedYieldIDs.count)" + return + "streaming(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), cancelledYields: \(cancelledYields.count), suspendedYields: \(suspendedYields.count))" + case .writerFinished( + let isWritable, + let inDelegateOutcall, + let suspendedYields, + let cancelledYields, + let bufferedYieldIDs, + _, + _ + ): + return + "writerFinished(isWritable: \(isWritable), inDelegateOutcall: \(inDelegateOutcall), suspendedYields: \(suspendedYields.count), cancelledYields: \(cancelledYields.count), bufferedYieldIDs: \(bufferedYieldIDs.count)" case .finished: return "finished" case .modifying: @@ -791,7 +822,7 @@ extension NIOAsyncWriter { /// The state machine's current state. @usableFromInline - /* private */ internal var _state: State + internal var _state: State @inlinable internal var isWriterFinished: Bool { @@ -817,7 +848,6 @@ extension NIOAsyncWriter { } } - @inlinable init( isWritable: Bool, @@ -834,7 +864,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction? { + internal mutating func setWritability(to newWritability: Bool) -> SetWritabilityAction? { switch self._state { case .initial(_, let delegate): // We just need to store the new writability state @@ -842,7 +872,13 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + let cancelledYields, + let suspendedYields, + let delegate + ): if isWritable == newWritability { // The writability didn't change so we can just early exit here return .none @@ -882,7 +918,15 @@ extension NIOAsyncWriter { return .none } - case .writerFinished(_, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): + case .writerFinished( + _, + let inDelegateOutcall, + let suspendedYields, + let cancelledYields, + let bufferedYieldIDs, + let delegate, + let error + ): if !newWritability { // We are not writable so we can't deliver the outstanding elements return .none @@ -958,7 +1002,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func yield( + internal mutating func yield( yieldID: YieldID ) -> YieldAction { switch self._state { @@ -967,7 +1011,7 @@ extension NIOAsyncWriter { self._state = .streaming( isWritable: isWritable, - inDelegateOutcall: isWritable, // If we are writable we are going to make an outcall + inDelegateOutcall: isWritable, // If we are writable we are going to make an outcall cancelledYields: [], suspendedYields: .init(), delegate: delegate @@ -975,7 +1019,13 @@ extension NIOAsyncWriter { return .init(isWritable: isWritable, delegate: delegate) - case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, let suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + var cancelledYields, + let suspendedYields, + let delegate + ): self._state = .modifying if let index = cancelledYields.firstIndex(of: yieldID) { @@ -999,7 +1049,7 @@ extension NIOAsyncWriter { case (true, false): self._state = .streaming( isWritable: isWritable, - inDelegateOutcall: true, // We are now making a call to the delegate + inDelegateOutcall: true, // We are now making a call to the delegate cancelledYields: cancelledYields, suspendedYields: suspendedYields, delegate: delegate @@ -1018,7 +1068,15 @@ extension NIOAsyncWriter { } } - case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + case .writerFinished( + let isWritable, + let inDelegateOutcall, + let suspendedYields, + var cancelledYields, + let bufferedYieldIDs, + let delegate, + let error + ): if bufferedYieldIDs.contains(yieldID) { // This yield was buffered before we became finished so we still have to deliver it self._state = .modifying @@ -1046,7 +1104,7 @@ extension NIOAsyncWriter { case (true, false): self._state = .writerFinished( isWritable: isWritable, - inDelegateOutcall: true, // We are now making a call to the delegate + inDelegateOutcall: true, // We are now making a call to the delegate suspendedYields: suspendedYields, cancelledYields: cancelledYields, bufferedYieldIDs: bufferedYieldIDs, @@ -1084,12 +1142,18 @@ extension NIOAsyncWriter { /// This method is called as a result of the above `yield` method if it decided that the task needs to get suspended. @inlinable - /* fileprivate */ internal mutating func yield( + internal mutating func yield( continuation: CheckedContinuation, yieldID: YieldID ) { switch self._state { - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, var suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + let cancelledYields, + var suspendedYields, + let delegate + ): // We have a suspended yield at this point that hasn't been cancelled yet. // We need to store the yield now. @@ -1127,7 +1191,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func cancel( + internal mutating func cancel( yieldID: YieldID ) -> CancelAction { switch self._state { @@ -1145,7 +1209,13 @@ extension NIOAsyncWriter { return .none - case .streaming(let isWritable, let inDelegateOutcall, var cancelledYields, var suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + var cancelledYields, + var suspendedYields, + let delegate + ): if let index = suspendedYields.firstIndex(where: { $0.yieldID == yieldID }) { self._state = .modifying // We have a suspended yield for the id. We need to resume the continuation now. @@ -1185,7 +1255,15 @@ extension NIOAsyncWriter { return .none } - case .writerFinished(let isWritable, let inDelegateOutcall, var suspendedYields, var cancelledYields, let bufferedYieldIDs, let delegate, let error): + case .writerFinished( + let isWritable, + let inDelegateOutcall, + var suspendedYields, + var cancelledYields, + let bufferedYieldIDs, + let delegate, + let error + ): guard bufferedYieldIDs.contains(yieldID) else { return .none } @@ -1253,7 +1331,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func writerFinish(error: Error?) -> WriterFinishAction { + internal mutating func writerFinish(error: Error?) -> WriterFinishAction { switch self._state { case .initial(_, let delegate): // Nothing was ever written so we can transition to finished @@ -1261,7 +1339,13 @@ extension NIOAsyncWriter { return .callDidTerminate(delegate) - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + let cancelledYields, + let suspendedYields, + let delegate + ): // We are currently streaming and the writer got finished. if suspendedYields.isEmpty { if inDelegateOutcall { @@ -1317,7 +1401,7 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func sinkFinish(error: Error?) -> SinkFinishAction { + internal mutating func sinkFinish(error: Error?) -> SinkFinishAction { switch self._state { case .initial(_, _): // Nothing was ever written so we can transition to finished @@ -1360,12 +1444,18 @@ extension NIOAsyncWriter { } @inlinable - /* fileprivate */ internal mutating func unbufferQueuedEvents() -> UnbufferQueuedEventsAction? { + internal mutating func unbufferQueuedEvents() -> UnbufferQueuedEventsAction? { switch self._state { case .initial: preconditionFailure("Invalid state") - case .streaming(let isWritable, let inDelegateOutcall, let cancelledYields, let suspendedYields, let delegate): + case .streaming( + let isWritable, + let inDelegateOutcall, + let cancelledYields, + let suspendedYields, + let delegate + ): precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") // We have to resume the other suspended yields now. @@ -1391,7 +1481,15 @@ extension NIOAsyncWriter { return .resumeContinuations(suspendedYields) } - case .writerFinished(let isWritable, let inDelegateOutcall, let suspendedYields, let cancelledYields, let bufferedYieldIDs, let delegate, let error): + case .writerFinished( + let isWritable, + let inDelegateOutcall, + let suspendedYields, + let cancelledYields, + let bufferedYieldIDs, + let delegate, + let error + ): precondition(inDelegateOutcall, "We must be in a delegate outcall when we unbuffer events") if suspendedYields.isEmpty { // We were the last writer task and can now call didTerminate @@ -1401,7 +1499,6 @@ extension NIOAsyncWriter { // There are still other writer tasks that need to be resumed self._state = .modifying - self._state = .writerFinished( isWritable: isWritable, inDelegateOutcall: inDelegateOutcall, diff --git a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift index 0477c01969..7b95ee86f2 100644 --- a/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift +++ b/Sources/NIOCore/AsyncSequences/NIOThrowingAsyncSequenceProducer.swift @@ -49,7 +49,7 @@ public struct NIOThrowingAsyncSequenceProducer< public let sequence: NIOThrowingAsyncSequenceProducer @usableFromInline - /* fileprivate */ internal init( + internal init( source: Source, sequence: NIOThrowingAsyncSequenceProducer ) { @@ -62,7 +62,7 @@ public struct NIOThrowingAsyncSequenceProducer< /// /// If we get move-only types we should be able to drop this class and use the `deinit` of the ``AsyncIterator`` struct itself. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { + internal final class InternalClass: Sendable { @usableFromInline internal let _storage: Storage @@ -78,10 +78,10 @@ public struct NIOThrowingAsyncSequenceProducer< } @usableFromInline - /* private */ internal let _internalClass: InternalClass + internal let _internalClass: InternalClass @usableFromInline - /* private */ internal var _storage: Storage { + internal var _storage: Storage { self._internalClass._storage } @@ -130,7 +130,11 @@ public struct NIOThrowingAsyncSequenceProducer< /// - backPressureStrategy: The back-pressure strategy of the sequence. /// - delegate: The delegate of the sequence /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. - @available(*, deprecated, message: "Support for a generic Failure type is deprecated. Failure type must be `any Swift.Error`.") + @available( + *, + deprecated, + message: "Support for a generic Failure type is deprecated. Failure type must be `any Swift.Error`." + ) @inlinable public static func makeSequence( elementType: Element.Type = Element.self, @@ -161,7 +165,12 @@ public struct NIOThrowingAsyncSequenceProducer< /// - delegate: The delegate of the sequence /// - Returns: A ``NIOThrowingAsyncSequenceProducer/Source`` and a ``NIOThrowingAsyncSequenceProducer``. @inlinable - @available(*, deprecated, renamed: "makeSequence(elementType:failureType:backPressureStrategy:finishOnDeinit:delegate:)", message: "This method has been deprecated since it defaults to deinit based resource teardown") + @available( + *, + deprecated, + renamed: "makeSequence(elementType:failureType:backPressureStrategy:finishOnDeinit:delegate:)", + message: "This method has been deprecated since it defaults to deinit based resource teardown" + ) public static func makeSequence( elementType: Element.Type = Element.self, failureType: Failure.Type = Error.self, @@ -195,7 +204,7 @@ public struct NIOThrowingAsyncSequenceProducer< } @inlinable - /* private */ internal init( + internal init( backPressureStrategy: Strategy, delegate: Delegate ) { @@ -221,9 +230,9 @@ extension NIOThrowingAsyncSequenceProducer { /// /// If we get move-only types we should be able to drop this class and use the `deinit` of the ``AsyncIterator`` struct itself. @usableFromInline - /* private */ internal final class InternalClass: Sendable { + internal final class InternalClass: Sendable { @usableFromInline - /* private */ internal let _storage: Storage + internal let _storage: Storage fileprivate init(storage: Storage) { self._storage = storage @@ -236,13 +245,13 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func next() async throws -> Element? { + internal func next() async throws -> Element? { try await self._storage.next() } } @usableFromInline - /* private */ internal let _internalClass: InternalClass + internal let _internalClass: InternalClass fileprivate init(storage: Storage) { self._internalClass = InternalClass(storage: storage) @@ -268,7 +277,7 @@ extension NIOThrowingAsyncSequenceProducer { /// /// - Important: This is safe to be unchecked ``Sendable`` since the `storage` is ``Sendable`` and `immutable`. @usableFromInline - /* fileprivate */ internal final class InternalClass: Sendable { + internal final class InternalClass: Sendable { @usableFromInline internal let _storage: Storage @@ -293,15 +302,15 @@ extension NIOThrowingAsyncSequenceProducer { } @usableFromInline - /* private */ internal let _internalClass: InternalClass + internal let _internalClass: InternalClass @usableFromInline - /* private */ internal var _storage: Storage { + internal var _storage: Storage { self._internalClass._storage } @usableFromInline - /* fileprivate */ internal init(storage: Storage, finishOnDeinit: Bool) { + internal init(storage: Storage, finishOnDeinit: Bool) { self._internalClass = .init(storage: storage, finishOnDeinit: finishOnDeinit) } @@ -388,7 +397,7 @@ extension NIOThrowingAsyncSequenceProducer { extension NIOThrowingAsyncSequenceProducer { /// This is the underlying storage of the sequence. The goal of this is to synchronize the access to all state. @usableFromInline - /* fileprivate */ internal struct Storage: Sendable { + internal struct Storage: Sendable { @usableFromInline struct State: Sendable { @usableFromInline @@ -426,7 +435,7 @@ extension NIOThrowingAsyncSequenceProducer { } @usableFromInline - /* fileprivate */ internal init( + internal init( backPressureStrategy: Strategy, delegate: Delegate ) { @@ -438,7 +447,7 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func sequenceDeinitialized() { + internal func sequenceDeinitialized() { let delegate: Delegate? = self._state.withLockedValue { let action = $0.stateMachine.sequenceDeinitialized() @@ -457,14 +466,14 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func iteratorInitialized() { + internal func iteratorInitialized() { self._state.withLockedValue { $0.stateMachine.iteratorInitialized() } } @inlinable - /* fileprivate */ internal func iteratorDeinitialized() { + internal func iteratorDeinitialized() { let delegate: Delegate? = self._state.withLockedValue { let action = $0.stateMachine.iteratorDeinitialized() @@ -484,7 +493,8 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func yield(_ sequence: S) -> Source.YieldResult where S.Element == Element { + internal func yield(_ sequence: S) -> Source.YieldResult + where S.Element == Element { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler @@ -515,23 +525,24 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func finish(_ failure: Failure?) { + internal func finish(_ failure: Failure?) { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler - let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.FinishAction) = self._state.withLockedValue { - let action = $0.stateMachine.finish(failure) + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.FinishAction) = self + ._state.withLockedValue { + let action = $0.stateMachine.finish(failure) - switch action { - case .resumeContinuationWithFailureAndCallDidTerminate: - let delegate = $0.delegate - $0.delegate = nil - return (delegate, action) + switch action { + case .resumeContinuationWithFailureAndCallDidTerminate: + let delegate = $0.delegate + $0.delegate = nil + return (delegate, action) - case .none: - return (nil, action) + case .none: + return (nil, action) + } } - } switch action { case .resumeContinuationWithFailureAndCallDidTerminate(let continuation, let failure): @@ -550,7 +561,7 @@ extension NIOThrowingAsyncSequenceProducer { } @inlinable - /* fileprivate */ internal func next() async throws -> Element? { + internal func next() async throws -> Element? { try await withTaskCancellationHandler { () async throws -> Element? in let unsafe = self._state.unsafe unsafe.lock() @@ -614,7 +625,8 @@ extension NIOThrowingAsyncSequenceProducer { case .suspendTask: // It is safe to hold the lock across this method // since the closure is guaranteed to be run straight away - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in let (action, callDidSuspend) = unsafe.withValueAssumingLockIsAcquired { let action = $0.stateMachine.next(for: continuation) let callDidSuspend = $0.didSuspend != nil @@ -644,26 +656,27 @@ extension NIOThrowingAsyncSequenceProducer { // We must not resume the continuation while holding the lock // because it can deadlock in combination with the underlying ulock // in cases where we race with a cancellation handler - let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.CancelledAction) = self._state.withLockedValue { - let action = $0.stateMachine.cancelled() + let (delegate, action): (Delegate?, NIOThrowingAsyncSequenceProducer.StateMachine.CancelledAction) = + self._state.withLockedValue { + let action = $0.stateMachine.cancelled() - switch action { - case .callDidTerminate: - let delegate = $0.delegate - $0.delegate = nil + switch action { + case .callDidTerminate: + let delegate = $0.delegate + $0.delegate = nil - return (delegate, action) + return (delegate, action) - case .resumeContinuationWithCancellationErrorAndCallDidTerminate: - let delegate = $0.delegate - $0.delegate = nil + case .resumeContinuationWithCancellationErrorAndCallDidTerminate: + let delegate = $0.delegate + $0.delegate = nil - return (delegate, action) + return (delegate, action) - case .none: - return (nil, action) + case .none: + return (nil, action) + } } - } switch action { case .callDidTerminate: @@ -697,9 +710,9 @@ extension NIOThrowingAsyncSequenceProducer { @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) extension NIOThrowingAsyncSequenceProducer { @usableFromInline - /* private */ internal struct StateMachine: Sendable { + internal struct StateMachine: Sendable { @usableFromInline - /* private */ internal enum State: Sendable { + internal enum State: Sendable { /// The initial state before either a call to `yield()` or a call to `next()` happened case initial( backPressureStrategy: Strategy, @@ -736,7 +749,7 @@ extension NIOThrowingAsyncSequenceProducer { /// The state machine's current state. @usableFromInline - /* private */ internal var _state: State + internal var _state: State @inlinable var isFinished: Bool { @@ -750,7 +763,6 @@ extension NIOThrowingAsyncSequenceProducer { } } - /// Initializes a new `StateMachine`. /// /// We are passing and holding the back-pressure strategy here because @@ -778,18 +790,18 @@ extension NIOThrowingAsyncSequenceProducer { mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { switch self._state { case .initial(_, iteratorInitialized: false), - .streaming(_, _, _, _, iteratorInitialized: false), - .sourceFinished(_, iteratorInitialized: false, _), - .cancelled(iteratorInitialized: false): + .streaming(_, _, _, _, iteratorInitialized: false), + .sourceFinished(_, iteratorInitialized: false, _), + .cancelled(iteratorInitialized: false): // No iterator was created so we can transition to finished right away. self._state = .finished(iteratorInitialized: false) return .callDidTerminate case .initial(_, iteratorInitialized: true), - .streaming(_, _, _, _, iteratorInitialized: true), - .sourceFinished(_, iteratorInitialized: true, _), - .cancelled(iteratorInitialized: true): + .streaming(_, _, _, _, iteratorInitialized: true), + .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true): // An iterator was created and we deinited the sequence. // This is an expected pattern and we just continue on normal. return .none @@ -808,10 +820,10 @@ extension NIOThrowingAsyncSequenceProducer { mutating func iteratorInitialized() { switch self._state { case .initial(_, iteratorInitialized: true), - .streaming(_, _, _, _, iteratorInitialized: true), - .sourceFinished(_, iteratorInitialized: true, _), - .cancelled(iteratorInitialized: true), - .finished(iteratorInitialized: true): + .streaming(_, _, _, _, iteratorInitialized: true), + .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true), + .finished(iteratorInitialized: true): // Our sequence is a unicast sequence and does not support multiple AsyncIterator's fatalError("NIOThrowingAsyncSequenceProducer allows only a single AsyncIterator to be created") @@ -868,16 +880,16 @@ extension NIOThrowingAsyncSequenceProducer { mutating func iteratorDeinitialized() -> IteratorDeinitializedAction { switch self._state { case .initial(_, iteratorInitialized: false), - .streaming(_, _, _, _, iteratorInitialized: false), - .sourceFinished(_, iteratorInitialized: false, _), - .cancelled(iteratorInitialized: false): + .streaming(_, _, _, _, iteratorInitialized: false), + .sourceFinished(_, iteratorInitialized: false, _), + .cancelled(iteratorInitialized: false): // An iterator needs to be initialized before it can be deinitialized. preconditionFailure("Internal inconsistency") case .initial(_, iteratorInitialized: true), - .streaming(_, _, _, _, iteratorInitialized: true), - .sourceFinished(_, iteratorInitialized: true, _), - .cancelled(iteratorInitialized: true): + .streaming(_, _, _, _, iteratorInitialized: true), + .sourceFinished(_, iteratorInitialized: true, _), + .cancelled(iteratorInitialized: true): // An iterator was created and deinited. Since we only support // a single iterator we can now transition to finish and inform the delegate. self._state = .finished(iteratorInitialized: true) @@ -917,7 +929,10 @@ extension NIOThrowingAsyncSequenceProducer { case returnDropped @usableFromInline - init(shouldProduceMore: Bool, continuationAndElement: (CheckedContinuation, Element)? = nil) { + init( + shouldProduceMore: Bool, + continuationAndElement: (CheckedContinuation, Element)? = nil + ) { switch (shouldProduceMore, continuationAndElement) { case (true, .none): self = .returnProduceMore @@ -957,7 +972,13 @@ extension NIOThrowingAsyncSequenceProducer { return .init(shouldProduceMore: shouldProduceMore) - case .streaming(var backPressureStrategy, var buffer, .some(let continuation), let hasOutstandingDemand, let iteratorInitialized): + case .streaming( + var backPressureStrategy, + var buffer, + .some(let continuation), + let hasOutstandingDemand, + let iteratorInitialized + ): // The buffer should always be empty if we hold a continuation precondition(buffer.isEmpty, "Expected an empty buffer") @@ -982,7 +1003,7 @@ extension NIOThrowingAsyncSequenceProducer { self._state = .streaming( backPressureStrategy: backPressureStrategy, buffer: buffer, - continuation: nil, // Setting this to nil since we are resuming the continuation + continuation: nil, // Setting this to nil since we are resuming the continuation hasOutstandingDemand: shouldProduceMore, iteratorInitialized: iteratorInitialized ) @@ -1167,7 +1188,13 @@ extension NIOThrowingAsyncSequenceProducer { // We have multiple AsyncIterators iterating the sequence preconditionFailure("This should never happen since we only allow a single Iterator to be created") - case .streaming(var backPressureStrategy, var buffer, .none, let hasOutstandingDemand, let iteratorInitialized): + case .streaming( + var backPressureStrategy, + var buffer, + .none, + let hasOutstandingDemand, + let iteratorInitialized + ): self._state = .modifying if let element = buffer.popFirst() { @@ -1252,7 +1279,13 @@ extension NIOThrowingAsyncSequenceProducer { // We are transitioning away from the initial state in `next()` preconditionFailure("Invalid state") - case .streaming(var backPressureStrategy, let buffer, .none, let hasOutstandingDemand, let iteratorInitialized): + case .streaming( + var backPressureStrategy, + let buffer, + .none, + let hasOutstandingDemand, + let iteratorInitialized + ): precondition(buffer.isEmpty, "Expected an empty buffer") self._state = .modifying diff --git a/Sources/NIOCore/BSDSocketAPI.swift b/Sources/NIOCore/BSDSocketAPI.swift index 7b6627db14..ddbbb83557 100644 --- a/Sources/NIOCore/BSDSocketAPI.swift +++ b/Sources/NIOCore/BSDSocketAPI.swift @@ -69,16 +69,21 @@ import Musl import CNIOLinux #if os(Android) -private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer, UnsafeMutablePointer, socklen_t) -> UnsafePointer? = inet_ntop +private let sysInet_ntop: + @convention(c) (CInt, UnsafeRawPointer, UnsafeMutablePointer, socklen_t) -> UnsafePointer? = inet_ntop private let sysInet_pton: @convention(c) (CInt, UnsafePointer, UnsafeMutableRawPointer) -> CInt = inet_pton #else -private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = inet_ntop +private let sysInet_ntop: + @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = + inet_ntop private let sysInet_pton: @convention(c) (CInt, UnsafePointer?, UnsafeMutableRawPointer?) -> CInt = inet_pton #endif #elseif canImport(Darwin) import Darwin -private let sysInet_ntop: @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = inet_ntop +private let sysInet_ntop: + @convention(c) (CInt, UnsafeRawPointer?, UnsafeMutablePointer?, socklen_t) -> UnsafePointer? = + inet_ntop private let sysInet_pton: @convention(c) (CInt, UnsafePointer?, UnsafeMutableRawPointer?) -> CInt = inet_pton #else #error("The BSD Socket module was unable to identify your C library.") @@ -99,11 +104,11 @@ let SO_RCVTIMEO = CNIOLinux_SO_RCVTIMEO #endif public enum NIOBSDSocket { -#if os(Windows) + #if os(Windows) public typealias Handle = SOCKET -#else + #else public typealias Handle = CInt -#endif + #endif } extension NIOBSDSocket { @@ -178,74 +183,73 @@ extension NIOBSDSocket.Option: Hashable { extension NIOBSDSocket.AddressFamily { /// Address for IP version 4. public static let inet: NIOBSDSocket.AddressFamily = - NIOBSDSocket.AddressFamily(rawValue: AF_INET) + NIOBSDSocket.AddressFamily(rawValue: AF_INET) /// Address for IP version 6. public static let inet6: NIOBSDSocket.AddressFamily = - NIOBSDSocket.AddressFamily(rawValue: AF_INET6) + NIOBSDSocket.AddressFamily(rawValue: AF_INET6) /// Unix local to host address. public static let unix: NIOBSDSocket.AddressFamily = - NIOBSDSocket.AddressFamily(rawValue: AF_UNIX) + NIOBSDSocket.AddressFamily(rawValue: AF_UNIX) } // Protocol Family extension NIOBSDSocket.ProtocolFamily { /// IP network 4 protocol. public static let inet: NIOBSDSocket.ProtocolFamily = - NIOBSDSocket.ProtocolFamily(rawValue: PF_INET) + NIOBSDSocket.ProtocolFamily(rawValue: PF_INET) /// IP network 6 protocol. public static let inet6: NIOBSDSocket.ProtocolFamily = - NIOBSDSocket.ProtocolFamily(rawValue: PF_INET6) + NIOBSDSocket.ProtocolFamily(rawValue: PF_INET6) /// UNIX local to the host. public static let unix: NIOBSDSocket.ProtocolFamily = - NIOBSDSocket.ProtocolFamily(rawValue: PF_UNIX) + NIOBSDSocket.ProtocolFamily(rawValue: PF_UNIX) } #if !os(Windows) - extension NIOBSDSocket.ProtocolFamily { - /// UNIX local to the host, alias for `PF_UNIX` (`.unix`) - public static let local: NIOBSDSocket.ProtocolFamily = - NIOBSDSocket.ProtocolFamily(rawValue: PF_LOCAL) - } +extension NIOBSDSocket.ProtocolFamily { + /// UNIX local to the host, alias for `PF_UNIX` (`.unix`) + public static let local: NIOBSDSocket.ProtocolFamily = + NIOBSDSocket.ProtocolFamily(rawValue: PF_LOCAL) +} #endif - // Option Level extension NIOBSDSocket.OptionLevel { /// Socket options that apply only to IP sockets. #if os(Linux) || os(Android) - public static let ip: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_IP)) + public static let ip: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_IP)) #else - public static let ip: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IP) + public static let ip: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IP) #endif /// Socket options that apply only to IPv6 sockets. #if os(Linux) || os(Android) - public static let ipv6: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_IPV6)) + public static let ipv6: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_IPV6)) #elseif os(Windows) - public static let ipv6: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IPV6.rawValue) + public static let ipv6: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IPV6.rawValue) #else - public static let ipv6: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IPV6) + public static let ipv6: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_IPV6) #endif /// Socket options that apply only to TCP sockets. #if os(Linux) || os(Android) - public static let tcp: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_TCP)) + public static let tcp: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_TCP)) #elseif os(Windows) - public static let tcp: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_TCP.rawValue) + public static let tcp: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_TCP.rawValue) #else - public static let tcp: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_TCP) + public static let tcp: NIOBSDSocket.OptionLevel = + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_TCP) #endif /// Socket options that apply to MPTCP sockets. @@ -255,15 +259,15 @@ extension NIOBSDSocket.OptionLevel { /// Socket options that apply to all sockets. public static let socket: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: SOL_SOCKET) + NIOBSDSocket.OptionLevel(rawValue: SOL_SOCKET) /// Socket options that apply only to UDP sockets. #if os(Linux) || os(Android) public static let udp: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_UDP)) + NIOBSDSocket.OptionLevel(rawValue: CInt(IPPROTO_UDP)) #else public static let udp: NIOBSDSocket.OptionLevel = - NIOBSDSocket.OptionLevel(rawValue: IPPROTO_UDP) + NIOBSDSocket.OptionLevel(rawValue: IPPROTO_UDP) #endif } @@ -271,72 +275,72 @@ extension NIOBSDSocket.OptionLevel { extension NIOBSDSocket.Option { /// Add a multicast group membership. public static let ip_add_membership: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_ADD_MEMBERSHIP) + NIOBSDSocket.Option(rawValue: IP_ADD_MEMBERSHIP) /// Drop a multicast group membership. public static let ip_drop_membership: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_DROP_MEMBERSHIP) + NIOBSDSocket.Option(rawValue: IP_DROP_MEMBERSHIP) /// Set the interface for outgoing multicast packets. public static let ip_multicast_if: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_MULTICAST_IF) + NIOBSDSocket.Option(rawValue: IP_MULTICAST_IF) /// Control multicast loopback. public static let ip_multicast_loop: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_MULTICAST_LOOP) + NIOBSDSocket.Option(rawValue: IP_MULTICAST_LOOP) /// Control multicast time-to-live. public static let ip_multicast_ttl: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL) + NIOBSDSocket.Option(rawValue: IP_MULTICAST_TTL) /// The IPv4 layer generates an IP header when sending a packet /// unless the ``ip_hdrincl`` socket option is enabled on the socket. /// When it is enabled, the packet must contain an IP header. For /// receiving, the IP header is always included in the packet. public static let ip_hdrincl: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_HDRINCL) + NIOBSDSocket.Option(rawValue: IP_HDRINCL) } // IPv6 Options extension NIOBSDSocket.Option { /// Add an IPv6 group membership. public static let ipv6_join_group: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_JOIN_GROUP) + NIOBSDSocket.Option(rawValue: IPV6_JOIN_GROUP) /// Drop an IPv6 group membership. public static let ipv6_leave_group: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_LEAVE_GROUP) + NIOBSDSocket.Option(rawValue: IPV6_LEAVE_GROUP) /// Specify the maximum number of router hops for an IPv6 packet. public static let ipv6_multicast_hops: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_HOPS) + NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_HOPS) /// Set the interface for outgoing multicast packets. public static let ipv6_multicast_if: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_IF) + NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_IF) /// Control multicast loopback. public static let ipv6_multicast_loop: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_LOOP) + NIOBSDSocket.Option(rawValue: IPV6_MULTICAST_LOOP) /// Indicates if a socket created for the `AF_INET6` address family is /// restricted to IPv6 only. public static let ipv6_v6only: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_V6ONLY) + NIOBSDSocket.Option(rawValue: IPV6_V6ONLY) } // TCP Options extension NIOBSDSocket.Option { /// Disables the Nagle algorithm for send coalescing. public static let tcp_nodelay: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: TCP_NODELAY) + NIOBSDSocket.Option(rawValue: TCP_NODELAY) } #if os(Linux) || os(FreeBSD) || os(Android) extension NIOBSDSocket.Option { /// Get information about the TCP connection. public static let tcp_info: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: TCP_INFO) + NIOBSDSocket.Option(rawValue: TCP_INFO) } #endif @@ -344,7 +348,7 @@ extension NIOBSDSocket.Option { extension NIOBSDSocket.Option { /// Get information about the TCP connection. public static let tcp_connection_info: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: TCP_CONNECTION_INFO) + NIOBSDSocket.Option(rawValue: TCP_CONNECTION_INFO) } #endif @@ -398,14 +402,18 @@ extension NIOBSDSocket.Option { extension NIOBSDSocket.Option { /// Indicate when to generate timestamps. public static let so_timestamp: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: SO_TIMESTAMP) + NIOBSDSocket.Option(rawValue: SO_TIMESTAMP) } #endif extension NIOBSDSocket { // Sadly this was defined on BSDSocket, and we need it for SocketAddress. @inline(never) - internal static func inet_pton(addressFamily: NIOBSDSocket.AddressFamily, addressDescription: UnsafePointer, address: UnsafeMutableRawPointer) throws { + internal static func inet_pton( + addressFamily: NIOBSDSocket.AddressFamily, + addressDescription: UnsafePointer, + address: UnsafeMutableRawPointer + ) throws { #if os(Windows) // TODO(compnerd) use `InetPtonW` to ensure that we handle unicode properly switch WinSDK.inet_pton(addressFamily.rawValue, addressDescription, address) { @@ -424,12 +432,22 @@ extension NIOBSDSocket { @discardableResult @inline(never) - internal static func inet_ntop(addressFamily: NIOBSDSocket.AddressFamily, addressBytes: UnsafeRawPointer, addressDescription: UnsafeMutablePointer, addressDescriptionLength: socklen_t) throws -> UnsafePointer { + internal static func inet_ntop( + addressFamily: NIOBSDSocket.AddressFamily, + addressBytes: UnsafeRawPointer, + addressDescription: UnsafeMutablePointer, + addressDescriptionLength: socklen_t + ) throws -> UnsafePointer { #if os(Windows) // TODO(compnerd) use `InetNtopW` to ensure that we handle unicode properly - guard let result = WinSDK.inet_ntop(addressFamily.rawValue, addressBytes, - addressDescription, - Int(addressDescriptionLength)) else { + guard + let result = WinSDK.inet_ntop( + addressFamily.rawValue, + addressBytes, + addressDescription, + Int(addressDescriptionLength) + ) + else { throw IOError(windows: GetLastError(), reason: "inet_ntop") } return result diff --git a/Sources/NIOCore/ByteBuffer-aux.swift b/Sources/NIOCore/ByteBuffer-aux.swift index 7765394651..fcfa74ce3b 100644 --- a/Sources/NIOCore/ByteBuffer-aux.swift +++ b/Sources/NIOCore/ByteBuffer-aux.swift @@ -36,7 +36,7 @@ extension ByteBuffer { // this is not technically correct because we shouldn't just bind // the memory to `UInt8` but it's not a real issue either and we // need to work around https://bugs.swift.org/browse/SR-9604 - Array(UnsafeRawBufferPointer(fastRebase: ptr[range]).bindMemory(to: UInt8.self)) + [UInt8](UnsafeRawBufferPointer(fastRebase: ptr[range]).bindMemory(to: UInt8.self)) } } @@ -79,8 +79,13 @@ extension ByteBuffer { @inlinable public mutating func setStaticString(_ string: StaticString, at index: Int) -> Int { // please do not replace the code below with code that uses `string.withUTF8Buffer { ... }` (see SR-7541) - return self.setBytes(UnsafeRawBufferPointer(start: string.utf8Start, - count: string.utf8CodeUnitCount), at: index) + self.setBytes( + UnsafeRawBufferPointer( + start: string.utf8Start, + count: string.utf8CodeUnitCount + ), + at: index + ) } // MARK: String APIs @@ -96,7 +101,7 @@ extension ByteBuffer { self._moveWriterIndex(forwardBy: written) return written } - + /// Write `string` into this `ByteBuffer` null terminated using UTF-8 encoding, moving the writer index forward appropriately. /// /// - parameters: @@ -144,7 +149,7 @@ extension ByteBuffer { return self._setStringSlowpath(string, at: index) } } - + /// Write `string` null terminated into this `ByteBuffer` at `index` using UTF-8 encoding. Does not move the writer index. /// /// - parameters: @@ -176,7 +181,7 @@ extension ByteBuffer { return String(decoding: UnsafeRawBufferPointer(fastRebase: pointer[range]), as: Unicode.UTF8.self) } } - + /// Get the string at `index` from this `ByteBuffer` decoding using the UTF-8 encoding. Does not move the reader index. /// The selected bytes must be readable or else `nil` will be returned. /// @@ -216,7 +221,7 @@ extension ByteBuffer { self._moveReaderIndex(forwardBy: length) return result } - + /// Read a null terminated string off this `ByteBuffer`, decoding it as `String` using the UTF-8 encoding. Move the reader index /// forward by the string's length and its null terminator. /// @@ -228,7 +233,7 @@ extension ByteBuffer { return nil } let result = self.readString(length: stringLength) - self.moveReaderIndex(forwardBy: 1) // move forward by null terminator + self.moveReaderIndex(forwardBy: 1) // move forward by null terminator return result } @@ -245,7 +250,7 @@ extension ByteBuffer { self._moveWriterIndex(forwardBy: written) return written } - + /// Write `substring` into this `ByteBuffer` at `index` using UTF-8 encoding. Does not move the writer index. /// /// - parameters: @@ -267,7 +272,7 @@ extension ByteBuffer { return self.setString(String(substring), at: index) } } - + // MARK: DispatchData APIs /// Write `dispatchData` into this `ByteBuffer`, moving the writer index forward appropriately. /// @@ -295,7 +300,7 @@ extension ByteBuffer { self.reserveCapacity(index + allBytesCount) self.withVeryUnsafeMutableBytes { destCompleteStorage in assert(destCompleteStorage.count >= index + allBytesCount) - let dest = destCompleteStorage[index ..< index + allBytesCount] + let dest = destCompleteStorage[index..(_ body: (UnsafeRawBufferPointer) throws -> (Int, T)) rethrows -> T { - let (bytesRead, ret) = try self.withUnsafeReadableBytes({ try body($0) }) - self._moveReaderIndex(forwardBy: bytesRead) - return ret - } - /// Yields a mutable buffer pointer containing this `ByteBuffer`'s readable bytes. You may modify the yielded bytes. /// Will move the reader index by the number of bytes returned by `body` but leave writer index as it was. /// @@ -377,27 +366,14 @@ extension ByteBuffer { /// - returns: The number of bytes read. @discardableResult @inlinable - public mutating func readWithUnsafeMutableReadableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { + public mutating func readWithUnsafeMutableReadableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> Int + ) rethrows -> Int { let bytesRead = try self.withUnsafeMutableReadableBytes({ try body($0) }) self._moveReaderIndex(forwardBy: bytesRead) return bytesRead } - /// Yields a mutable buffer pointer containing this `ByteBuffer`'s readable bytes. You may modify the yielded bytes. - /// Will move the reader index by the number of bytes `body` returns in the first tuple component but leave writer index as it was. - /// - /// - warning: Do not escape the pointer from the closure for later use. - /// - /// - parameters: - /// - body: The closure that will accept the yielded bytes and returns the number of bytes it processed along with some other value. - /// - returns: The value `body` returned in the second tuple component. - @inlinable - public mutating func readWithUnsafeMutableReadableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> (Int, T)) rethrows -> T { - let (bytesRead, ret) = try self.withUnsafeMutableReadableBytes({ try body($0) }) - self._moveReaderIndex(forwardBy: bytesRead) - return ret - } - /// Copy `buffer`'s readable bytes into this `ByteBuffer` starting at `index`. Does not move any of the reader or writer indices. /// /// - parameters: @@ -407,7 +383,7 @@ extension ByteBuffer { @discardableResult @available(*, deprecated, renamed: "setBuffer(_:at:)") public mutating func set(buffer: ByteBuffer, at index: Int) -> Int { - return self.setBuffer(buffer, at: index) + self.setBuffer(buffer, at: index) } /// Copy `buffer`'s readable bytes into this `ByteBuffer` starting at `index`. Does not move any of the reader or writer indices. @@ -419,7 +395,7 @@ extension ByteBuffer { @discardableResult @inlinable public mutating func setBuffer(_ buffer: ByteBuffer, at index: Int) -> Int { - return buffer.withUnsafeReadableBytes{ p in + buffer.withUnsafeReadableBytes { p in self.setBytes(p, at: index) } } @@ -464,7 +440,7 @@ extension ByteBuffer { self._moveWriterIndex(forwardBy: written) return written } - + /// Writes `byte` `count` times. Moves the writer index forward by the number of bytes written. /// /// - parameter byte: The `UInt8` byte to repeat. @@ -477,7 +453,7 @@ extension ByteBuffer { self._moveWriterIndex(forwardBy: written) return written } - + /// Sets the given `byte` `count` times at the given `index`. Will reserve more memory if necessary. Does not move the writer index. /// /// - parameter byte: The `UInt8` byte to repeat. @@ -489,7 +465,7 @@ extension ByteBuffer { precondition(count >= 0, "Can't write fewer than 0 bytes") self.reserveCapacity(index + count) self.withVeryUnsafeMutableBytes { pointer in - let dest = UnsafeMutableRawBufferPointer(fastRebase: pointer[index ..< index+count]) + let dest = UnsafeMutableRawBufferPointer(fastRebase: pointer[index.. ByteBuffer { - return self.getSlice(at: self.readerIndex, length: self.readableBytes)! // must work, bytes definitely in the buffer + // must work, bytes definitely in the buffer// must work, bytes definitely in the buffer + self.getSlice(at: self.readerIndex, length: self.readableBytes)! } /// Slice `length` bytes off this `ByteBuffer` and move the reader index forward by `length`. @@ -536,6 +513,43 @@ extension ByteBuffer { } } +// swift-format-ignore: AmbiguousTrailingClosureOverload +extension ByteBuffer { + /// Yields a mutable buffer pointer containing this `ByteBuffer`'s readable bytes. You may modify the yielded bytes. + /// Will move the reader index by the number of bytes `body` returns in the first tuple component but leave writer index as it was. + /// + /// - warning: Do not escape the pointer from the closure for later use. + /// + /// - parameters: + /// - body: The closure that will accept the yielded bytes and returns the number of bytes it processed along with some other value. + /// - returns: The value `body` returned in the second tuple component. + @inlinable + public mutating func readWithUnsafeMutableReadableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> (Int, T) + ) rethrows -> T { + let (bytesRead, ret) = try self.withUnsafeMutableReadableBytes({ try body($0) }) + self._moveReaderIndex(forwardBy: bytesRead) + return ret + } + + /// Yields an immutable buffer pointer containing this `ByteBuffer`'s readable bytes. Will move the reader index + /// by the number of bytes `body` returns in the first tuple component. + /// + /// - warning: Do not escape the pointer from the closure for later use. + /// + /// - parameters: + /// - body: The closure that will accept the yielded bytes and returns the number of bytes it processed along with some other value. + /// - returns: The value `body` returned in the second tuple component. + @inlinable + public mutating func readWithUnsafeReadableBytes( + _ body: (UnsafeRawBufferPointer) throws -> (Int, T) + ) rethrows -> T { + let (bytesRead, ret) = try self.withUnsafeReadableBytes({ try body($0) }) + self._moveReaderIndex(forwardBy: bytesRead) + return ret + } +} + extension ByteBuffer { /// Return an empty `ByteBuffer` allocated with `ByteBufferAllocator()`. /// @@ -749,9 +763,11 @@ extension ByteBufferAllocator { /// /// - returns: The `ByteBuffer` containing the written bytes. @inlinable - public func buffer(integer: I, - endianness: Endianness = .big, - as: I.Type = I.self) -> ByteBuffer { + public func buffer( + integer: I, + endianness: Endianness = .big, + as: I.Type = I.self + ) -> ByteBuffer { var buffer = self.buffer(capacity: MemoryLayout.size) buffer.writeInteger(integer, endianness: endianness, as: `as`) return buffer @@ -799,7 +815,6 @@ extension ByteBufferAllocator { } } - extension Optional where Wrapped == ByteBuffer { /// If `nil`, replace `self` with `.some(buffer)`. If non-`nil`, write `buffer`'s readable bytes into the /// `ByteBuffer` starting at `writerIndex`. diff --git a/Sources/NIOCore/ByteBuffer-conversions.swift b/Sources/NIOCore/ByteBuffer-conversions.swift index ef197aefe3..d1f4f0eef3 100644 --- a/Sources/NIOCore/ByteBuffer-conversions.swift +++ b/Sources/NIOCore/ByteBuffer-conversions.swift @@ -15,7 +15,7 @@ import Dispatch extension Array where Element == UInt8 { - + /// Creates a `[UInt8]` from the given buffer. The entire readable portion of the buffer will be read. /// - parameter buffer: The buffer to read. @inlinable @@ -23,11 +23,11 @@ extension Array where Element == UInt8 { var buffer = buffer self = buffer.readBytes(length: buffer.readableBytes)! } - + } extension String { - + /// Creates a `String` from a given `ByteBuffer`. The entire readable portion of the buffer will be read. /// - parameter buffer: The buffer to read. @inlinable @@ -49,7 +49,7 @@ extension String { } extension DispatchData { - + /// Creates a `DispatchData` from a given `ByteBuffer`. The entire readable portion of the buffer will be read. /// - parameter buffer: The buffer to read. @inlinable @@ -57,5 +57,5 @@ extension DispatchData { var buffer = buffer self = buffer.readDispatchData(length: buffer.readableBytes)! } - + } diff --git a/Sources/NIOCore/ByteBuffer-core.swift b/Sources/NIOCore/ByteBuffer-core.swift index 96ad63c028..8027fcc637 100644 --- a/Sources/NIOCore/ByteBuffer-core.swift +++ b/Sources/NIOCore/ByteBuffer-core.swift @@ -25,7 +25,8 @@ import Musl #endif @usableFromInline let sysMalloc: @convention(c) (size_t) -> UnsafeMutableRawPointer? = malloc -@usableFromInline let sysRealloc: @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer? = realloc +@usableFromInline let sysRealloc: @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer? = + realloc /// Xcode 13 GM shipped with a bug in the SDK that caused `free`'s first argument to be annotated as /// non-nullable. To that end, we define a thunk through to `free` that matches that constraint, as we @@ -42,19 +43,19 @@ struct _ByteBufferSlice: Sendable { @usableFromInline private(set) var upperBound: ByteBuffer._Index @usableFromInline private(set) var _begin: _UInt24 @inlinable var lowerBound: ByteBuffer._Index { - return UInt32(self._begin) + UInt32(self._begin) } @inlinable var count: Int { // Safe: the only constructors that set this enforce that upperBound > lowerBound, so // this cannot underflow. - return Int(self.upperBound &- self.lowerBound) + Int(self.upperBound &- self.lowerBound) } @inlinable init() { self._begin = .init(0) self.upperBound = .init(0) } @inlinable static var maxSupportedLowerBound: ByteBuffer._Index { - return ByteBuffer._Index(_UInt24.max) + ByteBuffer._Index(_UInt24.max) } } @@ -68,7 +69,7 @@ extension _ByteBufferSlice { extension _ByteBufferSlice: CustomStringConvertible { @usableFromInline var description: String { - return "_ByteBufferSlice { \(self.lowerBound)..<\(self.upperBound) }" + "_ByteBufferSlice { \(self.lowerBound)..<\(self.upperBound) }" } } @@ -82,17 +83,21 @@ public struct ByteBufferAllocator: Sendable { /// therefore it's recommended to reuse `ByteBufferAllocators` where possible instead of creating fresh ones in /// many places. @inlinable public init() { - self.init(hookedMalloc: { sysMalloc($0) }, - hookedRealloc: { sysRealloc($0, $1) }, - hookedFree: { sysFree($0) }, - hookedMemcpy: { $0.copyMemory(from: $1, byteCount: $2) }) + self.init( + hookedMalloc: { sysMalloc($0) }, + hookedRealloc: { sysRealloc($0, $1) }, + hookedFree: { sysFree($0) }, + hookedMemcpy: { $0.copyMemory(from: $1, byteCount: $2) } + ) } @inlinable - internal init(hookedMalloc: @escaping @convention(c) (size_t) -> UnsafeMutableRawPointer?, - hookedRealloc: @escaping @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer?, - hookedFree: @escaping @convention(c) (UnsafeMutableRawPointer) -> Void, - hookedMemcpy: @escaping @convention(c) (UnsafeMutableRawPointer, UnsafeRawPointer, size_t) -> Void) { + internal init( + hookedMalloc: @escaping @convention(c) (size_t) -> UnsafeMutableRawPointer?, + hookedRealloc: @escaping @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer?, + hookedFree: @escaping @convention(c) (UnsafeMutableRawPointer) -> Void, + hookedMemcpy: @escaping @convention(c) (UnsafeMutableRawPointer, UnsafeRawPointer, size_t) -> Void + ) { self.malloc = hookedMalloc self.realloc = hookedRealloc self.free = hookedFree @@ -118,20 +123,24 @@ public struct ByteBufferAllocator: Sendable { } @usableFromInline - internal static let zeroCapacityWithDefaultAllocator = ByteBuffer(allocator: ByteBufferAllocator(), startingCapacity: 0) + internal static let zeroCapacityWithDefaultAllocator = ByteBuffer( + allocator: ByteBufferAllocator(), + startingCapacity: 0 + ) @usableFromInline internal let malloc: @convention(c) (size_t) -> UnsafeMutableRawPointer? - @usableFromInline internal let realloc: @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer? + @usableFromInline internal let realloc: + @convention(c) (UnsafeMutableRawPointer?, size_t) -> UnsafeMutableRawPointer? @usableFromInline internal let free: @convention(c) (UnsafeMutableRawPointer) -> Void @usableFromInline internal let memcpy: @convention(c) (UnsafeMutableRawPointer, UnsafeRawPointer, size_t) -> Void } @inlinable func _toCapacity(_ value: Int) -> ByteBuffer._Capacity { - return ByteBuffer._Capacity(truncatingIfNeeded: value) + ByteBuffer._Capacity(truncatingIfNeeded: value) } @inlinable func _toIndex(_ value: Int) -> ByteBuffer._Index { - return ByteBuffer._Index(truncatingIfNeeded: value) + ByteBuffer._Index(truncatingIfNeeded: value) } /// `ByteBuffer` stores contiguously allocated raw bytes. It is a random and sequential accessible sequence of zero or @@ -283,28 +292,30 @@ public struct ByteBuffer { @inlinable var fullSlice: _ByteBufferSlice { - return _ByteBufferSlice(0.. UnsafeMutableRawPointer { let ptr = allocator.malloc(size_t(bytes))! - /* bind the memory so we can assume it elsewhere to be bound to UInt8 */ + // bind the memory so we can assume it elsewhere to be bound to UInt8 ptr.bindMemory(to: UInt8.self, capacity: Int(bytes)) return ptr } @inlinable func allocateStorage() -> _Storage { - return self.allocateStorage(capacity: self.capacity) + self.allocateStorage(capacity: self.capacity) } @inlinable func allocateStorage(capacity: _Capacity) -> _Storage { let newCapacity = capacity == 0 ? 0 : capacity.nextPowerOf2ClampedToMax() - return _Storage(bytesNoCopy: _Storage._allocateAndPrepareRawMemory(bytes: newCapacity, allocator: self.allocator), - capacity: newCapacity, - allocator: self.allocator) + return _Storage( + bytesNoCopy: _Storage._allocateAndPrepareRawMemory(bytes: newCapacity, allocator: self.allocator), + capacity: newCapacity, + allocator: self.allocator + ) } @inlinable @@ -319,7 +330,7 @@ public struct ByteBuffer { func reallocStorage(capacity minimumNeededCapacity: _Capacity) { let newCapacity = minimumNeededCapacity.nextPowerOf2ClampedToMax() let ptr = self.allocator.realloc(self.bytes, size_t(newCapacity))! - /* bind the memory so we can assume it elsewhere to be bound to UInt8 */ + // bind the memory so we can assume it elsewhere to be bound to UInt8 ptr.bindMemory(to: UInt8.self, capacity: Int(newCapacity)) self.bytes = ptr self.capacity = newCapacity @@ -333,15 +344,17 @@ public struct ByteBuffer { static func reallocated(minimumCapacity: _Capacity, allocator: Allocator) -> _Storage { let newCapacity = minimumCapacity == 0 ? 0 : minimumCapacity.nextPowerOf2ClampedToMax() // TODO: Use realloc if possible - return _Storage(bytesNoCopy: _Storage._allocateAndPrepareRawMemory(bytes: newCapacity, allocator: allocator), - capacity: newCapacity, - allocator: allocator) + return _Storage( + bytesNoCopy: _Storage._allocateAndPrepareRawMemory(bytes: newCapacity, allocator: allocator), + capacity: newCapacity, + allocator: allocator + ) } func dumpBytes(slice: Slice, offset: Int, length: Int) -> String { var desc = "[" let bytes = UnsafeRawBufferPointer(start: self.bytes, count: Int(self.capacity)) - for byte in bytes[Int(slice.lowerBound) + offset ..< Int(slice.lowerBound) + offset + length] { + for byte in bytes[Int(slice.lowerBound) + offset..= 0, "illegal slice: negative lower bound: \(self._slice.lowerBound)") - assert(self._slice.upperBound <= self._storage.capacity, "illegal slice: upper bound (\(self._slice.upperBound)) exceeds capacity: \(self._storage.capacity)") + assert( + self._slice.upperBound <= self._storage.capacity, + "illegal slice: upper bound (\(self._slice.upperBound)) exceeds capacity: \(self._storage.capacity)" + ) } // MARK: Internal API @@ -487,10 +505,13 @@ public struct ByteBuffer { @inline(never) @inlinable @_specialize(where Bytes == CircularBuffer) - mutating func _setSlowPath(bytes: Bytes, at index: _Index) -> _Capacity where Bytes.Element == UInt8 { + mutating func _setSlowPath(bytes: Bytes, at index: _Index) -> _Capacity + where Bytes.Element == UInt8 { func ensureCapacityAndReturnStorageBase(capacity: Int) -> UnsafeMutablePointer { self._ensureAvailableCapacity(_Capacity(capacity), at: index) - let newBytesPtr = UnsafeMutableRawBufferPointer(fastRebase: self._slicedStorageBuffer[Int(index) ..< Int(index) + Int(capacity)]) + let newBytesPtr = UnsafeMutableRawBufferPointer( + fastRebase: self._slicedStorageBuffer[Int(index)..(_ bytes: Bytes, at index: _Index) -> _Capacity where Bytes.Element == UInt8 { + mutating func _setBytes(_ bytes: Bytes, at index: _Index) -> _Capacity + where Bytes.Element == UInt8 { if let written = bytes.withContiguousStorageIfAvailable({ bytes in self._setBytes(UnsafeRawBufferPointer(bytes), at: index) }) { @@ -537,20 +561,20 @@ public struct ByteBuffer { /// trigger a copy of the bytes. @inlinable public var writableBytes: Int { // this cannot over/overflow because both values are positive and writerIndex<=slice.count, checked on ingestion - return Int(_toCapacity(self._slice.count) &- self._writerIndex) + Int(_toCapacity(self._slice.count) &- self._writerIndex) } /// The number of bytes readable (`readableBytes` = `writerIndex` - `readerIndex`). @inlinable public var readableBytes: Int { // this cannot under/overflow because both are positive and writer >= reader (checked on ingestion of bytes). - return Int(self._writerIndex &- self._readerIndex) + Int(self._writerIndex &- self._readerIndex) } /// The current capacity of the storage of this `ByteBuffer`, this is not constant and does _not_ signify the number /// of bytes that have been written to this `ByteBuffer`. @inlinable public var capacity: Int { - return self._slice.count + self._slice.count } /// The current capacity of the underlying storage of this `ByteBuffer`. @@ -558,7 +582,7 @@ public struct ByteBuffer { /// buffer until new data is written. @inlinable public var storageCapacity: Int { - return self._storage.fullSlice.count + self._storage.fullSlice.count } /// Reserves enough space to store the specified number of bytes. @@ -597,7 +621,7 @@ public struct ByteBuffer { /// - Parameter minimumWritableBytes: The minimum number of writable bytes this buffer must have. @inlinable public mutating func reserveCapacity(minimumWritableBytes: Int) { - return self.reserveCapacity(self.writerIndex + minimumWritableBytes) + self.reserveCapacity(self.writerIndex + minimumWritableBytes) } @inlinable @@ -610,8 +634,10 @@ public struct ByteBuffer { @inlinable var _slicedStorageBuffer: UnsafeMutableRawBufferPointer { - return UnsafeMutableRawBufferPointer(start: self._storage.bytes.advanced(by: Int(self._slice.lowerBound)), - count: self._slice.count) + UnsafeMutableRawBufferPointer( + start: self._storage.bytes.advanced(by: Int(self._slice.lowerBound)), + count: self._slice.count + ) } /// Yields a mutable buffer pointer containing this `ByteBuffer`'s readable bytes. You may modify those bytes. @@ -622,7 +648,9 @@ public struct ByteBuffer { /// - body: The closure that will accept the yielded bytes. /// - returns: The value returned by `body`. @inlinable - public mutating func withUnsafeMutableReadableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> T) rethrows -> T { + public mutating func withUnsafeMutableReadableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> T + ) rethrows -> T { self._copyStorageAndRebaseIfNeeded() // this is safe because we always know that readerIndex >= writerIndex let range = Range(uncheckedBounds: (lower: self.readerIndex, upper: self.writerIndex)) @@ -640,7 +668,9 @@ public struct ByteBuffer { /// - body: The closure that will accept the yielded bytes and return the number of bytes written. /// - returns: The number of bytes written. @inlinable - public mutating func withUnsafeMutableWritableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> T) rethrows -> T { + public mutating func withUnsafeMutableWritableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> T + ) rethrows -> T { self._copyStorageAndRebaseIfNeeded() return try body(.init(fastRebase: self._slicedStorageBuffer.dropFirst(self.writerIndex))) } @@ -655,7 +685,10 @@ public struct ByteBuffer { /// - returns: The number of bytes written. @discardableResult @inlinable - public mutating func writeWithUnsafeMutableBytes(minimumWritableBytes: Int, _ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { + public mutating func writeWithUnsafeMutableBytes( + minimumWritableBytes: Int, + _ body: (UnsafeMutableRawBufferPointer) throws -> Int + ) rethrows -> Int { if minimumWritableBytes > 0 { self.reserveCapacity(minimumWritableBytes: minimumWritableBytes) } @@ -664,11 +697,18 @@ public struct ByteBuffer { return bytesWritten } - @available(*, deprecated, message: "please use writeWithUnsafeMutableBytes(minimumWritableBytes:_:) instead to ensure sufficient write capacity.") + @available( + *, + deprecated, + message: + "please use writeWithUnsafeMutableBytes(minimumWritableBytes:_:) instead to ensure sufficient write capacity." + ) @discardableResult @inlinable - public mutating func writeWithUnsafeMutableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> Int) rethrows -> Int { - return try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0, { try body($0) }) + public mutating func writeWithUnsafeMutableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> Int + ) rethrows -> Int { + try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0, { try body($0) }) } /// This vends a pointer to the storage of the `ByteBuffer`. It's marked as _very unsafe_ because it might contain @@ -677,7 +717,7 @@ public struct ByteBuffer { /// - warning: Do not escape the pointer from the closure for later use. @inlinable public func withVeryUnsafeBytes(_ body: (UnsafeRawBufferPointer) throws -> T) rethrows -> T { - return try body(.init(self._slicedStorageBuffer)) + try body(.init(self._slicedStorageBuffer)) } /// This vends a pointer to the storage of the `ByteBuffer`. It's marked as _very unsafe_ because it might contain @@ -685,8 +725,10 @@ public struct ByteBuffer { /// /// - warning: Do not escape the pointer from the closure for later use. @inlinable - public mutating func withVeryUnsafeMutableBytes(_ body: (UnsafeMutableRawBufferPointer) throws -> T) rethrows -> T { - self._copyStorageAndRebaseIfNeeded() // this will trigger a CoW if necessary + public mutating func withVeryUnsafeMutableBytes( + _ body: (UnsafeMutableRawBufferPointer) throws -> T + ) rethrows -> T { + self._copyStorageAndRebaseIfNeeded() // this will trigger a CoW if necessary return try body(.init(self._slicedStorageBuffer)) } @@ -716,7 +758,9 @@ public struct ByteBuffer { /// - body: The closure that will accept the yielded bytes and the `storageManagement`. /// - returns: The value returned by `body`. @inlinable - public func withUnsafeReadableBytesWithStorageManagement(_ body: (UnsafeRawBufferPointer, Unmanaged) throws -> T) rethrows -> T { + public func withUnsafeReadableBytesWithStorageManagement( + _ body: (UnsafeRawBufferPointer, Unmanaged) throws -> T + ) rethrows -> T { let storageReference: Unmanaged = Unmanaged.passUnretained(self._storage) // This is safe, writerIndex >= readerIndex let range = Range(uncheckedBounds: (lower: self.readerIndex, upper: self.writerIndex)) @@ -725,7 +769,9 @@ public struct ByteBuffer { /// See `withUnsafeReadableBytesWithStorageManagement` and `withVeryUnsafeBytes`. @inlinable - public func withVeryUnsafeBytesWithStorageManagement(_ body: (UnsafeRawBufferPointer, Unmanaged) throws -> T) rethrows -> T { + public func withVeryUnsafeBytesWithStorageManagement( + _ body: (UnsafeRawBufferPointer, Unmanaged) throws -> T + ) rethrows -> T { let storageReference: Unmanaged = Unmanaged.passUnretained(self._storage) return try body(.init(self._slicedStorageBuffer), storageReference) } @@ -754,13 +800,16 @@ public struct ByteBuffer { /// not readable in the initial `ByteBuffer`. @inlinable public func getSlice(at index: Int, length: Int) -> ByteBuffer? { - return self.getSlice_inlineAlways(at: index, length: length) + self.getSlice_inlineAlways(at: index, length: length) } @inline(__always) @inlinable internal func getSlice_inlineAlways(at index: Int, length: Int) -> ByteBuffer? { - guard index >= 0 && length >= 0 && index >= self.readerIndex && length <= self.writerIndex && index <= self.writerIndex &- length else { + guard + index >= 0 && length >= 0 && index >= self.readerIndex && length <= self.writerIndex + && index <= self.writerIndex &- length + else { return nil } let index = _toIndex(index) @@ -789,8 +838,14 @@ public struct ByteBuffer { // 2. `length` <= `self.writerIndex` (see `guard`s) // 3. `sliceStartIndex` + `self._slice.count` is always safe (because that's `self._slice.upperBound`. // - The range construction is safe because `length` >= 0 (see `guard` at the beginning of the function). - new._slice = _ByteBufferSlice(Range(uncheckedBounds: (lower: sliceStartIndex, - upper: sliceStartIndex &+ length))) + new._slice = _ByteBufferSlice( + Range( + uncheckedBounds: ( + lower: sliceStartIndex, + upper: sliceStartIndex &+ length + ) + ) + ) new._moveReaderIndex(to: 0) new._moveWriterIndex(to: length) return new @@ -815,8 +870,10 @@ public struct ByteBuffer { if isKnownUniquelyReferenced(&self._storage) { self._storage.bytes.advanced(by: Int(self._slice.lowerBound)) - .copyMemory(from: self._storage.bytes.advanced(by: Int(self._slice.lowerBound + self._readerIndex)), - byteCount: self.readableBytes) + .copyMemory( + from: self._storage.bytes.advanced(by: Int(self._slice.lowerBound + self._readerIndex)), + byteCount: self.readableBytes + ) let indexShift = self._readerIndex self._moveReaderIndex(to: 0) self._moveWriterIndex(to: self._writerIndex - indexShift) @@ -830,14 +887,14 @@ public struct ByteBuffer { /// newly allocated `ByteBuffer`. @inlinable public var readerIndex: Int { - return Int(self._readerIndex) + Int(self._readerIndex) } /// The write index or the number of bytes previously written to this `ByteBuffer`. `writerIndex` is `0` for a /// newly allocated `ByteBuffer`. @inlinable public var writerIndex: Int { - return Int(self._writerIndex) + Int(self._writerIndex) } /// Set both reader index and writer index to `0`. This will reset the state of this `ByteBuffer` to the state @@ -869,7 +926,7 @@ public struct ByteBuffer { public mutating func clear(minimumCapacity: UInt32) { self.clear(minimumCapacity: Int(minimumCapacity)) } - + /// Set both reader index and writer index to `0`. This will reset the state of this `ByteBuffer` to the state /// of a freshly allocated one, if possible without allocations. This is the cheapest way to recycle a `ByteBuffer` /// for a new use-case. @@ -883,7 +940,7 @@ public struct ByteBuffer { public mutating func clear(minimumCapacity: Int) { precondition(minimumCapacity >= 0, "Cannot have a minimum capacity < 0") precondition(minimumCapacity <= _Capacity.max, "Minimum capacity must be <= \(_Capacity.max)") - + let minimumCapacity = _Capacity(minimumCapacity) if !isKnownUniquelyReferenced(&self._storage) { self._storage = self._storage.allocateStorage(capacity: minimumCapacity) @@ -906,7 +963,7 @@ extension ByteBuffer: CustomStringConvertible, CustomDebugStringConvertible { /// /// - returns: A description of this `ByteBuffer`. public var description: String { - return """ + """ ByteBuffer { \ readerIndex: \(self.readerIndex), \ writerIndex: \(self.writerIndex), \ @@ -928,24 +985,24 @@ extension ByteBuffer: CustomStringConvertible, CustomDebugStringConvertible { /// /// - returns: A description of this `ByteBuffer` useful for debugging. public var debugDescription: String { - return "\(self.description)\nreadable bytes (max 1k): \(self._storage.dumpBytes(slice: self._slice, offset: self.readerIndex, length: min(1024, self.readableBytes)))" + "\(self.description)\nreadable bytes (max 1k): \(self._storage.dumpBytes(slice: self._slice, offset: self.readerIndex, length: min(1024, self.readableBytes)))" } } -/* change types to the user visible `Int` */ +// change types to the user visible `Int` extension ByteBuffer { /// Copy the collection of `bytes` into the `ByteBuffer` at `index`. Does not move the writer index. @discardableResult @inlinable public mutating func setBytes(_ bytes: Bytes, at index: Int) -> Int where Bytes.Element == UInt8 { - return Int(self._setBytes(bytes, at: _toIndex(index))) + Int(self._setBytes(bytes, at: _toIndex(index))) } /// Copy `bytes` into the `ByteBuffer` at `index`. Does not move the writer index. @discardableResult @inlinable public mutating func setBytes(_ bytes: UnsafeRawBufferPointer, at index: Int) -> Int { - return Int(self._setBytes(bytes, at: _toIndex(index))) + Int(self._setBytes(bytes, at: _toIndex(index))) } /// Move the reader index forward by `offset` bytes. @@ -958,7 +1015,10 @@ extension ByteBuffer { @inlinable public mutating func moveReaderIndex(forwardBy offset: Int) { let newIndex = self._readerIndex + _toIndex(offset) - precondition(newIndex >= 0 && newIndex <= writerIndex, "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))") + precondition( + newIndex >= 0 && newIndex <= writerIndex, + "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))" + ) self._moveReaderIndex(to: newIndex) } @@ -972,7 +1032,10 @@ extension ByteBuffer { @inlinable public mutating func moveReaderIndex(to offset: Int) { let newIndex = _toIndex(offset) - precondition(newIndex >= 0 && newIndex <= writerIndex, "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))") + precondition( + newIndex >= 0 && newIndex <= writerIndex, + "new readerIndex: \(newIndex), expected: range(0, \(writerIndex))" + ) self._moveReaderIndex(to: newIndex) } @@ -986,7 +1049,10 @@ extension ByteBuffer { @inlinable public mutating func moveWriterIndex(forwardBy offset: Int) { let newIndex = self._writerIndex + _toIndex(offset) - precondition(newIndex >= 0 && newIndex <= _toCapacity(self._slice.count),"new writerIndex: \(newIndex), expected: range(0, \(_toCapacity(self._slice.count)))") + precondition( + newIndex >= 0 && newIndex <= _toCapacity(self._slice.count), + "new writerIndex: \(newIndex), expected: range(0, \(_toCapacity(self._slice.count)))" + ) self._moveWriterIndex(to: newIndex) } @@ -1000,7 +1066,10 @@ extension ByteBuffer { @inlinable public mutating func moveWriterIndex(to offset: Int) { let newIndex = _toIndex(offset) - precondition(newIndex >= 0 && newIndex <= _toCapacity(self._slice.count),"new writerIndex: \(newIndex), expected: range(0, \(_toCapacity(self._slice.count)))") + precondition( + newIndex >= 0 && newIndex <= _toCapacity(self._slice.count), + "new writerIndex: \(newIndex), expected: range(0, \(_toCapacity(self._slice.count)))" + ) self._moveWriterIndex(to: newIndex) } } @@ -1061,11 +1130,11 @@ extension ByteBuffer { } } -extension ByteBuffer.CopyBytesError: Hashable { } +extension ByteBuffer.CopyBytesError: Hashable {} extension ByteBuffer.CopyBytesError: CustomDebugStringConvertible { public var debugDescription: String { - return String(describing: self.baseError) + String(describing: self.baseError) } } @@ -1074,7 +1143,7 @@ extension ByteBuffer: Equatable { /// Compare two `ByteBuffer` values. Two `ByteBuffer` values are considered equal if the readable bytes are equal. @inlinable - public static func ==(lhs: ByteBuffer, rhs: ByteBuffer) -> Bool { + public static func == (lhs: ByteBuffer, rhs: ByteBuffer) -> Bool { guard lhs.readableBytes == rhs.readableBytes else { return false } @@ -1143,7 +1212,7 @@ extension ByteBuffer { return nil } - let upperBound = indexFromReaderIndex &+ length // safe, can't overflow, we checked it above. + let upperBound = indexFromReaderIndex &+ length // safe, can't overflow, we checked it above. // uncheckedBounds is safe because `length` is >= 0, so the lower bound will always be lower/equal to upper return Range(uncheckedBounds: (lower: indexFromReaderIndex, upper: upperBound)) diff --git a/Sources/NIOCore/ByteBuffer-hexdump.swift b/Sources/NIOCore/ByteBuffer-hexdump.swift index de5ed26bf0..fa5dde44e8 100644 --- a/Sources/NIOCore/ByteBuffer-hexdump.swift +++ b/Sources/NIOCore/ByteBuffer-hexdump.swift @@ -131,7 +131,7 @@ extension ByteBuffer { result += String(repeating: " ", count: 60 - result.count) // Right column renders the 16 bytes line as ASCII characters, or "." if the character is not printable. - let printableRange = UInt8(ascii: " ") ..< UInt8(ascii: "~") + let printableRange = UInt8(ascii: " ").. String { - switch(format.value) { + switch format.value { case .plain(let maxBytes): if let maxBytes = maxBytes { return self.hexDumpPlain(maxBytes: maxBytes) @@ -263,4 +263,3 @@ extension ByteBuffer { } } } - diff --git a/Sources/NIOCore/ByteBuffer-int.swift b/Sources/NIOCore/ByteBuffer-int.swift index 4277ec6f6e..cc5d23de46 100644 --- a/Sources/NIOCore/ByteBuffer-int.swift +++ b/Sources/NIOCore/ByteBuffer-int.swift @@ -14,7 +14,7 @@ extension ByteBuffer { @inlinable - func _toEndianness (value: T, endianness: Endianness) -> T { + func _toEndianness(value: T, endianness: Endianness) -> T { switch endianness { case .little: return value.littleEndian @@ -48,7 +48,11 @@ extension ByteBuffer { /// - returns: An integer value deserialized from this `ByteBuffer` or `nil` if the bytes of interest are not /// readable. @inlinable - public func getInteger(at index: Int, endianness: Endianness = Endianness.big, as: T.Type = T.self) -> T? { + public func getInteger( + at index: Int, + endianness: Endianness = Endianness.big, + as: T.Type = T.self + ) -> T? { guard let range = self.rangeWithinReadableBytes(index: index, length: MemoryLayout.size) else { return nil } @@ -78,9 +82,11 @@ extension ByteBuffer { /// - returns: The number of bytes written. @discardableResult @inlinable - public mutating func writeInteger(_ integer: T, - endianness: Endianness = .big, - as: T.Type = T.self) -> Int { + public mutating func writeInteger( + _ integer: T, + endianness: Endianness = .big, + as: T.Type = T.self + ) -> Int { let bytesWritten = self.setInteger(integer, at: self.writerIndex, endianness: endianness) self._moveWriterIndex(forwardBy: bytesWritten) return Int(bytesWritten) @@ -96,7 +102,12 @@ extension ByteBuffer { /// - returns: The number of bytes written. @discardableResult @inlinable - public mutating func setInteger(_ integer: T, at index: Int, endianness: Endianness = .big, as: T.Type = T.self) -> Int { + public mutating func setInteger( + _ integer: T, + at index: Int, + endianness: Endianness = .big, + as: T.Type = T.self + ) -> Int { var value = _toEndianness(value: integer, endianness: endianness) return Swift.withUnsafeBytes(of: &value) { ptr in self.setBytes(ptr, at: index) @@ -166,7 +177,7 @@ public enum Endianness: Sendable { public static let host: Endianness = hostEndianness0() private static func hostEndianness0() -> Endianness { - let number: UInt32 = 0x12345678 + let number: UInt32 = 0x1234_5678 return number == number.bigEndian ? .big : .little } @@ -176,5 +187,3 @@ public enum Endianness: Sendable { /// little endian, the least significant byte (LSB) is at the lowest address case little } - - diff --git a/Sources/NIOCore/ByteBuffer-lengthPrefix.swift b/Sources/NIOCore/ByteBuffer-lengthPrefix.swift index 300717294f..2941e7c467 100644 --- a/Sources/NIOCore/ByteBuffer-lengthPrefix.swift +++ b/Sources/NIOCore/ByteBuffer-lengthPrefix.swift @@ -19,9 +19,13 @@ extension ByteBuffer { case messageCouldNotBeReadSuccessfully } private var baseError: BaseError - - public static let messageLengthDoesNotFitExactlyIntoRequiredIntegerFormat: LengthPrefixError = .init(baseError: .messageLengthDoesNotFitExactlyIntoRequiredIntegerFormat) - public static let messageCouldNotBeReadSuccessfully: LengthPrefixError = .init(baseError: .messageCouldNotBeReadSuccessfully) + + public static let messageLengthDoesNotFitExactlyIntoRequiredIntegerFormat: LengthPrefixError = .init( + baseError: .messageLengthDoesNotFitExactlyIntoRequiredIntegerFormat + ) + public static let messageCouldNotBeReadSuccessfully: LengthPrefixError = .init( + baseError: .messageCouldNotBeReadSuccessfully + ) } } @@ -41,44 +45,44 @@ extension ByteBuffer { writeMessage: (inout ByteBuffer) throws -> Int ) throws -> Int where Integer: FixedWidthInteger { var totalBytesWritten = 0 - + let lengthPrefixIndex = self.writerIndex // Write a zero as a placeholder which will later be overwritten by the actual number of bytes written totalBytesWritten += self.writeInteger(.zero, endianness: endianness, as: Integer.self) - + let startWriterIndex = self.writerIndex let messageLength = try writeMessage(&self) let endWriterIndex = self.writerIndex - + totalBytesWritten += messageLength - + let actualBytesWritten = endWriterIndex - startWriterIndex assert( - actualBytesWritten == messageLength, + actualBytesWritten == messageLength, "writeMessage returned \(messageLength) bytes, but actually \(actualBytesWritten) bytes were written, but they should be the same" ) - + guard let lengthPrefix = Integer(exactly: messageLength) else { throw LengthPrefixError.messageLengthDoesNotFitExactlyIntoRequiredIntegerFormat } - + self.setInteger(lengthPrefix, at: lengthPrefixIndex, endianness: endianness, as: Integer.self) - + return totalBytesWritten } } extension ByteBuffer { - /// Reads an `Integer` from `self`, reads a slice of that length and passes it to `readMessage`. + /// Reads an `Integer` from `self`, reads a slice of that length and passes it to `readMessage`. /// It is checked that `readMessage` returns a non-nil value. - /// + /// /// The `readerIndex` is **not** moved forward if the length prefix could not be read or `self` does not contain enough bytes. Otherwise `readerIndex` is moved forward even if `readMessage` throws or returns nil. /// - Parameters: /// - endianness: The endianness of the length prefix `Integer` in this `ByteBuffer` (defaults to big endian). /// - integer: the desired `Integer` type used to read the length prefix /// - readMessage: A closure that takes a `ByteBuffer` slice which contains the message after the length prefix /// - Throws: if `readMessage` returns nil - /// - Returns: `nil` if the length prefix could not be read, + /// - Returns: `nil` if the length prefix could not be read, /// the length prefix is negative or /// the buffer does not contain enough bytes to read a message of this length. /// Otherwise the result of `readMessage`. @@ -96,14 +100,14 @@ extension ByteBuffer { } return result } - + /// Reads an `Integer` from `self` and reads a slice of that length from `self` and returns it. - /// + /// /// If nil is returned, `readerIndex` is **not** moved forward. /// - Parameters: /// - endianness: The endianness of the length prefix `Integer` in this `ByteBuffer` (defaults to big endian). /// - integer: the desired `Integer` type used to read the length prefix - /// - Returns: `nil` if the length prefix could not be read, + /// - Returns: `nil` if the length prefix could not be read, /// the length prefix is negative or /// the buffer does not contain enough bytes to read a message of this length. /// Otherwise the message after the length prefix. @@ -112,19 +116,20 @@ extension ByteBuffer { endianness: Endianness = .big, as integer: Integer.Type ) -> ByteBuffer? where Integer: FixedWidthInteger { - guard let result = self.getLengthPrefixedSlice(at: self.readerIndex, endianness: endianness, as: Integer.self) else { + guard let result = self.getLengthPrefixedSlice(at: self.readerIndex, endianness: endianness, as: Integer.self) + else { return nil } self._moveReaderIndex(forwardBy: MemoryLayout.size + result.readableBytes) return result } - + /// Gets an `Integer` from `self` and gets a slice of that length from `self` and returns it. - /// + /// /// - Parameters: /// - endianness: The endianness of the length prefix `Integer` in this `ByteBuffer` (defaults to big endian). /// - integer: the desired `Integer` type used to get the length prefix - /// - Returns: `nil` if the length prefix could not be read, + /// - Returns: `nil` if the length prefix could not be read, /// the length prefix is negative or /// the buffer does not contain enough bytes to read a message of this length. /// Otherwise the message after the length prefix. @@ -135,12 +140,12 @@ extension ByteBuffer { as integer: Integer.Type ) -> ByteBuffer? where Integer: FixedWidthInteger { guard let lengthPrefix = self.getInteger(at: index, endianness: endianness, as: Integer.self), - let messageLength = Int(exactly: lengthPrefix), - let messageBuffer = self.getSlice(at: index + MemoryLayout.size, length: messageLength) + let messageLength = Int(exactly: lengthPrefix), + let messageBuffer = self.getSlice(at: index + MemoryLayout.size, length: messageLength) else { return nil } - + return messageBuffer } } diff --git a/Sources/NIOCore/ByteBuffer-multi-int.swift b/Sources/NIOCore/ByteBuffer-multi-int.swift index d602d947f4..175f7d3fc6 100644 --- a/Sources/NIOCore/ByteBuffer-multi-int.swift +++ b/Sources/NIOCore/ByteBuffer-multi-int.swift @@ -17,7 +17,10 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2).Type = (T1, T2).self) -> (T1, T2)? { + public mutating func readMultipleIntegers( + endianness: Endianness = .big, + as: (T1, T2).Type = (T1, T2).self + ) -> (T1, T2)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -30,7 +33,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -53,7 +56,12 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, endianness: Endianness = .big, as: (T1, T2).Type = (T1, T2).self) -> Int { + public mutating func writeMultipleIntegers( + _ value1: T1, + _ value2: T2, + endianness: Endianness = .big, + as: (T1, T2).Type = (T1, T2).self + ) -> Int { var v1: T1 var v2: T2 switch endianness { @@ -71,7 +79,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -83,7 +91,10 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3).Type = (T1, T2, T3).self) -> (T1, T2, T3)? { + public mutating func readMultipleIntegers( + endianness: Endianness = .big, + as: (T1, T2, T3).Type = (T1, T2, T3).self + ) -> (T1, T2, T3)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -98,7 +109,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -125,7 +136,13 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, endianness: Endianness = .big, as: (T1, T2, T3).Type = (T1, T2, T3).self) -> Int { + public mutating func writeMultipleIntegers( + _ value1: T1, + _ value2: T2, + _ value3: T3, + endianness: Endianness = .big, + as: (T1, T2, T3).Type = (T1, T2, T3).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -147,7 +164,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -161,7 +178,12 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4).Type = (T1, T2, T3, T4).self) -> (T1, T2, T3, T4)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger + >(endianness: Endianness = .big, as: (T1, T2, T3, T4).Type = (T1, T2, T3, T4).self) -> (T1, T2, T3, T4)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -178,7 +200,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -209,7 +231,19 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, endianness: Endianness = .big, as: (T1, T2, T3, T4).Type = (T1, T2, T3, T4).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + endianness: Endianness = .big, + as: (T1, T2, T3, T4).Type = (T1, T2, T3, T4).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -235,7 +269,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -251,7 +285,14 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5).Type = (T1, T2, T3, T4, T5).self) -> (T1, T2, T3, T4, T5)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger + >(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5).Type = (T1, T2, T3, T4, T5).self) -> (T1, T2, T3, T4, T5)? + { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -270,7 +311,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -298,14 +339,31 @@ extension ByteBuffer { case .big: return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5)) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5).Type = (T1, T2, T3, T4, T5).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5).Type = (T1, T2, T3, T4, T5).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -335,7 +393,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -353,7 +411,17 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6).Type = (T1, T2, T3, T4, T5, T6).self) -> (T1, T2, T3, T4, T5, T6)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6).Type = (T1, T2, T3, T4, T5, T6).self + ) -> (T1, T2, T3, T4, T5, T6)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -374,7 +442,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -404,16 +472,38 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6).Type = (T1, T2, T3, T4, T5, T6).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6).Type = (T1, T2, T3, T4, T5, T6).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -447,7 +537,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -467,7 +557,18 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7).Type = (T1, T2, T3, T4, T5, T6, T7).self) -> (T1, T2, T3, T4, T5, T6, T7)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7).Type = (T1, T2, T3, T4, T5, T6, T7).self + ) -> (T1, T2, T3, T4, T5, T6, T7)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -490,7 +591,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -524,16 +625,40 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7).Type = (T1, T2, T3, T4, T5, T6, T7).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7).Type = (T1, T2, T3, T4, T5, T6, T7).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -571,7 +696,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -593,7 +718,19 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8).Type = (T1, T2, T3, T4, T5, T6, T7, T8).self) -> (T1, T2, T3, T4, T5, T6, T7, T8)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8).Type = (T1, T2, T3, T4, T5, T6, T7, T8).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -618,7 +755,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -656,16 +793,42 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8).Type = (T1, T2, T3, T4, T5, T6, T7, T8).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8).Type = (T1, T2, T3, T4, T5, T6, T7, T8).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -707,7 +870,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -731,7 +894,20 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -758,7 +934,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -800,16 +976,45 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -855,7 +1060,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -881,7 +1086,21 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -910,7 +1129,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -956,16 +1175,47 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1015,7 +1265,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -1043,7 +1293,22 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -1074,7 +1339,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -1124,16 +1389,50 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), T11(bigEndian: v11)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), + T11(bigEndian: v11) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, _ value11: T11, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + _ value11: T11, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1187,7 +1486,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -1217,7 +1516,25 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 + ).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -1250,7 +1567,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -1304,16 +1621,54 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), T11(bigEndian: v11), T12(bigEndian: v12)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), + T11(bigEndian: v11), T12(bigEndian: v12) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, _ value11: T11, _ value12: T12, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + _ value11: T11, + _ value12: T12, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 + ).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1371,7 +1726,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -1403,7 +1758,26 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 + ).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -1438,7 +1812,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -1496,16 +1870,57 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), + T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), T13(littleEndian: v13)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), + T13(littleEndian: v13) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, _ value11: T11, _ value12: T12, _ value13: T13, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + _ value11: T11, + _ value12: T12, + _ value13: T13, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 + ).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1567,7 +1982,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -1601,7 +2016,27 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger, + T14: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14 + ).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -1638,7 +2073,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -1700,16 +2135,59 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13), T14(bigEndian: v14)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), + T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13), T14(bigEndian: v14) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), T13(littleEndian: v13), T14(littleEndian: v14)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), + T13(littleEndian: v13), T14(littleEndian: v14) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, _ value11: T11, _ value12: T12, _ value13: T13, _ value14: T14, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger, + T14: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + _ value11: T11, + _ value12: T12, + _ value13: T13, + _ value14: T14, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14 + ).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1775,7 +2253,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) @@ -1811,7 +2289,28 @@ extension ByteBuffer { @inlinable @_alwaysEmitIntoClient - public mutating func readMultipleIntegers(endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).self) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15)? { + public mutating func readMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger, + T14: FixedWidthInteger, + T15: FixedWidthInteger + >( + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15 + ).self + ) -> (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15)? { var bytesRequired: Int = MemoryLayout.size bytesRequired &+= MemoryLayout.size bytesRequired &+= MemoryLayout.size @@ -1850,7 +2349,7 @@ extension ByteBuffer { var offset = 0 self.readWithUnsafeReadableBytes { ptr -> Int in assert(ptr.count >= bytesRequired) - let basePtr = ptr.baseAddress! // safe, ptr is non-empty + let basePtr = ptr.baseAddress! // safe, ptr is non-empty withUnsafeMutableBytes(of: &v1) { destPtr in destPtr.baseAddress!.copyMemory(from: basePtr + offset, byteCount: MemoryLayout.size) } @@ -1916,16 +2415,61 @@ extension ByteBuffer { } switch endianness { case .big: - return (T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13), T14(bigEndian: v14), T15(bigEndian: v15)) + return ( + T1(bigEndian: v1), T2(bigEndian: v2), T3(bigEndian: v3), T4(bigEndian: v4), T5(bigEndian: v5), + T6(bigEndian: v6), T7(bigEndian: v7), T8(bigEndian: v8), T9(bigEndian: v9), T10(bigEndian: v10), + T11(bigEndian: v11), T12(bigEndian: v12), T13(bigEndian: v13), T14(bigEndian: v14), T15(bigEndian: v15) + ) case .little: - return (T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), T13(littleEndian: v13), T14(littleEndian: v14), T15(littleEndian: v15)) + return ( + T1(littleEndian: v1), T2(littleEndian: v2), T3(littleEndian: v3), T4(littleEndian: v4), + T5(littleEndian: v5), T6(littleEndian: v6), T7(littleEndian: v7), T8(littleEndian: v8), + T9(littleEndian: v9), T10(littleEndian: v10), T11(littleEndian: v11), T12(littleEndian: v12), + T13(littleEndian: v13), T14(littleEndian: v14), T15(littleEndian: v15) + ) } } @inlinable @_alwaysEmitIntoClient @discardableResult - public mutating func writeMultipleIntegers(_ value1: T1, _ value2: T2, _ value3: T3, _ value4: T4, _ value5: T5, _ value6: T6, _ value7: T7, _ value8: T8, _ value9: T9, _ value10: T10, _ value11: T11, _ value12: T12, _ value13: T13, _ value14: T14, _ value15: T15, endianness: Endianness = .big, as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).Type = (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).self) -> Int { + public mutating func writeMultipleIntegers< + T1: FixedWidthInteger, + T2: FixedWidthInteger, + T3: FixedWidthInteger, + T4: FixedWidthInteger, + T5: FixedWidthInteger, + T6: FixedWidthInteger, + T7: FixedWidthInteger, + T8: FixedWidthInteger, + T9: FixedWidthInteger, + T10: FixedWidthInteger, + T11: FixedWidthInteger, + T12: FixedWidthInteger, + T13: FixedWidthInteger, + T14: FixedWidthInteger, + T15: FixedWidthInteger + >( + _ value1: T1, + _ value2: T2, + _ value3: T3, + _ value4: T4, + _ value5: T5, + _ value6: T6, + _ value7: T7, + _ value8: T8, + _ value9: T9, + _ value10: T10, + _ value11: T11, + _ value12: T12, + _ value13: T13, + _ value14: T14, + _ value15: T15, + endianness: Endianness = .big, + as: (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15).Type = ( + T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15 + ).self + ) -> Int { var v1: T1 var v2: T2 var v3: T3 @@ -1995,7 +2539,7 @@ extension ByteBuffer { return self.writeWithUnsafeMutableBytes(minimumWritableBytes: spaceNeeded) { ptr -> Int in assert(ptr.count >= spaceNeeded) var offset = 0 - let basePtr = ptr.baseAddress! // safe: pointer is non zero length + let basePtr = ptr.baseAddress! // safe: pointer is non zero length (basePtr + offset).copyMemory(from: &v1, byteCount: MemoryLayout.size) offset = offset &+ MemoryLayout.size (basePtr + offset).copyMemory(from: &v2, byteCount: MemoryLayout.size) diff --git a/Sources/NIOCore/ByteBuffer-views.swift b/Sources/NIOCore/ByteBuffer-views.swift index 2731338513..9c0c0ab488 100644 --- a/Sources/NIOCore/ByteBuffer-views.swift +++ b/Sources/NIOCore/ByteBuffer-views.swift @@ -21,8 +21,8 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { public typealias Index = Int public typealias SubSequence = ByteBufferView - /* private but usableFromInline */ @usableFromInline var _buffer: ByteBuffer - /* private but usableFromInline */ @usableFromInline var _range: Range + @usableFromInline var _buffer: ByteBuffer + @usableFromInline var _range: Range @inlinable internal init(buffer: ByteBuffer, range: Range) { @@ -34,37 +34,41 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { /// Creates a `ByteBufferView` from the readable bytes of the given `buffer`. @inlinable public init(_ buffer: ByteBuffer) { - self = ByteBufferView(buffer: buffer, range: buffer.readerIndex ..< buffer.writerIndex) + self = ByteBufferView(buffer: buffer, range: buffer.readerIndex..(_ body: (UnsafeRawBufferPointer) throws -> R) rethrows -> R { - return try self._buffer.withVeryUnsafeBytes { ptr in - try body(UnsafeRawBufferPointer(start: ptr.baseAddress!.advanced(by: self._range.lowerBound), - count: self._range.count)) + try self._buffer.withVeryUnsafeBytes { ptr in + try body( + UnsafeRawBufferPointer( + start: ptr.baseAddress!.advanced(by: self._range.lowerBound), + count: self._range.count + ) + ) } } @inlinable public var startIndex: Index { - return self._range.lowerBound + self._range.lowerBound } @inlinable public var endIndex: Index { - return self._range.upperBound + self._range.upperBound } @inlinable public func index(after i: Index) -> Index { - return i + 1 + i + 1 } @inlinable public var count: Int { // Unchecked is safe here: Range enforces that upperBound is strictly greater than // lower bound, and we guarantee that _range.lowerBound >= 0. - return self._range.upperBound &- self._range.lowerBound + self._range.upperBound &- self._range.lowerBound } @inlinable @@ -73,7 +77,7 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { guard position >= self._range.lowerBound && position < self._range.upperBound else { preconditionFailure("index \(position) out of range") } - return self._buffer.getInteger(at: position)! // range check above + return self._buffer.getInteger(at: position)! // range check above } set { guard position >= self._range.lowerBound && position < self._range.upperBound else { @@ -86,7 +90,7 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { @inlinable public subscript(range: Range) -> ByteBufferView { get { - return ByteBufferView(buffer: self._buffer, range: range) + ByteBufferView(buffer: self._buffer, range: range) } set { self.replaceSubrange(range, with: newValue) @@ -95,35 +99,41 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { @inlinable public func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { - return try self.withUnsafeBytes { bytes in - return try body(bytes.bindMemory(to: UInt8.self)) + try self.withUnsafeBytes { bytes in + try body(bytes.bindMemory(to: UInt8.self)) } } @inlinable - public func _customIndexOfEquatableElement(_ element : Element) -> Index?? { - return .some(self.withUnsafeBytes { ptr -> Index? in - return ptr.firstIndex(of: element).map { $0 + self._range.lowerBound } - }) + public func _customIndexOfEquatableElement(_ element: Element) -> Index?? { + .some( + self.withUnsafeBytes { ptr -> Index? in + ptr.firstIndex(of: element).map { $0 + self._range.lowerBound } + } + ) } @inlinable public func _customLastIndexOfEquatableElement(_ element: Element) -> Index?? { - return .some(self.withUnsafeBytes { ptr -> Index? in - return ptr.lastIndex(of: element).map { $0 + self._range.lowerBound } - }) + .some( + self.withUnsafeBytes { ptr -> Index? in + ptr.lastIndex(of: element).map { $0 + self._range.lowerBound } + } + ) } - + @inlinable public func _customContainsEquatableElement(_ element: Element) -> Bool? { - return .some(self.withUnsafeBytes { ptr -> Bool in - return ptr.contains(element) - }) + .some( + self.withUnsafeBytes { ptr -> Bool in + ptr.contains(element) + } + ) } @inlinable public func _copyContents( - initializing ptr: UnsafeMutableBufferPointer + initializing ptr: UnsafeMutableBufferPointer ) -> (Iterator, UnsafeMutableBufferPointer.Index) { precondition(ptr.count >= self.count) @@ -141,10 +151,10 @@ public struct ByteBufferView: RandomAccessCollection, Sendable { // These are implemented as no-ops for performance reasons. @inlinable public func _failEarlyRangeCheck(_ index: Index, bounds: Range) {} - + @inlinable public func _failEarlyRangeCheck(_ index: Index, bounds: ClosedRange) {} - + @inlinable public func _failEarlyRangeCheck(_ range: Range, bounds: Range) {} } @@ -176,7 +186,7 @@ extension ByteBufferView: RangeReplaceableCollection { // ``CollectionOfOne`` has no witness for // ``Sequence.withContiguousStorageIfAvailable(_:)``. so we do this instead: self._buffer.setInteger(byte, at: self._range.upperBound) - self._range = self._range.lowerBound ..< self._range.upperBound.advanced(by: 1) + self._range = self._range.lowerBound..(contentsOf bytes: Bytes) where Bytes.Element == UInt8 { let written = self._buffer.setBytes(bytes, at: self._range.upperBound) - self._range = self._range.lowerBound ..< self._range.upperBound.advanced(by: written) + self._range = self._range.lowerBound..(_ subrange: Range, with newElements: C) where ByteBufferView.Element == C.Element { - precondition(subrange.startIndex >= self.startIndex && subrange.endIndex <= self.endIndex, - "subrange out of bounds") + public mutating func replaceSubrange(_ subrange: Range, with newElements: C) + where ByteBufferView.Element == C.Element { + precondition( + subrange.startIndex >= self.startIndex && subrange.endIndex <= self.endIndex, + "subrange out of bounds" + ) if newElements.count == subrange.count { self._buffer.setBytes(newElements, at: subrange.startIndex) @@ -201,9 +214,11 @@ extension ByteBufferView: RangeReplaceableCollection { // Remove the unwanted bytes between the newly copied bytes and the end of the subrange. // try! is fine here: the copied range is within the view and the length can't be negative. - try! self._buffer.copyBytes(at: subrange.endIndex, - to: subrange.startIndex.advanced(by: newElements.count), - length: subrange.endIndex.distance(to: self._buffer.writerIndex)) + try! self._buffer.copyBytes( + at: subrange.endIndex, + to: subrange.startIndex.advanced(by: newElements.count), + length: subrange.endIndex.distance(to: self._buffer.writerIndex) + ) // Shorten the range. let removedBytes = subrange.count - newElements.count @@ -212,9 +227,11 @@ extension ByteBufferView: RangeReplaceableCollection { } else { // Make space for the new elements. // try! is fine here: the copied range is within the view and the length can't be negative. - try! self._buffer.copyBytes(at: subrange.endIndex, - to: subrange.startIndex.advanced(by: newElements.count), - length: subrange.endIndex.distance(to: self._buffer.writerIndex)) + try! self._buffer.copyBytes( + at: subrange.endIndex, + to: subrange.startIndex.advanced(by: newElements.count), + length: subrange.endIndex.distance(to: self._buffer.writerIndex) + ) // Replace the bytes. self._buffer.setBytes(newElements, at: subrange.startIndex) @@ -222,7 +239,7 @@ extension ByteBufferView: RangeReplaceableCollection { // Widen the range. let additionalByteCount = newElements.count - subrange.count self._buffer.moveWriterIndex(forwardBy: additionalByteCount) - self._range = self._range.startIndex ..< self._range.endIndex.advanced(by: additionalByteCount) + self._range = self._range.startIndex.. Bool { guard lhs._range.count == rhs._range.count else { - return false + return false } // A well-formed ByteBufferView can never have a range that is out-of-bounds of the backing ByteBuffer. // As a result, these getSlice calls can never fail, and we'd like to know it if they do. let leftBufferSlice = lhs._buffer.getSlice(at: lhs._range.startIndex, length: lhs._range.count)! let rightBufferSlice = rhs._buffer.getSlice(at: rhs._range.startIndex, length: rhs._range.count)! - + return leftBufferSlice == rightBufferSlice } } diff --git a/Sources/NIOCore/Channel.swift b/Sources/NIOCore/Channel.swift index b2523e955f..7b2a49c6fb 100644 --- a/Sources/NIOCore/Channel.swift +++ b/Sources/NIOCore/Channel.swift @@ -150,7 +150,7 @@ public protocol Channel: AnyObject, ChannelOutboundInvoker, _NIOPreconcurrencySe extension Channel { /// Default implementation: `NIOSynchronousChannelOptions` are not supported. public var syncOptions: NIOSynchronousChannelOptions? { - return nil + nil } } @@ -210,7 +210,6 @@ extension Channel { } } - /// Provides special extension to make writing data to the `Channel` easier by removing the need to wrap data in `NIOAny` manually. extension Channel { @@ -218,7 +217,7 @@ extension Channel { /// /// - seealso: `ChannelOutboundInvoker.write`. public func write(_ any: T) -> EventLoopFuture { - return self.write(NIOAny(any)) + self.write(NIOAny(any)) } /// Write data into the `Channel`, automatically wrapping with `NIOAny`. @@ -232,10 +231,9 @@ extension Channel { /// /// - seealso: `ChannelOutboundInvoker.writeAndFlush`. public func writeAndFlush(_ any: T) -> EventLoopFuture { - return self.writeAndFlush(NIOAny(any)) + self.writeAndFlush(NIOAny(any)) } - /// Write and flush data into the `Channel`, automatically wrapping with `NIOAny`. /// /// - seealso: `ChannelOutboundInvoker.writeAndFlush`. @@ -262,7 +260,7 @@ extension ChannelCore { /// - returns: The content of the `NIOAny`. @inlinable public func unwrapData(_ data: NIOAny, as: T.Type = T.self) -> T { - return data.forceAs() + data.forceAs() } /// Attempts to unwrap the given `NIOAny` as a specific concrete type. @@ -284,7 +282,7 @@ extension ChannelCore { /// are doing something _extremely_ unusual. @inlinable public func tryUnwrapData(_ data: NIOAny, as: T.Type = T.self) -> T? { - return data.tryAs() + data.tryAs() } /// Removes the `ChannelHandler`s from the `ChannelPipeline` belonging to `channel`, and @@ -384,7 +382,7 @@ extension ChannelError { static let _unremovableHandler: any Error = ChannelError.unremovableHandler } -extension ChannelError: Equatable { } +extension ChannelError: Equatable {} /// The removal of a `ChannelHandler` using `ChannelPipeline.removeHandler` has been attempted more than once. public struct NIOAttemptedToRemoveHandlerMultipleTimesError: Error {} diff --git a/Sources/NIOCore/ChannelHandlers.swift b/Sources/NIOCore/ChannelHandlers.swift index b63afc12dd..c2cf3c39c4 100644 --- a/Sources/NIOCore/ChannelHandlers.swift +++ b/Sources/NIOCore/ChannelHandlers.swift @@ -15,7 +15,6 @@ // // - /// A `ChannelHandler` that implements a backoff for a `ServerChannel` when accept produces an `IOError`. /// These errors are often recoverable by reducing the rate at which we call accept. public final class AcceptBackoffHandler: ChannelDuplexHandler, RemovableChannelHandler { @@ -28,7 +27,7 @@ public final class AcceptBackoffHandler: ChannelDuplexHandler, RemovableChannelH /// Default implementation used as `backoffProvider` which delays accept by 1 second. public static func defaultBackoffProvider(error: IOError) -> TimeAmount? { - return .seconds(1) + .seconds(1) } /// Create a new instance @@ -108,10 +107,8 @@ public final class AcceptBackoffHandler: ChannelDuplexHandler, RemovableChannelH @available(*, unavailable) extension AcceptBackoffHandler: Sendable {} -/** - ChannelHandler implementation which enforces back-pressure by stopping to read from the remote peer when it cannot write back fast enough. - It will start reading again once pending data was written. -*/ +/// ChannelHandler implementation which enforces back-pressure by stopping to read from the remote peer when it cannot write back fast enough. +/// It will start reading again once pending data was written. public final class BackPressureHandler: ChannelDuplexHandler, RemovableChannelHandler { public typealias OutboundIn = NIOAny public typealias InboundIn = ByteBuffer @@ -121,7 +118,7 @@ public final class BackPressureHandler: ChannelDuplexHandler, RemovableChannelHa private var pendingRead = false private var writable: Bool = true - public init() { } + public init() {} public func read(context: ChannelHandlerContext) { if writable { @@ -218,7 +215,7 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl } public func channelReadComplete(context: ChannelHandlerContext) { - if (readTimeout != nil || allTimeout != nil) && reading { + if (readTimeout != nil || allTimeout != nil) && reading { lastReadTime = .now() reading = false } @@ -246,32 +243,41 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl } private func makeReadTimeoutTask(_ context: ChannelHandlerContext, _ timeout: TimeAmount) -> (() -> Void) { - return { - guard self.shouldReschedule(context) else { + { + guard self.shouldReschedule(context) else { return } if self.reading { - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(in: timeout, self.makeReadTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + in: timeout, + self.makeReadTimeoutTask(context, timeout) + ) return } let diff = .now() - self.lastReadTime if diff >= timeout { // Reader is idle - set a new timeout and trigger an event through the pipeline - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(in: timeout, self.makeReadTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + in: timeout, + self.makeReadTimeoutTask(context, timeout) + ) context.fireUserInboundEventTriggered(IdleStateEvent.read) } else { // Read occurred before the timeout - set a new timeout with shorter delay. - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.lastReadTime + timeout, self.makeReadTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + deadline: self.lastReadTime + timeout, + self.makeReadTimeoutTask(context, timeout) + ) } } } private func makeWriteTimeoutTask(_ context: ChannelHandlerContext, _ timeout: TimeAmount) -> (() -> Void) { - return { - guard self.shouldReschedule(context) else { + { + guard self.shouldReschedule(context) else { return } @@ -280,24 +286,33 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl if diff >= timeout { // Writer is idle - set a new timeout and notify the callback. - self.scheduledWriterTask = context.eventLoop.assumeIsolated().scheduleTask(in: timeout, self.makeWriteTimeoutTask(context, timeout)) + self.scheduledWriterTask = context.eventLoop.assumeIsolated().scheduleTask( + in: timeout, + self.makeWriteTimeoutTask(context, timeout) + ) context.fireUserInboundEventTriggered(IdleStateEvent.write) } else { // Write occurred before the timeout - set a new timeout with shorter delay. - self.scheduledWriterTask = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.lastWriteCompleteTime + timeout, self.makeWriteTimeoutTask(context, timeout)) + self.scheduledWriterTask = context.eventLoop.assumeIsolated().scheduleTask( + deadline: self.lastWriteCompleteTime + timeout, + self.makeWriteTimeoutTask(context, timeout) + ) } } } private func makeAllTimeoutTask(_ context: ChannelHandlerContext, _ timeout: TimeAmount) -> (() -> Void) { - return { - guard self.shouldReschedule(context) else { + { + guard self.shouldReschedule(context) else { return } if self.reading { - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(in: timeout, self.makeAllTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + in: timeout, + self.makeAllTimeoutTask(context, timeout) + ) return } let lastRead = self.lastReadTime @@ -307,17 +322,27 @@ public final class IdleStateHandler: ChannelDuplexHandler, RemovableChannelHandl let diff = .now() - latestLast if diff >= timeout { // Reader is idle - set a new timeout and trigger an event through the pipeline - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(in: timeout, self.makeAllTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + in: timeout, + self.makeAllTimeoutTask(context, timeout) + ) context.fireUserInboundEventTriggered(IdleStateEvent.all) } else { // Read occurred before the timeout - set a new timeout with shorter delay. - self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask(deadline: latestLast + timeout, self.makeAllTimeoutTask(context, timeout)) + self.scheduledReaderTask = context.eventLoop.assumeIsolated().scheduleTask( + deadline: latestLast + timeout, + self.makeAllTimeoutTask(context, timeout) + ) } } } - private func schedule(_ context: ChannelHandlerContext, _ amount: TimeAmount?, _ body: @escaping (ChannelHandlerContext, TimeAmount) -> (() -> Void) ) -> Scheduled? { + private func schedule( + _ context: ChannelHandlerContext, + _ amount: TimeAmount?, + _ body: @escaping (ChannelHandlerContext, TimeAmount) -> (() -> Void) + ) -> Scheduled? { if let timeout = amount { return context.eventLoop.assumeIsolated().scheduleTask(in: timeout, body(context, timeout)) } diff --git a/Sources/NIOCore/ChannelInvoker.swift b/Sources/NIOCore/ChannelInvoker.swift index ec07740318..bd91c5160b 100644 --- a/Sources/NIOCore/ChannelInvoker.swift +++ b/Sources/NIOCore/ChannelInvoker.swift @@ -103,7 +103,11 @@ extension ChannelOutboundInvoker { /// - parameters: /// - to: the `SocketAddress` to which we should bind the `Channel`. /// - returns: the future which will be notified once the operation completes. - public func bind(to address: SocketAddress, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { + public func bind( + to address: SocketAddress, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture { let promise = makePromise(file: file, line: line) bind(to: address, promise: promise) return promise.futureResult @@ -113,7 +117,11 @@ extension ChannelOutboundInvoker { /// - parameters: /// - to: the `SocketAddress` to which we should connect the `Channel`. /// - returns: the future which will be notified once the operation completes. - public func connect(to address: SocketAddress, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { + public func connect( + to address: SocketAddress, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture { let promise = makePromise(file: file, line: line) connect(to: address, promise: promise) return promise.futureResult @@ -138,7 +146,8 @@ extension ChannelOutboundInvoker { /// - parameters: /// - data: the data to write /// - returns: the future which will be notified once the `write` operation completes. - public func writeAndFlush(_ data: NIOAny, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { + public func writeAndFlush(_ data: NIOAny, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture + { let promise = makePromise(file: file, line: line) writeAndFlush(data, promise: promise) return promise.futureResult @@ -149,7 +158,8 @@ extension ChannelOutboundInvoker { /// - parameters: /// - mode: the `CloseMode` that is used /// - returns: the future which will be notified once the operation completes. - public func close(mode: CloseMode = .all, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { + public func close(mode: CloseMode = .all, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture + { let promise = makePromise(file: file, line: line) close(mode: mode, promise: promise) return promise.futureResult @@ -160,14 +170,18 @@ extension ChannelOutboundInvoker { /// - parameters: /// - event: the event itself. /// - returns: the future which will be notified once the operation completes. - public func triggerUserOutboundEvent(_ event: Any, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { + public func triggerUserOutboundEvent( + _ event: Any, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture { let promise = makePromise(file: file, line: line) triggerUserOutboundEvent(event, promise: promise) return promise.futureResult } private func makePromise(file: StaticString = #fileID, line: UInt = #line) -> EventLoopPromise { - return eventLoop.makePromise(file: file, line: line) + eventLoop.makePromise(file: file, line: line) } } @@ -228,7 +242,7 @@ public protocol ChannelInboundInvoker { } /// A protocol that signals that outbound and inbound events are triggered by this invoker. -public protocol ChannelInvoker: ChannelOutboundInvoker, ChannelInboundInvoker { } +public protocol ChannelInvoker: ChannelOutboundInvoker, ChannelInboundInvoker {} /// Specify what kind of close operation is requested. public enum CloseMode: Sendable { diff --git a/Sources/NIOCore/ChannelOption.swift b/Sources/NIOCore/ChannelOption.swift index 009d47289c..a4e0746a04 100644 --- a/Sources/NIOCore/ChannelOption.swift +++ b/Sources/NIOCore/ChannelOption.swift @@ -20,11 +20,11 @@ public protocol ChannelOption: Equatable, _NIOPreconcurrencySendable { public typealias SocketOptionName = Int32 #if (os(Linux) || os(Android)) && !canImport(Musl) - public typealias SocketOptionLevel = Int - public typealias SocketOptionValue = Int +public typealias SocketOptionLevel = Int +public typealias SocketOptionValue = Int #else - public typealias SocketOptionLevel = CInt - public typealias SocketOptionValue = CInt +public typealias SocketOptionLevel = CInt +public typealias SocketOptionValue = CInt #endif @available(*, deprecated, renamed: "ChannelOptions.Types.SocketOption") @@ -77,7 +77,7 @@ extension ChannelOptions { public var level: SocketOptionLevel { get { - return SocketOptionLevel(optionLevel.rawValue) + SocketOptionLevel(optionLevel.rawValue) } set { self.optionLevel = NIOBSDSocket.OptionLevel(rawValue: CInt(newValue)) @@ -85,7 +85,7 @@ extension ChannelOptions { } public var name: SocketOptionName { get { - return SocketOptionName(optionName.rawValue) + SocketOptionName(optionName.rawValue) } set { self.optionName = NIOBSDSocket.Option(rawValue: CInt(newValue)) @@ -93,15 +93,15 @@ extension ChannelOptions { } #if !os(Windows) - /// Create a new `SocketOption`. - /// - /// - parameters: - /// - level: The level for the option as defined in `man setsockopt`, e.g. SO_SOCKET. - /// - name: The name of the option as defined in `man setsockopt`, e.g. `SO_REUSEADDR`. - public init(level: SocketOptionLevel, name: SocketOptionName) { - self.optionLevel = NIOBSDSocket.OptionLevel(rawValue: CInt(level)) - self.optionName = NIOBSDSocket.Option(rawValue: CInt(name)) - } + /// Create a new `SocketOption`. + /// + /// - parameters: + /// - level: The level for the option as defined in `man setsockopt`, e.g. SO_SOCKET. + /// - name: The name of the option as defined in `man setsockopt`, e.g. `SO_REUSEADDR`. + public init(level: SocketOptionLevel, name: SocketOptionName) { + self.optionLevel = NIOBSDSocket.OptionLevel(rawValue: CInt(level)) + self.optionName = NIOBSDSocket.Option(rawValue: CInt(name)) + } #endif /// Create a new `SocketOption`. @@ -188,7 +188,7 @@ extension ChannelOptions { public struct DatagramVectorReadMessageCountOption: ChannelOption, Sendable { public typealias Value = Int - public init() { } + public init() {} } /// ``DatagramSegmentSize`` controls the `UDP_SEGMENT` socket option (sometimes reffered to as 'GSO') which allows for @@ -201,7 +201,7 @@ extension ChannelOptions { /// Setting this option to zero disables segmentation offload. public struct DatagramSegmentSize: ChannelOption, Sendable { public typealias Value = CInt - public init() { } + public init() {} } /// ``DatagramReceiveOffload`` sets the `UDP_GRO` socket option which allows for datagrams to be accumulated @@ -214,7 +214,7 @@ extension ChannelOptions { /// The default allocator for datagram channels uses fixed sized buffers of 2048 bytes. public struct DatagramReceiveOffload: ChannelOption, Sendable { public typealias Value = Bool - public init() { } + public init() {} } /// When set to true IP level ECN information will be reported through `AddressedEnvelope.Metadata` @@ -293,23 +293,27 @@ extension ChannelOptions { /// Provides `ChannelOption`s to be used with a `Channel`, `Bootstrap` or `ServerBootstrap`. public struct ChannelOptions: Sendable { #if !os(Windows) - public static let socket: @Sendable (SocketOptionLevel, SocketOptionName) -> ChannelOptions.Types.SocketOption = { (level: SocketOptionLevel, name: SocketOptionName) -> Types.SocketOption in - .init(level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), name: NIOBSDSocket.Option(rawValue: CInt(name))) - } + public static let socket: @Sendable (SocketOptionLevel, SocketOptionName) -> ChannelOptions.Types.SocketOption = { + (level: SocketOptionLevel, name: SocketOptionName) -> Types.SocketOption in + .init(level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), name: NIOBSDSocket.Option(rawValue: CInt(name))) + } #endif /// - seealso: `SocketOption`. - public static let socketOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in + public static let socketOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { + (name: NIOBSDSocket.Option) -> Types.SocketOption in .init(level: .socket, name: name) } /// - seealso: `SocketOption`. - public static let ipOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in + public static let ipOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { + (name: NIOBSDSocket.Option) -> Types.SocketOption in .init(level: .ip, name: name) } /// - seealso: `SocketOption`. - public static let tcpOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { (name: NIOBSDSocket.Option) -> Types.SocketOption in + public static let tcpOption: @Sendable (NIOBSDSocket.Option) -> ChannelOptions.Types.SocketOption = { + (name: NIOBSDSocket.Option) -> Types.SocketOption in .init(level: .tcp, name: name) } @@ -361,7 +365,11 @@ extension ChannelOptions { /// `Channel` that needs to store `ChannelOption`s. public struct Storage: Sendable { @usableFromInline - internal var _storage: [(any ChannelOption, (any Sendable, @Sendable (Channel) -> (any ChannelOption, any Sendable) -> EventLoopFuture))] + internal var _storage: + [( + any ChannelOption, + (any Sendable, @Sendable (Channel) -> (any ChannelOption, any Sendable) -> EventLoopFuture) + )] public init() { self._storage = [] @@ -377,8 +385,8 @@ extension ChannelOptions { public mutating func append(key newKey: Option, value newValue: Option.Value) { @Sendable func applier(_ t: Channel) -> (any ChannelOption, any Sendable) -> EventLoopFuture { - return { (option, value) in - return t.setOption(option as! Option, value: value as! Option.Value) + { (option, value) in + t.setOption(option as! Option, value: value as! Option.Value) } } var hasSet = false @@ -407,7 +415,17 @@ extension ChannelOptions { let it = self._storage.makeIterator() @Sendable - func applyNext(iterator: IndexingIterator<[(any ChannelOption, (any Sendable, @Sendable (any Channel) -> (any ChannelOption, any Sendable) -> EventLoopFuture))]>) { + func applyNext( + iterator: IndexingIterator< + [( + any ChannelOption, + ( + any Sendable, + @Sendable (any Channel) -> (any ChannelOption, any Sendable) -> EventLoopFuture + ) + )] + > + ) { var iterator = iterator guard let (key, (value, applier)) = iterator.next() else { // If we reached the end, everything is applied. diff --git a/Sources/NIOCore/ChannelPipeline.swift b/Sources/NIOCore/ChannelPipeline.swift index e5ad620a05..aa0ab19275 100644 --- a/Sources/NIOCore/ChannelPipeline.swift +++ b/Sources/NIOCore/ChannelPipeline.swift @@ -167,9 +167,11 @@ public final class ChannelPipeline: ChannelInvoker { /// - handler: the `ChannelHandler` to add /// - position: The position in the `ChannelPipeline` to add `handler`. Defaults to `.last`. /// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was added. - public func addHandler(_ handler: ChannelHandler, - name: String? = nil, - position: ChannelPipeline.Position = .last) -> EventLoopFuture { + public func addHandler( + _ handler: ChannelHandler, + name: String? = nil, + position: ChannelPipeline.Position = .last + ) -> EventLoopFuture { let future: EventLoopFuture if self.eventLoop.inEventLoop { @@ -192,9 +194,11 @@ public final class ChannelPipeline: ChannelInvoker { /// - name: the name to use for the `ChannelHandler` when it's added. If none is specified a name will be generated. /// - position: The position in the `ChannelPipeline` to add `handler`. Defaults to `.last`. /// - returns: the result of adding this handler - either success or failure with an error code if this could not be completed. - fileprivate func addHandlerSync(_ handler: ChannelHandler, - name: String? = nil, - position: ChannelPipeline.Position = .last) -> Result { + fileprivate func addHandlerSync( + _ handler: ChannelHandler, + name: String? = nil, + position: ChannelPipeline.Position = .last + ) -> Result { self.eventLoop.assertInEventLoop() if self.destroyed { @@ -203,25 +207,33 @@ public final class ChannelPipeline: ChannelInvoker { switch position { case .first: - return self.add0(name: name, - handler: handler, - relativeContext: head!, - operation: self.add0(context:after:)) + return self.add0( + name: name, + handler: handler, + relativeContext: head!, + operation: self.add0(context:after:) + ) case .last: - return self.add0(name: name, - handler: handler, - relativeContext: tail!, - operation: self.add0(context:before:)) + return self.add0( + name: name, + handler: handler, + relativeContext: tail!, + operation: self.add0(context:before:) + ) case .before(let beforeHandler): - return self.add0(name: name, - handler: handler, - relativeHandler: beforeHandler, - operation: self.add0(context:before:)) + return self.add0( + name: name, + handler: handler, + relativeHandler: beforeHandler, + operation: self.add0(context:before:) + ) case .after(let afterHandler): - return self.add0(name: name, - handler: handler, - relativeHandler: afterHandler, - operation: self.add0(context:after:)) + return self.add0( + name: name, + handler: handler, + relativeHandler: afterHandler, + operation: self.add0(context:after:) + ) } } @@ -241,10 +253,12 @@ public final class ChannelPipeline: ChannelInvoker { /// inserted relative to. /// - operation: A callback that will insert `handler` relative to `relativeHandler`. /// - returns: the result of adding this handler - either success or failure with an error code if this could not be completed. - private func add0(name: String?, - handler: ChannelHandler, - relativeHandler: ChannelHandler, - operation: (ChannelHandlerContext, ChannelHandlerContext) -> Void) -> Result { + private func add0( + name: String?, + handler: ChannelHandler, + relativeHandler: ChannelHandler, + operation: (ChannelHandlerContext, ChannelHandlerContext) -> Void + ) -> Result { self.eventLoop.assertInEventLoop() if self.destroyed { return .failure(ChannelError._ioOnClosedChannel) @@ -273,10 +287,12 @@ public final class ChannelPipeline: ChannelInvoker { /// inserted relative to. /// - operation: A callback that will insert `handler` relative to `relativeHandler`. /// - returns: the result of adding this handler - either success or failure with an error code if this could not be completed. - private func add0(name: String?, - handler: ChannelHandler, - relativeContext: ChannelHandlerContext, - operation: (ChannelHandlerContext, ChannelHandlerContext) -> Void) -> Result { + private func add0( + name: String?, + handler: ChannelHandler, + relativeContext: ChannelHandlerContext, + operation: (ChannelHandlerContext, ChannelHandlerContext) -> Void + ) -> Result { self.eventLoop.assertInEventLoop() if self.destroyed { @@ -458,7 +474,7 @@ public final class ChannelPipeline: ChannelInvoker { /// - handler: the `ChannelHandler` for which the `ChannelHandlerContext` should be returned /// - returns: the `ChannelHandlerContext` that belongs to the `ChannelHandler`, if one exists. fileprivate func contextSync(handler: ChannelHandler) -> Result { - return self._contextSync({ $0.handler === handler }) + self._contextSync({ $0.handler === handler }) } /// Returns the `ChannelHandlerContext` that belongs to a `ChannelHandler`. @@ -487,7 +503,7 @@ public final class ChannelPipeline: ChannelInvoker { /// - Parameter name: The name of the `ChannelHandler` to find. /// - Returns: the `ChannelHandlerContext` that belongs to the `ChannelHandler`, if one exists. fileprivate func contextSync(name: String) -> Result { - return self._contextSync({ $0.name == name }) + self._contextSync({ $0.name == name }) } /// Returns the `ChannelHandlerContext` that belongs to a `ChannelHandler` of the given type. @@ -527,7 +543,7 @@ public final class ChannelPipeline: ChannelInvoker { /// Returns if the ``ChannelHandler`` of the given type is contained in the pipeline. /// /// - Parameters: - /// - name: The name of the handler. + /// - name: The name of the handler. /// - Returns: An ``EventLoopFuture`` that is succeeded if a handler of the given type is contained in the pipeline. Otherwise /// the future will be failed with an error. @inlinable @@ -541,14 +557,16 @@ public final class ChannelPipeline: ChannelInvoker { /// - Important: This must be called on the `EventLoop`. /// - Parameter handlerType: The type of handler to search for. /// - Returns: the `ChannelHandlerContext` that belongs to the `ChannelHandler`, if one exists. - @inlinable // should be fileprivate - internal func _contextSync(handlerType: Handler.Type) -> Result { - return self._contextSync({ $0.handler is Handler }) + @inlinable // should be fileprivate + internal func _contextSync( + handlerType: Handler.Type + ) -> Result { + self._contextSync({ $0.handler is Handler }) } /// Synchronously finds a `ChannelHandlerContext` in the `ChannelPipeline`. /// - Important: This must be called on the `EventLoop`. - @usableFromInline // should be fileprivate + @usableFromInline // should be fileprivate internal func _contextSync(_ body: (ChannelHandlerContext) -> Bool) -> Result { self.eventLoop.assertInEventLoop() @@ -815,11 +833,11 @@ public final class ChannelPipeline: ChannelInvoker { // These methods are expected to only be called from within the EventLoop private var firstOutboundCtx: ChannelHandlerContext? { - return self.tail?.prev + self.tail?.prev } private var firstInboundCtx: ChannelHandlerContext? { - return self.head?.next + self.head?.next } private func close0(mode: CloseMode, promise: EventLoopPromise?) { @@ -946,7 +964,7 @@ public final class ChannelPipeline: ChannelInvoker { } private var inEventLoop: Bool { - return eventLoop.inEventLoop + eventLoop.inEventLoop } /// Create `ChannelPipeline` for a given `Channel`. This method should never be called by the end-user @@ -958,11 +976,19 @@ public final class ChannelPipeline: ChannelInvoker { public init(channel: Channel) { self._channel = channel self.eventLoop = channel.eventLoop - self.head = nil // we need to initialise these to `nil` so we can use `self` in the lines below - self.tail = nil // we need to initialise these to `nil` so we can use `self` in the lines below - - self.head = ChannelHandlerContext(name: HeadChannelHandler.name, handler: HeadChannelHandler.sharedInstance, pipeline: self) - self.tail = ChannelHandlerContext(name: TailChannelHandler.name, handler: TailChannelHandler.sharedInstance, pipeline: self) + self.head = nil // we need to initialise these to `nil` so we can use `self` in the lines below + self.tail = nil // we need to initialise these to `nil` so we can use `self` in the lines below + + self.head = ChannelHandlerContext( + name: HeadChannelHandler.name, + handler: HeadChannelHandler.sharedInstance, + pipeline: self + ) + self.tail = ChannelHandlerContext( + name: TailChannelHandler.name, + handler: TailChannelHandler.sharedInstance, + pipeline: self + ) self.head?.next = self.tail self.tail?.prev = self.head } @@ -979,8 +1005,10 @@ extension ChannelPipeline { /// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`. /// /// - returns: A future that will be completed when all of the supplied `ChannelHandler`s were added. - public func addHandlers(_ handlers: [ChannelHandler], - position: ChannelPipeline.Position = .last) -> EventLoopFuture { + public func addHandlers( + _ handlers: [ChannelHandler], + position: ChannelPipeline.Position = .last + ) -> EventLoopFuture { let future: EventLoopFuture if self.eventLoop.inEventLoop { @@ -1002,9 +1030,11 @@ extension ChannelPipeline { /// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`. /// /// - returns: A future that will be completed when all of the supplied `ChannelHandler`s were added. - public func addHandlers(_ handlers: ChannelHandler..., - position: ChannelPipeline.Position = .last) -> EventLoopFuture { - return self.addHandlers(handlers, position: position) + public func addHandlers( + _ handlers: ChannelHandler..., + position: ChannelPipeline.Position = .last + ) -> EventLoopFuture { + self.addHandlers(handlers, position: position) } /// Synchronously adds the provided `ChannelHandler`s to the pipeline in the order given, taking @@ -1015,8 +1045,10 @@ extension ChannelPipeline { /// - handlers: The array of `ChannelHandler`s to add. /// - position: The position in the `ChannelPipeline` to add the handlers. /// - Returns: A result representing whether the handlers were added or not. - fileprivate func addHandlersSync(_ handlers: [ChannelHandler], - position: ChannelPipeline.Position) -> Result { + fileprivate func addHandlersSync( + _ handlers: [ChannelHandler], + position: ChannelPipeline.Position + ) -> Result { switch position { case .first, .after: return self._addHandlersSync(handlers.reversed(), position: position) @@ -1032,8 +1064,10 @@ extension ChannelPipeline { /// - handlers: A sequence of handlers to add. /// - position: The position in the `ChannelPipeline` to add the handlers. /// - Returns: A result representing whether the handlers were added or not. - private func _addHandlersSync(_ handlers: Handlers, - position: ChannelPipeline.Position) -> Result where Handlers.Element == ChannelHandler { + private func _addHandlersSync( + _ handlers: Handlers, + position: ChannelPipeline.Position + ) -> Result where Handlers.Element == ChannelHandler { self.eventLoop.assertInEventLoop() for handler in handlers { @@ -1066,7 +1100,7 @@ extension ChannelPipeline { /// The `EventLoop` of the `Channel` this synchronous operations view corresponds to. public var eventLoop: EventLoop { - return self._pipeline.eventLoop + self._pipeline.eventLoop } /// Add a handler to the pipeline. @@ -1076,9 +1110,11 @@ extension ChannelPipeline { /// - handler: The handler to add. /// - name: The name to use for the `ChannelHandler` when it's added. If no name is specified the one will be generated. /// - position: The position in the `ChannelPipeline` to add `handler`. Defaults to `.last`. - public func addHandler(_ handler: ChannelHandler, - name: String? = nil, - position: ChannelPipeline.Position = .last) throws { + public func addHandler( + _ handler: ChannelHandler, + name: String? = nil, + position: ChannelPipeline.Position = .last + ) throws { try self._pipeline.addHandlerSync(handler, name: name, position: position).get() } @@ -1088,8 +1124,10 @@ extension ChannelPipeline { /// - Parameters: /// - handlers: The handlers to add. /// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`. - public func addHandlers(_ handlers: [ChannelHandler], - position: ChannelPipeline.Position = .last) throws { + public func addHandlers( + _ handlers: [ChannelHandler], + position: ChannelPipeline.Position = .last + ) throws { try self._pipeline.addHandlersSync(handlers, position: position).get() } @@ -1099,8 +1137,10 @@ extension ChannelPipeline { /// - Parameters: /// - handlers: The handlers to add. /// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`. - public func addHandlers(_ handlers: ChannelHandler..., - position: ChannelPipeline.Position = .last) throws { + public func addHandlers( + _ handlers: ChannelHandler..., + position: ChannelPipeline.Position = .last + ) throws { try self._pipeline.addHandlersSync(handlers, position: position).get() } @@ -1128,7 +1168,7 @@ extension ChannelPipeline { /// - Parameter handler: The handler belonging to the context to fetch. /// - Returns: The `ChannelHandlerContext` associated with the handler. public func context(handler: ChannelHandler) throws -> ChannelHandlerContext { - return try self._pipeline._contextSync({ $0.handler === handler }).get() + try self._pipeline._contextSync({ $0.handler === handler }).get() } /// Returns the `ChannelHandlerContext` for the handler with the given name, if one exists. @@ -1137,7 +1177,7 @@ extension ChannelPipeline { /// - Parameter name: The name of the handler whose context is being fetched. /// - Returns: The `ChannelHandlerContext` associated with the handler. public func context(name: String) throws -> ChannelHandlerContext { - return try self._pipeline.contextSync(name: name).get() + try self._pipeline.contextSync(name: name).get() } /// Returns the `ChannelHandlerContext` for the handler of given type, if one exists. @@ -1147,7 +1187,7 @@ extension ChannelPipeline { /// - Returns: The `ChannelHandlerContext` associated with the handler. @inlinable public func context(handlerType: Handler.Type) throws -> ChannelHandlerContext { - return try self._pipeline._contextSync(handlerType: handlerType).get() + try self._pipeline._contextSync(handlerType: handlerType).get() } /// Returns the `ChannelHandler` of the given type from the `ChannelPipeline`, if it exists. @@ -1156,7 +1196,7 @@ extension ChannelPipeline { /// - Returns: A `ChannelHandler` of the given type if one exists in the `ChannelPipeline`. @inlinable public func handler(type _: Handler.Type) throws -> Handler { - return try self._pipeline._handlerSync(type: Handler.self).get() + try self._pipeline._handlerSync(type: Handler.self).get() } /// Fires `channelRegistered` from the head to the tail. @@ -1307,7 +1347,7 @@ extension ChannelPipeline { /// Returns a view of operations which can be performed synchronously on this pipeline. All /// operations **must** be called on the event loop. public var syncOperations: SynchronousOperations { - return SynchronousOperations(pipeline: self) + SynchronousOperations(pipeline: self) } } @@ -1335,12 +1375,12 @@ extension ChannelPipeline { extension ChannelPipeline.Position: Sendable {} /// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation. -/* private but tests */ final class HeadChannelHandler: _ChannelOutboundHandler { +final class HeadChannelHandler: _ChannelOutboundHandler { static let name = "head" static let sharedInstance = HeadChannelHandler() - private init() { } + private init() {} func register(context: ChannelHandlerContext, promise: EventLoopPromise?) { context.channel._channelCore.register0(promise: promise) @@ -1376,9 +1416,9 @@ extension ChannelPipeline.Position: Sendable {} } -private extension CloseMode { +extension CloseMode { /// Returns the error to fail outstanding operations writes with. - var error: any Error { + fileprivate var error: any Error { switch self { case .all: return ChannelError._ioOnClosedChannel @@ -1391,12 +1431,12 @@ private extension CloseMode { } /// Special `ChannelInboundHandler` which will consume all inbound events. -/* private but tests */ final class TailChannelHandler: _ChannelInboundHandler { +final class TailChannelHandler: _ChannelInboundHandler { static let name = "tail" static let sharedInstance = TailChannelHandler() - private init() { } + private init() {} func channelRegistered(context: ChannelHandlerContext) { // Discard @@ -1461,11 +1501,11 @@ public final class ChannelHandlerContext: ChannelInvoker { public let pipeline: ChannelPipeline public var channel: Channel { - return self.pipeline.channel + self.pipeline.channel } public var handler: ChannelHandler { - return self.inboundHandler ?? self.outboundHandler! + self.inboundHandler ?? self.outboundHandler! } public var remoteAddress: SocketAddress? { @@ -1495,7 +1535,7 @@ public final class ChannelHandlerContext: ChannelInvoker { } public var eventLoop: EventLoop { - return self.pipeline.eventLoop + self.pipeline.eventLoop } public let name: String @@ -1512,7 +1552,10 @@ public final class ChannelHandlerContext: ChannelInvoker { self.outboundHandler = handler as? _ChannelOutboundHandler self.next = nil self.prev = nil - precondition(self.inboundHandler != nil || self.outboundHandler != nil, "ChannelHandlers need to either be inbound or outbound") + precondition( + self.inboundHandler != nil || self.outboundHandler != nil, + "ChannelHandlers need to either be inbound or outbound" + ) } /// Send a `channelRegistered` event to the next (inbound) `ChannelHandler` in the `ChannelPipeline`. @@ -1788,7 +1831,7 @@ public final class ChannelHandlerContext: ChannelInvoker { } } - fileprivate func invokeBind(to address: SocketAddress, promise: EventLoopPromise?) { + fileprivate func invokeBind(to address: SocketAddress, promise: EventLoopPromise?) { self.eventLoop.assertInEventLoop() if let outboundHandler = self.outboundHandler { @@ -1916,8 +1959,10 @@ extension ChannelHandlerContext { return } self.userTriggeredRemovalStarted = true - (self.handler as! RemovableChannelHandler).removeHandler(context: self, - removalToken: .init(promise: promise)) + (self.handler as! RemovableChannelHandler).removeHandler( + context: self, + removalToken: .init(promise: promise) + ) } } @@ -1933,15 +1978,17 @@ extension ChannelPipeline: CustomDebugStringConvertible { // var desc = ["ChannelPipeline[\(ObjectIdentifier(self))]:"] let debugInfos = self.collectHandlerDebugInfos() - let maxIncomingTypeNameCount = debugInfos.filter { $0.isIncoming } + let maxIncomingTypeNameCount = + debugInfos.filter { $0.isIncoming } .map { $0.typeName.count } .max() ?? 0 - let maxOutgoingTypeNameCount = debugInfos.filter { $0.isOutgoing } + let maxOutgoingTypeNameCount = + debugInfos.filter { $0.isOutgoing } .map { $0.typeName.count } .max() ?? 0 func whitespace(count: Int) -> String { - return String(repeating: " ", count: count) + String(repeating: " ", count: count) } if debugInfos.isEmpty { @@ -1979,9 +2026,11 @@ extension ChannelPipeline: CustomDebugStringConvertible { /// - type: the type of `ChannelHandler` to return. @inlinable public func handler(type _: Handler.Type) -> EventLoopFuture { - return self.context(handlerType: Handler.self).map { context in + self.context(handlerType: Handler.self).map { context in guard let typedContext = context.handler as? Handler else { - preconditionFailure("Expected channel handler of type \(Handler.self), got \(type(of: context.handler)) instead.") + preconditionFailure( + "Expected channel handler of type \(Handler.self), got \(type(of: context.handler)) instead." + ) } return typedContext @@ -1993,11 +2042,13 @@ extension ChannelPipeline: CustomDebugStringConvertible { /// - Important: This must be called on the `EventLoop`. /// - Parameters: /// - type: the type of `ChannelHandler` to return. - @inlinable // should be fileprivate + @inlinable // should be fileprivate internal func _handlerSync(type _: Handler.Type) -> Result { - return self._contextSync(handlerType: Handler.self).map { context in + self._contextSync(handlerType: Handler.self).map { context in guard let typedContext = context.handler as? Handler else { - preconditionFailure("Expected channel handler of type \(Handler.self), got \(type(of: context.handler)) instead.") + preconditionFailure( + "Expected channel handler of type \(Handler.self), got \(type(of: context.handler)) instead." + ) } return typedContext } @@ -2007,13 +2058,13 @@ extension ChannelPipeline: CustomDebugStringConvertible { let handler: ChannelHandler let name: String var isIncoming: Bool { - return self.handler is _ChannelInboundHandler + self.handler is _ChannelInboundHandler } var isOutgoing: Bool { - return self.handler is _ChannelOutboundHandler + self.handler is _ChannelOutboundHandler } var typeName: String { - return "\(type(of: self.handler))" + "\(type(of: self.handler))" } } diff --git a/Sources/NIOCore/CircularBuffer.swift b/Sources/NIOCore/CircularBuffer.swift index 1b09440194..360b3dfb0d 100644 --- a/Sources/NIOCore/CircularBuffer.swift +++ b/Sources/NIOCore/CircularBuffer.swift @@ -27,7 +27,7 @@ public struct CircularBuffer: CustomStringConvertible { @inlinable internal var mask: Int { - return self._buffer.count &- 1 + self._buffer.count &- 1 } @inlinable @@ -42,17 +42,17 @@ public struct CircularBuffer: CustomStringConvertible { @inlinable internal func indexBeforeHeadIdx() -> Int { - return self.indexAdvanced(index: self.headBackingIndex, by: -1) + self.indexAdvanced(index: self.headBackingIndex, by: -1) } @inlinable internal func indexBeforeTailIdx() -> Int { - return self.indexAdvanced(index: self.tailBackingIndex, by: -1) + self.indexAdvanced(index: self.tailBackingIndex, by: -1) } @inlinable internal func indexAdvanced(index: Int, by: Int) -> Int { - return (index &+ by) & self.mask + (index &+ by) & self.mask } /// An opaque `CircularBuffer` index. @@ -69,7 +69,7 @@ public struct CircularBuffer: CustomStringConvertible { @inlinable internal var backingIndex: Int { - return Int(self._backingIndex) + Int(self._backingIndex) } @inlinable @@ -85,9 +85,8 @@ public struct CircularBuffer: CustomStringConvertible { @inlinable public static func == (lhs: Index, rhs: Index) -> Bool { - return lhs._backingIndex == rhs._backingIndex && - lhs._backingCheck == rhs._backingCheck && - lhs.isIndexGEQHeadIndex == rhs.isIndexGEQHeadIndex + lhs._backingIndex == rhs._backingIndex && lhs._backingCheck == rhs._backingCheck + && lhs.isIndexGEQHeadIndex == rhs.isIndexGEQHeadIndex } @inlinable @@ -105,8 +104,8 @@ public struct CircularBuffer: CustomStringConvertible { @usableFromInline internal func isValidIndex(for ring: CircularBuffer) -> Bool { - return self._backingCheck == _UInt24.max || Int(self._backingCheck) == ring.count - } + self._backingCheck == _UInt24.max || Int(self._backingCheck) == ring.count + } } } @@ -128,13 +127,13 @@ extension CircularBuffer: Collection, MutableCollection { /// - Returns: The index value immediately after `i`. @inlinable public func index(after: Index) -> Index { - return self.index(after, offsetBy: 1) + self.index(after, offsetBy: 1) } /// Returns the index before `index`. @inlinable public func index(before: Index) -> Index { - return self.index(before, offsetBy: -1) + self.index(before, offsetBy: -1) } /// Accesses the element at the specified index. @@ -152,15 +151,19 @@ extension CircularBuffer: Collection, MutableCollection { @inlinable public subscript(position: Index) -> Element { get { - assert(position.isValidIndex(for: self), - "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + - "but actual count is \(self.count)") + assert( + position.isValidIndex(for: self), + "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + + "but actual count is \(self.count)" + ) return self._buffer[position.backingIndex]! } set { - assert(position.isValidIndex(for: self), - "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + - "but actual count is \(self.count)") + assert( + position.isValidIndex(for: self), + "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + + "but actual count is \(self.count)" + ) self._buffer[position.backingIndex] = newValue } } @@ -170,9 +173,11 @@ extension CircularBuffer: Collection, MutableCollection { /// If the `CircularBuffer` is empty, `startIndex` is equal to `endIndex`. @inlinable public var startIndex: Index { - return .init(backingIndex: self.headBackingIndex, - backingCount: self.count, - backingIndexOfHead: self.headBackingIndex) + .init( + backingIndex: self.headBackingIndex, + backingCount: self.count, + backingIndexOfHead: self.headBackingIndex + ) } /// The `CircularBuffer`'s "past the end" position---that is, the position one @@ -186,9 +191,11 @@ extension CircularBuffer: Collection, MutableCollection { /// If the `CircularBuffer` is empty, `endIndex` is equal to `startIndex`. @inlinable public var endIndex: Index { - return .init(backingIndex: self.tailBackingIndex, - backingCount: self.count, - backingIndexOfHead: self.headBackingIndex) + .init( + backingIndex: self.tailBackingIndex, + backingCount: self.count, + backingIndexOfHead: self.headBackingIndex + ) } /// Returns the distance between two indices. @@ -251,14 +258,14 @@ extension CircularBuffer: Collection, MutableCollection { return (self[self.endIndex..) {} - + @inlinable public func _failEarlyRangeCheck(_ index: Index, bounds: ClosedRange) {} - + @inlinable public func _failEarlyRangeCheck(_ range: Range, bounds: Range) {} } @@ -294,9 +301,11 @@ extension CircularBuffer: RandomAccessCollection { /// value of `distance`. @inlinable public func index(_ i: Index, offsetBy distance: Int) -> Index { - return .init(backingIndex: (i.backingIndex &+ distance) & self.mask, - backingCount: self.count, - backingIndexOfHead: self.headBackingIndex) + .init( + backingIndex: (i.backingIndex &+ distance) & self.mask, + backingCount: self.count, + backingIndexOfHead: self.headBackingIndex + ) } @inlinable @@ -345,7 +354,7 @@ extension CircularBuffer { public mutating func append(_ value: Element) { self._buffer[self.tailBackingIndex] = value self.advanceTailIdx(by: 1) - + if self.headBackingIndex == self.tailBackingIndex { // No more room left for another append so grow the buffer now. self._doubleCapacity() @@ -424,24 +433,24 @@ extension CircularBuffer { self._buffer = newBacking assert(self.verifyInvariants()) } - + /// Return element `offset` from first element. /// /// *O(1)* @inlinable public subscript(offset offset: Int) -> Element { get { - return self[self.index(self.startIndex, offsetBy: offset)] + self[self.index(self.startIndex, offsetBy: offset)] } set { self[self.index(self.startIndex, offsetBy: offset)] = newValue } } - + /// Returns whether the ring is empty. @inlinable public var isEmpty: Bool { - return self.headBackingIndex == self.tailBackingIndex + self.headBackingIndex == self.tailBackingIndex } /// Returns the number of element in the ring. @@ -457,7 +466,7 @@ extension CircularBuffer { /// The total number of elements that the ring can contain without allocating new storage. @inlinable public var capacity: Int { - return self._buffer.count + self._buffer.count } /// Removes all members from the circular buffer whist keeping the capacity. @@ -474,7 +483,6 @@ extension CircularBuffer { assert(self.verifyInvariants()) } - /// Modify the element at `index`. /// /// This function exists to provide a method of modifying the element in its underlying backing storage, instead @@ -491,10 +499,13 @@ extension CircularBuffer { /// - index: The index of the object that should be modified. If this index is invalid this function will trap. /// - modifyFunc: The function to apply to the modified object. @inlinable - public mutating func modify(_ index: Index, _ modifyFunc: (inout Element) throws -> Result) rethrows -> Result { - return try modifyFunc(&self._buffer[index.backingIndex]!) + public mutating func modify( + _ index: Index, + _ modifyFunc: (inout Element) throws -> Result + ) rethrows -> Result { + try modifyFunc(&self._buffer[index.backingIndex]!) } - + // MARK: CustomStringConvertible implementation /// Returns a human readable description of the ring. public var description: String { @@ -572,14 +583,13 @@ extension CircularBuffer: RangeReplaceableCollection { public mutating func removeLast(_ k: Int) { precondition(k <= self.count, "Number of elements to drop bigger than the amount of elements in the buffer.") var idx = self.tailBackingIndex - for _ in 0 ..< k { + for _ in 0..(_ subrange: Range, with newElements: C) where Element == C.Element { - precondition(subrange.lowerBound >= self.startIndex && subrange.upperBound <= self.endIndex, - "Subrange out of bounds") - assert(subrange.lowerBound.isValidIndex(for: self), - "illegal index used, index was for CircularBuffer with count \(subrange.lowerBound._backingCheck), " + - "but actual count is \(self.count)") - assert(subrange.upperBound.isValidIndex(for: self), - "illegal index used, index was for CircularBuffer with count \(subrange.upperBound._backingCheck), " + - "but actual count is \(self.count)") + public mutating func replaceSubrange(_ subrange: Range, with newElements: C) + where Element == C.Element { + precondition( + subrange.lowerBound >= self.startIndex && subrange.upperBound <= self.endIndex, + "Subrange out of bounds" + ) + assert( + subrange.lowerBound.isValidIndex(for: self), + "illegal index used, index was for CircularBuffer with count \(subrange.lowerBound._backingCheck), " + + "but actual count is \(self.count)" + ) + assert( + subrange.upperBound.isValidIndex(for: self), + "illegal index used, index was for CircularBuffer with count \(subrange.upperBound._backingCheck), " + + "but actual count is \(self.count)" + ) let subrangeCount = self.distance(from: subrange.lowerBound, to: subrange.upperBound) @@ -674,14 +691,14 @@ extension CircularBuffer: RangeReplaceableCollection { self.removeSubrange(subrange) } else { var newBuffer: ContiguousArray = [] - let neededNewCapacity = self.count + newElements.count - subrangeCount + 1 /* always one spare */ + let neededNewCapacity = self.count + newElements.count - subrangeCount + 1 // always one spare let newCapacity = Swift.max(self.capacity, neededNewCapacity.nextPowerOf2()) newBuffer.reserveCapacity(newCapacity) // This mapping is required due to an inconsistent ability to append sequences of non-optional // to optional sequences. // https://bugs.swift.org/browse/SR-7921 - newBuffer.append(contentsOf: self[self.startIndex ..< subrange.lowerBound].lazy.map { $0 }) + newBuffer.append(contentsOf: self[self.startIndex.. Element { - assert(position.isValidIndex(for: self), - "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + - "but actual count is \(self.count)") + assert( + position.isValidIndex(for: self), + "illegal index used, index was for CircularBuffer with count \(position._backingCheck), " + + "but actual count is \(self.count)" + ) defer { assert(self.verifyInvariants()) } @@ -807,13 +826,13 @@ extension CircularBuffer { } internal func testOnly_verifyInvariantsForNonSlices() -> Bool { - return self.verifyInvariants() && self.unreachableAreNil() + self.verifyInvariants() && self.unreachableAreNil() } } extension CircularBuffer: Equatable where Element: Equatable { - public static func ==(lhs: CircularBuffer, rhs: CircularBuffer) -> Bool { - return lhs.count == rhs.count && zip(lhs, rhs).allSatisfy(==) + public static func == (lhs: CircularBuffer, rhs: CircularBuffer) -> Bool { + lhs.count == rhs.count && zip(lhs, rhs).allSatisfy(==) } } diff --git a/Sources/NIOCore/Codec.swift b/Sources/NIOCore/Codec.swift index a67e840fa6..6b9d59c4f7 100644 --- a/Sources/NIOCore/Codec.swift +++ b/Sources/NIOCore/Codec.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// - /// State of the current decoding process. public enum DecodingState: Sendable { /// Continue decoding. @@ -43,7 +42,6 @@ extension ByteToMessageDecoderError { } } - /// `ByteToMessageDecoder`s decode bytes in a stream-like fashion from `ByteBuffer` to another message type. /// /// ### Purpose @@ -63,7 +61,7 @@ extension ByteToMessageDecoderError { /// ### Implementing ByteToMessageDecoder /// /// A type that implements `ByteToMessageDecoder` may implement two methods: decode and decodeLast. Implementations -/// must implement decode: if they do not implement decodeLast, a default implementation will be used that +/// must implement decode: if they do not implement decodeLast, a default implementation will be used that /// simply calls decode. /// /// `decode` is the main decoding method, and is the one that will be called most often. `decode` is invoked @@ -176,7 +174,11 @@ public protocol ByteToMessageDecoder { /// - seenEOF: `true` if EOF has been seen. Usually if this is `false` the handler has been removed. /// - returns: `DecodingState.continue` if we should continue calling this method or `DecodingState.needMoreData` if it should be called /// again when more data is present in the `ByteBuffer`. - mutating func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState + mutating func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState /// Called once this `ByteToMessageDecoder` is removed from the `ChannelPipeline`. /// @@ -237,15 +239,19 @@ extension ByteToMessageDecoder { @inlinable public func wrapInboundOut(_ value: InboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } @inlinable public static func wrapInboundOut(_ value: InboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } - public mutating func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public mutating func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { while try self.decode(context: context, buffer: &buffer) == .continue {} return .needMoreData } @@ -297,7 +303,7 @@ extension B2MDBuffer { case .ready where self.buffers.count > 0: var buffer = self.buffers.removeFirst() buffer.writeBuffers(self.buffers) - self.buffers.removeAll(keepingCapacity: self.buffers.capacity < 16) // don't grow too much + self.buffers.removeAll(keepingCapacity: self.buffers.capacity < 16) // don't grow too much if buffer.readableBytes > 0 || allowEmptyBuffer { self.state = .processingInProgress return .available(buffer) @@ -314,7 +320,7 @@ extension B2MDBuffer { } } - mutating func finishProcessing(remainder buffer: inout ByteBuffer) -> Void { + mutating func finishProcessing(remainder buffer: inout ByteBuffer) { assert(self.state == .processingInProgress) self.state = .ready if buffer.readableBytes == 0 && self.buffers.isEmpty { @@ -326,7 +332,8 @@ extension B2MDBuffer { } else { buffer.discardReadBytes() buffer.writeBuffers(self.buffers) - self.buffers.removeAll(keepingCapacity: self.buffers.capacity < 16) // don't grow too much + // don't grow too much + self.buffers.removeAll(keepingCapacity: self.buffers.capacity < 16) self.buffers.append(buffer) } } @@ -339,8 +346,8 @@ extension B2MDBuffer { } // MARK: B2MDBuffer Helpers -private extension ByteBuffer { - mutating func writeBuffers(_ buffers: CircularBuffer) { +extension ByteBuffer { + fileprivate mutating func writeBuffers(_ buffers: CircularBuffer) { guard buffers.count > 0 else { return } @@ -355,8 +362,8 @@ private extension ByteBuffer { } } -private extension B2MDBuffer { - func _testOnlyOneBuffer() -> ByteBuffer? { +extension B2MDBuffer { + fileprivate func _testOnlyOneBuffer() -> ByteBuffer? { switch self.buffers.count { case 0: return nil @@ -452,12 +459,15 @@ public final class ByteToMessageHandler { } } - internal private(set) var decoder: Decoder? // only `nil` if we're already decoding (ie. we're re-entered) + // only `nil` if we're already decoding (ie. we're re-entered) + internal private(set) var decoder: Decoder? private let maximumBufferSize: Int? - private var queuedWrites = CircularBuffer(initialCapacity: 1) // queues writes received whilst we're already decoding (re-entrant write) + // queues writes received whilst we're already decoding (re-entrant write) + private var queuedWrites = CircularBuffer(initialCapacity: 1) private var state: State = .active { willSet { - assert(!self.state.isFinalState, "illegal state on state set: \(self.state)") // we can never leave final states + // we can never leave final states + assert(!self.state.isFinalState, "illegal state on state set: \(self.state)") } } private var removalState: RemovalState = .notAddedToPipeline @@ -484,8 +494,10 @@ public final class ByteToMessageHandler { deinit { if self.removalState != .notAddedToPipeline { // we have been added to the pipeline, if not, we don't need to check our state. - assert(self.removalState == .handlerRemovedCalled, - "illegal state in deinit: removalState = \(self.removalState)") + assert( + self.removalState == .handlerRemovedCalled, + "illegal state in deinit: removalState = \(self.removalState)" + ) assert(self.state.isFinalState, "illegal state in deinit: state = \(self.state)") } } @@ -497,7 +509,7 @@ extension ByteToMessageHandler: Sendable {} // MARK: ByteToMessageHandler: Test Helpers extension ByteToMessageHandler { internal var cumulationBuffer: ByteBuffer? { - return self.buffer._testOnlyOneBuffer() + self.buffer._testOnlyOneBuffer() } } @@ -514,11 +526,13 @@ extension ByteToMessageHandler: CanDequeueWrites where Decoder: WriteObservingBy } } - // MARK: ByteToMessageHandler's Main API extension ByteToMessageHandler { - @inline(__always) // allocations otherwise (reconsider with Swift 5.1) - private func withNextBuffer(allowEmptyBuffer: Bool, _ body: (inout Decoder, inout ByteBuffer) throws -> DecodingState) rethrows -> B2MDBuffer.BufferProcessingResult { + @inline(__always) // allocations otherwise (reconsider with Swift 5.1) + private func withNextBuffer( + allowEmptyBuffer: Bool, + _ body: (inout Decoder, inout ByteBuffer) throws -> DecodingState + ) rethrows -> B2MDBuffer.BufferProcessingResult { switch self.buffer.startProcessing(allowEmptyBuffer: allowEmptyBuffer) { case .bufferAlreadyBeingProcessed: return .cannotProcessReentrantly @@ -528,7 +542,8 @@ extension ByteToMessageHandler { var possiblyReclaimBytes = false var decoder: Decoder? = nil swap(&decoder, &self.decoder) - assert(decoder != nil) // self.decoder only `nil` if we're being re-entered, but .available means we're not + // self.decoder only `nil` if we're being re-entered, but .available means we're not + assert(decoder != nil) defer { swap(&decoder, &self.decoder) if buffer.readableBytes > 0 && possiblyReclaimBytes { @@ -575,7 +590,10 @@ extension ByteToMessageHandler { } } - private func decodeLoop(context: ChannelHandlerContext, decodeMode: DecodeMode) throws -> B2MDBuffer.BufferProcessingResult { + private func decodeLoop( + context: ChannelHandlerContext, + decodeMode: DecodeMode + ) throws -> B2MDBuffer.BufferProcessingResult { assert(!self.state.isError) var allowEmptyBuffer = decodeMode == .last while (self.state.isActive && self.removalState == .notBeingRemoved) || decodeMode == .last { @@ -588,7 +606,9 @@ extension ByteToMessageHandler { allowEmptyBuffer = false decoderResult = try decoder.decodeLast(context: context, buffer: &buffer, seenEOF: self.seenEOF) } - if decoderResult == .needMoreData, let maximumBufferSize = self.maximumBufferSize, buffer.readableBytes > maximumBufferSize { + if decoderResult == .needMoreData, let maximumBufferSize = self.maximumBufferSize, + buffer.readableBytes > maximumBufferSize + { throw ByteToMessageDecoderError.PayloadTooLargeError() } return decoderResult @@ -600,7 +620,7 @@ extension ByteToMessageHandler { case .didProcess(.needMoreData): if self.queuedWrites.count > 0 { self.tryDecodeWrites() - continue // we might have received more, so let's spin once more + continue // we might have received more, so let's spin once more } else { return .didProcess(.needMoreData) } @@ -612,7 +632,6 @@ extension ByteToMessageHandler { } } - // MARK: ByteToMessageHandler: ChannelInboundHandler extension ByteToMessageHandler: ChannelInboundHandler { @@ -623,11 +642,10 @@ extension ByteToMessageHandler: ChannelInboundHandler { self.removalState = .notBeingRemoved self.buffer = B2MDBuffer(emptyByteBuffer: context.channel.allocator.buffer(capacity: 0)) // here we can force it because we know that the decoder isn't in use if we're just adding this handler - self.selfAsCanDequeueWrites = self as? CanDequeueWrites // we need to cache this as it allocates. + self.selfAsCanDequeueWrites = self as? CanDequeueWrites // we need to cache this as it allocates. self.decoder!.decoderAdded(context: context) } - public func handlerRemoved(context: ChannelHandlerContext) { // very likely, the removal state is `.notBeingRemoved` or `.removalCompleted` here but we can't assert it // because the pipeline might be torn down during the formal removal process. @@ -656,14 +674,14 @@ extension ByteToMessageHandler: ChannelInboundHandler { case .didProcess: switch self.state { case .active: - () // cool, all normal + () // cool, all normal case .done, .error: - () // fair, all done already + () // fair, all done already case .leftoversNeedProcessing: // seems like we received a `channelInactive` or `handlerRemoved` whilst we were processing a read switch try self.decodeLoop(context: context, decodeMode: .last) { case .didProcess: - () // expected and cool + () // expected and cool case .cannotProcessReentrantly: preconditionFailure("bug in NIO: non-reentrant decode loop couldn't run \(self), \(self.state)") } @@ -698,7 +716,8 @@ extension ByteToMessageHandler: ChannelInboundHandler { } } -extension ByteToMessageHandler: ChannelOutboundHandler, _ChannelOutboundHandler where Decoder: WriteObservingByteToMessageDecoder { +extension ByteToMessageHandler: ChannelOutboundHandler, _ChannelOutboundHandler +where Decoder: WriteObservingByteToMessageDecoder { public typealias OutboundIn = Decoder.OutboundIn public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { if self.decoder != nil { @@ -785,8 +804,10 @@ extension MessageToByteHandler: Sendable {} extension MessageToByteHandler { public func handlerAdded(context: ChannelHandlerContext) { - precondition(self.state.readyToBeAddedToChannel, - "illegal state when adding to Channel: \(self.state)") + precondition( + self.state.readyToBeAddedToChannel, + "illegal state when adding to Channel: \(self.state)" + ) self.state = .operational self.buffer = context.channel.allocator.buffer(capacity: 256) } diff --git a/Sources/NIOCore/ConvenienceOptionSupport.swift b/Sources/NIOCore/ConvenienceOptionSupport.swift index af93b79803..0ef1c61519 100644 --- a/Sources/NIOCore/ConvenienceOptionSupport.swift +++ b/Sources/NIOCore/ConvenienceOptionSupport.swift @@ -20,7 +20,7 @@ extension NIOClientTCPBootstrapProtocol { /// - returns: The updated bootstrap with and options applied. public func _applyChannelConvenienceOptions(_ options: inout ChannelOptions.TCPConvenienceOptions) -> Self { // Default is to consume no options and not update self. - return self + self } } @@ -34,8 +34,10 @@ extension NIOClientTCPBootstrap { var optionsRemaining = options // First give the underlying a chance to consume options. let withUnderlyingOverrides = - NIOClientTCPBootstrap(self, - updating: underlyingBootstrap._applyChannelConvenienceOptions(&optionsRemaining)) + NIOClientTCPBootstrap( + self, + updating: underlyingBootstrap._applyChannelConvenienceOptions(&optionsRemaining) + ) // Default apply any remaining options. return optionsRemaining.applyFallbackMapping(withUnderlyingOverrides) } @@ -84,11 +86,11 @@ extension ChannelOptions { /// A TCP channel option which can be applied to a bootstrap using convenience notation. public struct TCPConvenienceOption: Hashable, Sendable { fileprivate var data: ConvenienceOption - + private init(_ data: ConvenienceOption) { self.data = data } - + fileprivate enum ConvenienceOption: Hashable { case allowLocalEndpointReuse case disableAutoRead @@ -101,17 +103,17 @@ extension ChannelOptions { extension ChannelOptions.TCPConvenienceOption { /// Allow immediately reusing a local address. public static let allowLocalEndpointReuse = ChannelOptions.TCPConvenienceOption(.allowLocalEndpointReuse) - + /// The user will manually call `Channel.read` once all the data is read from the transport. public static let disableAutoRead = ChannelOptions.TCPConvenienceOption(.disableAutoRead) - + /// Allows users to configure whether the `Channel` will close itself when its remote /// peer shuts down its send stream, or whether it will remain open. If set to `false` (the default), the `Channel` /// will be closed automatically if the remote peer shuts down its send stream. If set to true, the `Channel` will /// not be closed: instead, a `ChannelEvent.inboundClosed` user event will be sent on the `ChannelPipeline`, /// and no more data will be received. public static let allowRemoteHalfClosure = - ChannelOptions.TCPConvenienceOption(.allowRemoteHalfClosure) + ChannelOptions.TCPConvenienceOption(.allowRemoteHalfClosure) } extension ChannelOptions { @@ -120,7 +122,7 @@ extension ChannelOptions { var allowLocalEndpointReuse = false var disableAutoRead = false var allowRemoteHalfClosure = false - + /// Construct from an array literal. @inlinable public init(arrayLiteral elements: TCPConvenienceOption...) { @@ -128,7 +130,7 @@ extension ChannelOptions { self.add(element) } } - + @usableFromInline mutating func add(_ element: TCPConvenienceOption) { switch element.data { @@ -140,7 +142,7 @@ extension ChannelOptions { self.disableAutoRead = true } } - + /// Caller is consuming the knowledge that `allowLocalEndpointReuse` was set or not. /// The setting will nolonger be set after this call. /// - Returns: If `allowLocalEndpointReuse` was set. @@ -150,7 +152,7 @@ extension ChannelOptions { } return Types.ConvenienceOptionValue(flag: self.allowLocalEndpointReuse) } - + /// Caller is consuming the knowledge that disableAutoRead was set or not. /// The setting will nolonger be set after this call. /// - Returns: If disableAutoRead was set. @@ -160,7 +162,7 @@ extension ChannelOptions { } return Types.ConvenienceOptionValue(flag: self.disableAutoRead) } - + /// Caller is consuming the knowledge that allowRemoteHalfClosure was set or not. /// The setting will nolonger be set after this call. /// - Returns: If allowRemoteHalfClosure was set. @@ -170,7 +172,7 @@ extension ChannelOptions { } return Types.ConvenienceOptionValue(flag: self.allowRemoteHalfClosure) } - + mutating func applyFallbackMapping(_ universalBootstrap: NIOClientTCPBootstrap) -> NIOClientTCPBootstrap { var result = universalBootstrap if self.consumeAllowLocalEndpointReuse().isSet { diff --git a/Sources/NIOCore/DeadChannel.swift b/Sources/NIOCore/DeadChannel.swift index 256e3480ac..7c0964973d 100644 --- a/Sources/NIOCore/DeadChannel.swift +++ b/Sources/NIOCore/DeadChannel.swift @@ -80,7 +80,7 @@ internal final class DeadChannel: Channel, @unchecked Sendable { let pipeline: ChannelPipeline public var closeFuture: EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } internal init(pipeline: ChannelPipeline) { @@ -90,25 +90,25 @@ internal final class DeadChannel: Channel, @unchecked Sendable { // This is `Channel` API so must be thread-safe. var allocator: ByteBufferAllocator { - return ByteBufferAllocator() + ByteBufferAllocator() } var localAddress: SocketAddress? { - return nil + nil } var remoteAddress: SocketAddress? { - return nil + nil } let parent: Channel? = nil func setOption(_ option: Option, value: Option.Value) -> EventLoopFuture { - return self.pipeline.eventLoop.makeFailedFuture(ChannelError._ioOnClosedChannel) + self.pipeline.eventLoop.makeFailedFuture(ChannelError._ioOnClosedChannel) } func getOption(_ option: Option) -> EventLoopFuture { - return eventLoop.makeFailedFuture(ChannelError._ioOnClosedChannel) + eventLoop.makeFailedFuture(ChannelError._ioOnClosedChannel) } let isWritable = false diff --git a/Sources/NIOCore/EventLoop+Deprecated.swift b/Sources/NIOCore/EventLoop+Deprecated.swift index 62fcafef73..e2321ceb74 100644 --- a/Sources/NIOCore/EventLoop+Deprecated.swift +++ b/Sources/NIOCore/EventLoop+Deprecated.swift @@ -15,13 +15,21 @@ extension EventLoop { @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func makeFailedFuture(_ error: Error, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.makeFailedFuture(error) + public func makeFailedFuture( + _ error: Error, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture { + self.makeFailedFuture(error) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func makeSucceededFuture(_ value: Success, file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.makeSucceededFuture(value) + public func makeSucceededFuture( + _ value: Success, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture { + self.makeSucceededFuture(value) } } diff --git a/Sources/NIOCore/EventLoop+SerialExecutor.swift b/Sources/NIOCore/EventLoop+SerialExecutor.swift index f157701778..b3d0434133 100644 --- a/Sources/NIOCore/EventLoop+SerialExecutor.swift +++ b/Sources/NIOCore/EventLoop+SerialExecutor.swift @@ -19,7 +19,7 @@ /// Implementers of `EventLoop` should consider conforming to this protocol as /// well on Swift 5.9 and later. @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) -public protocol NIOSerialEventLoopExecutor: EventLoop, SerialExecutor { } +public protocol NIOSerialEventLoopExecutor: EventLoop, SerialExecutor {} @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) extension NIOSerialEventLoopExecutor { diff --git a/Sources/NIOCore/EventLoop.swift b/Sources/NIOCore/EventLoop.swift index 7bfe8fa309..1f38c73a31 100644 --- a/Sources/NIOCore/EventLoop.swift +++ b/Sources/NIOCore/EventLoop.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -import NIOConcurrencyHelpers import Dispatch +import NIOConcurrencyHelpers + #if os(Linux) import CNIOLinux -#endif // os(Linux) +#endif // os(Linux) /// Returned once a task was scheduled on the `EventLoop` for later execution. /// @@ -24,8 +25,8 @@ import CNIOLinux /// will be notified once the execution is complete. public struct Scheduled { @usableFromInline typealias CancelationCallback = @Sendable () -> Void - /* private but usableFromInline */ @usableFromInline let _promise: EventLoopPromise - /* private but usableFromInline */ @usableFromInline let _cancellationTask: CancelationCallback + @usableFromInline let _promise: EventLoopPromise + @usableFromInline let _cancellationTask: CancelationCallback @inlinable @preconcurrency @@ -47,7 +48,7 @@ public struct Scheduled { /// Returns the `EventLoopFuture` which will be notified once the execution of the scheduled task completes. @inlinable public var futureResult: EventLoopFuture { - return self._promise.futureResult + self._promise.futureResult } } @@ -202,7 +203,7 @@ public struct EventLoopIterator: Sequence, IteratorProtocol { /// /// - returns: The next `EventLoop` if a next element exists; otherwise, `nil`. public mutating func next() -> EventLoop? { - return self.eventLoops.next() + self.eventLoops.next() } } @@ -247,7 +248,7 @@ public protocol EventLoop: EventLoopGroup { /// /// If it is necessary for correctness to confirm that you're on an event loop, prefer ``preconditionInEventLoop(file:line:)-7ukrq``. var inEventLoop: Bool { get } - + /// Submit a given task to be executed by the `EventLoop` @preconcurrency func execute(_ task: @escaping @Sendable () -> Void) @@ -361,7 +362,7 @@ public protocol EventLoop: EventLoopGroup { extension EventLoop { /// Default implementation of `makeSucceededVoidFuture`: Return a fresh future (which will allocate). public func makeSucceededVoidFuture() -> EventLoopFuture { - return EventLoopFuture(eventLoop: self, value: ()) + EventLoopFuture(eventLoop: self, value: ()) } public func _preconditionSafeToWait(file: StaticString, line: UInt) { @@ -374,8 +375,9 @@ extension EventLoop { } /// Default implementation of `_promiseCompleted`: does nothing. - public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? { - return nil + public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? + { + nil } } @@ -402,7 +404,7 @@ extension EventLoop { extension EventLoopGroup { public var description: String { - return String(describing: self) + String(describing: self) } } @@ -416,7 +418,7 @@ public struct TimeAmount: Hashable, Sendable { /// The nanoseconds representation of the `TimeAmount`. public let nanoseconds: Int64 - /* private but */ @inlinable + @inlinable init(_ nanoseconds: Int64) { self.nanoseconds = nanoseconds } @@ -428,7 +430,7 @@ public struct TimeAmount: Hashable, Sendable { /// - returns: the `TimeAmount` for the given amount. @inlinable public static func nanoseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(amount) + TimeAmount(amount) } /// Creates a new `TimeAmount` for the given amount of microseconds. @@ -440,7 +442,7 @@ public struct TimeAmount: Hashable, Sendable { /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func microseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000)) + TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000)) } /// Creates a new `TimeAmount` for the given amount of milliseconds. @@ -452,7 +454,7 @@ public struct TimeAmount: Hashable, Sendable { /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func milliseconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000)) + TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of seconds. @@ -464,7 +466,7 @@ public struct TimeAmount: Hashable, Sendable { /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func seconds(_ amount: Int64) -> TimeAmount { - return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000)) + TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000)) } /// Creates a new `TimeAmount` for the given amount of minutes. @@ -476,7 +478,7 @@ public struct TimeAmount: Hashable, Sendable { /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func minutes(_ amount: Int64) -> TimeAmount { - return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60)) + TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60)) } /// Creates a new `TimeAmount` for the given amount of hours. @@ -488,9 +490,9 @@ public struct TimeAmount: Hashable, Sendable { /// - note: returns `TimeAmount(.max)` if the amount overflows when converted to nanoseconds and `TimeAmount(.min)` if it underflows. @inlinable public static func hours(_ amount: Int64) -> TimeAmount { - return TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60 * 60)) + TimeAmount(_cappedNanoseconds(amount: amount, multiplier: 1000 * 1000 * 1000 * 60 * 60)) } - + /// Converts `amount` to nanoseconds multiplying it by `multiplier`. The return value is capped to `Int64.max` if the multiplication overflows and `Int64.min` if it underflows. /// /// - parameters: @@ -511,7 +513,7 @@ public struct TimeAmount: Hashable, Sendable { extension TimeAmount: Comparable { @inlinable public static func < (lhs: TimeAmount, rhs: TimeAmount) -> Bool { - return lhs.nanoseconds < rhs.nanoseconds + lhs.nanoseconds < rhs.nanoseconds } } @@ -519,37 +521,37 @@ extension TimeAmount: AdditiveArithmetic { /// The zero value for `TimeAmount`. @inlinable public static var zero: TimeAmount { - return TimeAmount.nanoseconds(0) + TimeAmount.nanoseconds(0) } @inlinable public static func + (lhs: TimeAmount, rhs: TimeAmount) -> TimeAmount { - return TimeAmount(lhs.nanoseconds + rhs.nanoseconds) + TimeAmount(lhs.nanoseconds + rhs.nanoseconds) } @inlinable - public static func +=(lhs: inout TimeAmount, rhs: TimeAmount) { + public static func += (lhs: inout TimeAmount, rhs: TimeAmount) { lhs = lhs + rhs } @inlinable public static func - (lhs: TimeAmount, rhs: TimeAmount) -> TimeAmount { - return TimeAmount(lhs.nanoseconds - rhs.nanoseconds) + TimeAmount(lhs.nanoseconds - rhs.nanoseconds) } @inlinable - public static func -=(lhs: inout TimeAmount, rhs: TimeAmount) { + public static func -= (lhs: inout TimeAmount, rhs: TimeAmount) { lhs = lhs - rhs } @inlinable public static func * (lhs: T, rhs: TimeAmount) -> TimeAmount { - return TimeAmount(Int64(lhs) * rhs.nanoseconds) + TimeAmount(Int64(lhs) * rhs.nanoseconds) } @inlinable public static func * (lhs: TimeAmount, rhs: T) -> TimeAmount { - return TimeAmount(lhs.nanoseconds * Int64(rhs)) + TimeAmount(lhs.nanoseconds * Int64(rhs)) } } @@ -575,7 +577,7 @@ public struct NIODeadline: Equatable, Hashable, Sendable { public typealias Value = UInt64 // This really should be an UInt63 but we model it as Int64 with >=0 assert - /* private but */ @usableFromInline var _uptimeNanoseconds: Int64 { + @usableFromInline var _uptimeNanoseconds: Int64 { didSet { assert(self._uptimeNanoseconds >= 0) } @@ -584,18 +586,17 @@ public struct NIODeadline: Equatable, Hashable, Sendable { /// The nanoseconds since boot representation of the `NIODeadline`. @inlinable public var uptimeNanoseconds: UInt64 { - return .init(self._uptimeNanoseconds) + .init(self._uptimeNanoseconds) } public static let distantPast = NIODeadline(0) public static let distantFuture = NIODeadline(.init(Int64.max)) - /* private but */ @inlinable init(_ nanoseconds: Int64) { + @inlinable init(_ nanoseconds: Int64) { precondition(nanoseconds >= 0) self._uptimeNanoseconds = nanoseconds } - /// Getting the time is a very common operation so it warrants optimization. /// /// Prior to this function, NIO relied on `DispatchTime.now()`, on all platforms. In addition to @@ -608,31 +609,31 @@ public struct NIODeadline: Equatable, Hashable, Sendable { /// - TODO: Investigate optimizing the call to `DispatchTime.now()` away on other platforms too. @inlinable static func timeNow() -> UInt64 { -#if os(Linux) + #if os(Linux) var ts = timespec() clock_gettime(CLOCK_MONOTONIC, &ts) /// We use unsafe arithmetic here because `UInt64.max` nanoseconds is more than 580 years, /// and the odds that this code will still be running 530 years from now is very, very low, /// so as a practical matter this will never overflow. return UInt64(ts.tv_sec) &* 1_000_000_000 &+ UInt64(ts.tv_nsec) -#else // os(Linux) + #else // os(Linux) return DispatchTime.now().uptimeNanoseconds -#endif // os(Linux) + #endif // os(Linux) } @inlinable public static func now() -> NIODeadline { - return NIODeadline.uptimeNanoseconds(timeNow()) + NIODeadline.uptimeNanoseconds(timeNow()) } @inlinable public static func uptimeNanoseconds(_ nanoseconds: UInt64) -> NIODeadline { - return NIODeadline(Int64(min(UInt64(Int64.max), nanoseconds))) + NIODeadline(Int64(min(UInt64(Int64.max), nanoseconds))) } @inlinable public static func == (lhs: NIODeadline, rhs: NIODeadline) -> Bool { - return lhs.uptimeNanoseconds == rhs.uptimeNanoseconds + lhs.uptimeNanoseconds == rhs.uptimeNanoseconds } @inlinable @@ -644,19 +645,19 @@ public struct NIODeadline: Equatable, Hashable, Sendable { extension NIODeadline: Comparable { @inlinable public static func < (lhs: NIODeadline, rhs: NIODeadline) -> Bool { - return lhs.uptimeNanoseconds < rhs.uptimeNanoseconds + lhs.uptimeNanoseconds < rhs.uptimeNanoseconds } @inlinable public static func > (lhs: NIODeadline, rhs: NIODeadline) -> Bool { - return lhs.uptimeNanoseconds > rhs.uptimeNanoseconds + lhs.uptimeNanoseconds > rhs.uptimeNanoseconds } } extension NIODeadline: CustomStringConvertible { @inlinable public var description: String { - return self.uptimeNanoseconds.description + self.uptimeNanoseconds.description } } @@ -665,7 +666,7 @@ extension NIODeadline { public static func - (lhs: NIODeadline, rhs: NIODeadline) -> TimeAmount { // This won't ever crash, NIODeadlines are guaranteed to be within 0 ..< 2^63-1 nanoseconds so the result can // definitely be stored in a TimeAmount (which is an Int64). - return .nanoseconds(Int64(lhs.uptimeNanoseconds) - Int64(rhs.uptimeNanoseconds)) + .nanoseconds(Int64(lhs.uptimeNanoseconds) - Int64(rhs.uptimeNanoseconds)) } @inlinable @@ -674,7 +675,7 @@ extension NIODeadline { let overflow: Bool (partial, overflow) = Int64(lhs.uptimeNanoseconds).addingReportingOverflow(rhs.nanoseconds) if overflow { - assert(rhs.nanoseconds > 0) // this certainly must have overflowed towards +infinity + assert(rhs.nanoseconds > 0) // this certainly must have overflowed towards +infinity return NIODeadline.distantFuture } guard partial >= 0 else { @@ -807,7 +808,7 @@ extension EventLoop { ) -> Scheduled { self._flatScheduleTask(in: delay, file: file, line: line, task) } - + @usableFromInline typealias FlatScheduleTaskDelayCallback = @Sendable () throws -> EventLoopFuture @inlinable @@ -826,8 +827,12 @@ extension EventLoop { /// Creates and returns a new `EventLoopPromise` that will be notified using this `EventLoop` as execution `NIOThread`. @inlinable - public func makePromise(of type: T.Type = T.self, file: StaticString = #fileID, line: UInt = #line) -> EventLoopPromise { - return EventLoopPromise(eventLoop: self, file: file, line: line) + public func makePromise( + of type: T.Type = T.self, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopPromise { + EventLoopPromise(eventLoop: self, file: file, line: line) } /// Creates and returns a new `EventLoopFuture` that is already marked as failed. Notifications will be done using this `EventLoop` as execution `NIOThread`. @@ -837,7 +842,7 @@ extension EventLoop { /// - returns: a failed `EventLoopFuture`. @inlinable public func makeFailedFuture(_ error: Error) -> EventLoopFuture { - return EventLoopFuture(eventLoop: self, error: error) + EventLoopFuture(eventLoop: self, error: error) } /// Creates and returns a new `EventLoopFuture` that is already marked as success. Notifications will be done using this `EventLoop` as execution `NIOThread`. @@ -885,21 +890,20 @@ extension EventLoop { /// /// - returns: Itself, because an `EventLoop` forms a singular `EventLoopGroup`. public func next() -> EventLoop { - return self + self } /// An `EventLoop` forms a singular `EventLoopGroup`, returning itself as 'any' `EventLoop`. /// /// - returns: Itself, because an `EventLoop` forms a singular `EventLoopGroup`. public func any() -> EventLoop { - return self + self } /// Close this `EventLoop`. public func close() throws { // Do nothing } - /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each /// task. @@ -920,7 +924,7 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } - + /// Schedule a repeated task to be executed by the `EventLoop` with a fixed delay between the end and start of each /// task. /// @@ -939,9 +943,17 @@ extension EventLoop { notifying promise: EventLoopPromise? = nil, _ task: @escaping @Sendable (RepeatedTask) throws -> Void ) -> RepeatedTask { - let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredInitialDelay = Self._getJitteredDelay( + delay: initialDelay, + maximumAllowableJitter: maximumAllowableJitter + ) let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) - return self.scheduleRepeatedTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) + return self.scheduleRepeatedTask( + initialDelay: jitteredInitialDelay, + delay: jitteredDelay, + notifying: promise, + task + ) } typealias ScheduleRepeatedTaskCallback = @Sendable (RepeatedTask) throws -> Void @@ -961,7 +973,7 @@ extension EventLoop { } return self.scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, futureTask) } - + /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and /// start of each task. /// @@ -988,7 +1000,7 @@ extension EventLoop { ) -> RepeatedTask { self._scheduleRepeatedAsyncTask(initialDelay: initialDelay, delay: delay, notifying: promise, task) } - + /// Schedule a repeated asynchronous task to be executed by the `EventLoop` with a fixed delay between the end and /// start of each task. /// @@ -1014,9 +1026,17 @@ extension EventLoop { notifying promise: EventLoopPromise? = nil, _ task: @escaping @Sendable (RepeatedTask) -> EventLoopFuture ) -> RepeatedTask { - let jitteredInitialDelay = Self._getJitteredDelay(delay: initialDelay, maximumAllowableJitter: maximumAllowableJitter) + let jitteredInitialDelay = Self._getJitteredDelay( + delay: initialDelay, + maximumAllowableJitter: maximumAllowableJitter + ) let jitteredDelay = Self._getJitteredDelay(delay: delay, maximumAllowableJitter: maximumAllowableJitter) - return self._scheduleRepeatedAsyncTask(initialDelay: jitteredInitialDelay, delay: jitteredDelay, notifying: promise, task) + return self._scheduleRepeatedAsyncTask( + initialDelay: jitteredInitialDelay, + delay: jitteredDelay, + notifying: promise, + task + ) } typealias ScheduleRepeatedAsyncTaskCallback = @Sendable (RepeatedTask) -> EventLoopFuture @@ -1043,14 +1063,14 @@ extension EventLoop { maximumAllowableJitter: TimeAmount ) -> TimeAmount { let jitter = TimeAmount.nanoseconds(Int64.random(in: .zero.. EventLoopIterator { - return EventLoopIterator([self]) + EventLoopIterator([self]) } /// Asserts that the current thread is the one tied to this `EventLoop`. @@ -1127,7 +1147,7 @@ public protocol EventLoopGroup: AnyObject, _NIOPreconcurrencySendable { /// The rule of thumb is: If you are trying to do _load balancing_, use `next()`. If you just want to create a new /// future or kick off some operation, use `any()`. func any() -> EventLoop - + /// Shuts down the eventloop gracefully. This function is clearly an outlier in that it uses a completion /// callback instead of an EventLoopFuture. The reason for that is that NIO's EventLoopFutures will call back on an event loop. /// The virtue of this function is to shut the event loop down. To work around that we call back on a DispatchQueue @@ -1150,7 +1170,7 @@ extension EventLoopGroup { /// The default implementation of `any()` just returns the `next()` EventLoop but it's highly recommended to /// override this and return the current `EventLoop` if possible. public func any() -> EventLoop { - return self.next() + self.next() } } diff --git a/Sources/NIOCore/EventLoopFuture+AssumeIsolated.swift b/Sources/NIOCore/EventLoopFuture+AssumeIsolated.swift index 339d1d42ec..bd7df99c56 100644 --- a/Sources/NIOCore/EventLoopFuture+AssumeIsolated.swift +++ b/Sources/NIOCore/EventLoopFuture+AssumeIsolated.swift @@ -120,7 +120,7 @@ struct IsolatedEventLoop { /// Returns the wrapped event loop. @inlinable func nonisolated() -> any EventLoop { - return self._wrapped + self._wrapped } } extension EventLoop { @@ -421,8 +421,8 @@ extension EventLoopFuture { @inlinable func unwrap( orReplace replacement: NewValue - ) -> EventLoopFuture.Isolated where Value == Optional { - return self.map { (value) -> NewValue in + ) -> EventLoopFuture.Isolated where Value == NewValue? { + self.map { (value) -> NewValue in guard let value = value else { return replacement } @@ -447,8 +447,8 @@ extension EventLoopFuture { @inlinable func unwrap( orElse callback: @escaping () -> NewValue - ) -> EventLoopFuture.Isolated where Value == Optional { - return self.map { (value) -> NewValue in + ) -> EventLoopFuture.Isolated where Value == NewValue? { + self.map { (value) -> NewValue in guard let value = value else { return callback() } @@ -459,7 +459,7 @@ extension EventLoopFuture { /// Returns the wrapped event loop future. @inlinable func nonisolated() -> EventLoopFuture { - return self._wrapped + self._wrapped } } @@ -471,7 +471,6 @@ extension EventLoopFuture { } } - extension EventLoopPromise { /// A struct wrapping an ``EventLoopPromise`` that ensures all calls to any method on this struct /// are coming from the event loop of the promise. @@ -480,7 +479,6 @@ extension EventLoopPromise { @usableFromInline let _wrapped: EventLoopPromise - /// Deliver a successful result to the associated `EventLoopFuture` object. /// /// - parameters: @@ -514,7 +512,7 @@ extension EventLoopPromise { /// Returns the wrapped event loop promise. @inlinable func nonisolated() -> EventLoopPromise { - return self._wrapped + self._wrapped } } @@ -525,4 +523,3 @@ extension EventLoopPromise { return Isolated(_wrapped: self) } } - diff --git a/Sources/NIOCore/EventLoopFuture+Deprecated.swift b/Sources/NIOCore/EventLoopFuture+Deprecated.swift index 75cfb07162..6883c9ffd9 100644 --- a/Sources/NIOCore/EventLoopFuture+Deprecated.swift +++ b/Sources/NIOCore/EventLoopFuture+Deprecated.swift @@ -15,63 +15,91 @@ extension EventLoopFuture { @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func flatMap(file: StaticString = #fileID, line: UInt = #line, _ callback: @escaping (Value) -> EventLoopFuture) -> EventLoopFuture { - return self.flatMap(callback) + public func flatMap( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Value) -> EventLoopFuture + ) -> EventLoopFuture { + self.flatMap(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func flatMapThrowing(file: StaticString = #fileID, - line: UInt = #line, - _ callback: @escaping (Value) throws -> NewValue) -> EventLoopFuture { - return self.flatMapThrowing(callback) + public func flatMapThrowing( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Value) throws -> NewValue + ) -> EventLoopFuture { + self.flatMapThrowing(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func flatMapErrorThrowing(file: StaticString = #fileID, line: UInt = #line, _ callback: @escaping (Error) throws -> Value) -> EventLoopFuture { - return self.flatMapErrorThrowing(callback) + public func flatMapErrorThrowing( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Error) throws -> Value + ) -> EventLoopFuture { + self.flatMapErrorThrowing(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func map(file: StaticString = #fileID, line: UInt = #line, _ callback: @escaping (Value) -> (NewValue)) -> EventLoopFuture { - return self.map(callback) + public func map( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Value) -> (NewValue) + ) -> EventLoopFuture { + self.map(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func flatMapError(file: StaticString = #fileID, line: UInt = #line, _ callback: @escaping (Error) -> EventLoopFuture) -> EventLoopFuture { - return self.flatMapError(callback) + public func flatMapError( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Error) -> EventLoopFuture + ) -> EventLoopFuture { + self.flatMapError(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func flatMapResult(file: StaticString = #fileID, - line: UInt = #line, - _ body: @escaping (Value) -> Result) -> EventLoopFuture { - return self.flatMapResult(body) + public func flatMapResult( + file: StaticString = #fileID, + line: UInt = #line, + _ body: @escaping (Value) -> Result + ) -> EventLoopFuture { + self.flatMapResult(body) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func recover(file: StaticString = #fileID, line: UInt = #line, _ callback: @escaping (Error) -> Value) -> EventLoopFuture { - return self.recover(callback) + public func recover( + file: StaticString = #fileID, + line: UInt = #line, + _ callback: @escaping (Error) -> Value + ) -> EventLoopFuture { + self.recover(callback) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func and(_ other: EventLoopFuture, - file: StaticString = #fileID, - line: UInt = #line) -> EventLoopFuture<(Value, OtherValue)> { - return self.and(other) + public func and( + _ other: EventLoopFuture, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture<(Value, OtherValue)> { + self.and(other) } @inlinable @available(*, deprecated, message: "Please don't pass file:line:, there's no point.") - public func and(value: OtherValue, - file: StaticString = #fileID, - line: UInt = #line) -> EventLoopFuture<(Value, OtherValue)> { - return self.and(value: value) + public func and( + value: OtherValue, + file: StaticString = #fileID, + line: UInt = #line + ) -> EventLoopFuture<(Value, OtherValue)> { + self.and(value: value) } } diff --git a/Sources/NIOCore/EventLoopFuture+WithEventLoop.swift b/Sources/NIOCore/EventLoopFuture+WithEventLoop.swift index 4d105fa6e5..bf76a0e97e 100644 --- a/Sources/NIOCore/EventLoopFuture+WithEventLoop.swift +++ b/Sources/NIOCore/EventLoopFuture+WithEventLoop.swift @@ -41,7 +41,9 @@ extension EventLoopFuture { /// - returns: A future that will receive the eventual value. @inlinable @preconcurrency - public func flatMapWithEventLoop(_ callback: @escaping @Sendable (Value, EventLoop) -> EventLoopFuture) -> EventLoopFuture { + public func flatMapWithEventLoop( + _ callback: @escaping @Sendable (Value, EventLoop) -> EventLoopFuture + ) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) self._whenComplete { [eventLoop = self.eventLoop] in switch self._value! { @@ -61,7 +63,7 @@ extension EventLoopFuture { } return next.futureResult } - + /// When the current `EventLoopFuture` is in an error state, run the provided callback, which /// may recover from the error by returning an `EventLoopFuture`. The callback is intended to potentially /// recover from the error by returning a new `EventLoopFuture` that will eventually contain the recovered @@ -75,7 +77,9 @@ extension EventLoopFuture { /// - returns: A future that will receive the recovered value. @inlinable @preconcurrency - public func flatMapErrorWithEventLoop(_ callback: @escaping @Sendable (Error, EventLoop) -> EventLoopFuture) -> EventLoopFuture { + public func flatMapErrorWithEventLoop( + _ callback: @escaping @Sendable (Error, EventLoop) -> EventLoopFuture + ) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) self._whenComplete { [eventLoop = self.eventLoop] in switch self._value! { @@ -119,7 +123,8 @@ extension EventLoopFuture { with combiningFunction: @escaping @Sendable (Value, OtherValue, EventLoop) -> EventLoopFuture ) -> EventLoopFuture { func fold0(eventLoop: EventLoop) -> EventLoopFuture { - let body = futures.reduce(self) { (f1: EventLoopFuture, f2: EventLoopFuture) -> EventLoopFuture in + let body = futures.reduce(self) { + (f1: EventLoopFuture, f2: EventLoopFuture) -> EventLoopFuture in let newFuture = f1.and(f2).flatMap { (args: (Value, OtherValue)) -> EventLoopFuture in let (f1Value, f2Value) = args self.eventLoop.assertInEventLoop() diff --git a/Sources/NIOCore/EventLoopFuture.swift b/Sources/NIOCore/EventLoopFuture.swift index 219bbe18d6..65b18c048c 100644 --- a/Sources/NIOCore/EventLoopFuture.swift +++ b/Sources/NIOCore/EventLoopFuture.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOConcurrencyHelpers import Dispatch +import NIOConcurrencyHelpers /// Internal list of callbacks. /// @@ -158,9 +158,11 @@ public struct EventLoopPromise { @inlinable internal static func makeUnleakablePromise(eventLoop: EventLoop, line: UInt = #line) -> EventLoopPromise { - return EventLoopPromise(eventLoop: eventLoop, - file: "BUG in SwiftNIO (please report), unleakable promise leaked.", - line: line) + EventLoopPromise( + eventLoop: eventLoop, + file: "BUG in SwiftNIO (please report), unleakable promise leaked.", + line: line + ) } /// General initializer @@ -196,7 +198,7 @@ public struct EventLoopPromise { /// /// This method is equivalent to invoking `future.cascade(to: promise)`, /// but sometimes may read better than its cascade counterpart. - /// + /// /// - parameters: /// - future: The future whose value will be used to succeed or fail this promise. /// - seealso: `EventLoopFuture.cascade(to:)` @@ -250,7 +252,7 @@ public struct EventLoopPromise { /// - returns: The callback list to run. @inlinable internal func _setValue(value: Result) -> CallbackList { - return self.futureResult._setValue(value: value) + self.futureResult._setValue(value: value) } } @@ -432,8 +434,8 @@ public final class EventLoopFuture { } extension EventLoopFuture: Equatable { - public static func ==(lhs: EventLoopFuture, rhs: EventLoopFuture) -> Bool { - return lhs === rhs + public static func == (lhs: EventLoopFuture, rhs: EventLoopFuture) -> Bool { + lhs === rhs } } @@ -470,7 +472,9 @@ extension EventLoopFuture { /// - returns: A future that will receive the eventual value. @inlinable @preconcurrency - public func flatMap(_ callback: @escaping @Sendable (Value) -> EventLoopFuture) -> EventLoopFuture { + public func flatMap( + _ callback: @escaping @Sendable (Value) -> EventLoopFuture + ) -> EventLoopFuture { self._flatMap(callback) } @usableFromInline typealias FlatMapCallback = @Sendable (Value) -> EventLoopFuture @@ -513,13 +517,17 @@ extension EventLoopFuture { /// - returns: A future that will receive the eventual value. @inlinable @preconcurrency - public func flatMapThrowing(_ callback: @escaping @Sendable (Value) throws -> NewValue) -> EventLoopFuture { + public func flatMapThrowing( + _ callback: @escaping @Sendable (Value) throws -> NewValue + ) -> EventLoopFuture { self._flatMapThrowing(callback) } @usableFromInline typealias FlatMapThrowingCallback = @Sendable (Value) throws -> NewValue @inlinable - func _flatMapThrowing(_ callback: @escaping FlatMapThrowingCallback) -> EventLoopFuture { + func _flatMapThrowing( + _ callback: @escaping FlatMapThrowingCallback + ) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) self._whenComplete { switch self._value! { @@ -553,7 +561,8 @@ extension EventLoopFuture { /// - returns: A future that will receive the eventual value or a rethrown error. @inlinable @preconcurrency - public func flatMapErrorThrowing(_ callback: @escaping @Sendable (Error) throws -> Value) -> EventLoopFuture { + public func flatMapErrorThrowing(_ callback: @escaping @Sendable (Error) throws -> Value) -> EventLoopFuture + { self._flatMapErrorThrowing(callback) } @usableFromInline typealias FlatMapErrorThrowingCallback = @Sendable (Error) throws -> Value @@ -618,7 +627,7 @@ extension EventLoopFuture { } else { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) self._whenComplete { - return next._setValue(value: self._value!.map(callback)) + next._setValue(value: self._value!.map(callback)) } return next.futureResult } @@ -637,7 +646,9 @@ extension EventLoopFuture { /// - returns: A future that will receive the recovered value. @inlinable @preconcurrency - public func flatMapError(_ callback: @escaping @Sendable (Error) -> EventLoopFuture) -> EventLoopFuture { + public func flatMapError( + _ callback: @escaping @Sendable (Error) -> EventLoopFuture + ) -> EventLoopFuture { self._flatMapError(callback) } @usableFromInline typealias FlatMapErrorCallback = @Sendable (Error) -> EventLoopFuture @@ -679,13 +690,19 @@ extension EventLoopFuture { /// - returns: A future that will receive the eventual value. @inlinable @preconcurrency - public func flatMapResult(_ body: @escaping @Sendable (Value) -> Result) -> EventLoopFuture { + public func flatMapResult( + _ body: @escaping @Sendable (Value) -> Result + ) -> EventLoopFuture { self._flatMapResult(body) } - @usableFromInline typealias FlatMapResultCallback = @Sendable (Value) -> Result + @usableFromInline typealias FlatMapResultCallback = @Sendable (Value) -> Result< + NewValue, SomeError + > @inlinable - func _flatMapResult(_ body: @escaping FlatMapResultCallback) -> EventLoopFuture { + func _flatMapResult( + _ body: @escaping FlatMapResultCallback + ) -> EventLoopFuture { let next = EventLoopPromise.makeUnleakablePromise(eventLoop: self.eventLoop) self._whenComplete { switch self._value! { @@ -750,7 +767,8 @@ extension EventLoopFuture { /// Add a callback. If there's already a value, run as much of the chain as we can. @inlinable - @preconcurrency // TODO: We want to remove @preconcurrency but it results in more allocations in 1000_udpconnections + // TODO: We want to remove @preconcurrency but it results in more allocations in 1000_udpconnections + @preconcurrency internal func _whenComplete(_ callback: @escaping @Sendable () -> CallbackList) { self._internalWhenComplete(callback) } @@ -865,7 +883,7 @@ extension EventLoopFuture { @inlinable public func and(_ other: EventLoopFuture) -> EventLoopFuture<(Value, OtherValue)> { let promise = EventLoopPromise<(Value, OtherValue)>.makeUnleakablePromise(eventLoop: self.eventLoop) - let box: UnsafeMutableTransferBox<(t:Value?, u: OtherValue?)> = .init((nil, nil)) + let box: UnsafeMutableTransferBox<(t: Value?, u: OtherValue?)> = .init((nil, nil)) assert(self.eventLoop === promise.futureResult.eventLoop) self._whenComplete { () -> CallbackList in @@ -905,7 +923,7 @@ extension EventLoopFuture { /// This is just syntactic sugar for `future.and(loop.makeSucceedFuture(value))`. @inlinable public func and(value: OtherValue) -> EventLoopFuture<(Value, OtherValue)> { - return self.and(EventLoopFuture(eventLoop: self.eventLoop, value: value)) + self.and(EventLoopFuture(eventLoop: self.eventLoop, value: value)) } } @@ -994,7 +1012,7 @@ extension EventLoopFuture { @available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()") @inlinable public func wait(file: StaticString = #file, line: UInt = #line) throws -> Value { - return try self._wait(file: file, line: line) + try self._wait(file: file, line: line) } @inlinable @@ -1012,7 +1030,7 @@ extension EventLoopFuture { lock.lock(whenValue: 1) lock.unlock() - switch(v.wrappedValue!) { + switch v.wrappedValue! { case .success(let result): return result case .failure(let error): @@ -1056,7 +1074,8 @@ extension EventLoopFuture { with combiningFunction: @escaping FoldCallback ) -> EventLoopFuture { func fold0() -> EventLoopFuture { - let body = futures.reduce(self) { (f1: EventLoopFuture, f2: EventLoopFuture) -> EventLoopFuture in + let body = futures.reduce(self) { + (f1: EventLoopFuture, f2: EventLoopFuture) -> EventLoopFuture in let newFuture = f1.and(f2).flatMap { (args: (Value, OtherValue)) -> EventLoopFuture in let (f1Value, f2Value) = args self.eventLoop.assertInEventLoop() @@ -1106,9 +1125,9 @@ extension EventLoopFuture { @inlinable public static func reduce( _ initialResult: Value, - _ futures: [EventLoopFuture], - on eventLoop: EventLoop, - _ nextPartialResult: @escaping @Sendable (Value, InputValue) -> Value + _ futures: [EventLoopFuture], + on eventLoop: EventLoop, + _ nextPartialResult: @escaping @Sendable (Value, InputValue) -> Value ) -> EventLoopFuture { Self._reduce(initialResult, futures, on: eventLoop, nextPartialResult) } @@ -1203,7 +1222,10 @@ extension EventLoopFuture { /// - on: The `EventLoop` on which the new `EventLoopFuture` callbacks will execute on. /// - Returns: A new `EventLoopFuture` that waits for the other futures to succeed. @inlinable - public static func andAllSucceed(_ futures: [EventLoopFuture], on eventLoop: EventLoop) -> EventLoopFuture { + public static func andAllSucceed( + _ futures: [EventLoopFuture], + on eventLoop: EventLoop + ) -> EventLoopFuture { let promise = eventLoop.makePromise(of: Void.self) EventLoopFuture.andAllSucceed(futures, promise: promise) return promise.futureResult @@ -1238,7 +1260,10 @@ extension EventLoopFuture { /// - futures: An array of homogenous `EventLoopFuture`s to wait on for fulfilled values. /// - on: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. /// - Returns: A new `EventLoopFuture` with all of the values fulfilled by the provided futures. - public static func whenAllSucceed(_ futures: [EventLoopFuture], on eventLoop: EventLoop) -> EventLoopFuture<[Value]> { + public static func whenAllSucceed( + _ futures: [EventLoopFuture], + on eventLoop: EventLoop + ) -> EventLoopFuture<[Value]> { let promise = eventLoop.makePromise(of: [Value].self) EventLoopFuture.whenAllSucceed(futures, promise: promise) return promise.futureResult @@ -1272,7 +1297,7 @@ extension EventLoopFuture { reduced.futureResult.whenComplete { result in switch result { case .success: - // verify that all operations have been completed + // verify that all operations have been completed assert(!results.wrappedValue.contains(where: { $0 == nil })) promise.succeed(results.wrappedValue.map { $0! }) case .failure(let error): @@ -1321,7 +1346,8 @@ extension EventLoopFuture { // in the "futures" to pass their result to the caller for (index, future) in futures.enumerated() { if future.eventLoop.inEventLoop, - let result = future._value { + let result = future._value + { // Fast-track already-fulfilled results without the overhead of calling `whenComplete`. This can yield a // ~20% performance improvement in the case of large arrays where all elements are already fulfilled. processResult(index, result) @@ -1350,7 +1376,10 @@ extension EventLoopFuture { /// - on: The `EventLoop` on which the new `EventLoopFuture` callbacks will execute on. /// - Returns: A new `EventLoopFuture` that succeeds after all futures complete. @inlinable - public static func andAllComplete(_ futures: [EventLoopFuture], on eventLoop: EventLoop) -> EventLoopFuture { + public static func andAllComplete( + _ futures: [EventLoopFuture], + on eventLoop: EventLoop + ) -> EventLoopFuture { let promise = eventLoop.makePromise(of: Void.self) EventLoopFuture.andAllComplete(futures, promise: promise) return promise.futureResult @@ -1390,8 +1419,10 @@ extension EventLoopFuture { /// - on: The `EventLoop` on which the new `EventLoopFuture` callbacks will fire. /// - Returns: A new `EventLoopFuture` with all the results of the provided futures. @inlinable - public static func whenAllComplete(_ futures: [EventLoopFuture], - on eventLoop: EventLoop) -> EventLoopFuture<[Result]> { + public static func whenAllComplete( + _ futures: [EventLoopFuture], + on eventLoop: EventLoop + ) -> EventLoopFuture<[Result]> { let promise = eventLoop.makePromise(of: [Result].self) EventLoopFuture.whenAllComplete(futures, promise: promise) return promise.futureResult @@ -1405,12 +1436,16 @@ extension EventLoopFuture { /// - futures: An array of homogenous `EventLoopFuture`s to gather results from. /// - promise: The `EventLoopPromise` to complete with the result of the futures. @inlinable - public static func whenAllComplete(_ futures: [EventLoopFuture], - promise: EventLoopPromise<[Result]>) { + public static func whenAllComplete( + _ futures: [EventLoopFuture], + promise: EventLoopPromise<[Result]> + ) { let eventLoop = promise.futureResult.eventLoop let reduced = eventLoop.makePromise(of: Void.self) - let results: UnsafeMutableTransferBox<[Result]> = .init(.init(repeating: .failure(OperationPlaceholderError()), count: futures.count)) + let results: UnsafeMutableTransferBox<[Result]> = .init( + .init(repeating: .failure(OperationPlaceholderError()), count: futures.count) + ) let callback = { @Sendable (index: Int, result: Result) in results.wrappedValue[index] = result } @@ -1427,10 +1462,12 @@ extension EventLoopFuture { switch result { case .success: // verify that all operations have been completed - assert(!results.wrappedValue.contains(where: { - guard case let .failure(error) = $0 else { return false } - return error is OperationPlaceholderError - })) + assert( + !results.wrappedValue.contains(where: { + guard case let .failure(error) = $0 else { return false } + return error is OperationPlaceholderError + }) + ) promise.succeed(results.wrappedValue) case .failure(let error): @@ -1474,7 +1511,8 @@ extension EventLoopFuture { // in the "futures" to pass their result to the caller for (index, future) in futures.enumerated() { if future.eventLoop.inEventLoop, - let result = future._value { + let result = future._value + { // Fast-track already-fulfilled results without the overhead of calling `whenComplete`. This can yield a // ~30% performance improvement in the case of large arrays where all elements are already fulfilled. processResult(index, result) @@ -1519,7 +1557,7 @@ extension EventLoopFuture { /// `EventLoopFuture` has any result. /// /// - parameters: - /// - callback: the callback that is called when the `EventLoopFuture` is fulfilled. + /// - callback: the callback that is called when the `EventLoopFuture` is fulfilled. /// - returns: the current `EventLoopFuture` @inlinable @preconcurrency @@ -1557,8 +1595,8 @@ extension EventLoopFuture { /// future. /// - throws: the `Error` passed in the `orError` parameter when the resolved future's value is `Optional.none`. @inlinable - public func unwrap(orError error: Error) -> EventLoopFuture where Value == Optional { - return self.flatMapThrowing { (value) throws -> NewValue in + public func unwrap(orError error: Error) -> EventLoopFuture where Value == NewValue? { + self.flatMapThrowing { (value) throws -> NewValue in guard let value = value else { throw error } @@ -1578,12 +1616,13 @@ extension EventLoopFuture { /// - orReplace: the value of the returned `EventLoopFuture` when then resolved future's value is `Optional.some()`. /// - returns: an new `EventLoopFuture` with new type parameter `NewValue` and the value passed in the `orReplace` parameter. @inlinable - public func unwrap(orReplace replacement: NewValue) -> EventLoopFuture where Value == Optional { - return self.map { (value) -> NewValue in + public func unwrap(orReplace replacement: NewValue) -> EventLoopFuture + where Value == NewValue? { + self.map { (value) -> NewValue in guard let value = value else { return replacement } - return value + return value } } @@ -1605,7 +1644,7 @@ extension EventLoopFuture { @preconcurrency public func unwrap( orElse callback: @escaping @Sendable () -> NewValue - ) -> EventLoopFuture where Value == Optional { + ) -> EventLoopFuture where Value == NewValue? { self._unwrap(orElse: callback) } @usableFromInline typealias UnwrapCallback = @Sendable () -> NewValue @@ -1613,8 +1652,8 @@ extension EventLoopFuture { @inlinable func _unwrap( orElse callback: @escaping UnwrapCallback - ) -> EventLoopFuture where Value == Optional { - return self.map { (value) -> NewValue in + ) -> EventLoopFuture where Value == NewValue? { + self.map { (value) -> NewValue in guard let value = value else { return callback() } @@ -1623,7 +1662,7 @@ extension EventLoopFuture { } } -// MARK: may block +// MARK: may block extension EventLoopFuture { /// Chain an `EventLoopFuture` providing the result of a IO / task that may block. For example: @@ -1651,7 +1690,7 @@ extension EventLoopFuture { onto queue: DispatchQueue, _ callbackMayBlock: @escaping FlatMapBlockingCallback ) -> EventLoopFuture { - return self.flatMap { result in + self.flatMap { result in queue.asyncWithFuture(eventLoop: self.eventLoop) { try callbackMayBlock(result) } } } @@ -1687,7 +1726,8 @@ extension EventLoopFuture { /// - callbackMayBlock: The callback that is called with the failed result of the `EventLoopFuture`. @inlinable @preconcurrency - public func whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping @Sendable (Error) -> Void) { + public func whenFailureBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping @Sendable (Error) -> Void) + { self._whenFailureBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenFailureBlockingCallback = @Sendable (Error) -> Void @@ -1707,7 +1747,10 @@ extension EventLoopFuture { /// - callbackMayBlock: The callback that is called when the `EventLoopFuture` is fulfilled. @inlinable @preconcurrency - public func whenCompleteBlocking(onto queue: DispatchQueue, _ callbackMayBlock: @escaping @Sendable (Result) -> Void) { + public func whenCompleteBlocking( + onto queue: DispatchQueue, + _ callbackMayBlock: @escaping @Sendable (Result) -> Void + ) { self._whenCompleteBlocking(onto: queue, callbackMayBlock) } @usableFromInline typealias WhenCompleteBlocking = @Sendable (Result) -> Void @@ -1733,7 +1776,7 @@ extension EventLoopFuture { /// - line: The line this function was called on, for debugging purposes. @inlinable public func assertSuccess(file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.always { result in + self.always { result in switch result { case .success: () @@ -1752,7 +1795,7 @@ extension EventLoopFuture { /// - line: The line this function was called on, for debugging purposes. @inlinable public func assertFailure(file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.always { result in + self.always { result in switch result { case .success(let value): assertionFailure("Expected failure, but got success: \(value)", file: file, line: line) @@ -1772,7 +1815,7 @@ extension EventLoopFuture { /// - line: The line this function was called on, for debugging purposes. @inlinable public func preconditionSuccess(file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.always { result in + self.always { result in switch result { case .success: () @@ -1792,7 +1835,7 @@ extension EventLoopFuture { /// - line: The line this function was called on, for debugging purposes. @inlinable public func preconditionFailure(file: StaticString = #fileID, line: UInt = #line) -> EventLoopFuture { - return self.always { result in + self.always { result in switch result { case .success(let value): Swift.preconditionFailure("Expected failure, but got success: \(value)", file: file, line: line) @@ -1821,17 +1864,17 @@ public struct _NIOEventLoopFutureIdentifier: Hashable, Sendable { // 1. 0xbf15ca5d is randomly picked such that it fits into both 32 and 64 bit address spaces // 2. XOR with 0xbf15ca5d so that Memory Graph Debugger and other memory debugging tools // won't see it as a reference. - return UInt(bitPattern: ObjectIdentifier(future)) ^ 0xbf15ca5d + UInt(bitPattern: ObjectIdentifier(future)) ^ 0xbf15_ca5d } } // EventLoopPromise is a reference type, but by its very nature is Sendable (if its Value is). -extension EventLoopPromise: Sendable where Value: Sendable { } +extension EventLoopPromise: Sendable where Value: Sendable {} // EventLoopFuture is a reference type, but it is Sendable (if its Value is). However, we enforce // that by way of the guarantees of the EventLoop protocol, so the compiler cannot // check it. -extension EventLoopFuture: @unchecked Sendable where Value: Sendable { } +extension EventLoopFuture: @unchecked Sendable where Value: Sendable {} extension EventLoopPromise where Value == Void { // Deliver a successful result to the associated `EventLoopFuture` object. @@ -1849,7 +1892,8 @@ extension Optional { /// to `promise`. /// /// - Parameter promise: The promise to set or cascade to. - public mutating func setOrCascade(to promise: EventLoopPromise?) where Wrapped == EventLoopPromise { + public mutating func setOrCascade(to promise: EventLoopPromise?) + where Wrapped == EventLoopPromise { guard let promise = promise else { return } switch self { diff --git a/Sources/NIOCore/FileDescriptor.swift b/Sources/NIOCore/FileDescriptor.swift index b97d874b10..b9d2c8997b 100644 --- a/Sources/NIOCore/FileDescriptor.swift +++ b/Sources/NIOCore/FileDescriptor.swift @@ -31,4 +31,3 @@ public protocol FileDescriptor { /// Close this `FileDescriptor`. func close() throws } - diff --git a/Sources/NIOCore/FileHandle.swift b/Sources/NIOCore/FileHandle.swift index b9de95175a..f7018dde5a 100644 --- a/Sources/NIOCore/FileHandle.swift +++ b/Sources/NIOCore/FileHandle.swift @@ -51,7 +51,10 @@ public final class NIOFileHandle: FileDescriptor { } deinit { - assert(!self.isOpen, "leaked open NIOFileHandle(descriptor: \(self.descriptor)). Call `close()` to close or `takeDescriptorOwnership()` to take ownership and close by some other means.") + assert( + !self.isOpen, + "leaked open NIOFileHandle(descriptor: \(self.descriptor)). Call `close()` to close or `takeDescriptorOwnership()` to take ownership and close by some other means." + ) } /// Duplicates this `NIOFileHandle`. This means that a new `NIOFileHandle` object with a new underlying file descriptor @@ -61,7 +64,7 @@ public final class NIOFileHandle: FileDescriptor { /// /// - returns: A new `NIOFileHandle` with a fresh underlying file descriptor but shared seek pointer. public func duplicate() throws -> NIOFileHandle { - return try withUnsafeFileDescriptor { fd in + try withUnsafeFileDescriptor { fd in NIOFileHandle(descriptor: try SystemCalls.dup(descriptor: fd)) } } @@ -132,18 +135,18 @@ extension NIOFileHandle { public static let `default` = Flags(posixMode: 0, posixFlags: 0) -#if os(Windows) + #if os(Windows) public static let defaultPermissions = _S_IREAD | _S_IWRITE -#else + #else public static let defaultPermissions = S_IWUSR | S_IRUSR | S_IRGRP | S_IROTH -#endif + #endif /// Allows file creation when opening file for writing. File owner is set to the effective user ID of the process. /// /// - parameters: /// - posixMode: `file mode` applied when file is created. Default permissions are: read and write for fileowner, read for owners group and others. public static func allowFileCreation(posixMode: NIOPOSIXFileMode = defaultPermissions) -> Flags { - return Flags(posixMode: posixMode, posixFlags: O_CREAT) + Flags(posixMode: posixMode, posixFlags: O_CREAT) } /// Allows the specification of POSIX flags (e.g. `O_TRUNC`) and mode (e.g. `S_IWUSR`) @@ -153,7 +156,7 @@ extension NIOFileHandle { /// - mode: The POSIX mode (the third parameter for `open(2)`). /// - returns: A `NIOFileHandle.Mode` equivalent to the given POSIX flags and mode. public static func posix(flags: CInt, mode: NIOPOSIXFileMode) -> Flags { - return Flags(posixMode: mode, posixFlags: flags) + Flags(posixMode: mode, posixFlags: flags) } } @@ -164,11 +167,11 @@ extension NIOFileHandle { /// - mode: Access mode. Default mode is `.read`. /// - flags: Additional POSIX flags. public convenience init(path: String, mode: Mode = .read, flags: Flags = .default) throws { -#if os(Windows) + #if os(Windows) let fl = mode.posixFlags | flags.posixFlags | _O_NOINHERIT -#else + #else let fl = mode.posixFlags | flags.posixFlags | O_CLOEXEC -#endif + #endif let fd = try SystemCalls.open(file: path, oFlag: fl, mode: flags.posixMode) self.init(descriptor: fd) } @@ -186,6 +189,6 @@ extension NIOFileHandle { extension NIOFileHandle: CustomStringConvertible { public var description: String { - return "FileHandle { descriptor: \(self.descriptor) }" + "FileHandle { descriptor: \(self.descriptor) }" } } diff --git a/Sources/NIOCore/FileRegion.swift b/Sources/NIOCore/FileRegion.swift index 72db1b13a7..e2595586ba 100644 --- a/Sources/NIOCore/FileRegion.swift +++ b/Sources/NIOCore/FileRegion.swift @@ -23,7 +23,6 @@ import Musl #error("The File Region module was unable to identify your C library.") #endif - /// A `FileRegion` represent a readable portion usually created to be sent over the network. /// /// Usually a `FileRegion` will allow the underlying transport to use `sendfile` to transfer its content and so allows transferring @@ -47,7 +46,7 @@ public struct FileRegion { /// The current reader index of this `FileRegion` private(set) public var readerIndex: Int { get { - return Int(self._readerIndex) + Int(self._readerIndex) } set { self._readerIndex = _UInt56(newValue) @@ -56,7 +55,7 @@ public struct FileRegion { /// The end index of this `FileRegion`. public var endIndex: Int { - return Int(self._endIndex) + Int(self._endIndex) } /// Create a new `FileRegion` from an open `NIOFileHandle`. @@ -75,7 +74,7 @@ public struct FileRegion { /// The number of readable bytes within this FileRegion (taking the `readerIndex` and `endIndex` into account). public var readableBytes: Int { - return endIndex - readerIndex + endIndex - readerIndex } /// Move the readerIndex forward by `offset`. @@ -106,13 +105,13 @@ extension FileRegion { } extension FileRegion: Equatable { - public static func ==(lhs: FileRegion, rhs: FileRegion) -> Bool { - return lhs.fileHandle === rhs.fileHandle && lhs.readerIndex == rhs.readerIndex && lhs.endIndex == rhs.endIndex + public static func == (lhs: FileRegion, rhs: FileRegion) -> Bool { + lhs.fileHandle === rhs.fileHandle && lhs.readerIndex == rhs.readerIndex && lhs.endIndex == rhs.endIndex } } extension FileRegion: CustomStringConvertible { public var description: String { - return "FileRegion { handle: \(self.fileHandle), readerIndex: \(self.readerIndex), endIndex: \(self.endIndex) }" + "FileRegion { handle: \(self.fileHandle), readerIndex: \(self.readerIndex), endIndex: \(self.endIndex) }" } } diff --git a/Sources/NIOCore/GlobalSingletons.swift b/Sources/NIOCore/GlobalSingletons.swift index 79e752a74e..384dfc6d71 100644 --- a/Sources/NIOCore/GlobalSingletons.swift +++ b/Sources/NIOCore/GlobalSingletons.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import Atomics + #if canImport(Darwin) import Darwin #elseif os(Windows) @@ -49,8 +50,10 @@ extension NIOSingletons { } get { - return Self.getTrustworthyThreadCount(rawStorage: globalRawSuggestedLoopCount, - environmentVariable: "NIO_SINGLETON_GROUP_LOOP_COUNT") + Self.getTrustworthyThreadCount( + rawStorage: globalRawSuggestedLoopCount, + environmentVariable: "NIO_SINGLETON_GROUP_LOOP_COUNT" + ) } } @@ -67,8 +70,10 @@ extension NIOSingletons { } get { - return Self.getTrustworthyThreadCount(rawStorage: globalRawSuggestedBlockingThreadCount, - environmentVariable: "NIO_SINGLETON_BLOCKING_POOL_THREAD_COUNT") + Self.getTrustworthyThreadCount( + rawStorage: globalRawSuggestedBlockingThreadCount, + environmentVariable: "NIO_SINGLETON_BLOCKING_POOL_THREAD_COUNT" + ) } } @@ -79,9 +84,11 @@ extension NIOSingletons { /// - note: This value must be set _before_ any singletons are used and must only be set once. public static var singletonsEnabledSuggestion: Bool { get { - let (exchanged, original) = globalRawSingletonsEnabled.compareExchange(expected: 0, - desired: 1, - ordering: .relaxed) + let (exchanged, original) = globalRawSingletonsEnabled.compareExchange( + expected: 0, + desired: 1, + ordering: .relaxed + ) if exchanged { // Never been set, we're committing to the default (enabled). assert(original == 0) @@ -96,15 +103,19 @@ extension NIOSingletons { set { let intRepresentation = newValue ? 1 : -1 - let (exchanged, _) = globalRawSingletonsEnabled.compareExchange(expected: 0, - desired: intRepresentation, - ordering: .relaxed) + let (exchanged, _) = globalRawSingletonsEnabled.compareExchange( + expected: 0, + desired: intRepresentation, + ordering: .relaxed + ) guard exchanged else { - fatalError(""" - Bug in user code: Global singleton enabled suggestion has been changed after \ - user or has been changed more than once. Either is an error, you must set this value very \ - early and only once. - """) + fatalError( + """ + Bug in user code: Global singleton enabled suggestion has been changed after \ + user or has been changed more than once. Either is an error, you must set this value very \ + early and only once. + """ + ) } } } @@ -124,19 +135,25 @@ extension NIOSingletons { // to 5. let (exchanged, _) = rawStorage.compareExchange(expected: 0, desired: -userValue, ordering: .relaxed) guard exchanged else { - fatalError(""" - Bug in user code: Global singleton suggested loop/thread count has been changed after \ - user or has been changed more than once. Either is an error, you must set this value very early \ - and only once. - """) + fatalError( + """ + Bug in user code: Global singleton suggested loop/thread count has been changed after \ + user or has been changed more than once. Either is an error, you must set this value very early \ + and only once. + """ + ) } } private static func validateTrustedThreadCount(_ threadCount: Int) { - assert(threadCount > 0, - "BUG IN NIO, please report: negative suggested loop/thread count: \(threadCount)") - assert(threadCount <= 1024, - "BUG IN NIO, please report: overly big suggested loop/thread count: \(threadCount)") + assert( + threadCount > 0, + "BUG IN NIO, please report: negative suggested loop/thread count: \(threadCount)" + ) + assert( + threadCount <= 1024, + "BUG IN NIO, please report: overly big suggested loop/thread count: \(threadCount)" + ) } private static func getTrustworthyThreadCount(rawStorage: ManagedAtomic, environmentVariable: String) -> Int { @@ -144,15 +161,15 @@ extension NIOSingletons { let rawSuggestion = rawStorage.load(ordering: .relaxed) switch rawSuggestion { - case 0: // == 0 + case 0: // == 0 // Not set by user, not yet finalised, let's try to get it from the env var and fall back to // `System.coreCount`. let envVarString = getenv(environmentVariable).map { String(cString: $0) } returnedValueUnchecked = envVarString.flatMap(Int.init) ?? System.coreCount - case .min ..< 0: // < 0 + case .min..<0: // < 0 // Untrusted and unchecked user value. Let's invert and then sanitise/check. returnedValueUnchecked = -rawSuggestion - case 1 ... .max: // > 0 + case 1 ... .max: // > 0 // Trustworthy value that has been evaluated and sanitised before. let returnValue = rawSuggestion Self.validateTrustedThreadCount(returnValue) @@ -167,9 +184,11 @@ extension NIOSingletons { Self.validateTrustedThreadCount(returnValue) // Store it for next time. - let (exchanged, _) = rawStorage.compareExchange(expected: rawSuggestion, - desired: returnValue, - ordering: .relaxed) + let (exchanged, _) = rawStorage.compareExchange( + expected: rawSuggestion, + desired: returnValue, + ordering: .relaxed + ) if !exchanged { // We lost the race, this must mean it has been concurrently set correctly so we can safely recurse // and try again. diff --git a/Sources/NIOCore/IO.swift b/Sources/NIOCore/IO.swift index 3fdd7a10f9..49eb52e4d4 100644 --- a/Sources/NIOCore/IO.swift +++ b/Sources/NIOCore/IO.swift @@ -26,7 +26,7 @@ import typealias WinSDK.WCHAR import typealias WinSDK.WORD internal func MAKELANGID(_ p: WORD, _ s: WORD) -> DWORD { - return DWORD((s << 10) | p) + DWORD((s << 10) | p) } #elseif canImport(Glibc) import Glibc @@ -49,15 +49,19 @@ public struct IOError: Swift.Error { /// The actual reason (in an human-readable form) for this `IOError`. private var failureDescription: String - @available(*, deprecated, message: "NIO no longer uses FailureDescription, use IOError.description for a human-readable error description") + @available( + *, + deprecated, + message: "NIO no longer uses FailureDescription, use IOError.description for a human-readable error description" + ) public var reason: FailureDescription { - return .reason(self.failureDescription) + .reason(self.failureDescription) } private enum Error { #if os(Windows) - case windows(DWORD) - case winsock(CInt) + case windows(DWORD) + case winsock(CInt) #endif case errno(CInt) } @@ -70,13 +74,13 @@ public struct IOError: Swift.Error { case .errno(let code): return code #if os(Windows) - default: - fatalError("IOError domain is not `errno`") + default: + fatalError("IOError domain is not `errno`") #endif } } -#if os(Windows) + #if os(Windows) public init(windows code: DWORD, reason: String) { self.error = .windows(code) self.failureDescription = reason @@ -86,7 +90,7 @@ public struct IOError: Swift.Error { self.error = .winsock(code) self.failureDescription = reason } -#endif + #endif /// Creates a new `IOError`` /// @@ -126,9 +130,10 @@ private func reasonForError(errnoCode: CInt, reason: String) -> String { #if os(Windows) private func reasonForWinError(_ code: DWORD) -> String { - let dwFlags: DWORD = DWORD(FORMAT_MESSAGE_ALLOCATE_BUFFER) - | DWORD(FORMAT_MESSAGE_FROM_SYSTEM) - | DWORD(FORMAT_MESSAGE_IGNORE_INSERTS) + let dwFlags: DWORD = + DWORD(FORMAT_MESSAGE_ALLOCATE_BUFFER) + | DWORD(FORMAT_MESSAGE_FROM_SYSTEM) + | DWORD(FORMAT_MESSAGE_IGNORE_INSERTS) var buffer: UnsafeMutablePointer? // We use `FORMAT_MESSAGE_ALLOCATE_BUFFER` in flags which means that the @@ -136,9 +141,15 @@ private func reasonForWinError(_ code: DWORD) -> String { // expects a `LPWSTR` and expects the user to type-pun in this case. let dwResult: DWORD = withUnsafeMutablePointer(to: &buffer) { $0.withMemoryRebound(to: WCHAR.self, capacity: 2) { - FormatMessageW(dwFlags, nil, code, - MAKELANGID(WORD(LANG_NEUTRAL), WORD(SUBLANG_DEFAULT)), - $0, 0, nil) + FormatMessageW( + dwFlags, + nil, + code, + MAKELANGID(WORD(LANG_NEUTRAL), WORD(SUBLANG_DEFAULT)), + $0, + 0, + nil + ) } } guard dwResult > 0, let message = buffer else { @@ -151,11 +162,11 @@ private func reasonForWinError(_ code: DWORD) -> String { extension IOError: CustomStringConvertible { public var description: String { - return self.localizedDescription + self.localizedDescription } public var localizedDescription: String { -#if os(Windows) + #if os(Windows) switch self.error { case .errno(let errno): return reasonForError(errnoCode: errno, reason: self.failureDescription) @@ -164,9 +175,9 @@ extension IOError: CustomStringConvertible { case .winsock(let code): return reasonForWinError(DWORD(code)) } -#else + #else return reasonForError(errnoCode: self.errnoCode, reason: self.failureDescription) -#endif + #endif } } @@ -181,7 +192,7 @@ enum CoreIOResult: Equatable { case processed(T) } -internal extension CoreIOResult where T: FixedWidthInteger { +extension CoreIOResult where T: FixedWidthInteger { var result: T { switch self { case .processed(let value): @@ -191,4 +202,3 @@ internal extension CoreIOResult where T: FixedWidthInteger { } } } - diff --git a/Sources/NIOCore/IPProtocol.swift b/Sources/NIOCore/IPProtocol.swift index 1452fb2b4a..a967cf2636 100644 --- a/Sources/NIOCore/IPProtocol.swift +++ b/Sources/NIOCore/IPProtocol.swift @@ -19,7 +19,7 @@ public struct NIOIPProtocol: RawRepresentable, Hashable, Sendable { public typealias RawValue = UInt8 public var rawValue: RawValue - + @inlinable public init(rawValue: RawValue) { self.rawValue = rawValue @@ -169,7 +169,7 @@ extension NIOIPProtocol: CustomStringConvertible { default: return nil } } - + public var description: String { let name = self.name ?? "Unknown Protocol" return "\(name) - \(rawValue)" diff --git a/Sources/NIOCore/IntegerBitPacking.swift b/Sources/NIOCore/IntegerBitPacking.swift index 6eff57f9e2..40b6dd52d6 100644 --- a/Sources/NIOCore/IntegerBitPacking.swift +++ b/Sources/NIOCore/IntegerBitPacking.swift @@ -19,11 +19,15 @@ enum _IntegerBitPacking {} extension _IntegerBitPacking { @inlinable - static func packUU(_ left: Left, - _ right: Right, - type: Result.Type = Result.self) -> Result { + static func packUU< + Left: FixedWidthInteger & UnsignedInteger, + Right: FixedWidthInteger & UnsignedInteger, + Result: FixedWidthInteger & UnsignedInteger + >( + _ left: Left, + _ right: Right, + type: Result.Type = Result.self + ) -> Result { assert(MemoryLayout.size + MemoryLayout.size <= MemoryLayout.size) let resultLeft = Result(left) @@ -34,11 +38,15 @@ extension _IntegerBitPacking { } @inlinable - static func unpackUU(_ input: Input, - leftType: Left.Type = Left.self, - rightType: Right.Type = Right.self) -> (Left, Right) { + static func unpackUU< + Input: FixedWidthInteger & UnsignedInteger, + Left: FixedWidthInteger & UnsignedInteger, + Right: FixedWidthInteger & UnsignedInteger + >( + _ input: Input, + leftType: Left.Type = Left.self, + rightType: Right.Type = Right.self + ) -> (Left, Right) { assert(MemoryLayout.size + MemoryLayout.size <= MemoryLayout.size) let leftMask = Input(Left.max) @@ -57,7 +65,7 @@ enum IntegerBitPacking {} extension IntegerBitPacking { @inlinable static func packUInt32UInt16UInt8(_ left: UInt32, _ middle: UInt16, _ right: UInt8) -> UInt64 { - return _IntegerBitPacking.packUU( + _IntegerBitPacking.packUU( _IntegerBitPacking.packUU(right, middle, type: UInt32.self), left ) @@ -72,27 +80,27 @@ extension IntegerBitPacking { @inlinable static func packUInt8UInt8(_ left: UInt8, _ right: UInt8) -> UInt16 { - return _IntegerBitPacking.packUU(left, right) + _IntegerBitPacking.packUU(left, right) } @inlinable static func unpackUInt8UInt8(_ value: UInt16) -> (UInt8, UInt8) { - return _IntegerBitPacking.unpackUU(value) + _IntegerBitPacking.unpackUU(value) } @inlinable static func packUInt16UInt8(_ left: UInt16, _ right: UInt8) -> UInt32 { - return _IntegerBitPacking.packUU(left, right) + _IntegerBitPacking.packUU(left, right) } @inlinable static func unpackUInt16UInt8(_ value: UInt32) -> (UInt16, UInt8) { - return _IntegerBitPacking.unpackUU(value) + _IntegerBitPacking.unpackUU(value) } @inlinable static func packUInt32CInt(_ left: UInt32, _ right: CInt) -> UInt64 { - return _IntegerBitPacking.packUU(left, UInt32(truncatingIfNeeded: right)) + _IntegerBitPacking.packUU(left, UInt32(truncatingIfNeeded: right)) } @inlinable diff --git a/Sources/NIOCore/IntegerTypes.swift b/Sources/NIOCore/IntegerTypes.swift index 1448ab3877..6de5018b9e 100644 --- a/Sources/NIOCore/IntegerTypes.swift +++ b/Sources/NIOCore/IntegerTypes.swift @@ -50,18 +50,17 @@ extension Int { } } - extension _UInt24: Equatable { @inlinable - public static func ==(lhs: _UInt24, rhs: _UInt24) -> Bool { - return lhs._backing == rhs._backing + public static func == (lhs: _UInt24, rhs: _UInt24) -> Bool { + lhs._backing == rhs._backing } } extension _UInt24: CustomStringConvertible { @usableFromInline var description: String { - return UInt32(self).description + UInt32(self).description } } @@ -77,7 +76,7 @@ struct _UInt56: Sendable { static let bitWidth: Int = 56 - private static let initializeUInt64 : UInt64 = (1 << 56) - 1 + private static let initializeUInt64: UInt64 = (1 << 56) - 1 static let max: _UInt56 = .init(initializeUInt64) static let min: _UInt56 = .init(0) } @@ -90,9 +89,11 @@ extension _UInt56 { extension UInt64 { init(_ value: _UInt56) { - self = IntegerBitPacking.packUInt32UInt16UInt8(value._backing.0, - value._backing.1, - value._backing.2) + self = IntegerBitPacking.packUInt32UInt16UInt8( + value._backing.0, + value._backing.1, + value._backing.2 + ) } } @@ -104,13 +105,13 @@ extension Int { extension _UInt56: Equatable { @inlinable - public static func ==(lhs: _UInt56, rhs: _UInt56) -> Bool { - return lhs._backing == rhs._backing + public static func == (lhs: _UInt56, rhs: _UInt56) -> Bool { + lhs._backing == rhs._backing } } extension _UInt56: CustomStringConvertible { var description: String { - return UInt64(self).description + UInt64(self).description } } diff --git a/Sources/NIOCore/Interfaces.swift b/Sources/NIOCore/Interfaces.swift index c4b6eb001c..cc70fd2ea7 100644 --- a/Sources/NIOCore/Interfaces.swift +++ b/Sources/NIOCore/Interfaces.swift @@ -43,8 +43,8 @@ import typealias WinSDK.UINT8 #endif #if !os(Windows) -private extension ifaddrs { - var dstaddr: UnsafeMutablePointer? { +extension ifaddrs { + fileprivate var dstaddr: UnsafeMutablePointer? { #if os(Linux) || os(Android) return self.ifa_ifu.ifu_dstaddr #elseif canImport(Darwin) @@ -52,7 +52,7 @@ private extension ifaddrs { #endif } - var broadaddr: UnsafeMutablePointer? { + fileprivate var broadaddr: UnsafeMutablePointer? { #if os(Linux) || os(Android) return self.ifa_ifu.ifu_broadaddr #elseif canImport(Darwin) @@ -93,11 +93,15 @@ public final class NIONetworkInterface: Sendable { /// The index of the interface, as provided by `if_nametoindex`. public let interfaceIndex: Int -#if os(Windows) - internal init?(_ pAdapter: UnsafeMutablePointer, - _ pAddress: UnsafeMutablePointer) { - self.name = String(decodingCString: pAdapter.pointee.FriendlyName, - as: UTF16.self) + #if os(Windows) + internal init?( + _ pAdapter: UnsafeMutablePointer, + _ pAddress: UnsafeMutablePointer + ) { + self.name = String( + decodingCString: pAdapter.pointee.FriendlyName, + as: UTF16.self + ) guard let address = pAddress.pointee.Address.lpSockaddr.convert() else { return nil } @@ -121,7 +125,7 @@ public final class NIONetworkInterface: Sendable { self.pointToPointDestinationAddress = nil self.multicastSupported = false } -#else + #else internal init?(_ caddr: ifaddrs) { self.name = String(cString: caddr.ifa_name!) @@ -163,7 +167,7 @@ public final class NIONetworkInterface: Sendable { return nil } } -#endif + #endif } @available(*, deprecated, renamed: "NIONetworkDevice") @@ -177,13 +181,11 @@ extension NIONetworkInterface: CustomDebugStringConvertible { @available(*, deprecated, renamed: "NIONetworkDevice") extension NIONetworkInterface: Equatable { - public static func ==(lhs: NIONetworkInterface, rhs: NIONetworkInterface) -> Bool { - return lhs.name == rhs.name && - lhs.address == rhs.address && - lhs.netmask == rhs.netmask && - lhs.broadcastAddress == rhs.broadcastAddress && - lhs.pointToPointDestinationAddress == rhs.pointToPointDestinationAddress && - lhs.interfaceIndex == rhs.interfaceIndex + public static func == (lhs: NIONetworkInterface, rhs: NIONetworkInterface) -> Bool { + lhs.name == rhs.name && lhs.address == rhs.address && lhs.netmask == rhs.netmask + && lhs.broadcastAddress == rhs.broadcastAddress + && lhs.pointToPointDestinationAddress == rhs.pointToPointDestinationAddress + && lhs.interfaceIndex == rhs.interfaceIndex } } @@ -212,7 +214,7 @@ public struct NIONetworkDevice { /// The name of the network device. public var name: String { get { - return self.backing.name + self.backing.name } set { self.uniquifyIfNeeded() @@ -223,7 +225,7 @@ public struct NIONetworkDevice { /// The address associated with the given network device. public var address: SocketAddress? { get { - return self.backing.address + self.backing.address } set { self.uniquifyIfNeeded() @@ -234,7 +236,7 @@ public struct NIONetworkDevice { /// The netmask associated with this address, if any. public var netmask: SocketAddress? { get { - return self.backing.netmask + self.backing.netmask } set { self.uniquifyIfNeeded() @@ -246,7 +248,7 @@ public struct NIONetworkDevice { /// interfaces do not, especially those that have a `pointToPointDestinationAddress`. public var broadcastAddress: SocketAddress? { get { - return self.backing.broadcastAddress + self.backing.broadcastAddress } set { self.uniquifyIfNeeded() @@ -259,7 +261,7 @@ public struct NIONetworkDevice { /// instead. public var pointToPointDestinationAddress: SocketAddress? { get { - return self.backing.pointToPointDestinationAddress + self.backing.pointToPointDestinationAddress } set { self.uniquifyIfNeeded() @@ -270,7 +272,7 @@ public struct NIONetworkDevice { /// If the Interface supports Multicast public var multicastSupported: Bool { get { - return self.backing.multicastSupported + self.backing.multicastSupported } set { self.uniquifyIfNeeded() @@ -281,7 +283,7 @@ public struct NIONetworkDevice { /// The index of the interface, as provided by `if_nametoindex`. public var interfaceIndex: Int { get { - return self.backing.interfaceIndex + self.backing.interfaceIndex } set { self.uniquifyIfNeeded() @@ -294,15 +296,17 @@ public struct NIONetworkDevice { /// This constructor will fail if NIO does not understand the format of the underlying /// socket address family. This is quite common: for example, Linux will return AF_PACKET /// addressed interfaces on most platforms, which NIO does not currently understand. -#if os(Windows) - internal init?(_ pAdapter: UnsafeMutablePointer, - _ pAddress: UnsafeMutablePointer) { + #if os(Windows) + internal init?( + _ pAdapter: UnsafeMutablePointer, + _ pAddress: UnsafeMutablePointer + ) { guard let backing = Backing(pAdapter, pAddress) else { return nil } self.backing = backing } -#else + #else internal init?(_ caddr: ifaddrs) { guard let backing = Backing(caddr) else { return nil @@ -310,9 +314,9 @@ public struct NIONetworkDevice { self.backing = backing } -#endif + #endif -#if !os(Windows) + #if !os(Windows) /// Convert a `NIONetworkInterface` to a `NIONetworkDevice`. As `NIONetworkDevice`s are a superset of `NIONetworkInterface`s, /// it is always possible to perform this conversion. @available(*, deprecated, message: "This is a compatibility helper, and will be removed in a future release") @@ -327,15 +331,17 @@ public struct NIONetworkDevice { interfaceIndex: interface.interfaceIndex ) } -#endif - - public init(name: String, - address: SocketAddress?, - netmask: SocketAddress?, - broadcastAddress: SocketAddress?, - pointToPointDestinationAddress: SocketAddress, - multicastSupported: Bool, - interfaceIndex: Int) { + #endif + + public init( + name: String, + address: SocketAddress?, + netmask: SocketAddress?, + broadcastAddress: SocketAddress?, + pointToPointDestinationAddress: SocketAddress, + multicastSupported: Bool, + interfaceIndex: Int + ) { self.backing = Backing( name: name, address: address, @@ -387,11 +393,15 @@ extension NIONetworkDevice { /// This constructor will fail if NIO does not understand the format of the underlying /// socket address family. This is quite common: for example, Linux will return AF_PACKET /// addressed interfaces on most platforms, which NIO does not currently understand. -#if os(Windows) - internal init?(_ pAdapter: UnsafeMutablePointer, - _ pAddress: UnsafeMutablePointer) { - self.name = String(decodingCString: pAdapter.pointee.FriendlyName, - as: UTF16.self) + #if os(Windows) + internal init?( + _ pAdapter: UnsafeMutablePointer, + _ pAddress: UnsafeMutablePointer + ) { + self.name = String( + decodingCString: pAdapter.pointee.FriendlyName, + as: UTF16.self + ) self.address = pAddress.pointee.Address.lpSockaddr.convert() switch pAddress.pointee.Address.lpSockaddr.pointee.sa_family { @@ -412,7 +422,7 @@ extension NIONetworkDevice { self.pointToPointDestinationAddress = nil self.multicastSupported = false } -#else + #else internal init?(_ caddr: ifaddrs) { self.name = String(cString: caddr.ifa_name!) self.address = caddr.ifa_addr.flatMap { $0.convert() } @@ -436,7 +446,7 @@ extension NIONetworkDevice { return nil } } -#endif + #endif init(copying original: Backing) { self.name = original.name @@ -448,13 +458,15 @@ extension NIONetworkDevice { self.interfaceIndex = original.interfaceIndex } - init(name: String, - address: SocketAddress?, - netmask: SocketAddress?, - broadcastAddress: SocketAddress?, - pointToPointDestinationAddress: SocketAddress?, - multicastSupported: Bool, - interfaceIndex: Int) { + init( + name: String, + address: SocketAddress?, + netmask: SocketAddress?, + broadcastAddress: SocketAddress?, + pointToPointDestinationAddress: SocketAddress?, + multicastSupported: Bool, + interfaceIndex: Int + ) { self.name = name self.address = address self.netmask = netmask @@ -476,13 +488,11 @@ extension NIONetworkDevice: CustomDebugStringConvertible { // Sadly, as this is class-backed we cannot synthesise the implementation. extension NIONetworkDevice: Equatable { - public static func ==(lhs: NIONetworkDevice, rhs: NIONetworkDevice) -> Bool { - return lhs.name == rhs.name && - lhs.address == rhs.address && - lhs.netmask == rhs.netmask && - lhs.broadcastAddress == rhs.broadcastAddress && - lhs.pointToPointDestinationAddress == rhs.pointToPointDestinationAddress && - lhs.interfaceIndex == rhs.interfaceIndex + public static func == (lhs: NIONetworkDevice, rhs: NIONetworkDevice) -> Bool { + lhs.name == rhs.name && lhs.address == rhs.address && lhs.netmask == rhs.netmask + && lhs.broadcastAddress == rhs.broadcastAddress + && lhs.pointToPointDestinationAddress == rhs.pointToPointDestinationAddress + && lhs.interfaceIndex == rhs.interfaceIndex } } @@ -496,4 +506,3 @@ extension NIONetworkDevice: Hashable { hasher.combine(self.interfaceIndex) } } - diff --git a/Sources/NIOCore/Linux.swift b/Sources/NIOCore/Linux.swift index 73a0e62557..3f944ee726 100644 --- a/Sources/NIOCore/Linux.swift +++ b/Sources/NIOCore/Linux.swift @@ -30,7 +30,7 @@ enum Linux { var buf = ByteBufferAllocator().buffer(capacity: 1024) try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: buf.capacity) { ptr in let res = try fh.withUnsafeFileDescriptor { fd -> CoreIOResult in - return try SystemCalls.read(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) + try SystemCalls.read(descriptor: fd, pointer: ptr.baseAddress!, size: ptr.count) } switch res { case .processed(let n): @@ -62,8 +62,10 @@ enum Linux { /// Get the available core count according to cgroup1 restrictions. /// Round up to the next whole number. - static func coreCountCgroup1Restriction(quota quotaPath: String = Linux.cfsQuotaPath, - period periodPath: String = Linux.cfsPeriodPath) -> Int? { + static func coreCountCgroup1Restriction( + quota quotaPath: String = Linux.cfsQuotaPath, + period periodPath: String = Linux.cfsPeriodPath + ) -> Int? { guard let quota = try? Int(firstLineOfFile(path: quotaPath)), quota > 0 @@ -72,18 +74,18 @@ enum Linux { let period = try? Int(firstLineOfFile(path: periodPath)), period > 0 else { return nil } - return (quota - 1 + period) / period // always round up if fractional CPU quota requested + return (quota - 1 + period) / period // always round up if fractional CPU quota requested } /// Get the available core count according to cgroup2 restrictions. /// Round up to the next whole number. static func coreCountCgroup2Restriction(cpuMaxPath: String = Linux.cfsCpuMaxPath) -> Int? { guard let maxDetails = try? firstLineOfFile(path: cpuMaxPath), - let spaceIndex = maxDetails.firstIndex(of: " "), - let quota = Int(maxDetails[maxDetails.startIndex ..< spaceIndex]), - let period = Int(maxDetails[maxDetails.index(after: spaceIndex) ..< maxDetails.endIndex]) + let spaceIndex = maxDetails.firstIndex(of: " "), + let quota = Int(maxDetails[maxDetails.startIndex..: CustomStringConvertible { @usableFromInline internal var _buffer: CircularBuffer - @usableFromInline internal var _markedIndexOffset: Int? /* nil: nothing marked */ + @usableFromInline internal var _markedIndexOffset: Int? // nil: nothing marked /// Create a new instance. /// @@ -60,24 +60,24 @@ public struct MarkedCircularBuffer: CustomStringConvertible { /// The first element in the buffer. @inlinable public var first: Element? { - return self._buffer.first + self._buffer.first } /// If the buffer is empty. @inlinable public var isEmpty: Bool { - return self._buffer.isEmpty + self._buffer.isEmpty } /// The number of elements in the buffer. @inlinable public var count: Int { - return self._buffer.count + self._buffer.count } @inlinable public var description: String { - return self._buffer.description + self._buffer.description } // MARK: Marking @@ -119,13 +119,13 @@ public struct MarkedCircularBuffer: CustomStringConvertible { /// Returns the marked element. @inlinable public var markedElement: Element? { - return self.markedElementIndex.map { self._buffer[$0] } + self.markedElementIndex.map { self._buffer[$0] } } /// Returns true if the buffer has been marked at all. @inlinable public var hasMark: Bool { - return self._markedIndexOffset != nil + self._markedIndexOffset != nil } } @@ -136,20 +136,20 @@ extension MarkedCircularBuffer: Collection, MutableCollection { @inlinable public func index(after i: Index) -> Index { - return self._buffer.index(after: i) + self._buffer.index(after: i) } @inlinable - public var startIndex: Index { return self._buffer.startIndex } + public var startIndex: Index { self._buffer.startIndex } @inlinable - public var endIndex: Index { return self._buffer.endIndex } + public var endIndex: Index { self._buffer.endIndex } /// Retrieves the element at the given index from the buffer, without removing it. @inlinable public subscript(index: Index) -> Element { get { - return self._buffer[index] + self._buffer[index] } set { self._buffer[index] = newValue @@ -159,7 +159,7 @@ extension MarkedCircularBuffer: Collection, MutableCollection { @inlinable public subscript(bounds: Range) -> SubSequence { get { - return self._buffer[bounds] + self._buffer[bounds] } set { var index = bounds.lowerBound @@ -176,17 +176,17 @@ extension MarkedCircularBuffer: Collection, MutableCollection { extension MarkedCircularBuffer: RandomAccessCollection { @inlinable public func index(_ i: Index, offsetBy distance: Int) -> Index { - return self._buffer.index(i, offsetBy: distance) + self._buffer.index(i, offsetBy: distance) } @inlinable public func distance(from start: Index, to end: Index) -> Int { - return self._buffer.distance(from: start, to: end) + self._buffer.distance(from: start, to: end) } @inlinable public func index(before i: Index) -> Index { - return self._buffer.index(before: i) + self._buffer.index(before: i) } } diff --git a/Sources/NIOCore/MulticastChannel.swift b/Sources/NIOCore/MulticastChannel.swift index 9669617d3c..129fe6f46f 100644 --- a/Sources/NIOCore/MulticastChannel.swift +++ b/Sources/NIOCore/MulticastChannel.swift @@ -25,7 +25,7 @@ public protocol MulticastChannel: Channel { /// `nil` if you are not interested in the result of the operation. func joinGroup(_ group: SocketAddress, promise: EventLoopPromise?) -#if !os(Windows) + #if !os(Windows) /// Request that the `MulticastChannel` join the multicast group given by `group` on the interface /// given by `interface`. /// @@ -36,7 +36,7 @@ public protocol MulticastChannel: Channel { /// `nil` if you are not interested in the result of the operation. @available(*, deprecated, renamed: "joinGroup(_:device:promise:)") func joinGroup(_ group: SocketAddress, interface: NIONetworkInterface?, promise: EventLoopPromise?) -#endif + #endif /// Request that the `MulticastChannel` join the multicast group given by `group` on the device /// given by `device`. @@ -56,7 +56,7 @@ public protocol MulticastChannel: Channel { /// `nil` if you are not interested in the result of the operation. func leaveGroup(_ group: SocketAddress, promise: EventLoopPromise?) -#if !os(Windows) + #if !os(Windows) /// Request that the `MulticastChannel` leave the multicast group given by `group` on the interface /// given by `interface`. /// @@ -67,7 +67,7 @@ public protocol MulticastChannel: Channel { /// `nil` if you are not interested in the result of the operation. @available(*, deprecated, renamed: "leaveGroup(_:device:promise:)") func leaveGroup(_ group: SocketAddress, interface: NIONetworkInterface?, promise: EventLoopPromise?) -#endif + #endif /// Request that the `MulticastChannel` leave the multicast group given by `group` on the device /// given by `device`. @@ -80,7 +80,6 @@ public protocol MulticastChannel: Channel { func leaveGroup(_ group: SocketAddress, device: NIONetworkDevice?, promise: EventLoopPromise?) } - // MARK:- Default implementations for MulticastChannel extension MulticastChannel { public func joinGroup(_ group: SocketAddress, promise: EventLoopPromise?) { @@ -93,14 +92,14 @@ extension MulticastChannel { return promise.futureResult } -#if !os(Windows) + #if !os(Windows) @available(*, deprecated, renamed: "joinGroup(_:device:)") public func joinGroup(_ group: SocketAddress, interface: NIONetworkInterface?) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Void.self) self.joinGroup(group, interface: interface, promise: promise) return promise.futureResult } -#endif + #endif public func joinGroup(_ group: SocketAddress, device: NIONetworkDevice?) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Void.self) @@ -118,14 +117,14 @@ extension MulticastChannel { return promise.futureResult } -#if !os(Windows) + #if !os(Windows) @available(*, deprecated, renamed: "leaveGroup(_:device:)") public func leaveGroup(_ group: SocketAddress, interface: NIONetworkInterface?) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Void.self) self.leaveGroup(group, interface: interface, promise: promise) return promise.futureResult } -#endif + #endif public func leaveGroup(_ group: SocketAddress, device: NIONetworkDevice?) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Void.self) @@ -176,4 +175,3 @@ public struct NIOMulticastNotSupportedError: Error { public struct NIOMulticastNotImplementedError: Error { public init() {} } - diff --git a/Sources/NIOCore/NIOAny.swift b/Sources/NIOCore/NIOAny.swift index c489e2b7de..1f687b134b 100644 --- a/Sources/NIOCore/NIOAny.swift +++ b/Sources/NIOCore/NIOAny.swift @@ -44,7 +44,7 @@ /// } public struct NIOAny { @usableFromInline - /* private but _versioned */ let _storage: _NIOAny + let _storage: _NIOAny /// Wrap a value in a `NIOAny`. In most cases you should not create a `NIOAny` directly using this constructor. /// The abstraction that accepts values of type `NIOAny` must also provide a mechanism to do the wrapping. An @@ -98,7 +98,9 @@ public struct NIOAny { if let v = tryAsByteBuffer() { return v } else { - fatalError("tried to decode as type \(ByteBuffer.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)") + fatalError( + "tried to decode as type \(ByteBuffer.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)" + ) } } @@ -122,7 +124,9 @@ public struct NIOAny { if let v = tryAsIOData() { return v } else { - fatalError("tried to decode as type \(IOData.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)") + fatalError( + "tried to decode as type \(IOData.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)" + ) } } @@ -146,7 +150,9 @@ public struct NIOAny { if let v = tryAsFileRegion() { return v } else { - fatalError("tried to decode as type \(FileRegion.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)") + fatalError( + "tried to decode as type \(FileRegion.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)" + ) } } @@ -170,7 +176,9 @@ public struct NIOAny { if let e = tryAsByteEnvelope() { return e } else { - fatalError("tried to decode as type \(AddressedEnvelope.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)") + fatalError( + "tried to decode as type \(AddressedEnvelope.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)" + ) } } @@ -197,7 +205,9 @@ public struct NIOAny { if let v = tryAsOther(type: type) { return v } else { - fatalError("tried to decode as type \(T.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)") + fatalError( + "tried to decode as type \(T.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType) with contents \(self._storage)" + ) } } @@ -262,6 +272,6 @@ extension NIOAny: Sendable {} extension NIOAny: CustomStringConvertible { public var description: String { - return "NIOAny { \(self.asAny()) }" + "NIOAny { \(self.asAny()) }" } } diff --git a/Sources/NIOCore/NIOCloseOnErrorHandler.swift b/Sources/NIOCore/NIOCloseOnErrorHandler.swift index 53e78134ea..1f105545ba 100644 --- a/Sources/NIOCore/NIOCloseOnErrorHandler.swift +++ b/Sources/NIOCore/NIOCloseOnErrorHandler.swift @@ -12,15 +12,14 @@ // //===----------------------------------------------------------------------===// - /// A `ChannelInboundHandler` that closes the channel when an error is caught public final class NIOCloseOnErrorHandler: ChannelInboundHandler, Sendable { public typealias InboundIn = NIOAny - + /// Initialize a `NIOCloseOnErrorHandler` public init() {} - + public func errorCaught(context: ChannelHandlerContext, error: Error) { context.fireErrorCaught(error) context.close(promise: nil) diff --git a/Sources/NIOCore/NIOLoopBound.swift b/Sources/NIOCore/NIOLoopBound.swift index 6ce3c465dd..9b013ddc54 100644 --- a/Sources/NIOCore/NIOLoopBound.swift +++ b/Sources/NIOCore/NIOLoopBound.swift @@ -28,7 +28,7 @@ public struct NIOLoopBound: @unchecked Sendable { public let _eventLoop: EventLoop @usableFromInline - /* private */ var _value: Value + var _value: Value /// Initialise a ``NIOLoopBound`` to `value` with the precondition that the code is running on `eventLoop`. @inlinable @@ -75,7 +75,7 @@ public final class NIOLoopBoundBox: @unchecked Sendable { public let _eventLoop: EventLoop @usableFromInline - /* private */var _value: Value + var _value: Value @inlinable internal init(_value value: Value, uncheckedEventLoop eventLoop: EventLoop) { @@ -96,7 +96,7 @@ public final class NIOLoopBoundBox: @unchecked Sendable { public static func makeEmptyBox( valueType: NonOptionalValue.Type = NonOptionalValue.self, eventLoop: EventLoop - ) -> NIOLoopBoundBox where Optional == Value { + ) -> NIOLoopBoundBox where NonOptionalValue? == Value { // Here, we -- possibly surprisingly -- do not precondition being on the EventLoop. This is okay for a few // reasons: // - We write the `Optional.none` value which we know is _not_ a value of the potentially non-Sendable type @@ -104,7 +104,7 @@ public final class NIOLoopBoundBox: @unchecked Sendable { // - Because of Swift's Definitive Initialisation (DI), we know that we did write `self._value` before `init` // returns. // - The only way to ever write (or read indeed) `self._value` is by proving to be inside the `EventLoop`. - return .init(_value: nil, uncheckedEventLoop: eventLoop) + .init(_value: nil, uncheckedEventLoop: eventLoop) } /// Initialise a ``NIOLoopBoundBox`` by sending a `Sendable` value, validly callable off `eventLoop`. @@ -124,7 +124,7 @@ public final class NIOLoopBoundBox: @unchecked Sendable { // - Because of Swift's Definitive Initialisation (DI), we know that we did write `self._value` before `init` // returns. // - The only way to ever write (or read indeed) `self._value` is by proving to be inside the `EventLoop`. - return .init(_value: value, uncheckedEventLoop: eventLoop) + .init(_value: value, uncheckedEventLoop: eventLoop) } /// Access the `value` with the precondition that the code is running on `eventLoop`. @@ -142,4 +142,3 @@ public final class NIOLoopBoundBox: @unchecked Sendable { } } } - diff --git a/Sources/NIOCore/NIOSendable.swift b/Sources/NIOCore/NIOSendable.swift index 8d7cc66b26..f07ec670dd 100644 --- a/Sources/NIOCore/NIOSendable.swift +++ b/Sources/NIOCore/NIOSendable.swift @@ -27,7 +27,7 @@ public typealias NIOPreconcurrencySendable = _NIOPreconcurrencySendable struct UnsafeTransfer { @usableFromInline var wrappedValue: Wrapped - + @inlinable init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue @@ -46,7 +46,7 @@ extension UnsafeTransfer: Hashable where Wrapped: Hashable {} final class UnsafeMutableTransferBox { @usableFromInline var wrappedValue: Wrapped - + @inlinable init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue @@ -54,4 +54,3 @@ final class UnsafeMutableTransferBox { } extension UnsafeMutableTransferBox: @unchecked Sendable {} - diff --git a/Sources/NIOCore/RecvByteBufferAllocator.swift b/Sources/NIOCore/RecvByteBufferAllocator.swift index 6521fe4ab8..ddce4d757e 100644 --- a/Sources/NIOCore/RecvByteBufferAllocator.swift +++ b/Sources/NIOCore/RecvByteBufferAllocator.swift @@ -31,7 +31,7 @@ public protocol RecvByteBufferAllocator: _NIOPreconcurrencySendable { extension RecvByteBufferAllocator { // Default implementation to maintain API compatability. public func nextBufferSize() -> Int? { - return nil + nil } } @@ -46,17 +46,17 @@ public struct FixedSizeRecvByteBufferAllocator: RecvByteBufferAllocator { public mutating func record(actualReadBytes: Int) -> Bool { // Returns false as we always allocate the same size of buffers. - return false + false } public func buffer(allocator: ByteBufferAllocator) -> ByteBuffer { - return allocator.buffer(capacity: self.capacity) + allocator.buffer(capacity: self.capacity) } } extension FixedSizeRecvByteBufferAllocator { public func nextBufferSize() -> Int? { - return self.capacity + self.capacity } } @@ -91,7 +91,7 @@ public struct AdaptiveRecvByteBufferAllocator: RecvByteBufferAllocator { } public func buffer(allocator: ByteBufferAllocator) -> ByteBuffer { - return allocator.buffer(capacity: self.nextReceiveBufferSize) + allocator.buffer(capacity: self.nextReceiveBufferSize) } public mutating func record(actualReadBytes: Int) -> Bool { @@ -116,8 +116,9 @@ public struct AdaptiveRecvByteBufferAllocator: RecvByteBufferAllocator { } else { self.decreaseNow = true } - } else if actualReadBytes >= self.nextReceiveBufferSize && upperBound <= self.maximum && - self.nextReceiveBufferSize != upperBound { + } else if actualReadBytes >= self.nextReceiveBufferSize && upperBound <= self.maximum + && self.nextReceiveBufferSize != upperBound + { self.nextReceiveBufferSize = upperBound self.decreaseNow = false mayGrow = true @@ -131,6 +132,6 @@ public struct AdaptiveRecvByteBufferAllocator: RecvByteBufferAllocator { extension AdaptiveRecvByteBufferAllocator { public func nextBufferSize() -> Int? { - return self.nextReceiveBufferSize + self.nextReceiveBufferSize } } diff --git a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift index 467b44ab2b..1fed8965ac 100644 --- a/Sources/NIOCore/SingleStepByteToMessageDecoder.swift +++ b/Sources/NIOCore/SingleStepByteToMessageDecoder.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// - /// A simplified version of `ByteToMessageDecoder` that can generate zero or one messages for each invocation of `decode` or `decodeLast`. /// Having `decode` and `decodeLast` return an optional message avoids re-entrancy problems, since the functions relinquish exclusive access /// to the `ByteBuffer` when returning. This allows for greatly simplified processing. @@ -51,7 +50,6 @@ public protocol NIOSingleStepByteToMessageDecoder: ByteToMessageDecoder { mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? } - // MARK: NIOSingleStepByteToMessageDecoder: ByteToMessageDecoder extension NIOSingleStepByteToMessageDecoder { public mutating func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { @@ -63,7 +61,11 @@ extension NIOSingleStepByteToMessageDecoder { } } - public mutating func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public mutating func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { if let message = try self.decodeLast(buffer: &buffer, seenEOF: seenEOF) { context.fireChannelRead(Self.wrapInboundOut(message)) return .continue @@ -73,7 +75,6 @@ extension NIOSingleStepByteToMessageDecoder { } } - /// `NIOSingleStepByteToMessageProcessor` uses a `NIOSingleStepByteToMessageDecoder` to produce messages /// from a stream of incoming bytes. It works like `ByteToMessageHandler` but may be used outside of the channel pipeline. This allows /// processing of wrapped protocols in a general way. @@ -238,7 +239,11 @@ public final class NIOSingleStepByteToMessageProcessor Void) throws { + func _decodeLoop( + decodeMode: DecodeMode, + seenEOF: Bool = false, + _ messageReceiver: (Decoder.InboundOut) throws -> Void + ) throws { // we want to call decodeLast once with an empty buffer if we have nothing if decodeMode == .last && (self._buffer == nil || self._buffer!.readableBytes == 0) { var emptyBuffer = self._buffer == nil ? ByteBuffer() : self._buffer! diff --git a/Sources/NIOCore/SocketAddresses.swift b/Sources/NIOCore/SocketAddresses.swift index a0cf8f0ee5..0c52ba7a8e 100644 --- a/Sources/NIOCore/SocketAddresses.swift +++ b/Sources/NIOCore/SocketAddresses.swift @@ -37,10 +37,10 @@ import struct WinSDK.sockaddr_un import typealias WinSDK.u_short -fileprivate typealias in_addr = WinSDK.IN_ADDR -fileprivate typealias in6_addr = WinSDK.IN6_ADDR -fileprivate typealias in_port_t = WinSDK.u_short -fileprivate typealias sa_family_t = WinSDK.ADDRESS_FAMILY +private typealias in_addr = WinSDK.IN_ADDR +private typealias in6_addr = WinSDK.IN6_ADDR +private typealias in_port_t = WinSDK.u_short +private typealias sa_family_t = WinSDK.ADDRESS_FAMILY #elseif canImport(Darwin) import Darwin #elseif os(Linux) || os(FreeBSD) || os(Android) @@ -70,7 +70,7 @@ extension SocketAddressError { /// Unable to parse a given IP ByteBuffer public struct FailedToParseIPByteBuffer: Error, Hashable { public var address: ByteBuffer - + public init(address: ByteBuffer) { self.address = address } @@ -85,10 +85,10 @@ public enum SocketAddress: CustomStringConvertible, Sendable { private let _storage: Box<(address: sockaddr_in, host: String)> /// The libc socket address for an IPv4 address. - public var address: sockaddr_in { return _storage.value.address } + public var address: sockaddr_in { _storage.value.address } /// The host this address is for, if known. - public var host: String { return _storage.value.host } + public var host: String { _storage.value.host } fileprivate init(address: sockaddr_in, host: String) { self._storage = Box((address: address, host: host)) @@ -100,10 +100,10 @@ public enum SocketAddress: CustomStringConvertible, Sendable { private let _storage: Box<(address: sockaddr_in6, host: String)> /// The libc socket address for an IPv6 address. - public var address: sockaddr_in6 { return _storage.value.address } + public var address: sockaddr_in6 { _storage.value.address } /// The host this address is for, if known. - public var host: String { return _storage.value.host } + public var host: String { _storage.value.host } fileprivate init(address: sockaddr_in6, host: String) { self._storage = Box((address: address, host: host)) @@ -115,7 +115,7 @@ public enum SocketAddress: CustomStringConvertible, Sendable { private let _storage: Box /// The libc socket address for a Unix Domain Socket. - public var address: sockaddr_un { return _storage.value } + public var address: sockaddr_un { _storage.value } fileprivate init(address: sockaddr_un) { self._storage = Box(address) @@ -165,7 +165,7 @@ public enum SocketAddress: CustomStringConvertible, Sendable { @available(*, deprecated, renamed: "SocketAddress.protocol") public var protocolFamily: Int32 { - return Int32(self.protocol.rawValue) + Int32(self.protocol.rawValue) } /// Returns the protocol family as defined in `man 2 socket` of this `SocketAddress`. @@ -228,7 +228,7 @@ public enum SocketAddress: CustomStringConvertible, Sendable { } } } - + /// Get the pathname of a UNIX domain socket as a string public var pathname: String? { switch self { @@ -364,14 +364,14 @@ public enum SocketAddress: CustomStringConvertible, Sendable { addr.sin6_scope_id = 0 return .v6(.init(address: addr, host: "")) } catch { - // If `inet_pton` fails as an IPv6 address (and has failed as an - // IPv4 address above), we will throw an error below. + // If `inet_pton` fails as an IPv6 address (and has failed as an + // IPv4 address above), we will throw an error below. } throw SocketAddressError.failedToParseIPString(ipAddress) } } - + /// Create a new `SocketAddress` for an IP address in ByteBuffer form. /// /// - parameters: @@ -381,7 +381,7 @@ public enum SocketAddress: CustomStringConvertible, Sendable { /// - throws: may throw `SocketAddressError.failedToParseIPByteBuffer` if the IP address cannot be parsed. public init(packedIPAddress: ByteBuffer, port: Int) throws { let packed = packedIPAddress.readableBytesView - + switch packedIPAddress.readableBytes { case 4: var ipv4Addr = sockaddr_in() @@ -411,7 +411,7 @@ public enum SocketAddress: CustomStringConvertible, Sendable { internal init(ipv4MaskForPrefix prefix: Int) { precondition((0...32).contains(prefix)) - let packedAddress = (UInt32(0xFFFFFFFF) << (32 - prefix)).bigEndian + let packedAddress = (UInt32(0xFFFF_FFFF) << (32 - prefix)).bigEndian var ipv4Addr = sockaddr_in() ipv4Addr.sin_family = sa_family_t(AF_INET) ipv4Addr.sin_port = 0 @@ -433,9 +433,9 @@ public enum SocketAddress: CustomStringConvertible, Sendable { // This defends against the possibility of a greater-than-/64 subnet, which would produce a negative shift // operand which is absolutely not what we want. let highShift = min(prefix, 64) - let packedAddressHigh = (UInt64(0xFFFFFFFFFFFFFFFF) << (64 - highShift)).bigEndian + let packedAddressHigh = (UInt64(0xFFFF_FFFF_FFFF_FFFF) << (64 - highShift)).bigEndian - let packedAddressLow = (UInt64(0xFFFFFFFFFFFFFFFF) << (128 - prefix)).bigEndian + let packedAddressLow = (UInt64(0xFFFF_FFFF_FFFF_FFFF) << (128 - prefix)).bigEndian let packedAddress = (packedAddressHigh, packedAddressLow) var ipv6Addr = sockaddr_in6() @@ -455,9 +455,9 @@ public enum SocketAddress: CustomStringConvertible, Sendable { /// - returns: the `SocketAddress` for the host / port pair. /// - throws: a `SocketAddressError.unknown` if we could not resolve the `host`, or `SocketAddressError.unsupported` if the address itself is not supported (yet). public static func makeAddressResolvingHost(_ host: String, port: Int) throws -> SocketAddress { -#if os(Windows) + #if os(Windows) return try host.withCString(encodedAs: UTF16.self) { wszHost in - return try String(port).withCString(encodedAs: UTF16.self) { wszPort in + try String(port).withCString(encodedAs: UTF16.self) { wszPort in var pResult: UnsafeMutablePointer? guard GetAddrInfoW(wszHost, wszPort, nil, &pResult) == 0 else { @@ -482,10 +482,10 @@ public enum SocketAddress: CustomStringConvertible, Sendable { throw SocketAddressError.unsupported } } -#else + #else var info: UnsafeMutablePointer? - /* FIXME: this is blocking! */ + // FIXME: this is blocking! if getaddrinfo(host, String(port), nil, &info) != 0 { throw SocketAddressError.unknown(host: host, port: port) } @@ -507,34 +507,36 @@ public enum SocketAddress: CustomStringConvertible, Sendable { throw SocketAddressError.unsupported } } else { - /* this is odd, getaddrinfo returned NULL */ + // this is odd, getaddrinfo returned NULL throw SocketAddressError.unsupported } -#endif + #endif } } /// We define an extension on `SocketAddress` that gives it an elementwise equatable conformance, using /// only the elements defined on the structure in their man pages (excluding lengths). extension SocketAddress: Equatable { - public static func ==(lhs: SocketAddress, rhs: SocketAddress) -> Bool { + public static func == (lhs: SocketAddress, rhs: SocketAddress) -> Bool { switch (lhs, rhs) { case (.v4(let addr1), .v4(let addr2)): -#if os(Windows) - return addr1.address.sin_family == addr2.address.sin_family && - addr1.address.sin_port == addr2.address.sin_port && - addr1.address.sin_addr.S_un.S_addr == addr2.address.sin_addr.S_un.S_addr -#else - return addr1.address.sin_family == addr2.address.sin_family && - addr1.address.sin_port == addr2.address.sin_port && - addr1.address.sin_addr.s_addr == addr2.address.sin_addr.s_addr -#endif + #if os(Windows) + return addr1.address.sin_family == addr2.address.sin_family + && addr1.address.sin_port == addr2.address.sin_port + && addr1.address.sin_addr.S_un.S_addr == addr2.address.sin_addr.S_un.S_addr + #else + return addr1.address.sin_family == addr2.address.sin_family + && addr1.address.sin_port == addr2.address.sin_port + && addr1.address.sin_addr.s_addr == addr2.address.sin_addr.s_addr + #endif case (.v6(let addr1), .v6(let addr2)): - guard addr1.address.sin6_family == addr2.address.sin6_family && - addr1.address.sin6_port == addr2.address.sin6_port && - addr1.address.sin6_flowinfo == addr2.address.sin6_flowinfo && - addr1.address.sin6_scope_id == addr2.address.sin6_scope_id else { - return false + guard + addr1.address.sin6_family == addr2.address.sin6_family + && addr1.address.sin6_port == addr2.address.sin6_port + && addr1.address.sin6_flowinfo == addr2.address.sin6_flowinfo + && addr1.address.sin6_scope_id == addr2.address.sin6_scope_id + else { + return false } var s6addr1 = addr1.address.sin6_addr @@ -547,14 +549,13 @@ extension SocketAddress: Equatable { let bufferSize = MemoryLayout.size(ofValue: addr1.address.sun_path) - // Swift implicitly binds the memory for homogeneous tuples to both the tuple type and the element type. // This allows us to use assumingMemoryBound(to:) for managing the types. However, we add a static assertion here to validate // that the element type _really is_ what we're assuming it to be. assert(Swift.type(of: addr1.address.sun_path.0) == CChar.self) assert(Swift.type(of: addr2.address.sun_path.0) == CChar.self) return withUnsafePointer(to: addr1.address.sun_path) { sunpath1 in - return withUnsafePointer(to: addr2.address.sun_path) { sunpath2 in + withUnsafePointer(to: addr2.address.sun_path) { sunpath2 in let typedSunpath1 = UnsafeRawPointer(sunpath1).assumingMemoryBound(to: CChar.self) let typedSunpath2 = UnsafeRawPointer(sunpath2).assumingMemoryBound(to: CChar.self) return strncmp(typedSunpath1, typedSunpath2, bufferSize) == 0 @@ -594,11 +595,11 @@ extension SocketAddress: Hashable { hasher.combine(1) hasher.combine(v4Addr.address.sin_family) hasher.combine(v4Addr.address.sin_port) -#if os(Windows) + #if os(Windows) hasher.combine(v4Addr.address.sin_addr.S_un.S_addr) -#else + #else hasher.combine(v4Addr.address.sin_addr.s_addr) -#endif + #endif case .v6(let v6Addr): hasher.combine(2) hasher.combine(v6Addr.address.sin6_family) @@ -612,7 +613,6 @@ extension SocketAddress: Hashable { } } - extension SocketAddress { /// Whether this `SocketAddress` corresponds to a multicast address. public var isMulticast: Bool { @@ -624,15 +624,15 @@ extension SocketAddress { // For IPv4 a multicast address is in the range 224.0.0.0/4. // The easy way to check if this is the case is to just mask off // the address. -#if os(Windows) + #if os(Windows) let v4WireAddress = v4Addr.address.sin_addr.S_un.S_addr let mask = UInt32(0xF000_0000).bigEndian let subnet = UInt32(0xE000_0000).bigEndian -#else + #else let v4WireAddress = v4Addr.address.sin_addr.s_addr let mask = in_addr_t(0xF000_0000 as UInt32).bigEndian let subnet = in_addr_t(0xE000_0000 as UInt32).bigEndian -#endif + #endif return v4WireAddress & mask == subnet case .v6(let v6Addr): // For IPv6 a multicast address is in the range ff00::/8. @@ -649,13 +649,22 @@ protocol SockAddrProtocol { } /// Returns a description for the given address. -internal func descriptionForAddress(family: NIOBSDSocket.AddressFamily, bytes: UnsafeRawPointer, length byteCount: Int) throws -> String { +internal func descriptionForAddress( + family: NIOBSDSocket.AddressFamily, + bytes: UnsafeRawPointer, + length byteCount: Int +) throws -> String { var addressBytes: [Int8] = Array(repeating: 0, count: byteCount) - return try addressBytes.withUnsafeMutableBufferPointer { (addressBytesPtr: inout UnsafeMutableBufferPointer) -> String in - try NIOBSDSocket.inet_ntop(addressFamily: family, addressBytes: bytes, - addressDescription: addressBytesPtr.baseAddress!, - addressDescriptionLength: socklen_t(byteCount)) - return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { addressBytesPtr -> String in + return try addressBytes.withUnsafeMutableBufferPointer { + (addressBytesPtr: inout UnsafeMutableBufferPointer) -> String in + try NIOBSDSocket.inet_ntop( + addressFamily: family, + addressBytes: bytes, + addressDescription: addressBytesPtr.baseAddress!, + addressDescriptionLength: socklen_t(byteCount) + ) + return addressBytesPtr.baseAddress!.withMemoryRebound(to: UInt8.self, capacity: byteCount) { + addressBytesPtr -> String in String(cString: addressBytesPtr) } } @@ -663,14 +672,14 @@ internal func descriptionForAddress(family: NIOBSDSocket.AddressFamily, bytes: U extension sockaddr_in: SockAddrProtocol { func withSockAddr(_ body: (UnsafePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeBytes(of: self) { p in + try withUnsafeBytes(of: self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } /// Returns a description of the `sockaddr_in`. func addressDescription() -> String { - return withUnsafePointer(to: self.sin_addr) { addrPtr in + withUnsafePointer(to: self.sin_addr) { addrPtr in // this uses inet_ntop which is documented to only fail if family is not AF_INET or AF_INET6 (or ENOSPC) try! descriptionForAddress(family: .inet, bytes: addrPtr, length: Int(INET_ADDRSTRLEN)) } @@ -679,14 +688,14 @@ extension sockaddr_in: SockAddrProtocol { extension sockaddr_in6: SockAddrProtocol { func withSockAddr(_ body: (UnsafePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeBytes(of: self) { p in + try withUnsafeBytes(of: self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } /// Returns a description of the `sockaddr_in6`. func addressDescription() -> String { - return withUnsafePointer(to: self.sin6_addr) { addrPtr in + withUnsafePointer(to: self.sin6_addr) { addrPtr in // this uses inet_ntop which is documented to only fail if family is not AF_INET or AF_INET6 (or ENOSPC) try! descriptionForAddress(family: .inet6, bytes: addrPtr, length: Int(INET6_ADDRSTRLEN)) } @@ -695,7 +704,7 @@ extension sockaddr_in6: SockAddrProtocol { extension sockaddr_un: SockAddrProtocol { func withSockAddr(_ body: (UnsafePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeBytes(of: self) { p in + try withUnsafeBytes(of: self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } @@ -703,7 +712,7 @@ extension sockaddr_un: SockAddrProtocol { extension sockaddr_storage: SockAddrProtocol { func withSockAddr(_ body: (UnsafePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeBytes(of: self) { p in + try withUnsafeBytes(of: self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } @@ -714,21 +723,23 @@ extension sockaddr_storage: SockAddrProtocol { // the compiler falls over when we try to access them from test code. As these functions // exist purely to make the behaviours accessible from test code, we name them truly awfully. func __testOnly_addressDescription(_ addr: sockaddr_in) -> String { - return addr.addressDescription() + addr.addressDescription() } func __testOnly_addressDescription(_ addr: sockaddr_in6) -> String { - return addr.addressDescription() + addr.addressDescription() } func __testOnly_withSockAddr( - _ addr: sockaddr_in, _ body: (UnsafePointer, Int) throws -> ReturnType + _ addr: sockaddr_in, + _ body: (UnsafePointer, Int) throws -> ReturnType ) rethrows -> ReturnType { - return try addr.withSockAddr(body) + try addr.withSockAddr(body) } func __testOnly_withSockAddr( - _ addr: sockaddr_in6, _ body: (UnsafePointer, Int) throws -> ReturnType + _ addr: sockaddr_in6, + _ body: (UnsafePointer, Int) throws -> ReturnType ) rethrows -> ReturnType { - return try addr.withSockAddr(body) + try addr.withSockAddr(body) } diff --git a/Sources/NIOCore/SocketOptionProvider.swift b/Sources/NIOCore/SocketOptionProvider.swift index fbb272ac0a..821a6f9fa6 100644 --- a/Sources/NIOCore/SocketOptionProvider.swift +++ b/Sources/NIOCore/SocketOptionProvider.swift @@ -58,20 +58,24 @@ public protocol SocketOptionProvider: _NIOPreconcurrencySendable { var eventLoop: EventLoop { get } #if !os(Windows) - /// Set a socket option for a given level and name to the specified value. - /// - /// This function is not memory-safe: if you set the generic type parameter incorrectly, - /// this function will still execute, and this can cause you to incorrectly interpret memory - /// and thereby read uninitialized or invalid memory. If at all possible, please use one of - /// the safe functions defined by this protocol. - /// - /// - parameters: - /// - level: The socket option level, e.g. `SOL_SOCKET` or `IPPROTO_IP`. - /// - name: The name of the socket option, e.g. `SO_REUSEADDR`. - /// - value: The value to set the socket option to. - /// - returns: An `EventLoopFuture` that fires when the option has been set, - /// or if an error has occurred. - func unsafeSetSocketOption(level: SocketOptionLevel, name: SocketOptionName, value: Value) -> EventLoopFuture + /// Set a socket option for a given level and name to the specified value. + /// + /// This function is not memory-safe: if you set the generic type parameter incorrectly, + /// this function will still execute, and this can cause you to incorrectly interpret memory + /// and thereby read uninitialized or invalid memory. If at all possible, please use one of + /// the safe functions defined by this protocol. + /// + /// - parameters: + /// - level: The socket option level, e.g. `SOL_SOCKET` or `IPPROTO_IP`. + /// - name: The name of the socket option, e.g. `SO_REUSEADDR`. + /// - value: The value to set the socket option to. + /// - returns: An `EventLoopFuture` that fires when the option has been set, + /// or if an error has occurred. + func unsafeSetSocketOption( + level: SocketOptionLevel, + name: SocketOptionName, + value: Value + ) -> EventLoopFuture #endif /// Set a socket option for a given level and name to the specified value. @@ -87,22 +91,26 @@ public protocol SocketOptionProvider: _NIOPreconcurrencySendable { /// - value: The value to set the socket option to. /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. - func unsafeSetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: Value) -> EventLoopFuture + func unsafeSetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option, + value: Value + ) -> EventLoopFuture #if !os(Windows) - /// Obtain the value of the socket option for the given level and name. - /// - /// This function is not memory-safe: if you set the generic type parameter incorrectly, - /// this function will still execute, and this can cause you to incorrectly interpret memory - /// and thereby read uninitialized or invalid memory. If at all possible, please use one of - /// the safe functions defined by this protocol. - /// - /// - parameters: - /// - level: The socket option level, e.g. `SOL_SOCKET` or `IPPROTO_IP`. - /// - name: The name of the socket option, e.g. `SO_REUSEADDR`. - /// - returns: An `EventLoopFuture` containing the value of the socket option, or - /// any error that occurred while retrieving the socket option. - func unsafeGetSocketOption(level: SocketOptionLevel, name: SocketOptionName) -> EventLoopFuture + /// Obtain the value of the socket option for the given level and name. + /// + /// This function is not memory-safe: if you set the generic type parameter incorrectly, + /// this function will still execute, and this can cause you to incorrectly interpret memory + /// and thereby read uninitialized or invalid memory. If at all possible, please use one of + /// the safe functions defined by this protocol. + /// + /// - parameters: + /// - level: The socket option level, e.g. `SOL_SOCKET` or `IPPROTO_IP`. + /// - name: The name of the socket option, e.g. `SO_REUSEADDR`. + /// - returns: An `EventLoopFuture` containing the value of the socket option, or + /// any error that occurred while retrieving the socket option. + func unsafeGetSocketOption(level: SocketOptionLevel, name: SocketOptionName) -> EventLoopFuture #endif /// Obtain the value of the socket option for the given level and name. @@ -117,19 +125,33 @@ public protocol SocketOptionProvider: _NIOPreconcurrencySendable { /// - name: The name of the socket option, e.g. `SO_REUSEADDR`. /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. - func unsafeGetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option) -> EventLoopFuture + func unsafeGetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option + ) -> EventLoopFuture } #if !os(Windows) - extension SocketOptionProvider { - func unsafeSetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: Value) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: SocketOptionLevel(level.rawValue), name: SocketOptionName(name.rawValue), value: value) - } +extension SocketOptionProvider { + func unsafeSetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option, + value: Value + ) -> EventLoopFuture { + self.unsafeSetSocketOption( + level: SocketOptionLevel(level.rawValue), + name: SocketOptionName(name.rawValue), + value: value + ) + } - func unsafeGetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option) -> EventLoopFuture { - return self.unsafeGetSocketOption(level: SocketOptionLevel(level.rawValue), name: SocketOptionName(name.rawValue)) - } + func unsafeGetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option + ) -> EventLoopFuture { + self.unsafeGetSocketOption(level: SocketOptionLevel(level.rawValue), name: SocketOptionName(name.rawValue)) } +} #endif // MARK:- Safe helper methods. @@ -147,7 +169,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setSoLinger(_ value: linger) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .socket, name: .so_linger, value: value) + self.unsafeSetSocketOption(level: .socket, name: .so_linger, value: value) } /// Gets the value of the socket option SO_LINGER. @@ -155,7 +177,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getSoLinger() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .socket, name: .so_linger) + self.unsafeGetSocketOption(level: .socket, name: .so_linger) } /// Sets the socket option IP_MULTICAST_IF to `value`. @@ -165,7 +187,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPMulticastIF(_ value: in_addr) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_if, value: value) + self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_if, value: value) } /// Gets the value of the socket option IP_MULTICAST_IF. @@ -173,7 +195,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPMulticastIF() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_if) + self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_if) } /// Sets the socket option IP_MULTICAST_TTL to `value`. @@ -183,7 +205,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPMulticastTTL(_ value: CUnsignedChar) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_ttl, value: value) + self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_ttl, value: value) } /// Gets the value of the socket option IP_MULTICAST_TTL. @@ -191,7 +213,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPMulticastTTL() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_ttl) + self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_ttl) } /// Sets the socket option IP_MULTICAST_LOOP to `value`. @@ -201,7 +223,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPMulticastLoop(_ value: CUnsignedChar) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_loop, value: value) + self.unsafeSetSocketOption(level: .ip, name: .ip_multicast_loop, value: value) } /// Gets the value of the socket option IP_MULTICAST_LOOP. @@ -209,7 +231,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPMulticastLoop() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_loop) + self.unsafeGetSocketOption(level: .ip, name: .ip_multicast_loop) } /// Sets the socket option IPV6_MULTICAST_IF to `value`. @@ -219,7 +241,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPv6MulticastIF(_ value: CUnsignedInt) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_if, value: value) + self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_if, value: value) } /// Gets the value of the socket option IPV6_MULTICAST_IF. @@ -227,7 +249,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPv6MulticastIF() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_if) + self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_if) } /// Sets the socket option IPV6_MULTICAST_HOPS to `value`. @@ -237,7 +259,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPv6MulticastHops(_ value: CInt) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_hops, value: value) + self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_hops, value: value) } /// Gets the value of the socket option IPV6_MULTICAST_HOPS. @@ -245,7 +267,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPv6MulticastHops() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_hops) + self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_hops) } /// Sets the socket option IPV6_MULTICAST_LOOP to `value`. @@ -255,7 +277,7 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` that fires when the option has been set, /// or if an error has occurred. public func setIPv6MulticastLoop(_ value: CUnsignedInt) -> EventLoopFuture { - return self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_loop, value: value) + self.unsafeSetSocketOption(level: .ipv6, name: .ipv6_multicast_loop, value: value) } /// Gets the value of the socket option IPV6_MULTICAST_LOOP. @@ -263,42 +285,42 @@ extension SocketOptionProvider { /// - returns: An `EventLoopFuture` containing the value of the socket option, or /// any error that occurred while retrieving the socket option. public func getIPv6MulticastLoop() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_loop) + self.unsafeGetSocketOption(level: .ipv6, name: .ipv6_multicast_loop) } #if os(Linux) || os(FreeBSD) || os(Android) - /// Gets the value of the socket option TCP_INFO. - /// - /// This socket option cannot be set. - /// - /// - returns: An `EventLoopFuture` containing the value of the socket option, or - /// any error that occurred while retrieving the socket option. - public func getTCPInfo() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .tcp, name: .tcp_info) - } + /// Gets the value of the socket option TCP_INFO. + /// + /// This socket option cannot be set. + /// + /// - returns: An `EventLoopFuture` containing the value of the socket option, or + /// any error that occurred while retrieving the socket option. + public func getTCPInfo() -> EventLoopFuture { + self.unsafeGetSocketOption(level: .tcp, name: .tcp_info) + } #endif #if canImport(Darwin) - /// Gets the value of the socket option TCP_CONNECTION_INFO. - /// - /// This socket option cannot be set. - /// - /// - returns: An `EventLoopFuture` containing the value of the socket option, or - /// any error that occurred while retrieving the socket option. - public func getTCPConnectionInfo() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .tcp, name: .tcp_connection_info) - } + /// Gets the value of the socket option TCP_CONNECTION_INFO. + /// + /// This socket option cannot be set. + /// + /// - returns: An `EventLoopFuture` containing the value of the socket option, or + /// any error that occurred while retrieving the socket option. + public func getTCPConnectionInfo() -> EventLoopFuture { + self.unsafeGetSocketOption(level: .tcp, name: .tcp_connection_info) + } #endif #if os(Linux) - /// Gets the value of the socket option MPTCP_INFO. - /// - /// This socket option cannot be set. - /// - /// - returns: An `EventLoopFuture` containing the value of the socket option, or - /// any error that occurred while retrieving the socket option. - public func getMPTCPInfo() -> EventLoopFuture { - return self.unsafeGetSocketOption(level: .mptcp, name: .mptcp_info) - } + /// Gets the value of the socket option MPTCP_INFO. + /// + /// This socket option cannot be set. + /// + /// - returns: An `EventLoopFuture` containing the value of the socket option, or + /// any error that occurred while retrieving the socket option. + public func getMPTCPInfo() -> EventLoopFuture { + self.unsafeGetSocketOption(level: .mptcp, name: .mptcp_info) + } #endif } diff --git a/Sources/NIOCore/SystemCallHelpers.swift b/Sources/NIOCore/SystemCallHelpers.swift index cbf3b53a0f..47ae2ca8cb 100644 --- a/Sources/NIOCore/SystemCallHelpers.swift +++ b/Sources/NIOCore/SystemCallHelpers.swift @@ -63,31 +63,35 @@ private func isUnacceptableErrno(_ code: Int32) -> Bool { } } -private func preconditionIsNotUnacceptableErrno(err: CInt, where function: String) -> Void { +private func preconditionIsNotUnacceptableErrno(err: CInt, where function: String) { // strerror is documented to return "Unknown error: ..." for illegal value so it won't ever fail - precondition(!isUnacceptableErrno(err), "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))") + precondition( + !isUnacceptableErrno(err), + "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))" + ) } -/* - * Sorry, we really try hard to not use underscored attributes. In this case - * however we seem to break the inlining threshold which makes a system call - * take twice the time, ie. we need this exception. - */ +// Sorry, we really try hard to not use underscored attributes. In this case +// however we seem to break the inlining threshold which makes a system call +// take twice the time, ie. we need this exception. @inline(__always) @discardableResult -internal func syscall(blocking: Bool, - where function: String = #function, - _ body: () throws -> T) - throws -> CoreIOResult { +internal func syscall( + blocking: Bool, + where function: String = #function, + _ body: () throws -> T +) + throws -> CoreIOResult +{ while true { let res = try body() if res == -1 { -#if os(Windows) + #if os(Windows) var err: CInt = 0 ucrt._get_errno(&err) -#else + #else let err = errno -#endif + #endif switch (err, blocking) { case (EINTR, _): continue @@ -106,7 +110,7 @@ enum SystemCalls { @discardableResult @inline(never) internal static func dup(descriptor: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysDup(descriptor) }.result } @@ -115,12 +119,12 @@ enum SystemCalls { internal static func close(descriptor: CInt) throws { let res = sysClose(descriptor) if res == -1 { -#if os(Windows) + #if os(Windows) var err: CInt = 0 ucrt._get_errno(&err) -#else + #else let err = errno -#endif + #endif // There is really nothing "good" we can do when EINTR was reported on close. // So just ignore it and "assume" everything is fine == we closed the file descriptor. @@ -136,48 +140,59 @@ enum SystemCalls { } @inline(never) - internal static func open(file: UnsafePointer, oFlag: CInt, - mode: NIOPOSIXFileMode) throws -> CInt { -#if os(Windows) + internal static func open( + file: UnsafePointer, + oFlag: CInt, + mode: NIOPOSIXFileMode + ) throws -> CInt { + #if os(Windows) return try syscall(blocking: false) { var fh: CInt = -1 let _ = ucrt._sopen_s(&fh, file, oFlag, _SH_DENYNO, mode) return fh }.result -#else + #else return try syscall(blocking: false) { sysOpenWithMode(file, oFlag, mode) }.result -#endif + #endif } @discardableResult @inline(never) internal static func lseek(descriptor: CInt, offset: off_t, whence: CInt) throws -> off_t { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysLseek(descriptor, offset, whence) }.result } -#if os(Windows) + #if os(Windows) @inline(never) - internal static func read(descriptor: CInt, pointer: UnsafeMutableRawPointer, size: CUnsignedInt) throws -> CoreIOResult { - return try syscall(blocking: true) { + internal static func read( + descriptor: CInt, + pointer: UnsafeMutableRawPointer, + size: CUnsignedInt + ) throws -> CoreIOResult { + try syscall(blocking: true) { sysRead(descriptor, pointer, size) } } -#else + #else @inline(never) - internal static func read(descriptor: CInt, pointer: UnsafeMutableRawPointer, size: size_t) throws -> CoreIOResult { - return try syscall(blocking: true) { + internal static func read( + descriptor: CInt, + pointer: UnsafeMutableRawPointer, + size: size_t + ) throws -> CoreIOResult { + try syscall(blocking: true) { sysRead(descriptor, pointer, size) } } -#endif + #endif @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysIfNameToIndex(name!) }.result } diff --git a/Sources/NIOCore/TimeAmount+Duration.swift b/Sources/NIOCore/TimeAmount+Duration.swift index 1dfeecb975..fb3011fe5f 100644 --- a/Sources/NIOCore/TimeAmount+Duration.swift +++ b/Sources/NIOCore/TimeAmount+Duration.swift @@ -33,7 +33,7 @@ extension Swift.Duration { } @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -internal extension Swift.Duration { +extension Swift.Duration { /// The duration represented as nanoseconds, clamped to maximum expressible value. var nanosecondsClamped: Int64 { let components = self.components diff --git a/Sources/NIOCore/TypeAssistedChannelHandler.swift b/Sources/NIOCore/TypeAssistedChannelHandler.swift index ce0a055d90..15dd42c442 100644 --- a/Sources/NIOCore/TypeAssistedChannelHandler.swift +++ b/Sources/NIOCore/TypeAssistedChannelHandler.swift @@ -28,12 +28,12 @@ public protocol _EmittingChannelHandler { extension _EmittingChannelHandler { @inlinable public func wrapOutboundOut(_ value: OutboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } @inlinable public static func wrapOutboundOut(_ value: OutboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } } @@ -60,22 +60,22 @@ public protocol ChannelInboundHandler: _ChannelInboundHandler, _EmittingChannelH extension ChannelInboundHandler { @inlinable public func unwrapInboundIn(_ value: NIOAny) -> InboundIn { - return value.forceAs() + value.forceAs() } @inlinable public func wrapInboundOut(_ value: InboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } @inlinable public static func unwrapInboundIn(_ value: NIOAny) -> InboundIn { - return value.forceAs() + value.forceAs() } @inlinable public static func wrapInboundOut(_ value: InboundOut) -> NIOAny { - return NIOAny(value) + NIOAny(value) } } @@ -95,12 +95,12 @@ public protocol ChannelOutboundHandler: _ChannelOutboundHandler, _EmittingChanne extension ChannelOutboundHandler { @inlinable public func unwrapOutboundIn(_ value: NIOAny) -> OutboundIn { - return value.forceAs() + value.forceAs() } @inlinable public static func unwrapOutboundIn(_ value: NIOAny) -> OutboundIn { - return value.forceAs() + value.forceAs() } } diff --git a/Sources/NIOCore/UniversalBootstrapSupport.swift b/Sources/NIOCore/UniversalBootstrapSupport.swift index 3cc6b5a1b7..16a04f7733 100644 --- a/Sources/NIOCore/UniversalBootstrapSupport.swift +++ b/Sources/NIOCore/UniversalBootstrapSupport.swift @@ -154,21 +154,25 @@ public struct NIOClientTCPBootstrap { /// - parameters: /// - bootstrap: The underlying bootstrap to use. /// - tls: The TLS implementation to use, needs to be compatible with `Bootstrap`. - public init(_ bootstrap: Bootstrap, tls: TLS) where TLS.Bootstrap == Bootstrap { + public init< + Bootstrap: NIOClientTCPBootstrapProtocol, + TLS: NIOClientTLSProvider + >(_ bootstrap: Bootstrap, tls: TLS) where TLS.Bootstrap == Bootstrap { self.underlyingBootstrap = bootstrap self.tlsEnablerTypeErased = { bootstrap in - return tls.enableTLS(bootstrap as! TLS.Bootstrap) + tls.enableTLS(bootstrap as! TLS.Bootstrap) } } - private init(_ bootstrap: NIOClientTCPBootstrapProtocol, - tlsEnabler: @escaping (NIOClientTCPBootstrapProtocol) -> NIOClientTCPBootstrapProtocol) { + private init( + _ bootstrap: NIOClientTCPBootstrapProtocol, + tlsEnabler: @escaping (NIOClientTCPBootstrapProtocol) -> NIOClientTCPBootstrapProtocol + ) { self.underlyingBootstrap = bootstrap self.tlsEnablerTypeErased = tlsEnabler } - internal init(_ original : NIOClientTCPBootstrap, updating underlying : NIOClientTCPBootstrapProtocol) { + internal init(_ original: NIOClientTCPBootstrap, updating underlying: NIOClientTCPBootstrapProtocol) { self.underlyingBootstrap = underlying self.tlsEnablerTypeErased = original.tlsEnablerTypeErased } @@ -190,9 +194,13 @@ public struct NIOClientTCPBootstrap { /// /// - parameters: /// - handler: A closure that initializes the provided `Channel`. - public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture) -> NIOClientTCPBootstrap { - return NIOClientTCPBootstrap(self.underlyingBootstrap.channelInitializer(handler), - tlsEnabler: self.tlsEnablerTypeErased) + public func channelInitializer( + _ handler: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> NIOClientTCPBootstrap { + NIOClientTCPBootstrap( + self.underlyingBootstrap.channelInitializer(handler), + tlsEnabler: self.tlsEnablerTypeErased + ) } /// Specifies a `ChannelOption` to be applied to the `SocketChannel`. @@ -201,15 +209,19 @@ public struct NIOClientTCPBootstrap { /// - option: The option to be applied. /// - value: The value for the option. public func channelOption(_ option: Option, value: Option.Value) -> NIOClientTCPBootstrap { - return NIOClientTCPBootstrap(self.underlyingBootstrap.channelOption(option, value: value), - tlsEnabler: self.tlsEnablerTypeErased) + NIOClientTCPBootstrap( + self.underlyingBootstrap.channelOption(option, value: value), + tlsEnabler: self.tlsEnablerTypeErased + ) } /// - parameters: /// - timeout: The timeout that will apply to the connection attempt. public func connectTimeout(_ timeout: TimeAmount) -> NIOClientTCPBootstrap { - return NIOClientTCPBootstrap(self.underlyingBootstrap.connectTimeout(timeout), - tlsEnabler: self.tlsEnablerTypeErased) + NIOClientTCPBootstrap( + self.underlyingBootstrap.connectTimeout(timeout), + tlsEnabler: self.tlsEnablerTypeErased + ) } /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. @@ -219,7 +231,7 @@ public struct NIOClientTCPBootstrap { /// - port: The port to connect to. /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(host: String, port: Int) -> EventLoopFuture { - return self.underlyingBootstrap.connect(host: host, port: port) + self.underlyingBootstrap.connect(host: host, port: port) } /// Specify the `address` to connect to for the TCP `Channel` that will be established. @@ -228,7 +240,7 @@ public struct NIOClientTCPBootstrap { /// - address: The address to connect to. /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(to address: SocketAddress) -> EventLoopFuture { - return self.underlyingBootstrap.connect(to: address) + self.underlyingBootstrap.connect(to: address) } /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. @@ -237,14 +249,15 @@ public struct NIOClientTCPBootstrap { /// - unixDomainSocketPath: The _Unix domain socket_ path to connect to. /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(unixDomainSocketPath: String) -> EventLoopFuture { - return self.underlyingBootstrap.connect(unixDomainSocketPath: unixDomainSocketPath) + self.underlyingBootstrap.connect(unixDomainSocketPath: unixDomainSocketPath) } - @discardableResult public func enableTLS() -> NIOClientTCPBootstrap { - return NIOClientTCPBootstrap(self.tlsEnablerTypeErased(self.underlyingBootstrap), - tlsEnabler: self.tlsEnablerTypeErased) + NIOClientTCPBootstrap( + self.tlsEnablerTypeErased(self.underlyingBootstrap), + tlsEnabler: self.tlsEnablerTypeErased + ) } } diff --git a/Sources/NIOCore/Utilities.swift b/Sources/NIOCore/Utilities.swift index f00ae18ecc..e192be8f9a 100644 --- a/Sources/NIOCore/Utilities.swift +++ b/Sources/NIOCore/Utilities.swift @@ -48,7 +48,12 @@ import Darwin @inlinable internal func debugOnly(_ body: () -> Void) { // FIXME: duplicated with NIO. - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } /// Allows to "box" another value. @@ -74,15 +79,17 @@ public enum System { /// /// - returns: The logical core count on the system. public static var coreCount: Int { -#if os(Windows) + #if os(Windows) var dwLength: DWORD = 0 _ = GetLogicalProcessorInformation(nil, &dwLength) let alignment: Int = MemoryLayout.alignment let pBuffer: UnsafeMutableRawPointer = - UnsafeMutableRawPointer.allocate(byteCount: Int(dwLength), - alignment: alignment) + UnsafeMutableRawPointer.allocate( + byteCount: Int(dwLength), + alignment: alignment + ) defer { pBuffer.deallocate() } @@ -90,18 +97,22 @@ public enum System { let dwSLPICount: Int = Int(dwLength) / MemoryLayout.stride let pSLPI: UnsafeMutablePointer = - pBuffer.bindMemory(to: SYSTEM_LOGICAL_PROCESSOR_INFORMATION.self, - capacity: dwSLPICount) + pBuffer.bindMemory( + to: SYSTEM_LOGICAL_PROCESSOR_INFORMATION.self, + capacity: dwSLPICount + ) let bResult: Bool = GetLogicalProcessorInformation(pSLPI, &dwLength) precondition(bResult, "GetLogicalProcessorInformation: \(GetLastError())") - return UnsafeBufferPointer(start: pSLPI, - count: dwSLPICount) - .filter { $0.Relationship == RelationProcessorCore } - .map { $0.ProcessorMask.nonzeroBitCount } - .reduce(0, +) -#elseif os(Linux) || os(Android) + return UnsafeBufferPointer( + start: pSLPI, + count: dwSLPICount + ) + .filter { $0.Relationship == RelationProcessorCore } + .map { $0.ProcessorMask.nonzeroBitCount } + .reduce(0, +) + #elseif os(Linux) || os(Android) if let quota2 = Linux.coreCountCgroup2Restriction() { return quota2 } else if let quota = Linux.coreCountCgroup1Restriction() { @@ -111,12 +122,12 @@ public enum System { } else { return sysconf(CInt(_SC_NPROCESSORS_ONLN)) } -#else + #else return sysconf(CInt(_SC_NPROCESSORS_ONLN)) -#endif + #endif } -#if !os(Windows) + #if !os(Windows) /// A utility function that enumerates the available network interfaces on this machine. /// /// This function returns values that are true for a brief snapshot in time. These results can @@ -146,7 +157,7 @@ public enum System { return interfaces } -#endif + #endif /// A utility function that enumerates the available network devices on this machine. /// @@ -160,7 +171,7 @@ public enum System { var devices: [NIONetworkDevice] = [] devices.reserveCapacity(12) // Arbitrary choice. -#if os(Windows) + #if os(Windows) var ulSize: ULONG = 0 _ = GetAdaptersAddresses(ULONG(AF_UNSPEC), 0, nil, nil, &ulSize) @@ -172,8 +183,13 @@ public enum System { } let ulResult: ULONG = - GetAdaptersAddresses(ULONG(AF_UNSPEC), 0, nil, pBuffer.baseAddress, - &ulSize) + GetAdaptersAddresses( + ULONG(AF_UNSPEC), + 0, + nil, + pBuffer.baseAddress, + &ulSize + ) guard ulResult == ERROR_SUCCESS else { throw IOError(windows: ulResult, reason: "GetAdaptersAddresses") } @@ -193,7 +209,7 @@ public enum System { } pAdapter = pAdapter!.pointee.Next } -#else + #else var interface: UnsafeMutablePointer? = nil try SystemCalls.getifaddrs(&interface) let originalInterface = interface @@ -208,7 +224,7 @@ public enum System { interface = concreteInterface.pointee.ifa_next } -#endif + #endif return devices } } diff --git a/Sources/NIOCrashTester/CrashTests+EventLoop.swift b/Sources/NIOCrashTester/CrashTests+EventLoop.swift index d3b0234eec..17b5a56f4b 100644 --- a/Sources/NIOCrashTester/CrashTests+EventLoop.swift +++ b/Sources/NIOCrashTester/CrashTests+EventLoop.swift @@ -17,7 +17,7 @@ import Dispatch import NIOCore import NIOPosix -fileprivate let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) +private let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) struct EventLoopCrashTests { let testMultiThreadedELGCrashesOnZeroThreads = CrashTest( @@ -82,9 +82,9 @@ struct EventLoopCrashTests { exit(2) } func f() { - el.scheduleTask(in: .nanoseconds(0)) { [f /* to make 5.1 compiler not crash */] in + el.scheduleTask(in: .nanoseconds(0)) { [f] in f() - }.futureResult.whenFailure { [f /* to make 5.1 compiler not crash */] error in + }.futureResult.whenFailure { [f] error in guard case .some(.shutdown) = error as? EventLoopError else { exit(3) } @@ -180,7 +180,7 @@ struct EventLoopCrashTests { NIOSingletons.groupLoopCountSuggestion = -1 } - #if compiler(>=5.9) && swift(<5.11) // We only support Concurrency executor take-over on 5.9-5.10, as versions greater than 5.10 have not been properly tested. + #if compiler(>=5.9) && swift(<5.11) // We only support Concurrency executor take-over on 5.9-5.10, as versions greater than 5.10 have not been properly tested. let testInstallingSingletonMTELGAsConcurrencyExecutorWorksButOnlyOnce = CrashTest( regex: #"Fatal error: Must be called only once"# ) { @@ -207,6 +207,6 @@ struct EventLoopCrashTests { // This should crash _ = NIOSingletons.unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() } - #endif // compiler(>=5.9) && swift(<5.11) + #endif // compiler(>=5.9) && swift(<5.11) } -#endif // !canImport(Darwin) || os(macOS) +#endif // !canImport(Darwin) || os(macOS) diff --git a/Sources/NIOCrashTester/CrashTests+HTTP.swift b/Sources/NIOCrashTester/CrashTests+HTTP.swift index 5fb87e3184..2773754c23 100644 --- a/Sources/NIOCrashTester/CrashTests+HTTP.swift +++ b/Sources/NIOCrashTester/CrashTests+HTTP.swift @@ -18,28 +18,44 @@ import NIOHTTP1 struct HTTPCrashTests { let testEncodingChunkedAndContentLengthForRequestsCrashes = CrashTest( - regex: "Assertion failed: illegal HTTP sent: HTTPRequestHead .* contains both a content-length and transfer-encoding:chunked", + regex: + "Assertion failed: illegal HTTP sent: HTTPRequestHead .* contains both a content-length and transfer-encoding:chunked", { let channel = EmbeddedChannel(handler: HTTPRequestEncoder()) _ = try? channel.writeAndFlush( HTTPClientRequestPart.head( - HTTPRequestHead(version: .http1_1, - method: .POST, - uri: "/", - headers: ["content-Length": "1", - "transfer-Encoding": "chunked"]))).wait() - }) + HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: [ + "content-Length": "1", + "transfer-Encoding": "chunked", + ] + ) + ) + ).wait() + } + ) let testEncodingChunkedAndContentLengthForResponseCrashes = CrashTest( - regex: "Assertion failed: illegal HTTP sent: HTTPResponseHead .* contains both a content-length and transfer-encoding:chunked", + regex: + "Assertion failed: illegal HTTP sent: HTTPResponseHead .* contains both a content-length and transfer-encoding:chunked", { let channel = EmbeddedChannel(handler: HTTPResponseEncoder()) _ = try? channel.writeAndFlush( HTTPServerResponsePart.head( - HTTPResponseHead(version: .http1_1, - status: .ok, - headers: ["content-Length": "1", - "transfer-Encoding": "chunked"]))).wait() - }) + HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: [ + "content-Length": "1", + "transfer-Encoding": "chunked", + ] + ) + ) + ).wait() + } + ) } #endif diff --git a/Sources/NIOCrashTester/CrashTests+LoopBound.swift b/Sources/NIOCrashTester/CrashTests+LoopBound.swift index c813557e66..69f4c071db 100644 --- a/Sources/NIOCrashTester/CrashTests+LoopBound.swift +++ b/Sources/NIOCrashTester/CrashTests+LoopBound.swift @@ -15,20 +15,20 @@ import NIOCore import NIOPosix -fileprivate let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) +private let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) struct LoopBoundTests { #if !canImport(Darwin) || os(macOS) let testInitChecksEventLoop = CrashTest( regex: "NIOCore/NIOLoopBound.swift:[0-9]+: Precondition failed" ) { - _ = NIOLoopBound(1, eventLoop: group.any()) // BOOM + _ = NIOLoopBound(1, eventLoop: group.any()) // BOOM } let testInitOfBoxChecksEventLoop = CrashTest( regex: "NIOCore/NIOLoopBound.swift:[0-9]+: Precondition failed" ) { - _ = NIOLoopBoundBox(1, eventLoop: group.any()) // BOOM + _ = NIOLoopBoundBox(1, eventLoop: group.any()) // BOOM } let testGetChecksEventLoop = CrashTest( @@ -38,7 +38,7 @@ struct LoopBoundTests { let sendable = try? loop.submit { NIOLoopBound(1, eventLoop: loop) }.wait() - _ = sendable?.value // BOOM + _ = sendable?.value // BOOM } let testGetOfBoxChecksEventLoop = CrashTest( @@ -48,7 +48,7 @@ struct LoopBoundTests { let sendable = try? loop.submit { NIOLoopBoundBox(1, eventLoop: loop) }.wait() - _ = sendable?.value // BOOM + _ = sendable?.value // BOOM } let testSetChecksEventLoop = CrashTest( diff --git a/Sources/NIOCrashTester/CrashTests+System.swift b/Sources/NIOCrashTester/CrashTests+System.swift index e9899c6893..3147ff081a 100644 --- a/Sources/NIOCrashTester/CrashTests+System.swift +++ b/Sources/NIOCrashTester/CrashTests+System.swift @@ -16,12 +16,14 @@ import NIOPosix import Foundation -fileprivate let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) +private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) struct SystemCrashTests { let testEBADFIsUnacceptable = CrashTest( - regex: "Precondition failed: unacceptable errno \(EBADF) Bad file descriptor in", { + regex: "Precondition failed: unacceptable errno \(EBADF) Bad file descriptor in", + { _ = try? NIOPipeBootstrap(group: group).takingOwnershipOfDescriptors(input: .max, output: .max - 1).wait() - }) + } + ) } #endif diff --git a/Sources/NIOCrashTester/OutputGrepper.swift b/Sources/NIOCrashTester/OutputGrepper.swift index 56dfb68a79..7802fe83b9 100644 --- a/Sources/NIOCrashTester/OutputGrepper.swift +++ b/Sources/NIOCrashTester/OutputGrepper.swift @@ -12,8 +12,9 @@ // //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOFoundationCompat +import NIOPosix + import class Foundation.Pipe internal struct OutputGrepper { @@ -31,8 +32,10 @@ internal struct OutputGrepper { .channelOption(ChannelOptions.allowRemoteHalfClosure, value: true) .channelInitializer { channel in channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandlers([ByteToMessageHandler(NewlineFramer()), - GrepHandler(promise: outputPromise)]) + try channel.pipeline.syncOperations.addHandlers([ + ByteToMessageHandler(NewlineFramer()), + GrepHandler(promise: outputPromise), + ]) } } .takingOwnershipOfDescriptor(input: dup(processToChannel.fileHandleForReading.fileDescriptor)) @@ -40,8 +43,10 @@ internal struct OutputGrepper { processToChannel.fileHandleForReading.closeFile() processToChannel.fileHandleForWriting.closeFile() channelFuture.cascadeFailure(to: outputPromise) - return OutputGrepper(result: outputPromise.futureResult, - processOutputPipe: processOutputPipe) + return OutputGrepper( + result: outputPromise.futureResult, + processOutputPipe: processOutputPipe + ) } } @@ -63,9 +68,9 @@ private final class GrepHandler: ChannelInboundHandler { func channelRead(context: ChannelHandlerContext, data: NIOAny) { let line = Self.unwrapInboundIn(data) - if line.lowercased().contains("fatal error") || - line.lowercased().contains("precondition failed") || - line.lowercased().contains("assertion failed") { + if line.lowercased().contains("fatal error") || line.lowercased().contains("precondition failed") + || line.lowercased().contains("assertion failed") + { self.promise.succeed(line) context.close(promise: nil) } diff --git a/Sources/NIOCrashTester/main.swift b/Sources/NIOCrashTester/main.swift index 9297abb822..6234947aea 100644 --- a/Sources/NIOCrashTester/main.swift +++ b/Sources/NIOCrashTester/main.swift @@ -33,14 +33,14 @@ struct CrashTest { extension Process { var binaryPath: String? { get { - if #available(macOS 10.13, /* Linux */ *) { + if #available(macOS 10.13, *) { return self.executableURL?.path } else { return self.launchPath } } set { - if #available(macOS 10.13, /* Linux */ *) { + if #available(macOS 10.13, *) { self.executableURL = newValue.map { URL(fileURLWithPath: $0) } } else { self.launchPath = newValue @@ -76,14 +76,14 @@ func main() throws { } func allTestsForSuite(_ testSuite: String) -> [(String, CrashTest)] { - return crashTestSuites[testSuite].map { testSuiteObject in + crashTestSuites[testSuite].map { testSuiteObject in Mirror(reflecting: testSuiteObject) .children .filter { $0.label?.starts(with: "test") ?? false } .compactMap { crashTestDescriptor in crashTestDescriptor.label.flatMap { label in (crashTestDescriptor.value as? CrashTest).map { crashTest in - return (label, crashTest) + (label, crashTest) } } } @@ -91,14 +91,16 @@ func main() throws { } func findCrashTest(_ testName: String, suite: String) -> CrashTest? { - return allTestsForSuite(suite) + allTestsForSuite(suite) .first(where: { $0.0 == testName })? .1 } - func interpretOutput(_ result: Result, - regex: String, - runResult: RunResult) throws -> InterpretedRunResult { + func interpretOutput( + _ result: Result, + regex: String, + runResult: RunResult + ) throws -> InterpretedRunResult { struct NoOutputFound: Error {} #if arch(i386) || arch(x86_64) let expectedSignal = SIGILL @@ -107,12 +109,12 @@ func main() throws { #else #error("unknown CPU architecture for which we don't know the expected signal for a crash") #endif - guard case .signal(Int(expectedSignal)) = runResult else { + guard case .signal(Int(expectedSignal)) = runResult else { return .unexpectedRunResult(runResult) } let output = try result.get() - if output.range(of: regex, options: .regularExpression) != nil { + if output.range(of: regex, options: .regularExpression) != nil { return .crashedAsExpected } else { return .regexDidNotMatch(regex: regex, output: output) @@ -163,20 +165,26 @@ func main() throws { let result: Result = Result { try grepper.result.wait() } - return try interpretOutput(result, - regex: crashTest.crashRegex, - runResult: process.terminationReason == .exit ? - .exit(Int(process.terminationStatus)) : - .signal(Int(process.terminationStatus))) + return try interpretOutput( + result, + regex: crashTest.crashRegex, + runResult: process.terminationReason == .exit + ? .exit(Int(process.terminationStatus)) : .signal(Int(process.terminationStatus)) + ) } var failedTests = 0 func runAndEval(_ test: String, suite: String) throws { print("running crash test \(suite).\(test)", terminator: " ") switch try runCrashTest(test, suite: suite, binary: CommandLine.arguments.first!) { - case .regexDidNotMatch(regex: let regex, output: let output): - print("FAILED: regex did not match output", "regex: \(regex)", "output: \(output)", - separator: "\n", terminator: "") + case .regexDidNotMatch(let regex, let output): + print( + "FAILED: regex did not match output", + "regex: \(regex)", + "output: \(output)", + separator: "\n", + terminator: "" + ) failedTests += 1 case .unexpectedRunResult(let runResult): print("FAILED: unexpected run result: \(runResult)") @@ -207,8 +215,9 @@ func main() throws { } case .some("_exec"): if let testSuiteName = CommandLine.arguments.dropFirst(2).first, - let testName = CommandLine.arguments.dropFirst(3).first, - let crashTest = findCrashTest(testName, suite: testSuiteName) { + let testName = CommandLine.arguments.dropFirst(3).first, + let crashTest = findCrashTest(testName, suite: testSuiteName) + { crashTest.runTest() } else { fatalError("can't find/create test for \(Array(CommandLine.arguments.dropFirst(2)))") diff --git a/Sources/NIOEchoClient/main.swift b/Sources/NIOEchoClient/main.swift index ed22302eaf..5e846a1cba 100644 --- a/Sources/NIOEchoClient/main.swift +++ b/Sources/NIOEchoClient/main.swift @@ -22,10 +22,10 @@ private final class EchoHandler: ChannelInboundHandler { public typealias OutboundOut = ByteBuffer private var sendBytes = 0 private var receiveBuffer: ByteBuffer = ByteBuffer() - + public func channelActive(context: ChannelHandlerContext) { print("Client connected to \(context.remoteAddress?.description ?? "unknown")") - + // We are connected. It's time to send the message to the server to initialize the ping-pong sequence. let buffer = context.channel.allocator.buffer(string: line) self.sendBytes = buffer.readableBytes @@ -36,7 +36,7 @@ private final class EchoHandler: ChannelInboundHandler { var unwrappedInboundData = Self.unwrapInboundIn(data) self.sendBytes -= unwrappedInboundData.readableBytes receiveBuffer.writeBuffer(&unwrappedInboundData) - + if self.sendBytes == 0 { let string = String(buffer: receiveBuffer) print("Received: '\(string)' back from the server, closing channel.") @@ -81,19 +81,21 @@ enum ConnectTo { let connectTarget: ConnectTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { case (_, .some(let cid), .some(let port)): - /* we got two arguments (Int, Int), let's interpret that as vsock cid and port */ - connectTarget = .vsock(VsockAddress( - cid: VsockAddress.ContextID(cid), - port: VsockAddress.Port(port) - )) + // we got two arguments (Int, Int), let's interpret that as vsock cid and port + connectTarget = .vsock( + VsockAddress( + cid: VsockAddress.ContextID(cid), + port: VsockAddress.Port(port) + ) + ) case (.some(let h), .none, .some(let p)): - /* we got two arguments (String, Int), let's interpret that as host and port */ + // we got two arguments (String, Int), let's interpret that as host and port connectTarget = .ip(host: h, port: p) case (.some(let portString), .none, .none): - /* we got one argument (String), let's interpret that as unix domain socket path */ + // we got one argument (String), let's interpret that as unix domain socket path connectTarget = .unixDomainSocket(path: portString) case (_, .some(let p), _): - /* we got one argument (Int), let's interpret that as port on default host */ + // we got one argument (Int), let's interpret that as port on default host connectTarget = .ip(host: defaultHost, port: p) default: connectTarget = .ip(host: defaultHost, port: defaultPort) diff --git a/Sources/NIOEchoServer/main.swift b/Sources/NIOEchoServer/main.swift index 9547d96e63..111ca580f3 100644 --- a/Sources/NIOEchoServer/main.swift +++ b/Sources/NIOEchoServer/main.swift @@ -77,19 +77,21 @@ enum BindTo { let bindTarget: BindTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { case (_, .some(let cid), .some(let port)): - /* we got two arguments (Int, Int), let's interpret that as vsock cid and port */ - bindTarget = .vsock(VsockAddress( - cid: VsockAddress.ContextID(cid), - port: VsockAddress.Port(port) - )) + // we got two arguments (Int, Int), let's interpret that as vsock cid and port + bindTarget = .vsock( + VsockAddress( + cid: VsockAddress.ContextID(cid), + port: VsockAddress.Port(port) + ) + ) case (.some(let h), _, .some(let p)): - /* we got two arguments (String, Int), let's interpret that as host and port */ + // we got two arguments (String, Int), let's interpret that as host and port bindTarget = .ip(host: h, port: p) case (.some(let pathString), .none, .none): - /* we got one argument (String), let's interpret that unix domain socket path */ + // we got one argument (String), let's interpret that unix domain socket path bindTarget = .unixDomainSocket(path: pathString) case (_, .some(let p), .none): - /* we got one argument (Int), let's interpret that as port on default host */ + // we got one argument (Int), let's interpret that as port on default host bindTarget = .ip(host: defaultHost, port: p) default: bindTarget = .ip(host: defaultHost, port: defaultPort) diff --git a/Sources/NIOEmbedded/AsyncTestingChannel.swift b/Sources/NIOEmbedded/AsyncTestingChannel.swift index 3324db25d7..11900d8353 100644 --- a/Sources/NIOEmbedded/AsyncTestingChannel.swift +++ b/Sources/NIOEmbedded/AsyncTestingChannel.swift @@ -110,7 +110,7 @@ public final class NIOAsyncTestingChannel: Channel { /// `true` if the ``NIOAsyncTestingChannel`` if there was unconsumed inbound, outbound, or pending outbound data left /// on the `Channel` when it was `finish`ed. public var hasLeftOvers: Bool { - return !self.isClean + !self.isClean } } @@ -140,7 +140,7 @@ public final class NIOAsyncTestingChannel: Channel { /// Returns `true` if the buffer was non-empty. public var isFull: Bool { - return !self.isEmpty + !self.isEmpty } } @@ -159,7 +159,7 @@ public final class NIOAsyncTestingChannel: Channel { } public static func == (lhs: WrongTypeError, rhs: WrongTypeError) -> Bool { - return lhs.expected == rhs.expected && lhs.actual == rhs.actual + lhs.expected == rhs.expected && lhs.actual == rhs.actual } } @@ -168,12 +168,12 @@ public final class NIOAsyncTestingChannel: Channel { /// An active ``NIOAsyncTestingChannel`` can be closed by calling `close` or ``finish()`` on the ``NIOAsyncTestingChannel``. /// /// - note: An ``NIOAsyncTestingChannel`` starts _inactive_ and can be activated, for example by calling `connect`. - public var isActive: Bool { return channelcore.isActive } + public var isActive: Bool { channelcore.isActive } /// - see: `ChannelOptions.Types.AllowRemoteHalfClosureOption` public var allowRemoteHalfClosure: Bool { get { - return channelcore.allowRemoteHalfClosure + channelcore.allowRemoteHalfClosure } set { channelcore.allowRemoteHalfClosure = newValue @@ -181,14 +181,14 @@ public final class NIOAsyncTestingChannel: Channel { } /// - see: `Channel.closeFuture` - public var closeFuture: EventLoopFuture { return channelcore.closePromise.futureResult } + public var closeFuture: EventLoopFuture { channelcore.closePromise.futureResult } /// - see: `Channel.allocator` public let allocator: ByteBufferAllocator = ByteBufferAllocator() /// - see: `Channel.eventLoop` public var eventLoop: EventLoop { - return self.testingEventLoop + self.testingEventLoop } /// Returns the ``NIOAsyncTestingEventLoop`` that this ``NIOAsyncTestingChannel`` uses. This will return the same instance as @@ -201,7 +201,7 @@ public final class NIOAsyncTestingChannel: Channel { // This is only written once, from a single thread, and never written again, so it's _technically_ thread-safe. Most methods cannot safely // be used from multiple threads, but `isActive`, `isOpen`, `eventLoop`, and `closeFuture` can all safely be used from any thread. Just. @usableFromInline - /*private but usableFromInline */ var channelcore: EmbeddedChannelCore! + var channelcore: EmbeddedChannelCore! /// Guards any of the getters/setters that can be accessed from any thread. private let stateLock: NIOLock = NIOLock() @@ -219,18 +219,18 @@ public final class NIOAsyncTestingChannel: Channel { /// - see: `Channel._channelCore` public var _channelCore: ChannelCore { - return channelcore + channelcore } /// - see: `Channel.pipeline` public var pipeline: ChannelPipeline { - return _pipeline + _pipeline } /// - see: `Channel.isWritable` public var isWritable: Bool { get { - return self.stateLock.withLock { self._isWritable } + self.stateLock.withLock { self._isWritable } } set { self.stateLock.withLock { () -> Void in @@ -242,7 +242,7 @@ public final class NIOAsyncTestingChannel: Channel { /// - see: `Channel.localAddress` public var localAddress: SocketAddress? { get { - return self.stateLock.withLock { self._localAddress } + self.stateLock.withLock { self._localAddress } } set { self.stateLock.withLock { () -> Void in @@ -254,7 +254,7 @@ public final class NIOAsyncTestingChannel: Channel { /// - see: `Channel.remoteAddress` public var remoteAddress: SocketAddress? { get { - return self.stateLock.withLock { self._remoteAddress } + self.stateLock.withLock { self._remoteAddress } } set { self.stateLock.withLock { () -> Void in @@ -282,7 +282,8 @@ public final class NIOAsyncTestingChannel: Channel { /// - parameters: /// - handler: The `ChannelHandler` to add to the `ChannelPipeline` before register. /// - loop: The ``NIOAsyncTestingEventLoop`` to use. - public convenience init(handler: ChannelHandler, loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()) async { + public convenience init(handler: ChannelHandler, loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()) async + { await self.init(handlers: [handler], loop: loop) } @@ -293,7 +294,10 @@ public final class NIOAsyncTestingChannel: Channel { /// - parameters: /// - handlers: The `ChannelHandler`s to add to the `ChannelPipeline` before register. /// - loop: The ``NIOAsyncTestingEventLoop`` to use. - public convenience init(handlers: [ChannelHandler], loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop()) async { + public convenience init( + handlers: [ChannelHandler], + loop: NIOAsyncTestingEventLoop = NIOAsyncTestingEventLoop() + ) async { self.init(loop: loop) try! await self._pipeline.addHandlers(handlers) @@ -334,9 +338,11 @@ public final class NIOAsyncTestingChannel: Channel { if c.outboundBuffer.isEmpty && c.inboundBuffer.isEmpty && c.pendingOutboundBuffer.isEmpty { return .clean } else { - return .leftOvers(inbound: c.inboundBuffer, - outbound: c.outboundBuffer, - pendingOutbound: c.pendingOutboundBuffer.map { $0.0 }) + return .leftOvers( + inbound: c.inboundBuffer, + outbound: c.outboundBuffer, + pendingOutbound: c.pendingOutboundBuffer.map { $0.0 } + ) } } } @@ -351,7 +357,7 @@ public final class NIOAsyncTestingChannel: Channel { /// writes) this will be ``LeftOverState/clean``. If there are any unconsumed inbound, outbound, or pending outbound /// events, the ``NIOAsyncTestingChannel`` will returns those as ``LeftOverState/leftOvers(inbound:outbound:pendingOutbound:)``. public func finish() async throws -> LeftOverState { - return try await self.finish(acceptAlreadyClosed: false) + try await self.finish(acceptAlreadyClosed: false) } /// If available, this method reads one element of type `T` out of the ``NIOAsyncTestingChannel``'s outbound buffer. If the @@ -395,9 +401,11 @@ public final class NIOAsyncTestingChannel: Channel { return } self.channelcore.outboundBufferConsumer.append { element in - continuation.resume(with: Result { - try self._cast(element) - }) + continuation.resume( + with: Result { + try self._cast(element) + } + ) } } catch { continuation.resume(throwing: error) @@ -443,9 +451,11 @@ public final class NIOAsyncTestingChannel: Channel { return } self.channelcore.inboundBufferConsumer.append { element in - continuation.resume(with: Result { - try self._cast(element) - }) + continuation.resume( + with: Result { + try self._cast(element) + } + ) } } catch { continuation.resume(throwing: error) @@ -488,7 +498,7 @@ public final class NIOAsyncTestingChannel: Channel { try await self.writeAndFlush(NIOAny(data)) return try await self.testingEventLoop.executeInContext { - return self.channelcore.outboundBuffer.isEmpty ? .empty : .full(self.channelcore.outboundBuffer) + self.channelcore.outboundBuffer.isEmpty ? .empty : .full(self.channelcore.outboundBuffer) } } @@ -510,7 +520,6 @@ public final class NIOAsyncTestingChannel: Channel { } } - @inlinable func _readFromBuffer(buffer: inout CircularBuffer) throws -> T? { self.testingEventLoop.preconditionInEventLoop() @@ -524,7 +533,10 @@ public final class NIOAsyncTestingChannel: Channel { @inlinable func _cast(_ element: NIOAny, to: T.Type = T.self) throws -> T { guard let t = self._channelCore.tryUnwrapData(element, as: T.self) else { - throw WrongTypeError(expected: T.self, actual: type(of: self._channelCore.tryUnwrapData(element, as: Any.self)!)) + throw WrongTypeError( + expected: T.self, + actual: type(of: self._channelCore.tryUnwrapData(element, as: Any.self)!) + ) } return t } @@ -552,7 +564,7 @@ public final class NIOAsyncTestingChannel: Channel { /// - see: `Channel.getOption` @inlinable - public func getOption(_ option: Option) -> EventLoopFuture { + public func getOption(_ option: Option) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.eventLoop.makeSucceededFuture(self.getOptionSync(option)) } else { @@ -633,7 +645,7 @@ public final class NIOAsyncTestingChannel: Channel { } public final var syncOptions: NIOSynchronousChannelOptions? { - return SynchronousOptions(channel: self) + SynchronousOptions(channel: self) } } @@ -645,7 +657,7 @@ public final class NIOAsyncTestingChannel: Channel { // in a channel pipeline _are_ `Sendable`, and because these objects only carry NIOAnys in cases // where the `Channel` itself no longer holds a reference to these objects. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension NIOAsyncTestingChannel.LeftOverState: @unchecked Sendable { } +extension NIOAsyncTestingChannel.LeftOverState: @unchecked Sendable {} @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -extension NIOAsyncTestingChannel.BufferState: @unchecked Sendable { } +extension NIOAsyncTestingChannel.BufferState: @unchecked Sendable {} diff --git a/Sources/NIOEmbedded/AsyncTestingEventLoop.swift b/Sources/NIOEmbedded/AsyncTestingEventLoop.swift index bdc9423c85..58520c34f1 100644 --- a/Sources/NIOEmbedded/AsyncTestingEventLoop.swift +++ b/Sources/NIOEmbedded/AsyncTestingEventLoop.swift @@ -14,10 +14,9 @@ import Atomics import Dispatch -import _NIODataStructures -import NIOCore import NIOConcurrencyHelpers - +import NIOCore +import _NIODataStructures /// An `EventLoop` that is thread safe and whose execution is fully controlled /// by the user. @@ -64,12 +63,12 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { /// As we need to access this from any thread, we store this as an atomic. private let _now = ManagedAtomic(0) internal var now: NIODeadline { - return NIODeadline.uptimeNanoseconds(self._now.load(ordering: .relaxed)) + NIODeadline.uptimeNanoseconds(self._now.load(ordering: .relaxed)) } /// This is used to derive an identifier for this loop. private var thisLoopID: ObjectIdentifier { - return ObjectIdentifier(self) + ObjectIdentifier(self) } /// A dispatch specific that we use to determine whether we are on the queue for this @@ -107,7 +106,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { /// - see: `EventLoop.inEventLoop` public var inEventLoop: Bool { - return DispatchQueue.getSpecific(key: Self.inQueueKey) == self.thisLoopID + DispatchQueue.getSpecific(key: Self.inQueueKey) == self.thisLoopID } /// Initialize a new `NIOAsyncTestingEventLoop`. @@ -128,13 +127,19 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { ) { dispatchPrecondition(condition: .onQueue(self.queue)) - let task = EmbeddedScheduledTask(id: taskID, readyTime: deadline, insertOrder: self.nextTaskNumber(), task: { - do { - promise.succeed(try task()) - } catch let err { - promise.fail(err) - } - }, promise.fail) + let task = EmbeddedScheduledTask( + id: taskID, + readyTime: deadline, + insertOrder: self.nextTaskNumber(), + task: { + do { + promise.succeed(try task()) + } catch let err { + promise.fail(err) + } + }, + promise.fail + ) self.scheduledTasks.push(task) } @@ -145,15 +150,18 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { let promise: EventLoopPromise = self.makePromise() let taskID = self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed) - let scheduled = Scheduled(promise: promise, cancellationTask: { - if self.inEventLoop { - self.removeTask(taskID: taskID) - } else { - self.queue.async { + let scheduled = Scheduled( + promise: promise, + cancellationTask: { + if self.inEventLoop { self.removeTask(taskID: taskID) + } else { + self.queue.async { + self.removeTask(taskID: taskID) + } } } - }) + ) if self.inEventLoop { self.insertTask(taskID: taskID, deadline: deadline, promise: promise, task: task) @@ -168,7 +176,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { /// - see: `EventLoop.scheduleTask(in:_:)` @discardableResult public func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled { - return self.scheduleTask(deadline: self.now + `in`, task) + self.scheduleTask(deadline: self.now + `in`, task) } /// On an `NIOAsyncTestingEventLoop`, `execute` will simply use `scheduleTask` with a deadline of _now_. Unlike with the other operations, this will @@ -188,7 +196,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { // Now we want to grab all tasks that are ready to execute at the same // time as the first. - while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime { + while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime + { tasks.append(candidateTask) self.scheduledTasks.pop() } @@ -236,7 +245,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { // Now we want to grab all tasks that are ready to execute at the same // time as the first. - while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime { + while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime + { tasks.append(candidateTask) self.scheduledTasks.pop() } @@ -269,7 +279,9 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { /// /// Be careful not to try to spin the event loop again from within this callback, however. As long as this function is on the call /// stack the `NIOAsyncTestingEventLoop` cannot progress, and so any attempt to progress it will block until this function returns. - public func executeInContext(_ task: @escaping @Sendable () throws -> ReturnType) async throws -> ReturnType { + public func executeInContext( + _ task: @escaping @Sendable () throws -> ReturnType + ) async throws -> ReturnType { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in self.queue.async { do { @@ -328,8 +340,9 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { self._promiseCreationStore.promiseCreated(futureIdentifier: futureIdentifier, file: file, line: line) } - public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? { - return self._promiseCreationStore.promiseCompleted(futureIdentifier: futureIdentifier) + public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? + { + self._promiseCreationStore.promiseCompleted(futureIdentifier: futureIdentifier) } public func _preconditionSafeToSyncShutdown(file: StaticString, line: UInt) { @@ -352,7 +365,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable { // MARK: SerialExecutor conformance #if compiler(>=5.9) @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) -extension NIOAsyncTestingEventLoop: NIOSerialEventLoopExecutor { } +extension NIOAsyncTestingEventLoop: NIOSerialEventLoopExecutor {} #endif /// This is a thread-safe promise creation store. diff --git a/Sources/NIOEmbedded/Embedded.swift b/Sources/NIOEmbedded/Embedded.swift index 4a8ce0d218..8f808e3f02 100644 --- a/Sources/NIOEmbedded/Embedded.swift +++ b/Sources/NIOEmbedded/Embedded.swift @@ -13,20 +13,26 @@ //===----------------------------------------------------------------------===// import Atomics -import NIOConcurrencyHelpers +import DequeModule import Dispatch -import _NIODataStructures +import NIOConcurrencyHelpers import NIOCore -import DequeModule +import _NIODataStructures internal struct EmbeddedScheduledTask { let id: UInt64 let task: () -> Void - let failFn: (Error) -> () + let failFn: (Error) -> Void let readyTime: NIODeadline let insertOrder: UInt64 - init(id: UInt64, readyTime: NIODeadline, insertOrder: UInt64, task: @escaping () -> Void, _ failFn: @escaping (Error) -> ()) { + init( + id: UInt64, + readyTime: NIODeadline, + insertOrder: UInt64, + task: @escaping () -> Void, + _ failFn: @escaping (Error) -> Void + ) { self.id = id self.readyTime = readyTime self.insertOrder = insertOrder @@ -49,7 +55,7 @@ extension EmbeddedScheduledTask: Comparable { } static func == (lhs: EmbeddedScheduledTask, rhs: EmbeddedScheduledTask) -> Bool { - return lhs.id == rhs.id + lhs.id == rhs.id } } @@ -72,7 +78,7 @@ extension EmbeddedScheduledTask: Comparable { /// unsynchronized fashion. public final class EmbeddedEventLoop: EventLoop { /// The current "time" for this event loop. This is an amount in nanoseconds. - /* private but tests */ internal var _now: NIODeadline = .uptimeNanoseconds(0) + internal var _now: NIODeadline = .uptimeNanoseconds(0) private var scheduledTaskCounter: UInt64 = 0 private var scheduledTasks = PriorityQueue() @@ -94,29 +100,38 @@ public final class EmbeddedEventLoop: EventLoop { /// - see: `EventLoop.inEventLoop` public var inEventLoop: Bool { - return true + true } /// Initialize a new `EmbeddedEventLoop`. - public init() { } + public init() {} /// - see: `EventLoop.scheduleTask(deadline:_:)` @discardableResult public func scheduleTask(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled { let promise: EventLoopPromise = makePromise() self.scheduledTaskCounter += 1 - let task = EmbeddedScheduledTask(id: self.scheduledTaskCounter, readyTime: deadline, insertOrder: self.nextTaskNumber(), task: { - do { - promise.succeed(try task()) - } catch let err { - promise.fail(err) - } - }, promise.fail) + let task = EmbeddedScheduledTask( + id: self.scheduledTaskCounter, + readyTime: deadline, + insertOrder: self.nextTaskNumber(), + task: { + do { + promise.succeed(try task()) + } catch let err { + promise.fail(err) + } + }, + promise.fail + ) let taskId = task.id - let scheduled = Scheduled(promise: promise, cancellationTask: { - self.scheduledTasks.removeFirst { $0.id == taskId } - }) + let scheduled = Scheduled( + promise: promise, + cancellationTask: { + self.scheduledTasks.removeFirst { $0.id == taskId } + } + ) scheduledTasks.push(task) return scheduled } @@ -124,7 +139,7 @@ public final class EmbeddedEventLoop: EventLoop { /// - see: `EventLoop.scheduleTask(in:_:)` @discardableResult public func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled { - return scheduleTask(deadline: self._now + `in`, task) + scheduleTask(deadline: self._now + `in`, task) } /// On an `EmbeddedEventLoop`, `execute` will simply use `scheduleTask` with a deadline of _now_. This means that @@ -163,7 +178,7 @@ public final class EmbeddedEventLoop: EventLoop { // Now we want to grab all tasks that are ready to execute at the same // time as the first. - var tasks = Array() + var tasks = [EmbeddedScheduledTask]() while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime { tasks.append(candidateTask) self.scheduledTasks.pop() @@ -220,7 +235,8 @@ public final class EmbeddedEventLoop: EventLoop { self._promiseCreationStore[futureIdentifier] = (file: file, line: line) } - public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? { + public func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? + { precondition(_isDebugAssertConfiguration()) return self._promiseCreationStore.removeValue(forKey: futureIdentifier) } @@ -237,7 +253,9 @@ public final class EmbeddedEventLoop: EventLoop { #if compiler(>=5.9) @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) public var executor: any SerialExecutor { - fatalError("EmbeddedEventLoop is not thread safe and cannot be used as a SerialExecutor. Use NIOAsyncTestingEventLoop instead.") + fatalError( + "EmbeddedEventLoop is not thread safe and cannot be used as a SerialExecutor. Use NIOAsyncTestingEventLoop instead." + ) } #endif } @@ -246,7 +264,7 @@ public final class EmbeddedEventLoop: EventLoop { class EmbeddedChannelCore: ChannelCore { var isOpen: Bool { get { - return self._isOpen.load(ordering: .sequentiallyConsistent) + self._isOpen.load(ordering: .sequentiallyConsistent) } set { self._isOpen.store(newValue, ordering: .sequentiallyConsistent) @@ -255,7 +273,7 @@ class EmbeddedChannelCore: ChannelCore { var isActive: Bool { get { - return self._isActive.load(ordering: .sequentiallyConsistent) + self._isActive.load(ordering: .sequentiallyConsistent) } set { self._isActive.store(newValue, ordering: .sequentiallyConsistent) @@ -264,7 +282,7 @@ class EmbeddedChannelCore: ChannelCore { var allowRemoteHalfClosure: Bool { get { - return self._allowRemoteHalfClosure.load(ordering: .sequentiallyConsistent) + self._allowRemoteHalfClosure.load(ordering: .sequentiallyConsistent) } set { self._allowRemoteHalfClosure.store(newValue, ordering: .sequentiallyConsistent) @@ -289,8 +307,10 @@ class EmbeddedChannelCore: ChannelCore { } deinit { - assert(!self.isOpen && !self.isActive, - "leaked an open EmbeddedChannel, maybe forgot to call channel.finish()?") + assert( + !self.isOpen && !self.isActive, + "leaked an open EmbeddedChannel, maybe forgot to call channel.finish()?" + ) isOpen = false closePromise.succeed(()) } @@ -305,7 +325,9 @@ class EmbeddedChannelCore: ChannelCore { /// Contains the unflushed items that went into the `Channel` @usableFromInline - var pendingOutboundBuffer: MarkedCircularBuffer<(NIOAny, EventLoopPromise?)> = MarkedCircularBuffer(initialCapacity: 16) + var pendingOutboundBuffer: MarkedCircularBuffer<(NIOAny, EventLoopPromise?)> = MarkedCircularBuffer( + initialCapacity: 16 + ) /// Contains the items that travelled the `ChannelPipeline` all the way and hit the tail channel handler. On a /// regular `Channel` these items would be lost. @@ -509,7 +531,7 @@ public final class EmbeddedChannel: Channel { /// `true` if the `EmbeddedChannel` if there was unconsumed inbound, outbound, or pending outbound data left /// on the `Channel` when it was `finish`ed. public var hasLeftOvers: Bool { - return !self.isClean + !self.isClean } } @@ -538,7 +560,7 @@ public final class EmbeddedChannel: Channel { /// Returns `true` if the buffer was non-empty. public var isFull: Bool { - return !self.isEmpty + !self.isEmpty } } @@ -557,7 +579,7 @@ public final class EmbeddedChannel: Channel { } public static func == (lhs: WrongTypeError, rhs: WrongTypeError) -> Bool { - return lhs.expected == rhs.expected && lhs.actual == rhs.actual + lhs.expected == rhs.expected && lhs.actual == rhs.actual } } @@ -566,12 +588,12 @@ public final class EmbeddedChannel: Channel { /// An active `EmbeddedChannel` can be closed by calling `close` or `finish` on the `EmbeddedChannel`. /// /// - note: An `EmbeddedChannel` starts _inactive_ and can be activated, for example by calling `connect`. - public var isActive: Bool { return channelcore.isActive } + public var isActive: Bool { channelcore.isActive } /// - see: `ChannelOptions.Types.AllowRemoteHalfClosureOption` public var allowRemoteHalfClosure: Bool { get { - return channelcore.allowRemoteHalfClosure + channelcore.allowRemoteHalfClosure } set { channelcore.allowRemoteHalfClosure = newValue @@ -579,19 +601,22 @@ public final class EmbeddedChannel: Channel { } /// - see: `Channel.closeFuture` - public var closeFuture: EventLoopFuture { return channelcore.closePromise.futureResult } + public var closeFuture: EventLoopFuture { channelcore.closePromise.futureResult } @usableFromInline - /*private but usableFromInline */ lazy var channelcore: EmbeddedChannelCore = EmbeddedChannelCore(pipeline: self._pipeline, eventLoop: self.eventLoop) + lazy var channelcore: EmbeddedChannelCore = EmbeddedChannelCore( + pipeline: self._pipeline, + eventLoop: self.eventLoop + ) /// - see: `Channel._channelCore` public var _channelCore: ChannelCore { - return channelcore + channelcore } /// - see: `Channel.pipeline` public var pipeline: ChannelPipeline { - return _pipeline + _pipeline } /// - see: `Channel.isWritable` @@ -622,9 +647,11 @@ public final class EmbeddedChannel: Channel { if c.outboundBuffer.isEmpty && c.inboundBuffer.isEmpty && c.pendingOutboundBuffer.isEmpty { return .clean } else { - return .leftOvers(inbound: Array(c.inboundBuffer), - outbound: Array(c.outboundBuffer), - pendingOutbound: c.pendingOutboundBuffer.map { $0.0 }) + return .leftOvers( + inbound: Array(c.inboundBuffer), + outbound: Array(c.outboundBuffer), + pendingOutbound: c.pendingOutboundBuffer.map { $0.0 } + ) } } @@ -638,7 +665,7 @@ public final class EmbeddedChannel: Channel { /// writes) this will be `.clean`. If there are any unconsumed inbound, outbound, or pending outbound /// events, the `EmbeddedChannel` will returns those as `.leftOvers(inbound:outbound:pendingOutbound:)`. public func finish() throws -> LeftOverState { - return try self.finish(acceptAlreadyClosed: false) + try self.finish(acceptAlreadyClosed: false) } private var _pipeline: ChannelPipeline! @@ -648,7 +675,7 @@ public final class EmbeddedChannel: Channel { /// - see: `Channel.eventLoop` public var eventLoop: EventLoop { - return self.embeddedEventLoop + self.embeddedEventLoop } /// Returns the `EmbeddedEventLoop` that this `EmbeddedChannel` uses. This will return the same instance as @@ -692,7 +719,7 @@ public final class EmbeddedChannel: Channel { /// `ChannelHandler`. @inlinable public func readOutbound(as type: T.Type = T.self) throws -> T? { - return try _readFromBuffer(buffer: &channelcore.outboundBuffer) + try _readFromBuffer(buffer: &channelcore.outboundBuffer) } /// If available, this method reads one element of type `T` out of the `EmbeddedChannel`'s inbound buffer. If the @@ -707,7 +734,7 @@ public final class EmbeddedChannel: Channel { /// - note: `EmbeddedChannel.writeInbound` will fire data through the `ChannelPipeline` using `fireChannelRead`. @inlinable public func readInbound(as type: T.Type = T.self) throws -> T? { - return try _readFromBuffer(buffer: &channelcore.inboundBuffer) + try _readFromBuffer(buffer: &channelcore.inboundBuffer) } /// Sends an inbound `channelRead` event followed by a `channelReadComplete` event through the `ChannelPipeline`. @@ -760,7 +787,10 @@ public final class EmbeddedChannel: Channel { } let elem = buffer.removeFirst() guard let t = self._channelCore.tryUnwrapData(elem, as: T.self) else { - throw WrongTypeError(expected: T.self, actual: type(of: self._channelCore.tryUnwrapData(elem, as: Any.self)!)) + throw WrongTypeError( + expected: T.self, + actual: type(of: self._channelCore.tryUnwrapData(elem, as: Any.self)!) + ) } return t } @@ -813,8 +843,8 @@ public final class EmbeddedChannel: Channel { /// - see: `Channel.getOption` @inlinable - public func getOption(_ option: Option) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(self.getOptionSync(option)) + public func getOption(_ option: Option) -> EventLoopFuture { + self.eventLoop.makeSucceededFuture(self.getOptionSync(option)) } @inlinable @@ -873,12 +903,12 @@ extension EmbeddedChannel { @inlinable public func getOption(_ option: Option) throws -> Option.Value { - return self.channel.getOptionSync(option) + self.channel.getOptionSync(option) } } public final var syncOptions: NIOSynchronousChannelOptions? { - return SynchronousOptions(channel: self) + SynchronousOptions(channel: self) } } diff --git a/Sources/NIOFileSystem/BufferedReader.swift b/Sources/NIOFileSystem/BufferedReader.swift index 6d0fbdabac..d66aa5d95e 100644 --- a/Sources/NIOFileSystem/BufferedReader.swift +++ b/Sources/NIOFileSystem/BufferedReader.swift @@ -44,7 +44,7 @@ public struct BufferedReader { /// The number of bytes currently in the buffer. public var count: Int { - return self.buffer.readableBytes + self.buffer.readableBytes } internal init(wrapping readableHandle: Handle, initialOffset: Int64, capacity: Int) { @@ -106,20 +106,6 @@ public struct BufferedReader { } } - /// Reads from the current position in the file until `predicate` returns `false` and returns - /// the read bytes. - /// - /// - Parameters: - /// - predicate: A predicate which evaluates to `true` for all bytes returned. - /// - Returns: The bytes read from the file. - /// - Important: This method has been deprecated: use ``read(while:)-8aukk`` instead. - @available(*, deprecated, message: "Use the read(while:) method returning a (ByteBuffer, Bool) tuple instead.") - public mutating func read( - while predicate: (UInt8) -> Bool - ) async throws -> ByteBuffer { - try await self.read(while: predicate).bytes - } - /// Reads from the current position in the file until `predicate` returns `false` and returns /// the read bytes. /// @@ -139,7 +125,7 @@ public struct BufferedReader { let prefix = view[.. { } } +// swift-format-ignore: AmbiguousTrailingClosureOverload +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension BufferedReader { + /// Reads from the current position in the file until `predicate` returns `false` and returns + /// the read bytes. + /// + /// - Parameters: + /// - predicate: A predicate which evaluates to `true` for all bytes returned. + /// - Returns: The bytes read from the file. + /// - Important: This method has been deprecated: use ``read(while:)-8aukk`` instead. + @available(*, deprecated, message: "Use the read(while:) method returning a (ByteBuffer, Bool) tuple instead.") + public mutating func read( + while predicate: (UInt8) -> Bool + ) async throws -> ByteBuffer { + try await self.read(while: predicate).bytes + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension ReadableFileHandleProtocol { /// Creates a new ``BufferedReader`` for this file handle. @@ -233,7 +237,7 @@ extension ReadableFileHandleProtocol { startingAtAbsoluteOffset initialOffset: Int64 = 0, capacity: ByteCount = .kibibytes(512) ) -> BufferedReader { - return BufferedReader(wrapping: self, initialOffset: initialOffset, capacity: Int(capacity.bytes)) + BufferedReader(wrapping: self, initialOffset: initialOffset, capacity: Int(capacity.bytes)) } } diff --git a/Sources/NIOFileSystem/BufferedWriter.swift b/Sources/NIOFileSystem/BufferedWriter.swift index 9de3cec25c..3f02f22181 100644 --- a/Sources/NIOFileSystem/BufferedWriter.swift +++ b/Sources/NIOFileSystem/BufferedWriter.swift @@ -50,13 +50,13 @@ public struct BufferedWriter { /// /// You can flush the buffer manually by calling ``flush()``. public var bufferedBytes: Int { - return self.buffer.count + self.buffer.count } /// The capacity of the buffer. @_spi(Testing) public var bufferCapacity: Int { - return self.buffer.capacity + self.buffer.capacity } internal init(wrapping writableHandle: Handle, initialOffset: Int64, capacity: Int) { @@ -215,13 +215,13 @@ extension WritableFileHandleProtocol { startingAtAbsoluteOffset initialOffset: Int64 = 0, capacity: ByteCount = .kibibytes(512) ) -> BufferedWriter { - return BufferedWriter( + BufferedWriter( wrapping: self, initialOffset: initialOffset, capacity: Int(capacity.bytes) ) } - + /// Convenience function that creates a buffered reader, executes /// the closure that writes the contents into the buffer and calls 'flush()'. /// @@ -238,7 +238,7 @@ extension WritableFileHandleProtocol { ) async throws -> Result { var bufferedWriter = self.bufferedWriter(startingAtAbsoluteOffset: initialOffset, capacity: capacity) return try await withUncancellableTearDown { - return try await body(&bufferedWriter) + try await body(&bufferedWriter) } tearDown: { _ in try await bufferedWriter.flush() } diff --git a/Sources/NIOFileSystem/ByteCount.swift b/Sources/NIOFileSystem/ByteCount.swift index 05593ff9e9..ef5f443322 100644 --- a/Sources/NIOFileSystem/ByteCount.swift +++ b/Sources/NIOFileSystem/ByteCount.swift @@ -22,7 +22,7 @@ public struct ByteCount: Hashable, Sendable { /// Returns a ``ByteCount`` with a given number of bytes /// - Parameter count: The number of bytes public static func bytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: count) + ByteCount(bytes: count) } /// Returns a ``ByteCount`` with a given number of kilobytes @@ -31,7 +31,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of kilobytes public static func kilobytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1000 * count) + ByteCount(bytes: 1000 * count) } /// Returns a ``ByteCount`` with a given number of megabytes @@ -40,7 +40,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of megabytes public static func megabytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1000 * 1000 * count) + ByteCount(bytes: 1000 * 1000 * count) } /// Returns a ``ByteCount`` with a given number of gigabytes @@ -49,7 +49,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of gigabytes public static func gigabytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1000 * 1000 * 1000 * count) + ByteCount(bytes: 1000 * 1000 * 1000 * count) } /// Returns a ``ByteCount`` with a given number of kibibytes @@ -58,7 +58,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of kibibytes public static func kibibytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1024 * count) + ByteCount(bytes: 1024 * count) } /// Returns a ``ByteCount`` with a given number of mebibytes @@ -67,7 +67,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of mebibytes public static func mebibytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1024 * 1024 * count) + ByteCount(bytes: 1024 * 1024 * count) } /// Returns a ``ByteCount`` with a given number of gibibytes @@ -76,7 +76,7 @@ public struct ByteCount: Hashable, Sendable { /// /// - Parameter count: The number of gibibytes public static func gibibytes(_ count: Int64) -> ByteCount { - return ByteCount(bytes: 1024 * 1024 * 1024 * count) + ByteCount(bytes: 1024 * 1024 * 1024 * count) } } diff --git a/Sources/NIOFileSystem/DirectoryEntries.swift b/Sources/NIOFileSystem/DirectoryEntries.swift index 08d6a983f4..60734ff402 100644 --- a/Sources/NIOFileSystem/DirectoryEntries.swift +++ b/Sources/NIOFileSystem/DirectoryEntries.swift @@ -40,7 +40,7 @@ public struct DirectoryEntries: AsyncSequence { } public func makeAsyncIterator() -> DirectoryIterator { - return DirectoryIterator(iterator: self.batchedSequence.makeAsyncIterator()) + DirectoryIterator(iterator: self.batchedSequence.makeAsyncIterator()) } /// Returns a sequence of directory entry batches. @@ -49,7 +49,7 @@ public struct DirectoryEntries: AsyncSequence { /// than `DirectoryEntry`. This can enable better performance by reducing the number of /// executor hops. public func batched() -> Batched { - return self.batchedSequence + self.batchedSequence } /// An `AsyncIteratorProtocol` of `DirectoryEntry`. @@ -111,7 +111,7 @@ extension DirectoryEntries { } public func makeAsyncIterator() -> BatchedIterator { - return BatchedIterator(wrapping: self.stream.makeAsyncIterator()) + BatchedIterator(wrapping: self.stream.makeAsyncIterator()) } /// An `AsyncIteratorProtocol` of `Array`. @@ -208,7 +208,7 @@ private struct DirectoryEntryProducer { } private func nextBatch() throws -> [DirectoryEntry] { - return try self.state.withLockedValue { state in + try self.state.withLockedValue { state in try state.next(self.entriesPerBatch) } } @@ -422,7 +422,7 @@ private struct DirectoryEnumerator: Sendable { private mutating func makeReaddirSource( _ handle: SystemFileHandle.SendableView ) -> Result { - return handle._duplicate().mapError { dupError in + handle._duplicate().mapError { dupError in FileSystemError( message: "Unable to open directory stream for '\(handle.path)'.", wrapping: dupError @@ -443,7 +443,7 @@ private struct DirectoryEnumerator: Sendable { private mutating func makeFTSSource( _ handle: SystemFileHandle.SendableView ) -> Result { - return Libc.ftsOpen(handle.path, options: [.noChangeDir, .physical]).mapError { errno in + Libc.ftsOpen(handle.path, options: [.noChangeDir, .physical]).mapError { errno in FileSystemError.open("fts_open", error: errno, path: handle.path, location: .here()) }.map { .fts($0) @@ -652,7 +652,7 @@ private struct DirectoryEnumerator: Sendable { extension UnsafeMutablePointer { fileprivate var path: FilePath { - return FilePath(platformString: self.pointee.fts_path!) + FilePath(platformString: self.pointee.fts_path!) } } diff --git a/Sources/NIOFileSystem/FileChunks.swift b/Sources/NIOFileSystem/FileChunks.swift index 4b87221410..b42d76e18d 100644 --- a/Sources/NIOFileSystem/FileChunks.swift +++ b/Sources/NIOFileSystem/FileChunks.swift @@ -63,7 +63,7 @@ public struct FileChunks: AsyncSequence { } public func makeAsyncIterator() -> FileChunkIterator { - return FileChunkIterator(wrapping: self.stream.makeAsyncIterator()) + FileChunkIterator(wrapping: self.stream.makeAsyncIterator()) } public struct FileChunkIterator: AsyncIteratorProtocol { @@ -161,7 +161,7 @@ private struct FileChunkProducer: Sendable { } private func readNextChunk() throws -> ByteBuffer { - return try self.state.withLockedValue { state in + try self.state.withLockedValue { state in state.produceMore() }.flatMap { if let (descriptor, range) = $0 { diff --git a/Sources/NIOFileSystem/FileHandleProtocol.swift b/Sources/NIOFileSystem/FileHandleProtocol.swift index e7607f7829..d58e33065a 100644 --- a/Sources/NIOFileSystem/FileHandleProtocol.swift +++ b/Sources/NIOFileSystem/FileHandleProtocol.swift @@ -220,7 +220,7 @@ extension ReadableFileHandleProtocol { in range: ClosedRange, chunkLength: ByteCount = .kibibytes(128) ) -> FileChunks { - return self.readChunks(in: Range(range), chunkLength: chunkLength) + self.readChunks(in: Range(range), chunkLength: chunkLength) } /// Returns an asynchronous sequence of chunks read from the file. @@ -235,7 +235,7 @@ extension ReadableFileHandleProtocol { in range: Range, chunkLength: ByteCount = .kibibytes(128) ) -> FileChunks { - return self.readChunks(in: range, chunkLength: chunkLength) + self.readChunks(in: range, chunkLength: chunkLength) } /// Returns an asynchronous sequence of chunks read from the file. @@ -298,7 +298,7 @@ extension ReadableFileHandleProtocol { in range: UnboundedRange, chunkLength: ByteCount = .kibibytes(128) ) -> FileChunks { - return self.readChunks(in: 0.. FileChunks { - return self.readChunks(in: ..., chunkLength: chunkLength) + self.readChunks(in: ..., chunkLength: chunkLength) } } @@ -514,7 +514,7 @@ extension WritableFileHandleProtocol { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension FileHandleProtocol { /// Sets the file's last access time to the given time. - /// + /// /// - Parameter time: The time to which the file's last access time should be set. /// /// - Throws: If there's an error updating the time. If this happens, the original value won't be modified. @@ -630,7 +630,7 @@ extension DirectoryFileHandleProtocol { /// The current (".") and parent ("..") directory entries are not included. The order of entries /// is arbitrary and should not be relied upon. public func listContents() -> DirectoryEntries { - return self.listContents(recursive: false) + self.listContents(recursive: false) } } @@ -663,7 +663,7 @@ extension DirectoryFileHandleProtocol { let handle = try await self.openFile(forReadingAt: path, options: options) return try await withUncancellableTearDown { - return try await body(handle) + try await body(handle) } tearDown: { _ in try await handle.close() } @@ -694,7 +694,7 @@ extension DirectoryFileHandleProtocol { let handle = try await self.openFile(forWritingAt: path, options: options) return try await withUncancellableTearDown { - return try await body(handle) + try await body(handle) } tearDown: { result in switch result { case .success: @@ -729,7 +729,7 @@ extension DirectoryFileHandleProtocol { let handle = try await self.openFile(forReadingAndWritingAt: path, options: options) return try await withUncancellableTearDown { - return try await body(handle) + try await body(handle) } tearDown: { result in switch result { case .success: @@ -755,7 +755,7 @@ extension DirectoryFileHandleProtocol { let handle = try await self.openDirectory(atPath: path, options: options) return try await withUncancellableTearDown { - return try await body(handle) + try await body(handle) } tearDown: { _ in try await handle.close() } diff --git a/Sources/NIOFileSystem/FileSystem.swift b/Sources/NIOFileSystem/FileSystem.swift index 7a55fecd94..16692a7854 100644 --- a/Sources/NIOFileSystem/FileSystem.swift +++ b/Sources/NIOFileSystem/FileSystem.swift @@ -85,7 +85,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { let threadPool = NIOThreadPool(numberOfThreads: numberOfThreads) threadPool.start() // Wait for the thread pool to start. - try? await threadPool.runIfActive { } + try? await threadPool.runIfActive {} self.init(threadPool: threadPool, ownsThreadPool: true) } @@ -263,7 +263,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { public func createTemporaryDirectory( template: FilePath ) async throws -> FilePath { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try self._createTemporaryDirectory(template: template).get() } } @@ -284,7 +284,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { forFileAt path: FilePath, infoAboutSymbolicLink: Bool ) async throws -> FileInfo? { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try self._info(forFileAt: path, infoAboutSymbolicLink: infoAboutSymbolicLink).get() } } @@ -564,7 +564,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { at linkPath: FilePath, withDestination destinationPath: FilePath ) async throws { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try self._createSymbolicLink(at: linkPath, withDestination: destinationPath).get() } } @@ -595,7 +595,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { public func destinationOfSymbolicLink( at path: FilePath ) async throws -> FilePath { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try self._destinationOfSymbolicLink(at: path).get() } } @@ -632,7 +632,7 @@ public struct FileSystem: Sendable, FileSystemProtocol { get async throws { #if canImport(Darwin) return try await self.threadPool.runIfActive { - return try Libc.constr(_CS_DARWIN_USER_TEMP_DIR).map { path in + try Libc.constr(_CS_DARWIN_USER_TEMP_DIR).map { path in FilePath(path) }.mapError { errno in FileSystemError.confstr( @@ -694,7 +694,7 @@ extension NIOSingletons { extension FileSystemProtocol where Self == FileSystem { /// A global shared instance of ``FileSystem``. public static var shared: FileSystem { - return FileSystem.shared + FileSystem.shared } } @@ -719,7 +719,7 @@ extension FileSystem { forReadingAt path: FilePath, options: OpenOptions.Read ) -> Result { - return SystemFileHandle.syncOpen( + SystemFileHandle.syncOpen( atPath: path, mode: .readOnly, options: options.descriptorOptions, @@ -736,7 +736,7 @@ extension FileSystem { forWritingAt path: FilePath, options: OpenOptions.Write ) -> Result { - return SystemFileHandle.syncOpen( + SystemFileHandle.syncOpen( atPath: path, mode: .writeOnly, options: options.descriptorOptions, @@ -753,7 +753,7 @@ extension FileSystem { forReadingAndWritingAt path: FilePath, options: OpenOptions.Write ) -> Result { - return SystemFileHandle.syncOpen( + SystemFileHandle.syncOpen( atPath: path, mode: .readWrite, options: options.descriptorOptions, @@ -770,7 +770,7 @@ extension FileSystem { at path: FilePath, options: OpenOptions.Directory ) -> Result { - return SystemFileHandle.syncOpen( + SystemFileHandle.syncOpen( atPath: path, mode: .readOnly, options: options.descriptorOptions, @@ -1001,7 +1001,7 @@ extension FileSystem { path: FilePath, location: FileSystemError.SourceLocation ) -> FileSystemError { - return FileSystemError( + FileSystemError( code: .closed, message: "Can't copy '\(sourcePath)' to '\(destinationPath)', '\(path)' is closed.", cause: nil, @@ -1354,7 +1354,7 @@ extension FileSystem { at linkPath: FilePath, withDestination destinationPath: FilePath ) -> Result { - return Syscall.symlink(to: destinationPath, from: linkPath).mapError { errno in + Syscall.symlink(to: destinationPath, from: linkPath).mapError { errno in FileSystemError.symlink( errno: errno, link: linkPath, diff --git a/Sources/NIOFileSystem/FileSystemError+Syscall.swift b/Sources/NIOFileSystem/FileSystemError+Syscall.swift index a59185bb25..a9309abdae 100644 --- a/Sources/NIOFileSystem/FileSystemError+Syscall.swift +++ b/Sources/NIOFileSystem/FileSystemError+Syscall.swift @@ -541,7 +541,7 @@ extension FileSystemError { @_spi(Testing) public static func fdopendir(errno: Errno, path: FilePath, location: SourceLocation) -> Self { - return FileSystemError( + FileSystemError( code: .unknown, message: "Unable to open directory stream for '\(path)'.", systemCall: "fdopendir", @@ -552,7 +552,7 @@ extension FileSystemError { @_spi(Testing) public static func readdir(errno: Errno, path: FilePath, location: SourceLocation) -> Self { - return FileSystemError( + FileSystemError( code: .unknown, message: "Unable to read directory stream for '\(path)'.", systemCall: "readdir", @@ -563,7 +563,7 @@ extension FileSystemError { @_spi(Testing) public static func ftsRead(errno: Errno, path: FilePath, location: SourceLocation) -> Self { - return FileSystemError( + FileSystemError( code: .unknown, message: "Unable to read FTS stream for '\(path)'.", systemCall: "fts_read", @@ -966,7 +966,7 @@ extension FileSystemError { @_spi(Testing) public static func getcwd(errno: Errno, location: SourceLocation) -> Self { - return FileSystemError( + FileSystemError( code: .unavailable, message: "Can't get current working directory.", systemCall: "getcwd", @@ -977,7 +977,7 @@ extension FileSystemError { @_spi(Testing) public static func confstr(name: String, errno: Errno, location: SourceLocation) -> Self { - return FileSystemError( + FileSystemError( code: .unavailable, message: "Can't get configuration value for '\(name)'.", systemCall: "confstr", @@ -1085,7 +1085,8 @@ extension FileSystemError { case .readOnlyFileSystem: code = .unsupported - message = "Not permitted to change last access or last data modification times for \(path): this is a read-only file system." + message = + "Not permitted to change last access or last data modification times for \(path): this is a read-only file system." case .badFileDescriptor: code = .closed diff --git a/Sources/NIOFileSystem/FileSystemError.swift b/Sources/NIOFileSystem/FileSystemError.swift index 2bfa20ea21..801fcd699c 100644 --- a/Sources/NIOFileSystem/FileSystemError.swift +++ b/Sources/NIOFileSystem/FileSystemError.swift @@ -118,7 +118,7 @@ extension FileSystemError { /// /// - Returns: A multi-line description of the error. public func detailedDescription() -> String { - return self.detailedDescriptionLines().joined(separator: "\n") + self.detailedDescriptionLines().joined(separator: "\n") } } @@ -256,7 +256,7 @@ extension FileSystemError { file: String = #fileID, line: Int = #line ) -> Self { - return SourceLocation(function: function, file: file, line: line) + SourceLocation(function: function, file: file, line: line) } } } diff --git a/Sources/NIOFileSystem/FileSystemProtocol.swift b/Sources/NIOFileSystem/FileSystemProtocol.swift index 0fee8f78a2..9bd024d70a 100644 --- a/Sources/NIOFileSystem/FileSystemProtocol.swift +++ b/Sources/NIOFileSystem/FileSystemProtocol.swift @@ -281,7 +281,7 @@ extension FileSystemProtocol { ) async throws -> Result { let handle = try await self.openFile(forReadingAt: path, options: options) return try await withUncancellableTearDown { - return try await execute(handle) + try await execute(handle) } tearDown: { _ in try await handle.close() } @@ -308,7 +308,7 @@ extension FileSystemProtocol { ) async throws -> Result { let handle = try await self.openFile(forWritingAt: path, options: options) return try await withUncancellableTearDown { - return try await execute(handle) + try await execute(handle) } tearDown: { result in switch result { case .success: @@ -340,7 +340,7 @@ extension FileSystemProtocol { ) async throws -> Result { let handle = try await self.openFile(forReadingAndWritingAt: path, options: options) return try await withUncancellableTearDown { - return try await execute(handle) + try await execute(handle) } tearDown: { _ in try await handle.close() } @@ -361,7 +361,7 @@ extension FileSystemProtocol { ) async throws -> Result { let handle = try await self.openDirectory(atPath: path, options: options) return try await withUncancellableTearDown { - return try await execute(handle) + try await execute(handle) } tearDown: { _ in try await handle.close() } @@ -406,7 +406,7 @@ extension FileSystemProtocol { /// - path: The path to get information about. /// - Returns: Information about the file at the given path or `nil` if no file exists. public func info(forFileAt path: FilePath) async throws -> FileInfo? { - return try await self.info(forFileAt: path, infoAboutSymbolicLink: false) + try await self.info(forFileAt: path, infoAboutSymbolicLink: false) } /// Copies the item at the specified path to a new location. @@ -429,7 +429,7 @@ extension FileSystemProtocol { try await self.copyItem(at: sourcePath, to: destinationPath) { entry, error in throw error } shouldCopyFile: { source, destination in - return true + true } } diff --git a/Sources/NIOFileSystem/Internal/BufferedOrAnyStream.swift b/Sources/NIOFileSystem/Internal/BufferedOrAnyStream.swift index 06a0ca54ab..fc7645bb1c 100644 --- a/Sources/NIOFileSystem/Internal/BufferedOrAnyStream.swift +++ b/Sources/NIOFileSystem/Internal/BufferedOrAnyStream.swift @@ -74,7 +74,7 @@ internal struct AnyAsyncSequence: AsyncSequence { } internal func makeAsyncIterator() -> AsyncIterator { - return self._makeAsyncIterator() + self._makeAsyncIterator() } internal struct AsyncIterator: AsyncIteratorProtocol { @@ -85,7 +85,7 @@ internal struct AnyAsyncSequence: AsyncSequence { } internal mutating func next() async throws -> Element? { - return try await self.iterator.next() as? Element + try await self.iterator.next() as? Element } } } diff --git a/Sources/NIOFileSystem/Internal/BufferedStream.swift b/Sources/NIOFileSystem/Internal/BufferedStream.swift index a08173dd04..6653e4f8c0 100644 --- a/Sources/NIOFileSystem/Internal/BufferedStream.swift +++ b/Sources/NIOFileSystem/Internal/BufferedStream.swift @@ -527,12 +527,12 @@ extension BufferedStream { func didYield(bufferDepth: Int) -> Bool { // We are demanding more until we reach the high watermark - return bufferDepth < self._high + bufferDepth < self._high } func didConsume(bufferDepth: Int) -> Bool { // We start demanding again once we are below the low watermark - return bufferDepth < self._low + bufferDepth < self._low } } @@ -656,7 +656,7 @@ extension BufferedStream { contentsOf sequence: some Sequence ) throws -> Source.WriteResult { let action = self._stateMachine.withCriticalRegion { - return $0.write(sequence) + $0.write(sequence) } switch action { @@ -781,8 +781,8 @@ extension BufferedStream { } func suspendNext() async throws -> Element? { - return try await withTaskCancellationHandler { - return try await withCheckedThrowingContinuation { continuation in + try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { continuation in let action = self._stateMachine.withCriticalRegion { $0.suspendNext(continuation: continuation) } diff --git a/Sources/NIOFileSystem/Internal/Cancellation.swift b/Sources/NIOFileSystem/Internal/Cancellation.swift index 4316c88676..307e78c6be 100644 --- a/Sources/NIOFileSystem/Internal/Cancellation.swift +++ b/Sources/NIOFileSystem/Internal/Cancellation.swift @@ -42,7 +42,7 @@ public func withUncancellableTearDown( result = .failure(error) } - let errorOnlyResult: Result = result.map { _ in return () } + let errorOnlyResult: Result = result.map { _ in () } let tearDownResult: Result = try await withoutCancellation { do { return .success(try await tearDown(errorOnlyResult)) diff --git a/Sources/NIOFileSystem/Internal/System Calls/Errno.swift b/Sources/NIOFileSystem/Internal/System Calls/Errno.swift index 3d0eb07907..0ac8876aff 100644 --- a/Sources/NIOFileSystem/Internal/System Calls/Errno.swift +++ b/Sources/NIOFileSystem/Internal/System Calls/Errno.swift @@ -89,7 +89,7 @@ public func nothingOrErrno( retryOnInterrupt: Bool = true, _ fn: () -> I ) -> Result { - return valueOrErrno(retryOnInterrupt: retryOnInterrupt, fn).map { _ in } + valueOrErrno(retryOnInterrupt: retryOnInterrupt, fn).map { _ in } } /// Returns a `Result` representing the value returned from the given closure diff --git a/Sources/NIOFileSystem/Internal/System Calls/FileDescriptor+Syscalls.swift b/Sources/NIOFileSystem/Internal/System Calls/FileDescriptor+Syscalls.swift index 307e194785..d49c549056 100644 --- a/Sources/NIOFileSystem/Internal/System Calls/FileDescriptor+Syscalls.swift +++ b/Sources/NIOFileSystem/Internal/System Calls/FileDescriptor+Syscalls.swift @@ -108,7 +108,7 @@ extension FileDescriptor { public func listExtendedAttributes( _ buffer: UnsafeMutableBufferPointer? ) -> Result { - return valueOrErrno(retryOnInterrupt: false) { + valueOrErrno(retryOnInterrupt: false) { system_flistxattr(self.rawValue, buffer?.baseAddress, buffer?.count ?? 0) } } @@ -128,7 +128,7 @@ extension FileDescriptor { named name: String, buffer: UnsafeMutableRawBufferPointer? ) -> Result { - return valueOrErrno(retryOnInterrupt: false) { + valueOrErrno(retryOnInterrupt: false) { name.withPlatformString { system_fgetxattr(self.rawValue, $0, buffer?.baseAddress, buffer?.count ?? 0) } @@ -198,7 +198,7 @@ extension FileDescriptor { extension FileDescriptor { func listExtendedAttributes() -> Result<[String], Errno> { // Required capacity is returned if a no buffer is passed to flistxattr. - return self.listExtendedAttributes(nil).flatMap { capacity in + self.listExtendedAttributes(nil).flatMap { capacity in guard capacity > 0 else { // Required capacity is zero: no attributes to read. return .success([]) @@ -226,7 +226,7 @@ extension FileDescriptor { func readExtendedAttribute(named name: String) -> Result<[UInt8], Errno> { // Required capacity is returned if a no buffer is passed to fgetxattr. - return self.getExtendedAttribute(named: name, buffer: nil).flatMap { capacity in + self.getExtendedAttribute(named: name, buffer: nil).flatMap { capacity in guard capacity > 0 else { // Required capacity is zero: no values to read. return .success([]) @@ -263,7 +263,7 @@ extension FileDescriptor { // factor here. However we should investigate whether it's possible to have a pool of // buffers which we can reuse. This would need to be at least as large as the high watermark // of the chunked file for it to be useful. - return Result { + Result { var buffer = ByteBuffer() try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: Int(length)) { buffer in let bufferPointer: UnsafeMutableRawBufferPointer @@ -295,7 +295,7 @@ extension FileDescriptor { contentsOf bytes: some Sequence, toAbsoluteOffset offset: Int64 ) -> Result { - return Result { + Result { Int64(try self.writeAll(toAbsoluteOffset: offset, bytes)) } } @@ -303,7 +303,7 @@ extension FileDescriptor { func write( contentsOf bytes: some Sequence ) -> Result { - return Result { + Result { Int64(try self.writeAll(bytes)) } } diff --git a/Sources/NIOFileSystem/Internal/System Calls/Mocking.swift b/Sources/NIOFileSystem/Internal/System Calls/Mocking.swift index 264e4c92f7..8ca8c18962 100644 --- a/Sources/NIOFileSystem/Internal/System Calls/Mocking.swift +++ b/Sources/NIOFileSystem/Internal/System Calls/Mocking.swift @@ -12,14 +12,10 @@ // //===----------------------------------------------------------------------===// -/* - This source file is part of the Swift System open source project - - Copyright (c) 2020 Apple Inc. and the Swift System project authors - Licensed under Apache License v2.0 with Runtime Library Exception - - See https://swift.org/LICENSE.txt for license information - */ +// This source file is part of the Swift System open source project// +// Copyright (c) 2020 Apple Inc. and the Swift System project authors +// Licensed under Apache License v2.0 with Runtime Library Exception// +// See https://swift.org/LICENSE.txt for license information #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(Linux) || os(Android) import SystemPackage @@ -141,7 +137,7 @@ extension MockingDriver { // Check TLS for mocking @inline(never) private var contextualMockingEnabled: Bool { - return currentMockingDriver != nil + currentMockingDriver != nil } extension MockingDriver { @@ -212,7 +208,7 @@ private func mockImpl(syscall name: String, args: [AnyHashable]) -> CInt { } private func reinterpret(_ args: [AnyHashable?]) -> [AnyHashable] { - return args.map { arg in + args.map { arg in switch arg { case let charPointer as UnsafePointer: return String(_errorCorrectingPlatformString: charPointer) @@ -234,14 +230,14 @@ func mock( syscall name: String = #function, _ args: AnyHashable?... ) -> CInt { - return mockImpl(syscall: name, args: reinterpret(args)) + mockImpl(syscall: name, args: reinterpret(args)) } func mockInt( syscall name: String = #function, _ args: AnyHashable?... ) -> Int { - return Int(mockImpl(syscall: name, args: reinterpret(args))) + Int(mockImpl(syscall: name, args: reinterpret(args))) } #endif // ENABLE_MOCKING @@ -312,7 +308,7 @@ internal func system_strlen(_ s: UnsafeMutablePointer) -> Int { // strlen for the platform string internal func system_platform_strlen(_ s: UnsafePointer) -> Int { - return strlen(s) + strlen(s) } // memset for raw buffers @@ -331,7 +327,7 @@ extension String { _ body: (UnsafePointer) throws -> Result ) rethrows -> Result { // Need to #if because CChar may be signed - return try withCString(body) + try withCString(body) } internal init?(_platformString platformString: UnsafePointer) { @@ -364,6 +360,6 @@ internal func setTLS(_ key: _PlatformTLSKey, _ p: UnsafeMutableRawPointer?) { } internal func getTLS(_ key: _PlatformTLSKey) -> UnsafeMutableRawPointer? { - return pthread_getspecific(key) + pthread_getspecific(key) } #endif diff --git a/Sources/NIOFileSystem/Internal/System Calls/Syscall.swift b/Sources/NIOFileSystem/Internal/System Calls/Syscall.swift index 308c239c0f..aaa03091e6 100644 --- a/Sources/NIOFileSystem/Internal/System Calls/Syscall.swift +++ b/Sources/NIOFileSystem/Internal/System Calls/Syscall.swift @@ -30,7 +30,7 @@ import CNIOLinux public enum Syscall { @_spi(Testing) public static func stat(path: FilePath) -> Result { - return path.withPlatformString { platformPath in + path.withPlatformString { platformPath in var status = CInterop.Stat() return valueOrErrno(retryOnInterrupt: false) { system_stat(platformPath, &status) @@ -42,7 +42,7 @@ public enum Syscall { @_spi(Testing) public static func lstat(path: FilePath) -> Result { - return path.withPlatformString { platformPath in + path.withPlatformString { platformPath in var status = CInterop.Stat() return valueOrErrno(retryOnInterrupt: false) { system_lstat(platformPath, &status) @@ -54,7 +54,7 @@ public enum Syscall { @_spi(Testing) public static func mkdir(at path: FilePath, permissions: FilePermissions) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { path.withPlatformString { p in system_mkdir(p, permissions.rawValue) } @@ -63,7 +63,7 @@ public enum Syscall { @_spi(Testing) public static func rename(from old: FilePath, to new: FilePath) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { old.withPlatformString { oldPath in new.withPlatformString { newPath in system_rename(oldPath, newPath) @@ -79,7 +79,7 @@ public enum Syscall { to new: FilePath, options: RenameOptions ) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { old.withPlatformString { oldPath in new.withPlatformString { newPath in system_renamex_np(oldPath, newPath, options.rawValue) @@ -97,11 +97,11 @@ public enum Syscall { } public static var exclusive: Self { - return Self(rawValue: UInt32(bitPattern: RENAME_EXCL)) + Self(rawValue: UInt32(bitPattern: RENAME_EXCL)) } public static var swap: Self { - return Self(rawValue: UInt32(bitPattern: RENAME_SWAP)) + Self(rawValue: UInt32(bitPattern: RENAME_SWAP)) } } #endif @@ -115,7 +115,7 @@ public enum Syscall { relativeTo newFD: FileDescriptor, flags: RenameAtFlags ) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { old.withPlatformString { oldPath in new.withPlatformString { newPath in system_renameat2( @@ -139,11 +139,11 @@ public enum Syscall { } public static var exclusive: Self { - return Self(rawValue: CNIOLinux_RENAME_NOREPLACE) + Self(rawValue: CNIOLinux_RENAME_NOREPLACE) } public static var swap: Self { - return Self(rawValue: CNIOLinux_RENAME_EXCHANGE) + Self(rawValue: CNIOLinux_RENAME_EXCHANGE) } } #endif @@ -178,7 +178,7 @@ public enum Syscall { relativeTo destinationFD: FileDescriptor, flags: LinkAtFlags ) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { source.withPlatformString { src in destination.withPlatformString { dst in system_linkat( @@ -199,7 +199,7 @@ public enum Syscall { from source: FilePath, to destination: FilePath ) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { source.withPlatformString { src in destination.withPlatformString { dst in system_link(src, dst) @@ -210,7 +210,7 @@ public enum Syscall { @_spi(Testing) public static func unlink(path: FilePath) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { path.withPlatformString { ptr in system_unlink(ptr) } @@ -222,7 +222,7 @@ public enum Syscall { to destination: FilePath, from source: FilePath ) -> Result { - return nothingOrErrno(retryOnInterrupt: false) { + nothingOrErrno(retryOnInterrupt: false) { source.withPlatformString { src in destination.withPlatformString { dst in system_symlink(dst, src) @@ -381,7 +381,10 @@ public enum Libc { pathBytes.withUnsafeMutableBufferPointer { pointer in // The array must be terminated with a nil. #if os(Android) - libc_fts_open([pointer.baseAddress!, unsafeBitCast(0, to: UnsafeMutablePointer.self)], options.rawValue) + libc_fts_open( + [pointer.baseAddress!, unsafeBitCast(0, to: UnsafeMutablePointer.self)], + options.rawValue + ) #else libc_fts_open([pointer.baseAddress, nil], options.rawValue) #endif diff --git a/Sources/NIOFileSystem/Internal/System Calls/Syscalls.swift b/Sources/NIOFileSystem/Internal/System Calls/Syscalls.swift index 284d59d779..8c3266124b 100644 --- a/Sources/NIOFileSystem/Internal/System Calls/Syscalls.swift +++ b/Sources/NIOFileSystem/Internal/System Calls/Syscalls.swift @@ -367,21 +367,21 @@ internal func system_futimens( internal func libc_fdopendir( _ fd: FileDescriptor.RawValue ) -> CInterop.DirPointer { - return fdopendir(fd)! + fdopendir(fd)! } /// readdir(3): Returns a pointer to the next directory entry internal func libc_readdir( _ dir: CInterop.DirPointer ) -> UnsafeMutablePointer? { - return readdir(dir) + readdir(dir) } /// readdir(3): Closes the directory stream and frees associated structures internal func libc_closedir( _ dir: CInterop.DirPointer ) -> CInt { - return closedir(dir) + closedir(dir) } /// remove(3): Remove directory entry @@ -418,7 +418,7 @@ internal func libc_getcwd( _ buffer: UnsafeMutablePointer, _ size: Int ) -> UnsafeMutablePointer? { - return getcwd(buffer, size) + getcwd(buffer, size) } /// confstr(3) @@ -428,7 +428,7 @@ internal func libc_confstr( _ buffer: UnsafeMutablePointer, _ size: Int ) -> Int { - return confstr(name, buffer, size) + confstr(name, buffer, size) } #endif @@ -438,14 +438,14 @@ internal func libc_fts_open( _ path: [UnsafeMutablePointer], _ options: CInt ) -> UnsafeMutablePointer { - return fts_open(path, options, nil)! + fts_open(path, options, nil)! } #else internal func libc_fts_open( _ path: [UnsafeMutablePointer?], _ options: CInt ) -> UnsafeMutablePointer { - return fts_open(path, options, nil) + fts_open(path, options, nil) } #endif @@ -453,13 +453,13 @@ internal func libc_fts_open( internal func libc_fts_read( _ fts: UnsafeMutablePointer ) -> UnsafeMutablePointer? { - return fts_read(fts) + fts_read(fts) } /// fts(3) internal func libc_fts_close( _ fts: UnsafeMutablePointer ) -> CInt { - return fts_close(fts) + fts_close(fts) } #endif diff --git a/Sources/NIOFileSystem/Internal/SystemFileHandle.swift b/Sources/NIOFileSystem/Internal/SystemFileHandle.swift index 34c8f62353..6f560c0de0 100644 --- a/Sources/NIOFileSystem/Internal/SystemFileHandle.swift +++ b/Sources/NIOFileSystem/Internal/SystemFileHandle.swift @@ -145,7 +145,7 @@ public final class SystemFileHandle { extension SystemFileHandle.SendableView { /// Returns the file descriptor if it's available; `nil` otherwise. internal func descriptorIfAvailable() -> FileDescriptor? { - return self.lifecycle.withLockedValue { + self.lifecycle.withLockedValue { switch $0 { case let .open(descriptor): return descriptor @@ -196,37 +196,37 @@ extension SystemFileHandle: FileHandleProtocol { // currently using. public func info() async throws -> FileInfo { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._info().get() } } public func replacePermissions(_ permissions: FilePermissions) async throws { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._replacePermissions(permissions) } } public func addPermissions(_ permissions: FilePermissions) async throws -> FilePermissions { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._addPermissions(permissions) } } public func removePermissions(_ permissions: FilePermissions) async throws -> FilePermissions { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._removePermissions(permissions) } } public func attributeNames() async throws -> [String] { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._attributeNames() } } public func valueForAttribute(_ name: String) async throws -> [UInt8] { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._valueForAttribute(name) } } @@ -235,19 +235,19 @@ extension SystemFileHandle: FileHandleProtocol { _ bytes: some (Sendable & RandomAccessCollection), attribute name: String ) async throws { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._updateValueForAttribute(bytes, attribute: name) } } public func removeValueForAttribute(_ name: String) async throws { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._removeValueForAttribute(name) } } public func synchronize() async throws { - return try await self.threadPool.runIfActive { [sendableView] in + try await self.threadPool.runIfActive { [sendableView] in try sendableView._synchronize() } } @@ -257,7 +257,7 @@ extension SystemFileHandle: FileHandleProtocol { ) async throws -> R { try await self.threadPool.runIfActive { [sendableView] in try sendableView._withUnsafeDescriptor { - return try execute($0) + try execute($0) } onUnavailable: { FileSystemError( code: .closed, @@ -270,7 +270,7 @@ extension SystemFileHandle: FileHandleProtocol { } public func detachUnsafeFileDescriptor() throws -> FileDescriptor { - return try self.sendableView.lifecycle.withLockedValue { lifecycle in + try self.sendableView.lifecycle.withLockedValue { lifecycle in switch lifecycle { case let .open(descriptor): lifecycle = .detached @@ -377,18 +377,18 @@ extension SystemFileHandle: FileHandleProtocol { extension SystemFileHandle.SendableView { /// Returns a string in the format: "{message}, the file '{path}' is closed." private func fileIsClosed(_ message: String) -> String { - return "\(message), the file '\(self.path)' is closed." + "\(message), the file '\(self.path)' is closed." } /// Returns a string in the format: "{message} for '{path}'." private func unknown(_ message: String) -> String { - return "\(message) for '\(self.path)'." + "\(message) for '\(self.path)'." } @_spi(Testing) public func _info() -> Result { self._withUnsafeDescriptorResult { descriptor in - return descriptor.status().map { stat in + descriptor.status().map { stat in FileInfo(platformSpecificStatus: stat) }.mapError { errno in .stat("fstat", errno: errno, path: self.path, location: .here()) @@ -507,7 +507,7 @@ extension SystemFileHandle.SendableView { operand: FilePermissions, descriptor: FileDescriptor ) throws { - return try descriptor.changeMode(permissions).mapError { errno in + try descriptor.changeMode(permissions).mapError { errno in FileSystemError.fchmod( operation: operation, operand: operand, @@ -521,8 +521,8 @@ extension SystemFileHandle.SendableView { @_spi(Testing) public func _attributeNames() throws -> [String] { - return try self._withUnsafeDescriptor { descriptor in - return try descriptor.listExtendedAttributes().mapError { errno in + try self._withUnsafeDescriptor { descriptor in + try descriptor.listExtendedAttributes().mapError { errno in FileSystemError.flistxattr(errno: errno, path: self.path, location: .here()) }.get() } onUnavailable: { @@ -537,8 +537,8 @@ extension SystemFileHandle.SendableView { @_spi(Testing) public func _valueForAttribute(_ name: String) throws -> [UInt8] { - return try self._withUnsafeDescriptor { descriptor in - return try descriptor.readExtendedAttribute( + try self._withUnsafeDescriptor { descriptor in + try descriptor.readExtendedAttribute( named: name ).flatMapError { errno -> Result<[UInt8], FileSystemError> in switch errno { @@ -577,7 +577,7 @@ extension SystemFileHandle.SendableView { _ bytes: some RandomAccessCollection, attribute name: String ) throws { - return try self._withUnsafeDescriptor { descriptor in + try self._withUnsafeDescriptor { descriptor in func withUnsafeBufferPointer(_ body: (UnsafeBufferPointer) throws -> Void) throws { try bytes.withContiguousStorageIfAvailable(body) ?? Array(bytes).withUnsafeBufferPointer(body) @@ -611,7 +611,7 @@ extension SystemFileHandle.SendableView { @_spi(Testing) public func _removeValueForAttribute(_ name: String) throws { - return try self._withUnsafeDescriptor { descriptor in + try self._withUnsafeDescriptor { descriptor in try descriptor.removeExtendedAttribute(name).mapError { errno in FileSystemError.fremovexattr( attribute: name, @@ -647,7 +647,7 @@ extension SystemFileHandle.SendableView { } internal func _duplicate() -> Result { - return self._withUnsafeDescriptorResult { descriptor in + self._withUnsafeDescriptorResult { descriptor in Result { try descriptor.duplicate() }.mapError { error in @@ -680,7 +680,11 @@ extension SystemFileHandle.SendableView { } // Materialize then close. - let materializeResult = self._materialize(materialize, descriptor: descriptor, failRenameat2WithEINVAL: failRenameat2WithEINVAL) + let materializeResult = self._materialize( + materialize, + descriptor: descriptor, + failRenameat2WithEINVAL: failRenameat2WithEINVAL + ) return Result { try descriptor.close() @@ -760,7 +764,7 @@ extension SystemFileHandle.SendableView { location: .here() ) }.flatMap { - return linkAtProcFS().mapError { errno in + linkAtProcFS().mapError { errno in FileSystemError.link( errno: errno, from: createdPath, @@ -1019,8 +1023,8 @@ extension SystemFileHandle: ReadableFileHandleProtocol { fromAbsoluteOffset offset: Int64, length: ByteCount ) async throws -> ByteBuffer { - return try await self.threadPool.runIfActive { [sendableView] in - return try sendableView._withUnsafeDescriptor { descriptor in + try await self.threadPool.runIfActive { [sendableView] in + try sendableView._withUnsafeDescriptor { descriptor in try descriptor.readChunk( fromAbsoluteOffset: offset, length: length.bytes @@ -1072,7 +1076,7 @@ extension SystemFileHandle: ReadableFileHandleProtocol { in range: Range, chunkLength size: ByteCount ) -> FileChunks { - return FileChunks(handle: self, chunkLength: size, range: range) + FileChunks(handle: self, chunkLength: size, range: range) } } @@ -1085,8 +1089,8 @@ extension SystemFileHandle: WritableFileHandleProtocol { contentsOf bytes: some (Sequence & Sendable), toAbsoluteOffset offset: Int64 ) async throws -> Int64 { - return try await self.threadPool.runIfActive { [sendableView] in - return try sendableView._withUnsafeDescriptor { descriptor in + try await self.threadPool.runIfActive { [sendableView] in + try sendableView._withUnsafeDescriptor { descriptor in try descriptor.write(contentsOf: bytes, toAbsoluteOffset: offset) .flatMapError { error in if let errno = error as? Errno, errno == .illegalSeek { @@ -1143,7 +1147,7 @@ extension SystemFileHandle: WritableFileHandleProtocol { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension SystemFileHandle.SendableView { func _resize(to size: ByteCount) -> Result<(), FileSystemError> { - return self._withUnsafeDescriptorResult { descriptor in + self._withUnsafeDescriptorResult { descriptor in Result { try descriptor.resize(to: size.bytes, retryOnInterrupt: true) }.mapError { error in @@ -1169,7 +1173,7 @@ extension SystemFileHandle: DirectoryFileHandleProtocol { public typealias ReadWriteFileHandle = SystemFileHandle public func listContents(recursive: Bool) -> DirectoryEntries { - return DirectoryEntries(handle: self, recursive: recursive) + DirectoryEntries(handle: self, recursive: recursive) } public func openFile( @@ -1370,7 +1374,7 @@ extension SystemFileHandle { permissions: FilePermissions?, threadPool: NIOThreadPool ) -> Result { - return Result { + Result { try FileDescriptor.open( path, mode, @@ -1512,7 +1516,7 @@ extension SystemFileHandle { return .success(handle) } catch { - #if canImport(Glibc) || canImport(Musl) + #if canImport(Glibc) || canImport(Musl) // 'O_TMPFILE' isn't supported for the current file system, try again but using // rename instead. if useTemporaryFileIfPossible, let errno = error as? Errno, errno == .notSupported { diff --git a/Sources/NIOFileSystem/OpenOptions.swift b/Sources/NIOFileSystem/OpenOptions.swift index ce7ceaecc6..454826897a 100644 --- a/Sources/NIOFileSystem/OpenOptions.swift +++ b/Sources/NIOFileSystem/OpenOptions.swift @@ -124,7 +124,7 @@ public enum OpenOptions { replaceExisting: Bool, permissions: FilePermissions? = nil ) -> Self { - return Write( + Write( existingFile: replaceExisting ? .truncate : .none, newFile: NewFile(permissions: permissions) ) @@ -143,7 +143,7 @@ public enum OpenOptions { createIfNecessary: Bool, permissions: FilePermissions? = nil ) -> Self { - return Write( + Write( existingFile: .open, newFile: createIfNecessary ? NewFile(permissions: permissions) : nil ) diff --git a/Sources/NIOFileSystemFoundationCompat/Date+FileInfo.swift b/Sources/NIOFileSystemFoundationCompat/Date+FileInfo.swift index 5e302d0930..f65db1f773 100644 --- a/Sources/NIOFileSystemFoundationCompat/Date+FileInfo.swift +++ b/Sources/NIOFileSystemFoundationCompat/Date+FileInfo.swift @@ -27,7 +27,7 @@ extension Date { extension FileInfo.Timespec { /// The UTC time of the timestamp. public var date: Date { - return Date(timespec: self) + Date(timespec: self) } } #endif diff --git a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift index 2c61d9ef92..7aedfb17b0 100644 --- a/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift +++ b/Sources/NIOFoundationCompat/ByteBuffer-foundation.swift @@ -12,9 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import Foundation - +import NIOCore /// Errors that may be thrown by ByteBuffer methods that call into Foundation. public enum ByteBufferFoundationError: Error { @@ -22,20 +21,17 @@ public enum ByteBufferFoundationError: Error { case failedToEncodeString } - -/* - * This is NIO's `NIOFoundationCompat` module which at the moment only adds `ByteBuffer` utility methods - * for Foundation's `Data` type. - * - * The reason that it's not in the `NIO` module is that we don't want to have any direct Foundation dependencies - * in `NIO` as Foundation is problematic for a few reasons: - * - * - its implementation is different on Linux and on macOS which means our macOS tests might be inaccurate - * - on macOS Foundation is mostly written in ObjC which means the autorelease pool might get populated - * - `swift-corelibs-foundation` (the OSS Foundation used on Linux) links the world which will prevent anyone from - * having static binaries. It can also cause problems in the choice of an SSL library as Foundation already brings - * the platforms OpenSSL in which might cause problems. - */ +// This is NIO's `NIOFoundationCompat` module which at the moment only adds `ByteBuffer` utility methods +// for Foundation's `Data` type. +// +// The reason that it's not in the `NIO` module is that we don't want to have any direct Foundation dependencies +// in `NIO` as Foundation is problematic for a few reasons: +// +// - its implementation is different on Linux and on macOS which means our macOS tests might be inaccurate +// - on macOS Foundation is mostly written in ObjC which means the autorelease pool might get populated +// - `swift-corelibs-foundation` (the OSS Foundation used on Linux) links the world which will prevent anyone from +// having static binaries. It can also cause problems in the choice of an SSL library as Foundation already brings +// the platforms OpenSSL in which might cause problems. extension ByteBuffer { /// Controls how bytes are transferred between `ByteBuffer` and other storage types. @@ -63,10 +59,9 @@ extension ByteBuffer { /// - length: The number of bytes to be read from this `ByteBuffer`. /// - returns: A `Data` value containing `length` bytes or `nil` if there aren't at least `length` bytes readable. public mutating func readData(length: Int) -> Data? { - return self.readData(length: length, byteTransferStrategy: .automatic) + self.readData(length: length, byteTransferStrategy: .automatic) } - /// Read `length` bytes off this `ByteBuffer`, move the reader index forward by `length` bytes and return the result /// as `Data`. /// @@ -76,7 +71,9 @@ extension ByteBuffer { /// of the options. /// - returns: A `Data` value containing `length` bytes or `nil` if there aren't at least `length` bytes readable. public mutating func readData(length: Int, byteTransferStrategy: ByteTransferStrategy) -> Data? { - guard let result = self.getData(at: self.readerIndex, length: length, byteTransferStrategy: byteTransferStrategy) else { + guard + let result = self.getData(at: self.readerIndex, length: length, byteTransferStrategy: byteTransferStrategy) + else { return nil } self.moveReaderIndex(forwardBy: length) @@ -95,7 +92,7 @@ extension ByteBuffer { /// - length: The number of bytes of interest /// - returns: A `Data` value containing the bytes of interest or `nil` if the selected bytes are not readable. public func getData(at index: Int, length: Int) -> Data? { - return self.getData(at: index, length: length, byteTransferStrategy: .automatic) + self.getData(at: index, length: length, byteTransferStrategy: .automatic) } /// Return `length` bytes starting at `index` and return the result as `Data`. This will not change the reader index. @@ -119,18 +116,22 @@ extension ByteBuffer { case .noCopy: doCopy = false case .automatic: - doCopy = length <= 256*1024 + doCopy = length <= 256 * 1024 } return self.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in if doCopy { - return Data(bytes: UnsafeMutableRawPointer(mutating: ptr.baseAddress!.advanced(by: index)), - count: Int(length)) + return Data( + bytes: UnsafeMutableRawPointer(mutating: ptr.baseAddress!.advanced(by: index)), + count: Int(length) + ) } else { _ = storageRef.retain() - return Data(bytesNoCopy: UnsafeMutableRawPointer(mutating: ptr.baseAddress!.advanced(by: index)), - count: Int(length), - deallocator: .custom { _, _ in storageRef.release() }) + return Data( + bytesNoCopy: UnsafeMutableRawPointer(mutating: ptr.baseAddress!.advanced(by: index)), + count: Int(length), + deallocator: .custom { _, _ in storageRef.release() } + ) } } } @@ -227,7 +228,7 @@ extension ByteBuffer { @inlinable @discardableResult public mutating func setContiguousBytes(_ bytes: Bytes, at index: Int) -> Int { - return bytes.withUnsafeBytes { bufferPointer in + bytes.withUnsafeBytes { bufferPointer in self.setBytes(bufferPointer, at: index) } } @@ -277,7 +278,8 @@ extension ByteBuffer { /// are not readable or there were not enough bytes. public func getUUIDBytes(at index: Int) -> UUID? { guard let chunk1 = self.getInteger(at: index, as: UInt64.self), - let chunk2 = self.getInteger(at: index + 8, as: UInt64.self) else { + let chunk2 = self.getInteger(at: index + 8, as: UInt64.self) + else { return nil } @@ -389,7 +391,7 @@ extension ByteBufferView { public typealias Regions = CollectionOfOne public var regions: CollectionOfOne { - return .init(self) + .init(self) } } diff --git a/Sources/NIOFoundationCompat/Codable+ByteBuffer.swift b/Sources/NIOFoundationCompat/Codable+ByteBuffer.swift index b1d30bde39..0f61ad0ae1 100644 --- a/Sources/NIOFoundationCompat/Codable+ByteBuffer.swift +++ b/Sources/NIOFoundationCompat/Codable+ByteBuffer.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import Foundation +import NIOCore extension ByteBuffer { /// Attempts to decode the `length` bytes from `index` using the `JSONDecoder` `decoder` as `T`. @@ -25,9 +25,12 @@ extension ByteBuffer { /// - length: The number of bytes to decode. /// - returns: The decoded value if successful or `nil` if there are not enough readable bytes available. @inlinable - public func getJSONDecodable(_ type: T.Type, - decoder: JSONDecoder = JSONDecoder(), - at index: Int, length: Int) throws -> T? { + public func getJSONDecodable( + _ type: T.Type, + decoder: JSONDecoder = JSONDecoder(), + at index: Int, + length: Int + ) throws -> T? { guard let data = self.getData(at: index, length: length, byteTransferStrategy: .noCopy) else { return nil } @@ -42,13 +45,19 @@ extension ByteBuffer { /// - length: The number of bytes to decode. /// - returns: The decoded value is successful or `nil` if there are not enough readable bytes available. @inlinable - public mutating func readJSONDecodable(_ type: T.Type, - decoder: JSONDecoder = JSONDecoder(), - length: Int) throws -> T? { - guard let decoded = try self.getJSONDecodable(T.self, - decoder: decoder, - at: self.readerIndex, - length: length) else { + public mutating func readJSONDecodable( + _ type: T.Type, + decoder: JSONDecoder = JSONDecoder(), + length: Int + ) throws -> T? { + guard + let decoded = try self.getJSONDecodable( + T.self, + decoder: decoder, + at: self.readerIndex, + length: length + ) + else { return nil } self.moveReaderIndex(forwardBy: length) @@ -66,9 +75,11 @@ extension ByteBuffer { /// - returns: The number of bytes written. @inlinable @discardableResult - public mutating func setJSONEncodable(_ value: T, - encoder: JSONEncoder = JSONEncoder(), - at index: Int) throws -> Int { + public mutating func setJSONEncodable( + _ value: T, + encoder: JSONEncoder = JSONEncoder(), + at index: Int + ) throws -> Int { let data = try encoder.encode(value) return self.setBytes(data, at: index) } @@ -83,8 +94,10 @@ extension ByteBuffer { /// - returns: The number of bytes written. @inlinable @discardableResult - public mutating func writeJSONEncodable(_ value: T, - encoder: JSONEncoder = JSONEncoder()) throws -> Int { + public mutating func writeJSONEncodable( + _ value: T, + encoder: JSONEncoder = JSONEncoder() + ) throws -> Int { let result = try self.setJSONEncodable(value, encoder: encoder, at: self.writerIndex) self.moveWriterIndex(forwardBy: result) return result @@ -106,10 +119,12 @@ extension JSONDecoder { /// - buffer: The `ByteBuffer` that contains JSON object to decode. /// - returns: The decoded object. public func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T { - return try buffer.getJSONDecodable(T.self, - decoder: self, - at: buffer.readerIndex, - length: buffer.readableBytes)! // must work, enough readable bytes + try buffer.getJSONDecodable( + T.self, + decoder: self, + at: buffer.readerIndex, + length: buffer.readableBytes + )! // must work, enough readable bytes// must work, enough readable bytes } } @@ -119,8 +134,10 @@ extension JSONEncoder { /// - parameters: /// - value: The value to encode as JSON. /// - buffer: The `ByteBuffer` to encode into. - public func encode(_ value: T, - into buffer: inout ByteBuffer) throws { + public func encode( + _ value: T, + into buffer: inout ByteBuffer + ) throws { try buffer.writeJSONEncodable(value, encoder: self) } diff --git a/Sources/NIOFoundationCompat/JSONSerialization+ByteBuffer.swift b/Sources/NIOFoundationCompat/JSONSerialization+ByteBuffer.swift index 5dfc5fb2d9..d8f00438ca 100644 --- a/Sources/NIOFoundationCompat/JSONSerialization+ByteBuffer.swift +++ b/Sources/NIOFoundationCompat/JSONSerialization+ByteBuffer.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import NIOCore import Foundation +import NIOCore extension JSONSerialization { - + /// Attempts to derive a Foundation object from a ByteBuffer and return it as `T`. /// /// - parameters: @@ -24,8 +24,10 @@ extension JSONSerialization { /// - options: The reading option used when the parser derives the Foundation type from the ByteBuffer. /// - returns: The Foundation value if successful or `nil` if there was an issue creating the Foundation type. @inlinable - public static func jsonObject(with buffer: ByteBuffer, - options opt: JSONSerialization.ReadingOptions = []) throws -> Any { - return try JSONSerialization.jsonObject(with: Data(buffer: buffer), options: opt) + public static func jsonObject( + with buffer: ByteBuffer, + options opt: JSONSerialization.ReadingOptions = [] + ) throws -> Any { + try JSONSerialization.jsonObject(with: Data(buffer: buffer), options: opt) } } diff --git a/Sources/NIOHTTP1/ByteCollectionUtils.swift b/Sources/NIOHTTP1/ByteCollectionUtils.swift index 75ee72d74f..95b6f8dff5 100644 --- a/Sources/NIOHTTP1/ByteCollectionUtils.swift +++ b/Sources/NIOHTTP1/ByteCollectionUtils.swift @@ -14,20 +14,21 @@ import NIOCore -fileprivate let defaultWhitespaces = [" ", "\t"].map({$0.utf8.first!}) +private let defaultWhitespaces = [" ", "\t"].map({ $0.utf8.first! }) extension ByteBufferView { internal func trim(limitingElements: [UInt8]) -> ByteBufferView { guard let lastNonWhitespaceIndex = self.lastIndex(where: { !limitingElements.contains($0) }), - let firstNonWhitespaceIndex = self.firstIndex(where: { !limitingElements.contains($0) }) else { - // This buffer is entirely trimmed elements, so trim it to nothing. - return self[self.startIndex.. ByteBufferView { - return trim(limitingElements: defaultWhitespaces) + trim(limitingElements: defaultWhitespaces) } } @@ -40,34 +41,34 @@ extension Sequence where Self.Element == UInt8 { /// - Parameter bytes: The string constant in the form of a collection of `UInt8` /// - Returns: Whether the collection contains **EXACTLY** this array or no, but by ignoring case. internal func compareCaseInsensitiveASCIIBytes(to: T) -> Bool - where T.Element == UInt8 { - // fast path: we can get the underlying bytes of both - let maybeMaybeResult = self.withContiguousStorageIfAvailable { lhsBuffer -> Bool? in - to.withContiguousStorageIfAvailable { rhsBuffer in - if lhsBuffer.count != rhsBuffer.count { - return false - } + where T.Element == UInt8 { + // fast path: we can get the underlying bytes of both + let maybeMaybeResult = self.withContiguousStorageIfAvailable { lhsBuffer -> Bool? in + to.withContiguousStorageIfAvailable { rhsBuffer in + if lhsBuffer.count != rhsBuffer.count { + return false + } - for idx in 0 ..< lhsBuffer.count { - // let's hope this gets vectorised ;) - if lhsBuffer[idx] & 0xdf != rhsBuffer[idx] & 0xdf { - return false - } + for idx in 0.. Bool { - return self.utf8.compareCaseInsensitiveASCIIBytes(to: to.utf8) + self.utf8.compareCaseInsensitiveASCIIBytes(to: to.utf8) } } diff --git a/Sources/NIOHTTP1/HTTPDecoder.swift b/Sources/NIOHTTP1/HTTPDecoder.swift index b37c99a965..eb33c620ba 100644 --- a/Sources/NIOHTTP1/HTTPDecoder.swift +++ b/Sources/NIOHTTP1/HTTPDecoder.swift @@ -12,13 +12,13 @@ // //===----------------------------------------------------------------------===// -import NIOCore @_implementationOnly import CNIOLLHTTP +import NIOCore -private extension UnsafeMutablePointer where Pointee == llhttp_t { +extension UnsafeMutablePointer where Pointee == llhttp_t { /// Returns the `KeepAliveState` for the current message that is parsed. - var keepAliveState: KeepAliveState { - return c_nio_llhttp_should_keep_alive(self) == 0 ? .close : .keepAlive + fileprivate var keepAliveState: KeepAliveState { + c_nio_llhttp_should_keep_alive(self) == 0 ? .close : .keepAlive } } @@ -39,7 +39,7 @@ private class BetterHTTPParser { private static let maximumHeaderFieldSize = 80 * 1024 var delegate: HTTPDecoderDelegate! = nil - private var parser: llhttp_t? = llhttp_t() // nil if unaccessible because reference passed away exclusively + private var parser: llhttp_t? = llhttp_t() // nil if unaccessible because reference passed away exclusively private var settings: UnsafeMutablePointer private var decodingState: HTTPDecodingState = .beforeMessageBegin private var firstNonDiscardableOffset: Int? = nil @@ -58,7 +58,7 @@ private class BetterHTTPParser { } private static func fromOpaque(_ opaque: UnsafePointer?) -> BetterHTTPParser { - return Unmanaged.fromOpaque(UnsafeRawPointer(opaque!.pointee.data)).takeUnretainedValue() + Unmanaged.fromOpaque(UnsafeRawPointer(opaque!.pointee.data)).takeUnretainedValue() } init(kind: HTTPDecoderKind) { @@ -70,17 +70,21 @@ private class BetterHTTPParser { return 0 } self.settings.pointee.on_header_field = { opaque, bytes, len in - return BetterHTTPParser.fromOpaque(opaque).didReceiveHeaderFieldData(UnsafeRawBufferPointer(start: bytes, count: len)) + BetterHTTPParser.fromOpaque(opaque).didReceiveHeaderFieldData( + UnsafeRawBufferPointer(start: bytes, count: len) + ) } self.settings.pointee.on_header_value = { opaque, bytes, len in - return BetterHTTPParser.fromOpaque(opaque).didReceiveHeaderValueData(UnsafeRawBufferPointer(start: bytes, count: len)) + BetterHTTPParser.fromOpaque(opaque).didReceiveHeaderValueData( + UnsafeRawBufferPointer(start: bytes, count: len) + ) } self.settings.pointee.on_status = { opaque, bytes, len in BetterHTTPParser.fromOpaque(opaque).didReceiveStatusData(UnsafeRawBufferPointer(start: bytes, count: len)) return 0 } self.settings.pointee.on_url = { opaque, bytes, len in - return BetterHTTPParser.fromOpaque(opaque).didReceiveURLData(UnsafeRawBufferPointer(start: bytes, count: len)) + BetterHTTPParser.fromOpaque(opaque).didReceiveURLData(UnsafeRawBufferPointer(start: bytes, count: len)) } self.settings.pointee.on_chunk_complete = { opaque in BetterHTTPParser.fromOpaque(opaque).didReceiveChunkCompleteNotification() @@ -96,19 +100,21 @@ private class BetterHTTPParser { } self.settings.pointee.on_headers_complete = { opaque in let parser = BetterHTTPParser.fromOpaque(opaque) - switch parser.didReceiveHeadersCompleteNotification(versionMajor: Int(opaque!.pointee.http_major), - versionMinor: Int(opaque!.pointee.http_minor), - statusCode: Int(opaque!.pointee.status_code), - isUpgrade: opaque!.pointee.upgrade != 0, - method: llhttp_method(rawValue: CUnsignedInt(opaque!.pointee.method)), - keepAliveState: opaque!.keepAliveState) { + switch parser.didReceiveHeadersCompleteNotification( + versionMajor: Int(opaque!.pointee.http_major), + versionMinor: Int(opaque!.pointee.http_minor), + statusCode: Int(opaque!.pointee.status_code), + isUpgrade: opaque!.pointee.upgrade != 0, + method: llhttp_method(rawValue: CUnsignedInt(opaque!.pointee.method)), + keepAliveState: opaque!.keepAliveState + ) { case .normal: return 0 case .skipBody: return 1 case .error(let err): parser.httpErrno = err - return -1 // error + return -1 // error } } self.settings.pointee.on_message_complete = { opaque in @@ -145,7 +151,7 @@ private class BetterHTTPParser { let end = start + currentFieldByteLength self.firstNonDiscardableOffset = nil precondition(start >= self.rawBytesView.startIndex && end <= self.rawBytesView.endIndex) - try callout(&self.delegate, .init(rebasing: self.rawBytesView[start ..< end])) + try callout(&self.delegate, .init(rebasing: self.rawBytesView[start.. MessageContinuation { + private func didReceiveHeadersCompleteNotification( + versionMajor: Int, + versionMinor: Int, + statusCode: Int, + isUpgrade: Bool, + method: llhttp_method, + keepAliveState: KeepAliveState + ) -> MessageContinuation { switch self.decodingState { case .headerValue: self.finish { delegate, bytes in @@ -307,7 +315,7 @@ private class BetterHTTPParser { self.richerError = NIOHTTPDecoderError.unsolicitedResponse return .error(HPE_INTERNAL) } - + if 100 <= statusCode && statusCode < 200 && statusCode != 101 { // if the response status is in the range of 100..<200 but not 101 we don't want to // pop the request method. The actual request head is expected with the next HTTP @@ -317,19 +325,22 @@ private class BetterHTTPParser { let method = self.requestHeads.removeFirst().method if method == .HEAD || method == .CONNECT { skipBody = true - } else if statusCode / 100 == 1 || // 1XX codes - statusCode == 204 || statusCode == 304 { + } else if statusCode / 100 == 1 // 1XX codes + || statusCode == 204 || statusCode == 304 + { skipBody = true } } } - let success = self.delegate.didFinishHead(versionMajor: versionMajor, - versionMinor: versionMinor, - isUpgrade: isUpgrade, - method: method, - statusCode: statusCode, - keepAliveState: keepAliveState) + let success = self.delegate.didFinishHead( + versionMajor: versionMajor, + versionMinor: versionMinor, + isUpgrade: isUpgrade, + method: method, + statusCode: statusCode, + keepAliveState: keepAliveState + ) guard success else { return .error(HPE_INVALID_VERSION) } @@ -361,7 +372,7 @@ private class BetterHTTPParser { return 0 } - @inline(__always) // this need to be optimised away + @inline(__always) // this need to be optimised away func withExclusiveHTTPParser(_ body: (UnsafeMutablePointer) -> T) -> T { var parser: llhttp_t? = nil assert(self.parser != nil, "parser must not be nil here, must be a re-entrancy issue") @@ -387,9 +398,11 @@ private class BetterHTTPParser { let startPointer = bytes.baseAddress! + self.httpParserOffset let bytesToRead = bytes.count - self.httpParserOffset - rc = c_nio_llhttp_execute_swift(parserPtr, - startPointer, - bytesToRead) + rc = c_nio_llhttp_execute_swift( + parserPtr, + startPointer, + bytesToRead + ) if rc == HPE_PAUSED_UPGRADE { // This is a special pause. We don't need to stop here (our other code will prevent us @@ -417,7 +430,7 @@ private class BetterHTTPParser { // If we have a richer error than the errno code, and the errno is internal, we'll use it. Otherwise, we use the // error from http_parser. let err = self.httpErrno ?? parserErrno - if (err == HPE_INTERNAL || err == HPE_USER), let richerError = self.richerError { + if err == HPE_INTERNAL || err == HPE_USER, let richerError = self.richerError { throw richerError } else { throw HTTPParserError.httpError(fromCHTTPParserErrno: err)! @@ -449,12 +462,14 @@ private protocol HTTPDecoderDelegate { mutating func didReceiveTrailerName(_ bytes: UnsafeRawBufferPointer) throws mutating func didReceiveTrailerValue(_ bytes: UnsafeRawBufferPointer) mutating func didReceiveURL(_ bytes: UnsafeRawBufferPointer) - mutating func didFinishHead(versionMajor: Int, - versionMinor: Int, - isUpgrade: Bool, - method: llhttp_method, - statusCode: Int, - keepAliveState: KeepAliveState) -> Bool + mutating func didFinishHead( + versionMajor: Int, + versionMinor: Int, + isUpgrade: Bool, + method: llhttp_method, + statusCode: Int, + keepAliveState: KeepAliveState + ) -> Bool mutating func didFinishMessage() } @@ -484,7 +499,8 @@ public enum HTTPDecoderKind: Sendable { case response } -extension HTTPDecoder: WriteObservingByteToMessageDecoder where In == HTTPClientResponsePart, Out == HTTPClientRequestPart { +extension HTTPDecoder: WriteObservingByteToMessageDecoder +where In == HTTPClientResponsePart, Out == HTTPClientRequestPart { public typealias OutboundIn = Out public func write(data: HTTPClientRequestPart) { @@ -520,7 +536,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega private let leftOverBytesStrategy: RemoveAfterUpgradeStrategy private let informationalResponseStrategy: NIOInformationalResponseStrategy private let kind: HTTPDecoderKind - private var stopParsing = false // set on upgrade or HTTP version error + private var stopParsing = false // set on upgrade or HTTP version error private var lastResponseHeaderWasInformational = false /// Creates a new instance of `HTTPDecoder`. @@ -531,7 +547,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega public convenience init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) { self.init(leftOverBytesStrategy: leftOverBytesStrategy, informationalResponseStrategy: .drop) } - + /// Creates a new instance of `HTTPDecoder`. /// /// - parameters: @@ -539,7 +555,10 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega /// detected. Note that this does not affect what happens on EOF. /// - informationalResponseStrategy: Should informational responses (like http status 100) be forwarded or dropped. /// Default is `.drop`. This property is only respected when decoding responses. - public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, informationalResponseStrategy: NIOInformationalResponseStrategy = .drop) { + public init( + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + informationalResponseStrategy: NIOInformationalResponseStrategy = .drop + ) { self.headers.reserveCapacity(16) if In.self == HTTPServerRequestPart.self { self.kind = .request @@ -563,9 +582,13 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega self.buffer!.moveReaderIndex(forwardBy: offset) switch self.kind { case .request: - self.context!.fireChannelRead(NIOAny(HTTPServerRequestPart.body(self.buffer!.readSlice(length: bytes.count)!))) + self.context!.fireChannelRead( + NIOAny(HTTPServerRequestPart.body(self.buffer!.readSlice(length: bytes.count)!)) + ) case .response: - self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.body(self.buffer!.readSlice(length: bytes.count)!))) + self.context!.fireChannelRead( + NIOAny(HTTPClientResponsePart.body(self.buffer!.readSlice(length: bytes.count)!)) + ) } } @@ -608,12 +631,14 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega self.url = String(decoding: bytes, as: Unicode.UTF8.self) } - func didFinishHead(versionMajor: Int, - versionMinor: Int, - isUpgrade: Bool, - method: llhttp_method, - statusCode: Int, - keepAliveState: KeepAliveState) -> Bool { + func didFinishHead( + versionMajor: Int, + versionMinor: Int, + isUpgrade: Bool, + method: llhttp_method, + statusCode: Int, + keepAliveState: KeepAliveState + ) -> Bool { let message: NIOAny? guard versionMajor == 1 else { @@ -624,13 +649,17 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega switch self.kind { case .request: - let reqHead = HTTPRequestHead(version: .init(major: versionMajor, minor: versionMinor), - method: HTTPMethod.from(httpParserMethod: method), - uri: self.url!, - headers: HTTPHeaders(self.headers, - keepAliveState: keepAliveState)) + let reqHead = HTTPRequestHead( + version: .init(major: versionMajor, minor: versionMinor), + method: HTTPMethod.from(httpParserMethod: method), + uri: self.url!, + headers: HTTPHeaders( + self.headers, + keepAliveState: keepAliveState + ) + ) message = NIOAny(HTTPServerRequestPart.head(reqHead)) - + case .response where (100..<200).contains(statusCode) && statusCode != 101: self.lastResponseHeaderWasInformational = true switch self.informationalResponseStrategy.base { @@ -646,7 +675,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega case .drop: message = nil } - + case .response: self.lastResponseHeaderWasInformational = false let resHeadPart = HTTPClientResponsePart.head( @@ -721,7 +750,11 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega return .needMoreData } - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { if !self.stopParsing { while buffer.readableBytes > 0, case .continue = try self.decode(context: context, buffer: &buffer) {} if seenEOF { @@ -763,12 +796,12 @@ public struct NIOInformationalResponseStrategy: Hashable, Sendable { case drop case forward } - + var base: Base private init(_ base: Base) { self.base = base } - + /// Drop the informational response and only forward the "real" response public static let drop = Self(.drop) /// Forward the informational response and then forward the "real" response. This will result in @@ -803,7 +836,7 @@ extension HTTPParserError { case HPE_INVALID_VERSION: return .invalidVersion case HPE_INVALID_HEADER_TOKEN, - HPE_UNEXPECTED_SPACE: + HPE_UNEXPECTED_SPACE: return .invalidHeaderToken case HPE_INVALID_CONTENT_LENGTH: return .invalidContentLength @@ -816,17 +849,17 @@ extension HTTPParserError { case HPE_PAUSED, HPE_PAUSED_UPGRADE, HPE_PAUSED_H2_UPGRADE: return .paused case HPE_INVALID_TRANSFER_ENCODING, - HPE_CR_EXPECTED, - HPE_CB_MESSAGE_BEGIN, - HPE_CB_HEADERS_COMPLETE, - HPE_CB_MESSAGE_COMPLETE, - HPE_CB_CHUNK_HEADER, - HPE_CB_CHUNK_COMPLETE, - HPE_USER, - HPE_CB_URL_COMPLETE, - HPE_CB_STATUS_COMPLETE, - HPE_CB_HEADER_FIELD_COMPLETE, - HPE_CB_HEADER_VALUE_COMPLETE: + HPE_CR_EXPECTED, + HPE_CB_MESSAGE_BEGIN, + HPE_CB_HEADERS_COMPLETE, + HPE_CB_MESSAGE_COMPLETE, + HPE_CB_CHUNK_HEADER, + HPE_CB_CHUNK_COMPLETE, + HPE_USER, + HPE_CB_URL_COMPLETE, + HPE_CB_STATUS_COMPLETE, + HPE_CB_HEADER_FIELD_COMPLETE, + HPE_CB_HEADER_VALUE_COMPLETE: // The downside of enums here, we don't have a case for these. Map them to .unknown for now. return .unknown default: @@ -942,7 +975,6 @@ extension HTTPMethod { } } - /// Errors thrown by `HTTPRequestDecoder` and `HTTPResponseDecoder` in addition to /// `HTTPParserError`. public struct NIOHTTPDecoderError: Error { @@ -953,19 +985,16 @@ public struct NIOHTTPDecoderError: Error { private let baseError: BaseError } - extension NIOHTTPDecoderError { /// A response was received from a server without an associated request having been sent. public static let unsolicitedResponse: NIOHTTPDecoderError = .init(baseError: .unsolicitedResponse) } - -extension NIOHTTPDecoderError: Hashable { } - +extension NIOHTTPDecoderError: Hashable {} extension NIOHTTPDecoderError: CustomDebugStringConvertible { public var debugDescription: String { - return String(describing: self.baseError) + String(describing: self.baseError) } } @@ -977,10 +1006,12 @@ extension HTTPClientResponsePart { keepAliveState: KeepAliveState, headers: [(String, String)] ) -> HTTPClientResponsePart { - HTTPClientResponsePart.head(HTTPResponseHead( - version: .init(major: versionMajor, minor: versionMinor), - status: .init(statusCode: statusCode), - headers: HTTPHeaders(headers, keepAliveState: keepAliveState) - )) + HTTPClientResponsePart.head( + HTTPResponseHead( + version: .init(major: versionMajor, minor: versionMinor), + status: .init(statusCode: statusCode), + headers: HTTPHeaders(headers, keepAliveState: keepAliveState) + ) + ) } } diff --git a/Sources/NIOHTTP1/HTTPEncoder.swift b/Sources/NIOHTTP1/HTTPEncoder.swift index 90146aa974..19bc6109ac 100644 --- a/Sources/NIOHTTP1/HTTPEncoder.swift +++ b/Sources/NIOHTTP1/HTTPEncoder.swift @@ -14,7 +14,13 @@ import NIOCore -private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHandlerContext, isChunked: Bool, chunk: IOData, promise: EventLoopPromise?) { +private func writeChunk( + wrapOutboundOut: (IOData) -> NIOAny, + context: ChannelHandlerContext, + isChunked: Bool, + chunk: IOData, + promise: EventLoopPromise? +) { let readableBytes = chunk.readableBytes // we don't want to copy the chunk unnecessarily and therefore call write an annoyingly large number of times @@ -24,15 +30,19 @@ private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHan let (mW1, mW2, mW3): (EventLoopPromise?, EventLoopPromise?, EventLoopPromise?) if let p = promise { - /* chunked encoding and the user's interested: we need three promises and need to cascade into the users promise */ - let (w1, w2, w3) = (context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise, context.eventLoop.makePromise() as EventLoopPromise) + // chunked encoding and the user's interested: we need three promises and need to cascade into the users promise + let (w1, w2, w3) = ( + context.eventLoop.makePromise() as EventLoopPromise, + context.eventLoop.makePromise() as EventLoopPromise, + context.eventLoop.makePromise() as EventLoopPromise + ) w1.futureResult.and(w2.futureResult).and(w3.futureResult).map { (_: ((((), ()), ()))) in }.cascade(to: p) (mW1, mW2, mW3) = (w1, w2, w3) } else { - /* user isn't interested, let's not bother even allocating promises */ + // user isn't interested, let's not bother even allocating promises (mW1, mW2, mW3) = (nil, nil, nil) } - + var buffer = context.channel.allocator.buffer(capacity: 32) let len = String(readableBytes, radix: 16) buffer.writeString(len) @@ -49,14 +59,20 @@ private func writeChunk(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHan } } -private func writeTrailers(wrapOutboundOut: (IOData) -> NIOAny, context: ChannelHandlerContext, isChunked: Bool, trailers: HTTPHeaders?, promise: EventLoopPromise?) { +private func writeTrailers( + wrapOutboundOut: (IOData) -> NIOAny, + context: ChannelHandlerContext, + isChunked: Bool, + trailers: HTTPHeaders?, + promise: EventLoopPromise? +) { switch (isChunked, promise) { case (true, let p): var buffer: ByteBuffer if let trailers = trailers { buffer = context.channel.allocator.buffer(capacity: 256) buffer.writeStaticString("0\r\n") - buffer.write(headers: trailers) // Includes trailing CRLF. + buffer.write(headers: trailers) // Includes trailing CRLF. } else { buffer = context.channel.allocator.buffer(capacity: 8) buffer.writeStaticString("0\r\n\r\n") @@ -75,7 +91,13 @@ private func writeTrailers(wrapOutboundOut: (IOData) -> NIOAny, context: Channel // starting about swift-5.0-DEVELOPMENT-SNAPSHOT-2019-01-20-a, this doesn't get automatically inlined, which costs // 2 extra allocations so we need to help the optimiser out. @inline(__always) -private func writeHead(wrapOutboundOut: (IOData) -> NIOAny, writeStartLine: (inout ByteBuffer) -> Void, context: ChannelHandlerContext, headers: HTTPHeaders, promise: EventLoopPromise?) { +private func writeHead( + wrapOutboundOut: (IOData) -> NIOAny, + writeStartLine: (inout ByteBuffer) -> Void, + context: ChannelHandlerContext, + headers: HTTPHeaders, + promise: EventLoopPromise? +) { var buffer = context.channel.allocator.buffer(capacity: 256) writeStartLine(&buffer) @@ -99,7 +121,11 @@ private enum BodyFraming { /// /// Note that for HTTP/1.0 if there is no Content-Length then the response should be followed /// by connection close. We require that the user send that connection close: we don't do it. -private func correctlyFrameTransportHeaders(hasBody: HTTPMethod.HasBody, headers: inout HTTPHeaders, version: HTTPVersion) -> BodyFraming { +private func correctlyFrameTransportHeaders( + hasBody: HTTPMethod.HasBody, + headers: inout HTTPHeaders, + version: HTTPVersion +) -> BodyFraming { switch hasBody { case .no: headers.remove(name: "content-length") @@ -175,27 +201,50 @@ public final class HTTPRequestEncoder: ChannelOutboundHandler, RemovableChannelH self.configuration = configuration } - public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { switch Self.unwrapOutboundIn(data) { case .head(var request): - assert(!(request.headers.contains(name: "content-length") && - request.headers[canonicalForm: "transfer-encoding"].contains("chunked"[...])), - "illegal HTTP sent: \(request) contains both a content-length and transfer-encoding:chunked") + assert( + !(request.headers.contains(name: "content-length") + && request.headers[canonicalForm: "transfer-encoding"].contains("chunked"[...])), + "illegal HTTP sent: \(request) contains both a content-length and transfer-encoding:chunked" + ) if self.configuration.automaticallySetFramingHeaders { - self.isChunked = correctlyFrameTransportHeaders(hasBody: request.method.hasRequestBody, - headers: &request.headers, version: request.version) == .chunked + self.isChunked = + correctlyFrameTransportHeaders( + hasBody: request.method.hasRequestBody, + headers: &request.headers, + version: request.version + ) == .chunked } else { self.isChunked = messageIsChunked(headers: request.headers, version: request.version) } - writeHead(wrapOutboundOut: Self.wrapOutboundOut, writeStartLine: { buffer in - buffer.write(request: request) - }, context: context, headers: request.headers, promise: promise) + writeHead( + wrapOutboundOut: Self.wrapOutboundOut, + writeStartLine: { buffer in + buffer.write(request: request) + }, + context: context, + headers: request.headers, + promise: promise + ) case .body(let bodyPart): - writeChunk(wrapOutboundOut: Self.wrapOutboundOut, context: context, isChunked: self.isChunked, chunk: bodyPart, promise: promise) + writeChunk( + wrapOutboundOut: Self.wrapOutboundOut, + context: context, + isChunked: self.isChunked, + chunk: bodyPart, + promise: promise + ) case .end(let trailers): - writeTrailers(wrapOutboundOut: Self.wrapOutboundOut, context: context, isChunked: self.isChunked, trailers: trailers, promise: promise) + writeTrailers( + wrapOutboundOut: Self.wrapOutboundOut, + context: context, + isChunked: self.isChunked, + trailers: trailers, + promise: promise + ) } } } @@ -245,24 +294,48 @@ public final class HTTPResponseEncoder: ChannelOutboundHandler, RemovableChannel public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { switch Self.unwrapOutboundIn(data) { case .head(var response): - assert(!(response.headers.contains(name: "content-length") && - response.headers[canonicalForm: "transfer-encoding"].contains("chunked"[...])), - "illegal HTTP sent: \(response) contains both a content-length and transfer-encoding:chunked") + assert( + !(response.headers.contains(name: "content-length") + && response.headers[canonicalForm: "transfer-encoding"].contains("chunked"[...])), + "illegal HTTP sent: \(response) contains both a content-length and transfer-encoding:chunked" + ) if self.configuration.automaticallySetFramingHeaders { - self.isChunked = correctlyFrameTransportHeaders(hasBody: response.status.mayHaveResponseBody ? .yes : .no, - headers: &response.headers, version: response.version) == .chunked + self.isChunked = + correctlyFrameTransportHeaders( + hasBody: response.status.mayHaveResponseBody ? .yes : .no, + headers: &response.headers, + version: response.version + ) == .chunked } else { self.isChunked = messageIsChunked(headers: response.headers, version: response.version) } - writeHead(wrapOutboundOut: Self.wrapOutboundOut, writeStartLine: { buffer in - buffer.write(response: response) - }, context: context, headers: response.headers, promise: promise) + writeHead( + wrapOutboundOut: Self.wrapOutboundOut, + writeStartLine: { buffer in + buffer.write(response: response) + }, + context: context, + headers: response.headers, + promise: promise + ) case .body(let bodyPart): - writeChunk(wrapOutboundOut: Self.wrapOutboundOut, context: context, isChunked: self.isChunked, chunk: bodyPart, promise: promise) + writeChunk( + wrapOutboundOut: Self.wrapOutboundOut, + context: context, + isChunked: self.isChunked, + chunk: bodyPart, + promise: promise + ) case .end(let trailers): - writeTrailers(wrapOutboundOut: Self.wrapOutboundOut, context: context, isChunked: self.isChunked, trailers: trailers, promise: promise) + writeTrailers( + wrapOutboundOut: Self.wrapOutboundOut, + context: context, + isChunked: self.isChunked, + trailers: trailers, + promise: promise + ) } } } @@ -270,14 +343,14 @@ public final class HTTPResponseEncoder: ChannelOutboundHandler, RemovableChannel @available(*, unavailable) extension HTTPResponseEncoder: Sendable {} -private extension ByteBuffer { +extension ByteBuffer { private mutating func write(status: HTTPResponseStatus) { self.writeString(String(status.code)) self.writeWhitespace() self.writeString(status.reasonPhrase) } - mutating func write(response: HTTPResponseHead) { + fileprivate mutating func write(response: HTTPResponseHead) { switch (response.version.major, response.version.minor, response.status) { // Optimization for HTTP/1.0 case (1, 0, .custom(_, _)): @@ -550,7 +623,7 @@ private extension ByteBuffer { } } - mutating func write(request: HTTPRequestHead) { + fileprivate mutating func write(request: HTTPRequestHead) { self.write(method: request.method) self.writeWhitespace() self.writeString(request.uri) @@ -559,7 +632,7 @@ private extension ByteBuffer { self.writeStaticString("\r\n") } - mutating func writeWhitespace() { + fileprivate mutating func writeWhitespace() { self.writeInteger(32, as: UInt8.self) } diff --git a/Sources/NIOHTTP1/HTTPHeaderValidator.swift b/Sources/NIOHTTP1/HTTPHeaderValidator.swift index 74bc677599..ad8a0b584d 100644 --- a/Sources/NIOHTTP1/HTTPHeaderValidator.swift +++ b/Sources/NIOHTTP1/HTTPHeaderValidator.swift @@ -26,7 +26,7 @@ public final class NIOHTTPRequestHeadersValidator: ChannelOutboundHandler, Remov public typealias OutboundIn = HTTPClientRequestPart public typealias OutboundOut = HTTPClientRequestPart - public init() { } + public init() {} public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { switch Self.unwrapOutboundIn(data) { @@ -50,7 +50,6 @@ public final class NIOHTTPRequestHeadersValidator: ChannelOutboundHandler, Remov } } - /// A ChannelHandler to validate that outbound response headers are spec-compliant. /// /// The HTTP RFCs constrain the bytes that are validly present within a HTTP/1.1 header block. @@ -64,7 +63,7 @@ public final class NIOHTTPResponseHeadersValidator: ChannelOutboundHandler, Remo public typealias OutboundIn = HTTPServerResponsePart public typealias OutboundOut = HTTPServerResponsePart - public init() { } + public init() {} public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { switch Self.unwrapOutboundIn(data) { diff --git a/Sources/NIOHTTP1/HTTPHeaders+Validation.swift b/Sources/NIOHTTP1/HTTPHeaders+Validation.swift index d154875c32..b5a29c695f 100644 --- a/Sources/NIOHTTP1/HTTPHeaders+Validation.swift +++ b/Sources/NIOHTTP1/HTTPHeaders+Validation.swift @@ -67,7 +67,7 @@ extension String { /// ``` /// /// We implement this check directly. - fileprivate var isValidHeaderFieldName: Bool { + fileprivate var isValidHeaderFieldName: Bool { let fastResult = self.utf8.withContiguousStorageIfAvailable { ptr in ptr.allSatisfy { $0.isValidHeaderFieldNameByte } } @@ -124,7 +124,7 @@ extension UInt8 { @inline(__always) fileprivate var isValidHeaderFieldNameByte: Bool { switch self { - case UInt8(ascii: "0")...UInt8(ascii: "9"), // DIGIT + case UInt8(ascii: "0")...UInt8(ascii: "9"), // DIGIT UInt8(ascii: "a")...UInt8(ascii: "z"), UInt8(ascii: "A")...UInt8(ascii: "Z"), // ALPHA UInt8(ascii: "!"), UInt8(ascii: "#"), diff --git a/Sources/NIOHTTP1/HTTPPipelineSetup.swift b/Sources/NIOHTTP1/HTTPPipelineSetup.swift index c433665afc..419dcc275d 100644 --- a/Sources/NIOHTTP1/HTTPPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPPipelineSetup.swift @@ -18,7 +18,9 @@ import NIOCore /// /// See the documentation for `HTTPClientUpgradeHandler` for details on these /// properties. -public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) +public typealias NIOHTTPClientUpgradeConfiguration = ( + upgraders: [NIOHTTPClientProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void +) /// Configuration required to configure a HTTP server pipeline for upgrade. /// @@ -27,7 +29,9 @@ public typealias NIOHTTPClientUpgradeConfiguration = (upgraders: [NIOHTTPClientP @available(*, deprecated, renamed: "NIOHTTPServerUpgradeConfiguration") public typealias HTTPUpgradeConfiguration = NIOHTTPServerUpgradeConfiguration -public typealias NIOHTTPServerUpgradeConfiguration = (upgraders: [HTTPServerProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void) +public typealias NIOHTTPServerUpgradeConfiguration = ( + upgraders: [HTTPServerProtocolUpgrader], completionHandler: @Sendable (ChannelHandlerContext) -> Void +) extension ChannelPipeline { /// Configure a `ChannelPipeline` for use as a HTTP client. @@ -37,11 +41,15 @@ extension ChannelPipeline { /// - leftOverBytesStrategy: The strategy to use when dealing with leftover bytes after removing the `HTTPDecoder` /// from the pipeline. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) -> EventLoopFuture { - return self.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: nil) + public func addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes + ) -> EventLoopFuture { + self.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + withClientUpgrade: nil + ) } /// Configure a `ChannelPipeline` for use as a HTTP client with a client upgrader configuration. @@ -56,9 +64,11 @@ extension ChannelPipeline { /// for more details. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. @preconcurrency - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { + public func addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? + ) -> EventLoopFuture { self._addHTTPClientHandlers( position: position, leftOverBytesStrategy: leftOverBytesStrategy, @@ -66,23 +76,29 @@ extension ChannelPipeline { ) } - private func _addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) -> EventLoopFuture { + private func _addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? + ) -> EventLoopFuture { let future: EventLoopFuture if self.eventLoop.inEventLoop { let result = Result { - try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + withClientUpgrade: upgrade + ) } future = self.eventLoop.makeCompletedFuture(result) } else { future = self.eventLoop.submit { - return try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + withClientUpgrade: upgrade + ) } } @@ -102,26 +118,32 @@ extension ChannelPipeline { /// the upgrade completion handler. See the documentation on ``NIOHTTPClientUpgradeHandler`` /// for more details. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - enableOutboundHeaderValidation: Bool = true, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) -> EventLoopFuture { + public func addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + enableOutboundHeaderValidation: Bool = true, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) -> EventLoopFuture { let future: EventLoopFuture if self.eventLoop.inEventLoop { let result = Result { - try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + withClientUpgrade: upgrade + ) } future = self.eventLoop.makeCompletedFuture(result) } else { future = self.eventLoop.submit { - return try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + withClientUpgrade: upgrade + ) } } @@ -142,29 +164,35 @@ extension ChannelPipeline { /// the upgrade completion handler. See the documentation on ``NIOHTTPClientUpgradeHandler`` /// for more details. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func addHTTPClientHandlers(position: Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - enableOutboundHeaderValidation: Bool = true, - encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) -> EventLoopFuture { + public func addHTTPClientHandlers( + position: Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + enableOutboundHeaderValidation: Bool = true, + encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) -> EventLoopFuture { let future: EventLoopFuture if self.eventLoop.inEventLoop { let result = Result { - try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - encoderConfiguration: encoderConfiguration, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + encoderConfiguration: encoderConfiguration, + withClientUpgrade: upgrade + ) } future = self.eventLoop.makeCompletedFuture(result) } else { future = self.eventLoop.submit { - return try self.syncOperations.addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - encoderConfiguration: encoderConfiguration, - withClientUpgrade: upgrade) + try self.syncOperations.addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + encoderConfiguration: encoderConfiguration, + withClientUpgrade: upgrade + ) } } @@ -197,10 +225,12 @@ extension ChannelPipeline { /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. @preconcurrency - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) -> EventLoopFuture { + public func configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true + ) -> EventLoopFuture { self._configureHTTPServerPipeline( position: position, withPipeliningAssistance: pipelining, @@ -238,11 +268,13 @@ extension ChannelPipeline { /// - headerValidation: Whether to validate outbound request headers to confirm that they meet /// spec compliance. Defaults to `true`. /// - returns: An `EventLoopFuture` that will fire when the pipeline is configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true, - withOutboundHeaderValidation headerValidation: Bool = true) -> EventLoopFuture { + public func configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true, + withOutboundHeaderValidation headerValidation: Bool = true + ) -> EventLoopFuture { self._configureHTTPServerPipeline( position: position, withPipeliningAssistance: pipelining, @@ -312,22 +344,26 @@ extension ChannelPipeline { if self.eventLoop.inEventLoop { let result = Result { - try self.syncOperations.configureHTTPServerPipeline(position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling, - withOutboundHeaderValidation: headerValidation, - withEncoderConfiguration: encoderConfiguration) + try self.syncOperations.configureHTTPServerPipeline( + position: position, + withPipeliningAssistance: pipelining, + withServerUpgrade: upgrade, + withErrorHandling: errorHandling, + withOutboundHeaderValidation: headerValidation, + withEncoderConfiguration: encoderConfiguration + ) } future = self.eventLoop.makeCompletedFuture(result) } else { future = self.eventLoop.submit { - try self.syncOperations.configureHTTPServerPipeline(position: position, - withPipeliningAssistance: pipelining, - withServerUpgrade: upgrade, - withErrorHandling: errorHandling, - withOutboundHeaderValidation: headerValidation, - withEncoderConfiguration: encoderConfiguration) + try self.syncOperations.configureHTTPServerPipeline( + position: position, + withPipeliningAssistance: pipelining, + withServerUpgrade: upgrade, + withErrorHandling: errorHandling, + withOutboundHeaderValidation: headerValidation, + withEncoderConfiguration: encoderConfiguration + ) } } @@ -349,9 +385,11 @@ extension ChannelPipeline.SynchronousOperations { /// for more details. /// - throws: If the pipeline could not be configured. @preconcurrency - public func addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { + public func addHTTPClientHandlers( + position: ChannelPipeline.Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) throws { try self._addHTTPClientHandlers( position: position, leftOverBytesStrategy: leftOverBytesStrategy, @@ -371,14 +409,18 @@ extension ChannelPipeline.SynchronousOperations { /// the upgrade completion handler. See the documentation on ``NIOHTTPClientUpgradeHandler`` /// for more details. /// - throws: If the pipeline could not be configured. - public func addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - enableOutboundHeaderValidation: Bool = true, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { - try self._addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - withClientUpgrade: upgrade) + public func addHTTPClientHandlers( + position: ChannelPipeline.Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + enableOutboundHeaderValidation: Bool = true, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) throws { + try self._addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + withClientUpgrade: upgrade + ) } /// Configure a `ChannelPipeline` for use as a HTTP client. @@ -394,54 +436,70 @@ extension ChannelPipeline.SynchronousOperations { /// the upgrade completion handler. See the documentation on ``NIOHTTPClientUpgradeHandler`` /// for more details. /// - throws: If the pipeline could not be configured. - public func addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - enableOutboundHeaderValidation: Bool = true, - encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { - try self._addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - encoderConfiguration: encoderConfiguration, - withClientUpgrade: upgrade) + public func addHTTPClientHandlers( + position: ChannelPipeline.Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + enableOutboundHeaderValidation: Bool = true, + encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) throws { + try self._addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + encoderConfiguration: encoderConfiguration, + withClientUpgrade: upgrade + ) } - private func _addHTTPClientHandlers(position: ChannelPipeline.Position = .last, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, - enableOutboundHeaderValidation: Bool = true, - encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil) throws { + private func _addHTTPClientHandlers( + position: ChannelPipeline.Position = .last, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes, + enableOutboundHeaderValidation: Bool = true, + encoderConfiguration: HTTPRequestEncoder.Configuration = .init(), + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? = nil + ) throws { // Why two separate functions? With the fast-path (no upgrader, yes header validator) we can promote the Array of handlers // to the stack and skip an allocation. if upgrade != nil || enableOutboundHeaderValidation != true { - try self._addHTTPClientHandlersFallback(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - enableOutboundHeaderValidation: enableOutboundHeaderValidation, - encoderConfiguration: encoderConfiguration, - withClientUpgrade: upgrade) + try self._addHTTPClientHandlersFallback( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + enableOutboundHeaderValidation: enableOutboundHeaderValidation, + encoderConfiguration: encoderConfiguration, + withClientUpgrade: upgrade + ) } else { - try self._addHTTPClientHandlers(position: position, - leftOverBytesStrategy: leftOverBytesStrategy, - encoderConfiguration: encoderConfiguration) + try self._addHTTPClientHandlers( + position: position, + leftOverBytesStrategy: leftOverBytesStrategy, + encoderConfiguration: encoderConfiguration + ) } } - private func _addHTTPClientHandlers(position: ChannelPipeline.Position, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy, - encoderConfiguration: HTTPRequestEncoder.Configuration) throws { + private func _addHTTPClientHandlers( + position: ChannelPipeline.Position, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy, + encoderConfiguration: HTTPRequestEncoder.Configuration + ) throws { self.eventLoop.assertInEventLoop() let requestEncoder = HTTPRequestEncoder(configuration: encoderConfiguration) let responseDecoder = HTTPResponseDecoder(leftOverBytesStrategy: leftOverBytesStrategy) let requestHeaderValidator = NIOHTTPRequestHeadersValidator() - let handlers: [ChannelHandler] = [requestEncoder, ByteToMessageHandler(responseDecoder), requestHeaderValidator] + let handlers: [ChannelHandler] = [ + requestEncoder, ByteToMessageHandler(responseDecoder), requestHeaderValidator, + ] try self.addHandlers(handlers, position: position) } - private func _addHTTPClientHandlersFallback(position: ChannelPipeline.Position, - leftOverBytesStrategy: RemoveAfterUpgradeStrategy, - enableOutboundHeaderValidation: Bool, - encoderConfiguration: HTTPRequestEncoder.Configuration, - withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration?) throws { + private func _addHTTPClientHandlersFallback( + position: ChannelPipeline.Position, + leftOverBytesStrategy: RemoveAfterUpgradeStrategy, + enableOutboundHeaderValidation: Bool, + encoderConfiguration: HTTPRequestEncoder.Configuration, + withClientUpgrade upgrade: NIOHTTPClientUpgradeConfiguration? + ) throws { self.eventLoop.assertInEventLoop() let requestEncoder = HTTPRequestEncoder(configuration: encoderConfiguration) let responseDecoder = HTTPResponseDecoder(leftOverBytesStrategy: leftOverBytesStrategy) @@ -452,9 +510,11 @@ extension ChannelPipeline.SynchronousOperations { } if let upgrade = upgrade { - let upgrader = NIOHTTPClientUpgradeHandler(upgraders: upgrade.upgraders, - httpHandlers: handlers, - upgradeCompletionHandler: upgrade.completionHandler) + let upgrader = NIOHTTPClientUpgradeHandler( + upgraders: upgrade.upgraders, + httpHandlers: handlers, + upgradeCompletionHandler: upgrade.completionHandler + ) handlers.append(upgrader) } @@ -487,10 +547,12 @@ extension ChannelPipeline.SynchronousOperations { /// failure to parse the HTTP request) by sending 400 errors. Defaults to `true`. /// - throws: If the pipeline could not be configured. @preconcurrency - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true) throws { + public func configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true + ) throws { try self._configureHTTPServerPipeline( position: position, withPipeliningAssistance: pipelining, @@ -529,11 +591,13 @@ extension ChannelPipeline.SynchronousOperations { /// - headerValidation: Whether to validate outbound request headers to confirm that they meet /// spec compliance. Defaults to `true`. /// - throws: If the pipeline could not be configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true, - withOutboundHeaderValidation headerValidation: Bool = true) throws { + public func configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true, + withOutboundHeaderValidation headerValidation: Bool = true + ) throws { try self._configureHTTPServerPipeline( position: position, withPipeliningAssistance: pipelining, @@ -574,12 +638,14 @@ extension ChannelPipeline.SynchronousOperations { /// spec compliance. Defaults to `true`. /// - encoderConfiguration: The configuration for the ``HTTPRequestEncoder``. /// - throws: If the pipeline could not be configured. - public func configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true, - withOutboundHeaderValidation headerValidation: Bool = true, - withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration) throws { + public func configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true, + withOutboundHeaderValidation headerValidation: Bool = true, + withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration + ) throws { try self._configureHTTPServerPipeline( position: position, withPipeliningAssistance: pipelining, @@ -590,12 +656,14 @@ extension ChannelPipeline.SynchronousOperations { ) } - private func _configureHTTPServerPipeline(position: ChannelPipeline.Position = .last, - withPipeliningAssistance pipelining: Bool = true, - withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, - withErrorHandling errorHandling: Bool = true, - withOutboundHeaderValidation headerValidation: Bool = true, - withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init()) throws { + private func _configureHTTPServerPipeline( + position: ChannelPipeline.Position = .last, + withPipeliningAssistance pipelining: Bool = true, + withServerUpgrade upgrade: NIOHTTPServerUpgradeConfiguration? = nil, + withErrorHandling errorHandling: Bool = true, + withOutboundHeaderValidation headerValidation: Bool = true, + withEncoderConfiguration encoderConfiguration: HTTPResponseEncoder.Configuration = .init() + ) throws { self.eventLoop.assertInEventLoop() let responseEncoder = HTTPResponseEncoder(configuration: encoderConfiguration) @@ -616,10 +684,12 @@ extension ChannelPipeline.SynchronousOperations { } if let (upgraders, completionHandler) = upgrade { - let upgrader = HTTPServerUpgradeHandler(upgraders: upgraders, - httpEncoder: responseEncoder, - extraHTTPHandlers: Array(handlers.dropFirst()), - upgradeCompletionHandler: completionHandler) + let upgrader = HTTPServerUpgradeHandler( + upgraders: upgraders, + httpEncoder: responseEncoder, + extraHTTPHandlers: Array(handlers.dropFirst()), + upgradeCompletionHandler: completionHandler + ) handlers.append(upgrader) } diff --git a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift index 110a34e7d9..c6626dd0e7 100644 --- a/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerPipelineHandler.swift @@ -96,7 +96,7 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self.line = line } - public static func ==(lhs: ConnectionStateError, rhs: ConnectionStateError) -> Bool { + public static func == (lhs: ConnectionStateError, rhs: ConnectionStateError) -> Bool { lhs.base == rhs.base } @@ -151,7 +151,7 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha self = .requestAndResponseEndPending return .none case .requestAndResponseEndPending, .responseEndPending, .requestEndPending, - .sentCloseOutputRequestEndPending, .sentCloseOutput: + .sentCloseOutputRequestEndPending, .sentCloseOutput: let message = "received request head in state \(self)" self = .preconditionFailed return .warnPreconditionViolated(message: message) @@ -298,10 +298,11 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } private func deliverOneMessage(context: ChannelHandlerContext, data: NIOAny) -> ConnectionStateAction { - self.checkAssertion(self.lifecycleState != .quiescingLastRequestEndReceived && - self.lifecycleState != .quiescingCompleted, - "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)") - let msg = Self.unwrapInboundIn(data) + self.checkAssertion( + self.lifecycleState != .quiescingLastRequestEndReceived && self.lifecycleState != .quiescingCompleted, + "deliverOneMessage called in lifecycle illegal state \(self.lifecycleState)" + ) + let msg = self.unwrapInboundIn(data) debugOnly { switch msg { @@ -354,16 +355,18 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case is ChannelShouldQuiesceEvent: - self.checkAssertion(self.lifecycleState == .acceptingEvents, - "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)") + self.checkAssertion( + self.lifecycleState == .acceptingEvents, + "unexpected lifecycle state when receiving ChannelShouldQuiesceEvent: \(self.lifecycleState)" + ) switch self.state { case .responseEndPending: // we're not in the middle of a request, let's just shut the door self.lifecycleState = .quiescingLastRequestEndReceived self.eventBuffer.removeAll() case .preconditionFailed, - // An invariant has been violated already, this time we close the connection - .idle, .sentCloseOutput: + // An invariant has been violated already, this time we close the connection + .idle, .sentCloseOutput: // we're completely idle, let's just close self.lifecycleState = .quiescingCompleted self.eventBuffer.removeAll() @@ -400,8 +403,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - self.checkAssertion(self.state != .requestEndPending, - "Received second response while waiting for first one to complete") + self.checkAssertion( + self.state != .requestEndPending, + "Received second response while waiting for first one to complete" + ) debugOnly { let res = Self.unwrapOutboundIn(data) switch res { @@ -502,7 +507,6 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha } } - switch self.lifecycleState { case .quiescingLastRequestEndReceived, .quiescingWaitingForRequestEnd: context.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -558,9 +562,9 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha /// - Returns: True if an error was sent, ie the caller should not continue private func handleConnectionStateAction( - context: ChannelHandlerContext, - action: ConnectionStateAction, - promise: EventLoopPromise? + context: ChannelHandlerContext, + action: ConnectionStateAction, + promise: EventLoopPromise? ) -> Bool { switch action { case .warnPreconditionViolated(let message): @@ -570,7 +574,8 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha promise?.fail(error) return true case .forceCloseConnection: - let message = "The connection has been forcefully closed because further IO was attempted after a precondition was violated" + let message = + "The connection has been forcefully closed because further IO was attempted after a precondition was violated" let error = ConnectionStateError.preconditionViolated(message: message) promise?.fail(error) self.close(context: context, mode: .all, promise: nil) @@ -669,7 +674,12 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha /// This is currently the only way to do this in Swift: see /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. private func debugOnly(_ body: () -> Void) { - self.checkAssertion({ body(); return true }()) + self.checkAssertion( + { + body() + return true + }() + ) } /// Calls assertionFailure if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests @@ -681,10 +691,10 @@ public final class HTTPServerPipelineHandler: ChannelDuplexHandler, RemovableCha /// Calls assert if and only if `self.failOnPreconditions` is true. This allows us to avoid terminating the program in tests private func checkAssertion( - _ closure: @autoclosure () -> Bool, - _ message: @autoclosure () -> String = String(), - file: StaticString = #file, - line: UInt = #line + _ closure: @autoclosure () -> Bool, + _ message: @autoclosure () -> String = String(), + file: StaticString = #file, + line: UInt = #line ) { if self.failOnPreconditions { assert(closure(), message(), file: file, line: line) diff --git a/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift index 535a008288..0c691cc5df 100644 --- a/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift @@ -26,7 +26,6 @@ public enum HTTPServerUpgradeEvents: Sendable { case upgradeComplete(toProtocol: String, upgradeRequest: HTTPRequestHead) } - /// An object that implements `HTTPServerProtocolUpgrader` knows how to handle HTTP upgrade to /// a protocol on a server-side channel. public protocol HTTPServerProtocolUpgrader { @@ -41,7 +40,11 @@ public protocol HTTPServerProtocolUpgrader { /// Builds the upgrade response headers. Should return any headers that need to be supplied to the client /// in the 101 Switching Protocols response. If upgrade cannot proceed for any reason, this function should /// fail the future. - func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture + func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel pipeline /// to add whatever channel handlers are required. Until the returned `EventLoopFuture` succeeds, all received @@ -62,7 +65,7 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart - + private let upgraders: [String: HTTPServerProtocolUpgrader] private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void @@ -130,9 +133,13 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha // We were re-entrantly called while delivering the request head. We can just pass this through. context.fireChannelRead(data) case .upgradeComplete: - preconditionFailure("Upgrade has completed but we have not seen a whole request and still got re-entrantly called.") + preconditionFailure( + "Upgrade has completed but we have not seen a whole request and still got re-entrantly called." + ) case .upgrading: - preconditionFailure("We think we saw .end before and began upgrading, but somehow we have not set seenFirstRequest") + preconditionFailure( + "We think we saw .end before and began upgrading, but somehow we have not set seenFirstRequest" + ) } } @@ -174,7 +181,7 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha self.upgradeState = .awaitingUpgrader self.handleUpgrade(context: context, request: request, requestedProtocols: requestedProtocols) - .hop(to: context.eventLoop) // the user might return a future from another EventLoop. + .hop(to: context.eventLoop) // the user might return a future from another EventLoop. .whenSuccess { callback in context.eventLoop.assertInEventLoop() if let callback = callback { @@ -182,25 +189,41 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha } else { self.notUpgrading(context: context, data: requestPart) } - } + } } /// The core of the upgrade handling logic. /// /// - returns: An `EventLoopFuture` that will contain a callback to invoke if upgrade is requested, or nil if upgrade has failed. Never returns a failed future. - private func handleUpgrade(context: ChannelHandlerContext, request: HTTPRequestHead, requestedProtocols: [String]) -> EventLoopFuture<(() -> Void)?> { + private func handleUpgrade( + context: ChannelHandlerContext, + request: HTTPRequestHead, + requestedProtocols: [String] + ) -> EventLoopFuture<(() -> Void)?> { let connectionHeader = Set(request.headers[canonicalForm: "connection"].map { $0.lowercased() }) let allHeaderNames = Set(request.headers.map { $0.name.lowercased() }) // We now set off a chain of Futures to try to find a protocol upgrade. While this is blocking, we need to buffer inbound data. let protocolIterator = requestedProtocols.makeIterator() - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } /// Attempt to upgrade a single protocol. /// /// Will recurse through `protocolIterator` if upgrade fails. - private func handleUpgradeForProtocol(context: ChannelHandlerContext, protocolIterator: Array.Iterator, request: HTTPRequestHead, allHeaderNames: Set, connectionHeader: Set) -> EventLoopFuture<(() -> Void)?> { + private func handleUpgradeForProtocol( + context: ChannelHandlerContext, + protocolIterator: Array.Iterator, + request: HTTPRequestHead, + allHeaderNames: Set, + connectionHeader: Set + ) -> EventLoopFuture<(() -> Void)?> { // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. var protocolIterator = protocolIterator guard let proto = protocolIterator.next() else { @@ -209,17 +232,33 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha } guard let upgrader = self.upgraders[proto.lowercased()] else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } let responseHeaders = self.buildUpgradeHeaders(protocol: proto) - return upgrader.buildUpgradeResponse(channel: context.channel, upgradeRequest: request, initialResponseHeaders: responseHeaders).map { finalResponseHeaders in - return { + return upgrader.buildUpgradeResponse( + channel: context.channel, + upgradeRequest: request, + initialResponseHeaders: responseHeaders + ).map { finalResponseHeaders in + { // Ok, we're upgrading. self.upgradeState = .upgrading @@ -233,7 +272,11 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha // our final cleanup steps, namely we replay the received data we buffered in the meantime and // then remove ourselves from the pipeline. self.removeExtraHandlers(context: context).flatMap { - self.sendUpgradeResponse(context: context, upgradeRequest: request, responseHeaders: finalResponseHeaders) + self.sendUpgradeResponse( + context: context, + upgradeRequest: request, + responseHeaders: finalResponseHeaders + ) }.flatMap { context.pipeline.syncOperations.removeHandler(self.httpEncoder) }.flatMap { () -> EventLoopFuture in @@ -242,7 +285,9 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha }.whenComplete { result in switch result { case .success: - context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request)) + context.fireUserInboundEventTriggered( + HTTPServerUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request) + ) self.upgradeState = .upgradeComplete // When we remove ourselves we'll be delivering any buffered data. context.pipeline.removeHandler(context: context, promise: nil) @@ -256,7 +301,13 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha }.flatMapError { error in // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. context.fireErrorCaught(error) - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } } @@ -275,7 +326,11 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha } /// Sends the 101 Switching Protocols response for the pipeline. - private func sendUpgradeResponse(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead, responseHeaders: HTTPHeaders) -> EventLoopFuture { + private func sendUpgradeResponse( + context: ChannelHandlerContext, + upgradeRequest: HTTPRequestHead, + responseHeaders: HTTPHeaders + ) -> EventLoopFuture { var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) response.headers = responseHeaders return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) @@ -307,7 +362,7 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { - return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) + HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) } /// Removes any extra HTTP-related handlers from the channel pipeline. @@ -316,8 +371,10 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha return context.eventLoop.makeSucceededFuture(()) } - return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, - on: context.eventLoop) + return .andAllSucceed( + self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, + on: context.eventLoop + ) } } diff --git a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift index 9021062488..0285c07ea7 100644 --- a/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift +++ b/Sources/NIOHTTP1/HTTPTypedPipelineSetup.swift @@ -223,7 +223,9 @@ extension ChannelPipeline.SynchronousOperations { self.eventLoop.assertInEventLoop() let requestEncoder = HTTPRequestEncoder(configuration: configuration.encoderConfiguration) - let responseDecoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy)) + let responseDecoder = ByteToMessageHandler( + HTTPResponseDecoder(leftOverBytesStrategy: configuration.leftOverBytesStrategy) + ) var httpHandlers = [RemovableChannelHandler]() httpHandlers.reserveCapacity(3) httpHandlers.append(requestEncoder) diff --git a/Sources/NIOHTTP1/HTTPTypes.swift b/Sources/NIOHTTP1/HTTPTypes.swift index 0237074137..a10abaccb2 100644 --- a/Sources/NIOHTTP1/HTTPTypes.swift +++ b/Sources/NIOHTTP1/HTTPTypes.swift @@ -48,7 +48,7 @@ public struct HTTPRequestHead: Equatable { } func copy() -> _Storage { - return .init(method: self.method, uri: self.uri, version: self.version) + .init(method: self.method, uri: self.uri, version: self.version) } } @@ -61,7 +61,7 @@ public struct HTTPRequestHead: Equatable { /// The HTTP method for this request. public var method: HTTPMethod { get { - return self._storage.method + self._storage.method } set { self.copyStorageIfNotUniquelyReferenced() @@ -72,7 +72,7 @@ public struct HTTPRequestHead: Equatable { // This request's URI. public var uri: String { get { - return self._storage.uri + self._storage.uri } set { self.copyStorageIfNotUniquelyReferenced() @@ -83,7 +83,7 @@ public struct HTTPRequestHead: Equatable { /// The version for this HTTP request. public var version: HTTPVersion { get { - return self._storage.version + self._storage.version } set { self.copyStorageIfNotUniquelyReferenced() @@ -112,11 +112,11 @@ public struct HTTPRequestHead: Equatable { self.init(version: version, method: method, uri: uri, headers: HTTPHeaders()) } - public static func ==(lhs: HTTPRequestHead, rhs: HTTPRequestHead) -> Bool { - return lhs.method == rhs.method && lhs.uri == rhs.uri && lhs.version == rhs.version && lhs.headers == rhs.headers + public static func == (lhs: HTTPRequestHead, rhs: HTTPRequestHead) -> Bool { + lhs.method == rhs.method && lhs.uri == rhs.uri && lhs.version == rhs.version && lhs.headers == rhs.headers } - private mutating func copyStorageIfNotUniquelyReferenced () { + private mutating func copyStorageIfNotUniquelyReferenced() { if !isKnownUniquelyReferenced(&self._storage) { self._storage = self._storage.copy() } @@ -126,12 +126,12 @@ public struct HTTPRequestHead: Equatable { extension HTTPRequestHead: @unchecked Sendable {} /// The parts of a complete HTTP message, representing either a request or a response. -/// +/// /// An HTTP message is made up of: /// - a request or status line with several headers, encoded by a single ``HTTPPart/head(_:)`` part, /// - zero or more ``HTTPPart/body(_:)`` parts, /// - and some optional trailers (represented as headers) in a single ``HTTPPart/end(_:)`` part. -/// +/// /// To indicate that a complete HTTP message has been sent or received, /// an ``HTTPPart/end(_:)`` part must be used, even when no trailers are included. public enum HTTPPart { @@ -139,13 +139,13 @@ public enum HTTPPart { /// /// A single part is always used to encode all headers. case head(HeadT) - + /// A part of an HTTP request or response's body. /// /// Zero or more body parts can be sent or received. The stream is finished when /// an ``HTTPPart/end(_:)`` part is received. case body(BodyT) - + /// The end of an HTTP request or response, optionally containing trailers. /// /// A single part is always used to encode all trailers. @@ -172,7 +172,7 @@ extension HTTPRequestHead { /// Whether this HTTP request is a keep-alive request: that is, whether the /// connection should remain open after the request is complete. public var isKeepAlive: Bool { - return headers.isKeepAlive(version: version) + headers.isKeepAlive(version: version) } } @@ -180,7 +180,7 @@ extension HTTPResponseHead { /// Whether this HTTP response is a keep-alive request: that is, whether the /// connection should remain open after the request is complete. public var isKeepAlive: Bool { - return self.headers.isKeepAlive(version: self.version) + self.headers.isKeepAlive(version: self.version) } } @@ -194,7 +194,7 @@ public struct HTTPResponseHead: Equatable { self.version = version } func copy() -> _Storage { - return .init(status: self.status, version: self.version) + .init(status: self.status, version: self.version) } } @@ -207,7 +207,7 @@ public struct HTTPResponseHead: Equatable { /// The HTTP response status. public var status: HTTPResponseStatus { get { - return self._storage.status + self._storage.status } set { self.copyStorageIfNotUniquelyReferenced() @@ -218,7 +218,7 @@ public struct HTTPResponseHead: Equatable { /// The HTTP version that corresponds to this response. public var version: HTTPVersion { get { - return self._storage.version + self._storage.version } set { self.copyStorageIfNotUniquelyReferenced() @@ -236,11 +236,11 @@ public struct HTTPResponseHead: Equatable { self._storage = _Storage(status: status, version: version) } - public static func ==(lhs: HTTPResponseHead, rhs: HTTPResponseHead) -> Bool { - return lhs.status == rhs.status && lhs.version == rhs.version && lhs.headers == rhs.headers + public static func == (lhs: HTTPResponseHead, rhs: HTTPResponseHead) -> Bool { + lhs.status == rhs.status && lhs.version == rhs.version && lhs.headers == rhs.headers } - private mutating func copyStorageIfNotUniquelyReferenced () { + private mutating func copyStorageIfNotUniquelyReferenced() { if !isKnownUniquelyReferenced(&self._storage) { self._storage = self._storage.copy() } @@ -257,9 +257,9 @@ extension HTTPResponseHead { } } -private extension UInt8 { - var isASCII: Bool { - return self <= 127 +extension UInt8 { + fileprivate var isASCII: Bool { + self <= 127 } } @@ -313,11 +313,11 @@ public struct HTTPHeaders: CustomStringConvertible, ExpressibleByDictionaryLiter internal var keepAliveState: KeepAliveState = .unknown public var description: String { - return self.headers.description + self.headers.description } internal var names: [String] { - return self.headers.map { $0.0 } + self.headers.map { $0.0 } } internal init(_ headers: [(String, String)], keepAliveState: KeepAliveState) { @@ -326,7 +326,7 @@ public struct HTTPHeaders: CustomStringConvertible, ExpressibleByDictionaryLiter } internal func isConnectionHeader(_ name: String) -> Bool { - return name.utf8.compareCaseInsensitiveASCIIBytes(to: "connection".utf8) + name.utf8.compareCaseInsensitiveASCIIBytes(to: "connection".utf8) } /// Construct a `HTTPHeaders` structure. @@ -442,7 +442,7 @@ public struct HTTPHeaders: CustomStringConvertible, ExpressibleByDictionaryLiter /// - Parameter name: The header field name whose values are to be retrieved. /// - Returns: A list of the values for that header field name. public subscript(name: String) -> [String] { - return self.headers.reduce(into: []) { target, lr in + self.headers.reduce(into: []) { target, lr in let (key, value) = lr if key.utf8.compareCaseInsensitiveASCIIBytes(to: name.utf8) { target.append(value) @@ -512,7 +512,7 @@ extension HTTPHeaders { /// The total number of headers that can be contained without allocating new storage. public var capacity: Int { - return self.headers.capacity + self.headers.capacity } /// Reserves enough space to store the specified number of headers. @@ -546,28 +546,28 @@ extension HTTPHeaders: RandomAccessCollection { public struct Index: Comparable { fileprivate let base: Array<(String, String)>.Index public static func < (lhs: Index, rhs: Index) -> Bool { - return lhs.base < rhs.base + lhs.base < rhs.base } } public var startIndex: HTTPHeaders.Index { - return .init(base: self.headers.startIndex) + .init(base: self.headers.startIndex) } public var endIndex: HTTPHeaders.Index { - return .init(base: self.headers.endIndex) + .init(base: self.headers.endIndex) } public func index(before i: HTTPHeaders.Index) -> HTTPHeaders.Index { - return .init(base: self.headers.index(before: i.base)) + .init(base: self.headers.index(before: i.base)) } public func index(after i: HTTPHeaders.Index) -> HTTPHeaders.Index { - return .init(base: self.headers.index(after: i.base)) + .init(base: self.headers.index(after: i.base)) } public subscript(position: HTTPHeaders.Index) -> Element { - return self.headers[position.base] + self.headers[position.base] } } @@ -575,26 +575,26 @@ extension UTF8.CodeUnit { var isASCIIWhitespace: Bool { switch self { case UInt8(ascii: " "), - UInt8(ascii: "\t"): - return true + UInt8(ascii: "\t"): + return true default: - return false + return false } } } extension String { func trimASCIIWhitespace() -> Substring { - return Substring(self).trimWhitespace() + Substring(self).trimWhitespace() } } extension Substring { fileprivate func trimWhitespace() -> Substring { guard let firstNonWhitespace = self.utf8.firstIndex(where: { !$0.isASCIIWhitespace }) else { - // The whole substring is ASCII whitespace. - return Substring() + // The whole substring is ASCII whitespace. + return Substring() } // There must be at least one non-ascii whitespace character, so banging here is safe. @@ -604,7 +604,7 @@ extension Substring { } extension HTTPHeaders: Equatable { - public static func ==(lhs: HTTPHeaders, rhs: HTTPHeaders) -> Bool { + public static func == (lhs: HTTPHeaders, rhs: HTTPHeaders) -> Bool { guard lhs.headers.count == rhs.headers.count else { return false } @@ -699,7 +699,7 @@ public struct HTTPVersion: Equatable, Sendable { /// The major version number. public var major: Int { get { - return Int(self._major) + Int(self._major) } set { self._major = UInt16(newValue) @@ -709,7 +709,7 @@ public struct HTTPVersion: Equatable, Sendable { /// The minor version number. public var minor: Int { get { - return Int(self._minor) + Int(self._minor) } set { self._minor = UInt16(newValue) @@ -756,7 +756,7 @@ extension HTTPParserError: CustomDebugStringConvertible { case .invalidHost: return "invalid host" case .invalidPort: - return "invalid port" + return "invalid port" case .invalidPath: return "invalid path" case .invalidQueryString: @@ -791,7 +791,7 @@ extension HTTPParserError: CustomDebugStringConvertible { public enum HTTPParserError: Error { case invalidCharactersUsed case trailingGarbage - /* from CHTTPParser */ + // from CHTTPParser case invalidEOFState case headerOverflow case closedConnection @@ -951,7 +951,7 @@ extension HTTPResponseStatus { return 510 case .networkAuthenticationRequired: return 511 - case .custom(code: let code, reasonPhrase: _): + case .custom(let code, reasonPhrase: _): return code } } @@ -1090,11 +1090,11 @@ extension HTTPResponseStatus { /// A HTTP response status code. public enum HTTPResponseStatus: Sendable { - /* use custom if you want to use a non-standard response code or - have it available in a (UInt, String) pair from a higher-level web framework. */ + // use custom if you want to use a non-standard response code or + // have it available in a (UInt, String) pair from a higher-level web framework. case custom(code: UInt, reasonPhrase: String) - /* all the codes from http://www.iana.org/assignments/http-status-codes */ + // all the codes from http://www.iana.org/assignments/http-status-codes // 1xx case `continue` @@ -1171,11 +1171,11 @@ public enum HTTPResponseStatus: Sendable { public var mayHaveResponseBody: Bool { switch self { case .`continue`, - .switchingProtocols, - .processing, - .noContent, - .notModified, - .custom where (code < 200) && (code >= 100): + .switchingProtocols, + .processing, + .noContent, + .notModified, + .custom where (code < 200) && (code >= 100): return false default: return true @@ -1321,194 +1321,194 @@ extension HTTPResponseStatus: Hashable {} extension HTTPRequestHead: CustomStringConvertible { public var description: String { - return "HTTPRequestHead { method: \(self.method), uri: \"\(self.uri)\", version: \(self.version), headers: \(self.headers) }" + "HTTPRequestHead { method: \(self.method), uri: \"\(self.uri)\", version: \(self.version), headers: \(self.headers) }" } } extension HTTPResponseStatus: CustomStringConvertible { public var description: String { - return "\(self.code) \(self.reasonPhrase)" + "\(self.code) \(self.reasonPhrase)" } } extension HTTPResponseHead: CustomStringConvertible { public var description: String { - return "HTTPResponseHead { version: \(self.version), status: \(self.status), headers: \(self.headers) }" + "HTTPResponseHead { version: \(self.version), status: \(self.status), headers: \(self.headers) }" } } extension HTTPVersion: CustomStringConvertible { public var description: String { - return "HTTP/\(self.major).\(self.minor)" + "HTTP/\(self.major).\(self.minor)" } } extension HTTPMethod: RawRepresentable { public var rawValue: String { switch self { - case .GET: - return "GET" - case .PUT: - return "PUT" - case .ACL: - return "ACL" - case .HEAD: - return "HEAD" - case .POST: - return "POST" - case .COPY: - return "COPY" - case .LOCK: - return "LOCK" - case .MOVE: - return "MOVE" - case .BIND: - return "BIND" - case .LINK: - return "LINK" - case .PATCH: - return "PATCH" - case .TRACE: - return "TRACE" - case .MKCOL: - return "MKCOL" - case .MERGE: - return "MERGE" - case .PURGE: - return "PURGE" - case .NOTIFY: - return "NOTIFY" - case .SEARCH: - return "SEARCH" - case .UNLOCK: - return "UNLOCK" - case .REBIND: - return "REBIND" - case .UNBIND: - return "UNBIND" - case .REPORT: - return "REPORT" - case .DELETE: - return "DELETE" - case .UNLINK: - return "UNLINK" - case .CONNECT: - return "CONNECT" - case .MSEARCH: - return "MSEARCH" - case .OPTIONS: - return "OPTIONS" - case .PROPFIND: - return "PROPFIND" - case .CHECKOUT: - return "CHECKOUT" - case .PROPPATCH: - return "PROPPATCH" - case .SUBSCRIBE: - return "SUBSCRIBE" - case .MKCALENDAR: - return "MKCALENDAR" - case .MKACTIVITY: - return "MKACTIVITY" - case .UNSUBSCRIBE: - return "UNSUBSCRIBE" - case .SOURCE: - return "SOURCE" - case let .RAW(value): - return value + case .GET: + return "GET" + case .PUT: + return "PUT" + case .ACL: + return "ACL" + case .HEAD: + return "HEAD" + case .POST: + return "POST" + case .COPY: + return "COPY" + case .LOCK: + return "LOCK" + case .MOVE: + return "MOVE" + case .BIND: + return "BIND" + case .LINK: + return "LINK" + case .PATCH: + return "PATCH" + case .TRACE: + return "TRACE" + case .MKCOL: + return "MKCOL" + case .MERGE: + return "MERGE" + case .PURGE: + return "PURGE" + case .NOTIFY: + return "NOTIFY" + case .SEARCH: + return "SEARCH" + case .UNLOCK: + return "UNLOCK" + case .REBIND: + return "REBIND" + case .UNBIND: + return "UNBIND" + case .REPORT: + return "REPORT" + case .DELETE: + return "DELETE" + case .UNLINK: + return "UNLINK" + case .CONNECT: + return "CONNECT" + case .MSEARCH: + return "MSEARCH" + case .OPTIONS: + return "OPTIONS" + case .PROPFIND: + return "PROPFIND" + case .CHECKOUT: + return "CHECKOUT" + case .PROPPATCH: + return "PROPPATCH" + case .SUBSCRIBE: + return "SUBSCRIBE" + case .MKCALENDAR: + return "MKCALENDAR" + case .MKACTIVITY: + return "MKACTIVITY" + case .UNSUBSCRIBE: + return "UNSUBSCRIBE" + case .SOURCE: + return "SOURCE" + case let .RAW(value): + return value } } public init(rawValue: String) { switch rawValue { - case "GET": - self = .GET - case "PUT": - self = .PUT - case "ACL": - self = .ACL - case "HEAD": - self = .HEAD - case "POST": - self = .POST - case "COPY": - self = .COPY - case "LOCK": - self = .LOCK - case "MOVE": - self = .MOVE - case "BIND": - self = .BIND - case "LINK": - self = .LINK - case "PATCH": - self = .PATCH - case "TRACE": - self = .TRACE - case "MKCOL": - self = .MKCOL - case "MERGE": - self = .MERGE - case "PURGE": - self = .PURGE - case "NOTIFY": - self = .NOTIFY - case "SEARCH": - self = .SEARCH - case "UNLOCK": - self = .UNLOCK - case "REBIND": - self = .REBIND - case "UNBIND": - self = .UNBIND - case "REPORT": - self = .REPORT - case "DELETE": - self = .DELETE - case "UNLINK": - self = .UNLINK - case "CONNECT": - self = .CONNECT - case "MSEARCH": - self = .MSEARCH - case "OPTIONS": - self = .OPTIONS - case "PROPFIND": - self = .PROPFIND - case "CHECKOUT": - self = .CHECKOUT - case "PROPPATCH": - self = .PROPPATCH - case "SUBSCRIBE": - self = .SUBSCRIBE - case "MKCALENDAR": - self = .MKCALENDAR - case "MKACTIVITY": - self = .MKACTIVITY - case "UNSUBSCRIBE": - self = .UNSUBSCRIBE - case "SOURCE": - self = .SOURCE - default: - self = .RAW(value: rawValue) + case "GET": + self = .GET + case "PUT": + self = .PUT + case "ACL": + self = .ACL + case "HEAD": + self = .HEAD + case "POST": + self = .POST + case "COPY": + self = .COPY + case "LOCK": + self = .LOCK + case "MOVE": + self = .MOVE + case "BIND": + self = .BIND + case "LINK": + self = .LINK + case "PATCH": + self = .PATCH + case "TRACE": + self = .TRACE + case "MKCOL": + self = .MKCOL + case "MERGE": + self = .MERGE + case "PURGE": + self = .PURGE + case "NOTIFY": + self = .NOTIFY + case "SEARCH": + self = .SEARCH + case "UNLOCK": + self = .UNLOCK + case "REBIND": + self = .REBIND + case "UNBIND": + self = .UNBIND + case "REPORT": + self = .REPORT + case "DELETE": + self = .DELETE + case "UNLINK": + self = .UNLINK + case "CONNECT": + self = .CONNECT + case "MSEARCH": + self = .MSEARCH + case "OPTIONS": + self = .OPTIONS + case "PROPFIND": + self = .PROPFIND + case "CHECKOUT": + self = .CHECKOUT + case "PROPPATCH": + self = .PROPPATCH + case "SUBSCRIBE": + self = .SUBSCRIBE + case "MKCALENDAR": + self = .MKCALENDAR + case "MKACTIVITY": + self = .MKACTIVITY + case "UNSUBSCRIBE": + self = .UNSUBSCRIBE + case "SOURCE": + self = .SOURCE + default: + self = .RAW(value: rawValue) } } } extension HTTPResponseHead { internal var contentLength: Int? { - return headers.contentLength + headers.contentLength } } extension HTTPRequestHead { internal var contentLength: Int? { - return headers.contentLength + headers.contentLength } } extension HTTPHeaders { internal var contentLength: Int? { - return self.first(name: "content-length").flatMap { Int($0) } + self.first(name: "content-length").flatMap { Int($0) } } } diff --git a/Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift index 2cb96824a6..b092512f16 100644 --- a/Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift @@ -27,9 +27,9 @@ public struct NIOHTTPClientUpgradeError: Hashable, Error { case receivedResponseBeforeRequestSent case receivedResponseAfterUpgradeCompleted } - + private var code: Code - + private init(_ code: Code) { self.code = code } @@ -38,15 +38,21 @@ public struct NIOHTTPClientUpgradeError: Hashable, Error { public static let invalidHTTPOrdering = NIOHTTPClientUpgradeError(.invalidHTTPOrdering) public static let upgraderDeniedUpgrade = NIOHTTPClientUpgradeError(.upgraderDeniedUpgrade) public static let writingToHandlerDuringUpgrade = NIOHTTPClientUpgradeError(.writingToHandlerDuringUpgrade) - public static let writingToHandlerAfterUpgradeCompleted = NIOHTTPClientUpgradeError(.writingToHandlerAfterUpgradeCompleted) - public static let writingToHandlerAfterUpgradeFailed = NIOHTTPClientUpgradeError(.writingToHandlerAfterUpgradeFailed) + public static let writingToHandlerAfterUpgradeCompleted = NIOHTTPClientUpgradeError( + .writingToHandlerAfterUpgradeCompleted + ) + public static let writingToHandlerAfterUpgradeFailed = NIOHTTPClientUpgradeError( + .writingToHandlerAfterUpgradeFailed + ) public static let receivedResponseBeforeRequestSent = NIOHTTPClientUpgradeError(.receivedResponseBeforeRequestSent) - public static let receivedResponseAfterUpgradeCompleted = NIOHTTPClientUpgradeError(.receivedResponseAfterUpgradeCompleted) + public static let receivedResponseAfterUpgradeCompleted = NIOHTTPClientUpgradeError( + .receivedResponseAfterUpgradeCompleted + ) } extension NIOHTTPClientUpgradeError: CustomStringConvertible { public var description: String { - return String(describing: self.code) + String(describing: self.code) } } @@ -54,21 +60,21 @@ extension NIOHTTPClientUpgradeError: CustomStringConvertible { /// a protocol on a client-side channel. /// It has the option of denying this upgrade based upon the server response. public protocol NIOHTTPClientProtocolUpgrader { - + /// The protocol this upgrader knows how to support. var supportedProtocol: String { get } - + /// All the header fields the protocol requires in the request to successfully upgrade. /// These header fields will be added to the outbound request's "Connection" header field. /// It is the responsibility of the custom headers call to actually add these required headers. var requiredUpgradeHeaders: [String] { get } - + /// Additional headers to be added to the request, beyond the "Upgrade" and "Connection" headers. func addCustom(upgradeRequestHeaders: inout HTTPHeaders) - + /// Gives the receiving upgrader the chance to deny the upgrade based on the upgrade HTTP response. func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool - + /// Called when the upgrade response has been flushed. At this time it is safe to mutate the channel /// pipeline to add whatever channel handlers are required. /// Until the returned `EventLoopFuture` succeeds, all received data will be buffered. @@ -85,24 +91,24 @@ public protocol NIOHTTPClientProtocolUpgrader { /// It will only upgrade to the protocol that is returned first in the list and does not currently /// have the capability to upgrade to multiple simultaneous layered protocols. public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { - + public typealias OutboundIn = HTTPClientRequestPart public typealias OutboundOut = HTTPClientRequestPart public typealias InboundIn = HTTPClientResponsePart public typealias InboundOut = HTTPClientResponsePart - + private var upgraders: [NIOHTTPClientProtocolUpgrader] private let httpHandlers: [RemovableChannelHandler] private let upgradeCompletionHandler: (ChannelHandlerContext) -> Void - + /// Whether we've already seen the first response from our initial upgrade request. private var seenFirstResponse = false - + private var upgradeState: UpgradeState = .requestRequired - + private var receivedMessages: CircularBuffer = CircularBuffer() - + /// Create a `HTTPClientUpgradeHandler`. /// /// - Parameter upgraders: All `HTTPClientProtocolUpgrader` objects that will add their upgrade request @@ -121,7 +127,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC ) { self.init(_upgraders: upgraders, httpHandlers: httpHandlers, upgradeCompletionHandler: upgradeCompletionHandler) } - + private init( _upgraders upgraders: [NIOHTTPClientProtocolUpgrader], httpHandlers: [RemovableChannelHandler], @@ -137,7 +143,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { switch self.upgradeState { - + case .requestRequired: let updatedData = self.addHeadersToOutboundOut(data: data) context.write(updatedData, promise: promise) @@ -145,7 +151,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC case .awaitingConfirmationResponse: // Still have full http stack. context.write(data, promise: promise) - + case .upgraderReady, .upgrading: promise?.fail(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) context.fireErrorCaught(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade) @@ -154,7 +160,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC // These are most likely messages immediately fired by a new protocol handler. // As that is added last we can just forward them on. context.write(data, promise: promise) - + case .upgradeComplete: // Upgrade complete and this handler should have been removed from the pipeline. promise?.fail(NIOHTTPClientUpgradeError.writingToHandlerAfterUpgradeCompleted) @@ -168,18 +174,18 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC } private func addHeadersToOutboundOut(data: NIOAny) -> NIOAny { - + let interceptedOutgoingRequest = Self.unwrapOutboundIn(data) - + if case .head(var requestHead) = interceptedOutgoingRequest { - + self.upgradeState = .awaitingConfirmationResponse - + self.addConnectionHeaders(to: &requestHead) self.addUpgradeHeaders(to: &requestHead) return Self.wrapOutboundOut(.head(requestHead)) } - + return data } @@ -190,7 +196,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC } private func addUpgradeHeaders(to requestHead: inout HTTPRequestHead) { - + let allProtocols = self.upgraders.map { $0.supportedProtocol.lowercased() } requestHead.headers.add(name: "Upgrade", value: allProtocols.joined(separator: ",")) @@ -199,9 +205,9 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC upgrader.addCustom(upgradeRequestHeaders: &requestHead.headers) } } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - + guard !self.seenFirstResponse else { // We're waiting for upgrade to complete: buffer this data. self.receivedMessages.append(data) @@ -235,9 +241,9 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC context.fireErrorCaught(NIOHTTPClientUpgradeError.receivedResponseBeforeRequestSent) } } - + private func firstResponseHeadReceived(context: ChannelHandlerContext, responsePart: HTTPClientResponsePart) { - + // We should decide if we're can upgrade based on the first response header: if we aren't upgrading, // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're // upgrading, the only thing we should see is a response head. Anything else in an error. @@ -245,7 +251,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC self.notUpgrading(context: context, data: responsePart, error: .invalidHTTPOrdering) return } - + // Assess whether the upgrade response has accepted our upgrade request. guard case .switchingProtocols = response.status else { self.notUpgrading(context: context, data: responsePart, error: nil) @@ -260,44 +266,57 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC self.notUpgrading(context: context, data: responsePart, error: clientError) } } - - private func handleUpgrade(context: ChannelHandlerContext, upgradeResponse response: HTTPResponseHead) throws -> (() -> Void) { + + private func handleUpgrade( + context: ChannelHandlerContext, + upgradeResponse response: HTTPResponseHead + ) throws -> (() -> Void) { // Ok, we have a HTTP response. Check if it's an upgrade confirmation. // If it's not, we want to pass it on and remove ourselves from the channel pipeline. let acceptedProtocols = response.headers[canonicalForm: "upgrade"] - + // At the moment we only upgrade to the first protocol returned from the server. guard let protocolName = acceptedProtocols.first?.lowercased() else { // There are no upgrade protocols returned. throw NIOHTTPClientUpgradeError.responseProtocolNotFound } - return try self.handleUpgradeForProtocol(context: context, - protocolName: protocolName, - response: response) + return try self.handleUpgradeForProtocol( + context: context, + protocolName: protocolName, + response: response + ) } - + /// Attempt to upgrade a single protocol. - private func handleUpgradeForProtocol(context: ChannelHandlerContext, protocolName: String, response: HTTPResponseHead) throws -> (() -> Void) { + private func handleUpgradeForProtocol( + context: ChannelHandlerContext, + protocolName: String, + response: HTTPResponseHead + ) throws -> (() -> Void) { let matchingUpgrader = self.upgraders .first(where: { $0.supportedProtocol.lowercased() == protocolName }) - + guard let upgrader = matchingUpgrader else { // There is no upgrader for this protocol. throw NIOHTTPClientUpgradeError.responseProtocolNotFound } - + guard upgrader.shouldAllowUpgrade(upgradeResponse: response) else { // The upgrader says no. throw NIOHTTPClientUpgradeError.upgraderDeniedUpgrade } - + return self.performUpgrade(context: context, upgrader: upgrader, response: response) } - - private func performUpgrade(context: ChannelHandlerContext, upgrader: NIOHTTPClientProtocolUpgrader, response: HTTPResponseHead) -> () -> Void { + + private func performUpgrade( + context: ChannelHandlerContext, + upgrader: NIOHTTPClientProtocolUpgrader, + response: HTTPResponseHead + ) -> () -> Void { // Before we start the upgrade we have to remove the HTTPEncoder and HTTPDecoder handlers from the // pipeline, to prevent them parsing any more data. We'll buffer the incoming data until that completes. @@ -306,52 +325,52 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC // Once that's done, we call the internal handler, then call the upgrader code, and then finally when the // upgrader code is done, we do our final cleanup steps, namely we replay the received data we // buffered in the meantime and then remove ourselves from the pipeline. - return { + { self.upgradeState = .upgrading - + self.removeHTTPHandlers(context: context) - .map { - // Let the other handlers be removed before continuing with upgrade. - self.upgradeCompletionHandler(context) - self.upgradeState = .upgradingAddingHandlers - } - .flatMap { - upgrader.upgrade(context: context, upgradeResponse: response) - } - .map { - // We unbuffer any buffered data here. - - // If we received any, we fire readComplete. - let fireReadComplete = self.receivedMessages.count > 0 - while self.receivedMessages.count > 0 { - let bufferedPart = self.receivedMessages.removeFirst() - context.fireChannelRead(bufferedPart) + .map { + // Let the other handlers be removed before continuing with upgrade. + self.upgradeCompletionHandler(context) + self.upgradeState = .upgradingAddingHandlers } - if fireReadComplete { - context.fireChannelReadComplete() + .flatMap { + upgrader.upgrade(context: context, upgradeResponse: response) + } + .map { + // We unbuffer any buffered data here. + + // If we received any, we fire readComplete. + let fireReadComplete = self.receivedMessages.count > 0 + while self.receivedMessages.count > 0 { + let bufferedPart = self.receivedMessages.removeFirst() + context.fireChannelRead(bufferedPart) + } + if fireReadComplete { + context.fireChannelReadComplete() + } + + // We wait with the state change until _after_ the channel reads here. + // This is to prevent firing writes in response to these reads after we went to .upgradeComplete + // See: https://github.com/apple/swift-nio/issues/1279 + self.upgradeState = .upgradeComplete + } + .whenComplete { _ in + context.pipeline.removeHandler(context: context, promise: nil) } - - // We wait with the state change until _after_ the channel reads here. - // This is to prevent firing writes in response to these reads after we went to .upgradeComplete - // See: https://github.com/apple/swift-nio/issues/1279 - self.upgradeState = .upgradeComplete - } - .whenComplete { _ in - context.pipeline.removeHandler(context: context, promise: nil) - } } } - + /// Removes any extra HTTP-related handlers from the channel pipeline. private func removeHTTPHandlers(context: ChannelHandlerContext) -> EventLoopFuture { guard self.httpHandlers.count > 0 else { return context.eventLoop.makeSucceededFuture(()) } - + let removeFutures = self.httpHandlers.map { context.pipeline.removeHandler($0) } return .andAllSucceed(removeFutures, on: context.eventLoop) } - + private func gotUpgrader(upgrader: @escaping (() -> Void)) { self.upgradeState = .upgraderReady(upgrader) @@ -362,17 +381,21 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC } } - private func notUpgrading(context: ChannelHandlerContext, data: HTTPClientResponsePart, error: NIOHTTPClientUpgradeError?) { - + private func notUpgrading( + context: ChannelHandlerContext, + data: HTTPClientResponsePart, + error: NIOHTTPClientUpgradeError? + ) { + self.upgradeState = .upgradeFailed - + if let error = error { context.fireErrorCaught(error) } - + assert(self.receivedMessages.isEmpty) context.fireChannelRead(Self.wrapInboundOut(data)) - + // We've delivered the data. We can now remove ourselves, which should happen synchronously. context.pipeline.removeHandler(context: context, promise: nil) } @@ -388,19 +411,19 @@ extension NIOHTTPClientUpgradeHandler { /// Awaiting confirmation response which will allow the upgrade to zero one or more protocols. case awaitingConfirmationResponse - + /// The response head has been received. We have an upgrader, which means we can begin upgrade. case upgraderReady(() -> Void) - + /// The response head has been received. The upgrade is in process. case upgrading - + /// The upgrade is in process and all of the http handlers have been removed. case upgradingAddingHandlers - + /// The upgrade has succeeded, and we are being removed from the pipeline. case upgradeComplete - + /// The upgrade has failed. case upgradeFailed } diff --git a/Sources/NIOHTTP1/NIOHTTPObjectAggregator.swift b/Sources/NIOHTTP1/NIOHTTPObjectAggregator.swift index 0e9f4895ba..3821e6e2b4 100644 --- a/Sources/NIOHTTP1/NIOHTTPObjectAggregator.swift +++ b/Sources/NIOHTTP1/NIOHTTPObjectAggregator.swift @@ -123,7 +123,6 @@ internal enum AggregatorState { } } - mutating func messageEndReceived() throws { switch self { case .receiving: @@ -190,7 +189,7 @@ public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, Remova private var maxContentLength: Int private var closeOnExpectationFailed: Bool private var state: AggregatorState - + public init(maxContentLength: Int, closeOnExpectationFailed: Bool = false) { precondition(maxContentLength >= 0, "maxContentLength must not be negative") self.maxContentLength = maxContentLength @@ -240,7 +239,11 @@ public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, Remova } } - private func beginAggregation(context: ChannelHandlerContext, request: HTTPRequestHead, message: InboundIn) -> HTTPResponseHead? { + private func beginAggregation( + context: ChannelHandlerContext, + request: HTTPRequestHead, + message: InboundIn + ) -> HTTPResponseHead? { self.fullMessageHead = request if let contentLength = request.contentLength, contentLength > self.maxContentLength { return self.handleOversizeMessage(message: message) @@ -248,8 +251,12 @@ public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, Remova return nil } - private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer, message: InboundIn) -> HTTPResponseHead? { - if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) { + private func aggregate( + context: ChannelHandlerContext, + content: inout ByteBuffer, + message: InboundIn + ) -> HTTPResponseHead? { + if content.readableBytes > self.maxContentLength - self.buffer.readableBytes { return self.handleOversizeMessage(message: message) } else { self.buffer.writeBuffer(&content) @@ -266,8 +273,10 @@ public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, Remova aggregated.headers.add(contentsOf: headers) } - let fullMessage = NIOHTTPServerRequestFull(head: aggregated, - body: self.buffer.readableBytes > 0 ? self.buffer : nil) + let fullMessage = NIOHTTPServerRequestFull( + head: aggregated, + body: self.buffer.readableBytes > 0 ? self.buffer : nil + ) self.fullMessageHead = nil self.buffer.clear() context.fireChannelRead(NIOAny(fullMessage)) @@ -278,7 +287,8 @@ public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, Remova var payloadTooLargeHead = HTTPResponseHead( version: self.fullMessageHead?.version ?? .http1_1, status: .payloadTooLarge, - headers: HTTPHeaders([("content-length", "0")])) + headers: HTTPHeaders([("content-length", "0")]) + ) switch message { case .head(let request): @@ -369,7 +379,7 @@ public final class NIOHTTPClientResponseAggregator: ChannelInboundHandler, Remov } private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer) throws { - if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) { + if content.readableBytes > self.maxContentLength - self.buffer.readableBytes { self.state.handlingOversizeMessage() context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong) context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong) @@ -389,7 +399,8 @@ public final class NIOHTTPClientResponseAggregator: ChannelInboundHandler, Remov let fullMessage = NIOHTTPClientResponseFull( head: aggregated, - body: self.buffer.readableBytes > 0 ? self.buffer : nil) + body: self.buffer.readableBytes > 0 ? self.buffer : nil + ) self.fullMessageHead = nil self.buffer.clear() context.fireChannelRead(NIOAny(fullMessage)) diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift index a2be0227b8..30f168c4e4 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift @@ -75,7 +75,9 @@ public struct NIOTypedHTTPClientUpgradeConfiguration { /// It will only upgrade to the protocol that is returned first in the list and does not currently /// have the capability to upgrade to multiple simultaneous layered protocols. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableChannelHandler { +public final class NIOTypedHTTPClientUpgradeHandler: ChannelDuplexHandler, + RemovableChannelHandler +{ public typealias OutboundIn = HTTPClientRequestPart public typealias OutboundOut = HTTPClientRequestPart public typealias InboundIn = HTTPClientResponsePart diff --git a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift index 6e9c696811..be51bd22f5 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPClientUpgraderStateMachine.swift @@ -114,7 +114,7 @@ struct NIOTypedHTTPClientUpgraderStateMachine { case .unbuffering, .finished: return .forwardWrite - + case .modifying: fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") } @@ -157,7 +157,6 @@ struct NIOTypedHTTPClientUpgraderStateMachine { } } - @usableFromInline enum ChannelReadResponsePartAction { case fireErrorCaughtAndRemoveHandler(Error) @@ -202,7 +201,8 @@ struct NIOTypedHTTPClientUpgraderStateMachine { return .fireErrorCaughtAndRemoveHandler(NIOHTTPClientUpgradeError.responseProtocolNotFound) } - let matchingUpgrader = upgraders + let matchingUpgrader = + upgraders .first(where: { $0.supportedProtocol.lowercased() == protocolName }) guard let upgrader = matchingUpgrader else { @@ -219,10 +219,12 @@ struct NIOTypedHTTPClientUpgraderStateMachine { // We received the response head and decided that we can upgrade. // We now need to wait for the response end and then we can perform the upgrade - self.state = .awaitingUpgradeResponseEnd(.init( - upgrader: upgrader, - responseHead: response - )) + self.state = .awaitingUpgradeResponseEnd( + .init( + upgrader: upgrader, + responseHead: response + ) + ) return .none case .awaitingUpgradeResponseEnd(let awaitingUpgradeResponseEnd): @@ -248,7 +250,6 @@ struct NIOTypedHTTPClientUpgraderStateMachine { case .upgrading, .unbuffering, .finished: fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") - case .modifying: fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") } @@ -263,7 +264,8 @@ struct NIOTypedHTTPClientUpgraderStateMachine { } @inlinable - mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? + { switch self.state { case .initial, .awaitingUpgradeResponseHead, .awaitingUpgradeResponseEnd, .unbuffering: fatalError("Internal inconsistency in HTTPClientUpgradeStateMachine") diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift index 2ca3395890..dd08789a6d 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgradeHandler.swift @@ -75,7 +75,9 @@ public struct NIOTypedHTTPServerUpgradeConfiguration { /// requests that we choose to punt on it entirely and not allow it. As it happens this is mostly fine: /// the odds of someone needing to upgrade midway through the lifetime of a connection are very low. @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, RemovableChannelHandler { +public final class NIOTypedHTTPServerUpgradeHandler: ChannelInboundHandler, + RemovableChannelHandler +{ public typealias InboundIn = HTTPServerRequestPart public typealias InboundOut = HTTPServerRequestPart public typealias OutboundOut = HTTPServerResponsePart @@ -101,7 +103,7 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch } /// Create a ``NIOTypedHTTPServerUpgradeHandler``. - /// + /// /// - Parameters: /// - httpEncoder: The ``HTTPResponseEncoder`` encoding responses from this handler and which will /// be removed from the pipeline once the upgrade response is sent. This is used to ensure @@ -143,7 +145,7 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch case .unwrapData: let requestPart = Self.unwrapInboundIn(data) self.channelRead(context: context, requestPart: requestPart) - + case .fireChannelRead: context.fireChannelRead(data) @@ -209,14 +211,24 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch case .startUnbuffering(let value): if let requestHeadAndProtocol = requestHeadAndProtocol { - context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + context.fireUserInboundEventTriggered( + HTTPServerUpgradeEvents.upgradeComplete( + toProtocol: requestHeadAndProtocol.1, + upgradeRequest: requestHeadAndProtocol.0 + ) + ) } self.upgradeResultPromise.succeed(value) self.unbuffer(context: context) case .removeHandler(let value): if let requestHeadAndProtocol = requestHeadAndProtocol { - context.fireUserInboundEventTriggered(HTTPServerUpgradeEvents.upgradeComplete(toProtocol: requestHeadAndProtocol.1, upgradeRequest: requestHeadAndProtocol.0)) + context.fireUserInboundEventTriggered( + HTTPServerUpgradeEvents.upgradeComplete( + toProtocol: requestHeadAndProtocol.1, + upgradeRequest: requestHeadAndProtocol.0 + ) + ) } self.upgradeResultPromise.succeed(value) context.pipeline.removeHandler(self, promise: nil) @@ -235,7 +247,9 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch request: HTTPRequestHead, allHeaderNames: Set, connectionHeader: Set - ) -> EventLoopFuture<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?> { + ) -> EventLoopFuture< + (upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)? + > { // We want a local copy of the protocol iterator. We'll pass it to the next invocation of the function. var protocolIterator = protocolIterator guard let proto = protocolIterator.next() else { @@ -244,12 +258,24 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch } guard let upgrader = self.upgraders[proto.lowercased()] else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } let requiredHeaders = Set(upgrader.requiredUpgradeHeaders.map { $0.lowercased() }) guard requiredHeaders.isSubset(of: allHeaderNames) && requiredHeaders.isSubset(of: connectionHeader) else { - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } let responseHeaders = self.buildUpgradeHeaders(protocol: proto) @@ -263,14 +289,25 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch .flatMapError { error in // No upgrade here. We want to fire the error down the pipeline, and then try another loop iteration. context.fireErrorCaught(error) - return self.handleUpgradeForProtocol(context: context, protocolIterator: protocolIterator, request: request, allHeaderNames: allHeaderNames, connectionHeader: connectionHeader) + return self.handleUpgradeForProtocol( + context: context, + protocolIterator: protocolIterator, + request: request, + allHeaderNames: allHeaderNames, + connectionHeader: connectionHeader + ) } } private func findingUpgradeCompleted( context: ChannelHandlerContext, requestHead: HTTPRequestHead, - _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + _ result: Result< + ( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, + proto: String + )?, Error + > ) { switch self.stateMachine.findingUpgraderCompleted(requestHead: requestHead, result) { case .startUpgrading(let upgrader, let responseHeaders, let proto): @@ -325,15 +362,18 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch }.flatMap { context.pipeline.syncOperations.removeHandler(self.httpEncoder) }.flatMap { () -> EventLoopFuture in - return upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) + upgrader.upgrade(channel: context.channel, upgradeRequest: requestHead) }.hop(to: context.eventLoop) - .whenComplete { result in - self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) - } + .whenComplete { result in + self.upgradingHandlerCompleted(context: context, result, requestHeadAndProtocol: (requestHead, proto)) + } } /// Sends the 101 Switching Protocols response for the pipeline. - private func sendUpgradeResponse(context: ChannelHandlerContext, responseHeaders: HTTPHeaders) -> EventLoopFuture { + private func sendUpgradeResponse( + context: ChannelHandlerContext, + responseHeaders: HTTPHeaders + ) -> EventLoopFuture { var response = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) response.headers = responseHeaders return context.writeAndFlush(wrapOutboundOut(HTTPServerResponsePart.head(response))) @@ -341,7 +381,7 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch /// Builds the initial mandatory HTTP headers for HTTP upgrade responses. private func buildUpgradeHeaders(`protocol`: String) -> HTTPHeaders { - return HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) + HTTPHeaders([("connection", "upgrade"), ("upgrade", `protocol`)]) } /// Removes any extra HTTP-related handlers from the channel pipeline. @@ -350,8 +390,10 @@ public final class NIOTypedHTTPServerUpgradeHandler: Ch return context.eventLoop.makeSucceededFuture(()) } - return .andAllSucceed(self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, - on: context.eventLoop) + return .andAllSucceed( + self.extraHTTPHandlers.map { context.pipeline.removeHandler($0) }, + on: context.eventLoop + ) } private func unbuffer(context: ChannelHandlerContext) { diff --git a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift index bc2536f7c8..80fd018944 100644 --- a/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift +++ b/Sources/NIOHTTP1/NIOTypedHTTPServerUpgraderStateMachine.swift @@ -235,7 +235,6 @@ struct NIOTypedHTTPServerUpgraderStateMachine { case .upgrading, .unbuffering, .finished: fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - case .modifying: fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") } @@ -250,7 +249,8 @@ struct NIOTypedHTTPServerUpgraderStateMachine { } @inlinable - mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? { + mutating func upgradingHandlerCompleted(_ result: Result) -> UpgradingHandlerCompletedAction? + { switch self.state { case .initial: fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") @@ -293,7 +293,11 @@ struct NIOTypedHTTPServerUpgraderStateMachine { @usableFromInline enum FindingUpgraderCompletedAction { - case startUpgrading(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String) + case startUpgrading( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, + responseHeaders: HTTPHeaders, + proto: String + ) case runNotUpgradingInitializer case fireErrorCaughtAndStartUnbuffering(Error) case fireErrorCaughtAndRemoveHandler(Error) @@ -302,7 +306,12 @@ struct NIOTypedHTTPServerUpgraderStateMachine { @inlinable mutating func findingUpgraderCompleted( requestHead: HTTPRequestHead, - _ result: Result<(upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, proto: String)?, Error> + _ result: Result< + ( + upgrader: any NIOTypedHTTPServerProtocolUpgrader, responseHeaders: HTTPHeaders, + proto: String + )?, Error + > ) -> FindingUpgraderCompletedAction? { switch self.state { case .initial, .upgraderReady: @@ -317,13 +326,15 @@ struct NIOTypedHTTPServerUpgraderStateMachine { return .startUpgrading(upgrader: upgrader, responseHeaders: responseHeaders, proto: proto) } else { // We have not yet seen the end so we have to wait until that happens - self.state = .upgraderReady(.init( - upgrader: upgrader, - requestHead: requestHead, - responseHeaders: responseHeaders, - proto: proto, - buffer: awaitingUpgrader.buffer - )) + self.state = .upgraderReady( + .init( + upgrader: upgrader, + requestHead: requestHead, + responseHeaders: responseHeaders, + proto: proto, + buffer: awaitingUpgrader.buffer + ) + ) return nil } @@ -378,7 +389,7 @@ struct NIOTypedHTTPServerUpgraderStateMachine { case .modifying: fatalError("Internal inconsistency in HTTPServerUpgradeStateMachine") - + } } diff --git a/Sources/NIOHTTP1Client/main.swift b/Sources/NIOHTTP1Client/main.swift index 03ef928066..4afd26600b 100644 --- a/Sources/NIOHTTP1Client/main.swift +++ b/Sources/NIOHTTP1Client/main.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOHTTP1 +import NIOPosix print("Please enter line to send to the server") let line = readLine(strippingNewline: true)! @@ -21,38 +21,40 @@ let line = readLine(strippingNewline: true)! private final class HTTPEchoHandler: ChannelInboundHandler { public typealias InboundIn = HTTPClientResponsePart public typealias OutboundOut = HTTPClientRequestPart - + public func channelActive(context: ChannelHandlerContext) { print("Client connected to \(context.remoteAddress!)") - + // We are connected. It's time to send the message to the server to initialize the ping-pong sequence. - + let buffer = context.channel.allocator.buffer(string: line) var headers = HTTPHeaders() headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") headers.add(name: "Content-Length", value: "\(buffer.readableBytes)") - + // This sample only sends an echo request. // The sample server has more functionality which can be easily tested by playing with the URI. // For example, try "/dynamic/count-to-ten" or "/dynamic/client-ip" - - let requestHead = HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/dynamic/echo", - headers: headers) - + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/dynamic/echo", + headers: headers + ) + context.write(Self.wrapOutboundOut(.head(requestHead)), promise: nil) - + context.write(Self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) - + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let clientResponse = Self.unwrapInboundIn(data) - + switch clientResponse { case .head(let responseHead): print("Received status: \(responseHead.status)") @@ -79,8 +81,10 @@ let bootstrap = ClientBootstrap(group: group) // Enable SO_REUSEADDR. .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers(position: .first, - leftOverBytesStrategy: .fireError).flatMap { + channel.pipeline.addHTTPClientHandlers( + position: .first, + leftOverBytesStrategy: .fireError + ).flatMap { channel.pipeline.addHandler(HTTPEchoHandler()) } } @@ -103,14 +107,14 @@ enum ConnectTo { let connectTarget: ConnectTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ +case (.some(let h), _, .some(let p)): + // we got two arguments, let's interpret that as host and port connectTarget = .ip(host: h, port: p) case (.some(let portString), .none, _): - /* couldn't parse as number, expecting unix domain socket path */ + // couldn't parse as number, expecting unix domain socket path connectTarget = .unixDomainSocket(path: portString) case (_, .some(let p), _): - /* only one argument --> port */ + // only one argument --> port connectTarget = .ip(host: defaultHost, port: p) default: connectTarget = .ip(host: defaultHost, port: defaultPort) diff --git a/Sources/NIOHTTP1Server/main.swift b/Sources/NIOHTTP1Server/main.swift index 75b182ae8f..728512fac8 100644 --- a/Sources/NIOHTTP1Server/main.swift +++ b/Sources/NIOHTTP1Server/main.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// import NIOCore -import NIOPosix import NIOHTTP1 +import NIOPosix extension String { func chopPrefix(_ prefix: String) -> String? { @@ -34,7 +34,11 @@ extension String { } } -private func httpResponseHead(request: HTTPRequestHead, status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) -> HTTPResponseHead { +private func httpResponseHead( + request: HTTPRequestHead, + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders() +) -> HTTPResponseHead { var head = HTTPResponseHead(version: request.version, status: status, headers: headers) let connectionHeaders: [String] = head.headers[canonicalForm: "connection"].map { $0.lowercased() } @@ -117,19 +121,24 @@ private final class HTTPHandler: ChannelInboundHandler { case .end: self.state.requestComplete() let response = """ - HTTP method: \(self.infoSavedRequestHead!.method)\r - URL: \(self.infoSavedRequestHead!.uri)\r - body length: \(self.infoSavedBodyBytes)\r - headers: \(self.infoSavedRequestHead!.headers)\r - client: \(context.remoteAddress?.description ?? "zombie")\r - IO: SwiftNIO Electric Boogaloo™️\r\n - """ + HTTP method: \(self.infoSavedRequestHead!.method)\r + URL: \(self.infoSavedRequestHead!.uri)\r + body length: \(self.infoSavedBodyBytes)\r + headers: \(self.infoSavedRequestHead!.headers)\r + client: \(context.remoteAddress?.description ?? "zombie")\r + IO: SwiftNIO Electric Boogaloo™️\r\n + """ self.buffer.clear() self.buffer.writeString(response) var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "\(response.utf8.count)") - context.write(Self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil) - context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) + context.write( + Self.wrapOutboundOut( + .head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers)) + ), + promise: nil + ) + context.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) self.completeResponse(context, trailers: nil, promise: nil) } } @@ -147,7 +156,10 @@ private final class HTTPHandler: ChannelInboundHandler { if balloonInMemory { self.buffer.clear() } else { - context.writeAndFlush(Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil) + context.writeAndFlush( + Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), + promise: nil + ) } case .body(buffer: var buf): if balloonInMemory { @@ -160,8 +172,13 @@ private final class HTTPHandler: ChannelInboundHandler { if balloonInMemory { var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "\(self.buffer.readableBytes)") - context.write(Self.wrapOutboundOut(.head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers))), promise: nil) - context.write(Self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) + context.write( + Self.wrapOutboundOut( + .head(httpResponseHead(request: self.infoSavedRequestHead!, status: .ok, headers: headers)) + ), + promise: nil + ) + context.write(self.wrapOutboundOut(.body(.byteBuffer(self.buffer))), promise: nil) self.completeResponse(context, trailers: nil, promise: nil) } else { self.completeResponse(context, trailers: nil, promise: nil) @@ -169,12 +186,22 @@ private final class HTTPHandler: ChannelInboundHandler { } } - func handleJustWrite(context: ChannelHandlerContext, request: HTTPServerRequestPart, statusCode: HTTPResponseStatus = .ok, string: String, trailer: (String, String)? = nil, delay: TimeAmount = .nanoseconds(0)) { + func handleJustWrite( + context: ChannelHandlerContext, + request: HTTPServerRequestPart, + statusCode: HTTPResponseStatus = .ok, + string: String, + trailer: (String, String)? = nil, + delay: TimeAmount = .nanoseconds(0) + ) { switch request { case .head(let request): self.keepAlive = request.isKeepAlive self.state.requestReceived() - context.writeAndFlush(Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: statusCode))), promise: nil) + context.writeAndFlush( + Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: statusCode))), + promise: nil + ) case .body(buffer: _): () case .end: @@ -210,7 +237,10 @@ private final class HTTPHandler: ChannelInboundHandler { self.completeResponse(context, trailers: nil, promise: nil) } } - context.writeAndFlush(Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil) + context.writeAndFlush( + Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), + promise: nil + ) doNext() case .end: self.state.requestComplete() @@ -219,7 +249,12 @@ private final class HTTPHandler: ChannelInboundHandler { } } - func handleMultipleWrites(context: ChannelHandlerContext, request: HTTPServerRequestPart, strings: [String], delay: TimeAmount) { + func handleMultipleWrites( + context: ChannelHandlerContext, + request: HTTPServerRequestPart, + strings: [String], + delay: TimeAmount + ) { switch request { case .head(let request): self.keepAlive = request.isKeepAlive @@ -237,7 +272,10 @@ private final class HTTPHandler: ChannelInboundHandler { } } } - context.writeAndFlush(Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), promise: nil) + context.writeAndFlush( + Self.wrapOutboundOut(.head(httpResponseHead(request: request, status: .ok))), + promise: nil + ) doNext() case .end: self.state.requestComplete() @@ -249,9 +287,12 @@ private final class HTTPHandler: ChannelInboundHandler { func dynamicHandler(request reqHead: HTTPRequestHead) -> ((ChannelHandlerContext, HTTPServerRequestPart) -> Void)? { if let howLong = reqHead.uri.chopPrefix("/dynamic/write-delay/") { return { context, req in - self.handleJustWrite(context: context, - request: req, string: self.defaultResponse, - delay: Int64(howLong).map { .milliseconds($0) } ?? .seconds(0)) + self.handleJustWrite( + context: context, + request: req, + string: self.defaultResponse, + delay: Int64(howLong).map { .milliseconds($0) } ?? .seconds(0) + ) } } @@ -263,23 +304,57 @@ private final class HTTPHandler: ChannelInboundHandler { case "/dynamic/pid": return { context, req in self.handleJustWrite(context: context, request: req, string: "\(getpid())") } case "/dynamic/write-delay": - return { context, req in self.handleJustWrite(context: context, request: req, string: self.defaultResponse, delay: .milliseconds(100)) } + return { context, req in + self.handleJustWrite( + context: context, + request: req, + string: self.defaultResponse, + delay: .milliseconds(100) + ) + } case "/dynamic/info": return self.handleInfo case "/dynamic/trailers": - return { context, req in self.handleJustWrite(context: context, request: req, string: "\(getpid())\r\n", trailer: ("Trailer-Key", "Trailer-Value")) } + return { context, req in + self.handleJustWrite( + context: context, + request: req, + string: "\(getpid())\r\n", + trailer: ("Trailer-Key", "Trailer-Value") + ) + } case "/dynamic/continuous": return self.handleContinuousWrites case "/dynamic/count-to-ten": - return { self.handleMultipleWrites(context: $0, request: $1, strings: (1...10).map { "\($0)" }, delay: .milliseconds(100)) } + return { + self.handleMultipleWrites( + context: $0, + request: $1, + strings: (1...10).map { "\($0)" }, + delay: .milliseconds(100) + ) + } case "/dynamic/client-ip": - return { context, req in self.handleJustWrite(context: context, request: req, string: "\(context.remoteAddress.debugDescription)") } + return { context, req in + self.handleJustWrite( + context: context, + request: req, + string: "\(context.remoteAddress.debugDescription)" + ) + } default: - return { context, req in self.handleJustWrite(context: context, request: req, statusCode: .notFound, string: "not found") } + return { context, req in + self.handleJustWrite(context: context, request: req, statusCode: .notFound, string: "not found") + } } } - private func handleFile(context: ChannelHandlerContext, request: HTTPServerRequestPart, ioMethod: FileIOMethod, path: String) { + private func handleFile( + context: ChannelHandlerContext, + request: HTTPServerRequestPart, + ioMethod: FileIOMethod, + path: String + ) { self.buffer.clear() func sendErrorResponse(request: HTTPRequestHead, _ error: Error) { @@ -338,15 +413,17 @@ private final class HTTPHandler: ChannelInboundHandler { responseStarted = true context.write(Self.wrapOutboundOut(.head(response)), promise: nil) } - return self.fileIO.readChunked(fileRegion: region, - chunkSize: 32 * 1024, - allocator: context.channel.allocator, - eventLoop: context.eventLoop) { buffer in - if !responseStarted { - responseStarted = true - context.write(Self.wrapOutboundOut(.head(response)), promise: nil) - } - return context.writeAndFlush(Self.wrapOutboundOut(.body(.byteBuffer(buffer)))) + return self.fileIO.readChunked( + fileRegion: region, + chunkSize: 32 * 1024, + allocator: context.channel.allocator, + eventLoop: context.eventLoop + ) { buffer in + if !responseStarted { + responseStarted = true + context.write(Self.wrapOutboundOut(.head(response)), promise: nil) + } + return context.writeAndFlush(Self.wrapOutboundOut(.body(.byteBuffer(buffer)))) }.flatMap { () -> EventLoopFuture in let p = context.eventLoop.makePromise(of: Void.self) self.completeResponse(context, trailers: nil, promise: p) @@ -379,7 +456,7 @@ private final class HTTPHandler: ChannelInboundHandler { _ = try? file.close() } } - } + } case .end: self.state.requestComplete() default: @@ -387,7 +464,11 @@ private final class HTTPHandler: ChannelInboundHandler { } } - private func completeResponse(_ context: ChannelHandlerContext, trailers: HTTPHeaders?, promise: EventLoopPromise?) { + private func completeResponse( + _ context: ChannelHandlerContext, + trailers: HTTPHeaders?, + promise: EventLoopPromise? + ) { self.state.responseComplete() let promise = self.keepAlive ? promise : (promise ?? context.eventLoop.makePromise()) @@ -469,7 +550,7 @@ private final class HTTPHandler: ChannelInboundHandler { } // First argument is the program path -var arguments = CommandLine.arguments.dropFirst(0) // just to get an ArraySlice from [String] +var arguments = CommandLine.arguments.dropFirst(0) // just to get an ArraySlice from [String] var allowHalfClosure = true if arguments.dropFirst().first == .some("--disable-half-closure") { allowHalfClosure = false @@ -493,16 +574,16 @@ let htdocs: String let bindTarget: BindTo switch (arg1, arg1.flatMap(Int.init), arg2, arg2.flatMap(Int.init), arg3) { -case (.some(let h), _ , _, .some(let p), let maybeHtdocs): - /* second arg an integer --> host port [htdocs] */ +case (.some(let h), _, _, .some(let p), let maybeHtdocs): + // second arg an integer --> host port [htdocs] bindTarget = .ip(host: h, port: p) htdocs = maybeHtdocs ?? defaultHtdocs case (_, .some(let p), let maybeHtdocs, _, _): - /* first arg an integer --> port [htdocs] */ + // first arg an integer --> port [htdocs] bindTarget = .ip(host: defaultHost, port: p) htdocs = maybeHtdocs ?? defaultHtdocs case (.some(let portString), .none, let maybeHtdocs, .none, .none): - /* couldn't parse as number --> uds-path-or-stdio [htdocs] */ + // couldn't parse as number --> uds-path-or-stdio [htdocs] if portString == "-" { bindTarget = .stdio } else { @@ -515,7 +596,7 @@ default: } func childChannelInitializer(channel: Channel) -> EventLoopFuture { - return channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { channel.pipeline.addHandler(HTTPHandler(fileIO: fileIO, htdocsPath: htdocs)) } } @@ -557,7 +638,9 @@ if case .stdio = bindTarget { localAddress = "STDIO" } else { guard let channelLocalAddress = channel.localAddress else { - fatalError("Address was unable to bind. Please check that the socket was not closed or that the address family was understood.") + fatalError( + "Address was unable to bind. Please check that the socket was not closed or that the address family was understood." + ) } localAddress = "\(channelLocalAddress)" } diff --git a/Sources/NIOMulticastChat/main.swift b/Sources/NIOMulticastChat/main.swift index 60cf45486d..60818b0f3f 100644 --- a/Sources/NIOMulticastChat/main.swift +++ b/Sources/NIOMulticastChat/main.swift @@ -32,7 +32,6 @@ private final class ChatMessageDecoder: ChannelInboundHandler { } } - private final class ChatMessageEncoder: ChannelOutboundHandler { public typealias OutboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope @@ -40,15 +39,18 @@ private final class ChatMessageEncoder: ChannelOutboundHandler { func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { let message = Self.unwrapOutboundIn(data) let buffer = context.channel.allocator.buffer(string: message.data) - context.write(Self.wrapOutboundOut(AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer)), promise: promise) + context.write( + Self.wrapOutboundOut(AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer)), + promise: promise + ) } } - // We allow users to specify the interface they want to use here. let targetDevice: NIONetworkDevice? = { if let interfaceAddress = CommandLine.arguments.dropFirst().first, - let targetAddress = try? SocketAddress(ipAddress: interfaceAddress, port: 0) { + let targetAddress = try? SocketAddress(ipAddress: interfaceAddress, port: 0) + { for device in try! System.enumerateDevices() { if device.address == targetAddress { return device @@ -59,7 +61,6 @@ let targetDevice: NIONetworkDevice? = { return nil }() - // For this chat protocol we temporarily squat on 224.1.0.26. This is a reserved multicast IPv4 address, // so your machine is unlikely to have already joined this group. That helps properly demonstrate correct // operation. We use port 7654 because, well, because why not. @@ -70,13 +71,14 @@ let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) var datagramBootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in - return channel.pipeline.addHandler(ChatMessageEncoder()).flatMap { + channel.pipeline.addHandler(ChatMessageEncoder()).flatMap { channel.pipeline.addHandler(ChatMessageDecoder()) } } - // We cast our channel to MulticastChannel to obtain the multicast operations. -let datagramChannel = try datagramBootstrap +// We cast our channel to MulticastChannel to obtain the multicast operations. +let datagramChannel = + try datagramBootstrap .bind(host: "0.0.0.0", port: 7654) .flatMap { channel -> EventLoopFuture in let channel = channel as! MulticastChannel @@ -95,7 +97,9 @@ let datagramChannel = try datagramBootstrap case .some(.unixDomainSocket): preconditionFailure("Should not be possible to create a multicast socket on a unix domain socket") case .none: - preconditionFailure("Should not be possible to create a multicast socket on an interface without an address") + preconditionFailure( + "Should not be possible to create a multicast socket on an interface without an address" + ) } }.wait() diff --git a/Sources/NIOPerformanceTester/Benchmark.swift b/Sources/NIOPerformanceTester/Benchmark.swift index f17a35e649..0ab5f5fc7c 100644 --- a/Sources/NIOPerformanceTester/Benchmark.swift +++ b/Sources/NIOPerformanceTester/Benchmark.swift @@ -26,7 +26,7 @@ func measureAndPrint(desc: String, benchmark bench: B) throws { bench.tearDown() } try measureAndPrint(desc: desc) { - return try bench.run() + try bench.run() } } @@ -48,7 +48,7 @@ func measureAndPrint(desc: String, benchmark bench: B) throws bench.tearDown() } try await measureAndPrint(desc: desc) { - return try await bench.run() + try await bench.run() } } group.leave() diff --git a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift index 489af7b4a1..b8a2691101 100644 --- a/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift +++ b/Sources/NIOPerformanceTester/ByteBufferWriteMultipleBenchmarks.swift @@ -67,7 +67,16 @@ final class ByteBufferMultiReadWriteTenIntegersBenchmark: for _ in 0.. Int { try! self.loop.submit { diff --git a/Sources/NIOPerformanceTester/LockBenchmark.swift b/Sources/NIOPerformanceTester/LockBenchmark.swift index 114062bc56..be23c10b43 100644 --- a/Sources/NIOPerformanceTester/LockBenchmark.swift +++ b/Sources/NIOPerformanceTester/LockBenchmark.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import NIOCore -import NIOPosix import Dispatch import NIOConcurrencyHelpers +import NIOCore +import NIOPosix final class NIOLockBenchmark: Benchmark { private let numberOfThreads: Int @@ -28,23 +28,23 @@ final class NIOLockBenchmark: Benchmark { private var opsDone = 0 private let lock = NIOLock() - + init(numberOfThreads: Int, lockOperationsPerThread: Int) { self.numberOfThreads = numberOfThreads self.lockOperationsPerThread = lockOperationsPerThread self.threadPool = NIOThreadPool(numberOfThreads: numberOfThreads) self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) } - + func setUp() throws { self.threadPool.start() } - + func tearDown() { try! self.threadPool.syncShutdownGracefully() try! self.group.syncShutdownGracefully() } - + func run() throws -> Int { self.lock.withLock { self.opsDone = 0 @@ -53,13 +53,13 @@ final class NIOLockBenchmark: Benchmark { _ = self.threadPool.runIfActive(eventLoop: self.group.next()) { self.sem1.signal() self.sem2.wait() - - for _ in 0 ..< self.lockOperationsPerThread { + + for _ in 0.. + fileprivate typealias SequenceProducer = NIOThrowingAsyncSequenceProducer< + Int, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncSequenceProducerBenchmark + > private let iterations: Int private var iterator: SequenceProducer.AsyncIterator! diff --git a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift index ab14f33182..3443edef19 100644 --- a/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift +++ b/Sources/NIOPerformanceTester/NIOAsyncWriterSingleWritesBenchmark.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import NIOCore -import DequeModule import Atomics +import DequeModule +import NIOCore -fileprivate struct NoOpDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { +private struct NoOpDelegate: NIOAsyncWriterSinkDelegate, @unchecked Sendable { typealias Element = Int let counter = ManagedAtomic(0) @@ -38,7 +38,11 @@ final class NIOAsyncWriterSingleWritesBenchmark: AsyncBenchmark, @unchecked Send init(iterations: Int) { self.iterations = iterations self.delegate = .init() - let newWriter = NIOAsyncWriter.makeWriter(isWritable: true, finishOnDeinit: false, delegate: self.delegate) + let newWriter = NIOAsyncWriter.makeWriter( + isWritable: true, + finishOnDeinit: false, + delegate: self.delegate + ) self.writer = newWriter.writer self.sink = newWriter.sink } diff --git a/Sources/NIOPerformanceTester/RunIfActiveBenchmark.swift b/Sources/NIOPerformanceTester/RunIfActiveBenchmark.swift index f09a266971..7cebce2763 100644 --- a/Sources/NIOPerformanceTester/RunIfActiveBenchmark.swift +++ b/Sources/NIOPerformanceTester/RunIfActiveBenchmark.swift @@ -40,13 +40,13 @@ final class RunIfActiveBenchmark: Benchmark { let semaphore = DispatchSemaphore(value: 0) let eventLoop = MultiThreadedEventLoopGroup.singleton.any() let futures = (0.. Int { try! self.loop.submit { diff --git a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift index b661bcf57b..e4b836ba62 100644 --- a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift +++ b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift @@ -92,7 +92,7 @@ final class TCPThroughputBenchmark: Benchmark { public func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.messagesReceived += 1 - if (self.expectedMessages == self.messagesReceived) { + if self.expectedMessages == self.messagesReceived { let promise = self.completionPromise self.messagesReceived = 0 @@ -133,9 +133,9 @@ final class TCPThroughputBenchmark: Benchmark { .wait() var message = self.serverChannel.allocator.buffer(capacity: self.messageSize) - message.writeInteger(UInt16(messageSize), as:UInt16.self) + message.writeInteger(UInt16(messageSize), as: UInt16.self) for idx in 0..<(self.messageSize - MemoryLayout.stride) { - message.writeInteger(UInt8(truncatingIfNeeded: idx), endianness:.little, as:UInt8.self) + message.writeInteger(UInt8(truncatingIfNeeded: idx), endianness: .little, as: UInt8.self) } self.message = message @@ -154,7 +154,10 @@ final class TCPThroughputBenchmark: Benchmark { let expectedMessages = self.messages try clientChannel.eventLoop.submit { - try clientChannel.pipeline.syncOperations.handler(type: ClientHandler.self).prepareRun(expectedMessages: expectedMessages, promise: isDonePromise) + try clientChannel.pipeline.syncOperations.handler(type: ClientHandler.self).prepareRun( + expectedMessages: expectedMessages, + promise: isDonePromise + ) }.wait() let serverHandler = self.serverHandler! diff --git a/Sources/NIOPerformanceTester/UDPBenchmark.swift b/Sources/NIOPerformanceTester/UDPBenchmark.swift index a3840b50ad..090b261bb8 100644 --- a/Sources/NIOPerformanceTester/UDPBenchmark.swift +++ b/Sources/NIOPerformanceTester/UDPBenchmark.swift @@ -47,7 +47,7 @@ extension UDPBenchmark: Benchmark { // zero is the same as not applying the option. .channelOption(ChannelOptions.datagramVectorReadMessageCount, value: self.vectorReads) .channelInitializer { channel in - return channel.pipeline.addHandler(EchoHandler()) + channel.pipeline.addHandler(EchoHandler()) } .bind(to: address) .wait() @@ -58,11 +58,15 @@ extension UDPBenchmark: Benchmark { // zero is the same as not applying the option. .channelOption(ChannelOptions.datagramVectorReadMessageCount, value: self.vectorReads) .channelInitializer { channel in - let handler = EchoHandlerClient(eventLoop: channel.eventLoop, - config: .init(remoteAddress: remoteAddress, - request: self.data, - requests: self.numberOfRequests, - writesPerFlush: self.vectorWrites)) + let handler = EchoHandlerClient( + eventLoop: channel.eventLoop, + config: .init( + remoteAddress: remoteAddress, + request: self.data, + requests: self.numberOfRequests, + writesPerFlush: self.vectorWrites + ) + ) return channel.pipeline.addHandler(handler) } .bind(to: address) @@ -82,7 +86,6 @@ extension UDPBenchmark: Benchmark { } } - extension UDPBenchmark { final class EchoHandler: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope @@ -241,7 +244,7 @@ extension UDPBenchmark { self.state.run(requests: self.config.requests, writesPerFlush: self.config.writesPerFlush, promise: promise) let context = self.context! - for _ in 0 ..< self.config.writesPerFlush { + for _ in 0..(remoteAddress: self.config.remoteAddress, data: self.config.request) + let envolope = AddressedEnvelope( + remoteAddress: self.config.remoteAddress, + data: self.config.request + ) context.write(Self.wrapOutboundOut(envolope), promise: nil) if flush { context.flush() diff --git a/Sources/NIOPerformanceTester/WebSocketFrameDecoderBenchmark.swift b/Sources/NIOPerformanceTester/WebSocketFrameDecoderBenchmark.swift index 260b4bd56c..4e1697fb71 100644 --- a/Sources/NIOPerformanceTester/WebSocketFrameDecoderBenchmark.swift +++ b/Sources/NIOPerformanceTester/WebSocketFrameDecoderBenchmark.swift @@ -35,7 +35,9 @@ extension WebSocketFrameDecoderBenchmark: Benchmark { func setUp() throws { self.data = ByteBufferAllocator().webSocketFrame(size: dataSize, maskingKey: maskingKey) - try self.channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: dataSize))) + try self.channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: dataSize)) + ) } func tearDown() { @@ -45,7 +47,7 @@ extension WebSocketFrameDecoderBenchmark: Benchmark { func run() throws -> Int { for _ in 0.. private var frame: Optional - init(dataSize: Int, runCount: Int, dataStrategy: DataStrategy, cowStrategy: CoWStrategy, maskingKeyStrategy: MaskingKeyStrategy) { + init( + dataSize: Int, + runCount: Int, + dataStrategy: DataStrategy, + cowStrategy: CoWStrategy, + maskingKeyStrategy: MaskingKeyStrategy + ) { self.frame = nil self.channel = EmbeddedChannel() self.dataSize = dataSize @@ -37,7 +43,6 @@ final class WebSocketFrameEncoderBenchmark { } } - extension WebSocketFrameEncoderBenchmark { enum DataStrategy { case spaceAtFront @@ -45,7 +50,6 @@ extension WebSocketFrameEncoderBenchmark { } } - extension WebSocketFrameEncoderBenchmark { enum CoWStrategy { case always @@ -53,7 +57,6 @@ extension WebSocketFrameEncoderBenchmark { } } - extension WebSocketFrameEncoderBenchmark { enum MaskingKeyStrategy { case always @@ -61,7 +64,6 @@ extension WebSocketFrameEncoderBenchmark { } } - extension WebSocketFrameEncoderBenchmark: Benchmark { func setUp() throws { // We want the pipeline walk to have some cost. @@ -106,7 +108,6 @@ extension WebSocketFrameEncoderBenchmark: Benchmark { } } - extension ByteBufferAllocator { fileprivate func buffer(size: Int, dataStrategy: WebSocketFrameEncoderBenchmark.DataStrategy) -> ByteBuffer { var data: ByteBuffer @@ -125,13 +126,12 @@ extension ByteBufferAllocator { } } -fileprivate final class NoOpOutboundHandler: ChannelOutboundHandler { +private final class NoOpOutboundHandler: ChannelOutboundHandler { typealias OutboundIn = Any typealias OutboundOut = Any } - -fileprivate final class WriteConsumingHandler: ChannelOutboundHandler { +private final class WriteConsumingHandler: ChannelOutboundHandler { typealias OutboundIn = Any typealias OutboundOut = Never diff --git a/Sources/NIOPerformanceTester/main.swift b/Sources/NIOPerformanceTester/main.swift index e0408e7f15..3b40b7a15c 100644 --- a/Sources/NIOPerformanceTester/main.swift +++ b/Sources/NIOPerformanceTester/main.swift @@ -11,12 +11,15 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +// swift-format-ignore: AmbiguousTrailingClosureOverload + +import Dispatch import NIOCore -import NIOPosix import NIOEmbedded -import NIOHTTP1 import NIOFoundationCompat -import Dispatch +import NIOHTTP1 +import NIOPosix import NIOWebSocket // Use unbuffered stdout to help detect exactly which test was running in the event of a crash. @@ -25,13 +28,15 @@ setbuf(stdout, nil) // MARK: Test Harness var warning: String = "" -assert({ - print("======================================================") - print("= YOU ARE RUNNING NIOPerformanceTester IN DEBUG MODE =") - print("======================================================") - warning = " <<< DEBUG MODE >>>" - return true - }()) +assert( + { + print("======================================================") + print("= YOU ARE RUNNING NIOPerformanceTester IN DEBUG MODE =") + print("======================================================") + warning = " <<< DEBUG MODE >>>" + return true + }() +) public func measure(_ fn: () throws -> Int) rethrows -> [Double] { func measureOne(_ fn: () throws -> Int) rethrows -> Double { @@ -41,7 +46,7 @@ public func measure(_ fn: () throws -> Int) rethrows -> [Double] { return Double(end - start) / Double(TimeAmount.seconds(1).nanoseconds) } - _ = try measureOne(fn) /* pre-heat and throw away */ + _ = try measureOne(fn) // pre-heat and throw away var measurements = Array(repeating: 0.0, count: 10) for i in 0..<10 { measurements[i] = try measureOne(fn) @@ -52,7 +57,7 @@ public func measure(_ fn: () throws -> Int) rethrows -> [Double] { let limitSet = CommandLine.arguments.dropFirst() -public func measureAndPrint(desc: String, fn: () throws -> Int) rethrows -> Void { +public func measureAndPrint(desc: String, fn: () throws -> Int) rethrows { if limitSet.isEmpty || limitSet.contains(desc) { print("measuring\(warning): \(desc): ", terminator: "") let measurements = try measure(fn) @@ -71,7 +76,7 @@ public func measure(_ fn: () async throws -> Int) async rethrows -> [Double] { return Double(end - start) / Double(TimeAmount.seconds(1).nanoseconds) } - _ = try await measureOne(fn) /* pre-heat and throw away */ + _ = try await measureOne(fn) // pre-heat and throw away var measurements = Array(repeating: 0.0, count: 10) for i in 0..<10 { measurements[i] = try await measureOne(fn) @@ -81,7 +86,7 @@ public func measure(_ fn: () async throws -> Int) async rethrows -> [Double] { } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -public func measureAndPrint(desc: String, fn: () async throws -> Int) async rethrows -> Void { +public func measureAndPrint(desc: String, fn: () async throws -> Int) async rethrows { if limitSet.isEmpty || limitSet.contains(desc) { print("measuring\(warning): \(desc): ", terminator: "") let measurements = try await measure(fn) @@ -149,7 +154,6 @@ private final class SimpleHTTPServer: ChannelInboundHandler { } } - let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) defer { try! group.syncShutdownGracefully() @@ -276,9 +280,9 @@ measureAndPrint(desc: "bytebuffer_write_12MB_short_string_literals") { let bufferSize = 12 * 1024 * 1024 var buffer = ByteBufferAllocator().buffer(capacity: bufferSize) - for _ in 0 ..< 3 { + for _ in 0..<3 { buffer.clear() - for _ in 0 ..< (bufferSize / 4) { + for _ in 0..<(bufferSize / 4) { buffer.writeString("abcd") } } @@ -293,9 +297,9 @@ measureAndPrint(desc: "bytebuffer_write_12MB_short_calculated_strings") { var buffer = ByteBufferAllocator().buffer(capacity: bufferSize) let s = someString(size: 4) - for _ in 0 ..< 1 { + for _ in 0..<1 { buffer.clear() - for _ in 0 ..< (bufferSize / 4) { + for _ in 0..<(bufferSize / 4) { buffer.writeString(s) } } @@ -309,9 +313,9 @@ measureAndPrint(desc: "bytebuffer_write_12MB_medium_string_literals") { let bufferSize = 12 * 1024 * 1024 var buffer = ByteBufferAllocator().buffer(capacity: bufferSize) - for _ in 0 ..< 100 { + for _ in 0..<100 { buffer.clear() - for _ in 0 ..< (bufferSize / 24) { + for _ in 0..<(bufferSize / 24) { buffer.writeString("012345678901234567890123") } } @@ -326,9 +330,9 @@ measureAndPrint(desc: "bytebuffer_write_12MB_medium_calculated_strings") { var buffer = ByteBufferAllocator().buffer(capacity: bufferSize) let s = someString(size: 24) - for _ in 0 ..< 5 { + for _ in 0..<5 { buffer.clear() - for _ in 0 ..< (bufferSize / 24) { + for _ in 0..<(bufferSize / 24) { buffer.writeString(s) } } @@ -343,9 +347,9 @@ measureAndPrint(desc: "bytebuffer_write_12MB_large_calculated_strings") { var buffer = ByteBufferAllocator().buffer(capacity: bufferSize) let s = someString(size: 1024 * 1024) - for _ in 0 ..< 5 { + for _ in 0..<5 { buffer.clear() - for _ in 0 ..< 12 { + for _ in 0..<12 { buffer.writeString(s) } } @@ -363,7 +367,7 @@ measureAndPrint(desc: "bytebuffer_lots_of_rw") { let substring = Substring("A") @inline(never) func doWrites(buffer: inout ByteBuffer, dispatchData: DispatchData, substring: Substring) { - /* all of those should be 0 allocations */ + // all of those should be 0 allocations // buffer.writeBytes(foundationData) // see SR-7542 buffer.writeBytes([0x41]) @@ -376,7 +380,7 @@ measureAndPrint(desc: "bytebuffer_lots_of_rw") { } @inline(never) func doReads(buffer: inout ByteBuffer) { - /* these ones are zero allocations */ + // these ones are zero allocations let val = buffer.readInteger(as: UInt8.self) precondition(0x41 == val, "\(val!)") var slice = buffer.readSlice(length: 1) @@ -386,13 +390,13 @@ measureAndPrint(desc: "bytebuffer_lots_of_rw") { precondition(ptr[0] == 0x41) } - /* those down here should be one allocation each */ + // those down here should be one allocation each let arr = buffer.readBytes(length: 1) precondition([0x41] == arr!, "\(arr!)") let str = buffer.readString(length: 1) precondition("A" == str, "\(str!)") } - for _ in 0 ..< 100_000 { + for _ in 0..<100_000 { doWrites(buffer: &buffer, dispatchData: dispatchData, substring: substring) doReads(buffer: &buffer) } @@ -582,18 +586,20 @@ try measureAndPrint(desc: "no-net_http1_1k_reqs_1_conn") { func handlerAdded(context: ChannelHandlerContext) { self.requestBuffer = context.channel.allocator.buffer(capacity: 512) - self.requestBuffer.writeString(""" - GET /perf-test-2 HTTP/1.1\r - Host: example.com\r - X-Some-Header-1: foo\r - X-Some-Header-2: foo\r - X-Some-Header-3: foo\r - X-Some-Header-4: foo\r - X-Some-Header-5: foo\r - X-Some-Header-6: foo\r - X-Some-Header-7: foo\r - X-Some-Header-8: foo\r\n\r\n - """) + self.requestBuffer.writeString( + """ + GET /perf-test-2 HTTP/1.1\r + Host: example.com\r + X-Some-Header-1: foo\r + X-Some-Header-2: foo\r + X-Some-Header-3: foo\r + X-Some-Header-4: foo\r + X-Some-Header-5: foo\r + X-Some-Header-6: foo\r + X-Some-Header-7: foo\r + X-Some-Header-8: foo\r\n\r\n + """ + ) } func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -613,7 +619,7 @@ try measureAndPrint(desc: "no-net_http1_1k_reqs_1_conn") { } } - func kickOff(channel: Channel) -> Void { + func kickOff(channel: Channel) { try! (channel as! EmbeddedChannel).writeInbound(self.requestBuffer) } } @@ -627,8 +633,10 @@ try measureAndPrint(desc: "no-net_http1_1k_reqs_1_conn") { requestsDone = reqs done = true } - try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true, - withErrorHandling: true).flatMap { + try channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: true, + withErrorHandling: true + ).flatMap { channel.pipeline.addHandler(SimpleHTTPServer()) }.flatMap { channel.pipeline.addHandler(measuringHandler, position: .first) @@ -700,7 +708,7 @@ measureAndPrint(desc: "future_whenallsucceed_100k_immediately_succeeded_on_loop" let allSucceeded = try! loop.makeSucceededFuture(()).flatMap { _ -> EventLoopFuture<[Int]> in let futures = expected.map { loop.makeSucceededFuture($0) } return EventLoopFuture.whenAllSucceed(futures, on: loop) - }.wait() + }.wait() return allSucceeded.count } @@ -725,11 +733,10 @@ measureAndPrint(desc: "future_whenallsucceed_10k_deferred_on_loop") { promise.succeed(index) } return result - }.wait() + }.wait() return allSucceeded.count } - measureAndPrint(desc: "future_whenallcomplete_100k_immediately_succeeded_off_loop") { let loop = group.next() let expected = Array(0..<100_000) @@ -744,7 +751,7 @@ measureAndPrint(desc: "future_whenallcomplete_100k_immediately_succeeded_on_loop let allSucceeded = try! loop.makeSucceededFuture(()).flatMap { _ -> EventLoopFuture<[Result]> in let futures = expected.map { loop.makeSucceededFuture($0) } return EventLoopFuture.whenAllComplete(futures, on: loop) - }.wait() + }.wait() return allSucceeded.count } @@ -769,7 +776,7 @@ measureAndPrint(desc: "future_whenallcomplete_100k_deferred_on_loop") { promise.succeed(index) } return result - }.wait() + }.wait() return allSucceeded.count } @@ -962,7 +969,7 @@ try measureAndPrint( desc: "circular_buffer_into_byte_buffer_1mb", benchmark: CircularBufferIntoByteBufferBenchmark( iterations: 20, - bufferSize: 1024*1024 + bufferSize: 1024 * 1024 ) ) @@ -970,7 +977,7 @@ try measureAndPrint( desc: "byte_buffer_view_iterator_1mb", benchmark: ByteBufferViewIteratorBenchmark( iterations: 20, - bufferSize: 1024*1024 + bufferSize: 1024 * 1024 ) ) @@ -978,7 +985,7 @@ try measureAndPrint( desc: "byte_buffer_view_contains_12mb", benchmark: ByteBufferViewContainsBenchmark( iterations: 5, - bufferSize: 12*1024*1024 + bufferSize: 12 * 1024 * 1024 ) ) @@ -992,9 +999,12 @@ try measureAndPrint( measureAndPrint(desc: "generate_10k_random_request_keys") { let numKeys = 10_000 - return (0 ..< numKeys).reduce(into: 0, { result, _ in - result &+= NIOWebSocketClientUpgrader.randomRequestKey().count - }) + return (0.. - + + - - -Swift.org - Welcome to Swift.org - - - - - - - - - - - - - - - + + + Swift.org - Welcome to Swift.org + + + + + + + + + + + + + + + - + - + - - + + - - + + - - + + - - - + + + - + - + - -
-

Welcome to Swift.org

+
+

Welcome to Swift.org

-

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

+

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

-

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

+

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

-

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

+

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

-
+
-
- +
+ - -

Swift and the Swift logo are trademarks of Apple Inc.

-

-Privacy Policy -Cookies -

-
+ +

Swift and the Swift logo are trademarks of Apple Inc.

+

+ Privacy Policy + Cookies +

+
- - - - - - -""" + + + + + + + """ // generated using: // curl -s https://swift.org/ | pbcopy let htmlMostlyASCII: String = """ - - + + - - -Swift.org - Welcome to Swift.org - - - - - - - - - - - - - - - + + + Swift.org - Welcome to Swift.org + + + + + + + + + + + + + + + - + - + - - + + - - + + - - + + - - - + + + - + - + - -
-

Welcome to Swift.org

+
+

Welcome to Swift.org

-

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

+

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

-

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

+

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

-

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

+

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

-
+
-
- +
+ - -

Swift and the Swift logo are trademarks of Apple Inc.

-

-Privacy Policy -Cookies -

-
+ +

Swift and the Swift logo are trademarks of Apple Inc.

+

+ Privacy Policy + Cookies +

+
- - - - - - -""" + + + + + + + """ // generated using: // curl -s https://swift.org/ | iconv -c -f utf-8 -t ascii | pbcopy let htmlASCIIOnlyStaticString: StaticString = """ - - + + - - -Swift.org - Welcome to Swift.org - - - - - - - - - - - - - - - + + + Swift.org - Welcome to Swift.org + + + + + + + + + + + + + + + - + - + - - + + - - + + - - + + - - - + + + - + - + - -
-

Welcome to Swift.org

+
+

Welcome to Swift.org

-

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

+

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

-

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

+

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

-

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

+

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

-
+
-
- +
+ - -

Swift and the Swift logo are trademarks of Apple Inc.

-

-Privacy Policy -Cookies -

-
+ +

Swift and the Swift logo are trademarks of Apple Inc.

+

+ Privacy Policy + Cookies +

+
- - - - - - -""" + + + + + + + """ // generated using: // curl -s https://swift.org/ | pbcopy let htmlMostlyASCIIStaticString: StaticString = """ - - + + - - -Swift.org - Welcome to Swift.org - - - - - - - - - - - - - - - + + + Swift.org - Welcome to Swift.org + + + + + + + + + + + + + + + - + - + - - + + - - + + - - + + - - - + + + - + - + - -
-

Welcome to Swift.org

+
+

Welcome to Swift.org

-

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

+

Welcome to the Swift community. Together we are working to build a programming language to empower everyone to turn their ideas into apps on any platform.

-

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

+

Announced in 2014, the Swift programming language has quickly become one of the fastest growing languages in history. Swift makes it easy to write software that is incredibly fast and safe by design. Our goals for Swift are ambitious: we want to make programming simple things easy, and difficult things possible.

-

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

+

For students, learning Swift has been a great introduction to modern programming concepts and best practices. And because it is open, their Swift skills will be able to be applied to an even broader range of platforms, from mobile devices to the desktop to the cloud.

-
+
-
- +
+ - -

Swift and the Swift logo are trademarks of Apple Inc.

-

-Privacy Policy -Cookies -

-
+ +

Swift and the Swift logo are trademarks of Apple Inc.

+

+ Privacy Policy + Cookies +

+
- - - - - - -""" + + + + + + + """ diff --git a/Sources/NIOPosix/BSDSocketAPICommon.swift b/Sources/NIOPosix/BSDSocketAPICommon.swift index 721143a990..538d919afe 100644 --- a/Sources/NIOPosix/BSDSocketAPICommon.swift +++ b/Sources/NIOPosix/BSDSocketAPICommon.swift @@ -38,11 +38,11 @@ internal enum Shutdown: _SocketShutdownProtocol { } extension NIOBSDSocket { -#if os(Windows) + #if os(Windows) internal static let invalidHandle: Handle = INVALID_SOCKET -#else + #else internal static let invalidHandle: Handle = -1 -#endif + #endif } extension NIOBSDSocket { @@ -67,29 +67,29 @@ extension NIOBSDSocket.SocketType { /// Supports datagrams, which are connectionless, unreliable messages of a /// fixed (typically small) maximum length. #if os(Linux) && !canImport(Musl) - internal static let datagram: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: CInt(SOCK_DGRAM.rawValue)) + internal static let datagram: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: CInt(SOCK_DGRAM.rawValue)) #else - internal static let datagram: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: SOCK_DGRAM) + internal static let datagram: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: SOCK_DGRAM) #endif /// Supports reliable, two-way, connection-based byte streams without /// duplication of data and without preservation of boundaries. #if os(Linux) && !canImport(Musl) - internal static let stream: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: CInt(SOCK_STREAM.rawValue)) + internal static let stream: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: CInt(SOCK_STREAM.rawValue)) #else - internal static let stream: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) + internal static let stream: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: SOCK_STREAM) #endif #if os(Linux) && !canImport(Musl) - internal static let raw: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue)) + internal static let raw: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: CInt(SOCK_RAW.rawValue)) #else - internal static let raw: NIOBSDSocket.SocketType = - NIOBSDSocket.SocketType(rawValue: SOCK_RAW) + internal static let raw: NIOBSDSocket.SocketType = + NIOBSDSocket.SocketType(rawValue: SOCK_RAW) #endif } @@ -102,7 +102,7 @@ extension NIOBSDSocket.Option { /// `ChannelOptions.explicitCongestionNotification` which works for both /// IPv4 and IPv6. static let ip_recv_tos: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IP_RECVTOS) + NIOBSDSocket.Option(rawValue: IP_RECVTOS) /// Request that we are passed destination address and the receiving interface index when /// receiving datagrams. @@ -111,7 +111,7 @@ extension NIOBSDSocket.Option { /// `ChannelOptions.receivePacketInfo` which works for both /// IPv4 and IPv6. static let ip_recv_pktinfo: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: Posix.IP_RECVPKTINFO) + NIOBSDSocket.Option(rawValue: Posix.IP_RECVPKTINFO) } // IPv6 Options @@ -123,7 +123,7 @@ extension NIOBSDSocket.Option { /// `ChannelOptions.explicitCongestionNotification` which works for both /// IPv4 and IPv6. static let ipv6_recv_tclass: NIOBSDSocket.Option = - NIOBSDSocket.Option(rawValue: IPV6_RECVTCLASS) + NIOBSDSocket.Option(rawValue: IPV6_RECVTCLASS) /// Request that we are passed destination address and the receiving interface index when /// receiving datagrams. @@ -179,118 +179,156 @@ extension NIOBSDSocket.ProtocolSubtype { } } - /// This protocol defines the methods that are expected to be found on /// `NIOBSDSocket`. While defined as a protocol there is no expectation that any /// object other than `NIOBSDSocket` will implement this protocol: instead, this /// protocol acts as a reference for what new supported operating systems must /// implement. protocol _BSDSocketProtocol { - static func accept(socket s: NIOBSDSocket.Handle, - address addr: UnsafeMutablePointer?, - address_len addrlen: UnsafeMutablePointer?) throws -> NIOBSDSocket.Handle? - - static func bind(socket s: NIOBSDSocket.Handle, - address addr: UnsafePointer, - address_len namelen: socklen_t) throws + static func accept( + socket s: NIOBSDSocket.Handle, + address addr: UnsafeMutablePointer?, + address_len addrlen: UnsafeMutablePointer? + ) throws -> NIOBSDSocket.Handle? + + static func bind( + socket s: NIOBSDSocket.Handle, + address addr: UnsafePointer, + address_len namelen: socklen_t + ) throws static func close(socket s: NIOBSDSocket.Handle) throws - static func connect(socket s: NIOBSDSocket.Handle, - address name: UnsafePointer, - address_len namelen: socklen_t) throws -> Bool - - static func getpeername(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws - - static func getsockname(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws - - static func getsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeMutableRawPointer, - option_len optlen: UnsafeMutablePointer) throws + static func connect( + socket s: NIOBSDSocket.Handle, + address name: UnsafePointer, + address_len namelen: socklen_t + ) throws -> Bool + + static func getpeername( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws + + static func getsockname( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws + + static func getsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeMutableRawPointer, + option_len optlen: UnsafeMutablePointer + ) throws static func listen(socket s: NIOBSDSocket.Handle, backlog: CInt) throws - static func recv(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t) throws -> IOResult + static func recv( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeMutableRawPointer, + length len: size_t + ) throws -> IOResult // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func recvmsg(socket: NIOBSDSocket.Handle, - msgHdr: UnsafeMutablePointer, flags: CInt) - throws -> IOResult + static func recvmsg( + socket: NIOBSDSocket.Handle, + msgHdr: UnsafeMutablePointer, + flags: CInt + ) + throws -> IOResult // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func sendmsg(socket: NIOBSDSocket.Handle, - msgHdr: UnsafePointer, flags: CInt) - throws -> IOResult - - static func send(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeRawPointer, - length len: size_t) throws -> IOResult - - static func setsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeRawPointer, - option_len optlen: socklen_t) throws + static func sendmsg( + socket: NIOBSDSocket.Handle, + msgHdr: UnsafePointer, + flags: CInt + ) + throws -> IOResult + + static func send( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeRawPointer, + length len: size_t + ) throws -> IOResult + + static func setsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeRawPointer, + option_len optlen: socklen_t + ) throws static func shutdown(socket: NIOBSDSocket.Handle, how: Shutdown) throws - static func socket(domain af: NIOBSDSocket.ProtocolFamily, - type: NIOBSDSocket.SocketType, - protocolSubtype: NIOBSDSocket.ProtocolSubtype) throws -> NIOBSDSocket.Handle - - static func recvmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, - flags: CInt, - timeout: UnsafeMutablePointer?) throws -> IOResult - - static func sendmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, - flags: CInt) throws -> IOResult + static func socket( + domain af: NIOBSDSocket.ProtocolFamily, + type: NIOBSDSocket.SocketType, + protocolSubtype: NIOBSDSocket.ProtocolSubtype + ) throws -> NIOBSDSocket.Handle + + static func recvmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt, + timeout: UnsafeMutablePointer? + ) throws -> IOResult + + static func sendmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt + ) throws -> IOResult // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func pread(socket: NIOBSDSocket.Handle, - pointer: UnsafeMutableRawPointer, - size: size_t, - offset: off_t) throws -> IOResult + static func pread( + socket: NIOBSDSocket.Handle, + pointer: UnsafeMutableRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func pwrite(socket: NIOBSDSocket.Handle, - pointer: UnsafeRawPointer, - size: size_t, - offset: off_t) throws -> IOResult - -#if !os(Windows) + static func pwrite( + socket: NIOBSDSocket.Handle, + pointer: UnsafeRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult + + #if !os(Windows) // NOTE: We do not support this on Windows as WSAPoll behaves differently // from poll with reporting of failed connections (Connect Report 309411), // which recommended that you use NetAPI instead. // // This is safe to exclude as this is a testing-only API. - static func poll(fds: UnsafeMutablePointer, nfds: nfds_t, - timeout: CInt) throws -> CInt -#endif + static func poll( + fds: UnsafeMutablePointer, + nfds: nfds_t, + timeout: CInt + ) throws -> CInt + #endif - static func sendfile(socket s: NIOBSDSocket.Handle, - fd: CInt, - offset: off_t, - len: off_t) throws -> IOResult + static func sendfile( + socket s: NIOBSDSocket.Handle, + fd: CInt, + offset: off_t, + len: off_t + ) throws -> IOResult // MARK: non-BSD APIs added by NIO @@ -301,7 +339,7 @@ protocol _BSDSocketProtocol { /// If this extension is hitting a compile error, your platform is missing one /// of the functions defined above! -extension NIOBSDSocket: _BSDSocketProtocol { } +extension NIOBSDSocket: _BSDSocketProtocol {} /// This protocol defines the methods that are expected to be found on /// `NIOBSDControlMessage`. While defined as a protocol there is no expectation @@ -309,18 +347,26 @@ extension NIOBSDSocket: _BSDSocketProtocol { } /// protocol: instead, this protocol acts as a reference for what new supported /// operating systems must implement. protocol _BSDSocketControlMessageProtocol { - static func firstHeader(inside msghdr: UnsafePointer) - -> UnsafeMutablePointer? - - static func nextHeader(inside msghdr: UnsafeMutablePointer, - after: UnsafeMutablePointer) - -> UnsafeMutablePointer? - - static func data(for header: UnsafePointer) - -> UnsafeRawBufferPointer? - - static func data(for header: UnsafeMutablePointer) - -> UnsafeMutableRawBufferPointer? + static func firstHeader( + inside msghdr: UnsafePointer + ) + -> UnsafeMutablePointer? + + static func nextHeader( + inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer + ) + -> UnsafeMutablePointer? + + static func data( + for header: UnsafePointer + ) + -> UnsafeRawBufferPointer? + + static func data( + for header: UnsafeMutablePointer + ) + -> UnsafeMutableRawBufferPointer? static func length(payloadSize: size_t) -> size_t @@ -329,7 +375,7 @@ protocol _BSDSocketControlMessageProtocol { /// If this extension is hitting a compile error, your platform is missing one /// of the functions defined above! -enum NIOBSDSocketControlMessage: _BSDSocketControlMessageProtocol { } +enum NIOBSDSocketControlMessage: _BSDSocketControlMessageProtocol {} /// The requested UDS path exists and has wrong type (not a socket). public struct UnixDomainSocketPathWrongType: Error {} diff --git a/Sources/NIOPosix/BSDSocketAPIPosix.swift b/Sources/NIOPosix/BSDSocketAPIPosix.swift index 7b77527b6f..852ab5b4d2 100644 --- a/Sources/NIOPosix/BSDSocketAPIPosix.swift +++ b/Sources/NIOPosix/BSDSocketAPIPosix.swift @@ -30,162 +30,212 @@ extension Shutdown { // MARK: Implementation of _BSDSocketProtocol for POSIX systems extension NIOBSDSocket { - static func accept(socket s: NIOBSDSocket.Handle, - address addr: UnsafeMutablePointer?, - address_len addrlen: UnsafeMutablePointer?) throws -> NIOBSDSocket.Handle? { - return try Posix.accept(descriptor: s, addr: addr, len: addrlen) + static func accept( + socket s: NIOBSDSocket.Handle, + address addr: UnsafeMutablePointer?, + address_len addrlen: UnsafeMutablePointer? + ) throws -> NIOBSDSocket.Handle? { + try Posix.accept(descriptor: s, addr: addr, len: addrlen) } - static func bind(socket s: NIOBSDSocket.Handle, - address addr: UnsafePointer, - address_len namelen: socklen_t) throws { - return try Posix.bind(descriptor: s, ptr: addr, bytes: Int(namelen)) + static func bind( + socket s: NIOBSDSocket.Handle, + address addr: UnsafePointer, + address_len namelen: socklen_t + ) throws { + try Posix.bind(descriptor: s, ptr: addr, bytes: Int(namelen)) } static func close(socket s: NIOBSDSocket.Handle) throws { - return try Posix.close(descriptor: s) + try Posix.close(descriptor: s) } - static func connect(socket s: NIOBSDSocket.Handle, - address name: UnsafePointer, - address_len namelen: socklen_t) throws -> Bool { - return try Posix.connect(descriptor: s, addr: name, size: namelen) + static func connect( + socket s: NIOBSDSocket.Handle, + address name: UnsafePointer, + address_len namelen: socklen_t + ) throws -> Bool { + try Posix.connect(descriptor: s, addr: name, size: namelen) } - static func getpeername(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws { - return try Posix.getpeername(socket: s, address: name, addressLength: namelen) + static func getpeername( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws { + try Posix.getpeername(socket: s, address: name, addressLength: namelen) } - static func getsockname(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws { - return try Posix.getsockname(socket: s, address: name, addressLength: namelen) + static func getsockname( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws { + try Posix.getsockname(socket: s, address: name, addressLength: namelen) } - static func getsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeMutableRawPointer, - option_len optlen: UnsafeMutablePointer) throws { - return try Posix.getsockopt(socket: socket, - level: level.rawValue, - optionName: optname.rawValue, - optionValue: optval, - optionLen: optlen) + static func getsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeMutableRawPointer, + option_len optlen: UnsafeMutablePointer + ) throws { + try Posix.getsockopt( + socket: socket, + level: level.rawValue, + optionName: optname.rawValue, + optionValue: optval, + optionLen: optlen + ) } static func listen(socket s: NIOBSDSocket.Handle, backlog: CInt) throws { - return try Posix.listen(descriptor: s, backlog: backlog) - } - - static func recv(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t) throws -> IOResult { - return try Posix.read(descriptor: s, pointer: buf, size: len) - } - - static func recvmsg(socket: NIOBSDSocket.Handle, - msgHdr: UnsafeMutablePointer, flags: CInt) - throws -> IOResult { - return try Posix.recvmsg(descriptor: socket, msgHdr: msgHdr, flags: flags) - } - - static func sendmsg(socket: NIOBSDSocket.Handle, - msgHdr: UnsafePointer, flags: CInt) - throws -> IOResult { - return try Posix.sendmsg(descriptor: socket, msgHdr: msgHdr, flags: flags) - } - - static func send(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeRawPointer, - length len: size_t) throws -> IOResult { - return try Posix.write(descriptor: s, pointer: buf, size: len) - } - - static func setsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeRawPointer, - option_len optlen: socklen_t) throws { - return try Posix.setsockopt(socket: socket, - level: level.rawValue, - optionName: optname.rawValue, - optionValue: optval, - optionLen: optlen) + try Posix.listen(descriptor: s, backlog: backlog) + } + + static func recv( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeMutableRawPointer, + length len: size_t + ) throws -> IOResult { + try Posix.read(descriptor: s, pointer: buf, size: len) + } + + static func recvmsg( + socket: NIOBSDSocket.Handle, + msgHdr: UnsafeMutablePointer, + flags: CInt + ) + throws -> IOResult + { + try Posix.recvmsg(descriptor: socket, msgHdr: msgHdr, flags: flags) + } + + static func sendmsg( + socket: NIOBSDSocket.Handle, + msgHdr: UnsafePointer, + flags: CInt + ) + throws -> IOResult + { + try Posix.sendmsg(descriptor: socket, msgHdr: msgHdr, flags: flags) + } + + static func send( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeRawPointer, + length len: size_t + ) throws -> IOResult { + try Posix.write(descriptor: s, pointer: buf, size: len) + } + + static func setsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeRawPointer, + option_len optlen: socklen_t + ) throws { + try Posix.setsockopt( + socket: socket, + level: level.rawValue, + optionName: optname.rawValue, + optionValue: optval, + optionLen: optlen + ) } static func shutdown(socket: NIOBSDSocket.Handle, how: Shutdown) throws { - return try Posix.shutdown(descriptor: socket, how: how) - } - - static func socket(domain af: NIOBSDSocket.ProtocolFamily, - type: NIOBSDSocket.SocketType, - protocolSubtype: NIOBSDSocket.ProtocolSubtype) throws -> NIOBSDSocket.Handle { - return try Posix.socket(domain: af, type: type, protocolSubtype: protocolSubtype) - } - - static func recvmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, - flags: CInt, - timeout: UnsafeMutablePointer?) throws -> IOResult { - return try Posix.recvmmsg(sockfd: socket, - msgvec: msgvec, - vlen: vlen, - flags: flags, - timeout: timeout) - } - - static func sendmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, - flags: CInt) throws -> IOResult { - return try Posix.sendmmsg(sockfd: socket, - msgvec: msgvec, - vlen: vlen, - flags: flags) + try Posix.shutdown(descriptor: socket, how: how) + } + + static func socket( + domain af: NIOBSDSocket.ProtocolFamily, + type: NIOBSDSocket.SocketType, + protocolSubtype: NIOBSDSocket.ProtocolSubtype + ) throws -> NIOBSDSocket.Handle { + try Posix.socket(domain: af, type: type, protocolSubtype: protocolSubtype) + } + + static func recvmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt, + timeout: UnsafeMutablePointer? + ) throws -> IOResult { + try Posix.recvmmsg( + sockfd: socket, + msgvec: msgvec, + vlen: vlen, + flags: flags, + timeout: timeout + ) + } + + static func sendmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt + ) throws -> IOResult { + try Posix.sendmmsg( + sockfd: socket, + msgvec: msgvec, + vlen: vlen, + flags: flags + ) } // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func pread(socket: NIOBSDSocket.Handle, - pointer: UnsafeMutableRawPointer, - size: size_t, - offset: off_t) throws -> IOResult { - return try Posix.pread(descriptor: socket, - pointer: pointer, - size: size, - offset: offset) + static func pread( + socket: NIOBSDSocket.Handle, + pointer: UnsafeMutableRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult { + try Posix.pread( + descriptor: socket, + pointer: pointer, + size: size, + offset: offset + ) } // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. - static func pwrite(socket: NIOBSDSocket.Handle, - pointer: UnsafeRawPointer, - size: size_t, - offset: off_t) throws -> IOResult { - return try Posix.pwrite(descriptor: socket, pointer: pointer, size: size, offset: offset) + static func pwrite( + socket: NIOBSDSocket.Handle, + pointer: UnsafeRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult { + try Posix.pwrite(descriptor: socket, pointer: pointer, size: size, offset: offset) } - static func poll(fds: UnsafeMutablePointer, - nfds: nfds_t, - timeout: CInt) throws -> CInt { - return try Posix.poll(fds: fds, nfds: nfds, timeout: timeout) + static func poll( + fds: UnsafeMutablePointer, + nfds: nfds_t, + timeout: CInt + ) throws -> CInt { + try Posix.poll(fds: fds, nfds: nfds, timeout: timeout) } - static func sendfile(socket s: NIOBSDSocket.Handle, - fd: CInt, - offset: off_t, - len: off_t) throws -> IOResult { - return try Posix.sendfile(descriptor: s, fd: fd, offset: offset, count: size_t(len)) + static func sendfile( + socket s: NIOBSDSocket.Handle, + fd: CInt, + offset: off_t, + len: off_t + ) throws -> IOResult { + try Posix.sendfile(descriptor: s, fd: fd, offset: offset, count: size_t(len)) } static func setNonBlocking(socket: NIOBSDSocket.Handle) throws { - return try Posix.setNonBlocking(socket: socket) + try Posix.setNonBlocking(socket: socket) } static func cleanupUnixDomainSocket(atPath path: String) throws { @@ -231,27 +281,39 @@ private let CMSG_LEN = CNIOLinux_CMSG_LEN // MARK: _BSDSocketControlMessageProtocol implementation extension NIOBSDSocketControlMessage { - static func firstHeader(inside msghdr: UnsafePointer) - -> UnsafeMutablePointer? { - return CMSG_FIRSTHDR(msghdr) - } - - static func nextHeader(inside msghdr: UnsafeMutablePointer, - after: UnsafeMutablePointer) - -> UnsafeMutablePointer? { - return CMSG_NXTHDR(msghdr, after) - } - - static func data(for header: UnsafePointer) - -> UnsafeRawBufferPointer? { + static func firstHeader( + inside msghdr: UnsafePointer + ) + -> UnsafeMutablePointer? + { + CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader( + inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer + ) + -> UnsafeMutablePointer? + { + CMSG_NXTHDR(msghdr, after) + } + + static func data( + for header: UnsafePointer + ) + -> UnsafeRawBufferPointer? + { let data = CMSG_DATA(header) let length = size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) return UnsafeRawBufferPointer(start: data, count: Int(length)) } - static func data(for header: UnsafeMutablePointer) - -> UnsafeMutableRawBufferPointer? { + static func data( + for header: UnsafeMutablePointer + ) + -> UnsafeMutableRawBufferPointer? + { let data = CMSG_DATA_MUTABLE(header) let length = size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) @@ -259,11 +321,11 @@ extension NIOBSDSocketControlMessage { } static func length(payloadSize: size_t) -> size_t { - return CMSG_LEN(payloadSize) + CMSG_LEN(payloadSize) } static func space(payloadSize: size_t) -> size_t { - return CMSG_SPACE(payloadSize) + CMSG_SPACE(payloadSize) } } @@ -271,11 +333,13 @@ extension NIOBSDSocket { static func setUDPSegmentSize(_ segmentSize: CInt, socket: NIOBSDSocket.Handle) throws { #if os(Linux) var segmentSize = segmentSize - try Self.setsockopt(socket: socket, - level: .udp, - option_name: .udp_segment, - option_value: &segmentSize, - option_len: socklen_t(MemoryLayout.size)) + try Self.setsockopt( + socket: socket, + level: .udp, + option_name: .udp_segment, + option_value: &segmentSize, + option_len: socklen_t(MemoryLayout.size) + ) #else throw ChannelError._operationUnsupported #endif @@ -286,11 +350,13 @@ extension NIOBSDSocket { var segmentSize: CInt = 0 var optionLength = socklen_t(MemoryLayout.size) try withUnsafeMutablePointer(to: &segmentSize) { segmentSizeBytes in - try Self.getsockopt(socket: socket, - level: .udp, - option_name: .udp_segment, - option_value: segmentSizeBytes, - option_len: &optionLength) + try Self.getsockopt( + socket: socket, + level: .udp, + option_name: .udp_segment, + option_value: segmentSizeBytes, + option_len: &optionLength + ) } return segmentSize #else @@ -301,11 +367,13 @@ extension NIOBSDSocket { static func setUDPReceiveOffload(_ enabled: Bool, socket: NIOBSDSocket.Handle) throws { #if os(Linux) var isEnabled: CInt = enabled ? 1 : 0 - try Self.setsockopt(socket: socket, - level: .udp, - option_name: .udp_gro, - option_value: &isEnabled, - option_len: socklen_t(MemoryLayout.size)) + try Self.setsockopt( + socket: socket, + level: .udp, + option_name: .udp_gro, + option_value: &isEnabled, + option_len: socklen_t(MemoryLayout.size) + ) #else throw ChannelError._operationUnsupported #endif @@ -316,11 +384,13 @@ extension NIOBSDSocket { var enabled: CInt = 0 var optionLength = socklen_t(MemoryLayout.size) try withUnsafeMutablePointer(to: &enabled) { enabledBytes in - try Self.getsockopt(socket: socket, - level: .udp, - option_name: .udp_gro, - option_value: enabledBytes, - option_len: &optionLength) + try Self.getsockopt( + socket: socket, + level: .udp, + option_name: .udp_gro, + option_value: enabledBytes, + option_len: &optionLength + ) } return enabled != 0 #else diff --git a/Sources/NIOPosix/BSDSocketAPIWindows.swift b/Sources/NIOPosix/BSDSocketAPIWindows.swift index 89cf148047..37b79764f6 100644 --- a/Sources/NIOPosix/BSDSocketAPIWindows.swift +++ b/Sources/NIOPosix/BSDSocketAPIWindows.swift @@ -120,39 +120,46 @@ internal typealias sockaddr_in6 = SOCKADDR_IN6 internal typealias sockaddr_un = SOCKADDR_UN internal typealias sockaddr_storage = SOCKADDR_STORAGE - -fileprivate var IOC_IN: DWORD { - 0x8000_0000 +private var IOC_IN: DWORD { + 0x8000_0000 } -fileprivate var IOC_OUT: DWORD { - 0x4000_0000 +private var IOC_OUT: DWORD { + 0x4000_0000 } -fileprivate var IOC_INOUT: DWORD { - IOC_IN | IOC_OUT +private var IOC_INOUT: DWORD { + IOC_IN | IOC_OUT } -fileprivate var IOC_WS2: DWORD { - 0x0800_0000 +private var IOC_WS2: DWORD { + 0x0800_0000 } -fileprivate func _WSAIORW(_ x: DWORD, _ y: DWORD) -> DWORD { - IOC_INOUT | x | y +private func _WSAIORW(_ x: DWORD, _ y: DWORD) -> DWORD { + IOC_INOUT | x | y } -fileprivate var SIO_GET_EXTENSION_FUNCTION_POINTER: DWORD { - _WSAIORW(IOC_WS2, 6) +private var SIO_GET_EXTENSION_FUNCTION_POINTER: DWORD { + _WSAIORW(IOC_WS2, 6) } -fileprivate var WSAID_WSARECVMSG: _GUID { - _GUID(Data1: 0xf689d7c8, Data2: 0x6f1f, Data3: 0x436b, - Data4: (0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22)) +private var WSAID_WSARECVMSG: _GUID { + _GUID( + Data1: 0xf689_d7c8, + Data2: 0x6f1f, + Data3: 0x436b, + Data4: (0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22) + ) } -fileprivate var WSAID_WSASENDMSG: _GUID { - _GUID(Data1: 0xa441e712, Data2: 0x754f, Data3: 0x43ca, - Data4: (0x84,0xa7,0x0d,0xee,0x44,0xcf,0x60,0x6d)) +private var WSAID_WSASENDMSG: _GUID { + _GUID( + Data1: 0xa441_e712, + Data2: 0x754f, + Data3: 0x43ca, + Data4: (0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d) + ) } // TODO(compnerd) rather than query the `WSARecvMsg` and `WSASendMsg` on each @@ -180,9 +187,11 @@ extension Shutdown { // MARK: _BSDSocketProtocol implementation extension NIOBSDSocket { @inline(never) - static func accept(socket s: NIOBSDSocket.Handle, - address addr: UnsafeMutablePointer?, - address_len addrlen: UnsafeMutablePointer?) throws -> NIOBSDSocket.Handle? { + static func accept( + socket s: NIOBSDSocket.Handle, + address addr: UnsafeMutablePointer?, + address_len addrlen: UnsafeMutablePointer? + ) throws -> NIOBSDSocket.Handle? { let socket: NIOBSDSocket.Handle = WinSDK.accept(s, addr, addrlen) if socket == WinSDK.INVALID_SOCKET { throw IOError(winsock: WSAGetLastError(), reason: "accept") @@ -191,9 +200,11 @@ extension NIOBSDSocket { } @inline(never) - static func bind(socket s: NIOBSDSocket.Handle, - address addr: UnsafePointer, - address_len namelen: socklen_t) throws { + static func bind( + socket s: NIOBSDSocket.Handle, + address addr: UnsafePointer, + address_len namelen: socklen_t + ) throws { if WinSDK.bind(s, addr, namelen) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "bind") } @@ -207,9 +218,11 @@ extension NIOBSDSocket { } @inline(never) - static func connect(socket s: NIOBSDSocket.Handle, - address name: UnsafePointer, - address_len namelen: socklen_t) throws -> Bool { + static func connect( + socket s: NIOBSDSocket.Handle, + address name: UnsafePointer, + address_len namelen: socklen_t + ) throws -> Bool { if WinSDK.connect(s, name, namelen) == SOCKET_ERROR { let iResult = WSAGetLastError() if iResult == WSAEWOULDBLOCK { return true } @@ -219,31 +232,42 @@ extension NIOBSDSocket { } @inline(never) - static func getpeername(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws { + static func getpeername( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws { if WinSDK.getpeername(s, name, namelen) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "getpeername") } } @inline(never) - static func getsockname(socket s: NIOBSDSocket.Handle, - address name: UnsafeMutablePointer, - address_len namelen: UnsafeMutablePointer) throws { + static func getsockname( + socket s: NIOBSDSocket.Handle, + address name: UnsafeMutablePointer, + address_len namelen: UnsafeMutablePointer + ) throws { if WinSDK.getsockname(s, name, namelen) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "getsockname") } } @inline(never) - static func getsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeMutableRawPointer, - option_len optlen: UnsafeMutablePointer) throws { - if CNIOWindows_getsockopt(socket, level.rawValue, optname.rawValue, - optval, optlen) == SOCKET_ERROR { + static func getsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeMutableRawPointer, + option_len optlen: UnsafeMutablePointer + ) throws { + if CNIOWindows_getsockopt( + socket, + level.rawValue, + optname.rawValue, + optval, + optlen + ) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "getsockopt") } } @@ -256,9 +280,11 @@ extension NIOBSDSocket { } @inline(never) - static func recv(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeMutableRawPointer, - length len: size_t) throws -> IOResult { + static func recv( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeMutableRawPointer, + length len: size_t + ) throws -> IOResult { let iResult: CInt = CNIOWindows_recv(s, buf, CInt(len), 0) if iResult == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "recv") @@ -267,25 +293,36 @@ extension NIOBSDSocket { } @inline(never) - static func recvmsg(socket s: NIOBSDSocket.Handle, - msgHdr lpMsg: UnsafeMutablePointer, - flags: CInt) - throws -> IOResult { + static func recvmsg( + socket s: NIOBSDSocket.Handle, + msgHdr lpMsg: UnsafeMutablePointer, + flags: CInt + ) + throws -> IOResult + { // TODO(compnerd) see comment above var InBuffer = WSAID_WSARECVMSG var pfnWSARecvMsg: LPFN_WSARECVMSG? var cbBytesReturned: DWORD = 0 - if WinSDK.WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, - &InBuffer, DWORD(MemoryLayout.stride(ofValue: InBuffer)), - &pfnWSARecvMsg, - DWORD(MemoryLayout.stride(ofValue: pfnWSARecvMsg)), - &cbBytesReturned, nil, nil) == SOCKET_ERROR { + if WinSDK.WSAIoctl( + s, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &InBuffer, + DWORD(MemoryLayout.stride(ofValue: InBuffer)), + &pfnWSARecvMsg, + DWORD(MemoryLayout.stride(ofValue: pfnWSARecvMsg)), + &cbBytesReturned, + nil, + nil + ) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "WSAIoctl") } guard let WSARecvMsg = pfnWSARecvMsg else { - throw IOError(windows: DWORD(ERROR_INVALID_FUNCTION), - reason: "recvmsg") + throw IOError( + windows: DWORD(ERROR_INVALID_FUNCTION), + reason: "recvmsg" + ) } var dwNumberOfBytesRecvd: DWORD = 0 @@ -297,40 +334,58 @@ extension NIOBSDSocket { } @inline(never) - static func sendmsg(socket Handle: NIOBSDSocket.Handle, - msgHdr lpMsg: UnsafePointer, - flags dwFlags: CInt) throws -> IOResult { + static func sendmsg( + socket Handle: NIOBSDSocket.Handle, + msgHdr lpMsg: UnsafePointer, + flags dwFlags: CInt + ) throws -> IOResult { // TODO(compnerd) see comment above var InBuffer = WSAID_WSASENDMSG var pfnWSASendMsg: LPFN_WSASENDMSG? var cbBytesReturned: DWORD = 0 - if WinSDK.WSAIoctl(Handle, SIO_GET_EXTENSION_FUNCTION_POINTER, - &InBuffer, DWORD(MemoryLayout.stride(ofValue: InBuffer)), - &pfnWSASendMsg, - DWORD(MemoryLayout.stride(ofValue: pfnWSASendMsg)), - &cbBytesReturned, nil, nil) == SOCKET_ERROR { + if WinSDK.WSAIoctl( + Handle, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &InBuffer, + DWORD(MemoryLayout.stride(ofValue: InBuffer)), + &pfnWSASendMsg, + DWORD(MemoryLayout.stride(ofValue: pfnWSASendMsg)), + &cbBytesReturned, + nil, + nil + ) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "WSAIoctl") } guard let WSASendMsg = pfnWSASendMsg else { - throw IOError(windows: DWORD(ERROR_INVALID_FUNCTION), - reason: "sendmsg") + throw IOError( + windows: DWORD(ERROR_INVALID_FUNCTION), + reason: "sendmsg" + ) } let lpMsg: LPWSAMSG = UnsafeMutablePointer(mutating: lpMsg) var NumberOfBytesSent: DWORD = 0 // FIXME(compnerd) is the socket guaranteed to not be overlapped? - if WSASendMsg(Handle, lpMsg, DWORD(dwFlags), &NumberOfBytesSent, nil, - nil) == SOCKET_ERROR { + if WSASendMsg( + Handle, + lpMsg, + DWORD(dwFlags), + &NumberOfBytesSent, + nil, + nil + ) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "sendmsg") } return .processed(size_t(NumberOfBytesSent)) } @inline(never) - static func send(socket s: NIOBSDSocket.Handle, - buffer buf: UnsafeRawPointer, - length len: size_t) throws -> IOResult { + static func send( + socket s: NIOBSDSocket.Handle, + buffer buf: UnsafeRawPointer, + length len: size_t + ) throws -> IOResult { let iResult: CInt = CNIOWindows_send(s, buf, CInt(len), 0) if iResult == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "send") @@ -339,13 +394,20 @@ extension NIOBSDSocket { } @inline(never) - static func setsockopt(socket: NIOBSDSocket.Handle, - level: NIOBSDSocket.OptionLevel, - option_name optname: NIOBSDSocket.Option, - option_value optval: UnsafeRawPointer, - option_len optlen: socklen_t) throws { - if CNIOWindows_setsockopt(socket, level.rawValue, optname.rawValue, - optval, optlen) == SOCKET_ERROR { + static func setsockopt( + socket: NIOBSDSocket.Handle, + level: NIOBSDSocket.OptionLevel, + option_name optname: NIOBSDSocket.Option, + option_value optval: UnsafeRawPointer, + option_len optlen: socklen_t + ) throws { + if CNIOWindows_setsockopt( + socket, + level.rawValue, + optname.rawValue, + optval, + optlen + ) == SOCKET_ERROR { throw IOError(winsock: WSAGetLastError(), reason: "setsockopt") } } @@ -358,9 +420,11 @@ extension NIOBSDSocket { } @inline(never) - static func socket(domain af: NIOBSDSocket.ProtocolFamily, - type: NIOBSDSocket.SocketType, - protocolSubtype: NIOBSDSocket.ProtocolSubtype) throws -> NIOBSDSocket.Handle { + static func socket( + domain af: NIOBSDSocket.ProtocolFamily, + type: NIOBSDSocket.SocketType, + protocolSubtype: NIOBSDSocket.ProtocolSubtype + ) throws -> NIOBSDSocket.Handle { let socket: NIOBSDSocket.Handle = WinSDK.socket(af.rawValue, type.rawValue, protocolSubtype.rawValue) if socket == WinSDK.INVALID_SOCKET { throw IOError(winsock: WSAGetLastError(), reason: "socket") @@ -369,35 +433,51 @@ extension NIOBSDSocket { } @inline(never) - static func recvmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, flags: CInt, - timeout: UnsafeMutablePointer?) - throws -> IOResult { - return .processed(Int(CNIOWindows_recvmmsg(socket, msgvec, vlen, flags, timeout))) + static func recvmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt, + timeout: UnsafeMutablePointer? + ) + throws -> IOResult + { + .processed(Int(CNIOWindows_recvmmsg(socket, msgvec, vlen, flags, timeout))) } @inline(never) - static func sendmmsg(socket: NIOBSDSocket.Handle, - msgvec: UnsafeMutablePointer, - vlen: CUnsignedInt, flags: CInt) - throws -> IOResult { - return .processed(Int(CNIOWindows_sendmmsg(socket, msgvec, vlen, flags))) + static func sendmmsg( + socket: NIOBSDSocket.Handle, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt + ) + throws -> IOResult + { + .processed(Int(CNIOWindows_sendmmsg(socket, msgvec, vlen, flags))) } // NOTE: this should return a `ssize_t`, however, that is not a standard // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. @inline(never) - static func pread(socket: NIOBSDSocket.Handle, - pointer: UnsafeMutableRawPointer, - size: size_t, offset: off_t) throws -> IOResult { + static func pread( + socket: NIOBSDSocket.Handle, + pointer: UnsafeMutableRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult { var ovlOverlapped: OVERLAPPED = OVERLAPPED() - ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffffffff) - ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffffffff) + ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffff_ffff) + ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffff_ffff) var nNumberOfBytesRead: DWORD = 0 - if !ReadFile(HANDLE(bitPattern: UInt(socket)), pointer, DWORD(size), - &nNumberOfBytesRead, &ovlOverlapped) { + if !ReadFile( + HANDLE(bitPattern: UInt(socket)), + pointer, + DWORD(size), + &nNumberOfBytesRead, + &ovlOverlapped + ) { throw IOError(windows: GetLastError(), reason: "ReadFile") } return .processed(size_t(nNumberOfBytesRead)) @@ -407,33 +487,54 @@ extension NIOBSDSocket { // type, and defining that type is difficult. Opt to return a `size_t` // which is the same size, but is unsigned. @inline(never) - static func pwrite(socket: NIOBSDSocket.Handle, pointer: UnsafeRawPointer, - size: size_t, offset: off_t) throws -> IOResult { + static func pwrite( + socket: NIOBSDSocket.Handle, + pointer: UnsafeRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult { var ovlOverlapped: OVERLAPPED = OVERLAPPED() - ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffffffff) - ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffffffff) + ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffff_ffff) + ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffff_ffff) var nNumberOfBytesWritten: DWORD = 0 - if !WriteFile(HANDLE(bitPattern: UInt(socket)), pointer, DWORD(size), - &nNumberOfBytesWritten, &ovlOverlapped) { + if !WriteFile( + HANDLE(bitPattern: UInt(socket)), + pointer, + DWORD(size), + &nNumberOfBytesWritten, + &ovlOverlapped + ) { throw IOError(windows: GetLastError(), reason: "WriteFile") } return .processed(size_t(nNumberOfBytesWritten)) } @inline(never) - static func sendfile(socket s: NIOBSDSocket.Handle, fd: CInt, offset: off_t, - len nNumberOfBytesToWrite: off_t) - throws -> IOResult { + static func sendfile( + socket s: NIOBSDSocket.Handle, + fd: CInt, + offset: off_t, + len nNumberOfBytesToWrite: off_t + ) + throws -> IOResult + { let hFile: HANDLE = HANDLE(bitPattern: ucrt._get_osfhandle(fd))! if hFile == INVALID_HANDLE_VALUE { throw IOError(errnoCode: EBADF, reason: "_get_osfhandle") } var ovlOverlapped: OVERLAPPED = OVERLAPPED() - ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffffffff) - ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffffffff) - if !TransmitFile(s, hFile, DWORD(nNumberOfBytesToWrite), 0, - &ovlOverlapped, nil, DWORD(TF_USE_KERNEL_APC)) { + ovlOverlapped.Offset = DWORD(UInt32(offset >> 0) & 0xffff_ffff) + ovlOverlapped.OffsetHigh = DWORD(UInt32(offset >> 32) & 0xffff_ffff) + if !TransmitFile( + s, + hFile, + DWORD(nNumberOfBytesToWrite), + 0, + &ovlOverlapped, + nil, + DWORD(TF_USE_KERNEL_APC) + ) { throw IOError(winsock: WSAGetLastError(), reason: "TransmitFile") } @@ -444,7 +545,7 @@ extension NIOBSDSocket { // Returns nil if mptcp is not supported. static var mptcpProtocolSubtype: Int? { // MPTCP not supported on Windows. - return nil + nil } } @@ -462,30 +563,38 @@ extension NIOBSDSocket { } static func cleanupUnixDomainSocket(atPath path: String) throws { - guard let hFile = (path.withCString(encodedAs: UTF16.self) { - CreateFileW($0, GENERIC_READ, + guard + let hFile = + (path.withCString(encodedAs: UTF16.self) { + CreateFileW( + $0, + GENERIC_READ, DWORD(FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE), - nil, DWORD(OPEN_EXISTING), + nil, + DWORD(OPEN_EXISTING), DWORD(FILE_FLAG_OPEN_REPARSE_POINT | FILE_FLAG_BACKUP_SEMANTICS), - nil) - }) else { + nil + ) + }) + else { throw IOError(windows: DWORD(EBADF), reason: "CreateFileW") } defer { CloseHandle(hFile) } - let ftFileType = GetFileType(hFile) + let ftFileType = GetFileType(hFile) let dwError = GetLastError() guard dwError == NO_ERROR, ftFileType != FILE_TYPE_DISK else { throw IOError(windows: dwError, reason: "GetFileType") } var fiInformation: BY_HANDLE_FILE_INFORMATION = - BY_HANDLE_FILE_INFORMATION() + BY_HANDLE_FILE_INFORMATION() guard GetFileInformationByHandle(hFile, &fiInformation) else { throw IOError(windows: GetLastError(), reason: "GetFileInformationByHandle") } - guard fiInformation.dwFileAttributes & DWORD(FILE_ATTRIBUTE_REPARSE_POINT) == FILE_ATTRIBUTE_REPARSE_POINT else { + guard fiInformation.dwFileAttributes & DWORD(FILE_ATTRIBUTE_REPARSE_POINT) == FILE_ATTRIBUTE_REPARSE_POINT + else { throw UnixDomainSocketPathWrongType() } @@ -493,9 +602,16 @@ extension NIOBSDSocket { var dbReparseDataBuffer: CNIOWindows_REPARSE_DATA_BUFFER = CNIOWindows_REPARSE_DATA_BUFFER() try withUnsafeMutablePointer(to: &dbReparseDataBuffer) { - if !DeviceIoControl(hFile, FSCTL_GET_REPARSE_POINT, nil, 0, $0, - DWORD(MemoryLayout.stride), - &nBytesWritten, nil) { + if !DeviceIoControl( + hFile, + FSCTL_GET_REPARSE_POINT, + nil, + 0, + $0, + DWORD(MemoryLayout.stride), + &nBytesWritten, + nil + ) { throw IOError(windows: GetLastError(), reason: "DeviceIoControl") } } @@ -507,8 +623,12 @@ extension NIOBSDSocket { var fdi: FILE_DISPOSITION_INFO_EX = FILE_DISPOSITION_INFO_EX() fdi.Flags = DWORD(FILE_DISPOSITION_FLAG_DELETE | FILE_DISPOSITION_FLAG_POSIX_SEMANTICS) - if !SetFileInformationByHandle(hFile, FileDispositionInfoEx, &fdi, - DWORD(MemoryLayout.stride)) { + if !SetFileInformationByHandle( + hFile, + FileDispositionInfoEx, + &fdi, + DWORD(MemoryLayout.stride) + ) { throw IOError(windows: GetLastError(), reason: "GetFileInformationByHandle") } } @@ -516,27 +636,39 @@ extension NIOBSDSocket { // MARK: _BSDSocketControlMessageProtocol implementation extension NIOBSDSocketControlMessage { - static func firstHeader(inside msghdr: UnsafePointer) - -> UnsafeMutablePointer? { - return CNIOWindows_CMSG_FIRSTHDR(msghdr) - } - - static func nextHeader(inside msghdr: UnsafeMutablePointer, - after: UnsafeMutablePointer) - -> UnsafeMutablePointer? { - return CNIOWindows_CMSG_NXTHDR(msghdr, after) - } - - static func data(for header: UnsafePointer) - -> UnsafeRawBufferPointer? { + static func firstHeader( + inside msghdr: UnsafePointer + ) + -> UnsafeMutablePointer? + { + CNIOWindows_CMSG_FIRSTHDR(msghdr) + } + + static func nextHeader( + inside msghdr: UnsafeMutablePointer, + after: UnsafeMutablePointer + ) + -> UnsafeMutablePointer? + { + CNIOWindows_CMSG_NXTHDR(msghdr, after) + } + + static func data( + for header: UnsafePointer + ) + -> UnsafeRawBufferPointer? + { let data = CNIOWindows_CMSG_DATA(header) let length = size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) return UnsafeRawBufferPointer(start: data, count: Int(length)) } - static func data(for header: UnsafeMutablePointer) - -> UnsafeMutableRawBufferPointer? { + static func data( + for header: UnsafeMutablePointer + ) + -> UnsafeMutableRawBufferPointer? + { let data = CNIOWindows_CMSG_DATA_MUTABLE(header) let length = size_t(header.pointee.cmsg_len) - NIOBSDSocketControlMessage.length(payloadSize: 0) @@ -544,11 +676,11 @@ extension NIOBSDSocketControlMessage { } static func length(payloadSize: size_t) -> size_t { - return CNIOWindows_CMSG_LEN(payloadSize) + CNIOWindows_CMSG_LEN(payloadSize) } static func space(payloadSize: size_t) -> size_t { - return CNIOWindows_CMSG_SPACE(payloadSize) + CNIOWindows_CMSG_SPACE(payloadSize) } } diff --git a/Sources/NIOPosix/BaseSocket.swift b/Sources/NIOPosix/BaseSocket.swift index 39cd4b8387..489fec44de 100644 --- a/Sources/NIOPosix/BaseSocket.swift +++ b/Sources/NIOPosix/BaseSocket.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import NIOConcurrencyHelpers +import NIOCore #if os(Windows) import let WinSDK.EAFNOSUPPORT @@ -38,7 +38,7 @@ protocol Registration { // only our rebinding copy here is allowed. extension sockaddr_storage { mutating func withMutableSockAddr(_ body: (UnsafeMutablePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeMutableBytes(of: &self) { p in + try withUnsafeMutableBytes(of: &self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } @@ -117,9 +117,9 @@ class BaseSocket: BaseSocketProtocol { private var descriptor: NIOBSDSocket.Handle public var isOpen: Bool { #if os(Windows) - return descriptor != NIOBSDSocket.invalidHandle + return descriptor != NIOBSDSocket.invalidHandle #else - return descriptor >= 0 + return descriptor >= 0 #endif } @@ -128,7 +128,7 @@ class BaseSocket: BaseSocketProtocol { /// - returns: The local bound address. /// - throws: An `IOError` if the retrieval of the address failed. func localAddress() throws -> SocketAddress { - return try get_addr { + try get_addr { try NIOBSDSocket.getsockname(socket: $0, address: $1, address_len: $2) } } @@ -138,13 +138,15 @@ class BaseSocket: BaseSocketProtocol { /// - returns: The connected address. /// - throws: An `IOError` if the retrieval of the address failed. func remoteAddress() throws -> SocketAddress { - return try get_addr { + try get_addr { try NIOBSDSocket.getpeername(socket: $0, address: $1, address_len: $2) } } /// Internal helper function for retrieval of a `SocketAddress`. - private func get_addr(_ body: (NIOBSDSocket.Handle, UnsafeMutablePointer, UnsafeMutablePointer) throws -> Void) throws -> SocketAddress { + private func get_addr( + _ body: (NIOBSDSocket.Handle, UnsafeMutablePointer, UnsafeMutablePointer) throws -> Void + ) throws -> SocketAddress { var addr = sockaddr_storage() try addr.withMutableSockAddr { addressPtr, size in @@ -177,9 +179,11 @@ class BaseSocket: BaseSocketProtocol { sockType = type.rawValue | Linux.SOCK_NONBLOCK } #endif - let sock = try NIOBSDSocket.socket(domain: protocolFamily, - type: NIOBSDSocket.SocketType(rawValue: sockType), - protocolSubtype: protocolSubtype) + let sock = try NIOBSDSocket.socket( + domain: protocolFamily, + type: NIOBSDSocket.SocketType(rawValue: sockType), + protocolSubtype: protocolSubtype + ) #if !os(Linux) if setNonBlocking { do { @@ -194,21 +198,27 @@ class BaseSocket: BaseSocketProtocol { if protocolFamily == .inet6 { var zero: Int32 = 0 do { - try NIOBSDSocket.setsockopt(socket: sock, level: .ipv6, option_name: .ipv6_v6only, option_value: &zero, option_len: socklen_t(MemoryLayout.size(ofValue: zero))) + try NIOBSDSocket.setsockopt( + socket: sock, + level: .ipv6, + option_name: .ipv6_v6only, + option_value: &zero, + option_len: socklen_t(MemoryLayout.size(ofValue: zero)) + ) } catch let e as IOError { if e.errnoCode != EAFNOSUPPORT { // Ignore error that may be thrown by close. _ = try? NIOBSDSocket.close(socket: sock) throw e } - /* we couldn't enable dual IP4/6 support, that's okay too. */ + // we couldn't enable dual IP4/6 support, that's okay too. } catch let e { fatalError("Unexpected error type \(e)") } } return sock } - + /// Cleanup the unix domain socket. /// /// Deletes the associated file if it exists and has socket type. Does nothing if pathname does not exist. @@ -229,15 +239,16 @@ class BaseSocket: BaseSocketProtocol { /// - descriptor: The file descriptor to wrap. init(socket descriptor: NIOBSDSocket.Handle) throws { #if os(Windows) - precondition(descriptor != NIOBSDSocket.invalidHandle, "invalid socket") + precondition(descriptor != NIOBSDSocket.invalidHandle, "invalid socket") #else - precondition(descriptor >= 0, "invalid socket") + precondition(descriptor >= 0, "invalid socket") #endif self.descriptor = descriptor do { try self.ignoreSIGPIPE() } catch { - self.descriptor = NIOBSDSocket.invalidHandle // We have to unset the fd here, otherwise we'll crash with "leaking open BaseSocket" + // We have to unset the fd here, otherwise we'll crash with "leaking open BaseSocket" + self.descriptor = NIOBSDSocket.invalidHandle throw error } } @@ -256,7 +267,7 @@ class BaseSocket: BaseSocketProtocol { /// /// throws: An `IOError` if the operation failed. final func setNonBlocking() throws { - return try self.withUnsafeHandle { + try self.withUnsafeHandle { try NIOBSDSocket.setNonBlocking(socket: $0) } } @@ -288,7 +299,8 @@ class BaseSocket: BaseSocketProtocol { level: level, option_name: name, option_value: valueBuffer.baseAddress!, - option_len: socklen_t(valueBuffer.count)) + option_len: socklen_t(valueBuffer.count) + ) } } } @@ -302,10 +314,12 @@ class BaseSocket: BaseSocketProtocol { /// - name: The name of the option to set. /// - throws: An `IOError` if the operation failed. func getOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option) throws -> T { - return try self.withUnsafeHandle { fd in + try self.withUnsafeHandle { fd in var length = socklen_t(MemoryLayout.size) - let storage = UnsafeMutableRawBufferPointer.allocate(byteCount: MemoryLayout.stride, - alignment: MemoryLayout.alignment) + let storage = UnsafeMutableRawBufferPointer.allocate( + byteCount: MemoryLayout.stride, + alignment: MemoryLayout.alignment + ) // write zeroes into the memory as Linux's getsockopt doesn't zero them out storage.initializeMemory(as: UInt8.self, repeating: 0) let val = storage.bindMemory(to: T.self).baseAddress! @@ -315,7 +329,13 @@ class BaseSocket: BaseSocketProtocol { storage.deallocate() } - try NIOBSDSocket.getsockopt(socket: fd, level: level, option_name: name, option_value: val, option_len: &length) + try NIOBSDSocket.getsockopt( + socket: fd, + level: level, + option_name: name, + option_value: val, + option_len: &length + ) return val.pointee } } @@ -349,7 +369,7 @@ class BaseSocket: BaseSocketProtocol { /// /// - throws: An `IOError` if the operation failed. final func takeDescriptorOwnership() throws -> NIOBSDSocket.Handle { - return try self.withUnsafeHandle { + try self.withUnsafeHandle { self.descriptor = NIOBSDSocket.invalidHandle return $0 } @@ -367,7 +387,7 @@ extension BaseSocket: Selectable { extension BaseSocket: CustomStringConvertible { var description: String { - return "BaseSocket { fd=\(self.descriptor) }" + "BaseSocket { fd=\(self.descriptor) }" } } @@ -376,23 +396,24 @@ extension BaseSocket: CustomStringConvertible { // the compiler falls over when we try to access them from test code. As these functions // exist purely to make the behaviours accessible from test code, we name them truly awfully. func __testOnly_convertSockAddr(_ addr: sockaddr_storage) -> sockaddr_in { - return addr.convert() + addr.convert() } func __testOnly_convertSockAddr(_ addr: sockaddr_storage) -> sockaddr_in6 { - return addr.convert() + addr.convert() } func __testOnly_convertSockAddr(_ addr: sockaddr_storage) -> sockaddr_un { - return addr.convert() + addr.convert() } func __testOnly_convertSockAddr(_ addr: sockaddr_storage) throws -> SocketAddress { - return try addr.convert() + try addr.convert() } func __testOnly_withMutableSockAddr( - _ addr: inout sockaddr_storage, _ body: (UnsafeMutablePointer, Int) throws -> ReturnType + _ addr: inout sockaddr_storage, + _ body: (UnsafeMutablePointer, Int) throws -> ReturnType ) rethrows -> ReturnType { - return try addr.withMutableSockAddr(body) + try addr.withMutableSockAddr(body) } diff --git a/Sources/NIOPosix/BaseSocketChannel+SocketOptionProvider.swift b/Sources/NIOPosix/BaseSocketChannel+SocketOptionProvider.swift index 03eff14974..c50429abd3 100644 --- a/Sources/NIOPosix/BaseSocketChannel+SocketOptionProvider.swift +++ b/Sources/NIOPosix/BaseSocketChannel+SocketOptionProvider.swift @@ -15,12 +15,24 @@ import NIOCore extension BaseSocketChannel: SocketOptionProvider { #if !os(Windows) - func unsafeSetSocketOption(level: SocketOptionLevel, name: SocketOptionName, value: Value) -> EventLoopFuture { - return unsafeSetSocketOption(level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), name: NIOBSDSocket.Option(rawValue: CInt(name)), value: value) - } + func unsafeSetSocketOption( + level: SocketOptionLevel, + name: SocketOptionName, + value: Value + ) -> EventLoopFuture { + unsafeSetSocketOption( + level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), + name: NIOBSDSocket.Option(rawValue: CInt(name)), + value: value + ) + } #endif - func unsafeSetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option, value: Value) -> EventLoopFuture { + func unsafeSetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option, + value: Value + ) -> EventLoopFuture { if eventLoop.inEventLoop { let promise = eventLoop.makePromise(of: Void.self) executeAndComplete(promise) { @@ -35,12 +47,18 @@ extension BaseSocketChannel: SocketOptionProvider { } #if !os(Windows) - func unsafeGetSocketOption(level: SocketOptionLevel, name: SocketOptionName) -> EventLoopFuture { - return unsafeGetSocketOption(level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), name: NIOBSDSocket.Option(rawValue: CInt(name))) - } + func unsafeGetSocketOption(level: SocketOptionLevel, name: SocketOptionName) -> EventLoopFuture { + unsafeGetSocketOption( + level: NIOBSDSocket.OptionLevel(rawValue: CInt(level)), + name: NIOBSDSocket.Option(rawValue: CInt(name)) + ) + } #endif - func unsafeGetSocketOption(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option) -> EventLoopFuture { + func unsafeGetSocketOption( + level: NIOBSDSocket.OptionLevel, + name: NIOBSDSocket.Option + ) -> EventLoopFuture { if eventLoop.inEventLoop { let promise = eventLoop.makePromise(of: Value.self) executeAndComplete(promise) { @@ -59,6 +77,6 @@ extension BaseSocketChannel: SocketOptionProvider { } func getSocketOption0(level: NIOBSDSocket.OptionLevel, name: NIOBSDSocket.Option) throws -> Value { - return try self.socket.getOption(level: level, name: name) + try self.socket.getOption(level: level, name: name) } } diff --git a/Sources/NIOPosix/BaseSocketChannel.swift b/Sources/NIOPosix/BaseSocketChannel.swift index 21ee1f39ea..98cd75a57d 100644 --- a/Sources/NIOPosix/BaseSocketChannel.swift +++ b/Sources/NIOPosix/BaseSocketChannel.swift @@ -12,16 +12,16 @@ // //===----------------------------------------------------------------------===// -import NIOCore -import NIOConcurrencyHelpers import Atomics +import NIOConcurrencyHelpers +import NIOCore private struct SocketChannelLifecycleManager { // MARK: Types private enum State { case fresh - case preRegistered // register() has been run but the selector doesn't know about it yet - case fullyRegistered // fully registered, ie. the selector knows about it + case preRegistered // register() has been run but the selector doesn't know about it yet + case fullyRegistered // fully registered, ie. the selector knows about it case activated case closed } @@ -74,31 +74,36 @@ private struct SocketChannelLifecycleManager { // this is called from Channel's deinit, so don't assert we're on the EventLoop! internal var canBeDestroyed: Bool { - return self.currentState == .closed + self.currentState == .closed } - @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + @inline(__always) internal mutating func beginRegistration() -> ((EventLoopPromise?, ChannelPipeline) -> Void) { - return self.moveState(event: .beginRegistration) + self.moveState(event: .beginRegistration) } - @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + @inline(__always) internal mutating func finishRegistration() -> ((EventLoopPromise?, ChannelPipeline) -> Void) { - return self.moveState(event: .finishRegistration) + self.moveState(event: .finishRegistration) } - @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + @inline(__always) internal mutating func close() -> ((EventLoopPromise?, ChannelPipeline) -> Void) { - return self.moveState(event: .close) + self.moveState(event: .close) } - @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + @inline(__always) internal mutating func activate() -> ((EventLoopPromise?, ChannelPipeline) -> Void) { - return self.moveState(event: .activate) + self.moveState(event: .activate) } // MARK: private API - @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + @inline(__always) private mutating func moveState(event: Event) -> ((EventLoopPromise?, ChannelPipeline) -> Void) { self.eventLoop.assertInEventLoop() @@ -156,16 +161,16 @@ private struct SocketChannelLifecycleManager { } // bad transitions - case (.fresh, .activate), // should go through .registered first - (.preRegistered, .activate), // need to first be fully registered - (.preRegistered, .beginRegistration), // already registered - (.fullyRegistered, .beginRegistration), // already registered - (.activated, .activate), // already activated - (.activated, .beginRegistration), // already fully registered (and activated) - (.activated, .finishRegistration), // already fully registered (and activated) - (.fullyRegistered, .finishRegistration), // already fully registered - (.fresh, .finishRegistration), // need to register lazily first - (.closed, _): // already closed + case (.fresh, .activate), // should go through .registered first + (.preRegistered, .activate), // need to first be fully registered + (.preRegistered, .beginRegistration), // already registered + (.fullyRegistered, .beginRegistration), // already registered + (.activated, .activate), // already activated + (.activated, .beginRegistration), // already fully registered (and activated) + (.activated, .finishRegistration), // already fully registered (and activated) + (.fullyRegistered, .finishRegistration), // already fully registered + (.fresh, .finishRegistration), // need to register lazily first + (.closed, _): // already closed self.badTransition(event: event) } } @@ -223,8 +228,8 @@ class BaseSocketChannel: SelectableChannel, Chan struct AddressCache { // deliberately lets because they must always be updated together (so forcing `init` is useful). - let local: Optional - let remote: Optional + let local: SocketAddress? + let remote: SocketAddress? init(local: SocketAddress?, remote: SocketAddress?) { self.local = local @@ -248,17 +253,19 @@ class BaseSocketChannel: SelectableChannel, Chan var pendingConnect: Optional> var recvBufferPool: PooledRecvBufferAllocator var maxMessagesPerRead: UInt = 4 - private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method. + private var inFlushNow: Bool = false // Guard against re-entrance of flushNow() method. private var autoRead: Bool = true - // MARK: Variables that are really constants - private var _pipeline: ChannelPipeline! = nil // this is really a constant (set in .init) but needs `self` to be constructed and therefore a `var`. Do not change as this needs to accessed from arbitrary threads + // MARK: Variables that are really constant + // this is really a constant (set in .init) but needs `self` to be constructed and + // therefore a `var`. Do not change as this needs to accessed from arbitrary threads + private var _pipeline: ChannelPipeline! = nil // MARK: Special variables, please read comments. // For reads guarded by _either_ `self._offEventLoopLock` or the EL thread // Writes are guarded by _offEventLoopLock _and_ the EL thread. // PLEASE don't use these directly and use the non-underscored computed properties instead. - private var _addressCache = AddressCache(local: nil, remote: nil) // please use `self.addressesCached` instead + private var _addressCache = AddressCache(local: nil, remote: nil) // please use `self.addressesCached` instead private var _bufferAllocatorCache: ByteBufferAllocator // please use `self.bufferAllocatorCached` instead. // MARK: - Computed properties @@ -269,7 +276,7 @@ class BaseSocketChannel: SelectableChannel, Chan return self._addressCache } else { return self._offEventLoopLock.withLock { - return self._addressCache + self._addressCache } } } @@ -288,7 +295,7 @@ class BaseSocketChannel: SelectableChannel, Chan return self._bufferAllocatorCache } else { return self._offEventLoopLock.withLock { - return self._bufferAllocatorCache + self._bufferAllocatorCache } } } @@ -321,16 +328,16 @@ class BaseSocketChannel: SelectableChannel, Chan } } - public final var _channelCore: ChannelCore { return self } + public final var _channelCore: ChannelCore { self } // This is `Channel` API so must be thread-safe. public final var localAddress: SocketAddress? { - return self.addressesCached.local + self.addressesCached.local } // This is `Channel` API so must be thread-safe. public final var remoteAddress: SocketAddress? { - return self.addressesCached.remote + self.addressesCached.remote } /// `false` if the whole `Channel` is closed and so no more IO operation can be done. @@ -346,31 +353,31 @@ class BaseSocketChannel: SelectableChannel, Chan // This is `Channel` API so must be thread-safe. public var isActive: Bool { - return self.isActiveAtomic.load(ordering: .relaxed) + self.isActiveAtomic.load(ordering: .relaxed) } // This is `Channel` API so must be thread-safe. public final var closeFuture: EventLoopFuture { - return self.closePromise.futureResult + self.closePromise.futureResult } public final var eventLoop: EventLoop { - return selectableEventLoop + selectableEventLoop } // This is `Channel` API so must be thread-safe. public var isWritable: Bool { - return true + true } // This is `Channel` API so must be thread-safe. public final var allocator: ByteBufferAllocator { - return self.bufferAllocatorCached + self.bufferAllocatorCached } // This is `Channel` API so must be thread-safe. public final var pipeline: ChannelPipeline { - return self._pipeline + self._pipeline } // MARK: Methods to override in subclasses. @@ -508,8 +515,10 @@ class BaseSocketChannel: SelectableChannel, Chan } deinit { - assert(self.lifecycleManager.canBeDestroyed, - "leak of open Channel, state: \(String(describing: self.lifecycleManager))") + assert( + self.lifecycleManager.canBeDestroyed, + "leak of open Channel, state: \(String(describing: self.lifecycleManager))" + ) } public final func localAddress0() throws -> SocketAddress { @@ -590,9 +599,11 @@ class BaseSocketChannel: SelectableChannel, Chan return .unregister } - assert((newWriteRegistrationState == .register && self.hasFlushedPendingWrites()) || - (newWriteRegistrationState == .unregister && !self.hasFlushedPendingWrites()), - "illegal flushNow decision: \(newWriteRegistrationState) and \(self.hasFlushedPendingWrites())") + assert( + (newWriteRegistrationState == .register && self.hasFlushedPendingWrites()) + || (newWriteRegistrationState == .unregister && !self.hasFlushedPendingWrites()), + "illegal flushNow decision: \(newWriteRegistrationState) and \(self.hasFlushedPendingWrites())" + ) return newWriteRegistrationState } @@ -879,8 +890,8 @@ class BaseSocketChannel: SelectableChannel, Chan self.cancelWritesOnClose(error: error) // this should be a no-op as we shouldn't have any - errorCallouts.forEach { - $0(self.pipeline) + for callout in errorCallouts { + callout(self.pipeline) } if let connectPromise = self.pendingConnect { @@ -901,7 +912,6 @@ class BaseSocketChannel: SelectableChannel, Chan } } - public final func register0(promise: EventLoopPromise?) { self.eventLoop.assertInEventLoop() @@ -973,7 +983,7 @@ class BaseSocketChannel: SelectableChannel, Chan self.finishWritable() case .register: assert(!self.isOpen || self.interestedEvent.contains(.write)) - () // nothing to do because given that we just received `writable`, we're still registered for writable. + () // nothing to do because given that we just received `writable`, we're still registered for writable. } } @@ -1022,8 +1032,10 @@ class BaseSocketChannel: SelectableChannel, Chan // we can't be not active but still registered here; this would mean that we got a notification about a // channel before we're ready to receive them. - assert(self.lifecycleManager.isRegisteredFully, - "illegal state: \(self): active: \(self.lifecycleManager.isActive), registered: \(self.lifecycleManager.isRegisteredFully)") + assert( + self.lifecycleManager.isRegisteredFully, + "illegal state: \(self): active: \(self.lifecycleManager.isActive), registered: \(self.lifecycleManager.isRegisteredFully)" + ) self.readEOF0() @@ -1081,7 +1093,8 @@ class BaseSocketChannel: SelectableChannel, Chan #if os(Linux) let message: String = "connection reset (no error set)" #else - let message: String = "BUG IN SwiftNIO (possibly #572), please report! Connection reset (no error set)." + let message: String = + "BUG IN SwiftNIO (possibly #572), please report! Connection reset (no error set)." #endif error = IOError(errnoCode: ECONNRESET, reason: message) } @@ -1094,8 +1107,10 @@ class BaseSocketChannel: SelectableChannel, Chan } public final func readable() { - assert(!self.lifecycleManager.hasSeenEOFNotification, - "got a read notification after having already seen .readEOF") + assert( + !self.lifecycleManager.hasSeenEOFNotification, + "got a read notification after having already seen .readEOF" + ) self.readable0() } @@ -1185,7 +1200,7 @@ class BaseSocketChannel: SelectableChannel, Chan /// - err: The `Error` which was thrown by `readFromSocket`. /// - returns: `true` if the `Channel` should be closed, `false` otherwise. func shouldCloseOnReadError(_ err: Error) -> Bool { - return true + true } internal final func updateCachedAddressesFromSocket(updateLocal: Bool = true, updateRemote: Bool = true) { @@ -1260,7 +1275,7 @@ class BaseSocketChannel: SelectableChannel, Chan } private func isWritePending() -> Bool { - return self.interestedEvent.contains(.write) + self.interestedEvent.contains(.write) } private final func safeReregister(interested: SelectorEventSet) { @@ -1364,10 +1379,10 @@ class BaseSocketChannel: SelectableChannel, Chan extension BaseSocketChannel { public struct SynchronousOptions: NIOSynchronousChannelOptions { - @usableFromInline // should be private + @usableFromInline // should be private internal let _channel: BaseSocketChannel - @inlinable // should be fileprivate + @inlinable // should be fileprivate internal init(_channel channel: BaseSocketChannel) { self._channel = channel } @@ -1379,12 +1394,12 @@ extension BaseSocketChannel { @inlinable public func getOption(_ option: Option) throws -> Option.Value { - return try self._channel.getOption0(option) + try self._channel.getOption0(option) } } public final var syncOptions: NIOSynchronousChannelOptions? { - return SynchronousOptions(_channel: self) + SynchronousOptions(_channel: self) } } diff --git a/Sources/NIOPosix/BaseStreamSocketChannel.swift b/Sources/NIOPosix/BaseStreamSocketChannel.swift index 6e9c2a8334..605f551464 100644 --- a/Sources/NIOPosix/BaseStreamSocketChannel.swift +++ b/Sources/NIOPosix/BaseStreamSocketChannel.swift @@ -98,7 +98,7 @@ class BaseStreamSocketChannel: BaseSocketChannel // MARK: BaseSocketChannel's must override API that cannot be further refined by subclasses // This is `Channel` API so must be thread-safe. final override public var isWritable: Bool { - return self.pendingWrites.isWritable + self.pendingWrites.isWritable } final override var isOpen: Bool { @@ -159,19 +159,23 @@ class BaseStreamSocketChannel: BaseSocketChannel } final override func writeToSocket() throws -> OverallWriteResult { - let result = try self.pendingWrites.triggerAppropriateWriteOperations(scalarBufferWriteOperation: { ptr in - guard ptr.count > 0 else { - // No need to call write if the buffer is empty. - return .processed(0) + let result = try self.pendingWrites.triggerAppropriateWriteOperations( + scalarBufferWriteOperation: { ptr in + guard ptr.count > 0 else { + // No need to call write if the buffer is empty. + return .processed(0) + } + // normal write + return try self.socket.write(pointer: ptr) + }, + vectorBufferWriteOperation: { ptrs in + // Gathering write + try self.socket.writev(iovecs: ptrs) + }, + scalarFileWriteOperation: { descriptor, index, endIndex in + try self.socket.sendFile(fd: descriptor, offset: index, count: endIndex - index) } - // normal write - return try self.socket.write(pointer: ptr) - }, vectorBufferWriteOperation: { ptrs in - // Gathering write - try self.socket.writev(iovecs: ptrs) - }, scalarFileWriteOperation: { descriptor, index, endIndex in - try self.socket.sendFile(fd: descriptor, offset: index, count: endIndex - index) - }) + ) return result } @@ -231,7 +235,7 @@ class BaseStreamSocketChannel: BaseSocketChannel } final override func hasFlushedPendingWrites() -> Bool { - return self.pendingWrites.isFlushPending + self.pendingWrites.isFlushPending } final override func markFlushPoint() { diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index b06c8ced4d..6deaae1e1c 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -31,7 +31,7 @@ internal typealias ChannelInitializerCallback = @Sendable (Channel) -> EventLoop /// Common functionality for all NIO on sockets bootstraps. internal enum NIOOnSocketsBootstraps { internal static func isCompatible(group: EventLoopGroup) -> Bool { - return group is SelectableEventLoop || group is MultiThreadedEventLoopGroup + group is SelectableEventLoop || group is MultiThreadedEventLoopGroup } } @@ -98,8 +98,10 @@ public final class ServerBootstrap { /// - group: The `EventLoopGroup` to use for the `bind` of the `ServerSocketChannel` and to accept new `SocketChannel`s with. public convenience init(group: EventLoopGroup) { guard NIOOnSocketsBootstraps.isCompatible(group: group) else { - preconditionFailure("ServerBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + preconditionFailure( + "ServerBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible." + ) } self.init(validatingGroup: group, childGroup: group)! } @@ -115,10 +117,14 @@ public final class ServerBootstrap { /// - group: The `EventLoopGroup` to use for the `bind` of the `ServerSocketChannel` and to accept new `SocketChannel`s with. /// - childGroup: The `EventLoopGroup` to run the accepted `SocketChannel`s on. public convenience init(group: EventLoopGroup, childGroup: EventLoopGroup) { - guard NIOOnSocketsBootstraps.isCompatible(group: group) && NIOOnSocketsBootstraps.isCompatible(group: childGroup) else { - preconditionFailure("ServerBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with group: \(group) and " + - "childGroup: \(childGroup) at least one of which is incompatible.") + guard + NIOOnSocketsBootstraps.isCompatible(group: group) && NIOOnSocketsBootstraps.isCompatible(group: childGroup) + else { + preconditionFailure( + "ServerBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with group: \(group) and " + + "childGroup: \(childGroup) at least one of which is incompatible." + ) } self.init(validatingGroup: group, childGroup: childGroup)! @@ -132,7 +138,9 @@ public final class ServerBootstrap { /// - childGroup: The `EventLoopGroup` to run the accepted `SocketChannel`s on. If `nil`, `group` is used. public init?(validatingGroup group: EventLoopGroup, childGroup: EventLoopGroup? = nil) { let childGroup = childGroup ?? group - guard NIOOnSocketsBootstraps.isCompatible(group: group) && NIOOnSocketsBootstraps.isCompatible(group: childGroup) else { + guard + NIOOnSocketsBootstraps.isCompatible(group: group) && NIOOnSocketsBootstraps.isCompatible(group: childGroup) + else { return nil } @@ -156,7 +164,8 @@ public final class ServerBootstrap { /// - parameters: /// - initializer: A closure that initializes the provided `Channel`. @preconcurrency - public func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self { + public func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture) -> Self + { self.serverChannelInit = initializer return self } @@ -210,7 +219,7 @@ public final class ServerBootstrap { /// - parameters: /// - timeout: The timeout that will apply to the bind attempt. public func bindTimeout(_ timeout: TimeAmount) -> Self { - return self + self } /// Enables multi-path TCP support. @@ -243,8 +252,8 @@ public final class ServerBootstrap { /// - host: The host to bind on. /// - port: The port to bind on. public func bind(host: String, port: Int) -> EventLoopFuture { - return bind0 { - return try SocketAddress.makeAddressResolvingHost(host, port: port) + bind0 { + try SocketAddress.makeAddressResolvingHost(host, port: port) } } @@ -253,7 +262,7 @@ public final class ServerBootstrap { /// - parameters: /// - address: The `SocketAddress` to bind on. public func bind(to address: SocketAddress) -> EventLoopFuture { - return bind0 { address } + bind0 { address } } /// Bind the `ServerSocketChannel` to a UNIX Domain Socket. @@ -261,7 +270,7 @@ public final class ServerBootstrap { /// - parameters: /// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system. public func bind(unixDomainSocketPath: String) -> EventLoopFuture { - return bind0 { + bind0 { try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) } } @@ -289,8 +298,17 @@ public final class ServerBootstrap { /// - parameters: /// - vsockAddress: The VSOCK socket address to bind on. public func bind(to vsockAddress: VsockAddress) -> EventLoopFuture { - func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { - try ServerSocketChannel(eventLoop: eventLoop, group: childEventLoopGroup, protocolFamily: .vsock, enableMPTCP: enableMPTCP) + func makeChannel( + _ eventLoop: SelectableEventLoop, + _ childEventLoopGroup: EventLoopGroup, + _ enableMPTCP: Bool + ) throws -> ServerSocketChannel { + try ServerSocketChannel( + eventLoop: eventLoop, + group: childEventLoopGroup, + protocolFamily: .vsock, + enableMPTCP: enableMPTCP + ) } return bind0(makeServerChannel: makeChannel) { (eventLoop, serverChannel) in serverChannel.register().flatMap { @@ -305,14 +323,14 @@ public final class ServerBootstrap { } #if !os(Windows) - /// Use the existing bound socket file descriptor. - /// - /// - parameters: - /// - descriptor: The _Unix file descriptor_ representing the bound stream socket. - @available(*, deprecated, renamed: "withBoundSocket(_:)") - public func withBoundSocket(descriptor: CInt) -> EventLoopFuture { - return withBoundSocket(descriptor) - } + /// Use the existing bound socket file descriptor. + /// + /// - parameters: + /// - descriptor: The _Unix file descriptor_ representing the bound stream socket. + @available(*, deprecated, renamed: "withBoundSocket(_:)") + public func withBoundSocket(descriptor: CInt) -> EventLoopFuture { + withBoundSocket(descriptor) + } #endif /// Use the existing bound socket file descriptor. @@ -320,7 +338,11 @@ public final class ServerBootstrap { /// - parameters: /// - descriptor: The _Unix file descriptor_ representing the bound stream socket. public func withBoundSocket(_ socket: NIOBSDSocket.Handle) -> EventLoopFuture { - func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { + func makeChannel( + _ eventLoop: SelectableEventLoop, + _ childEventLoopGroup: EventLoopGroup, + _ enableMPTCP: Bool + ) throws -> ServerSocketChannel { if enableMPTCP { throw ChannelError._operationUnsupported } @@ -340,11 +362,17 @@ public final class ServerBootstrap { } catch { return group.next().makeFailedFuture(error) } - func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { - return try ServerSocketChannel(eventLoop: eventLoop, - group: childEventLoopGroup, - protocolFamily: address.protocol, - enableMPTCP: enableMPTCP) + func makeChannel( + _ eventLoop: SelectableEventLoop, + _ childEventLoopGroup: EventLoopGroup, + _ enableMPTCP: Bool + ) throws -> ServerSocketChannel { + try ServerSocketChannel( + eventLoop: eventLoop, + group: childEventLoopGroup, + protocolFamily: address.protocol, + enableMPTCP: enableMPTCP + ) } return bind0(makeServerChannel: makeChannel) { (eventLoop, serverChannel) in @@ -354,7 +382,11 @@ public final class ServerBootstrap { } } - private func bind0(makeServerChannel: (_ eventLoop: SelectableEventLoop, _ childGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel, _ register: @escaping (EventLoop, ServerSocketChannel) -> EventLoopFuture) -> EventLoopFuture { + private func bind0( + makeServerChannel: (_ eventLoop: SelectableEventLoop, _ childGroup: EventLoopGroup, _ enableMPTCP: Bool) throws + -> ServerSocketChannel, + _ register: @escaping (EventLoop, ServerSocketChannel) -> EventLoopFuture + ) -> EventLoopFuture { let eventLoop = self.group.next() let childEventLoopGroup = self.childGroup let serverChannelOptions = self._serverChannelOptions @@ -364,7 +396,11 @@ public final class ServerBootstrap { let serverChannel: ServerSocketChannel do { - serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP) + serverChannel = try makeServerChannel( + eventLoop as! SelectableEventLoop, + childEventLoopGroup, + self.enableMPTCP + ) } catch { return eventLoop.makeFailedFuture(error) } @@ -373,9 +409,13 @@ public final class ServerBootstrap { serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap { serverChannelInit(serverChannel) }.flatMap { - serverChannel.pipeline.addHandler(AcceptHandler(childChannelInitializer: childChannelInit, - childChannelOptions: childChannelOptions), - name: "AcceptHandler") + serverChannel.pipeline.addHandler( + AcceptHandler( + childChannelInitializer: childChannelInit, + childChannelOptions: childChannelOptions + ), + name: "AcceptHandler" + ) }.flatMap { register(eventLoop, serverChannel) }.map { @@ -395,7 +435,10 @@ public final class ServerBootstrap { private let childChannelInit: ((Channel) -> EventLoopFuture)? private let childChannelOptions: ChannelOptions.Storage - init(childChannelInitializer: ((Channel) -> EventLoopFuture)?, childChannelOptions: ChannelOptions.Storage) { + init( + childChannelInitializer: ((Channel) -> EventLoopFuture)?, + childChannelOptions: ChannelOptions.Storage + ) { self.childChannelInit = childChannelInitializer self.childChannelOptions = childChannelOptions } @@ -417,7 +460,7 @@ public final class ServerBootstrap { @inline(__always) func setupChildChannel() -> EventLoopFuture { - return self.childChannelOptions.applyAllChannelOptions(to: accepted).flatMap { () -> EventLoopFuture in + self.childChannelOptions.applyAllChannelOptions(to: accepted).flatMap { () -> EventLoopFuture in childEventLoop.assertInEventLoop() return childChannelInit(accepted) } @@ -442,9 +485,11 @@ public final class ServerBootstrap { if childEventLoop === ctxEventLoop { fireThroughPipeline(setupChildChannel()) } else { - fireThroughPipeline(childEventLoop.flatSubmit { - return setupChildChannel() - }.hop(to: ctxEventLoop)) + fireThroughPipeline( + childEventLoop.flatSubmit { + setupChildChannel() + }.hop(to: ctxEventLoop) + ) } } @@ -503,7 +548,7 @@ extension ServerBootstrap { serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { - return try await bind0( + try await bind0( makeServerChannel: { eventLoop, childEventLoopGroup, enableMPTCP in try ServerSocketChannel( eventLoop: eventLoop, @@ -566,8 +611,17 @@ extension ServerBootstrap { serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { - func makeChannel(_ eventLoop: SelectableEventLoop, _ childEventLoopGroup: EventLoopGroup, _ enableMPTCP: Bool) throws -> ServerSocketChannel { - try ServerSocketChannel(eventLoop: eventLoop, group: childEventLoopGroup, protocolFamily: .vsock, enableMPTCP: enableMPTCP) + func makeChannel( + _ eventLoop: SelectableEventLoop, + _ childEventLoopGroup: EventLoopGroup, + _ enableMPTCP: Bool + ) throws -> ServerSocketChannel { + try ServerSocketChannel( + eventLoop: eventLoop, + group: childEventLoopGroup, + protocolFamily: .vsock, + enableMPTCP: enableMPTCP + ) } return try await self.bind0( @@ -601,7 +655,7 @@ extension ServerBootstrap { serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> NIOAsyncChannel { - return try await bind0( + try await bind0( makeServerChannel: { eventLoop, childEventLoopGroup, enableMPTCP in if enableMPTCP { throw ChannelError._operationUnsupported @@ -628,7 +682,7 @@ extension ServerBootstrap { serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, registration: @escaping @Sendable (ServerSocketChannel) -> EventLoopFuture - ) -> EventLoopFuture> { + ) -> EventLoopFuture> { let eventLoop = self.group.next() let childEventLoopGroup = self.childGroup let serverChannelOptions = self._serverChannelOptions @@ -638,7 +692,11 @@ extension ServerBootstrap { let serverChannel: ServerSocketChannel do { - serverChannel = try makeServerChannel(eventLoop as! SelectableEventLoop, childEventLoopGroup, self.enableMPTCP) + serverChannel = try makeServerChannel( + eventLoop as! SelectableEventLoop, + childEventLoopGroup, + self.enableMPTCP + ) } catch { return eventLoop.makeFailedFuture(error) } @@ -649,7 +707,10 @@ extension ServerBootstrap { }.flatMap { (_) -> EventLoopFuture> in do { try serverChannel.pipeline.syncOperations.addHandler( - AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions), + AcceptHandler( + childChannelInitializer: childChannelInit, + childChannelOptions: childChannelOptions + ), name: "AcceptHandler" ) let asyncChannel = try NIOAsyncChannel @@ -666,8 +727,8 @@ extension ServerBootstrap { } ) return registration(serverChannel) - .map { (_) -> NIOAsyncChannel in asyncChannel - } + .map { (_) -> NIOAsyncChannel in asyncChannel + } } catch { return eventLoop.makeFailedFuture(error) } @@ -684,8 +745,10 @@ extension ServerBootstrap { @available(*, unavailable) extension ServerBootstrap: Sendable {} -private extension Channel { - func registerAndDoSynchronously(_ body: @escaping (Channel) -> EventLoopFuture) -> EventLoopFuture { +extension Channel { + fileprivate func registerAndDoSynchronously( + _ body: @escaping (Channel) -> EventLoopFuture + ) -> EventLoopFuture { // this is pretty delicate at the moment: // In many cases `body` must be _synchronously_ follow `register`, otherwise in our current // implementation, `epoll` will send us `EPOLLHUP`. To have it run synchronously, we need to invoke the @@ -758,8 +821,10 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - group: The `EventLoopGroup` to use. public convenience init(group: EventLoopGroup) { guard NIOOnSocketsBootstraps.isCompatible(group: group) else { - preconditionFailure("ClientBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + preconditionFailure( + "ClientBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible." + ) } self.init(validatingGroup: group)! } @@ -884,9 +949,15 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { return self } - func makeSocketChannel(eventLoop: EventLoop, - protocolFamily: NIOBSDSocket.ProtocolFamily) throws -> SocketChannel { - return try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily, enableMPTCP: self.enableMPTCP) + func makeSocketChannel( + eventLoop: EventLoop, + protocolFamily: NIOBSDSocket.ProtocolFamily + ) throws -> SocketChannel { + try SocketChannel( + eventLoop: eventLoop as! SelectableEventLoop, + protocolFamily: protocolFamily, + enableMPTCP: self.enableMPTCP + ) } /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. @@ -897,15 +968,21 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(host: String, port: Int) -> EventLoopFuture { let loop = self.group.next() - let resolver = self.resolver ?? GetaddrinfoResolver(loop: loop, - aiSocktype: .stream, - aiProtocol: .tcp) - let connector = HappyEyeballsConnector(resolver: resolver, - loop: loop, - host: host, - port: port, - connectTimeout: self.connectTimeout) { eventLoop, protocolFamily in - return self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) { + let resolver = + self.resolver + ?? GetaddrinfoResolver( + loop: loop, + aiSocktype: .stream, + aiProtocol: .tcp + ) + let connector = HappyEyeballsConnector( + resolver: resolver, + loop: loop, + host: host, + port: port, + connectTimeout: self.connectTimeout + ) { eventLoop, protocolFamily in + self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) { $0.eventLoop.makeSucceededFuture(()) } } @@ -926,10 +1003,12 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { return connectPromise.futureResult } - internal func testOnly_connect(injectedChannel: SocketChannel, - to address: SocketAddress) -> EventLoopFuture { - return self.initializeAndRegisterChannel(injectedChannel) { channel in - return self.connect(freshChannel: channel, address: address) + internal func testOnly_connect( + injectedChannel: SocketChannel, + to address: SocketAddress + ) -> EventLoopFuture { + self.initializeAndRegisterChannel(injectedChannel) { channel in + self.connect(freshChannel: channel, address: address) } } @@ -939,9 +1018,11 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// - address: The address to connect to. /// - returns: An `EventLoopFuture` to deliver the `Channel` when connected. public func connect(to address: SocketAddress) -> EventLoopFuture { - return self.initializeAndRegisterNewChannel(eventLoop: self.group.next(), - protocolFamily: address.protocol) { channel in - return self.connect(freshChannel: channel, address: address) + self.initializeAndRegisterNewChannel( + eventLoop: self.group.next(), + protocolFamily: address.protocol + ) { channel in + self.connect(freshChannel: channel, address: address) } } @@ -971,7 +1052,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { protocolFamily: .vsock ) { channel in let connectPromise = channel.eventLoop.makePromise(of: Void.self) - channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) + channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress(address), promise: connectPromise) let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) @@ -986,15 +1067,15 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { } #if !os(Windows) - /// Use the existing connected socket file descriptor. - /// - /// - parameters: - /// - descriptor: The _Unix file descriptor_ representing the connected stream socket. - /// - returns: an `EventLoopFuture` to deliver the `Channel`. - @available(*, deprecated, renamed: "withConnectedSocket(_:)") - public func withConnectedSocket(descriptor: CInt) -> EventLoopFuture { - return self.withConnectedSocket(descriptor) - } + /// Use the existing connected socket file descriptor. + /// + /// - parameters: + /// - descriptor: The _Unix file descriptor_ representing the connected stream socket. + /// - returns: an `EventLoopFuture` to deliver the `Channel`. + @available(*, deprecated, renamed: "withConnectedSocket(_:)") + public func withConnectedSocket(descriptor: CInt) -> EventLoopFuture { + self.withConnectedSocket(descriptor) + } #endif /// Use the existing connected socket file descriptor. @@ -1036,9 +1117,11 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { } } - private func initializeAndRegisterNewChannel(eventLoop: EventLoop, - protocolFamily: NIOBSDSocket.ProtocolFamily, - _ body: @escaping (Channel) -> EventLoopFuture) -> EventLoopFuture { + private func initializeAndRegisterNewChannel( + eventLoop: EventLoop, + protocolFamily: NIOBSDSocket.ProtocolFamily, + _ body: @escaping (Channel) -> EventLoopFuture + ) -> EventLoopFuture { let channel: SocketChannel do { channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) @@ -1048,8 +1131,10 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { return self.initializeAndRegisterChannel(channel, body) } - private func initializeAndRegisterChannel(_ channel: SocketChannel, - _ body: @escaping (Channel) -> EventLoopFuture) -> EventLoopFuture { + private func initializeAndRegisterChannel( + _ channel: SocketChannel, + _ body: @escaping (Channel) -> EventLoopFuture + ) -> EventLoopFuture { let channelInitializer = self.channelInitializer let channelOptions = self._channelOptions let eventLoop = channel.eventLoop @@ -1134,9 +1219,11 @@ extension ClientBootstrap { channelInitializer: channelInitializer, postRegisterTransformation: { output, eventLoop in eventLoop.makeSucceededFuture(output) - }, { channel in - return self.connect(freshChannel: channel, address: address) - }).get().1 + }, + { channel in + self.connect(freshChannel: channel, address: address) + } + ).get().1 } /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. @@ -1176,11 +1263,11 @@ extension ClientBootstrap { protocolFamily: NIOBSDSocket.ProtocolFamily.vsock, channelInitializer: channelInitializer, postRegisterTransformation: { result, eventLoop in - return eventLoop.makeSucceededFuture(result) + eventLoop.makeSucceededFuture(result) } ) { channel in let connectPromise = channel.eventLoop.makePromise(of: Void.self) - channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress( address), promise: connectPromise) + channel.triggerUserOutboundEvent(VsockChannelEvents.ConnectToAddress(address), promise: connectPromise) let cancelTask = channel.eventLoop.scheduleTask(in: connectTimeout) { connectPromise.fail(ChannelError.connectTimeout(connectTimeout)) @@ -1223,13 +1310,17 @@ extension ClientBootstrap { port: Int, eventLoop: EventLoop, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { - let resolver = self.resolver ?? GetaddrinfoResolver( - loop: eventLoop, - aiSocktype: .stream, - aiProtocol: .tcp - ) + let resolver = + self.resolver + ?? GetaddrinfoResolver( + loop: eventLoop, + aiSocktype: .stream, + aiProtocol: .tcp + ) let connector = HappyEyeballsConnector( resolver: resolver, @@ -1238,7 +1329,7 @@ extension ClientBootstrap { port: port, connectTimeout: self.connectTimeout ) { eventLoop, protocolFamily in - return self.initializeAndRegisterNewChannel( + self.initializeAndRegisterNewChannel( eventLoop: eventLoop, protocolFamily: protocolFamily, channelInitializer: channelInitializer, @@ -1255,7 +1346,9 @@ extension ClientBootstrap { eventLoop: EventLoop, socket: NIOBSDSocket.Handle, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { let channel = try SocketChannel(eventLoop: eventLoop as! SelectableEventLoop, socket: socket) @@ -1276,7 +1369,9 @@ extension ClientBootstrap { eventLoop: EventLoop, protocolFamily: NIOBSDSocket.ProtocolFamily, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture, + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + >, _ body: @escaping (Channel) -> EventLoopFuture ) -> EventLoopFuture<(Channel, PostRegistrationTransformationResult)> { let channel: SocketChannel @@ -1300,10 +1395,12 @@ extension ClientBootstrap { channel: SocketChannel, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, registration: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) -> EventLoopFuture { let channelInitializer = { channel in - return self.channelInitializer(channel) + self.channelInitializer(channel) .flatMap { channelInitializer(channel) } } let channelOptions = self._channelOptions @@ -1314,11 +1411,13 @@ extension ClientBootstrap { @Sendable func setupChannel() -> EventLoopFuture { eventLoop.assertInEventLoop() - return channelOptions + return + channelOptions .applyAllChannelOptions(to: channel) .flatMap { if let bindTarget = bindTarget { - return channel + return + channel .bind(to: bindTarget) .flatMap { channelInitializer(channel) @@ -1331,7 +1430,8 @@ extension ClientBootstrap { return registration(channel).map { result } - }.flatMap { (result: ChannelInitializerResult) -> EventLoopFuture in + }.flatMap { + (result: ChannelInitializerResult) -> EventLoopFuture in postRegisterTransformation(result, eventLoop) }.flatMapError { error in eventLoop.assertInEventLoop() @@ -1395,8 +1495,10 @@ public final class DatagramBootstrap { /// - group: The `EventLoopGroup` to use. public convenience init(group: EventLoopGroup) { guard NIOOnSocketsBootstraps.isCompatible(group: group) else { - preconditionFailure("DatagramBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + preconditionFailure( + "DatagramBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible." + ) } self.init(validatingGroup: group)! } @@ -1442,14 +1544,14 @@ public final class DatagramBootstrap { } #if !os(Windows) - /// Use the existing bound socket file descriptor. - /// - /// - parameters: - /// - descriptor: The _Unix file descriptor_ representing the bound datagram socket. - @available(*, deprecated, renamed: "withBoundSocket(_:)") - public func withBoundSocket(descriptor: CInt) -> EventLoopFuture { - return self.withBoundSocket(descriptor) - } + /// Use the existing bound socket file descriptor. + /// + /// - parameters: + /// - descriptor: The _Unix file descriptor_ representing the bound datagram socket. + @available(*, deprecated, renamed: "withBoundSocket(_:)") + public func withBoundSocket(descriptor: CInt) -> EventLoopFuture { + self.withBoundSocket(descriptor) + } #endif /// Use the existing bound socket file descriptor. @@ -1458,7 +1560,7 @@ public final class DatagramBootstrap { /// - descriptor: The _Unix file descriptor_ representing the bound datagram socket. public func withBoundSocket(_ socket: NIOBSDSocket.Handle) -> EventLoopFuture { func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, socket: socket) + try DatagramChannel(eventLoop: eventLoop, socket: socket) } return withNewChannel(makeChannel: makeChannel) { eventLoop, channel in let promise = eventLoop.makePromise(of: Void.self) @@ -1473,8 +1575,8 @@ public final class DatagramBootstrap { /// - host: The host to bind on. /// - port: The port to bind on. public func bind(host: String, port: Int) -> EventLoopFuture { - return bind0 { - return try SocketAddress.makeAddressResolvingHost(host, port: port) + bind0 { + try SocketAddress.makeAddressResolvingHost(host, port: port) } } @@ -1483,7 +1585,7 @@ public final class DatagramBootstrap { /// - parameters: /// - address: The `SocketAddress` to bind on. public func bind(to address: SocketAddress) -> EventLoopFuture { - return bind0 { address } + bind0 { address } } /// Bind the `DatagramChannel` to a UNIX Domain Socket. @@ -1491,8 +1593,8 @@ public final class DatagramBootstrap { /// - parameters: /// - unixDomainSocketPath: The path of the UNIX Domain Socket to bind on. `path` must not exist, it will be created by the system. public func bind(unixDomainSocketPath: String) -> EventLoopFuture { - return bind0 { - return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) + bind0 { + try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) } } @@ -1523,9 +1625,11 @@ public final class DatagramBootstrap { return group.next().makeFailedFuture(error) } func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, - protocolFamily: address.protocol, - protocolSubtype: subtype) + try DatagramChannel( + eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: subtype + ) } return withNewChannel(makeChannel: makeChannel) { _, channel in channel.register().flatMap { @@ -1540,8 +1644,8 @@ public final class DatagramBootstrap { /// - host: The host to connect to. /// - port: The port to connect to. public func connect(host: String, port: Int) -> EventLoopFuture { - return connect0 { - return try SocketAddress.makeAddressResolvingHost(host, port: port) + connect0 { + try SocketAddress.makeAddressResolvingHost(host, port: port) } } @@ -1550,7 +1654,7 @@ public final class DatagramBootstrap { /// - parameters: /// - address: The `SocketAddress` to connect to. public func connect(to address: SocketAddress) -> EventLoopFuture { - return connect0 { address } + connect0 { address } } /// Connect the `DatagramChannel` to a UNIX Domain Socket. @@ -1558,8 +1662,8 @@ public final class DatagramBootstrap { /// - parameters: /// - unixDomainSocketPath: The path of the UNIX Domain Socket to connect to. `path` must not exist, it will be created by the system. public func connect(unixDomainSocketPath: String) -> EventLoopFuture { - return connect0 { - return try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) + connect0 { + try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) } } @@ -1572,9 +1676,11 @@ public final class DatagramBootstrap { return group.next().makeFailedFuture(error) } func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, - protocolFamily: address.protocol, - protocolSubtype: subtype) + try DatagramChannel( + eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: subtype + ) } return withNewChannel(makeChannel: makeChannel) { _, channel in channel.register().flatMap { @@ -1583,7 +1689,10 @@ public final class DatagramBootstrap { } } - private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture) -> EventLoopFuture { + private func withNewChannel( + makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, + _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture + ) -> EventLoopFuture { let eventLoop = self.group.next() let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } let channelOptions = self._channelOptions @@ -1635,7 +1744,7 @@ extension DatagramBootstrap { channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, socket: socket) + try DatagramChannel(eventLoop: eventLoop, socket: socket) } return try await self.makeConfiguredChannel( makeChannel: makeChannel(_:), @@ -1665,7 +1774,7 @@ extension DatagramBootstrap { port: Int, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { - return try await self.bind0( + try await self.bind0( makeSocketAddress: { try SocketAddress.makeAddressResolvingHost(host, port: port) }, @@ -1688,7 +1797,7 @@ extension DatagramBootstrap { to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { - return try await self.bind0( + try await self.bind0( makeSocketAddress: { address }, @@ -1743,7 +1852,7 @@ extension DatagramBootstrap { port: Int, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { - return try await self.connect0( + try await self.connect0( makeSocketAddress: { try SocketAddress.makeAddressResolvingHost(host, port: port) }, @@ -1766,7 +1875,7 @@ extension DatagramBootstrap { to address: SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { - return try await self.connect0( + try await self.connect0( makeSocketAddress: { address }, @@ -1789,7 +1898,7 @@ extension DatagramBootstrap { unixDomainSocketPath: String, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) async throws -> Output { - return try await self.connect0( + try await self.connect0( makeSocketAddress: { try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) }, @@ -1804,13 +1913,15 @@ extension DatagramBootstrap { private func connect0( makeSocketAddress: () throws -> SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { let address = try makeSocketAddress() let subtype = self.proto func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel( + try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, protocolSubtype: subtype @@ -1833,13 +1944,15 @@ extension DatagramBootstrap { private func bind0( makeSocketAddress: () throws -> SocketAddress, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { let address = try makeSocketAddress() let subtype = self.proto func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel( + try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, protocolSubtype: subtype @@ -1863,7 +1976,9 @@ extension DatagramBootstrap { makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, registration: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) -> EventLoopFuture { let eventLoop = self.group.next() let channelInitializer = { (channel: Channel) -> EventLoopFuture in @@ -1939,8 +2054,10 @@ public final class NIOPipeBootstrap { /// - group: The `EventLoopGroup` to use. public convenience init(group: EventLoopGroup) { guard NIOOnSocketsBootstraps.isCompatible(group: group) else { - preconditionFailure("NIOPipeBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + preconditionFailure( + "NIOPipeBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible." + ) } self.init(validatingGroup: group)! } @@ -1998,7 +2115,7 @@ public final class NIOPipeBootstrap { } private func validateFileDescriptorIsNotAFile(_ descriptor: CInt) throws { -#if os(Windows) + #if os(Windows) // NOTE: this is a *non-owning* handle, do *NOT* call `CloseHandle` let hFile: HANDLE = HANDLE(bitPattern: _get_osfhandle(descriptor))! if hFile == INVALID_HANDLE_VALUE { @@ -2016,7 +2133,7 @@ public final class NIOPipeBootstrap { default: throw ChannelError._operationUnsupported } -#else + #else var s: stat = .init() try withUnsafeMutablePointer(to: &s) { ptr in try Posix.fstat(descriptor: descriptor, outStat: ptr) @@ -2025,9 +2142,9 @@ public final class NIOPipeBootstrap { case S_IFREG, S_IFDIR, S_IFLNK, S_IFBLK: throw ChannelError._operationUnsupported default: - () // Let's default to ok + () // Let's default to ok } -#endif + #endif } /// Create the `PipeChannel` with the provided file descriptor which is used for both input & output. @@ -2275,9 +2392,11 @@ extension NIOPipeBootstrap { output: CInt?, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture ) -> EventLoopFuture { - precondition(input ?? 0 >= 0 && output ?? 0 >= 0 && input != output, - "illegal file descriptor pair. The file descriptors \(String(describing: input)), \(String(describing: output)) " + - "must be distinct and both positive integers.") + precondition( + input ?? 0 >= 0 && output ?? 0 >= 0 && input != output, + "illegal file descriptor pair. The file descriptors \(String(describing: input)), \(String(describing: output)) " + + "must be distinct and both positive integers." + ) precondition(!(input == nil && output == nil), "Either input or output has to be set") let eventLoop = group.next() let channelOptions = self._channelOptions @@ -2311,11 +2430,11 @@ extension NIOPipeBootstrap { return eventLoop.makeFailedFuture(error) } - @Sendable func setupChannel() -> EventLoopFuture { eventLoop.assertInEventLoop() - return channelOptions.applyAllChannelOptions(to: channel).flatMap { _ -> EventLoopFuture in + return channelOptions.applyAllChannelOptions(to: channel).flatMap { + _ -> EventLoopFuture in channelInitializer(channel) }.flatMap { result in eventLoop.assertInEventLoop() @@ -2350,15 +2469,19 @@ extension NIOPipeBootstrap { extension NIOPipeBootstrap: Sendable {} protocol NIOPipeBootstrapHooks { - func makePipeChannel(eventLoop: SelectableEventLoop, - inputPipe: NIOFileHandle?, - outputPipe: NIOFileHandle?) throws -> PipeChannel + func makePipeChannel( + eventLoop: SelectableEventLoop, + inputPipe: NIOFileHandle?, + outputPipe: NIOFileHandle? + ) throws -> PipeChannel } -fileprivate struct DefaultNIOPipeBootstrapHooks: NIOPipeBootstrapHooks { - func makePipeChannel(eventLoop: SelectableEventLoop, - inputPipe: NIOFileHandle?, - outputPipe: NIOFileHandle?) throws -> PipeChannel { - return try PipeChannel(eventLoop: eventLoop, inputPipe: inputPipe, outputPipe: outputPipe) +private struct DefaultNIOPipeBootstrapHooks: NIOPipeBootstrapHooks { + func makePipeChannel( + eventLoop: SelectableEventLoop, + inputPipe: NIOFileHandle?, + outputPipe: NIOFileHandle? + ) throws -> PipeChannel { + try PipeChannel(eventLoop: eventLoop, inputPipe: inputPipe, outputPipe: outputPipe) } } diff --git a/Sources/NIOPosix/ControlMessage.swift b/Sources/NIOPosix/ControlMessage.swift index 620c356a64..96fe3ea6fc 100644 --- a/Sources/NIOPosix/ControlMessage.swift +++ b/Sources/NIOPosix/ControlMessage.swift @@ -47,8 +47,10 @@ struct UnsafeControlMessageStorage: Collection { /// - msghdrCount: How many `msghdr` structures will be fed from this buffer - we assume 4 Int32 cmsgs for each. static func allocate(msghdrCount: Int) -> UnsafeControlMessageStorage { let bytesPerMessage = Self.bytesPerMessage - let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: bytesPerMessage * msghdrCount, - alignment: MemoryLayout.alignment) + let buffer = UnsafeMutableRawBufferPointer.allocate( + byteCount: bytesPerMessage * msghdrCount, + alignment: MemoryLayout.alignment + ) return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer, deallocateBuffer: true) } @@ -56,7 +58,10 @@ struct UnsafeControlMessageStorage: Collection { /// parameter: /// - bytesPerMessage: How many bytes have been allocated for each supported message. /// - buffer: The memory allocated to use for control messages. - static func makeNotOwning(bytesPerMessage: Int, buffer: UnsafeMutableRawBufferPointer) -> UnsafeControlMessageStorage { + static func makeNotOwning( + bytesPerMessage: Int, + buffer: UnsafeMutableRawBufferPointer + ) -> UnsafeControlMessageStorage { precondition(buffer.count >= bytesPerMessage) return UnsafeControlMessageStorage(bytesPerMessage: bytesPerMessage, buffer: buffer, deallocateBuffer: false) } @@ -64,22 +69,26 @@ struct UnsafeControlMessageStorage: Collection { mutating func deallocate() { if self.deallocateBuffer { self.buffer.deallocate() - self.buffer = UnsafeMutableRawBufferPointer(start: UnsafeMutableRawPointer(bitPattern: 0x7eadbeef), count: 0) + self.buffer = UnsafeMutableRawBufferPointer( + start: UnsafeMutableRawPointer(bitPattern: 0x7ead_beef), + count: 0 + ) } } /// Get the part of the buffer for use with a message. public subscript(position: Int) -> UnsafeMutableRawBufferPointer { - return UnsafeMutableRawBufferPointer( - fastRebase: self.buffer[(position * self.bytesPerMessage)..<((position+1) * self.bytesPerMessage)]) + UnsafeMutableRawBufferPointer( + fastRebase: self.buffer[(position * self.bytesPerMessage)..<((position + 1) * self.bytesPerMessage)] + ) } - var startIndex: Int { return 0 } + var startIndex: Int { 0 } - var endIndex: Int { return self.buffer.count / self.bytesPerMessage } + var endIndex: Int { self.buffer.count / self.bytesPerMessage } func index(after: Int) -> Int { - return after + 1 + after + 1 } } @@ -104,12 +113,14 @@ struct UnsafeControlMessageCollection { // Add the `Collection` functionality to UnsafeControlMessageCollection. extension UnsafeControlMessageCollection: Collection { typealias Element = UnsafeControlMessage - + struct Index: Equatable, Comparable { fileprivate var cmsgPointer: UnsafeMutablePointer? - - static func < (lhs: UnsafeControlMessageCollection.Index, - rhs: UnsafeControlMessageCollection.Index) -> Bool { + + static func < ( + lhs: UnsafeControlMessageCollection.Index, + rhs: UnsafeControlMessageCollection.Index + ) -> Bool { // nil is high, as that's the end of the collection. switch (lhs.cmsgPointer, rhs.cmsgPointer) { case (.some(let lhs), .some(let rhs)): @@ -120,12 +131,12 @@ extension UnsafeControlMessageCollection: Collection { return false } } - + fileprivate init(cmsgPointer: UnsafeMutablePointer?) { self.cmsgPointer = cmsgPointer } } - + var startIndex: Index { var messageHeader = self.messageHeader return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in @@ -133,22 +144,28 @@ extension UnsafeControlMessageCollection: Collection { return Index(cmsgPointer: firstCMsg) } } - - var endIndex: Index { return Index(cmsgPointer: nil) } - + + var endIndex: Index { Index(cmsgPointer: nil) } + func index(after: Index) -> Index { var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in - return Index(cmsgPointer: NIOBSDSocketControlMessage.nextHeader(inside: messageHeaderPtr, - after: after.cmsgPointer!)) + Index( + cmsgPointer: NIOBSDSocketControlMessage.nextHeader( + inside: messageHeaderPtr, + after: after.cmsgPointer! + ) + ) } } - + public subscript(position: Index) -> Element { let cmsg = position.cmsgPointer! - return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, - type: cmsg.pointee.cmsg_type, - data: NIOBSDSocketControlMessage.data(for: cmsg)) + return UnsafeControlMessage( + level: cmsg.pointee.cmsg_level, + type: cmsg.pointee.cmsg_type, + data: NIOBSDSocketControlMessage.data(for: cmsg) + ) } } @@ -165,7 +182,7 @@ struct UnsafeReceivedControlBytes { /// Extract information from a collection of control messages. struct ControlMessageParser { - var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default + var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default var packetInfo: NIOPacketInfo? = nil init(parsing controlMessagesReceived: UnsafeControlMessageCollection) { @@ -173,11 +190,11 @@ struct ControlMessageParser { self.receiveMessage(controlMessage) } } - + #if canImport(Darwin) private static let ipv4TosType = IP_RECVTOS #else - private static let ipv4TosType = IP_TOS // Linux + private static let ipv4TosType = IP_TOS // Linux #endif static func _readCInt(data: UnsafeRawBufferPointer) -> CInt { @@ -189,7 +206,7 @@ struct ControlMessageParser { } return readValue } - + private mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { if controlMessage.level == IPPROTO_IP { self.receiveIPv4Message(controlMessage) @@ -213,8 +230,10 @@ struct ControlMessageParser { addr.sin_family = sa_family_t(NIOBSDSocket.AddressFamily.inet.rawValue) addr.sin_port = in_port_t(0) addr.sin_addr = info.ipi_addr - self.packetInfo = NIOPacketInfo(destinationAddress: SocketAddress(addr, host: ""), - interfaceIndex: Int(info.ipi_ifindex)) + self.packetInfo = NIOPacketInfo( + destinationAddress: SocketAddress(addr, host: ""), + interfaceIndex: Int(info.ipi_ifindex) + ) } } @@ -235,8 +254,10 @@ struct ControlMessageParser { addr.sin6_flowinfo = 0 addr.sin6_addr = info.ipi6_addr addr.sin6_scope_id = 0 - self.packetInfo = NIOPacketInfo(destinationAddress: SocketAddress(addr, host: ""), - interfaceIndex: Int(info.ipi6_ifindex)) + self.packetInfo = NIOPacketInfo( + destinationAddress: SocketAddress(addr, host: ""), + interfaceIndex: Int(info.ipi6_ifindex) + ) } } } @@ -277,7 +298,7 @@ extension CInt { struct UnsafeOutboundControlBytes { private var controlBytes: UnsafeMutableRawBufferPointer private var writePosition: UnsafeMutableRawBufferPointer.Index - + /// This structure must not outlive `controlBytes` init(controlBytes: UnsafeMutableRawBufferPointer) { self.controlBytes = controlBytes @@ -290,36 +311,40 @@ struct UnsafeOutboundControlBytes { /// Appends a control message. /// PayloadType needs to be trivial (eg CInt) - private mutating func appendGenericControlMessage(level: CInt, - type: CInt, - payload: PayloadType) { + private mutating func appendGenericControlMessage( + level: CInt, + type: CInt, + payload: PayloadType + ) { let writableBuffer = UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[writePosition...]) - + let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload)) precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") - + let bufferBase = writableBuffer.baseAddress! // Binding to cmsghdr is safe here as this is the only place where we bind to non-Raw. let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) cmsghdrPtr.pointee.cmsg_level = level cmsghdrPtr.pointee.cmsg_type = type - cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload))) - + cmsghdrPtr.pointee.cmsg_len = .init( + NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload)) + ) + let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)! precondition(dataPointer.count >= MemoryLayout.stride) dataPointer.storeBytes(of: payload, as: PayloadType.self) - + self.writePosition += requiredSize } - + /// The result is only valid while this is valid. var validControlBytes: UnsafeMutableRawBufferPointer { if writePosition == 0 { return UnsafeMutableRawBufferPointer(start: nil, count: 0) } - return UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[0 ..< self.writePosition]) + return UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[0...Metadata?, - protocolFamily: NIOBSDSocket.ProtocolFamily?) { + internal mutating func appendExplicitCongestionState( + metadata: AddressedEnvelope.Metadata?, + protocolFamily: NIOBSDSocket.ProtocolFamily? + ) { guard let metadata = metadata else { return } switch protocolFamily { case .some(.inet): - self.appendControlMessage(level: .init(IPPROTO_IP), - type: IP_TOS, - payload: CInt(ecnValue: metadata.ecnState)) + self.appendControlMessage( + level: .init(IPPROTO_IP), + type: IP_TOS, + payload: CInt(ecnValue: metadata.ecnState) + ) case .some(.inet6): - self.appendControlMessage(level: .init(IPPROTO_IPV6), - type: IPV6_TCLASS, - payload: CInt(ecnValue: metadata.ecnState)) + self.appendControlMessage( + level: .init(IPPROTO_IPV6), + type: IPV6_TCLASS, + payload: CInt(ecnValue: metadata.ecnState) + ) default: // Nothing to do - if we get here the user is probably making a mistake. break diff --git a/Sources/NIOPosix/DatagramVectorReadManager.swift b/Sources/NIOPosix/DatagramVectorReadManager.swift index 1378f5ebc3..07f1907dab 100644 --- a/Sources/NIOPosix/DatagramVectorReadManager.swift +++ b/Sources/NIOPosix/DatagramVectorReadManager.swift @@ -29,7 +29,7 @@ struct DatagramVectorReadManager { /// The number of messages that will be read in each syscall. var messageCount: Int { get { - return self.messageVector.count + self.messageVector.count } set { precondition(newValue >= 0) @@ -38,7 +38,10 @@ struct DatagramVectorReadManager { self.sockaddrVector.deinitializeAndDeallocate() self.controlMessageStorage.deallocate() - self.messageVector = .allocateAndInitialize(repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), count: newValue) + self.messageVector = .allocateAndInitialize( + repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), + count: newValue + ) self.ioVector = .allocateAndInitialize(repeating: IOVector(), count: newValue) self.sockaddrVector = .allocateAndInitialize(repeating: sockaddr_storage(), count: newValue) self.controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: newValue) @@ -60,10 +63,12 @@ struct DatagramVectorReadManager { // FIXME(cory): Right now there's no good API for specifying the various parameters of multi-read, especially how // it should interact with RecvByteBufferAllocator. For now I'm punting on this to see if I can get it working, // but we should design it back. - fileprivate init(messageVector: UnsafeMutableBufferPointer, - ioVector: UnsafeMutableBufferPointer, - sockaddrVector: UnsafeMutableBufferPointer, - controlMessageStorage: UnsafeControlMessageStorage) { + fileprivate init( + messageVector: UnsafeMutableBufferPointer, + ioVector: UnsafeMutableBufferPointer, + sockaddrVector: UnsafeMutableBufferPointer, + controlMessageStorage: UnsafeControlMessageStorage + ) { self.messageVector = messageVector self.ioVector = ioVector self.sockaddrVector = sockaddrVector @@ -89,9 +94,11 @@ struct DatagramVectorReadManager { /// - socket: The underlying socket from which to read. /// - buffer: The single large buffer into which reads will be written. /// - parseControlMessages: Should control messages be reported up using metadata. - func readFromSocket(socket: Socket, - buffer: inout ByteBuffer, - parseControlMessages: Bool) throws -> ReadResult { + func readFromSocket( + socket: Socket, + buffer: inout ByteBuffer, + parseControlMessages: Bool + ) throws -> ReadResult { assert(buffer.readerIndex == 0, "Buffer was not cleared between calls to readFromSocket!") let messageSize = buffer.capacity / self.messageCount @@ -101,8 +108,11 @@ struct DatagramVectorReadManager { // TODO(cory): almost all of this except for the iovec could be done at allocation time. Maybe we should? // First we set up the iovec and save it off. - self.ioVector[i] = IOVector(iov_base: bufferPointer.baseAddress! + (i * messageSize), iov_len: numericCast(messageSize)) - + self.ioVector[i] = IOVector( + iov_base: bufferPointer.baseAddress! + (i * messageSize), + iov_len: numericCast(messageSize) + ) + let controlBytes: UnsafeMutableRawBufferPointer if parseControlMessages { // This will be used in buildMessages below but should not be used beyond return of this function. @@ -136,10 +146,12 @@ struct DatagramVectorReadManager { return .none case .processed(let messagesProcessed): buffer.moveWriterIndex(to: messageSize * messagesProcessed) - return try self.buildMessages(messageCount: messagesProcessed, - sliceSize: messageSize, - buffer: &buffer, - parseControlMessages: parseControlMessages) + return try self.buildMessages( + messageCount: messagesProcessed, + sliceSize: messageSize, + buffer: &buffer, + parseControlMessages: parseControlMessages + ) } } @@ -151,14 +163,16 @@ struct DatagramVectorReadManager { self.controlMessageStorage.deallocate() } - private func buildMessages(messageCount: Int, - sliceSize: Int, - buffer: inout ByteBuffer, - parseControlMessages: Bool) throws -> ReadResult { + private func buildMessages( + messageCount: Int, + sliceSize: Int, + buffer: inout ByteBuffer, + parseControlMessages: Bool + ) throws -> ReadResult { var sliceOffset = buffer.readerIndex var totalReadSize = 0 - var results = Array>() + var results = [AddressedEnvelope]() results.reserveCapacity(messageCount) for i in 0...Metadata? if parseControlMessages { @@ -205,19 +219,26 @@ extension DatagramVectorReadManager { /// - parameters: /// - messageCount: The number of vector reads to support initially. static func allocate(messageCount: Int) -> DatagramVectorReadManager { - let messageVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), count: messageCount) + let messageVector = UnsafeMutableBufferPointer.allocateAndInitialize( + repeating: MMsgHdr(msg_hdr: msghdr(), msg_len: 0), + count: messageCount + ) let ioVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: IOVector(), count: messageCount) - let sockaddrVector = UnsafeMutableBufferPointer.allocateAndInitialize(repeating: sockaddr_storage(), count: messageCount) + let sockaddrVector = UnsafeMutableBufferPointer.allocateAndInitialize( + repeating: sockaddr_storage(), + count: messageCount + ) let controlMessageStorage = UnsafeControlMessageStorage.allocate(msghdrCount: messageCount) - return DatagramVectorReadManager(messageVector: messageVector, - ioVector: ioVector, - sockaddrVector: sockaddrVector, - controlMessageStorage: controlMessageStorage) + return DatagramVectorReadManager( + messageVector: messageVector, + ioVector: ioVector, + sockaddrVector: sockaddrVector, + controlMessageStorage: controlMessageStorage + ) } } - extension Optional where Wrapped == DatagramVectorReadManager { /// Updates the message count of the wrapped `DatagramVectorReadManager` to the new value. /// @@ -239,12 +260,14 @@ extension Optional where Wrapped == DatagramVectorReadManager { } } - extension UnsafeMutableBufferPointer { /// Safely creates an UnsafeMutableBufferPointer that can be used by the rest of the code. It ensures that /// the memory has been bound, allocated, and initialized, such that other Swift code can use it safely without /// worrying. - fileprivate static func allocateAndInitialize(repeating element: Element, count: Int) -> UnsafeMutableBufferPointer { + fileprivate static func allocateAndInitialize( + repeating element: Element, + count: Int + ) -> UnsafeMutableBufferPointer { let newPointer = UnsafeMutableBufferPointer.allocate(capacity: count) newPointer.initialize(repeating: element) return newPointer diff --git a/Sources/NIOPosix/Errors+Any.swift b/Sources/NIOPosix/Errors+Any.swift index d8bc938b89..5b961ec97e 100644 --- a/Sources/NIOPosix/Errors+Any.swift +++ b/Sources/NIOPosix/Errors+Any.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// - import NIOCore // 'any Error' is unconditionally boxed, avoid allocating per use by statically boxing them. diff --git a/Sources/NIOPosix/GetaddrinfoResolver.swift b/Sources/NIOPosix/GetaddrinfoResolver.swift index 633e91d0f2..c5a7e5bfa7 100644 --- a/Sources/NIOPosix/GetaddrinfoResolver.swift +++ b/Sources/NIOPosix/GetaddrinfoResolver.swift @@ -11,6 +11,8 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import Dispatch import NIOCore /// A DNS resolver built on top of the libc `getaddrinfo` function. @@ -24,8 +26,6 @@ import NIOCore /// /// This resolver is a single-use object: it can only be used to perform a single host resolution. -import Dispatch - #if os(Linux) || os(FreeBSD) || os(Android) import CNIOLinux #endif @@ -47,7 +47,6 @@ import struct WinSDK.SOCKADDR_IN6 // A thread-specific variable where we store the offload queue if we're on an `SelectableEventLoop`. let offloadQueueTSV = ThreadSpecificVariable() - internal class GetaddrinfoResolver: Resolver { private let v4Future: EventLoopPromise<[SocketAddress]> private let v6Future: EventLoopPromise<[SocketAddress]> @@ -60,8 +59,11 @@ internal class GetaddrinfoResolver: Resolver { /// - loop: The `EventLoop` whose thread this resolver will block. /// - aiSocktype: The sock type to use as hint when calling getaddrinfo. /// - aiProtocol: the protocol to use as hint when calling getaddrinfo. - init(loop: EventLoop, aiSocktype: NIOBSDSocket.SocketType, - aiProtocol: NIOBSDSocket.OptionLevel) { + init( + loop: EventLoop, + aiSocktype: NIOBSDSocket.SocketType, + aiProtocol: NIOBSDSocket.OptionLevel + ) { self.v4Future = loop.makePromise() self.v6Future = loop.makePromise() self.aiSocktype = aiSocktype @@ -79,7 +81,7 @@ internal class GetaddrinfoResolver: Resolver { /// - port: The port we'll be connecting to. /// - returns: An `EventLoopFuture` that fires with the result of the lookup. func initiateAQuery(host: String, port: Int) -> EventLoopFuture<[SocketAddress]> { - return v4Future.futureResult + v4Future.futureResult } /// Initiate a DNS AAAA query for a given host. @@ -120,7 +122,7 @@ internal class GetaddrinfoResolver: Resolver { /// clean up their state. /// /// In the getaddrinfo case this is a no-op, as the resolver blocks. - func cancelQueries() { } + func cancelQueries() {} /// Perform the DNS queries and record the result. /// @@ -128,7 +130,7 @@ internal class GetaddrinfoResolver: Resolver { /// - host: The hostname to do the DNS queries on. /// - port: The port we'll be connecting to. private func resolveBlocking(host: String, port: Int) { -#if os(Windows) + #if os(Windows) host.withCString(encodedAs: UTF16.self) { wszHost in String(port).withCString(encodedAs: UTF16.self) { wszPort in var pResult: UnsafeMutablePointer? @@ -151,7 +153,7 @@ internal class GetaddrinfoResolver: Resolver { } } } -#else + #else var info: UnsafeMutablePointer? var hint = addrinfo() @@ -166,10 +168,10 @@ internal class GetaddrinfoResolver: Resolver { self.parseAndPublishResults(info, host: host) freeaddrinfo(info) } else { - /* this is odd, getaddrinfo returned NULL */ + // this is odd, getaddrinfo returned NULL self.fail(SocketAddressError.unsupported) } -#endif + #endif } /// Parses the DNS results from the `addrinfo` linked list. @@ -177,11 +179,11 @@ internal class GetaddrinfoResolver: Resolver { /// - parameters: /// - info: The pointer to the first of the `addrinfo` structures in the list. /// - host: The hostname we resolved. -#if os(Windows) + #if os(Windows) internal typealias CAddrInfo = ADDRINFOW -#else + #else internal typealias CAddrInfo = addrinfo -#endif + #endif private func parseAndPublishResults(_ info: UnsafeMutablePointer, host: String) { var v4Results: [SocketAddress] = [] diff --git a/Sources/NIOPosix/HappyEyeballs.swift b/Sources/NIOPosix/HappyEyeballs.swift index f58435d77f..ffc1a06139 100644 --- a/Sources/NIOPosix/HappyEyeballs.swift +++ b/Sources/NIOPosix/HappyEyeballs.swift @@ -27,8 +27,8 @@ import NIOCore // We naturally still use an enum to hold our state, but the FSM is now inside a class, which makes the shared // state nature of this FSM a bit clearer. -private extension Array where Element == EventLoopFuture { - mutating func remove(element: Element) { +extension Array where Element == EventLoopFuture { + fileprivate mutating func remove(element: Element) { guard let channelIndex = self.firstIndex(where: { $0 === element }) else { return } @@ -226,7 +226,8 @@ internal final class HappyEyeballsConnector { /// than intended. /// /// The channel builder callback takes an event loop and a protocol family as arguments. - private let channelBuilderCallback: (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)> + private let channelBuilderCallback: + (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)> /// The amount of time to wait for an AAAA response to come in after a A response is /// received. By default this is 50ms. @@ -278,14 +279,18 @@ internal final class HappyEyeballsConnector { private var error: NIOConnectionError @inlinable - init(resolver: Resolver, - loop: EventLoop, - host: String, - port: Int, - connectTimeout: TimeAmount, - resolutionDelay: TimeAmount = .milliseconds(50), - connectionDelay: TimeAmount = .milliseconds(250), - channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)>) { + init( + resolver: Resolver, + loop: EventLoop, + host: String, + port: Int, + connectTimeout: TimeAmount, + resolutionDelay: TimeAmount = .milliseconds(50), + connectionDelay: TimeAmount = .milliseconds(250), + channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture< + (Channel, ChannelBuilderResult) + > + ) { self.resolver = resolver self.loop = loop self.host = host @@ -300,10 +305,16 @@ internal final class HappyEyeballsConnector { self.resolutionPromise = self.loop.makePromise() self.error = NIOConnectionError(host: host, port: port) - precondition(resolutionDelay.nanoseconds > 0, "Resolution delay must be greater than zero, got \(resolutionDelay).") + precondition( + resolutionDelay.nanoseconds > 0, + "Resolution delay must be greater than zero, got \(resolutionDelay)." + ) self.resolutionDelay = resolutionDelay - precondition(connectionDelay >= .milliseconds(100) && connectionDelay <= .milliseconds(2000), "Connection delay must be between 100 and 2000 ms, got \(connectionDelay)") + precondition( + connectionDelay >= .milliseconds(100) && connectionDelay <= .milliseconds(2000), + "Connection delay must be between 100 and 2000 ms, got \(connectionDelay)" + ) self.connectionDelay = connectionDelay } @@ -325,9 +336,10 @@ internal final class HappyEyeballsConnector { port: port, connectTimeout: connectTimeout, resolutionDelay: resolutionDelay, - connectionDelay: connectionDelay) { loop, protocolFamily in - channelBuilderCallback(loop, protocolFamily).map { ($0, ()) } - } + connectionDelay: connectionDelay + ) { loop, protocolFamily in + channelBuilderCallback(loop, protocolFamily).map { ($0, ()) } + } } /// Initiate a DNS resolution attempt using Happy Eyeballs 2. @@ -337,7 +349,9 @@ internal final class HappyEyeballsConnector { func resolveAndConnect() -> EventLoopFuture<(Channel, ChannelBuilderResult)> { // We dispatch ourselves onto the event loop, rather than do all the rest of our processing from outside it. self.loop.execute { - self.timeoutTask = self.loop.scheduleTask(in: self.connectTimeout) { self.processInput(.connectTimeoutElapsed) } + self.timeoutTask = self.loop.scheduleTask(in: self.connectTimeout) { + self.processInput(.connectTimeoutElapsed) + } self.processInput(.resolve) } return resolutionPromise.futureResult @@ -451,12 +465,12 @@ internal final class HappyEyeballsConnector { // ignore these, as our transition into the complete state should have already sent // cleanup messages to all of these things. case (.complete, .resolverACompleted), - (.complete, .resolverAAAACompleted), - (.complete, .connectSuccess), - (.complete, .connectFailed), - (.complete, .connectDelayElapsed), - (.complete, .connectTimeoutElapsed), - (.complete, .resolutionDelayElapsed): + (.complete, .resolverAAAACompleted), + (.complete, .connectSuccess), + (.complete, .connectFailed), + (.complete, .connectDelayElapsed), + (.complete, .connectTimeoutElapsed), + (.complete, .resolutionDelayElapsed): break default: fatalError("Invalid FSM transition attempt: state \(state), input \(input)") @@ -600,7 +614,10 @@ internal final class HappyEyeballsConnector { // The connection attempt failed. If we're in the complete state then there's nothing // to do. Otherwise, notify the state machine of the failure. if case .complete = self.state { - assert(self.pendingConnections.firstIndex { $0 === channelFuture } == nil, "failed but was still in pending connections") + assert( + self.pendingConnections.firstIndex { $0 === channelFuture } == nil, + "failed but was still in pending connections" + ) } else { self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: err)) self.pendingConnections.removeAll { $0 === channelFuture } diff --git a/Sources/NIOPosix/IO.swift b/Sources/NIOPosix/IO.swift index 452749b563..5a9a581583 100644 --- a/Sources/NIOPosix/IO.swift +++ b/Sources/NIOPosix/IO.swift @@ -12,7 +12,7 @@ // //===----------------------------------------------------------------------===// -internal extension IOResult where T: FixedWidthInteger { +extension IOResult where T: FixedWidthInteger { var result: T { switch self { case .processed(let value): diff --git a/Sources/NIOPosix/IntegerBitPacking.swift b/Sources/NIOPosix/IntegerBitPacking.swift index 6d89a9d81e..14532e1b9a 100644 --- a/Sources/NIOPosix/IntegerBitPacking.swift +++ b/Sources/NIOPosix/IntegerBitPacking.swift @@ -17,11 +17,15 @@ enum _IntegerBitPacking {} extension _IntegerBitPacking { @inlinable - static func packUU(_ left: Left, - _ right: Right, - type: Result.Type = Result.self) -> Result { + static func packUU< + Left: FixedWidthInteger & UnsignedInteger, + Right: FixedWidthInteger & UnsignedInteger, + Result: FixedWidthInteger & UnsignedInteger + >( + _ left: Left, + _ right: Right, + type: Result.Type = Result.self + ) -> Result { assert(MemoryLayout.size + MemoryLayout.size <= MemoryLayout.size) let resultLeft = Result(left) @@ -32,11 +36,15 @@ extension _IntegerBitPacking { } @inlinable - static func unpackUU(_ input: Input, - leftType: Left.Type = Left.self, - rightType: Right.Type = Right.self) -> (Left, Right) { + static func unpackUU< + Input: FixedWidthInteger & UnsignedInteger, + Left: FixedWidthInteger & UnsignedInteger, + Right: FixedWidthInteger & UnsignedInteger + >( + _ input: Input, + leftType: Left.Type = Left.self, + rightType: Right.Type = Right.self + ) -> (Left, Right) { assert(MemoryLayout.size + MemoryLayout.size <= MemoryLayout.size) let leftMask = Input(Left.max) @@ -55,7 +63,7 @@ enum IntegerBitPacking {} extension IntegerBitPacking { @inlinable static func packUInt32UInt16UInt8(_ left: UInt32, _ middle: UInt16, _ right: UInt8) -> UInt64 { - return _IntegerBitPacking.packUU( + _IntegerBitPacking.packUU( _IntegerBitPacking.packUU(right, middle, type: UInt32.self), left ) @@ -70,27 +78,27 @@ extension IntegerBitPacking { @inlinable static func packUInt8UInt8(_ left: UInt8, _ right: UInt8) -> UInt16 { - return _IntegerBitPacking.packUU(left, right) + _IntegerBitPacking.packUU(left, right) } @inlinable static func unpackUInt8UInt8(_ value: UInt16) -> (UInt8, UInt8) { - return _IntegerBitPacking.unpackUU(value) + _IntegerBitPacking.unpackUU(value) } @inlinable static func packUInt16UInt8(_ left: UInt16, _ right: UInt8) -> UInt32 { - return _IntegerBitPacking.packUU(left, right) + _IntegerBitPacking.packUU(left, right) } @inlinable static func unpackUInt16UInt8(_ value: UInt32) -> (UInt16, UInt8) { - return _IntegerBitPacking.unpackUU(value) + _IntegerBitPacking.unpackUU(value) } @inlinable static func packUInt32CInt(_ left: UInt32, _ right: CInt) -> UInt64 { - return _IntegerBitPacking.packUU(left, UInt32(truncatingIfNeeded: right)) + _IntegerBitPacking.packUU(left, UInt32(truncatingIfNeeded: right)) } @inlinable diff --git a/Sources/NIOPosix/IntegerTypes.swift b/Sources/NIOPosix/IntegerTypes.swift index e0cf571b9a..3ac37dc0bf 100644 --- a/Sources/NIOPosix/IntegerTypes.swift +++ b/Sources/NIOPosix/IntegerTypes.swift @@ -48,18 +48,17 @@ extension Int { } } - extension _UInt24: Equatable { @inlinable - public static func ==(lhs: _UInt24, rhs: _UInt24) -> Bool { - return lhs._backing == rhs._backing + public static func == (lhs: _UInt24, rhs: _UInt24) -> Bool { + lhs._backing == rhs._backing } } extension _UInt24: CustomStringConvertible { @usableFromInline var description: String { - return UInt32(self).description + UInt32(self).description } } @@ -75,7 +74,7 @@ struct _UInt56 { static let bitWidth: Int = 56 - private static let initializeUInt64 : UInt64 = (1 << 56) - 1 + private static let initializeUInt64: UInt64 = (1 << 56) - 1 static let max: _UInt56 = .init(initializeUInt64) static let min: _UInt56 = .init(0) } @@ -88,9 +87,11 @@ extension _UInt56 { extension UInt64 { init(_ value: _UInt56) { - self = IntegerBitPacking.packUInt32UInt16UInt8(value._backing.0, - value._backing.1, - value._backing.2) + self = IntegerBitPacking.packUInt32UInt16UInt8( + value._backing.0, + value._backing.1, + value._backing.2 + ) } } @@ -102,13 +103,13 @@ extension Int { extension _UInt56: Equatable { @inlinable - public static func ==(lhs: _UInt56, rhs: _UInt56) -> Bool { - return lhs._backing == rhs._backing + public static func == (lhs: _UInt56, rhs: _UInt56) -> Bool { + lhs._backing == rhs._backing } } extension _UInt56: CustomStringConvertible { var description: String { - return UInt64(self).description + UInt64(self).description } } diff --git a/Sources/NIOPosix/Linux.swift b/Sources/NIOPosix/Linux.swift index d5f25f76cd..7fcaed6a89 100644 --- a/Sources/NIOPosix/Linux.swift +++ b/Sources/NIOPosix/Linux.swift @@ -23,7 +23,12 @@ internal enum TimerFd { internal static let TFD_NONBLOCK = CNIOLinux.TFD_NONBLOCK @inline(never) - internal static func timerfd_settime(fd: CInt, flags: CInt, newValue: UnsafePointer, oldValue: UnsafeMutablePointer?) throws { + internal static func timerfd_settime( + fd: CInt, + flags: CInt, + newValue: UnsafePointer, + oldValue: UnsafeMutablePointer? + ) throws { _ = try syscall(blocking: false) { CNIOLinux.timerfd_settime(fd, flags, newValue, oldValue) } @@ -31,7 +36,7 @@ internal enum TimerFd { @inline(never) internal static func timerfd_create(clockId: CInt, flags: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { CNIOLinux.timerfd_create(clockId, flags) }.result } @@ -44,21 +49,21 @@ internal enum EventFd { @inline(never) internal static func eventfd_write(fd: CInt, value: UInt64) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { CNIOLinux.eventfd_write(fd, value) }.result } @inline(never) internal static func eventfd_read(fd: CInt, value: UnsafeMutablePointer) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { CNIOLinux.eventfd_read(fd, value) }.result } @inline(never) internal static func eventfd(initval: CUnsignedInt, flags: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { // Note: Please do _not_ remove the `numericCast`, this is to allow compilation in Ubuntu 14.04 and // other Linux distros which ship a glibc from before this commit: // https://sourceware.org/git/?p=glibc.git;a=commitdiff;h=69eb9a183c19e8739065e430758e4d3a2c5e4f1a @@ -76,12 +81,12 @@ internal enum Epoll { internal static let EPOLL_CTL_DEL: CInt = numericCast(CNIOLinux.EPOLL_CTL_DEL) #if os(Android) - internal static let EPOLLIN: CUnsignedInt = 1 //numericCast(CNIOLinux.EPOLLIN) - internal static let EPOLLOUT: CUnsignedInt = 4 //numericCast(CNIOLinux.EPOLLOUT) - internal static let EPOLLERR: CUnsignedInt = 8 // numericCast(CNIOLinux.EPOLLERR) - internal static let EPOLLRDHUP: CUnsignedInt = 8192 //numericCast(CNIOLinux.EPOLLRDHUP) - internal static let EPOLLHUP: CUnsignedInt = 16 //numericCast(CNIOLinux.EPOLLHUP) - internal static let EPOLLET: CUnsignedInt = 2147483648 //numericCast(CNIOLinux.EPOLLET) + internal static let EPOLLIN: CUnsignedInt = 1 //numericCast(CNIOLinux.EPOLLIN) + internal static let EPOLLOUT: CUnsignedInt = 4 //numericCast(CNIOLinux.EPOLLOUT) + internal static let EPOLLERR: CUnsignedInt = 8 // numericCast(CNIOLinux.EPOLLERR) + internal static let EPOLLRDHUP: CUnsignedInt = 8192 //numericCast(CNIOLinux.EPOLLRDHUP) + internal static let EPOLLHUP: CUnsignedInt = 16 //numericCast(CNIOLinux.EPOLLHUP) + internal static let EPOLLET: CUnsignedInt = 2_147_483_648 //numericCast(CNIOLinux.EPOLLET) #elseif canImport(Musl) internal static let EPOLLIN: CUnsignedInt = numericCast(CNIOLinux.EPOLLIN) internal static let EPOLLOUT: CUnsignedInt = numericCast(CNIOLinux.EPOLLOUT) @@ -100,50 +105,66 @@ internal enum Epoll { internal static let ENOENT: CUnsignedInt = numericCast(CNIOLinux.ENOENT) - @inline(never) internal static func epoll_create(size: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { CNIOLinux.epoll_create(size) }.result } @inline(never) @discardableResult - internal static func epoll_ctl(epfd: CInt, op: CInt, fd: CInt, event: UnsafeMutablePointer) throws -> CInt { - return try syscall(blocking: false) { + internal static func epoll_ctl( + epfd: CInt, + op: CInt, + fd: CInt, + event: UnsafeMutablePointer + ) throws -> CInt { + try syscall(blocking: false) { CNIOLinux.epoll_ctl(epfd, op, fd, event) }.result } @inline(never) - internal static func epoll_wait(epfd: CInt, events: UnsafeMutablePointer, maxevents: CInt, timeout: CInt) throws -> CInt { - return try syscall(blocking: false) { + internal static func epoll_wait( + epfd: CInt, + events: UnsafeMutablePointer, + maxevents: CInt, + timeout: CInt + ) throws -> CInt { + try syscall(blocking: false) { CNIOLinux.epoll_wait(epfd, events, maxevents, timeout) }.result } } internal enum Linux { -#if os(Android) + #if os(Android) static let SOCK_CLOEXEC = Glibc.SOCK_CLOEXEC static let SOCK_NONBLOCK = Glibc.SOCK_NONBLOCK -#elseif canImport(Musl) + #elseif canImport(Musl) static let SOCK_CLOEXEC = Musl.SOCK_CLOEXEC static let SOCK_NONBLOCK = Musl.SOCK_NONBLOCK -#else + #else static let SOCK_CLOEXEC = CInt(bitPattern: Glibc.SOCK_CLOEXEC.rawValue) static let SOCK_NONBLOCK = CInt(bitPattern: Glibc.SOCK_NONBLOCK.rawValue) -#endif + #endif @inline(never) - internal static func accept4(descriptor: CInt, - addr: UnsafeMutablePointer?, - len: UnsafeMutablePointer?, - flags: CInt) throws -> CInt? { - guard case let .processed(fd) = try syscall(blocking: true, { - CNIOLinux.CNIOLinux_accept4(descriptor, addr, len, flags) - }) else { - return nil + internal static func accept4( + descriptor: CInt, + addr: UnsafeMutablePointer?, + len: UnsafeMutablePointer?, + flags: CInt + ) throws -> CInt? { + guard + case let .processed(fd) = try syscall( + blocking: true, + { + CNIOLinux.CNIOLinux_accept4(descriptor, addr, len, flags) + } + ) + else { + return nil } return fd } diff --git a/Sources/NIOPosix/LinuxCPUSet.swift b/Sources/NIOPosix/LinuxCPUSet.swift index 98c92dcfee..ffa3c49672 100644 --- a/Sources/NIOPosix/LinuxCPUSet.swift +++ b/Sources/NIOPosix/LinuxCPUSet.swift @@ -15,82 +15,88 @@ #if os(Linux) || os(Android) import CNIOLinux - /// A set that contains CPU ids to use. - struct LinuxCPUSet { - /// The ids of all the cpus. - let cpuIds: Set +/// A set that contains CPU ids to use. +struct LinuxCPUSet { + /// The ids of all the cpus. + let cpuIds: Set - /// Create a new instance - /// - /// - arguments: - /// - cpuIds: The `Set` of CPU ids. It must be non-empty and can not contain invalid ids. - init(cpuIds: Set) { - precondition(!cpuIds.isEmpty) - self.cpuIds = cpuIds - } - - /// Create a new instance - /// - /// - arguments: - /// - cpuId: The CPU id. - init(_ cpuId: Int) { - let ids: Set = [cpuId] - self.init(cpuIds: ids) - } + /// Create a new instance + /// + /// - arguments: + /// - cpuIds: The `Set` of CPU ids. It must be non-empty and can not contain invalid ids. + init(cpuIds: Set) { + precondition(!cpuIds.isEmpty) + self.cpuIds = cpuIds } - extension LinuxCPUSet: Equatable {} - - /// Linux specific extension to `NIOThread`. - extension NIOThread { - /// Specify the thread-affinity of the `NIOThread` itself. - var affinity: LinuxCPUSet { - get { - var cpuset = cpu_set_t() + /// Create a new instance + /// + /// - arguments: + /// - cpuId: The CPU id. + init(_ cpuId: Int) { + let ids: Set = [cpuId] + self.init(cpuIds: ids) + } +} - // Ensure the cpuset is empty (and so nothing is selected yet). - CNIOLinux_CPU_ZERO(&cpuset) +extension LinuxCPUSet: Equatable {} - let res = self.withUnsafeThreadHandle { p in - CNIOLinux_pthread_getaffinity_np(p, MemoryLayout.size(ofValue: cpuset), &cpuset) - } +/// Linux specific extension to `NIOThread`. +extension NIOThread { + /// Specify the thread-affinity of the `NIOThread` itself. + var affinity: LinuxCPUSet { + get { + var cpuset = cpu_set_t() - precondition(res == 0, "pthread_getaffinity_np failed: \(res)") + // Ensure the cpuset is empty (and so nothing is selected yet). + CNIOLinux_CPU_ZERO(&cpuset) - let set = Set((CInt(0).. __kernel_timespec { var ts = __kernel_timespec() ts.tv_sec = self.nanoseconds / 1_000_000_000 @@ -43,9 +44,10 @@ internal extension TimeAmount { // for the type of event issued (poll/modify/delete). @usableFromInline struct URingUserData { @usableFromInline var fileDescriptor: CInt - @usableFromInline var registrationID: UInt16 // SelectorRegistrationID truncated, only have room for bottom 16 bits (could be expanded to 24 if required) + // SelectorRegistrationID truncated, only have room for bottom 16 bits (could be expanded to 24 if required) + @usableFromInline var registrationID: UInt16 @usableFromInline var eventType: CQEEventType - @usableFromInline var padding: Int8 // reserved for future use + @usableFromInline var padding: Int8 // reserved for future use @inlinable init(registrationID: SelectorRegistrationID, fileDescriptor: CInt, eventType: CQEEventType) { assert(MemoryLayout.size == MemoryLayout.size) @@ -57,9 +59,11 @@ internal extension TimeAmount { @inlinable init(rawValue: UInt64) { let unpacked = IntegerBitPacking.unpackUInt32UInt16UInt8(rawValue) - self = .init(registrationID: SelectorRegistrationID(rawValue: UInt32(unpacked.1)), - fileDescriptor: CInt(unpacked.0), - eventType: CQEEventType(rawValue:unpacked.2)!) + self = .init( + registrationID: SelectorRegistrationID(rawValue: UInt32(unpacked.1)), + fileDescriptor: CInt(unpacked.0), + eventType: CQEEventType(rawValue: unpacked.2)! + ) } } @@ -70,9 +74,11 @@ extension UInt64 { assert(fd >= 0, "\(fd) is not a valid file descriptor") assert(eventType >= 0, "\(eventType) is not a valid eventType") - self = IntegerBitPacking.packUInt32UInt16UInt8(UInt32(truncatingIfNeeded: fd), - uringUserData.registrationID, - eventType) + self = IntegerBitPacking.packUInt32UInt16UInt8( + UInt32(truncatingIfNeeded: fd), + uringUserData.registrationID, + eventType + ) } } @@ -80,9 +86,9 @@ extension UInt64 { internal struct URingEvent { var fd: CInt var pollMask: UInt32 - var registrationID: UInt16 // we just have the truncated lower 16 bits of the registrationID + var registrationID: UInt16 // we just have the truncated lower 16 bits of the registrationID var pollCancelled: Bool - init () { + init() { self.fd = -1 self.pollMask = 0 self.registrationID = 0 @@ -93,7 +99,7 @@ internal struct URingEvent { // This is the key we use for merging events in our internal hashtable struct FDEventKey: Hashable { var fileDescriptor: CInt - var registrationID: UInt16 // we just have the truncated lower 16 bits of the registrationID + var registrationID: UInt16 // we just have the truncated lower 16 bits of the registrationID init(_ f: CInt, _ s: UInt16) { self.fileDescriptor = f @@ -105,20 +111,21 @@ final internal class URing { internal static let POLLIN: CUnsignedInt = numericCast(CNIOLinux.POLLIN) internal static let POLLOUT: CUnsignedInt = numericCast(CNIOLinux.POLLOUT) internal static let POLLERR: CUnsignedInt = numericCast(CNIOLinux.POLLERR) - internal static let POLLRDHUP: CUnsignedInt = CNIOLinux_POLLRDHUP() // numericCast(CNIOLinux.POLLRDHUP) + internal static let POLLRDHUP: CUnsignedInt = CNIOLinux_POLLRDHUP() // numericCast(CNIOLinux.POLLRDHUP) internal static let POLLHUP: CUnsignedInt = numericCast(CNIOLinux.POLLHUP) - internal static let POLLCANCEL: CUnsignedInt = 0xF0000000 // Poll cancelled, need to reregister for singleshot polls + // Poll cancelled, need to reregister for singleshot polls + internal static let POLLCANCEL: CUnsignedInt = 0xF000_0000 private var ring = io_uring() private let ringEntries: CUnsignedInt = 8192 - private let cqeMaxCount: UInt32 = 8192 // this is the max chunk of CQE we take. + private let cqeMaxCount: UInt32 = 8192 // this is the max chunk of CQE we take. var cqes: UnsafeMutablePointer?> - var fdEvents = [FDEventKey : UInt32]() // fd, sequence_identifier : merged event_poll_return + var fdEvents = [FDEventKey: UInt32]() // fd, sequence_identifier : merged event_poll_return var emptyCqe = io_uring_cqe() var fd: CInt { - return ring.ring_fd + ring.ring_fd } static var io_uring_use_multishot_poll: Bool { @@ -129,7 +136,7 @@ final internal class URing { #endif } - func _dumpCqes(_ header:String, count: Int = 1) { + func _dumpCqes(_ header: String, count: Int = 1) { #if SWIFTNIO_IO_URING_DEBUG_DUMP_CQE func _debugPrintCQE(_ s: String) { print("Q [\(NIOThread.current)] " + s) @@ -143,31 +150,32 @@ final internal class URing { for i in 0..?>.allocate(capacity: Int(cqeMaxCount)) - cqes.initialize(repeating:&emptyCqe, count:Int(cqeMaxCount)) + cqes.initialize(repeating: &emptyCqe, count: Int(cqeMaxCount)) } deinit { cqes.deallocate() } - internal func io_uring_queue_init() throws -> () { - if (CNIOLinux.io_uring_queue_init(ringEntries, &ring, 0 ) != 0) - { - throw URingError.uringSetupFailure - } + internal func io_uring_queue_init() throws { + if CNIOLinux.io_uring_queue_init(ringEntries, &ring, 0) != 0 { + throw URingError.uringSetupFailure + } _debugPrint("io_uring_queue_init \(self.ring.ring_fd)") - } + } internal func io_uring_queue_exit() { _debugPrint("io_uring_queue_exit \(self.ring.ring_fd)") @@ -178,8 +186,7 @@ final internal class URing { // modifications - we never want to have to handle retries of // SQE allocation in all places it could possibly occur. // If the SQ ring is full, we may need to submit IO first - func withSQE(_ body: (UnsafeMutablePointer?) throws -> R) rethrows -> R - { + func withSQE(_ body: (UnsafeMutablePointer?) throws -> R) rethrows -> R { // io_uring_submit can fail here due to backpressure from kernel for not reaping CQE:s. // // I think we should consider handling that as a fatalError, as fundamentally the ring size is too small @@ -191,7 +198,7 @@ final internal class URing { // while true { if let sqe = CNIOLinux.io_uring_get_sqe(&ring) { - return try body(sqe) + return try body(sqe) } self.io_uring_flush() } @@ -202,15 +209,14 @@ final internal class URing { // has gone down and we are re-registering polls this means we will silently lose any // entries after the failed fd. Ouch. Proper approach is to use io_uring_sq_ready() in a loop. // See: https://github.com/axboe/liburing/issues/309 - internal func io_uring_flush() { // When using SQPOLL this is basically a NOP + internal func io_uring_flush() { // When using SQPOLL this is basically a NOP var waitingSubmissions: UInt32 = 0 var submissionCount = 0 var retval: CInt waitingSubmissions = CNIOLinux.io_uring_sq_ready(&ring) - loop: while (waitingSubmissions > 0) - { + loop: while waitingSubmissions > 0 { retval = CNIOLinux.io_uring_submit(&ring) submissionCount += 1 @@ -225,21 +231,27 @@ final internal class URing { // trying to get new SQE if the actual SQE queue is full, but // that would be due to user error in usage IMHO and we should fatalError there. case -EAGAIN, -EBUSY: - _debugPrint("io_uring_flush io_uring_submit -EBUSY/-EAGAIN waitingSubmissions[\(waitingSubmissions)] submissionCount[\(submissionCount)]. Breaking out and resubmitting later (whenReady() end).") + _debugPrint( + "io_uring_flush io_uring_submit -EBUSY/-EAGAIN waitingSubmissions[\(waitingSubmissions)] submissionCount[\(submissionCount)]. Breaking out and resubmitting later (whenReady() end)." + ) break loop // -ENOMEM when there is not enough memory to do internal allocations on the kernel side. // Right nog we just loop with a sleep trying to buy time, but could also possibly fatalError here. // See: https://github.com/axboe/liburing/issues/309 case -ENOMEM: - usleep(10_000) // let's not busy loop to give the kernel some time to recover if possible + usleep(10_000) // let's not busy loop to give the kernel some time to recover if possible _debugPrint("io_uring_flush io_uring_submit -ENOMEM \(submissionCount)") case 0: - _debugPrint("io_uring_flush io_uring_submit submitted 0, so far needed submissionCount[\(submissionCount)] waitingSubmissions[\(waitingSubmissions)] submitted [\(retval)] SQE:s this iteration") + _debugPrint( + "io_uring_flush io_uring_submit submitted 0, so far needed submissionCount[\(submissionCount)] waitingSubmissions[\(waitingSubmissions)] submitted [\(retval)] SQE:s this iteration" + ) break case 1...: - _debugPrint("io_uring_flush io_uring_submit needed [\(submissionCount)] submission(s), submitted [\(retval)] SQE:s out of [\(waitingSubmissions)] possible") + _debugPrint( + "io_uring_flush io_uring_submit needed [\(submissionCount)] submission(s), submitted [\(retval)] SQE:s out of [\(waitingSubmissions)] possible" + ) break - default: // other errors + default: // other errors fatalError("Unexpected error [\(retval)] from io_uring_submit ") } @@ -248,81 +260,132 @@ final internal class URing { } // we stuff event type into the upper byte, the next 3 bytes gives us the sequence number (16M before wrap) and final 4 bytes are fd. - internal func io_uring_prep_poll_add(fileDescriptor: CInt, pollMask: UInt32, registrationID: SelectorRegistrationID, submitNow: Bool = true, multishot: Bool = true) -> () { - let bitPattern = UInt64(URingUserData(registrationID: registrationID, fileDescriptor: fileDescriptor, eventType:CQEEventType.poll)) + internal func io_uring_prep_poll_add( + fileDescriptor: CInt, + pollMask: UInt32, + registrationID: SelectorRegistrationID, + submitNow: Bool = true, + multishot: Bool = true + ) { + let bitPattern = UInt64( + URingUserData(registrationID: registrationID, fileDescriptor: fileDescriptor, eventType: CQEEventType.poll) + ) let bitpatternAsPointer = UnsafeMutableRawPointer.init(bitPattern: UInt(bitPattern)) - _debugPrint("io_uring_prep_poll_add fileDescriptor[\(fileDescriptor)] pollMask[\(pollMask)] bitpatternAsPointer[\(String(describing:bitpatternAsPointer))] submitNow[\(submitNow)] multishot[\(multishot)]") + _debugPrint( + "io_uring_prep_poll_add fileDescriptor[\(fileDescriptor)] pollMask[\(pollMask)] bitpatternAsPointer[\(String(describing:bitpatternAsPointer))] submitNow[\(submitNow)] multishot[\(multishot)]" + ) self.withSQE { sqe in CNIOLinux.io_uring_prep_poll_add(sqe, fileDescriptor, pollMask) - CNIOLinux.io_uring_sqe_set_data(sqe, bitpatternAsPointer) // must be done after prep_poll_add, otherwise zeroed out. + // must be done after prep_poll_add, otherwise zeroed out. + CNIOLinux.io_uring_sqe_set_data(sqe, bitpatternAsPointer) if multishot { - sqe!.pointee.len |= IORING_POLL_ADD_MULTI; // turn on multishots, set through environment variable + // turn on multishots, set through environment variable + sqe!.pointee.len |= IORING_POLL_ADD_MULTI } } - + if submitNow { self.io_uring_flush() } } - internal func io_uring_prep_poll_remove(fileDescriptor: CInt, pollMask: UInt32, registrationID: SelectorRegistrationID, submitNow: Bool = true, link: Bool = false) -> () { - let bitPattern = UInt64(URingUserData(registrationID: registrationID, - fileDescriptor: fileDescriptor, - eventType:CQEEventType.poll)) - let userbitPattern = UInt64(URingUserData(registrationID: registrationID, - fileDescriptor: fileDescriptor, - eventType:CQEEventType.pollDelete)) - - _debugPrint("io_uring_prep_poll_remove fileDescriptor[\(fileDescriptor)] pollMask[\(pollMask)] bitpatternAsPointer[\(String(describing: bitPattern))] userBitpatternAsPointer[\(String(describing: userbitPattern))] submitNow[\(submitNow)] link[\(link)]") + internal func io_uring_prep_poll_remove( + fileDescriptor: CInt, + pollMask: UInt32, + registrationID: SelectorRegistrationID, + submitNow: Bool = true, + link: Bool = false + ) { + let bitPattern = UInt64( + URingUserData( + registrationID: registrationID, + fileDescriptor: fileDescriptor, + eventType: CQEEventType.poll + ) + ) + let userbitPattern = UInt64( + URingUserData( + registrationID: registrationID, + fileDescriptor: fileDescriptor, + eventType: CQEEventType.pollDelete + ) + ) + + _debugPrint( + "io_uring_prep_poll_remove fileDescriptor[\(fileDescriptor)] pollMask[\(pollMask)] bitpatternAsPointer[\(String(describing: bitPattern))] userBitpatternAsPointer[\(String(describing: userbitPattern))] submitNow[\(submitNow)] link[\(link)]" + ) self.withSQE { sqe in CNIOLinux.io_uring_prep_poll_remove(sqe, .init(userData: bitPattern)) - CNIOLinux.io_uring_sqe_set_data(sqe, .init(userData: userbitPattern)) // must be done after prep_poll_add, otherwise zeroed out. + // must be done after prep_poll_add, otherwise zeroed out. + CNIOLinux.io_uring_sqe_set_data(sqe, .init(userData: userbitPattern)) if link { CNIOLinux_io_uring_set_link_flag(sqe) } } - + if submitNow { self.io_uring_flush() } } - - // the update/multishot polls are - internal func io_uring_poll_update(fileDescriptor: CInt, newPollmask: UInt32, oldPollmask: UInt32, registrationID: SelectorRegistrationID, submitNow: Bool = true, multishot: Bool = true) -> () { - - let bitpattern = UInt64(URingUserData(registrationID: registrationID, - fileDescriptor: fileDescriptor, - eventType:CQEEventType.poll)) - let userbitPattern = UInt64(URingUserData(registrationID: registrationID, - fileDescriptor: fileDescriptor, - eventType:CQEEventType.pollModify)) - _debugPrint("io_uring_poll_update fileDescriptor[\(fileDescriptor)] oldPollmask[\(oldPollmask)] newPollmask[\(newPollmask)] userBitpatternAsPointer[\(String(describing: userbitPattern))]") + // the update/multishot polls are + internal func io_uring_poll_update( + fileDescriptor: CInt, + newPollmask: UInt32, + oldPollmask: UInt32, + registrationID: SelectorRegistrationID, + submitNow: Bool = true, + multishot: Bool = true + ) { + + let bitpattern = UInt64( + URingUserData( + registrationID: registrationID, + fileDescriptor: fileDescriptor, + eventType: CQEEventType.poll + ) + ) + let userbitPattern = UInt64( + URingUserData( + registrationID: registrationID, + fileDescriptor: fileDescriptor, + eventType: CQEEventType.pollModify + ) + ) + + _debugPrint( + "io_uring_poll_update fileDescriptor[\(fileDescriptor)] oldPollmask[\(oldPollmask)] newPollmask[\(newPollmask)] userBitpatternAsPointer[\(String(describing: userbitPattern))]" + ) self.withSQE { sqe in // "Documentation" for multishot polls and updates here: // https://git.kernel.dk/cgit/linux-block/commit/?h=poll-multiple&id=33021a19e324fb747c2038416753e63fd7cd9266 var flags = IORING_POLL_UPDATE_EVENTS | IORING_POLL_UPDATE_USER_DATA if multishot { - flags |= IORING_POLL_ADD_MULTI // ask for multiple updates + flags |= IORING_POLL_ADD_MULTI // ask for multiple updates } - CNIOLinux.io_uring_prep_poll_update(sqe, .init(userData: bitpattern), .init(userData: bitpattern), newPollmask, flags) + CNIOLinux.io_uring_prep_poll_update( + sqe, + .init(userData: bitpattern), + .init(userData: bitpattern), + newPollmask, + flags + ) CNIOLinux.io_uring_sqe_set_data(sqe, .init(userData: userbitPattern)) } - + if submitNow { self.io_uring_flush() } } - internal func _debugPrint(_ s: @autoclosure () -> String) - { + internal func _debugPrint(_ s: @autoclosure () -> String) { #if SWIFTNIO_IO_URING_DEBUG_URING print("L [\(NIOThread.current)] " + s()) #endif @@ -332,7 +395,7 @@ final internal class URing { // this minimizes amount of events propagating up and allows Selector to discard // events with an old sequence identifier. internal func _process_cqe(events: UnsafeMutablePointer, cqeIndex: Int, multishot: Bool) { - let bitPattern = UInt(bitPattern:io_uring_cqe_get_data(cqes[cqeIndex])) + let bitPattern = UInt(bitPattern: io_uring_cqe_get_data(cqes[cqeIndex])) let uringUserData = URingUserData(rawValue: UInt64(bitPattern)) let result = cqes[cqeIndex]!.pointee.res @@ -342,13 +405,14 @@ final internal class URing { case -ECANCELED: var pollError: UInt32 = 0 assert(uringUserData.fileDescriptor >= 0, "fd must be zero or greater") - if multishot { // -ECANCELED for streaming polls, should signal error + if multishot { // -ECANCELED for streaming polls, should signal error pollError = URing.POLLERR | URing.POLLHUP - } else { // this just signals that Selector just should resubmit a new fresh poll + } else { // this just signals that Selector just should resubmit a new fresh poll pollError = URing.POLLCANCEL } if let current = fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] { - fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = current | pollError + fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = + current | pollError } else { fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = pollError } @@ -365,29 +429,30 @@ final internal class URing { case -EBADF: _debugPrint("Failed poll with -EBADF for cqeIndex[\(cqeIndex)]") break - case ..<0: // other errors + case ..<0: // other errors fatalError("Failed poll with unexpected error (\(result) for cqeIndex[\(cqeIndex)]") break - case 0: // successfull chained add for singleshots, not an event + case 0: // successfull chained add for singleshots, not an event break - default: // positive success + default: // positive success assert(uringUserData.fileDescriptor >= 0, "fd must be zero or greater") let uresult = UInt32(result) if let current = fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] { - fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = current | uresult + fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = current | uresult } else { fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = uresult } } - case .pollModify: // we only get this for multishot modifications + case .pollModify: // we only get this for multishot modifications switch result { - case -ECANCELED: // -ECANCELED for streaming polls, should signal error + case -ECANCELED: // -ECANCELED for streaming polls, should signal error assert(uringUserData.fileDescriptor >= 0, "fd must be zero or greater") - let pollError = URing.POLLERR // URing.POLLERR // (URing.POLLHUP | URing.POLLERR) + let pollError = URing.POLLERR // URing.POLLERR // (URing.POLLHUP | URing.POLLERR) if let current = fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] { - fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = current | pollError + fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = + current | pollError } else { fdEvents[FDEventKey(uringUserData.fileDescriptor, uringUserData.registrationID)] = pollError } @@ -402,12 +467,12 @@ final internal class URing { case -EBADF: _debugPrint("Failed pollModify with -EBADF for cqeIndex[\(cqeIndex)]") break - case ..<0: // other errors + case ..<0: // other errors fatalError("Failed pollModify with unexpected error (\(result) for cqeIndex[\(cqeIndex)]") break - case 0: // successfull chained add, not an event + case 0: // successfull chained add, not an event break - default: // positive success + default: // positive success fatalError("pollModify returned > 0") } break @@ -416,7 +481,11 @@ final internal class URing { } } - internal func io_uring_peek_batch_cqe(events: UnsafeMutablePointer, maxevents: UInt32, multishot: Bool = true) -> Int { + internal func io_uring_peek_batch_cqe( + events: UnsafeMutablePointer, + maxevents: UInt32, + multishot: Bool = true + ) -> Int { var eventCount = 0 var currentCqeCount = CNIOLinux.io_uring_peek_batch_cqe(&ring, cqes, cqeMaxCount) @@ -432,19 +501,20 @@ final internal class URing { assert(currentCqeCount >= 0, "currentCqeCount should never be negative") assert(maxevents > 0, "maxevents should be a positive number") - for cqeIndex in 0 ..< currentCqeCount - { - self._process_cqe(events: events, cqeIndex: Int(cqeIndex), multishot:multishot) + for cqeIndex in 0.., error: Int32, multishot: Bool) throws -> Int { + internal func _io_uring_wait_cqe_shared( + events: UnsafeMutablePointer, + error: Int32, + multishot: Bool + ) throws -> Int { var eventCount = 0 switch error { @@ -489,7 +563,7 @@ final internal class URing { self._dumpCqes("_io_uring_wait_cqe_shared") - self._process_cqe(events: events, cqeIndex: 0, multishot:multishot) + self._process_cqe(events: events, cqeIndex: 0, multishot: multishot) CNIOLinux.io_uring_cqe_seen(&ring, cqes[0]) @@ -502,27 +576,36 @@ final internal class URing { _debugPrint("_io_uring_wait_cqe_shared if let firstEvent = fdEvents.first failed") } - fdEvents.removeAll(keepingCapacity: true) // reused for next batch + fdEvents.removeAll(keepingCapacity: true) // reused for next batch return eventCount } - internal func io_uring_wait_cqe(events: UnsafeMutablePointer, maxevents: UInt32, multishot: Bool = true) throws -> Int { + internal func io_uring_wait_cqe( + events: UnsafeMutablePointer, + maxevents: UInt32, + multishot: Bool = true + ) throws -> Int { _debugPrint("io_uring_wait_cqe") let error = CNIOLinux.io_uring_wait_cqe(&ring, cqes) - return try self._io_uring_wait_cqe_shared(events: events, error: error, multishot:multishot) + return try self._io_uring_wait_cqe_shared(events: events, error: error, multishot: multishot) } - internal func io_uring_wait_cqe_timeout(events: UnsafeMutablePointer, maxevents: UInt32, timeout: TimeAmount, multishot: Bool = true) throws -> Int { + internal func io_uring_wait_cqe_timeout( + events: UnsafeMutablePointer, + maxevents: UInt32, + timeout: TimeAmount, + multishot: Bool = true + ) throws -> Int { var ts = timeout.kernelTimespec() _debugPrint("io_uring_wait_cqe_timeout.ETIME milliseconds \(ts)") let error = CNIOLinux.io_uring_wait_cqe_timeout(&ring, cqes, &ts) - return try self._io_uring_wait_cqe_shared(events: events, error: error, multishot:multishot) + return try self._io_uring_wait_cqe_shared(events: events, error: error, multishot: multishot) } } diff --git a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift index 81c0811896..6e245b3862 100644 --- a/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift +++ b/Sources/NIOPosix/MultiThreadedEventLoopGroup.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import NIOCore -import NIOConcurrencyHelpers -import Dispatch import Atomics +import Dispatch +import NIOConcurrencyHelpers +import NIOCore struct NIORegistration: Registration { enum ChannelType { @@ -71,22 +71,26 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { private var runState: RunState = .running private let canBeShutDown: Bool - private static func runTheLoop(thread: NIOThread, - parentGroup: MultiThreadedEventLoopGroup? /* nil iff thread take-over */, - canEventLoopBeShutdownIndividually: Bool, - selectorFactory: @escaping () throws -> NIOPosix.Selector, - initializer: @escaping ThreadInitializer, - metricsDelegate: NIOEventLoopMetricsDelegate?, - _ callback: @escaping (SelectableEventLoop) -> Void) { + private static func runTheLoop( + thread: NIOThread, + parentGroup: MultiThreadedEventLoopGroup?, // nil iff thread take-over + canEventLoopBeShutdownIndividually: Bool, + selectorFactory: @escaping () throws -> NIOPosix.Selector, + initializer: @escaping ThreadInitializer, + metricsDelegate: NIOEventLoopMetricsDelegate?, + _ callback: @escaping (SelectableEventLoop) -> Void + ) { assert(NIOThread.current == thread) initializer(thread) do { - let loop = SelectableEventLoop(thread: thread, - parentGroup: parentGroup, - selector: try selectorFactory(), - canBeShutdownIndividually: canEventLoopBeShutdownIndividually, - metricsDelegate: metricsDelegate) + let loop = SelectableEventLoop( + thread: thread, + parentGroup: parentGroup, + selector: try selectorFactory(), + canBeShutdownIndividually: canEventLoopBeShutdownIndividually, + metricsDelegate: metricsDelegate + ) threadSpecificEventLoop.currentValue = loop defer { threadSpecificEventLoop.currentValue = nil @@ -100,23 +104,27 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { } } - private static func setupThreadAndEventLoop(name: String, - parentGroup: MultiThreadedEventLoopGroup, - selectorFactory: @escaping () throws -> NIOPosix.Selector, - initializer: @escaping ThreadInitializer, - metricsDelegate: NIOEventLoopMetricsDelegate?) -> SelectableEventLoop { + private static func setupThreadAndEventLoop( + name: String, + parentGroup: MultiThreadedEventLoopGroup, + selectorFactory: @escaping () throws -> NIOPosix.Selector, + initializer: @escaping ThreadInitializer, + metricsDelegate: NIOEventLoopMetricsDelegate? + ) -> SelectableEventLoop { let lock = ConditionLock(value: 0) - /* synchronised by `lock` */ + // synchronised by `lock` var _loop: SelectableEventLoop! = nil NIOThread.spawnAndRun(name: name, detachThread: false) { t in - MultiThreadedEventLoopGroup.runTheLoop(thread: t, - parentGroup: parentGroup, - canEventLoopBeShutdownIndividually: false, // part of MTELG - selectorFactory: selectorFactory, - initializer: initializer, - metricsDelegate: metricsDelegate) { l in + MultiThreadedEventLoopGroup.runTheLoop( + thread: t, + parentGroup: parentGroup, + canEventLoopBeShutdownIndividually: false, // part of MTELG + selectorFactory: selectorFactory, + initializer: initializer, + metricsDelegate: metricsDelegate + ) { l in lock.lock(whenValue: 0) _loop = l lock.unlock(withValue: 1) @@ -137,10 +145,12 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// - arguments: /// - numberOfThreads: The number of `Threads` to use. public convenience init(numberOfThreads: Int) { - self.init(numberOfThreads: numberOfThreads, - canBeShutDown: true, - metricsDelegate: nil, - selectorFactory: NIOPosix.Selector.init) + self.init( + numberOfThreads: numberOfThreads, + canBeShutDown: true, + metricsDelegate: nil, + selectorFactory: NIOPosix.Selector.init + ) } /// Creates a `MultiThreadedEventLoopGroup` instance which uses `numberOfThreads`. @@ -154,86 +164,120 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// - numberOfThreads: The number of `Threads` to use. /// - metricsDelegate: Delegate for collecting information from this eventloop public convenience init(numberOfThreads: Int, metricsDelegate: NIOEventLoopMetricsDelegate) { - self.init(numberOfThreads: numberOfThreads, - canBeShutDown: true, - metricsDelegate: metricsDelegate, - selectorFactory: NIOPosix.Selector.init) + self.init( + numberOfThreads: numberOfThreads, + canBeShutDown: true, + metricsDelegate: metricsDelegate, + selectorFactory: NIOPosix.Selector.init + ) } /// Create a ``MultiThreadedEventLoopGroup`` that cannot be shut down and must not be `deinit`ed. /// /// This is only useful for global singletons. - public static func _makePerpetualGroup(threadNamePrefix: String, - numberOfThreads: Int) -> MultiThreadedEventLoopGroup { - return self.init(numberOfThreads: numberOfThreads, - canBeShutDown: false, - threadNamePrefix: threadNamePrefix, - metricsDelegate: nil, - selectorFactory: NIOPosix.Selector.init) + public static func _makePerpetualGroup( + threadNamePrefix: String, + numberOfThreads: Int + ) -> MultiThreadedEventLoopGroup { + self.init( + numberOfThreads: numberOfThreads, + canBeShutDown: false, + threadNamePrefix: threadNamePrefix, + metricsDelegate: nil, + selectorFactory: NIOPosix.Selector.init + ) } - internal convenience init(numberOfThreads: Int, - metricsDelegate: NIOEventLoopMetricsDelegate?, - selectorFactory: @escaping () throws -> NIOPosix.Selector) { + internal convenience init( + numberOfThreads: Int, + metricsDelegate: NIOEventLoopMetricsDelegate?, + selectorFactory: @escaping () throws -> NIOPosix.Selector + ) { precondition(numberOfThreads > 0, "numberOfThreads must be positive") let initializers: [ThreadInitializer] = Array(repeating: { _ in }, count: numberOfThreads) - self.init(threadInitializers: initializers, canBeShutDown: true, metricsDelegate: metricsDelegate, selectorFactory: selectorFactory) + self.init( + threadInitializers: initializers, + canBeShutDown: true, + metricsDelegate: metricsDelegate, + selectorFactory: selectorFactory + ) } - internal convenience init(numberOfThreads: Int, - canBeShutDown: Bool, - threadNamePrefix: String, - metricsDelegate: NIOEventLoopMetricsDelegate?, - selectorFactory: @escaping () throws -> NIOPosix.Selector) { + internal convenience init( + numberOfThreads: Int, + canBeShutDown: Bool, + threadNamePrefix: String, + metricsDelegate: NIOEventLoopMetricsDelegate?, + selectorFactory: @escaping () throws -> NIOPosix.Selector + ) { precondition(numberOfThreads > 0, "numberOfThreads must be positive") let initializers: [ThreadInitializer] = Array(repeating: { _ in }, count: numberOfThreads) - self.init(threadInitializers: initializers, - canBeShutDown: canBeShutDown, - threadNamePrefix: threadNamePrefix, - metricsDelegate: metricsDelegate, - selectorFactory: selectorFactory) + self.init( + threadInitializers: initializers, + canBeShutDown: canBeShutDown, + threadNamePrefix: threadNamePrefix, + metricsDelegate: metricsDelegate, + selectorFactory: selectorFactory + ) } - internal convenience init(numberOfThreads: Int, - canBeShutDown: Bool, - metricsDelegate: NIOEventLoopMetricsDelegate?, - selectorFactory: @escaping () throws -> NIOPosix.Selector) { + internal convenience init( + numberOfThreads: Int, + canBeShutDown: Bool, + metricsDelegate: NIOEventLoopMetricsDelegate?, + selectorFactory: @escaping () throws -> NIOPosix.Selector + ) { precondition(numberOfThreads > 0, "numberOfThreads must be positive") let initializers: [ThreadInitializer] = Array(repeating: { _ in }, count: numberOfThreads) - self.init(threadInitializers: initializers, - canBeShutDown: canBeShutDown, - metricsDelegate: metricsDelegate, - selectorFactory: selectorFactory) + self.init( + threadInitializers: initializers, + canBeShutDown: canBeShutDown, + metricsDelegate: metricsDelegate, + selectorFactory: selectorFactory + ) } - internal convenience init(threadInitializers: [ThreadInitializer], - metricsDelegate: NIOEventLoopMetricsDelegate?, - selectorFactory: @escaping () throws -> NIOPosix.Selector = NIOPosix.Selector.init) { - self.init(threadInitializers: threadInitializers, canBeShutDown: true, metricsDelegate: metricsDelegate, selectorFactory: selectorFactory) + internal convenience init( + threadInitializers: [ThreadInitializer], + metricsDelegate: NIOEventLoopMetricsDelegate?, + selectorFactory: @escaping () throws -> NIOPosix.Selector = NIOPosix.Selector + .init + ) { + self.init( + threadInitializers: threadInitializers, + canBeShutDown: true, + metricsDelegate: metricsDelegate, + selectorFactory: selectorFactory + ) } /// Creates a `MultiThreadedEventLoopGroup` instance which uses the given `ThreadInitializer`s. One `NIOThread` per `ThreadInitializer` is created and used. /// /// - arguments: /// - threadInitializers: The `ThreadInitializer`s to use. - internal init(threadInitializers: [ThreadInitializer], - canBeShutDown: Bool, - threadNamePrefix: String = "NIO-ELT-", - metricsDelegate: NIOEventLoopMetricsDelegate?, - selectorFactory: @escaping () throws -> NIOPosix.Selector = NIOPosix.Selector.init) { + internal init( + threadInitializers: [ThreadInitializer], + canBeShutDown: Bool, + threadNamePrefix: String = "NIO-ELT-", + metricsDelegate: NIOEventLoopMetricsDelegate?, + selectorFactory: @escaping () throws -> NIOPosix.Selector = NIOPosix.Selector + .init + ) { self.threadNamePrefix = threadNamePrefix let myGroupID = nextEventLoopGroupID.loadThenWrappingIncrement(ordering: .relaxed) self.myGroupID = myGroupID var idx = 0 self.canBeShutDown = canBeShutDown - self.eventLoops = [] // Just so we're fully initialised and can vend `self` to the `SelectableEventLoop`. + self.eventLoops = [] // Just so we're fully initialised and can vend `self` to the `SelectableEventLoop`. self.eventLoops = threadInitializers.map { initializer in // Maximum name length on linux is 16 by default. - let ev = MultiThreadedEventLoopGroup.setupThreadAndEventLoop(name: "\(threadNamePrefix)\(myGroupID)-#\(idx)", - parentGroup: self, - selectorFactory: selectorFactory, - initializer: initializer, - metricsDelegate: metricsDelegate) + let ev = MultiThreadedEventLoopGroup.setupThreadAndEventLoop( + name: "\(threadNamePrefix)\(myGroupID)-#\(idx)", + parentGroup: self, + selectorFactory: selectorFactory, + initializer: initializer, + metricsDelegate: metricsDelegate + ) idx += 1 return ev } @@ -247,18 +291,18 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// /// - returns: The current `EventLoop` for the calling thread or `nil` if none is assigned to the thread. public static var currentEventLoop: EventLoop? { - return self.currentSelectableEventLoop + self.currentSelectableEventLoop } internal static var currentSelectableEventLoop: SelectableEventLoop? { - return threadSpecificEventLoop.currentValue + threadSpecificEventLoop.currentValue } /// Returns an `EventLoopIterator` over the `EventLoop`s in this `MultiThreadedEventLoopGroup`. /// /// - returns: `EventLoopIterator` public func makeIterator() -> EventLoopIterator { - return EventLoopIterator(self.eventLoops) + EventLoopIterator(self.eventLoops) } /// Returns the next `EventLoop` from this `MultiThreadedEventLoopGroup`. @@ -267,7 +311,7 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// /// - returns: The next `EventLoop` to use. public func next() -> EventLoop { - return eventLoops[abs(index.loadThenWrappingIncrement(ordering: .relaxed) % eventLoops.count)] + eventLoops[abs(index.loadThenWrappingIncrement(ordering: .relaxed) % eventLoops.count)] } /// Returns the current `EventLoop` if we are on an `EventLoop` of this `MultiThreadedEventLoopGroup` instance. @@ -275,8 +319,9 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// - returns: The `EventLoop`. public func any() -> EventLoop { if let loop = Self.currentSelectableEventLoop, - // We are on `loop`'s thread, so we may ask for the its parent group. - loop.parentGroupCallableFromThisEventLoopOnly() === self { + // We are on `loop`'s thread, so we may ask for the its parent group. + loop.parentGroupCallableFromThisEventLoopOnly() === self + { // Nice, we can return this. loop.assertInEventLoop() return loop @@ -362,23 +407,26 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { for loop in self.eventLoops { loop.syncFinaliseClose(joinThread: true) } - let (overallError, queueCallbackPairs): (Error?, [(DispatchQueue, ShutdownGracefullyCallback)]) = self.shutdownLock.withLock { - switch self.runState { - case .closed, .running: - preconditionFailure("MultiThreadedEventLoopGroup in illegal state when closing: \(self.runState)") - case .closing(let callbacks): - let overallError: Error? = { - switch result { - case .success: - return nil - case .failure(let error): - return error - } - }() - self.runState = .closed(overallError) - return (overallError, callbacks) + let (overallError, queueCallbackPairs): (Error?, [(DispatchQueue, ShutdownGracefullyCallback)]) = self + .shutdownLock.withLock { + switch self.runState { + case .closed, .running: + preconditionFailure( + "MultiThreadedEventLoopGroup in illegal state when closing: \(self.runState)" + ) + case .closing(let callbacks): + let overallError: Error? = { + switch result { + case .success: + return nil + case .failure(let error): + return error + } + }() + self.runState = .closed(overallError) + return (overallError, callbacks) + } } - } queue.async { handler(overallError) @@ -401,24 +449,31 @@ public final class MultiThreadedEventLoopGroup: EventLoopGroup { /// `EventLoop` reference. Just like usually on the `EventLoop`, do not block in `callback`. public static func withCurrentThreadAsEventLoop(_ callback: @escaping (EventLoop) -> Void) { let callingThread = NIOThread.current - MultiThreadedEventLoopGroup.runTheLoop(thread: callingThread, - parentGroup: nil, - canEventLoopBeShutdownIndividually: true, - selectorFactory: NIOPosix.Selector.init, - initializer: { _ in }, - metricsDelegate: nil) { loop in - loop.assertInEventLoop() - callback(loop) - } + MultiThreadedEventLoopGroup.runTheLoop( + thread: callingThread, + parentGroup: nil, + canEventLoopBeShutdownIndividually: true, + selectorFactory: NIOPosix.Selector.init, + initializer: { _ in }, + metricsDelegate: nil, + { loop in + loop.assertInEventLoop() + callback(loop) + } + ) } public func _preconditionSafeToSyncShutdown(file: StaticString, line: UInt) { if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { - preconditionFailure(""" - BUG DETECTED: syncShutdownGracefully() must not be called when on an EventLoop. - Calling syncShutdownGracefully() on any EventLoop can lead to deadlocks. - Current eventLoop: \(eventLoop) - """, file: file, line: line) + preconditionFailure( + """ + BUG DETECTED: syncShutdownGracefully() must not be called when on an EventLoop. + Calling syncShutdownGracefully() on any EventLoop can lead to deadlocks. + Current eventLoop: \(eventLoop) + """, + file: file, + line: line + ) } } } @@ -427,7 +482,7 @@ extension MultiThreadedEventLoopGroup: @unchecked Sendable {} extension MultiThreadedEventLoopGroup: CustomStringConvertible { public var description: String { - return "MultiThreadedEventLoopGroup { threadPattern = \(self.threadNamePrefix)\(self.myGroupID)-#* }" + "MultiThreadedEventLoopGroup { threadPattern = \(self.threadNamePrefix)\(self.myGroupID)-#* }" } } @@ -481,7 +536,7 @@ internal struct ScheduledTask { extension ScheduledTask: CustomStringConvertible { @usableFromInline var description: String { - return "ScheduledTask(readyTime: \(self.readyTime))" + "ScheduledTask(readyTime: \(self.readyTime))" } } @@ -497,7 +552,7 @@ extension ScheduledTask: Comparable { @usableFromInline static func == (lhs: ScheduledTask, rhs: ScheduledTask) -> Bool { - return lhs.id == rhs.id + lhs.id == rhs.id } } diff --git a/Sources/NIOPosix/NIOThreadPool.swift b/Sources/NIOPosix/NIOThreadPool.swift index b2db1239af..5f4ed1e04f 100644 --- a/Sources/NIOPosix/NIOThreadPool.swift +++ b/Sources/NIOPosix/NIOThreadPool.swift @@ -126,10 +126,12 @@ public final class NIOThreadPool { case .running(let items): self.state = .modifying queue.async { - items.forEach { $0.workItem(.cancelled) } + for item in items { + item.workItem(.cancelled) + } } self.state = .shuttingDown(Array(repeating: true, count: numberOfThreads)) - (0..) in self.shutdownGracefully { error in if let error = error { cont.resume(throwing: error) diff --git a/Sources/NIOPosix/NonBlockingFileIO.swift b/Sources/NIOPosix/NonBlockingFileIO.swift index 63830a8e0c..2e1c3d565b 100644 --- a/Sources/NIOPosix/NonBlockingFileIO.swift +++ b/Sources/NIOPosix/NonBlockingFileIO.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import NIOConcurrencyHelpers +import NIOCore /// ``NonBlockingFileIO`` is a helper that allows you to read files without blocking the calling thread. /// @@ -33,7 +33,7 @@ public struct NonBlockingFileIO: Sendable { public static let defaultThreadPoolSize = 2 /// The default and recommended chunk size. - public static let defaultChunkSize = 128*1024 + public static let defaultChunkSize = 128 * 1024 /// ``NonBlockingFileIO`` errors. public enum Error: Swift.Error { @@ -74,19 +74,23 @@ public struct NonBlockingFileIO: Sendable { /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. @preconcurrency - public func readChunked(fileRegion: FileRegion, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { + public func readChunked( + fileRegion: FileRegion, + chunkSize: Int = NonBlockingFileIO.defaultChunkSize, + allocator: ByteBufferAllocator, + eventLoop: EventLoop, + chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture + ) -> EventLoopFuture { let readableBytes = fileRegion.readableBytes - return self.readChunked(fileHandle: fileRegion.fileHandle, - fromOffset: Int64(fileRegion.readerIndex), - byteCount: readableBytes, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) + return self.readChunked( + fileHandle: fileRegion.fileHandle, + fromOffset: Int64(fileRegion.readerIndex), + byteCount: readableBytes, + chunkSize: chunkSize, + allocator: allocator, + eventLoop: eventLoop, + chunkHandler: chunkHandler + ) } /// Read `byteCount` bytes in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread @@ -114,18 +118,23 @@ public struct NonBlockingFileIO: Sendable { /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. @preconcurrency - public func readChunked(fileHandle: NIOFileHandle, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: nil, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) + public func readChunked( + fileHandle: NIOFileHandle, + byteCount: Int, + chunkSize: Int = NonBlockingFileIO.defaultChunkSize, + allocator: ByteBufferAllocator, + eventLoop: EventLoop, + chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture + ) -> EventLoopFuture { + self.readChunked0( + fileHandle: fileHandle, + fromOffset: nil, + byteCount: byteCount, + chunkSize: chunkSize, + allocator: allocator, + eventLoop: eventLoop, + chunkHandler: chunkHandler + ) } /// Read `byteCount` bytes from offset `fileOffset` in chunks of `chunkSize` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread @@ -153,30 +162,37 @@ public struct NonBlockingFileIO: Sendable { /// - chunkHandler: Called for every chunk read. The next chunk will be read upon successful completion of the returned `EventLoopFuture`. If the returned `EventLoopFuture` fails, the overall operation is aborted. /// - returns: An `EventLoopFuture` which is the result of the overall operation. If either the reading of `fileHandle` or `chunkHandler` fails, the `EventLoopFuture` will fail too. If the reading of `fileHandle` as well as `chunkHandler` always succeeded, the `EventLoopFuture` will succeed too. @preconcurrency - public func readChunked(fileHandle: NIOFileHandle, - fromOffset fileOffset: Int64, - byteCount: Int, - chunkSize: Int = NonBlockingFileIO.defaultChunkSize, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, - chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture) -> EventLoopFuture { - return self.readChunked0(fileHandle: fileHandle, - fromOffset: fileOffset, - byteCount: byteCount, - chunkSize: chunkSize, - allocator: allocator, - eventLoop: eventLoop, - chunkHandler: chunkHandler) + public func readChunked( + fileHandle: NIOFileHandle, + fromOffset fileOffset: Int64, + byteCount: Int, + chunkSize: Int = NonBlockingFileIO.defaultChunkSize, + allocator: ByteBufferAllocator, + eventLoop: EventLoop, + chunkHandler: @escaping @Sendable (ByteBuffer) -> EventLoopFuture + ) -> EventLoopFuture { + self.readChunked0( + fileHandle: fileHandle, + fromOffset: fileOffset, + byteCount: byteCount, + chunkSize: chunkSize, + allocator: allocator, + eventLoop: eventLoop, + chunkHandler: chunkHandler + ) } private typealias ReadChunkHandler = @Sendable (ByteBuffer) -> EventLoopFuture - private func readChunked0(fileHandle: NIOFileHandle, - fromOffset: Int64?, - byteCount: Int, - chunkSize: Int, - allocator: ByteBufferAllocator, - eventLoop: EventLoop, chunkHandler: @escaping ReadChunkHandler) -> EventLoopFuture { + private func readChunked0( + fileHandle: NIOFileHandle, + fromOffset: Int64?, + byteCount: Int, + chunkSize: Int, + allocator: ByteBufferAllocator, + eventLoop: EventLoop, + chunkHandler: @escaping ReadChunkHandler + ) -> EventLoopFuture { precondition(chunkSize > 0, "chunkSize must be > 0 (is \(chunkSize))") let remainingReads = 1 + (byteCount / chunkSize) let lastReadSize = byteCount % chunkSize @@ -187,11 +203,13 @@ public struct NonBlockingFileIO: Sendable { if remainingReads > 1 || (remainingReads == 1 && lastReadSize > 0) { let readSize = remainingReads > 1 ? chunkSize : lastReadSize assert(readSize > 0) - let readFuture = self.read0(fileHandle: fileHandle, - fromOffset: fromOffset.map { $0 + bytesReadSoFar }, - byteCount: readSize, - allocator: allocator, - eventLoop: eventLoop) + let readFuture = self.read0( + fileHandle: fileHandle, + fromOffset: fromOffset.map { $0 + bytesReadSoFar }, + byteCount: readSize, + allocator: allocator, + eventLoop: eventLoop + ) readFuture.whenComplete { (result) in switch result { case .success(let buffer): @@ -206,16 +224,18 @@ public struct NonBlockingFileIO: Sendable { switch result { case .success(_): eventLoop.assertInEventLoop() - _read(remainingReads: remainingReads - 1, - bytesReadSoFar: bytesReadSoFar + bytesRead) + _read( + remainingReads: remainingReads - 1, + bytesReadSoFar: bytesReadSoFar + bytesRead + ) case .failure(let error): promise.fail(error) } } - case .failure(let error): - promise.fail(error) - } - } + case .failure(let error): + promise.fail(error) + } + } } else { promise.succeed(()) } @@ -241,13 +261,19 @@ public struct NonBlockingFileIO: Sendable { /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which delivers a `ByteBuffer` if the read was successful or a failure on error. - public func read(fileRegion: FileRegion, allocator: ByteBufferAllocator, eventLoop: EventLoop) -> EventLoopFuture { + public func read( + fileRegion: FileRegion, + allocator: ByteBufferAllocator, + eventLoop: EventLoop + ) -> EventLoopFuture { let readableBytes = fileRegion.readableBytes - return self.read(fileHandle: fileRegion.fileHandle, - fromOffset: Int64(fileRegion.readerIndex), - byteCount: readableBytes, - allocator: allocator, - eventLoop: eventLoop) + return self.read( + fileHandle: fileRegion.fileHandle, + fromOffset: Int64(fileRegion.readerIndex), + byteCount: readableBytes, + allocator: allocator, + eventLoop: eventLoop + ) } /// Read `byteCount` bytes from `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. @@ -268,15 +294,19 @@ public struct NonBlockingFileIO: Sendable { /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which delivers a `ByteBuffer` if the read was successful or a failure on error. - public func read(fileHandle: NIOFileHandle, - byteCount: Int, - allocator: ByteBufferAllocator, - eventLoop: EventLoop) -> EventLoopFuture { - return self.read0(fileHandle: fileHandle, - fromOffset: nil, - byteCount: byteCount, - allocator: allocator, - eventLoop: eventLoop) + public func read( + fileHandle: NIOFileHandle, + byteCount: Int, + allocator: ByteBufferAllocator, + eventLoop: EventLoop + ) -> EventLoopFuture { + self.read0( + fileHandle: fileHandle, + fromOffset: nil, + byteCount: byteCount, + allocator: allocator, + eventLoop: eventLoop + ) } /// Read `byteCount` bytes starting at `fileOffset` from `fileHandle` in ``NonBlockingFileIO``'s private thread pool @@ -298,36 +328,47 @@ public struct NonBlockingFileIO: Sendable { /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which delivers a `ByteBuffer` if the read was successful or a failure on error. - public func read(fileHandle: NIOFileHandle, - fromOffset fileOffset: Int64, - byteCount: Int, - allocator: ByteBufferAllocator, - eventLoop: EventLoop) -> EventLoopFuture { - return self.read0(fileHandle: fileHandle, - fromOffset: fileOffset, - byteCount: byteCount, - allocator: allocator, - eventLoop: eventLoop) + public func read( + fileHandle: NIOFileHandle, + fromOffset fileOffset: Int64, + byteCount: Int, + allocator: ByteBufferAllocator, + eventLoop: EventLoop + ) -> EventLoopFuture { + self.read0( + fileHandle: fileHandle, + fromOffset: fileOffset, + byteCount: byteCount, + allocator: allocator, + eventLoop: eventLoop + ) } - private func read0(fileHandle: NIOFileHandle, - fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems - byteCount rawByteCount: Int, - allocator: ByteBufferAllocator, - eventLoop: EventLoop) -> EventLoopFuture { + private func read0( + fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount rawByteCount: Int, + allocator: ByteBufferAllocator, + eventLoop: EventLoop + ) -> EventLoopFuture { guard rawByteCount > 0 else { return eventLoop.makeSucceededFuture(allocator.buffer(capacity: 0)) } let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) return self.threadPool.runIfActive(eventLoop: eventLoop) { () -> ByteBuffer in - try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + try self.readSync( + fileHandle: fileHandle, + fromOffset: fromOffset, + byteCount: byteCount, + allocator: allocator + ) } } private func readSync( fileHandle: NIOFileHandle, - fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems byteCount: Int, allocator: ByteBufferAllocator ) throws -> ByteBuffer { @@ -337,14 +378,18 @@ public struct NonBlockingFileIO: Sendable { let n = try buf.writeWithUnsafeMutableBytes(minimumWritableBytes: byteCount - bytesRead) { ptr in let res = try fileHandle.withUnsafeFileDescriptor { descriptor -> IOResult in if let offset = fromOffset { - return try Posix.pread(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead, - offset: off_t(offset) + off_t(bytesRead)) + return try Posix.pread( + descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead, + offset: off_t(offset) + off_t(bytesRead) + ) } else { - return try Posix.read(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - bytesRead) + return try Posix.read( + descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - bytesRead + ) } } switch res { @@ -375,10 +420,12 @@ public struct NonBlockingFileIO: Sendable { /// - size: The new file size in bytes to write. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which is fulfilled if the write was successful or fails on error. - public func changeFileSize(fileHandle: NIOFileHandle, - size: Int64, - eventLoop: EventLoop) -> EventLoopFuture<()> { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + public func changeFileSize( + fileHandle: NIOFileHandle, + size: Int64, + eventLoop: EventLoop + ) -> EventLoopFuture<()> { + self.threadPool.runIfActive(eventLoop: eventLoop) { try fileHandle.withUnsafeFileDescriptor { descriptor -> Void in try Posix.ftruncate(descriptor: descriptor, size: off_t(size)) } @@ -391,10 +438,12 @@ public struct NonBlockingFileIO: Sendable { /// - fileHandle: The `NIOFileHandle` to read from. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which is fulfilled with the length of the file in bytes if the write was successful or fails on error. - public func readFileSize(fileHandle: NIOFileHandle, - eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { - return try fileHandle.withUnsafeFileDescriptor { descriptor in + public func readFileSize( + fileHandle: NIOFileHandle, + eventLoop: EventLoop + ) -> EventLoopFuture { + self.threadPool.runIfActive(eventLoop: eventLoop) { + try fileHandle.withUnsafeFileDescriptor { descriptor in let curr = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_CUR) let eof = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_END) try Posix.lseek(descriptor: descriptor, offset: curr, whence: SEEK_SET) @@ -410,10 +459,12 @@ public struct NonBlockingFileIO: Sendable { /// - buffer: The `ByteBuffer` to write. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which is fulfilled if the write was successful or fails on error. - public func write(fileHandle: NIOFileHandle, - buffer: ByteBuffer, - eventLoop: EventLoop) -> EventLoopFuture<()> { - return self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer, eventLoop: eventLoop) + public func write( + fileHandle: NIOFileHandle, + buffer: ByteBuffer, + eventLoop: EventLoop + ) -> EventLoopFuture<()> { + self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer, eventLoop: eventLoop) } /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool which is separate from any `EventLoop` thread. @@ -424,17 +475,21 @@ public struct NonBlockingFileIO: Sendable { /// - buffer: The `ByteBuffer` to write. /// - eventLoop: The `EventLoop` to create the returned `EventLoopFuture` from. /// - returns: An `EventLoopFuture` which is fulfilled if the write was successful or fails on error. - public func write(fileHandle: NIOFileHandle, - toOffset: Int64, - buffer: ByteBuffer, - eventLoop: EventLoop) -> EventLoopFuture<()> { - return self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer, eventLoop: eventLoop) + public func write( + fileHandle: NIOFileHandle, + toOffset: Int64, + buffer: ByteBuffer, + eventLoop: EventLoop + ) -> EventLoopFuture<()> { + self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer, eventLoop: eventLoop) } - private func write0(fileHandle: NIOFileHandle, - toOffset: Int64?, - buffer: ByteBuffer, - eventLoop: EventLoop) -> EventLoopFuture<()> { + private func write0( + fileHandle: NIOFileHandle, + toOffset: Int64?, + buffer: ByteBuffer, + eventLoop: EventLoop + ) -> EventLoopFuture<()> { let byteCount = buffer.readableBytes guard byteCount > 0 else { @@ -460,14 +515,18 @@ public struct NonBlockingFileIO: Sendable { precondition(ptr.count == byteCount - offsetAccumulator) let res: IOResult = try fileHandle.withUnsafeFileDescriptor { descriptor in if let toOffset = toOffset { - return try Posix.pwrite(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator, - offset: off_t(toOffset + Int64(offsetAccumulator))) + return try Posix.pwrite( + descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator, + offset: off_t(toOffset + Int64(offsetAccumulator)) + ) } else { - return try Posix.write(descriptor: descriptor, - pointer: ptr.baseAddress!, - size: byteCount - offsetAccumulator) + return try Posix.write( + descriptor: descriptor, + pointer: ptr.baseAddress!, + size: byteCount - offsetAccumulator + ) } } switch res { @@ -494,7 +553,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` containing the `NIOFileHandle` and the `FileRegion` comprising the whole file. public func openFile(path: String, eventLoop: EventLoop) -> EventLoopFuture<(NIOFileHandle, FileRegion)> { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { let fh = try NIOFileHandle(path: path) do { let fr = try FileRegion(fileHandle: fh) @@ -517,13 +576,18 @@ public struct NonBlockingFileIO: Sendable { /// - flags: Additional POSIX flags. /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` containing the `NIOFileHandle`. - public func openFile(path: String, mode: NIOFileHandle.Mode, flags: NIOFileHandle.Flags = .default, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { - return try NIOFileHandle(path: path, mode: mode, flags: flags) + public func openFile( + path: String, + mode: NIOFileHandle.Mode, + flags: NIOFileHandle.Flags = .default, + eventLoop: EventLoop + ) -> EventLoopFuture { + self.threadPool.runIfActive(eventLoop: eventLoop) { + try NIOFileHandle(path: path, mode: mode, flags: flags) } } -#if !os(Windows) + #if !os(Windows) /// Returns information about a file at `path` on a private thread pool which is separate from any `EventLoop` thread. /// /// - note: If `path` is a symlink, information about the link, not the file it points to. @@ -533,7 +597,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` containing file information. public func lstat(path: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { var s = stat() try Posix.lstat(pathname: path, outStat: &s) return s @@ -548,7 +612,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` which is fulfilled if the rename was successful or fails on error. public func symlink(path: String, to destination: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { try Posix.symlink(pathname: path, destination: destination) } } @@ -560,7 +624,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` containing link target. public func readlink(path: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { let maxLength = Int(PATH_MAX) let pointer = UnsafeMutableBufferPointer.allocate(capacity: maxLength) defer { @@ -578,7 +642,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` which is fulfilled if the rename was successful or fails on error. public func unlink(path: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { try Posix.unlink(pathname: path) } } @@ -645,8 +709,13 @@ public struct NonBlockingFileIO: Sendable { /// - withIntermediateDirectories: Whether intermediate directories should be created. /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` which is fulfilled if the rename was successful or fails on error. - public func createDirectory(path: String, withIntermediateDirectories createIntermediates: Bool = false, mode: NIOPOSIXFileMode, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + public func createDirectory( + path: String, + withIntermediateDirectories createIntermediates: Bool = false, + mode: NIOPOSIXFileMode, + eventLoop: EventLoop + ) -> EventLoopFuture { + self.threadPool.runIfActive(eventLoop: eventLoop) { if createIntermediates { #if canImport(Darwin) try Posix.mkpath_np(pathname: path, mode: mode) @@ -666,7 +735,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` containing the directory entries. public func listDirectory(path: String, eventLoop: EventLoop) -> EventLoopFuture<[NIODirectoryEntry]> { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { let dir = try Posix.opendir(pathname: path) var entries: [NIODirectoryEntry] = [] do { @@ -675,7 +744,9 @@ public struct NonBlockingFileIO: Sendable { let ptr = pointer.baseAddress!.assumingMemoryBound(to: CChar.self) return String(cString: ptr) } - entries.append(NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name)) + entries.append( + NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name) + ) } try? Posix.closedir(dir: dir) } catch { @@ -694,7 +765,7 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` which is fulfilled if the rename was successful or fails on error. public func rename(path: String, newName: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { try Posix.rename(pathname: path, newName: newName) } } @@ -706,11 +777,11 @@ public struct NonBlockingFileIO: Sendable { /// - eventLoop: The `EventLoop` on which the returned `EventLoopFuture` will fire. /// - returns: An `EventLoopFuture` which is fulfilled if the remove was successful or fails on error. public func remove(path: String, eventLoop: EventLoop) -> EventLoopFuture { - return self.threadPool.runIfActive(eventLoop: eventLoop) { + self.threadPool.runIfActive(eventLoop: eventLoop) { try Posix.remove(pathname: path) } } -#endif + #endif } #if !os(Windows) @@ -734,7 +805,7 @@ public struct NIODirectoryEntry: Hashable { extension NonBlockingFileIO { /// Read a `FileRegion` in ``NonBlockingFileIO``'s private thread pool. /// - /// The returned `ByteBuffer` will not have less than the minimum of `fileRegion.readableBytes` and `UInt32.max` unless we hit + /// The returned `ByteBuffer` will not have less than the minimum of `fileRegion.readableBytes` and `UInt32.max` unless we hit /// end-of-file in which case the `ByteBuffer` will contain the bytes available to read. /// /// This method will not use the file descriptor's seek pointer which means there is no danger of reading from the @@ -780,8 +851,8 @@ extension NonBlockingFileIO { fileHandle: NIOFileHandle, byteCount: Int, allocator: ByteBufferAllocator - ) async throws-> ByteBuffer { - return try await self.read0( + ) async throws -> ByteBuffer { + try await self.read0( fileHandle: fileHandle, fromOffset: nil, byteCount: byteCount, @@ -808,28 +879,39 @@ extension NonBlockingFileIO { /// - allocator: A `ByteBufferAllocator` used to allocate space for the returned `ByteBuffer`. /// - returns: ByteBuffer. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - public func read(fileHandle: NIOFileHandle, - fromOffset fileOffset: Int64, - byteCount: Int, - allocator: ByteBufferAllocator) async throws -> ByteBuffer { - return try await self.read0(fileHandle: fileHandle, - fromOffset: fileOffset, - byteCount: byteCount, - allocator: allocator) + public func read( + fileHandle: NIOFileHandle, + fromOffset fileOffset: Int64, + byteCount: Int, + allocator: ByteBufferAllocator + ) async throws -> ByteBuffer { + try await self.read0( + fileHandle: fileHandle, + fromOffset: fileOffset, + byteCount: byteCount, + allocator: allocator + ) } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - private func read0(fileHandle: NIOFileHandle, - fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems - byteCount rawByteCount: Int, - allocator: ByteBufferAllocator) async throws -> ByteBuffer { + private func read0( + fileHandle: NIOFileHandle, + fromOffset: Int64?, // > 2 GB offset is reasonable on 32-bit systems + byteCount rawByteCount: Int, + allocator: ByteBufferAllocator + ) async throws -> ByteBuffer { guard rawByteCount > 0 else { return allocator.buffer(capacity: 0) } let byteCount = rawByteCount < Int32.max ? rawByteCount : size_t(Int32.max) return try await self.threadPool.runIfActive { () -> ByteBuffer in - try self.readSync(fileHandle: fileHandle, fromOffset: fromOffset, byteCount: byteCount, allocator: allocator) + try self.readSync( + fileHandle: fileHandle, + fromOffset: fromOffset, + byteCount: byteCount, + allocator: allocator + ) } } @@ -842,9 +924,11 @@ extension NonBlockingFileIO { /// - fileHandle: The `NIOFileHandle` to write to. /// - size: The new file size in bytes to write. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - public func changeFileSize(fileHandle: NIOFileHandle, - size: Int64) async throws { - return try await self.threadPool.runIfActive { + public func changeFileSize( + fileHandle: NIOFileHandle, + size: Int64 + ) async throws { + try await self.threadPool.runIfActive { try fileHandle.withUnsafeFileDescriptor { descriptor -> Void in try Posix.ftruncate(descriptor: descriptor, size: off_t(size)) } @@ -857,8 +941,8 @@ extension NonBlockingFileIO { /// - fileHandle: The `NIOFileHandle` to read from. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func readFileSize(fileHandle: NIOFileHandle) async throws -> Int64 { - return try await self.threadPool.runIfActive { - return try fileHandle.withUnsafeFileDescriptor { descriptor in + try await self.threadPool.runIfActive { + try fileHandle.withUnsafeFileDescriptor { descriptor in let curr = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_CUR) let eof = try Posix.lseek(descriptor: descriptor, offset: 0, whence: SEEK_END) try Posix.lseek(descriptor: descriptor, offset: curr, whence: SEEK_SET) @@ -873,9 +957,11 @@ extension NonBlockingFileIO { /// - fileHandle: The `NIOFileHandle` to write to. /// - buffer: The `ByteBuffer` to write. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - public func write(fileHandle: NIOFileHandle, - buffer: ByteBuffer) async throws { - return try await self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer) + public func write( + fileHandle: NIOFileHandle, + buffer: ByteBuffer + ) async throws { + try await self.write0(fileHandle: fileHandle, toOffset: nil, buffer: buffer) } /// Write `buffer` starting from `toOffset` to `fileHandle` in ``NonBlockingFileIO``'s private thread pool. @@ -885,16 +971,20 @@ extension NonBlockingFileIO { /// - toOffset: The file offset to write to. /// - buffer: The `ByteBuffer` to write. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - public func write(fileHandle: NIOFileHandle, - toOffset: Int64, - buffer: ByteBuffer) async throws { - return try await self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer) + public func write( + fileHandle: NIOFileHandle, + toOffset: Int64, + buffer: ByteBuffer + ) async throws { + try await self.write0(fileHandle: fileHandle, toOffset: toOffset, buffer: buffer) } @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - private func write0(fileHandle: NIOFileHandle, - toOffset: Int64?, - buffer: ByteBuffer) async throws { + private func write0( + fileHandle: NIOFileHandle, + toOffset: Int64?, + buffer: ByteBuffer + ) async throws { let byteCount = buffer.readableBytes guard byteCount > 0 else { @@ -906,7 +996,7 @@ extension NonBlockingFileIO { } } - /// Open file at `path` and query its size on a private thread pool, run an operation given + /// Open file at `path` and query its size on a private thread pool, run an operation given /// the resulting file region and then close the file handle. /// /// The will return the result of the operation. @@ -919,7 +1009,7 @@ extension NonBlockingFileIO { /// - returns: return value of operation @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func withFileRegion( - path: String, + path: String, _ body: (_ fileRegion: FileRegion) async throws -> Result ) async throws -> Result { let fileRegion = try await self.threadPool.runIfActive { @@ -954,13 +1044,13 @@ extension NonBlockingFileIO { /// - returns: return value of operation @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func withFileHandle( - path: String, - mode: NIOFileHandle.Mode, - flags: NIOFileHandle.Flags = .default, + path: String, + mode: NIOFileHandle.Mode, + flags: NIOFileHandle.Flags = .default, _ body: (NIOFileHandle) async throws -> Result ) async throws -> Result { let fileHandle = try await self.threadPool.runIfActive { - return try UnsafeTransfer(NIOFileHandle(path: path, mode: mode, flags: flags)) + try UnsafeTransfer(NIOFileHandle(path: path, mode: mode, flags: flags)) } let result: Result do { @@ -973,7 +1063,7 @@ extension NonBlockingFileIO { return result } -#if !os(Windows) + #if !os(Windows) /// Returns information about a file at `path` on a private thread pool. /// @@ -984,7 +1074,7 @@ extension NonBlockingFileIO { /// - returns: file information. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func lstat(path: String) async throws -> stat { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { var s = stat() try Posix.lstat(pathname: path, outStat: &s) return s @@ -998,7 +1088,7 @@ extension NonBlockingFileIO { /// - destination: Target path where this link will point to. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func symlink(path: String, to destination: String) async throws { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try Posix.symlink(pathname: path, destination: destination) } } @@ -1010,7 +1100,7 @@ extension NonBlockingFileIO { /// - returns: link target. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func readlink(path: String) async throws -> String { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { let maxLength = Int(PATH_MAX) let pointer = UnsafeMutableBufferPointer.allocate(capacity: maxLength) defer { @@ -1027,20 +1117,23 @@ extension NonBlockingFileIO { /// - path: The path of the link to remove. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func unlink(path: String) async throws { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { try Posix.unlink(pathname: path) } } - /// Creates directory at `path` on a private thread pool. /// /// - parameters: /// - path: The path of the directory to be created. /// - withIntermediateDirectories: Whether intermediate directories should be created. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - public func createDirectory(path: String, withIntermediateDirectories createIntermediates: Bool = false, mode: NIOPOSIXFileMode) async throws { - return try await self.threadPool.runIfActive { + public func createDirectory( + path: String, + withIntermediateDirectories createIntermediates: Bool = false, + mode: NIOPOSIXFileMode + ) async throws { + try await self.threadPool.runIfActive { if createIntermediates { #if canImport(Darwin) try Posix.mkpath_np(pathname: path, mode: mode) @@ -1060,7 +1153,7 @@ extension NonBlockingFileIO { /// - returns: The directory entries. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func listDirectory(path: String) async throws -> [NIODirectoryEntry] { - return try await self.threadPool.runIfActive { + try await self.threadPool.runIfActive { let dir = try Posix.opendir(pathname: path) var entries: [NIODirectoryEntry] = [] do { @@ -1069,7 +1162,9 @@ extension NonBlockingFileIO { let ptr = pointer.baseAddress!.assumingMemoryBound(to: CChar.self) return String(cString: ptr) } - entries.append(NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name)) + entries.append( + NIODirectoryEntry(ino: UInt64(entry.pointee.d_ino), type: entry.pointee.d_type, name: name) + ) } try? Posix.closedir(dir: dir) } catch { @@ -1087,7 +1182,7 @@ extension NonBlockingFileIO { /// - newName: New file name. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func rename(path: String, newName: String) async throws { - return try await self.threadPool.runIfActive() { + try await self.threadPool.runIfActive { try Posix.rename(pathname: path, newName: newName) } } @@ -1098,9 +1193,9 @@ extension NonBlockingFileIO { /// - path: The path of the file to be removed. @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) public func remove(path: String) async throws { - return try await self.threadPool.runIfActive() { + try await self.threadPool.runIfActive { try Posix.remove(pathname: path) } } -#endif -} \ No newline at end of file + #endif +} diff --git a/Sources/NIOPosix/PendingDatagramWritesManager.swift b/Sources/NIOPosix/PendingDatagramWritesManager.swift index 332add9f07..31365ca987 100644 --- a/Sources/NIOPosix/PendingDatagramWritesManager.swift +++ b/Sources/NIOPosix/PendingDatagramWritesManager.swift @@ -11,8 +11,9 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -import NIOCore + import Atomics +import NIOCore private struct PendingDatagramWrite { var data: ByteBuffer @@ -46,14 +47,14 @@ private struct PendingDatagramWrite { } } -fileprivate extension Error { +extension Error { /// Returns whether the error is "recoverable" from the perspective of datagram sending. /// /// - returns: `true` if the error is recoverable, `false` otherwise. - var isRecoverable: Bool { + fileprivate var isRecoverable: Bool { switch self { case let e as IOError where e.errnoCode == EMSGSIZE, - let e as IOError where e.errnoCode == EHOSTUNREACH: + let e as IOError where e.errnoCode == EHOSTUNREACH: return true default: return false @@ -62,15 +63,19 @@ fileprivate extension Error { } /// Does the setup required to trigger a `sendmmsg`. -private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWritesState, - bufferPool: Pool, - msgs: UnsafeMutableBufferPointer, - addresses: UnsafeMutableBufferPointer, - controlMessageStorage: UnsafeControlMessageStorage, - _ body: (UnsafeMutableBufferPointer) throws -> IOResult) throws -> IOResult { +private func doPendingDatagramWriteVectorOperation( + pending: PendingDatagramWritesState, + bufferPool: Pool, + msgs: UnsafeMutableBufferPointer, + addresses: UnsafeMutableBufferPointer, + controlMessageStorage: UnsafeControlMessageStorage, + _ body: (UnsafeMutableBufferPointer) throws -> IOResult +) throws -> IOResult { assert(msgs.count >= Socket.writevLimitIOVectors, "Insufficiently sized buffer for a maximal sendmmsg") - assert(controlMessageStorage.count >= Socket.writevLimitIOVectors, - "Insufficiently sized control message storage for a maximal sendmmsg") + assert( + controlMessageStorage.count >= Socket.writevLimitIOVectors, + "Insufficiently sized control message storage for a maximal sendmmsg" + ) // the numbers of storage refs that we need to decrease later. var c = 0 @@ -85,7 +90,7 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite // TODO(cory): I can't see this limit documented in a man page anywhere, but it seems // plausible given that a similar limit exists for TCP. For now we assume it's present // in UDP until I can do some research to validate the existence of this limit. - guard (Socket.writevLimitBytes - toWrite >= p.data.readableBytes) else { + guard Socket.writevLimitBytes - toWrite >= p.data.readableBytes else { if c == 0 { // The first buffer is larger than the writev limit. Let's throw, and fall back to linear processing. throw IOError(errnoCode: EMSGSIZE, reason: "synthetic error for overlarge write") @@ -129,7 +134,10 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite protocolFamily = connectedRemoteAddress.protocol } - iovecs[c] = iovec(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer)) + iovecs[c] = iovec( + iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), + iov_len: numericCast(toWriteForThisBuffer) + ) var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c]) controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily) @@ -174,7 +182,7 @@ private struct PendingDatagramWritesState { private(set) var remoteAddress: SocketAddress? = nil public var nextWrite: PendingDatagramWrite? { - return self.pendingWrites.first + self.pendingWrites.first } /// Subtract `bytes` from the number of outstanding bytes to write. @@ -197,7 +205,7 @@ private struct PendingDatagramWritesState { } /// Initialise a new, empty `PendingWritesState`. - public init() { } + public init() {} /// Check if there are no outstanding writes. public var isEmpty: Bool { @@ -238,7 +246,10 @@ private struct PendingDatagramWritesState { /// - data: The result of the write operation: namely, for each datagram we attempted to write, the number of bytes we wrote. /// - messages: The vector messages written, if any. /// - returns: A promise and the error that should be sent to it, if any, and a `WriteResult` which indicates if we could write everything or not. - public mutating func didWrite(_ data: IOResult, messages: UnsafeMutableBufferPointer?) -> (DatagramWritePromiseFiller?, OneWriteOperationResult) { + public mutating func didWrite( + _ data: IOResult, + messages: UnsafeMutableBufferPointer? + ) -> (DatagramWritePromiseFiller?, OneWriteOperationResult) { switch data { case .processed(let written): if let messages = messages { @@ -267,7 +278,10 @@ private struct PendingDatagramWritesState { /// - messages: The list of message objects. /// - returns: A closure that the caller _needs_ to run which will fulfill the promises of the writes, and a `WriteResult` that indicates if we could write /// everything or not. - private mutating func didVectorWrite(written: Int, messages: UnsafeMutableBufferPointer) -> (DatagramWritePromiseFiller?, OneWriteOperationResult) { + private mutating func didVectorWrite( + written: Int, + messages: UnsafeMutableBufferPointer + ) -> (DatagramWritePromiseFiller?, OneWriteOperationResult) { // This was a vector write. We wrote `written` number of messages. let writes = messages[messages.startIndex...messages.index(messages.startIndex, offsetBy: written - 1)] var promiseFiller: DatagramWritePromiseFiller? @@ -283,7 +297,7 @@ private struct PendingDatagramWritesState { case (.none, .some(let this)): promiseFiller = this case (.some, .none), - (.none, .none): + (.none, .none): break } } @@ -300,8 +314,10 @@ private struct PendingDatagramWritesState { /// - returns: All the promises that must be fired, and a `WriteResult` that indicates if we could write /// everything or not. private mutating func didScalarWrite(written: Int) -> (DatagramWritePromiseFiller?, OneWriteOperationResult) { - precondition(written <= self.pendingWrites.first!.data.readableBytes, - "Appeared to write more bytes (\(written)) than the datagram contained (\(self.pendingWrites.first!.data.readableBytes))") + precondition( + written <= self.pendingWrites.first!.data.readableBytes, + "Appeared to write more bytes (\(written)) than the datagram contained (\(self.pendingWrites.first!.data.readableBytes))" + ) let writeFiller = self.wroteFirst() // If we no longer have a mark, we wrote everything. let result: OneWriteOperationResult = self.pendingWrites.hasMark ? .writtenPartially : .writtenCompletely @@ -310,7 +326,7 @@ private struct PendingDatagramWritesState { /// Is there a pending flush? public var isFlushPending: Bool { - return self.pendingWrites.hasMark + self.pendingWrites.hasMark } /// Fail all the outstanding writes. @@ -329,7 +345,9 @@ private struct PendingDatagramWritesState { w.promise.map { promises.append($0) } } - promises.forEach { $0.fail(error) } + for promise in promises { + promise.fail(error) + } } /// Returns the best mechanism to write pending data at the current point in time. @@ -339,7 +357,7 @@ private struct PendingDatagramWritesState { return .vectorBufferWrite case .some(let e): // The compiler can't prove this, but it must be so. - assert(self.pendingWrites.distance(from: e, to: self.pendingWrites.startIndex) == 0) + assert(self.pendingWrites.distance(from: e, to: self.pendingWrites.startIndex) == 0) return .scalarBufferWrite default: return .nothingToBeWritten @@ -361,8 +379,12 @@ extension PendingDatagramWritesState { } mutating func next() -> PendingDatagramWrite? { - while let markedIndex = self.markedIndex, self.pendingWrites.pendingWrites.distance(from: self.index, - to: markedIndex) >= 0 { + while let markedIndex = self.markedIndex, + self.pendingWrites.pendingWrites.distance( + from: self.index, + to: markedIndex + ) >= 0 + { let element = self.pendingWrites.pendingWrites[index] index = self.pendingWrites.pendingWrites.index(after: index) return element @@ -373,7 +395,7 @@ extension PendingDatagramWritesState { } var flushedWrites: FlushedDatagramWriteSequence { - return FlushedDatagramWriteSequence(self) + FlushedDatagramWriteSequence(self) } } @@ -387,7 +409,10 @@ final class PendingDatagramWritesManager: PendingWritesManager { private var state = PendingDatagramWritesState() - internal var waterMark: ChannelOptions.Types.WriteBufferWaterMark = ChannelOptions.Types.WriteBufferWaterMark(low: 32 * 1024, high: 64 * 1024) + internal var waterMark: ChannelOptions.Types.WriteBufferWaterMark = ChannelOptions.Types.WriteBufferWaterMark( + low: 32 * 1024, + high: 64 * 1024 + ) internal let channelWritabilityFlag = ManagedAtomic(true) internal var publishedWritability = true internal var writeSpinCount: UInt = 16 @@ -418,20 +443,21 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// Is there a flush pending? var isFlushPending: Bool { - return self.state.isFlushPending + self.state.isFlushPending } /// Are there any outstanding writes currently? var isEmpty: Bool { - return self.state.isEmpty + self.state.isEmpty } private func add(_ pendingWrite: PendingDatagramWrite) -> Bool { assert(self.isOpen) self.state.append(pendingWrite) - if self.state.bytes > waterMark.high && - channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged { + if self.state.bytes > waterMark.high + && channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged + { // Returns false to signal the Channel became non-writable and we need to notify the user. self.publishedWritability = false return false @@ -450,20 +476,29 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// address of the connected peer, otherwise this function will throw a fatal error. func add(envelope: AddressedEnvelope, promise: EventLoopPromise?) -> Bool { if let remoteAddress = self.state.remoteAddress { - precondition(envelope.remoteAddress == remoteAddress, """ + precondition( + envelope.remoteAddress == remoteAddress, + """ Remote address of AddressedEnvelope does not match remote address of connected socket. - """) - return self.add(PendingDatagramWrite( - data: envelope.data, - promise: promise, - address: nil, - metadata: envelope.metadata)) + """ + ) + return self.add( + PendingDatagramWrite( + data: envelope.data, + promise: promise, + address: nil, + metadata: envelope.metadata + ) + ) } else { - return self.add(PendingDatagramWrite( - data: envelope.data, - promise: promise, - address: envelope.remoteAddress, - metadata: envelope.metadata)) + return self.add( + PendingDatagramWrite( + data: envelope.data, + promise: promise, + address: envelope.remoteAddress, + metadata: envelope.metadata + ) + ) } } @@ -474,16 +509,19 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// - promise: Optionally an `EventLoopPromise` that will get the write operation's result /// - returns: If the `Channel` is still writable after adding the write of `data`. func add(data: ByteBuffer, promise: EventLoopPromise?) -> Bool { - return self.add(PendingDatagramWrite( - data: data, - promise: promise, - address: nil, - metadata: nil)) + self.add( + PendingDatagramWrite( + data: data, + promise: promise, + address: nil, + metadata: nil + ) + ) } /// Returns the best mechanism to write pending data at the current point in time. var currentBestWriteMechanism: WriteMechanism { - return self.state.currentBestWriteMechanism + self.state.currentBestWriteMechanism } /// Triggers the appropriate write operation. This is a fancy way of saying trigger either `sendto` or `sendmmsg`. @@ -493,9 +531,13 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendmsg`). /// - vectorWriteOperation: An operation that writes multiple contiguous arrays of bytes (usually `sendmmsg`). /// - returns: The `WriteResult` and whether the `Channel` is now writable. - func triggerAppropriateWriteOperations(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult, - vectorWriteOperation: (UnsafeMutableBufferPointer) throws -> IOResult) throws -> OverallWriteResult { - return try self.triggerWriteOperations { writeMechanism in + func triggerAppropriateWriteOperations( + scalarWriteOperation: ( + UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata? + ) throws -> IOResult, + vectorWriteOperation: (UnsafeMutableBufferPointer) throws -> IOResult + ) throws -> OverallWriteResult { + try self.triggerWriteOperations { writeMechanism in switch writeMechanism { case .scalarBufferWrite: return try triggerScalarBufferWrite(scalarWriteOperation: { try scalarWriteOperation($0, $1, $2, $3) }) @@ -509,7 +551,9 @@ final class PendingDatagramWritesManager: PendingWritesManager { throw error } - return try triggerScalarBufferWrite(scalarWriteOperation: { try scalarWriteOperation($0, $1, $2, $3) }) + return try triggerScalarBufferWrite(scalarWriteOperation: { + try scalarWriteOperation($0, $1, $2, $3) + }) } case .scalarFileWrite: preconditionFailure("PendingDatagramWritesManager was handed a file write") @@ -525,7 +569,10 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// /// - parameters: /// - data: The result of the write operation. - private func didWrite(_ data: IOResult, messages: UnsafeMutableBufferPointer?) -> OneWriteOperationResult { + private func didWrite( + _ data: IOResult, + messages: UnsafeMutableBufferPointer? + ) -> OneWriteOperationResult { let (promise, result) = self.state.didWrite(data, messages: messages) if self.state.bytes < waterMark.low { @@ -563,9 +610,15 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// /// - parameters: /// - scalarWriteOperation: An operation that writes a single, contiguous array of bytes (usually `sendto`). - private func triggerScalarBufferWrite(scalarWriteOperation: (UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata?) throws -> IOResult) rethrows -> OneWriteOperationResult { - assert(self.state.isFlushPending && self.isOpen && !self.state.isEmpty, - "illegal state for scalar datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)") + private func triggerScalarBufferWrite( + scalarWriteOperation: ( + UnsafeRawBufferPointer, UnsafePointer?, socklen_t, AddressedEnvelope.Metadata? + ) throws -> IOResult + ) rethrows -> OneWriteOperationResult { + assert( + self.state.isFlushPending && self.isOpen && !self.state.isEmpty, + "illegal state for scalar datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)" + ) let pending = self.state.nextWrite! do { let writeResult: IOResult @@ -600,21 +653,29 @@ final class PendingDatagramWritesManager: PendingWritesManager { /// /// - parameters: /// - vectorWriteOperation: The vector write operation to use. Usually `sendmmsg`. - private func triggerVectorBufferWrite(vectorWriteOperation: (UnsafeMutableBufferPointer) throws -> IOResult) throws -> OneWriteOperationResult { - assert(self.state.isFlushPending && self.isOpen && !self.state.isEmpty, - "illegal state for vector datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)") + private func triggerVectorBufferWrite( + vectorWriteOperation: (UnsafeMutableBufferPointer) throws -> IOResult + ) throws -> OneWriteOperationResult { + assert( + self.state.isFlushPending && self.isOpen && !self.state.isEmpty, + "illegal state for vector datagram write operation: flushPending: \(self.state.isFlushPending), isOpen: \(self.isOpen), empty: \(self.state.isEmpty)" + ) let msgBuffer = self.msgBufferPool.get() defer { self.msgBufferPool.put(msgBuffer) } return try msgBuffer.withUnsafePointers { msgs, addresses, controlMessageStorage in - return self.didWrite(try doPendingDatagramWriteVectorOperation(pending: self.state, - bufferPool: self.bufferPool, - msgs: msgs, - addresses: addresses, - controlMessageStorage: controlMessageStorage, - { try vectorWriteOperation($0) }), - messages: msgs) + self.didWrite( + try doPendingDatagramWriteVectorOperation( + pending: self.state, + bufferPool: self.bufferPool, + msgs: msgs, + addresses: addresses, + controlMessageStorage: controlMessageStorage, + { try vectorWriteOperation($0) } + ), + messages: msgs + ) } } diff --git a/Sources/NIOPosix/PendingWritesManager.swift b/Sources/NIOPosix/PendingWritesManager.swift index cda595f0e8..10c22cd204 100644 --- a/Sources/NIOPosix/PendingWritesManager.swift +++ b/Sources/NIOPosix/PendingWritesManager.swift @@ -11,8 +11,9 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -import NIOCore + import Atomics +import NIOCore private struct PendingStreamWrite { var data: IOData @@ -26,9 +27,11 @@ private struct PendingStreamWrite { /// - bufferPool: Pool of buffers to use for iovecs and storageRefs /// - body: The function that actually does the vector write (usually `writev`). /// - returns: A tuple of the number of items attempted to write and the result of the write operation. -private func doPendingWriteVectorOperation(pending: PendingStreamWritesState, - bufferPool: Pool, - _ body: (UnsafeBufferPointer) throws -> IOResult) throws -> (itemCount: Int, writeResult: IOResult) { +private func doPendingWriteVectorOperation( + pending: PendingStreamWritesState, + bufferPool: Pool, + _ body: (UnsafeBufferPointer) throws -> IOResult +) throws -> (itemCount: Int, writeResult: IOResult) { let buffer = bufferPool.get() defer { bufferPool.put(buffer) } @@ -46,7 +49,8 @@ private func doPendingWriteVectorOperation(pending: PendingStreamWritesState, switch p.data { case .byteBuffer(let buffer): // Must not write more than Int32.max in one go. - guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) else { + guard (numberOfUsedStorageSlots == 0) || (Socket.writevLimitBytes - toWrite >= buffer.readableBytes) + else { break loop } let toWriteForThisBuffer = min(Socket.writevLimitBytes, buffer.readableBytes) @@ -54,7 +58,10 @@ private func doPendingWriteVectorOperation(pending: PendingStreamWritesState, buffer.withUnsafeReadableBytesWithStorageManagement { ptr, storageRef in storageRefs[i] = storageRef.retain() - iovecs[i] = IOVector(iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), iov_len: numericCast(toWriteForThisBuffer)) + iovecs[i] = IOVector( + iov_base: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), + iov_len: numericCast(toWriteForThisBuffer) + ) } numberOfUsedStorageSlots += 1 case .fileRegion: @@ -69,7 +76,7 @@ private func doPendingWriteVectorOperation(pending: PendingStreamWritesState, } } let result = try body(UnsafeBufferPointer(start: iovecs.baseAddress!, count: numberOfUsedStorageSlots)) - /* if we hit a limit, we really wanted to write more than we have so the caller should retry us */ + // if we hit a limit, we really wanted to write more than we have so the caller should retry us return (numberOfUsedStorageSlots, result) } } @@ -114,7 +121,7 @@ private struct PendingStreamWritesState { public private(set) var bytes: Int64 = 0 public var flushedChunks: Int { - return self.pendingWrites.markedElementIndex.map { + self.pendingWrites.markedElementIndex.map { self.pendingWrites.distance(from: self.pendingWrites.startIndex, to: $0) + 1 } ?? 0 } @@ -145,7 +152,7 @@ private struct PendingStreamWritesState { } /// Initialise a new, empty `PendingWritesState`. - public init() { } + public init() {} /// Check if there are no outstanding writes. public var isEmpty: Bool { @@ -172,7 +179,7 @@ private struct PendingStreamWritesState { /// Get the outstanding write at `index`. public subscript(index: Int) -> PendingStreamWrite { - return self.pendingWrites[self.pendingWrites.index(self.pendingWrites.startIndex, offsetBy: index)] + self.pendingWrites[self.pendingWrites.index(self.pendingWrites.startIndex, offsetBy: index)] } /// Mark the flush checkpoint. @@ -191,7 +198,10 @@ private struct PendingStreamWritesState { /// - returns: A tuple of a promise and a `OneWriteResult`. The promise is the first promise that needs to be notified of the write result. /// This promise will cascade the result to all other promises that need notifying. If no promises need to be notified, will be `nil`. /// The write result will indicate whether we were able to write everything or not. - public mutating func didWrite(itemCount: Int, result writeResult: IOResult) -> (EventLoopPromise?, OneWriteOperationResult) { + public mutating func didWrite( + itemCount: Int, + result writeResult: IOResult + ) -> (EventLoopPromise?, OneWriteOperationResult) { switch writeResult { case .wouldBlock(0): return (nil, .wouldBlock) @@ -203,7 +213,7 @@ private struct PendingStreamWritesState { let headItemReadableBytes = self.pendingWrites.first!.data.readableBytes if unaccountedWrites >= headItemReadableBytes { unaccountedWrites -= headItemReadableBytes - /* we wrote at least the whole head item, so drop it and succeed the promise */ + // we wrote at least the whole head item, so drop it and succeed the promise if let promise = self.fullyWrittenFirst() { if let p = promise0 { p.futureResult.cascade(to: promise) @@ -212,21 +222,24 @@ private struct PendingStreamWritesState { } } } else { - /* we could only write a part of the head item, so don't drop it but remember what we wrote */ + // we could only write a part of the head item, so don't drop it but remember what we wrote self.partiallyWrittenFirst(bytes: unaccountedWrites) // may try again depending on the writeSpinCount return (promise0, .writtenPartially) } } - assert(unaccountedWrites == 0, "after doing all the accounting for the byte written, \(unaccountedWrites) bytes of unaccounted writes remain.") + assert( + unaccountedWrites == 0, + "after doing all the accounting for the byte written, \(unaccountedWrites) bytes of unaccounted writes remain." + ) return (promise0, .writtenCompletely) } } /// Is there a pending flush? public var isFlushPending: Bool { - return self.pendingWrites.hasMark + self.pendingWrites.hasMark } /// Remove all pending writes and return a `EventLoopPromise` which will cascade notifications to all. @@ -263,8 +276,10 @@ private struct PendingStreamWritesState { } default: let startIndex = self.pendingWrites.startIndex - switch (self.pendingWrites[startIndex].data, - self.pendingWrites[self.pendingWrites.index(after: startIndex)].data) { + switch ( + self.pendingWrites[startIndex].data, + self.pendingWrites[self.pendingWrites.index(after: startIndex)].data + ) { case (.byteBuffer, .byteBuffer): return .vectorBufferWrite case (.byteBuffer, .fileRegion): @@ -283,7 +298,10 @@ final class PendingStreamWritesManager: PendingWritesManager { private var state = PendingStreamWritesState() private let bufferPool: Pool - internal var waterMark: ChannelOptions.Types.WriteBufferWaterMark = ChannelOptions.Types.WriteBufferWaterMark(low: 32 * 1024, high: 64 * 1024) + internal var waterMark: ChannelOptions.Types.WriteBufferWaterMark = ChannelOptions.Types.WriteBufferWaterMark( + low: 32 * 1024, + high: 64 * 1024 + ) internal let channelWritabilityFlag = ManagedAtomic(true) internal var publishedWritability = true @@ -298,12 +316,12 @@ final class PendingStreamWritesManager: PendingWritesManager { /// Is there a flush pending? var isFlushPending: Bool { - return self.state.isFlushPending + self.state.isFlushPending } /// Are there any outstanding writes currently? var isEmpty: Bool { - return self.state.isEmpty + self.state.isEmpty } /// Add a pending write alongside its promise. @@ -316,8 +334,9 @@ final class PendingStreamWritesManager: PendingWritesManager { assert(self.isOpen) self.state.append(.init(data: data, promise: promise)) - if self.state.bytes > waterMark.high && - channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged { + if self.state.bytes > waterMark.high + && channelWritabilityFlag.compareExchange(expected: true, desired: false, ordering: .relaxed).exchanged + { // Returns false to signal the Channel became non-writable and we need to notify the user. self.publishedWritability = false return false @@ -327,7 +346,7 @@ final class PendingStreamWritesManager: PendingWritesManager { /// Returns the best mechanism to write pending data at the current point in time. var currentBestWriteMechanism: WriteMechanism { - return self.state.currentBestWriteMechanism + self.state.currentBestWriteMechanism } /// Triggers the appropriate write operation. This is a fancy way of saying trigger either `write`, `writev` or @@ -338,10 +357,12 @@ final class PendingStreamWritesManager: PendingWritesManager { /// - vectorBufferWriteOperation: An operation that writes multiple contiguous arrays of bytes (usually `writev`). /// - scalarFileWriteOperation: An operation that writes a region of a file descriptor (usually `sendfile`). /// - returns: The `OneWriteOperationResult` and whether the `Channel` is now writable. - func triggerAppropriateWriteOperations(scalarBufferWriteOperation: (UnsafeRawBufferPointer) throws -> IOResult, - vectorBufferWriteOperation: (UnsafeBufferPointer) throws -> IOResult, - scalarFileWriteOperation: (CInt, Int, Int) throws -> IOResult) throws -> OverallWriteResult { - return try self.triggerWriteOperations { writeMechanism in + func triggerAppropriateWriteOperations( + scalarBufferWriteOperation: (UnsafeRawBufferPointer) throws -> IOResult, + vectorBufferWriteOperation: (UnsafeBufferPointer) throws -> IOResult, + scalarFileWriteOperation: (CInt, Int, Int) throws -> IOResult + ) throws -> OverallWriteResult { + try self.triggerWriteOperations { writeMechanism in switch writeMechanism { case .scalarBufferWrite: return try triggerScalarBufferWrite({ try scalarBufferWriteOperation($0) }) @@ -377,9 +398,13 @@ final class PendingStreamWritesManager: PendingWritesManager { /// /// - parameters: /// - operation: An operation that writes a single, contiguous array of bytes (usually `write`). - private func triggerScalarBufferWrite(_ operation: (UnsafeRawBufferPointer) throws -> IOResult) throws -> OneWriteOperationResult { - assert(self.state.isFlushPending && !self.state.isEmpty && self.isOpen, - "single write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)") + private func triggerScalarBufferWrite( + _ operation: (UnsafeRawBufferPointer) throws -> IOResult + ) throws -> OneWriteOperationResult { + assert( + self.state.isFlushPending && !self.state.isEmpty && self.isOpen, + "single write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)" + ) switch self.state[0].data { case .byteBuffer(let buffer): @@ -393,9 +418,13 @@ final class PendingStreamWritesManager: PendingWritesManager { /// /// - parameters: /// - operation: An operation that writes a region of a file descriptor. - private func triggerScalarFileWrite(_ operation: (CInt, Int, Int) throws -> IOResult) throws -> OneWriteOperationResult { - assert(self.state.isFlushPending && !self.state.isEmpty && self.isOpen, - "single write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)") + private func triggerScalarFileWrite( + _ operation: (CInt, Int, Int) throws -> IOResult + ) throws -> OneWriteOperationResult { + assert( + self.state.isFlushPending && !self.state.isEmpty && self.isOpen, + "single write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)" + ) switch self.state[0].data { case .fileRegion(let file): @@ -413,12 +442,18 @@ final class PendingStreamWritesManager: PendingWritesManager { /// /// - parameters: /// - operation: The vector write operation to use. Usually `writev`. - private func triggerVectorBufferWrite(_ operation: (UnsafeBufferPointer) throws -> IOResult) throws -> OneWriteOperationResult { - assert(self.state.isFlushPending && !self.state.isEmpty && self.isOpen, - "vector write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)") - let result = try doPendingWriteVectorOperation(pending: self.state, - bufferPool: bufferPool, - { try operation($0) }) + private func triggerVectorBufferWrite( + _ operation: (UnsafeBufferPointer) throws -> IOResult + ) throws -> OneWriteOperationResult { + assert( + self.state.isFlushPending && !self.state.isEmpty && self.isOpen, + "vector write called in illegal state: flush pending: \(self.state.isFlushPending), empty: \(self.state.isEmpty), isOpen: \(self.isOpen)" + ) + let result = try doPendingWriteVectorOperation( + pending: self.state, + bufferPool: bufferPool, + { try operation($0) } + ) return self.didWrite(itemCount: result.itemCount, result: result.writeResult) } @@ -469,10 +504,12 @@ internal protocol PendingWritesManager: AnyObject { extension PendingWritesManager { // This is called from `Channel` API so must be thread-safe. var isWritable: Bool { - return self.channelWritabilityFlag.load(ordering: .relaxed) + self.channelWritabilityFlag.load(ordering: .relaxed) } - internal func triggerWriteOperations(triggerOneWriteOperation: (WriteMechanism) throws -> OneWriteOperationResult) throws -> OverallWriteResult { + internal func triggerWriteOperations( + triggerOneWriteOperation: (WriteMechanism) throws -> OneWriteOperationResult + ) throws -> OverallWriteResult { var result = OverallWriteResult(writeResult: .couldNotWriteEverything, writabilityChange: false) writeSpinLoop: for _ in 0...self.writeSpinCount { @@ -512,7 +549,7 @@ extension PendingWritesManager { extension PendingStreamWritesManager: CustomStringConvertible { var description: String { - return "PendingStreamWritesManager { isFlushPending: \(self.isFlushPending), " + - /* */ "writabilityFlag: \(self.channelWritabilityFlag.load(ordering: .relaxed))), state: \(self.state) }" + "PendingStreamWritesManager { isFlushPending: \(self.isFlushPending), " + + "writabilityFlag: \(self.channelWritabilityFlag.load(ordering: .relaxed))), state: \(self.state) }" } } diff --git a/Sources/NIOPosix/PipeChannel.swift b/Sources/NIOPosix/PipeChannel.swift index bd965c7bda..a049cf9f91 100644 --- a/Sources/NIOPosix/PipeChannel.swift +++ b/Sources/NIOPosix/PipeChannel.swift @@ -36,15 +36,20 @@ final class PipeChannel: BaseStreamSocketChannel { } func registrationForInput(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { - return NIORegistration(channel: .pipeChannel(self, .input), - interested: interested, - registrationID: registrationID) + NIORegistration( + channel: .pipeChannel(self, .input), + interested: interested, + registrationID: registrationID + ) } - func registrationForOutput(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { - return NIORegistration(channel: .pipeChannel(self, .output), - interested: interested, - registrationID: registrationID) + func registrationForOutput(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration + { + NIORegistration( + channel: .pipeChannel(self, .output), + interested: interested, + registrationID: registrationID + ) } override func connectSocket(to address: SocketAddress) throws -> Bool { @@ -133,6 +138,6 @@ final class PipeChannel: BaseStreamSocketChannel { extension PipeChannel: CustomStringConvertible { var description: String { - return "PipeChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" + "PipeChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" } } diff --git a/Sources/NIOPosix/PipePair.swift b/Sources/NIOPosix/PipePair.swift index d3b28bd63a..76b58c2f25 100644 --- a/Sources/NIOPosix/PipePair.swift +++ b/Sources/NIOPosix/PipePair.swift @@ -17,7 +17,7 @@ struct SelectableFileHandle { var handle: NIOFileHandle var isOpen: Bool { - return handle.isOpen + handle.isOpen } init(_ handle: NIOFileHandle) { @@ -31,7 +31,7 @@ struct SelectableFileHandle { extension SelectableFileHandle: Selectable { func withUnsafeHandle(_ body: (CInt) throws -> T) throws -> T { - return try self.handle.withUnsafeFileDescriptor(body) + try self.handle.withUnsafeFileDescriptor(body) } } @@ -61,7 +61,7 @@ final class PipePair: SocketProtocol { } var description: String { - return "PipePair { in=\(String(describing: inputFD)), out=\(String(describing: inputFD)) }" + "PipePair { in=\(String(describing: inputFD)), out=\(String(describing: inputFD)) }" } func connect(to address: SocketAddress) throws -> Bool { @@ -99,17 +99,21 @@ final class PipePair: SocketProtocol { } } - func recvmsg(pointer: UnsafeMutableRawBufferPointer, - storage: inout sockaddr_storage, - storageLen: inout socklen_t, - controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult { + func recvmsg( + pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout UnsafeReceivedControlBytes + ) throws -> IOResult { throw ChannelError._operationUnsupported } - func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer?, - destinationSize: socklen_t, - controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult { + func sendmsg( + pointer: UnsafeRawBufferPointer, + destinationPtr: UnsafePointer?, + destinationSize: socklen_t, + controlBytes: UnsafeMutableRawBufferPointer + ) throws -> IOResult { throw ChannelError._operationUnsupported } @@ -137,7 +141,7 @@ final class PipePair: SocketProtocol { } var isOpen: Bool { - return self.inputFD?.isOpen ?? false || self.outputFD?.isOpen ?? false + self.inputFD?.isOpen ?? false || self.outputFD?.isOpen ?? false } func close() throws { diff --git a/Sources/NIOPosix/Pool.swift b/Sources/NIOPosix/Pool.swift index 6d6d10b22e..e46f57537b 100644 --- a/Sources/NIOPosix/Pool.swift +++ b/Sources/NIOPosix/Pool.swift @@ -35,17 +35,15 @@ class Pool { func get() -> Element { if elements.isEmpty { return Element() - } - else { + } else { return elements.removeLast() } } func put(_ e: Element) { - if (elements.count == maxSize) { + if elements.count == maxSize { e.evictedFromPool() - } - else { + } else { elements.append(e) } } @@ -58,7 +56,7 @@ class Pool { /// be bound to a single thread, and ensures that the allocation it stores does not /// get freed before the buffer is out of use. struct PooledBuffer: PoolElement { - private static let sentinelValue = MemorySentinel(0xdeadbeef) + private static let sentinelValue = MemorySentinel(0xdead_beef) private let storage: BackingStorage @@ -72,7 +70,8 @@ struct PooledBuffer: PoolElement { } func withUnsafePointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>) throws -> ReturnValue + _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>) throws -> + ReturnValue ) rethrows -> ReturnValue { defer { self.validateSentinel() @@ -94,7 +93,9 @@ struct PooledBuffer: PoolElement { /// - body: The closure that will accept the yielded pointers and the `storageManagement`. /// - returns: The value returned by `body`. func withUnsafePointersWithStorageManagement( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>, Unmanaged) throws -> ReturnValue + _ body: ( + UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>, Unmanaged + ) throws -> ReturnValue ) rethrows -> ReturnValue { let storageRef: Unmanaged = Unmanaged.passUnretained(self.storage) return try self.storage.withUnsafeMutableTypedPointers { iovecPointer, ownerPointer, _ in @@ -165,27 +166,44 @@ extension PooledBuffer { // Here we set up our memory bindings. let storage = unsafeDowncast(baseStorage, to: Self.self) storage.withUnsafeMutablePointers { headPointer, tailPointer in - UnsafeRawPointer(tailPointer + headPointer.pointee.iovectorOffset).bindMemory(to: IOVector.self, capacity: iovectorCount) - UnsafeRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset).bindMemory(to: Unmanaged.self, capacity: iovectorCount) - UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory(to: MemorySentinel.self, capacity: 1) + UnsafeRawPointer(tailPointer + headPointer.pointee.iovectorOffset).bindMemory( + to: IOVector.self, + capacity: iovectorCount + ) + UnsafeRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset).bindMemory( + to: Unmanaged.self, + capacity: iovectorCount + ) + UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory( + to: MemorySentinel.self, + capacity: 1 + ) } return storage } func withUnsafeMutableTypedPointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>, UnsafeMutablePointer) throws -> ReturnType + _ body: ( + UnsafeMutableBufferPointer, UnsafeMutableBufferPointer>, + UnsafeMutablePointer + ) throws -> ReturnType ) rethrows -> ReturnType { - return try self.withUnsafeMutablePointers { headPointer, tailPointer in - let iovecPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.iovectorOffset).assumingMemoryBound(to: IOVector.self) - let ownersPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset).assumingMemoryBound(to: Unmanaged.self) - let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).assumingMemoryBound(to: MemorySentinel.self) + try self.withUnsafeMutablePointers { headPointer, tailPointer in + let iovecPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.iovectorOffset) + .assumingMemoryBound(to: IOVector.self) + let ownersPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.bufferOwnersOffset) + .assumingMemoryBound(to: Unmanaged.self) + let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset) + .assumingMemoryBound(to: MemorySentinel.self) let iovecBufferPointer = UnsafeMutableBufferPointer( - start: iovecPointer, count: headPointer.pointee.iovectorCount + start: iovecPointer, + count: headPointer.pointee.iovectorCount ) let ownersBufferPointer = UnsafeMutableBufferPointer( - start: ownersPointer, count: headPointer.pointee.iovectorCount + start: ownersPointer, + count: headPointer.pointee.iovectorCount ) return try body(iovecBufferPointer, ownersBufferPointer, sentinelPointer) } @@ -206,7 +224,7 @@ extension Int { struct PooledMsgBuffer: PoolElement { private typealias MemorySentinel = UInt32 - private static let sentinelValue = MemorySentinel(0xdeadbeef) + private static let sentinelValue = MemorySentinel(0xdead_beef) private struct PooledMsgBufferHead { let count: Int @@ -247,7 +265,7 @@ struct PooledMsgBuffer: PoolElement { } var memorySentinelOffset: Int { - return self.spaceForMsgHdrs + self.spaceForAddresses + self.spaceForControlData + self.spaceForMsgHdrs + self.spaceForAddresses + self.spaceForControlData } } @@ -261,33 +279,54 @@ struct PooledMsgBuffer: PoolElement { let storage = unsafeDowncast(baseStorage, to: Self.self) storage.withUnsafeMutablePointers { headPointer, tailPointer in - UnsafeRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset).bindMemory(to: MMsgHdr.self, capacity: count) - UnsafeRawPointer(tailPointer + headPointer.pointee.addressesOffset).bindMemory(to: sockaddr_storage.self, capacity: count) + UnsafeRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset).bindMemory( + to: MMsgHdr.self, + capacity: count + ) + UnsafeRawPointer(tailPointer + headPointer.pointee.addressesOffset).bindMemory( + to: sockaddr_storage.self, + capacity: count + ) // space for control message data not needed to be bound - UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory(to: MemorySentinel.self, capacity: 1) + UnsafeRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).bindMemory( + to: MemorySentinel.self, + capacity: 1 + ) } return storage } func withUnsafeMutableTypedPointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage, UnsafeMutablePointer) throws -> ReturnType + _ body: ( + UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, + UnsafeControlMessageStorage, UnsafeMutablePointer + ) throws -> ReturnType ) rethrows -> ReturnType { - return try self.withUnsafeMutablePointers { headPointer, tailPointer in - let msgHdrsPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset).assumingMemoryBound(to: MMsgHdr.self) - let addressesPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.addressesOffset).assumingMemoryBound(to: sockaddr_storage.self) - let controlDataPointer = UnsafeMutableRawBufferPointer(start: tailPointer + headPointer.pointee.controlDataOffset, count: headPointer.pointee.spaceForControlData) - let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset).assumingMemoryBound(to: MemorySentinel.self) + try self.withUnsafeMutablePointers { headPointer, tailPointer in + let msgHdrsPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.msgHdrsOffset) + .assumingMemoryBound(to: MMsgHdr.self) + let addressesPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.addressesOffset) + .assumingMemoryBound(to: sockaddr_storage.self) + let controlDataPointer = UnsafeMutableRawBufferPointer( + start: tailPointer + headPointer.pointee.controlDataOffset, + count: headPointer.pointee.spaceForControlData + ) + let sentinelPointer = UnsafeMutableRawPointer(tailPointer + headPointer.pointee.memorySentinelOffset) + .assumingMemoryBound(to: MemorySentinel.self) let msgHdrsBufferPointer = UnsafeMutableBufferPointer( - start: msgHdrsPointer, count: headPointer.pointee.count + start: msgHdrsPointer, + count: headPointer.pointee.count ) let addressesBufferPointer = UnsafeMutableBufferPointer( - start: addressesPointer, count: headPointer.pointee.count + start: addressesPointer, + count: headPointer.pointee.count ) let controlMessageStorage = UnsafeControlMessageStorage.makeNotOwning( bytesPerMessage: UnsafeControlMessageStorage.bytesPerMessage, - buffer: controlDataPointer) + buffer: controlDataPointer + ) return try body(msgHdrsBufferPointer, addressesBufferPointer, controlMessageStorage, sentinelPointer) } } @@ -313,18 +352,24 @@ struct PooledMsgBuffer: PoolElement { } func withUnsafePointers( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage) throws -> ReturnValue + _ body: ( + UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, + UnsafeControlMessageStorage + ) throws -> ReturnValue ) rethrows -> ReturnValue { defer { self.validateSentinel() } return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, controlMessageStorage, _ in - return try body(msgs, addresses, controlMessageStorage) + try body(msgs, addresses, controlMessageStorage) } } func withUnsafePointersWithStorageManagement( - _ body: (UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, UnsafeControlMessageStorage, Unmanaged) throws -> ReturnValue + _ body: ( + UnsafeMutableBufferPointer, UnsafeMutableBufferPointer, + UnsafeControlMessageStorage, Unmanaged + ) throws -> ReturnValue ) rethrows -> ReturnValue { let storageRef: Unmanaged = Unmanaged.passUnretained(self.storage) return try self.storage.withUnsafeMutableTypedPointers { msgs, addresses, controlMessageStorage, _ in diff --git a/Sources/NIOPosix/PooledRecvBufferAllocator.swift b/Sources/NIOPosix/PooledRecvBufferAllocator.swift index f179d9b38b..f34621cdc3 100644 --- a/Sources/NIOPosix/PooledRecvBufferAllocator.swift +++ b/Sources/NIOPosix/PooledRecvBufferAllocator.swift @@ -102,14 +102,17 @@ internal struct PooledRecvBufferAllocator { } } - private mutating func reuseExistingBuffer(_ body: (inout ByteBuffer) throws -> Result) rethrows -> (ByteBuffer, Result)? { + private mutating func reuseExistingBuffer( + _ body: (inout ByteBuffer) throws -> Result + ) rethrows -> (ByteBuffer, Result)? { if let nextBufferSize = self.recvAllocator.nextBufferSize() { if let result = try self.buffer?.modifyIfUniquelyOwned(minimumCapacity: nextBufferSize, body) { // `result` can only be non-nil if `buffer` is non-nil. return (self.buffer!, result) } else { // Cycle through the buffers starting at the last used buffer. - let resultAndIndex = try self.buffers.loopingFirstIndexWithResult(startingAt: self.lastUsedIndex) { buffer in + let resultAndIndex = try self.buffers.loopingFirstIndexWithResult(startingAt: self.lastUsedIndex) { + buffer in try buffer.modifyIfUniquelyOwned(minimumCapacity: nextBufferSize, body) } @@ -131,8 +134,10 @@ internal struct PooledRecvBufferAllocator { return nil } - private mutating func allocateNewBuffer(using allocator: ByteBufferAllocator, - _ body: (inout ByteBuffer) throws -> Result) rethrows -> (ByteBuffer, Result) { + private mutating func allocateNewBuffer( + using allocator: ByteBufferAllocator, + _ body: (inout ByteBuffer) throws -> Result + ) rethrows -> (ByteBuffer, Result) { // Couldn't reuse a buffer; create a new one and store it if there's capacity. var newBuffer = self.recvAllocator.buffer(allocator: allocator) @@ -173,17 +178,21 @@ internal struct PooledRecvBufferAllocator { } } - private mutating func modifyBuffer(atIndex index: Int, - _ body: (inout ByteBuffer) throws -> Result) rethrows -> (ByteBuffer, Result) { + private mutating func modifyBuffer( + atIndex index: Int, + _ body: (inout ByteBuffer) throws -> Result + ) rethrows -> (ByteBuffer, Result) { let result = try body(&self.buffers[index]) return (self.buffers[index], result) } } extension ByteBuffer { - fileprivate mutating func modifyIfUniquelyOwned(minimumCapacity: Int, - _ body: (inout ByteBuffer) throws -> Result) rethrows -> Result? { - return try self.modifyIfUniquelyOwned { buffer in + fileprivate mutating func modifyIfUniquelyOwned( + minimumCapacity: Int, + _ body: (inout ByteBuffer) throws -> Result + ) rethrows -> Result? { + try self.modifyIfUniquelyOwned { buffer in buffer.clear(minimumCapacity: minimumCapacity) return try body(&buffer) } @@ -197,17 +206,21 @@ extension Array { /// /// - Returns: The result and index of the first element passed to `body` which returned /// non-nil, or `nil` if no such element exists. - fileprivate mutating func loopingFirstIndexWithResult(startingAt middleIndex: Index, - whereNonNil body: (inout Element) throws -> Result?) rethrows -> (Result, Index)? { - if let result = try self.firstIndexWithResult(in: middleIndex ..< self.endIndex, whereNonNil: body) { + fileprivate mutating func loopingFirstIndexWithResult( + startingAt middleIndex: Index, + whereNonNil body: (inout Element) throws -> Result? + ) rethrows -> (Result, Index)? { + if let result = try self.firstIndexWithResult(in: middleIndex..(in indices: Range, - whereNonNil body: (inout Element) throws -> Result?) rethrows -> (Result, Index)? { + private mutating func firstIndexWithResult( + in indices: Range, + whereNonNil body: (inout Element) throws -> Result? + ) rethrows -> (Result, Index)? { for index in indices { if let result = try body(&self[index]) { return (result, index) diff --git a/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift index 9b6d97d139..cda58f162e 100644 --- a/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift +++ b/Sources/NIOPosix/PosixSingletons+ConcurrencyTakeOver.swift @@ -41,7 +41,8 @@ extension NIOSingletons { /// - warning: You may only call this method once. @discardableResult public static func unsafeTryInstallSingletonPosixEventLoopGroupAsConcurrencyGlobalExecutor() -> Bool { - #if /* minimum supported */ compiler(>=5.9) && /* maximum tested */ compiler(<6.1) + // Guard between the minimum and maximum supported version for the hook + #if compiler(>=5.9) && compiler(<6.1) guard #available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) else { return false } @@ -106,11 +107,13 @@ extension NIOSingletons { ) { enqueueOnNIOPtr in // Unsafe 4: We just pretend that we're the only ones in the world pulling this trick (or at least // that the others also use a `compareExchange`)... - guard concurrencyEnqueueGlobalHookAtomic.compareExchange( - expected: nil, - desired: enqueueOnNIOPtr.pointee, - ordering: .relaxed - ).exchanged else { + guard + concurrencyEnqueueGlobalHookAtomic.compareExchange( + expected: nil, + desired: enqueueOnNIOPtr.pointee, + ordering: .relaxed + ).exchanged + else { return false } diff --git a/Sources/NIOPosix/PosixSingletons.swift b/Sources/NIOPosix/PosixSingletons.swift index f09d88bd5b..0f82ff2920 100644 --- a/Sources/NIOPosix/PosixSingletons.swift +++ b/Sources/NIOPosix/PosixSingletons.swift @@ -20,14 +20,14 @@ extension NIOSingletons { /// /// The number of threads is determined by `NIOSingletons/groupLoopCountSuggestion`. public static var posixEventLoopGroup: MultiThreadedEventLoopGroup { - return singletonMTELG + singletonMTELG } /// A globally shared, lazily initialized ``NIOThreadPool`` that can be used for blocking I/O and other blocking operations. /// /// The number of threads is determined by `NIOSingletons/blockingPoolThreadCountSuggestion`. public static var posixBlockingThreadPool: NIOThreadPool { - return globalPosixBlockingPool + globalPosixBlockingPool } } @@ -50,16 +50,17 @@ extension MultiThreadedEventLoopGroup { /// if any code attempts to use the global singletons. /// public static var singleton: MultiThreadedEventLoopGroup { - return NIOSingletons.posixEventLoopGroup + NIOSingletons.posixEventLoopGroup } } +// swift-format-ignore: DontRepeatTypeInStaticProperties extension EventLoopGroup where Self == MultiThreadedEventLoopGroup { /// A globally shared, singleton ``MultiThreadedEventLoopGroup``. /// /// This provides the same object as ``MultiThreadedEventLoopGroup/singleton``. public static var singletonMultiThreadedEventLoopGroup: Self { - return MultiThreadedEventLoopGroup.singleton + MultiThreadedEventLoopGroup.singleton } } @@ -81,37 +82,41 @@ extension NIOThreadPool { /// `NIOSingletons/singletonsEnabledSuggestion` to `false` which will lead to a forced crash /// if any code attempts to use the global singletons. public static var singleton: NIOThreadPool { - return NIOSingletons.posixBlockingThreadPool + NIOSingletons.posixBlockingThreadPool } } private let singletonMTELG: MultiThreadedEventLoopGroup = { guard NIOSingletons.singletonsEnabledSuggestion else { - fatalError(""" - Cannot create global singleton MultiThreadedEventLoopGroup because the global singletons have been \ - disabled by setting `NIOSingletons.singletonsEnabledSuggestion = false` - """) + fatalError( + """ + Cannot create global singleton MultiThreadedEventLoopGroup because the global singletons have been \ + disabled by setting `NIOSingletons.singletonsEnabledSuggestion = false` + """ + ) } let threadCount = NIOSingletons.groupLoopCountSuggestion - let group = MultiThreadedEventLoopGroup._makePerpetualGroup(threadNamePrefix: "NIO-SGLTN-", - numberOfThreads: threadCount) - _ = Unmanaged.passUnretained(group).retain() // Never gonna give you up, + let group = MultiThreadedEventLoopGroup._makePerpetualGroup( + threadNamePrefix: "NIO-SGLTN-", + numberOfThreads: threadCount + ) + _ = Unmanaged.passUnretained(group).retain() // Never gonna give you up, return group }() private let globalPosixBlockingPool: NIOThreadPool = { guard NIOSingletons.singletonsEnabledSuggestion else { - fatalError(""" - Cannot create global singleton NIOThreadPool because the global singletons have been \ - disabled by setting `NIOSingletons.singletonsEnabledSuggestion = false` - """) + fatalError( + """ + Cannot create global singleton NIOThreadPool because the global singletons have been \ + disabled by setting `NIOSingletons.singletonsEnabledSuggestion = false` + """ + ) } let pool = NIOThreadPool._makePerpetualStartedPool( numberOfThreads: NIOSingletons.blockingPoolThreadCountSuggestion, threadNamePrefix: "SGLTN-TP-#" ) - _ = Unmanaged.passUnretained(pool).retain() // never gonna let you down. + _ = Unmanaged.passUnretained(pool).retain() // never gonna let you down. return pool }() - - diff --git a/Sources/NIOPosix/RawSocketBootstrap.swift b/Sources/NIOPosix/RawSocketBootstrap.swift index 9847f17e1f..5f507fb5e9 100644 --- a/Sources/NIOPosix/RawSocketBootstrap.swift +++ b/Sources/NIOPosix/RawSocketBootstrap.swift @@ -51,8 +51,10 @@ public final class NIORawSocketBootstrap { /// - group: The `EventLoopGroup` to use. public convenience init(group: EventLoopGroup) { guard NIOOnSocketsBootstraps.isCompatible(group: group) else { - preconditionFailure("RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " + - "SelectableEventLoop. You tried constructing one with \(group) which is incompatible.") + preconditionFailure( + "RawSocketBootstrap is only compatible with MultiThreadedEventLoopGroup and " + + "SelectableEventLoop. You tried constructing one with \(group) which is incompatible." + ) } self.init(validatingGroup: group)! } @@ -69,7 +71,7 @@ public final class NIORawSocketBootstrap { self.group = group self.channelInitializer = nil } - + /// Initialize the bound `Channel` with `initializer`. The most common task in initializer is to add /// `ChannelHandler`s to the `ChannelPipeline`. /// @@ -98,12 +100,15 @@ public final class NIORawSocketBootstrap { /// - host: The host to bind on. /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. public func bind(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture { - return bind0(ipProtocol: ipProtocol) { - return try SocketAddress.makeAddressResolvingHost(host, port: 0) + bind0(ipProtocol: ipProtocol) { + try SocketAddress.makeAddressResolvingHost(host, port: 0) } } - private func bind0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + private func bind0( + ipProtocol: NIOIPProtocol, + _ makeSocketAddress: () throws -> SocketAddress + ) -> EventLoopFuture { let address: SocketAddress do { address = try makeSocketAddress() @@ -112,10 +117,12 @@ public final class NIORawSocketBootstrap { } precondition(address.port == nil || address.port == 0, "port must be 0 or not set") func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, - protocolFamily: address.protocol, - protocolSubtype: .init(ipProtocol), - socketType: .raw) + try DatagramChannel( + eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: .init(ipProtocol), + socketType: .raw + ) } return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in channel.register().flatMap { @@ -130,12 +137,15 @@ public final class NIORawSocketBootstrap { /// - host: The host to connect to. /// - ipProtocol: The IP protocol used in the IP protocol/nextHeader field. public func connect(host: String, ipProtocol: NIOIPProtocol) -> EventLoopFuture { - return connect0(ipProtocol: ipProtocol) { - return try SocketAddress.makeAddressResolvingHost(host, port: 0) + connect0(ipProtocol: ipProtocol) { + try SocketAddress.makeAddressResolvingHost(host, port: 0) } } - private func connect0(ipProtocol: NIOIPProtocol, _ makeSocketAddress: () throws -> SocketAddress) -> EventLoopFuture { + private func connect0( + ipProtocol: NIOIPProtocol, + _ makeSocketAddress: () throws -> SocketAddress + ) -> EventLoopFuture { let address: SocketAddress do { address = try makeSocketAddress() @@ -143,10 +153,12 @@ public final class NIORawSocketBootstrap { return group.next().makeFailedFuture(error) } func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel(eventLoop: eventLoop, - protocolFamily: address.protocol, - protocolSubtype: .init(ipProtocol), - socketType: .raw) + try DatagramChannel( + eventLoop: eventLoop, + protocolFamily: address.protocol, + protocolSubtype: .init(ipProtocol), + socketType: .raw + ) } return withNewChannel(makeChannel: makeChannel) { (eventLoop, channel) in channel.register().flatMap { @@ -155,7 +167,10 @@ public final class NIORawSocketBootstrap { } } - private func withNewChannel(makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture) -> EventLoopFuture { + private func withNewChannel( + makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, + _ bringup: @escaping (EventLoop, DatagramChannel) -> EventLoopFuture + ) -> EventLoopFuture { let eventLoop = self.group.next() let channelInitializer = self.channelInitializer ?? { _ in eventLoop.makeSucceededFuture(()) } let channelOptions = self._channelOptions @@ -216,7 +231,7 @@ extension NIORawSocketBootstrap { postRegisterTransformation: { $1.makeSucceededFuture($0) } ) } - + /// Connect the `Channel` to `host`. /// /// - Parameters: @@ -244,12 +259,14 @@ extension NIORawSocketBootstrap { host: String, ipProtocol: NIOIPProtocol, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { let address = try SocketAddress.makeAddressResolvingHost(host, port: 0) func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel( + try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, protocolSubtype: .init(ipProtocol), @@ -274,13 +291,15 @@ extension NIORawSocketBootstrap { host: String, ipProtocol: NIOIPProtocol, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) async throws -> PostRegistrationTransformationResult { let address = try SocketAddress.makeAddressResolvingHost(host, port: 0) precondition(address.port == nil || address.port == 0, "port must be 0 or not set") func makeChannel(_ eventLoop: SelectableEventLoop) throws -> DatagramChannel { - return try DatagramChannel( + try DatagramChannel( eventLoop: eventLoop, protocolFamily: address.protocol, protocolSubtype: .init(ipProtocol), @@ -305,7 +324,9 @@ extension NIORawSocketBootstrap { makeChannel: (_ eventLoop: SelectableEventLoop) throws -> DatagramChannel, channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, registration: @escaping @Sendable (Channel) -> EventLoopFuture, - postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > ) -> EventLoopFuture { let eventLoop = self.group.next() let channelInitializer = { (channel: Channel) -> EventLoopFuture in diff --git a/Sources/NIOPosix/SelectableChannel.swift b/Sources/NIOPosix/SelectableChannel.swift index 8ecc84b486..20882d915e 100644 --- a/Sources/NIOPosix/SelectableChannel.swift +++ b/Sources/NIOPosix/SelectableChannel.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import NIOConcurrencyHelpers +import NIOCore /// A `SelectableChannel` is a `Channel` that can be used with a `Selector` which notifies a user when certain events /// are possible. On UNIX a `Selector` is usually an abstraction of `select`, `poll`, `epoll` or `kqueue`. diff --git a/Sources/NIOPosix/SelectableEventLoop.swift b/Sources/NIOPosix/SelectableEventLoop.swift index 8da3951975..c8f9204cfb 100644 --- a/Sources/NIOPosix/SelectableEventLoop.swift +++ b/Sources/NIOPosix/SelectableEventLoop.swift @@ -12,23 +12,23 @@ // //===----------------------------------------------------------------------===// +import Atomics import DequeModule import Dispatch -import NIOCore import NIOConcurrencyHelpers +import NIOCore import _NIODataStructures -import Atomics /// Execute the given closure and ensure we release all auto pools if needed. @inlinable internal func withAutoReleasePool(_ execute: () throws -> T) rethrows -> T { -#if canImport(Darwin) + #if canImport(Darwin) return try autoreleasepool { try execute() } -#else + #else return try execute() -#endif + #endif } /// Information about an EventLoop tick @@ -93,7 +93,7 @@ internal final class SelectableEventLoop: EventLoop { case exitingThread } - /* private but tests */ internal let _selector: NIOPosix.Selector + internal let _selector: NIOPosix.Selector private let thread: NIOThread @usableFromInline // _pendingTaskPop is set to `true` if the event loop is about to pop tasks off the task queue. @@ -110,11 +110,11 @@ internal final class SelectableEventLoop: EventLoop { // for every appended closure. https://bugs.swift.org/browse/SR-15872 private var tasksCopy = ContiguousArray() private static var tasksCopyBatchSize: Int { - return 4096 + 4096 } @usableFromInline - internal var _succeededVoidFuture: Optional> = nil { + internal var _succeededVoidFuture: EventLoopFuture? = nil { didSet { self.assertInEventLoop() } @@ -127,12 +127,16 @@ internal final class SelectableEventLoop: EventLoop { private var externalStateLock: NIOLock { // The assert is here to check that we never try to read the external state on the EventLoop unless we're // shutting down. - assert(!self.inEventLoop || self.internalState != .runningAndAcceptingNewRegistrations, - "lifecycle lock taken whilst up and running and in EventLoop") + assert( + !self.inEventLoop || self.internalState != .runningAndAcceptingNewRegistrations, + "lifecycle lock taken whilst up and running and in EventLoop" + ) return self._externalStateLock } - private var internalState: InternalState = .runningAndAcceptingNewRegistrations // protected by the EventLoop thread - private var externalState: ExternalState = .open // protected by externalStateLock + // protected by the EventLoop thread + private var internalState: InternalState = .runningAndAcceptingNewRegistrations + // protected by externalStateLock + private var externalState: ExternalState = .open let bufferPool: Pool let msgBufferPool: Pool @@ -157,7 +161,9 @@ internal final class SelectableEventLoop: EventLoop { } @usableFromInline - internal func _promiseCompleted(futureIdentifier: _NIOEventLoopFutureIdentifier) -> (file: StaticString, line: UInt)? { + internal func _promiseCompleted( + futureIdentifier: _NIOEventLoopFutureIdentifier + ) -> (file: StaticString, line: UInt)? { precondition(_isDebugAssertConfiguration()) return self.promiseCreationStoreLock.withLock { self._promiseCreationStore.removeValue(forKey: futureIdentifier) @@ -166,16 +172,17 @@ internal final class SelectableEventLoop: EventLoop { @usableFromInline internal func _preconditionSafeToWait(file: StaticString, line: UInt) { - let explainer: () -> String = { """ -BUG DETECTED: wait() must not be called when on an EventLoop. -Calling wait() on any EventLoop can lead to -- deadlocks -- stalling processing of other connections (Channels) that are handled on the EventLoop that wait was called on - -Further information: -- current eventLoop: \(MultiThreadedEventLoopGroup.currentEventLoop.debugDescription) -- event loop associated to future: \(self) -""" + let explainer: () -> String = { + """ + BUG DETECTED: wait() must not be called when on an EventLoop. + Calling wait() on any EventLoop can lead to + - deadlocks + - stalling processing of other connections (Channels) that are handled on the EventLoop that wait was called on + + Further information: + - current eventLoop: \(MultiThreadedEventLoopGroup.currentEventLoop.debugDescription) + - event loop associated to future: \(self) + """ } precondition(!self.inEventLoop, explainer(), file: file, line: line) precondition(MultiThreadedEventLoopGroup.currentEventLoop == nil, explainer(), file: file, line: line) @@ -202,16 +209,18 @@ Further information: } internal var testsOnly_validExternalStateToScheduleTasks: Bool { - return self.externalStateLock.withLock { - return self.validExternalStateToScheduleTasks + self.externalStateLock.withLock { + self.validExternalStateToScheduleTasks } } - internal init(thread: NIOThread, - parentGroup: MultiThreadedEventLoopGroup?, /* nil iff thread take-over */ - selector: NIOPosix.Selector, - canBeShutdownIndividually: Bool, - metricsDelegate: NIOEventLoopMetricsDelegate?) { + internal init( + thread: NIOThread, + parentGroup: MultiThreadedEventLoopGroup?, // nil iff thread take-over + selector: NIOPosix.Selector, + canBeShutdownIndividually: Bool, + metricsDelegate: NIOEventLoopMetricsDelegate? + ) { self.metricsDelegate = metricsDelegate self._parentGroup = parentGroup self._selector = selector @@ -229,10 +238,14 @@ Further information: } deinit { - assert(self.internalState == .exitingThread, - "illegal internal state on deinit: \(self.internalState)") - assert(self.externalState == .resourcesReclaimed, - "illegal external state on shutdown: \(self.externalState)") + assert( + self.internalState == .exitingThread, + "illegal internal state on deinit: \(self.internalState)" + ) + assert( + self.externalState == .resourcesReclaimed, + "illegal external state on shutdown: \(self.externalState)" + ) } /// Is this `SelectableEventLoop` still open (ie. not shutting down or shut down) @@ -281,34 +294,42 @@ Further information: /// - see: `EventLoop.inEventLoop` @usableFromInline internal var inEventLoop: Bool { - return thread.isCurrent + thread.isCurrent } /// - see: `EventLoop.scheduleTask(deadline:_:)` @inlinable internal func scheduleTask(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled { let promise: EventLoopPromise = self.makePromise() - let task = ScheduledTask(id: self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed), { - do { - promise.succeed(try task()) - } catch let err { - promise.fail(err) - } - }, { error in - promise.fail(error) - }, deadline) + let task = ScheduledTask( + id: self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed), + { + do { + promise.succeed(try task()) + } catch let err { + promise.fail(err) + } + }, + { error in + promise.fail(error) + }, + deadline + ) let taskId = task.id - let scheduled = Scheduled(promise: promise, cancellationTask: { - self._tasksLock.withLock { () -> Void in - self._scheduledTasks.removeFirst(where: { $0.id == taskId }) + let scheduled = Scheduled( + promise: promise, + cancellationTask: { + self._tasksLock.withLock { () -> Void in + self._scheduledTasks.removeFirst(where: { $0.id == taskId }) + } + // We don't need to wake up the selector here, the scheduled task will never be picked up. Waking up the + // selector would mean that we may be able to recalculate the shutdown to a later date. The cost of not + // doing the recalculation is one potentially unnecessary wakeup which is exactly what we're + // saving here. So in the worst case, we didn't do a performance optimisation, in the best case, we saved + // one wakeup. } - // We don't need to wake up the selector here, the scheduled task will never be picked up. Waking up the - // selector would mean that we may be able to recalculate the shutdown to a later date. The cost of not - // doing the recalculation is one potentially unnecessary wakeup which is exactly what we're - // saving here. So in the worst case, we didn't do a performance optimisation, in the best case, we saved - // one wakeup. - }) + ) do { try self._schedule0(.scheduled(task)) @@ -322,7 +343,7 @@ Further information: /// - see: `EventLoop.scheduleTask(in:_:)` @inlinable internal func scheduleTask(in: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled { - return scheduleTask(deadline: .now() + `in`, task) + scheduleTask(deadline: .now() + `in`, task) } // - see: `EventLoop.execute` @@ -346,8 +367,10 @@ Further information: @usableFromInline internal func _schedule0(_ task: LoopTask) throws { if self.inEventLoop { - precondition(self._validInternalStateToScheduleTasks, - "BUG IN NIO (please report): EventLoop is shutdown, yet we're on the EventLoop.") + precondition( + self._validInternalStateToScheduleTasks, + "BUG IN NIO (please report): EventLoop is shutdown, yet we're on the EventLoop." + ) self._tasksLock.withLock { () -> Void in switch task { @@ -461,7 +484,7 @@ Further information: } private func run(_ task: UnderlyingTask) { - /* for macOS: in case any calls we make to Foundation put objects into an autoreleasepool */ + // for macOS: in case any calls we make to Foundation put objects into an autoreleasepool withAutoReleasePool { switch task { case .function(let function): @@ -484,7 +507,8 @@ Further information: tasksCopy: ContiguousArray, tasksCopyBatchSize: Int, now: NIODeadline, - nextDeadline: NIODeadline) { + nextDeadline: NIODeadline + ) { assert(tasksCopy.count <= tasksCopyBatchSize) // When we exit the loop, we would expect to // * have taskCopy full, or: @@ -537,7 +561,8 @@ Further information: immediateTasks: inout Deque, scheduledTasks: inout PriorityQueue, tasksCopy: inout ContiguousArray, - tasksCopyBatchSize: Int) -> NIODeadline? { + tasksCopyBatchSize: Int + ) -> NIODeadline? { // We expect empty tasksCopy, to put a new batch of tasks into assert(tasksCopy.isEmpty) @@ -556,7 +581,8 @@ Further information: while moreImmediateTasksToConsider || moreScheduledTasksToConsider { // We pick one item from immediateTasks & scheduledTask per iteration of the loop. // This prevents one task queue starving the other. - if moreImmediateTasksToConsider, tasksCopy.count < tasksCopyBatchSize, let task = immediateTasks.popFirst() { + if moreImmediateTasksToConsider, tasksCopy.count < tasksCopyBatchSize, let task = immediateTasks.popFirst() + { tasksCopy.append(task) } else { moreImmediateTasksToConsider = false @@ -585,7 +611,8 @@ Further information: tasksCopy: tasksCopy, tasksCopyBatchSize: tasksCopyBatchSize, now: now, - nextDeadline: nextDeadline) + nextDeadline: nextDeadline + ) } return nextDeadline @@ -595,7 +622,11 @@ Further information: let tickStartTime: NIODeadline = .now() var tasksProcessedInTick = 0 defer { - let tickInfo = NIOEventLoopTickInfo(eventLoopID: selfIdentifier, numberOfTasks: tasksProcessedInTick, startTime: tickStartTime) + let tickInfo = NIOEventLoopTickInfo( + eventLoopID: selfIdentifier, + numberOfTasks: tasksProcessedInTick, + startTime: tickStartTime + ) self.metricsDelegate?.processedTick(info: tickInfo) } while true { @@ -604,7 +635,8 @@ Further information: immediateTasks: &self._immediateTasks, scheduledTasks: &self._scheduledTasks, tasksCopy: &self.tasksCopy, - tasksCopyBatchSize: Self.tasksCopyBatchSize) + tasksCopyBatchSize: Self.tasksCopyBatchSize + ) if self.tasksCopy.isEmpty { // Rare, but it's possible to find no tasks to execute if all scheduled tasks are expiring in the future. self._pendingTaskPop = false @@ -612,7 +644,7 @@ Further information: return deadline } - // all pending tasks are set to occur in the future, so we can stop looping. + // all pending tasks are set to occur in the future, so we can stop looping. if self.tasksCopy.isEmpty { return nextReadyDeadline } @@ -640,7 +672,7 @@ Further information: var drained = false var scheduledTasksCopy = ContiguousArray() var immediateTasksCopy = Deque() - repeat { // We may need to do multiple rounds of this because failing tasks may lead to more work. + repeat { // We may need to do multiple rounds of this because failing tasks may lead to more work. self._tasksLock.withLock { // In this state we never want the selector to be woken again, so we pretend we're permanently running. self._pendingTaskPop = true @@ -684,13 +716,13 @@ Further information: nextReadyDeadline = firstScheduledTask.readyTime } if !self._immediateTasks.isEmpty { - nextReadyDeadline = NIODeadline.now() + nextReadyDeadline = NIODeadline.now() } } let selfIdentifier = ObjectIdentifier(self) while self.internalState != .noLongerRunning && self.internalState != .exitingThread { // Block until there are events to handle or the selector was woken up - /* for macOS: in case any calls we make to Foundation put objects into an autoreleasepool */ + // for macOS: in case any calls we make to Foundation put objects into an autoreleasepool try withAutoReleasePool { try self._selector.whenReady( strategy: currentSelectorStrategy(nextReadyDeadline: nextReadyDeadline), @@ -728,7 +760,7 @@ Further information: internal func initiateClose(queue: DispatchQueue, completionHandler: @escaping (Result) -> Void) { func doClose() { self.assertInEventLoop() - self._parentGroup = nil // break the cycle + self._parentGroup = nil // break the cycle // There should only ever be one call into this function so we need to be up and running, ... assert(self.internalState == .runningAndAcceptingNewRegistrations) self.internalState = .runningButNotAcceptingNewRegistrations @@ -742,7 +774,7 @@ Further information: self.assertInEventLoop() assert(self.internalState == .runningButNotAcceptingNewRegistrations) self.internalState = .noLongerRunning - self.execute {} // force a new event loop tick, so the event loop definitely stops looping very soon. + self.execute {} // force a new event loop tick, so the event loop definitely stops looping very soon. self.externalStateLock.withLock { assert(self.externalState == .closing) self.externalState = .closed @@ -807,7 +839,7 @@ Further information: func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { if self.canBeShutdownIndividually { self.initiateClose(queue: queue) { result in - self.syncFinaliseClose(joinThread: false) // This thread was taken over by somebody else + self.syncFinaliseClose(joinThread: false) // This thread was taken over by somebody else switch result { case .success: callback(nil) @@ -846,13 +878,13 @@ Further information: extension SelectableEventLoop: CustomStringConvertible, CustomDebugStringConvertible { @usableFromInline var description: String { - return "SelectableEventLoop { selector = \(self._selector), thread = \(self.thread) }" + "SelectableEventLoop { selector = \(self._selector), thread = \(self.thread) }" } @usableFromInline var debugDescription: String { - return self._tasksLock.withLock { - return "SelectableEventLoop { selector = \(self._selector), thread = \(self.thread), scheduledTasks = \(self._scheduledTasks.description) }" + self._tasksLock.withLock { + "SelectableEventLoop { selector = \(self._selector), thread = \(self.thread), scheduledTasks = \(self._scheduledTasks.description) }" } } } @@ -860,7 +892,7 @@ extension SelectableEventLoop: CustomStringConvertible, CustomDebugStringConvert // MARK: SerialExecutor conformance #if compiler(>=5.9) @available(macOS 14.0, iOS 17.0, watchOS 10.0, tvOS 17.0, *) -extension SelectableEventLoop: NIOSerialEventLoopExecutor { } +extension SelectableEventLoop: NIOSerialEventLoopExecutor {} #endif @usableFromInline @@ -879,7 +911,9 @@ internal enum LoopTask { @inlinable internal func assertExpression(_ body: () -> Bool) { - assert({ - return body() - }()) + assert( + { + body() + }() + ) } diff --git a/Sources/NIOPosix/SelectorEpoll.swift b/Sources/NIOPosix/SelectorEpoll.swift index a8a9376a39..5b5ea844cc 100644 --- a/Sources/NIOPosix/SelectorEpoll.swift +++ b/Sources/NIOPosix/SelectorEpoll.swift @@ -74,8 +74,8 @@ extension SelectorEventSet { if epollFilters.contains(.readHangup) { filter |= Epoll.EPOLLRDHUP } - assert(filter & Epoll.EPOLLHUP != 0) // both of these are reported - assert(filter & Epoll.EPOLLERR != 0) // always and can't be masked. + assert(filter & Epoll.EPOLLHUP != 0) // both of these are reported + assert(filter & Epoll.EPOLLERR != 0) // always and can't be masked. return filter } @@ -128,21 +128,32 @@ extension Selector: _SelectorBackendProtocol { func initialiseState0() throws { self.selectorFD = try Epoll.epoll_create(size: 128) self.eventFD = try EventFd.eventfd(initval: 0, flags: Int32(EventFd.EFD_CLOEXEC | EventFd.EFD_NONBLOCK)) - self.timerFD = try TimerFd.timerfd_create(clockId: CLOCK_MONOTONIC, flags: Int32(TimerFd.TFD_CLOEXEC | TimerFd.TFD_NONBLOCK)) + self.timerFD = try TimerFd.timerfd_create( + clockId: CLOCK_MONOTONIC, + flags: Int32(TimerFd.TFD_CLOEXEC | TimerFd.TFD_NONBLOCK) + ) self.lifecycleState = .open var ev = Epoll.epoll_event() ev.events = SelectorEventSet.read.epollEventSet - ev.data.u64 = UInt64(EPollUserData(registrationID: .initialRegistrationID, - fileDescriptor: self.eventFD)) + ev.data.u64 = UInt64( + EPollUserData( + registrationID: .initialRegistrationID, + fileDescriptor: self.eventFD + ) + ) try Epoll.epoll_ctl(epfd: self.selectorFD, op: Epoll.EPOLL_CTL_ADD, fd: self.eventFD, event: &ev) var timerev = Epoll.epoll_event() timerev.events = Epoll.EPOLLIN | Epoll.EPOLLERR | Epoll.EPOLLRDHUP - timerev.data.u64 = UInt64(EPollUserData(registrationID: .initialRegistrationID, - fileDescriptor: self.timerFD)) + timerev.data.u64 = UInt64( + EPollUserData( + registrationID: .initialRegistrationID, + fileDescriptor: self.timerFD + ) + ) try Epoll.epoll_ctl(epfd: self.selectorFD, op: Epoll.EPOLL_CTL_ADD, fd: self.timerFD, event: &timerev) } @@ -151,10 +162,12 @@ extension Selector: _SelectorBackendProtocol { assert(self.timerFD == -1, "self.timerFD == \(self.timerFD) in deinitAssertions0, forgot close?") } - func register0(selectable: S, - fileDescriptor: CInt, - interested: SelectorEventSet, - registrationID: SelectorRegistrationID) throws { + func register0( + selectable: S, + fileDescriptor: CInt, + interested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { var ev = Epoll.epoll_event() ev.events = interested.epollEventSet ev.data.u64 = UInt64(EPollUserData(registrationID: registrationID, fileDescriptor: fileDescriptor)) @@ -162,11 +175,13 @@ extension Selector: _SelectorBackendProtocol { try Epoll.epoll_ctl(epfd: self.selectorFD, op: Epoll.EPOLL_CTL_ADD, fd: fileDescriptor, event: &ev) } - func reregister0(selectable: S, - fileDescriptor: CInt, - oldInterested: SelectorEventSet, - newInterested: SelectorEventSet, - registrationID: SelectorRegistrationID) throws { + func reregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + newInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { var ev = Epoll.epoll_event() ev.events = newInterested.epollEventSet ev.data.u64 = UInt64(EPollUserData(registrationID: registrationID, fileDescriptor: fileDescriptor)) @@ -174,17 +189,26 @@ extension Selector: _SelectorBackendProtocol { _ = try Epoll.epoll_ctl(epfd: self.selectorFD, op: Epoll.EPOLL_CTL_MOD, fd: fileDescriptor, event: &ev) } - func deregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws { + func deregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { var ev = Epoll.epoll_event() _ = try Epoll.epoll_ctl(epfd: self.selectorFD, op: Epoll.EPOLL_CTL_DEL, fd: fileDescriptor, event: &ev) } - + /// Apply the given `SelectorStrategy` and execute `body` once it's complete (which may produce `SelectorEvent`s to handle). /// /// - parameters: /// - strategy: The `SelectorStrategy` to apply /// - body: The function to execute for each `SelectorEvent` that was produced. - func whenReady0(strategy: SelectorStrategy, onLoopBegin loopStart: () -> Void, _ body: (SelectorEvent) throws -> Void) throws -> Void { + func whenReady0( + strategy: SelectorStrategy, + onLoopBegin loopStart: () -> Void, + _ body: (SelectorEvent) throws -> Void + ) throws { assert(self.myThread == NIOThread.current) guard self.lifecycleState == .open else { throw IOError(errnoCode: EBADF, reason: "can't call whenReady for selector as it's \(self.lifecycleState).") @@ -193,7 +217,14 @@ extension Selector: _SelectorBackendProtocol { switch strategy { case .now: - ready = Int(try Epoll.epoll_wait(epfd: self.selectorFD, events: events, maxevents: Int32(eventsCapacity), timeout: 0)) + ready = Int( + try Epoll.epoll_wait( + epfd: self.selectorFD, + events: events, + maxevents: Int32(eventsCapacity), + timeout: 0 + ) + ) case .blockUntilTimeout(let timeAmount): // Only call timerfd_settime if we're not already scheduled one that will cover it. // This guards against calling timerfd_settime if not needed as this is generally speaking @@ -208,7 +239,14 @@ extension Selector: _SelectorBackendProtocol { } fallthrough case .block: - ready = Int(try Epoll.epoll_wait(epfd: self.selectorFD, events: events, maxevents: Int32(eventsCapacity), timeout: -1)) + ready = Int( + try Epoll.epoll_wait( + epfd: self.selectorFD, + events: events, + maxevents: Int32(eventsCapacity), + timeout: -1 + ) + ) } loopStart() @@ -240,7 +278,10 @@ extension Selector: _SelectorBackendProtocol { var selectorEvent = SelectorEventSet(epollEvent: ev) // we can only verify the events for i == 0 as for i > 0 the user might have changed the registrations since then. - assert(i != 0 || selectorEvent.isSubset(of: registration.interested), "selectorEvent: \(selectorEvent), registration: \(registration)") + assert( + i != 0 || selectorEvent.isSubset(of: registration.interested), + "selectorEvent: \(selectorEvent), registration: \(registration)" + ) // in any case we only want what the user is currently registered for & what we got selectorEvent = selectorEvent.intersection(registration.interested) @@ -259,7 +300,7 @@ extension Selector: _SelectorBackendProtocol { /// Close the `Selector`. /// /// After closing the `Selector` it's no longer possible to use it. - public func close0() throws { + public func close0() throws { self.externalSelectorFDLock.withLock { // We try! all of the closes because close can only fail in the following ways: // - EINTR, which we eat in Posix.close @@ -267,26 +308,26 @@ extension Selector: _SelectorBackendProtocol { // - EBADF, which can't happen here because we would crash as EBADF is marked unacceptable // Therefore, we assert here that close will always succeed and if not, that's a NIO bug we need to know // about. - + try! Posix.close(descriptor: self.timerFD) self.timerFD = -1 - + try! Posix.close(descriptor: self.eventFD) self.eventFD = -1 - + try! Posix.close(descriptor: self.selectorFD) self.selectorFD = -1 } } - /* attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) */ + // attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) func wakeup0() throws { assert(NIOThread.current != self.myThread) try self.externalSelectorFDLock.withLock { - guard self.eventFD >= 0 else { - throw EventLoopError.shutdown - } - _ = try EventFd.eventfd_write(fd: self.eventFD, value: 1) + guard self.eventFD >= 0 else { + throw EventLoopError.shutdown + } + _ = try EventFd.eventfd_write(fd: self.eventFD, value: 1) } } } diff --git a/Sources/NIOPosix/SelectorGeneric.swift b/Sources/NIOPosix/SelectorGeneric.swift index 73f8ef3e26..f865021bd7 100644 --- a/Sources/NIOPosix/SelectorGeneric.swift +++ b/Sources/NIOPosix/SelectorGeneric.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import NIOCore import NIOConcurrencyHelpers +import NIOCore internal enum SelectorLifecycleState { case open @@ -86,7 +86,7 @@ struct SelectorEventSet: OptionSet, Equatable { internal let isEarlyEOFDeliveryWorkingOnThisOS: Bool = { #if canImport(Darwin) - return false // rdar://53656794 , once fixed we need to do an OS version check here. + return false // rdar://53656794 , once fixed we need to do an OS version check here. #else return true #endif @@ -100,29 +100,48 @@ internal let isEarlyEOFDeliveryWorkingOnThisOS: Bool = { protocol _SelectorBackendProtocol { associatedtype R: Registration func initialiseState0() throws - func deinitAssertions0() // allows actual implementation to run some assertions as part of the class deinit - func register0(selectable: S, fileDescriptor: CInt, interested: SelectorEventSet, registrationID: SelectorRegistrationID) throws - func reregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, newInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws - func deregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws - /* attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) */ + func deinitAssertions0() // allows actual implementation to run some assertions as part of the class deinit + func register0( + selectable: S, + fileDescriptor: CInt, + interested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws + func reregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + newInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws + func deregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws + // attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) func wakeup0() throws /// Apply the given `SelectorStrategy` and execute `body` once it's complete (which may produce `SelectorEvent`s to handle). /// /// - parameters: /// - strategy: The `SelectorStrategy` to apply /// - body: The function to execute for each `SelectorEvent` that was produced. - func whenReady0(strategy: SelectorStrategy, onLoopBegin: () -> Void, _ body: (SelectorEvent) throws -> Void) throws -> Void + func whenReady0( + strategy: SelectorStrategy, + onLoopBegin: () -> Void, + _ body: (SelectorEvent) throws -> Void + ) throws func close0() throws } - /// A `Selector` allows a user to register different `Selectable` sources to an underlying OS selector, and for that selector to notify them once IO is ready for them to process. /// /// This implementation offers an consistent API over epoll/liburing (for linux) and kqueue (for Darwin, BSD). /// There are specific subclasses per API type with a shared common superclass providing overall scaffolding. -/* this is deliberately not thread-safe, only the wakeup() function may be called unprotectedly */ -internal class Selector { +// this is deliberately not thread-safe, only the wakeup() function may be called unprotectedly +internal class Selector { var lifecycleState: SelectorLifecycleState var registrations = [Int: R]() var registrationID: SelectorRegistrationID = .initialRegistrationID @@ -132,7 +151,7 @@ internal class Selector { // reads: `self.externalSelectorFDLock` OR access from the EventLoop thread // writes: `self.externalSelectorFDLock` AND access from the EventLoop thread let externalSelectorFDLock = NIOLock() - var selectorFD: CInt = -1 // -1 == we're closed + var selectorFD: CInt = -1 // -1 == we're closed // Here we add the stored properties that are used by the specific backends #if canImport(Darwin) @@ -141,15 +160,15 @@ internal class Selector { #if !SWIFTNIO_USE_IO_URING typealias EventType = Epoll.epoll_event var earliestTimer: NIODeadline = .distantFuture - var eventFD: CInt = -1 // -1 == we're closed - var timerFD: CInt = -1 // -1 == we're closed + var eventFD: CInt = -1 // -1 == we're closed + var timerFD: CInt = -1 // -1 == we're closed #else typealias EventType = URingEvent - var eventFD: CInt = -1 // -1 == we're closed + var eventFD: CInt = -1 // -1 == we're closed var ring = URing() - let multishot = URing.io_uring_use_multishot_poll // if true, we run with streaming multishot polls - let deferReregistrations = true // if true we only flush once at reentring whenReady() - saves syscalls - var deferredReregistrationsPending = false // true if flush needed when reentring whenReady() + let multishot = URing.io_uring_use_multishot_poll // if true, we run with streaming multishot polls + let deferReregistrations = true // if true we only flush once at reentring whenReady() - saves syscalls + var deferredReregistrationsPending = false // true if flush needed when reentring whenReady() #endif #else #error("Unsupported platform, no suitable selector backend (we need kqueue or epoll support)") @@ -174,7 +193,7 @@ internal class Selector { events = Selector.allocateEventsArray(capacity: eventsCapacity) try self.initialiseState0() } - + deinit { self.deinitAssertions0() assert(self.registrations.count == 0, "left-over registrations: \(self.registrations)") @@ -182,7 +201,7 @@ internal class Selector { assert(self.selectorFD == -1, "self.selectorFD == \(self.selectorFD) on Selector deinit, forgot close?") Selector.deallocateEventsArray(events: events, capacity: eventsCapacity) } - + private static func allocateEventsArray(capacity: Int) -> UnsafeMutablePointer { let events: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: capacity) events.initialize(to: EventType()) @@ -193,28 +212,30 @@ internal class Selector { events.deinitialize(count: capacity) events.deallocate() } - + func growEventArrayIfNeeded(ready: Int) { - assert(self.myThread == NIOThread.current) - guard ready == eventsCapacity else { - return - } - Selector.deallocateEventsArray(events: events, capacity: eventsCapacity) - - // double capacity - eventsCapacity = ready << 1 - events = Selector.allocateEventsArray(capacity: eventsCapacity) - } - + assert(self.myThread == NIOThread.current) + guard ready == eventsCapacity else { + return + } + Selector.deallocateEventsArray(events: events, capacity: eventsCapacity) + + // double capacity + eventsCapacity = ready << 1 + events = Selector.allocateEventsArray(capacity: eventsCapacity) + } + /// Register `Selectable` on the `Selector`. /// /// - parameters: /// - selectable: The `Selectable` to register. /// - interested: The `SelectorEventSet` in which we are interested and want to be notified about. /// - makeRegistration: Creates the registration data for the given `SelectorEventSet`. - func register(selectable: S, - interested: SelectorEventSet, - makeRegistration: (SelectorEventSet, SelectorRegistrationID) -> R) throws { + func register( + selectable: S, + interested: SelectorEventSet, + makeRegistration: (SelectorEventSet, SelectorRegistrationID) -> R + ) throws { assert(self.myThread == NIOThread.current) assert(interested.contains(.reset)) guard self.lifecycleState == .open else { @@ -223,10 +244,12 @@ internal class Selector { try selectable.withUnsafeHandle { fd in assert(registrations[Int(fd)] == nil) - try self.register0(selectable: selectable, - fileDescriptor: fd, - interested: interested, - registrationID: self.registrationID) + try self.register0( + selectable: selectable, + fileDescriptor: fd, + interested: interested, + registrationID: self.registrationID + ) let registration = makeRegistration(interested, self.registrationID.nextRegistrationID()) registrations[Int(fd)] = registration } @@ -245,11 +268,13 @@ internal class Selector { assert(interested.contains(.reset), "must register for at least .reset but tried registering for \(interested)") try selectable.withUnsafeHandle { fd in var reg = registrations[Int(fd)]! - try self.reregister0(selectable: selectable, - fileDescriptor: fd, - oldInterested: reg.interested, - newInterested: interested, - registrationID: reg.registrationID) + try self.reregister0( + selectable: selectable, + fileDescriptor: fd, + oldInterested: reg.interested, + newInterested: interested, + registrationID: reg.registrationID + ) reg.interested = interested self.registrations[Int(fd)] = reg } @@ -271,10 +296,12 @@ internal class Selector { guard let reg = registrations.removeValue(forKey: Int(fd)) else { return } - try self.deregister0(selectable: selectable, - fileDescriptor: fd, - oldInterested: reg.interested, - registrationID: reg.registrationID) + try self.deregister0( + selectable: selectable, + fileDescriptor: fd, + oldInterested: reg.interested, + registrationID: reg.registrationID + ) } } @@ -284,7 +311,11 @@ internal class Selector { /// - strategy: The `SelectorStrategy` to apply /// - onLoopBegin: A function executed after the selector returns, just before the main loop begins.. /// - body: The function to execute for each `SelectorEvent` that was produced. - func whenReady(strategy: SelectorStrategy, onLoopBegin loopStart: () -> Void, _ body: (SelectorEvent) throws -> Void) throws -> Void { + func whenReady( + strategy: SelectorStrategy, + onLoopBegin loopStart: () -> Void, + _ body: (SelectorEvent) throws -> Void + ) throws { try self.whenReady0(strategy: strategy, onLoopBegin: loopStart, body) } @@ -301,7 +332,7 @@ internal class Selector { self.registrations.removeAll() } - /* attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) */ + // attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) func wakeup() throws { try self.wakeup0() } @@ -310,7 +341,7 @@ internal class Selector { extension Selector: CustomStringConvertible { var description: String { func makeDescription() -> String { - return "Selector { descriptor = \(self.selectorFD) }" + "Selector { descriptor = \(self.selectorFD) }" } if NIOThread.current == self.myThread { @@ -344,10 +375,13 @@ extension Selector where R == NIORegistration { func closeGently(eventLoop: EventLoop) -> EventLoopFuture { assert(self.myThread == NIOThread.current) guard self.lifecycleState == .open else { - return eventLoop.makeFailedFuture(IOError(errnoCode: EBADF, reason: "can't close selector gently as it's \(self.lifecycleState).")) + return eventLoop.makeFailedFuture( + IOError(errnoCode: EBADF, reason: "can't close selector gently as it's \(self.lifecycleState).") + ) } - let futures: [EventLoopFuture] = self.registrations.map { (_, reg: NIORegistration) -> EventLoopFuture in + let futures: [EventLoopFuture] = self.registrations.map { + (_, reg: NIORegistration) -> EventLoopFuture in // The futures will only be notified (of success) once also the closeFuture of each Channel is notified. // This only happens after all other actions on the Channel is complete and all events are propagated through the // ChannelPipeline. We do this to minimize the risk to left over any tasks / promises that are tied to the @@ -406,11 +440,11 @@ enum SelectorStrategy { @usableFromInline var _rawValue: UInt32 @inlinable var rawValue: UInt32 { - return self._rawValue + self._rawValue } @inlinable static var initialRegistrationID: SelectorRegistrationID { - return SelectorRegistrationID(rawValue: .max) + SelectorRegistrationID(rawValue: .max) } @inlinable mutating func nextRegistrationID() -> SelectorRegistrationID { @@ -424,8 +458,8 @@ enum SelectorStrategy { self._rawValue = rawValue } - @inlinable static func ==(_ lhs: SelectorRegistrationID, _ rhs: SelectorRegistrationID) -> Bool { - return lhs._rawValue == rhs._rawValue + @inlinable static func == (_ lhs: SelectorRegistrationID, _ rhs: SelectorRegistrationID) -> Bool { + lhs._rawValue == rhs._rawValue } @inlinable func hash(into hasher: inout Hasher) { diff --git a/Sources/NIOPosix/SelectorKqueue.swift b/Sources/NIOPosix/SelectorKqueue.swift index bc10fea3f4..dc00367808 100644 --- a/Sources/NIOPosix/SelectorKqueue.swift +++ b/Sources/NIOPosix/SelectorKqueue.swift @@ -62,14 +62,17 @@ extension KQueueEventFilterSet { /// - previousKQueueFilterSet: The previous filter set that is currently registered with kqueue. /// - fileDescriptor: The file descriptor the `kevent`s should be generated to. /// - body: The closure that will then apply the change set. - func calculateKQueueFilterSetChanges(previousKQueueFilterSet: KQueueEventFilterSet, - fileDescriptor: CInt, - registrationID: SelectorRegistrationID, - _ body: (UnsafeMutableBufferPointer) throws -> Void) rethrows { + func calculateKQueueFilterSetChanges( + previousKQueueFilterSet: KQueueEventFilterSet, + fileDescriptor: CInt, + registrationID: SelectorRegistrationID, + _ body: (UnsafeMutableBufferPointer) throws -> Void + ) rethrows { // we only use three filters (EVFILT_READ, EVFILT_WRITE and EVFILT_EXCEPT) so the number of changes would be 3. var kevents = KeventTriple() - let differences = previousKQueueFilterSet.symmetricDifference(self) // contains all the events that need a change (either need to be added or removed) + // contains all the events that need a change (either need to be added or removed) + let differences = previousKQueueFilterSet.symmetricDifference(self) func calculateKQueueChange(event: KQueueEventFilterSet) -> UInt16? { guard differences.contains(event) else { @@ -78,9 +81,16 @@ extension KQueueEventFilterSet { return UInt16(self.contains(event) ? EV_ADD : EV_DELETE) } - for (event, filter) in [(KQueueEventFilterSet.read, EVFILT_READ), (.write, EVFILT_WRITE), (.except, EVFILT_EXCEPT)] { + for (event, filter) in [ + (KQueueEventFilterSet.read, EVFILT_READ), (.write, EVFILT_WRITE), (.except, EVFILT_EXCEPT), + ] { if let flags = calculateKQueueChange(event: event) { - kevents.appendEvent(fileDescriptor: fileDescriptor, filter: filter, flags: flags, registrationID: registrationID) + kevents.appendEvent( + fileDescriptor: fileDescriptor, + filter: filter, + flags: flags, + registrationID: registrationID + ) } } @@ -94,7 +104,7 @@ extension SelectorRegistrationID { } } -/* this is deliberately not thread-safe, only the wakeup() function may be called unprotectedly */ +// this is deliberately not thread-safe, only the wakeup() function may be called unprotectedly extension Selector: _SelectorBackendProtocol { private static func toKQueueTimeSpec(strategy: SelectorStrategy) -> timespec? { switch strategy { @@ -133,12 +143,14 @@ extension Selector: _SelectorBackendProtocol { return } do { - try KQueue.kevent(kq: self.selectorFD, - changelist: keventBuffer.baseAddress!, - nchanges: CInt(keventBuffer.count), - eventlist: nil, - nevents: 0, - timeout: nil) + try KQueue.kevent( + kq: self.selectorFD, + changelist: keventBuffer.baseAddress!, + nchanges: CInt(keventBuffer.count), + eventlist: nil, + nevents: 0, + timeout: nil + ) } catch let err as IOError { if err.errnoCode == EINTR { // See https://www.freebsd.org/cgi/man.cgi?query=kqueue&sektion=2 @@ -149,7 +161,12 @@ extension Selector: _SelectorBackendProtocol { } } - private func kqueueUpdateEventNotifications(selectable: S, interested: SelectorEventSet, oldInterested: SelectorEventSet?, registrationID: SelectorRegistrationID) throws { + private func kqueueUpdateEventNotifications( + selectable: S, + interested: SelectorEventSet, + oldInterested: SelectorEventSet?, + registrationID: SelectorRegistrationID + ) throws { assert(self.myThread == NIOThread.current) let oldKQueueFilters = KQueueEventFilterSet(selectorEventSet: oldInterested ?? ._none) let newKQueueFilters = KQueueEventFilterSet(selectorEventSet: interested) @@ -157,15 +174,17 @@ extension Selector: _SelectorBackendProtocol { assert(oldInterested?.contains(.reset) ?? true) try selectable.withUnsafeHandle { - try newKQueueFilters.calculateKQueueFilterSetChanges(previousKQueueFilterSet: oldKQueueFilters, - fileDescriptor: $0, - registrationID: registrationID, - kqueueApplyEventChangeSet) + try newKQueueFilters.calculateKQueueFilterSetChanges( + previousKQueueFilterSet: oldKQueueFilters, + fileDescriptor: $0, + registrationID: registrationID, + kqueueApplyEventChangeSet + ) } } - + func initialiseState0() throws { - + self.selectorFD = try KQueue.kqueue() self.lifecycleState = .open @@ -180,20 +199,51 @@ extension Selector: _SelectorBackendProtocol { try kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer(start: ptr, count: 1)) } } - + func deinitAssertions0() { } - - func register0(selectable: S, fileDescriptor: CInt, interested: SelectorEventSet, registrationID: SelectorRegistrationID) throws { - try kqueueUpdateEventNotifications(selectable: selectable, interested: interested, oldInterested: nil, registrationID: registrationID) + + func register0( + selectable: S, + fileDescriptor: CInt, + interested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { + try kqueueUpdateEventNotifications( + selectable: selectable, + interested: interested, + oldInterested: nil, + registrationID: registrationID + ) } - func reregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, newInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws { - try kqueueUpdateEventNotifications(selectable: selectable, interested: newInterested, oldInterested: oldInterested, registrationID: registrationID) + func reregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + newInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { + try kqueueUpdateEventNotifications( + selectable: selectable, + interested: newInterested, + oldInterested: oldInterested, + registrationID: registrationID + ) } - - func deregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws { - try kqueueUpdateEventNotifications(selectable: selectable, interested: .reset, oldInterested: oldInterested, registrationID: registrationID) + + func deregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { + try kqueueUpdateEventNotifications( + selectable: selectable, + interested: .reset, + oldInterested: oldInterested, + registrationID: registrationID + ) } /// Apply the given `SelectorStrategy` and execute `body` once it's complete (which may produce `SelectorEvent`s to handle). @@ -201,7 +251,11 @@ extension Selector: _SelectorBackendProtocol { /// - parameters: /// - strategy: The `SelectorStrategy` to apply /// - body: The function to execute for each `SelectorEvent` that was produced. - func whenReady0(strategy: SelectorStrategy, onLoopBegin loopStart: () -> Void, _ body: (SelectorEvent) throws -> Void) throws -> Void { + func whenReady0( + strategy: SelectorStrategy, + onLoopBegin loopStart: () -> Void, + _ body: (SelectorEvent) throws -> Void + ) throws { assert(self.myThread == NIOThread.current) guard self.lifecycleState == .open else { throw IOError(errnoCode: EBADF, reason: "can't call whenReady for selector as it's \(self.lifecycleState).") @@ -209,7 +263,16 @@ extension Selector: _SelectorBackendProtocol { let timespec = Selector.toKQueueTimeSpec(strategy: strategy) let ready = try timespec.withUnsafeOptionalPointer { ts in - Int(try KQueue.kevent(kq: self.selectorFD, changelist: nil, nchanges: 0, eventlist: events, nevents: Int32(eventsCapacity), timeout: ts)) + Int( + try KQueue.kevent( + kq: self.selectorFD, + changelist: nil, + nchanges: 0, + eventlist: events, + nevents: Int32(eventsCapacity), + timeout: ts + ) + ) } loopStart() @@ -219,7 +282,10 @@ extension Selector: _SelectorBackendProtocol { let filter = Int32(ev.filter) let eventRegistrationID = SelectorRegistrationID(kqueueUData: ev.udata) guard Int32(ev.flags) & EV_ERROR == 0 else { - throw IOError(errnoCode: Int32(ev.data), reason: "kevent returned with EV_ERROR set: \(String(describing: ev))") + throw IOError( + errnoCode: Int32(ev.data), + reason: "kevent returned with EV_ERROR set: \(String(describing: ev))" + ) } guard filter != EVFILT_USER, let registration = registrations[Int(ev.ident)] else { continue @@ -231,7 +297,7 @@ extension Selector: _SelectorBackendProtocol { switch filter { case EVFILT_READ: selectorEvent.formUnion(.read) - fallthrough // falling through here as `EVFILT_READ` also delivers `EV_EOF` (meaning `.readEOF`) + fallthrough // falling through here as `EVFILT_READ` also delivers `EV_EOF` (meaning `.readEOF`) case EVFILT_EXCEPT: if Int32(ev.flags) & EV_EOF != 0 && registration.interested.contains(.readEOF) { // we only add `.readEOF` if it happened and the user asked for it @@ -247,7 +313,10 @@ extension Selector: _SelectorBackendProtocol { selectorEvent.formUnion(.reset) } // we can only verify the events for i == 0 as for i > 0 the user might have changed the registrations since then. - assert(i != 0 || selectorEvent.isSubset(of: registration.interested), "selectorEvent: \(selectorEvent), registration: \(registration)") + assert( + i != 0 || selectorEvent.isSubset(of: registration.interested), + "selectorEvent: \(selectorEvent), registration: \(registration)" + ) // in any case we only want what the user is currently registered for & what we got selectorEvent = selectorEvent.intersection(registration.interested) @@ -264,7 +333,7 @@ extension Selector: _SelectorBackendProtocol { /// Close the `Selector`. /// /// After closing the `Selector` it's no longer possible to use it. - func close0() throws { + func close0() throws { self.externalSelectorFDLock.withLock { // We try! all of the closes because close can only fail in the following ways: @@ -275,37 +344,37 @@ extension Selector: _SelectorBackendProtocol { // about. // We limit close to only be for positive FD:s though, as subclasses (e.g. uring) // may already have closed some of these FD:s in their close function. - - + try! Posix.close(descriptor: self.selectorFD) self.selectorFD = -1 } } - /* attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) */ + // attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) func wakeup0() throws { assert(NIOThread.current != self.myThread) try self.externalSelectorFDLock.withLock { - guard self.selectorFD >= 0 else { - throw EventLoopError._shutdown - } - var event = kevent() - event.ident = 0 - event.filter = Int16(EVFILT_USER) - event.fflags = UInt32(NOTE_TRIGGER | NOTE_FFNOP) - event.data = 0 - event.udata = nil - event.flags = 0 - try withUnsafeMutablePointer(to: &event) { ptr in - try self.kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer(start: ptr, count: 1)) - } + guard self.selectorFD >= 0 else { + throw EventLoopError._shutdown + } + var event = kevent() + event.ident = 0 + event.filter = Int16(EVFILT_USER) + event.fflags = UInt32(NOTE_TRIGGER | NOTE_FFNOP) + event.data = 0 + event.udata = nil + event.flags = 0 + try withUnsafeMutablePointer(to: &event) { ptr in + try self.kqueueApplyEventChangeSet(keventBuffer: UnsafeMutableBufferPointer(start: ptr, count: 1)) + } } } } extension kevent { /// Update a kevent for a given filter, file descriptor, and set of flags. - mutating func setEvent(fileDescriptor fd: CInt, filter: CInt, flags: UInt16, registrationID: SelectorRegistrationID) { + mutating func setEvent(fileDescriptor fd: CInt, filter: CInt, flags: UInt16, registrationID: SelectorRegistrationID) + { self.ident = UInt(fd) self.filter = Int16(filter) self.flags = flags @@ -331,7 +400,7 @@ extension kevent { /// set changes. We want to be able to store these kevent objects on the stack, which we historically did with /// unsafe pointers. This object replaces that unsafe code with safe code, and attempts to achieve the same /// performance constraints. -fileprivate struct KeventTriple { +private struct KeventTriple { // We need to store this in a tuple to achieve C-style memory layout. private var kevents = (kevent(), kevent(), kevent()) @@ -364,14 +433,24 @@ fileprivate struct KeventTriple { } } - mutating func appendEvent(fileDescriptor fd: CInt, filter: CInt, flags: UInt16, registrationID: SelectorRegistrationID) { + mutating func appendEvent( + fileDescriptor fd: CInt, + filter: CInt, + flags: UInt16, + registrationID: SelectorRegistrationID + ) { defer { // Unchecked math is safe here: we access through the subscript, which will trap on out-of-bounds value, so we'd trap // well before we overflow. self.initialized &+= 1 } - self[self.initialized].setEvent(fileDescriptor: fd, filter: filter, flags: flags, registrationID: registrationID) + self[self.initialized].setEvent( + fileDescriptor: fd, + filter: filter, + flags: flags, + registrationID: registrationID + ) } mutating func withUnsafeBufferPointer(_ body: (UnsafeMutableBufferPointer) throws -> Void) rethrows { diff --git a/Sources/NIOPosix/SelectorUring.swift b/Sources/NIOPosix/SelectorUring.swift index b418feb7e6..e8c0ba84e9 100644 --- a/Sources/NIOPosix/SelectorUring.swift +++ b/Sources/NIOPosix/SelectorUring.swift @@ -73,8 +73,8 @@ extension SelectorEventSet { if uringFilters.contains(.readHangup) { filter |= URing.POLLRDHUP } - assert(filter & URing.POLLHUP != 0) // both of these are reported - assert(filter & URing.POLLERR != 0) // always and can't be masked. + assert(filter & URing.POLLHUP != 0) // both of these are reported + assert(filter & URing.POLLERR != 0) // always and can't be masked. return filter } @@ -116,10 +116,12 @@ extension Selector: _SelectorBackendProtocol { self.eventFD = try EventFd.eventfd(initval: 0, flags: Int32(EventFd.EFD_CLOEXEC | EventFd.EFD_NONBLOCK)) - ring.io_uring_prep_poll_add(fileDescriptor: self.eventFD, - pollMask: URing.POLLIN, - registrationID:SelectorRegistrationID(rawValue: 0), - multishot:false) // wakeups + ring.io_uring_prep_poll_add( + fileDescriptor: self.eventFD, + pollMask: URing.POLLIN, + registrationID: SelectorRegistrationID(rawValue: 0), + multishot: false + ) // wakeups self.lifecycleState = .open _debugPrint("URingSelector up and running fd [\(self.selectorFD)] wakeups on event_fd [\(self.eventFD)]") @@ -129,60 +131,83 @@ extension Selector: _SelectorBackendProtocol { assert(self.eventFD == -1, "self.eventFD == \(self.eventFD) on deinitAssertions0 deinit, forgot close?") } - func register0(selectable: S, - fileDescriptor: CInt, - interested: SelectorEventSet, - registrationID: SelectorRegistrationID) throws { - _debugPrint("register interested \(interested) uringEventSet [\(interested.uringEventSet)] registrationID[\(registrationID)]") + func register0( + selectable: S, + fileDescriptor: CInt, + interested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { + _debugPrint( + "register interested \(interested) uringEventSet [\(interested.uringEventSet)] registrationID[\(registrationID)]" + ) self.deferredReregistrationsPending = true - ring.io_uring_prep_poll_add(fileDescriptor: fileDescriptor, - pollMask: interested.uringEventSet, - registrationID: registrationID, - submitNow: !deferReregistrations, - multishot: multishot) + ring.io_uring_prep_poll_add( + fileDescriptor: fileDescriptor, + pollMask: interested.uringEventSet, + registrationID: registrationID, + submitNow: !deferReregistrations, + multishot: multishot + ) } - func reregister0(selectable: S, - fileDescriptor: CInt, - oldInterested: SelectorEventSet, - newInterested: SelectorEventSet, - registrationID: SelectorRegistrationID) throws { - _debugPrint("Re-register old \(oldInterested) new \(newInterested) uringEventSet [\(oldInterested.uringEventSet)] reg.uringEventSet [\(newInterested.uringEventSet)]") + func reregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + newInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { + _debugPrint( + "Re-register old \(oldInterested) new \(newInterested) uringEventSet [\(oldInterested.uringEventSet)] reg.uringEventSet [\(newInterested.uringEventSet)]" + ) self.deferredReregistrationsPending = true if multishot { - ring.io_uring_poll_update(fileDescriptor: fileDescriptor, - newPollmask: newInterested.uringEventSet, - oldPollmask: oldInterested.uringEventSet, - registrationID: registrationID, - submitNow: !deferReregistrations, - multishot: true) + ring.io_uring_poll_update( + fileDescriptor: fileDescriptor, + newPollmask: newInterested.uringEventSet, + oldPollmask: oldInterested.uringEventSet, + registrationID: registrationID, + submitNow: !deferReregistrations, + multishot: true + ) } else { - ring.io_uring_prep_poll_remove(fileDescriptor: fileDescriptor, - pollMask: oldInterested.uringEventSet, - registrationID: registrationID, - submitNow:!deferReregistrations, - link: true) // next event linked will cancel if this event fails - - ring.io_uring_prep_poll_add(fileDescriptor: fileDescriptor, - pollMask: newInterested.uringEventSet, - registrationID: registrationID, - submitNow: !deferReregistrations, - multishot: false) + ring.io_uring_prep_poll_remove( + fileDescriptor: fileDescriptor, + pollMask: oldInterested.uringEventSet, + registrationID: registrationID, + submitNow: !deferReregistrations, + link: true + ) // next event linked will cancel if this event fails + + ring.io_uring_prep_poll_add( + fileDescriptor: fileDescriptor, + pollMask: newInterested.uringEventSet, + registrationID: registrationID, + submitNow: !deferReregistrations, + multishot: false + ) } } - func deregister0(selectable: S, fileDescriptor: CInt, oldInterested: SelectorEventSet, registrationID: SelectorRegistrationID) throws { + func deregister0( + selectable: S, + fileDescriptor: CInt, + oldInterested: SelectorEventSet, + registrationID: SelectorRegistrationID + ) throws { _debugPrint("deregister interested \(selectable) reg.interested.uringEventSet [\(oldInterested.uringEventSet)]") self.deferredReregistrationsPending = true - ring.io_uring_prep_poll_remove(fileDescriptor: fileDescriptor, - pollMask: oldInterested.uringEventSet, - registrationID: registrationID, - submitNow:!deferReregistrations) + ring.io_uring_prep_poll_remove( + fileDescriptor: fileDescriptor, + pollMask: oldInterested.uringEventSet, + registrationID: registrationID, + submitNow: !deferReregistrations + ) } - private func shouldRefreshPollForEvent(selectorEvent:SelectorEventSet) -> Bool { + private func shouldRefreshPollForEvent(selectorEvent: SelectorEventSet) -> Bool { if selectorEvent.contains(.read) { // as we don't do exhaustive reads, we need to prod the kernel for // new events, would be even better if we knew if we had read all there is @@ -197,7 +222,11 @@ extension Selector: _SelectorBackendProtocol { /// - parameters: /// - strategy: The `SelectorStrategy` to apply /// - body: The function to execute for each `SelectorEvent` that was produced. - func whenReady0(strategy: SelectorStrategy, onLoopBegin loopStart: () -> Void, _ body: (SelectorEvent) throws -> Void) throws -> Void { + func whenReady0( + strategy: SelectorStrategy, + onLoopBegin loopStart: () -> Void, + _ body: (SelectorEvent) throws -> Void + ) throws { assert(self.myThread == NIOThread.current) guard self.lifecycleState == .open else { throw IOError(errnoCode: EBADF, reason: "can't call whenReady for selector as it's \(self.lifecycleState).") @@ -219,17 +248,32 @@ extension Selector: _SelectorBackendProtocol { switch strategy { case .now: _debugPrint("whenReady.now") - ready = Int(ring.io_uring_peek_batch_cqe(events: events, maxevents: UInt32(eventsCapacity), multishot:multishot)) + ready = Int( + ring.io_uring_peek_batch_cqe(events: events, maxevents: UInt32(eventsCapacity), multishot: multishot) + ) case .blockUntilTimeout(let timeAmount): _debugPrint("whenReady.blockUntilTimeout") - ready = try Int(ring.io_uring_wait_cqe_timeout(events: events, maxevents: UInt32(eventsCapacity), timeout:timeAmount, multishot:multishot)) + ready = try Int( + ring.io_uring_wait_cqe_timeout( + events: events, + maxevents: UInt32(eventsCapacity), + timeout: timeAmount, + multishot: multishot + ) + ) case .block: _debugPrint("whenReady.block") - ready = Int(ring.io_uring_peek_batch_cqe(events: events, maxevents: UInt32(eventsCapacity), multishot:multishot)) // first try to consume any existing + ready = Int( + ring.io_uring_peek_batch_cqe(events: events, maxevents: UInt32(eventsCapacity), multishot: multishot) + ) // first try to consume any existing - if (ready <= 0) // otherwise block (only single supported, but we will use batch peek cqe next run around... + if ready <= 0 // otherwise block (only single supported, but we will use batch peek cqe next run around... { - ready = try ring.io_uring_wait_cqe(events: events, maxevents: UInt32(eventsCapacity), multishot:multishot) + ready = try ring.io_uring_wait_cqe( + events: events, + maxevents: UInt32(eventsCapacity), + multishot: multishot + ) } } @@ -239,27 +283,32 @@ extension Selector: _SelectorBackendProtocol { let event = events[i] switch event.fd { - case self.eventFD: // we don't run these as multishots to avoid tons of events when many wakeups are done - _debugPrint("wakeup successful for event.fd [\(event.fd)]") - var val = EventFd.eventfd_t() - ring.io_uring_prep_poll_add(fileDescriptor: self.eventFD, - pollMask: URing.POLLIN, - registrationID: SelectorRegistrationID(rawValue: 0), - submitNow: false, - multishot: false) - do { - _ = try EventFd.eventfd_read(fd: self.eventFD, value: &val) // consume wakeup event - _debugPrint("read val [\(val)] from event.fd [\(event.fd)]") - } catch { - } + case self.eventFD: // we don't run these as multishots to avoid tons of events when many wakeups are done + _debugPrint("wakeup successful for event.fd [\(event.fd)]") + var val = EventFd.eventfd_t() + ring.io_uring_prep_poll_add( + fileDescriptor: self.eventFD, + pollMask: URing.POLLIN, + registrationID: SelectorRegistrationID(rawValue: 0), + submitNow: false, + multishot: false + ) + do { + _ = try EventFd.eventfd_read(fd: self.eventFD, value: &val) // consume wakeup event + _debugPrint("read val [\(val)] from event.fd [\(event.fd)]") + } catch { + } default: if let registration = registrations[Int(event.fd)] { - _debugPrint("We found a registration for event.fd [\(event.fd)]") // \(registration) + _debugPrint("We found a registration for event.fd [\(event.fd)]") // \(registration) // The io_uring backend only has 16 bits available for the registration id - guard event.registrationID == UInt16(truncatingIfNeeded:registration.registrationID.rawValue) else { - _debugPrint("The event.registrationID [\(event.registrationID)] != registration.selectableregistrationID [\(registration.registrationID)], skipping to next event") + guard event.registrationID == UInt16(truncatingIfNeeded: registration.registrationID.rawValue) + else { + _debugPrint( + "The event.registrationID [\(event.registrationID)] != registration.selectableregistrationID [\(registration.registrationID)], skipping to next event" + ) continue } @@ -273,15 +322,17 @@ extension Selector: _SelectorBackendProtocol { _debugPrint("intersection [\(selectorEvent)]") if selectorEvent.contains(.readEOF) { - _debugPrint("selectorEvent.contains(.readEOF) [\(selectorEvent.contains(.readEOF))]") + _debugPrint("selectorEvent.contains(.readEOF) [\(selectorEvent.contains(.readEOF))]") } - if multishot == false { // must be before guard, otherwise lost wake - ring.io_uring_prep_poll_add(fileDescriptor: event.fd, - pollMask: registration.interested.uringEventSet, - registrationID: registration.registrationID, - submitNow: false, - multishot: false) + if multishot == false { // must be before guard, otherwise lost wake + ring.io_uring_prep_poll_add( + fileDescriptor: event.fd, + pollMask: registration.interested.uringEventSet, + registrationID: registration.registrationID, + submitNow: false, + multishot: false + ) if event.pollCancelled { _debugPrint("Received event.pollCancelled") @@ -289,39 +340,51 @@ extension Selector: _SelectorBackendProtocol { } guard selectorEvent != ._none else { - _debugPrint("selectorEvent != ._none / [\(selectorEvent)] [\(registration.interested)] [\(SelectorEventSet(uringEvent: event.pollMask))] [\(event.pollMask)] [\(event.fd)]") + _debugPrint( + "selectorEvent != ._none / [\(selectorEvent)] [\(registration.interested)] [\(SelectorEventSet(uringEvent: event.pollMask))] [\(event.pollMask)] [\(event.fd)]" + ) continue } // This is only needed due to the edge triggered nature of liburing, possibly // we can get away with only updating (force triggering an event if available) for // partial reads (where we currently give up after N iterations) - if multishot && self.shouldRefreshPollForEvent(selectorEvent:selectorEvent) { // can be after guard as it is multishot - ring.io_uring_poll_update(fileDescriptor: event.fd, - newPollmask: registration.interested.uringEventSet, - oldPollmask: registration.interested.uringEventSet, - registrationID: registration.registrationID, - submitNow: false) + + // can be after guard as it is multishot + if multishot && self.shouldRefreshPollForEvent(selectorEvent: selectorEvent) { + ring.io_uring_poll_update( + fileDescriptor: event.fd, + newPollmask: registration.interested.uringEventSet, + oldPollmask: registration.interested.uringEventSet, + registrationID: registration.registrationID, + submitNow: false + ) } - _debugPrint("running body [\(NIOThread.current)] \(selectorEvent) \(SelectorEventSet(uringEvent: event.pollMask))") + _debugPrint( + "running body [\(NIOThread.current)] \(selectorEvent) \(SelectorEventSet(uringEvent: event.pollMask))" + ) try body((SelectorEvent(io: selectorEvent, registration: registration))) - } else { // remove any polling if we don't have a registration for it - _debugPrint("We had no registration for event.fd [\(event.fd)] event.pollMask [\(event.pollMask)] event.registrationID [\(event.registrationID)], it should be deregistered already") + } else { // remove any polling if we don't have a registration for it + _debugPrint( + "We had no registration for event.fd [\(event.fd)] event.pollMask [\(event.pollMask)] event.registrationID [\(event.registrationID)], it should be deregistered already" + ) if multishot == false { - ring.io_uring_prep_poll_remove(fileDescriptor: event.fd, - pollMask: event.pollMask, - registrationID: SelectorRegistrationID(rawValue: UInt32(event.registrationID)), - submitNow: false) + ring.io_uring_prep_poll_remove( + fileDescriptor: event.fd, + pollMask: event.pollMask, + registrationID: SelectorRegistrationID(rawValue: UInt32(event.registrationID)), + submitNow: false + ) } } } } - self.deferredReregistrationsPending = false // none pending as we will flush here - ring.io_uring_flush() // flush reregisteration of the polls if needed (nop in SQPOLL mode) + self.deferredReregistrationsPending = false // none pending as we will flush here + ring.io_uring_flush() // flush reregisteration of the polls if needed (nop in SQPOLL mode) growEventArrayIfNeeded(ready: ready) } @@ -337,7 +400,7 @@ extension Selector: _SelectorBackendProtocol { // Therefore, we assert here that close will always succeed and if not, that's a NIO bug we need to know // about. - ring.io_uring_queue_exit() // This closes the ring selector fd for us + ring.io_uring_queue_exit() // This closes the ring selector fd for us self.selectorFD = -1 try! Posix.close(descriptor: self.eventFD) @@ -346,14 +409,14 @@ extension Selector: _SelectorBackendProtocol { return } - /* attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) */ + // attention, this may (will!) be called from outside the event loop, ie. can't access mutable shared state (such as `self.open`) func wakeup0() throws { assert(NIOThread.current != self.myThread) try self.externalSelectorFDLock.withLock { - guard self.eventFD >= 0 else { - throw EventLoopError.shutdown - } - _ = try EventFd.eventfd_write(fd: self.eventFD, value: 1) + guard self.eventFD >= 0 else { + throw EventLoopError.shutdown + } + _ = try EventFd.eventfd_write(fd: self.eventFD, value: 1) } } } diff --git a/Sources/NIOPosix/ServerSocket.swift b/Sources/NIOPosix/ServerSocket.swift index 1868d62139..153bd36692 100644 --- a/Sources/NIOPosix/ServerSocket.swift +++ b/Sources/NIOPosix/ServerSocket.swift @@ -15,11 +15,15 @@ import NIOCore /// A server socket that can accept new connections. -/* final but tests */ class ServerSocket: BaseSocket, ServerSocketProtocol { +class ServerSocket: BaseSocket, ServerSocketProtocol { typealias SocketType = ServerSocket private let cleanupOnClose: Bool - public final class func bootstrap(protocolFamily: NIOBSDSocket.ProtocolFamily, host: String, port: Int) throws -> ServerSocket { + public final class func bootstrap( + protocolFamily: NIOBSDSocket.ProtocolFamily, + host: String, + port: Int + ) throws -> ServerSocket { let socket = try ServerSocket(protocolFamily: protocolFamily) try socket.bind(to: SocketAddress.makeAddressResolvingHost(host, port: port)) try socket.listen() @@ -34,8 +38,17 @@ import NIOCore /// argument to the socket syscall. Defaults to 0. /// - setNonBlocking: Set non-blocking mode on the socket. /// - throws: An `IOError` if creation of the socket failed. - init(protocolFamily: NIOBSDSocket.ProtocolFamily, protocolSubtype: NIOBSDSocket.ProtocolSubtype = .default, setNonBlocking: Bool = false) throws { - let sock = try BaseSocket.makeSocket(protocolFamily: protocolFamily, type: .stream, protocolSubtype: protocolSubtype, setNonBlocking: setNonBlocking) + init( + protocolFamily: NIOBSDSocket.ProtocolFamily, + protocolSubtype: NIOBSDSocket.ProtocolSubtype = .default, + setNonBlocking: Bool = false + ) throws { + let sock = try BaseSocket.makeSocket( + protocolFamily: protocolFamily, + type: .stream, + protocolSubtype: protocolSubtype, + setNonBlocking: setNonBlocking + ) switch protocolFamily { case .unix: cleanupOnClose = true @@ -52,10 +65,10 @@ import NIOCore /// - setNonBlocking: Set non-blocking mode on the socket. /// - throws: An `IOError` if socket is invalid. #if !os(Windows) - @available(*, deprecated, renamed: "init(socket:setNonBlocking:)") - convenience init(descriptor: CInt, setNonBlocking: Bool = false) throws { - try self.init(socket: descriptor, setNonBlocking: setNonBlocking) - } + @available(*, deprecated, renamed: "init(socket:setNonBlocking:)") + convenience init(descriptor: CInt, setNonBlocking: Bool = false) throws { + try self.init(socket: descriptor, setNonBlocking: setNonBlocking) + } #endif /// Create a new instance. @@ -90,7 +103,7 @@ import NIOCore /// - returns: A `Socket` once a new connection was established or `nil` if this `ServerSocket` is in non-blocking mode and there is no new connection that can be accepted when this method is called. /// - throws: An `IOError` if the operation failed. func accept(setNonBlocking: Bool = false) throws -> Socket? { - return try withUnsafeHandle { fd in + try withUnsafeHandle { fd in #if os(Linux) let flags: Int32 if setNonBlocking { diff --git a/Sources/NIOPosix/Socket.swift b/Sources/NIOPosix/Socket.swift index 9bd16a80bf..450c3a6173 100644 --- a/Sources/NIOPosix/Socket.swift +++ b/Sources/NIOPosix/Socket.swift @@ -18,7 +18,7 @@ import NIOCore typealias IOVector = iovec // TODO: scattering support -/* final but tests */ class Socket: BaseSocket, SocketProtocol { +class Socket: BaseSocket, SocketProtocol { typealias SocketType = Socket /// The maximum number of bytes to write per `writev` call. @@ -58,10 +58,10 @@ typealias IOVector = iovec /// - setNonBlocking: Set non-blocking mode on the socket. /// - throws: An `IOError` if could not change the socket into non-blocking #if !os(Windows) - @available(*, deprecated, renamed: "init(socket:setNonBlocking:)") - convenience init(descriptor: CInt, setNonBlocking: Bool) throws { - try self.init(socket: descriptor, setNonBlocking: setNonBlocking) - } + @available(*, deprecated, renamed: "init(socket:setNonBlocking:)") + convenience init(descriptor: CInt, setNonBlocking: Bool) throws { + try self.init(socket: descriptor, setNonBlocking: setNonBlocking) + } #endif /// Create a new instance out of an already established socket. @@ -85,10 +85,10 @@ typealias IOVector = iovec /// - parameters: /// - descriptor: The file descriptor to wrap. #if !os(Windows) - @available(*, deprecated, renamed: "init(socket:)") - convenience init(descriptor: CInt) throws { - try self.init(socket: descriptor) - } + @available(*, deprecated, renamed: "init(socket:)") + convenience init(descriptor: CInt) throws { + try self.init(socket: descriptor) + } #endif /// Create a new instance. @@ -109,19 +109,25 @@ typealias IOVector = iovec /// - returns: `true` if the connection attempt completes, `false` if `finishConnect` must be called later to complete the connection attempt. /// - throws: An `IOError` if the operation failed. func connect(to address: SocketAddress) throws -> Bool { - return try withUnsafeHandle { fd in - return try address.withSockAddr { (ptr, size) in - return try NIOBSDSocket.connect(socket: fd, address: ptr, - address_len: socklen_t(size)) + try withUnsafeHandle { fd in + try address.withSockAddr { (ptr, size) in + try NIOBSDSocket.connect( + socket: fd, + address: ptr, + address_len: socklen_t(size) + ) } } } func connect(to address: VsockAddress) throws -> Bool { - return try withUnsafeHandle { fd in - return try address.withSockAddr { (ptr, size) in - return try NIOBSDSocket.connect(socket: fd, address: ptr, - address_len: socklen_t(size)) + try withUnsafeHandle { fd in + try address.withSockAddr { (ptr, size) in + try NIOBSDSocket.connect( + socket: fd, + address: ptr, + address_len: socklen_t(size) + ) } } } @@ -143,9 +149,12 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be written and if the operation returned before all could be written (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func write(pointer: UnsafeRawBufferPointer) throws -> IOResult { - return try withUnsafeHandle { - try NIOBSDSocket.send(socket: $0, buffer: pointer.baseAddress!, - length: pointer.count) + try withUnsafeHandle { + try NIOBSDSocket.send( + socket: $0, + buffer: pointer.baseAddress!, + length: pointer.count + ) } } @@ -156,7 +165,7 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be written and if the operation returned before all could be written (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func writev(iovecs: UnsafeBufferPointer) throws -> IOResult { - return try withUnsafeHandle { + try withUnsafeHandle { try Posix.writev(descriptor: $0, iovecs: iovecs) } } @@ -171,30 +180,42 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be written and if the operation returned before all could be written /// (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. - func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer?, - destinationSize: socklen_t, - controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult { + func sendmsg( + pointer: UnsafeRawBufferPointer, + destinationPtr: UnsafePointer?, + destinationSize: socklen_t, + controlBytes: UnsafeMutableRawBufferPointer + ) throws -> IOResult { // Dubious const casts - it should be OK as there is no reason why this should get mutated // just bad const declaration below us. - var vec = IOVector(iov_base: UnsafeMutableRawPointer(mutating: pointer.baseAddress!), iov_len: numericCast(pointer.count)) + var vec = IOVector( + iov_base: UnsafeMutableRawPointer(mutating: pointer.baseAddress!), + iov_len: numericCast(pointer.count) + ) let notConstCorrectDestinationPtr = UnsafeMutableRawPointer(mutating: destinationPtr) return try withUnsafeHandle { handle in - return try withUnsafeMutablePointer(to: &vec) { vecPtr in -#if os(Windows) + try withUnsafeMutablePointer(to: &vec) { vecPtr in + #if os(Windows) var messageHeader = - WSAMSG(name: notConstCorrectDestinationPtr - .assumingMemoryBound(to: sockaddr.self), - namelen: destinationSize, - lpBuffers: vecPtr, - dwBufferCount: 1, - Control: WSABUF(len: ULONG(controlBytes.count), - buf: controlBytes.baseAddress? - .bindMemory(to: CHAR.self, - capacity: controlBytes.count)), - dwFlags: 0) -#else + WSAMSG( + name: + notConstCorrectDestinationPtr + .assumingMemoryBound(to: sockaddr.self), + namelen: destinationSize, + lpBuffers: vecPtr, + dwBufferCount: 1, + Control: WSABUF( + len: ULONG(controlBytes.count), + buf: controlBytes.baseAddress? + .bindMemory( + to: CHAR.self, + capacity: controlBytes.count + ) + ), + dwFlags: 0 + ) + #else var messageHeader = msghdr() messageHeader.msg_name = notConstCorrectDestinationPtr messageHeader.msg_namelen = destinationSize @@ -203,7 +224,7 @@ typealias IOVector = iovec messageHeader.msg_control = controlBytes.baseAddress messageHeader.msg_controllen = .init(controlBytes.count) messageHeader.msg_flags = 0 -#endif + #endif return try NIOBSDSocket.sendmsg(socket: handle, msgHdr: &messageHeader, flags: 0) } } @@ -216,7 +237,7 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be read and if the operation returned before all could be read (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult { - return try withUnsafeHandle { + try withUnsafeHandle { try Posix.read(descriptor: $0, pointer: pointer.baseAddress!, size: pointer.count) } } @@ -231,28 +252,38 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be received and if the operation returned before all the data could be received /// (because the socket is in non-blocking mode) /// - throws: An `IOError` if the operation failed. - func recvmsg(pointer: UnsafeMutableRawBufferPointer, - storage: inout sockaddr_storage, - storageLen: inout socklen_t, - controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult { + func recvmsg( + pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout UnsafeReceivedControlBytes + ) throws -> IOResult { var vec = IOVector(iov_base: pointer.baseAddress, iov_len: numericCast(pointer.count)) return try withUnsafeMutablePointer(to: &vec) { vecPtr in - return try storage.withMutableSockAddr { (sockaddrPtr, _) in -#if os(Windows) + try storage.withMutableSockAddr { (sockaddrPtr, _) in + #if os(Windows) var messageHeader = - WSAMSG(name: sockaddrPtr, namelen: storageLen, - lpBuffers: vecPtr, dwBufferCount: 1, - Control: WSABUF(len: ULONG(controlBytes.controlBytesBuffer.count), - buf: controlBytes.controlBytesBuffer.baseAddress? - .bindMemory(to: CHAR.self, - capacity: controlBytes.controlBytesBuffer.count)), - dwFlags: 0) + WSAMSG( + name: sockaddrPtr, + namelen: storageLen, + lpBuffers: vecPtr, + dwBufferCount: 1, + Control: WSABUF( + len: ULONG(controlBytes.controlBytesBuffer.count), + buf: controlBytes.controlBytesBuffer.baseAddress? + .bindMemory( + to: CHAR.self, + capacity: controlBytes.controlBytesBuffer.count + ) + ), + dwFlags: 0 + ) defer { // We need to write back the length of the message. storageLen = messageHeader.namelen } -#else + #else var messageHeader = msghdr() messageHeader.msg_name = .init(sockaddrPtr) messageHeader.msg_namelen = storageLen @@ -265,11 +296,11 @@ typealias IOVector = iovec // We need to write back the length of the message. storageLen = messageHeader.msg_namelen } -#endif + #endif let result = try withUnsafeMutablePointer(to: &messageHeader) { messageHeader in - return try withUnsafeHandle { fd in - return try NIOBSDSocket.recvmsg(socket: fd, msgHdr: messageHeader, flags: 0) + try withUnsafeHandle { fd in + try NIOBSDSocket.recvmsg(socket: fd, msgHdr: messageHeader, flags: 0) } } @@ -292,9 +323,13 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how much data could be send and if the operation returned before all could be send (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func sendFile(fd: CInt, offset: Int, count: Int) throws -> IOResult { - return try withUnsafeHandle { - try NIOBSDSocket.sendfile(socket: $0, fd: fd, offset: off_t(offset), - len: off_t(count)) + try withUnsafeHandle { + try NIOBSDSocket.sendfile( + socket: $0, + fd: fd, + offset: off_t(offset), + len: off_t(count) + ) } } @@ -305,10 +340,14 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how many messages could be received and if the operation returned before all messages could be received (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func recvmmsg(msgs: UnsafeMutableBufferPointer) throws -> IOResult { - return try withUnsafeHandle { - try NIOBSDSocket.recvmmsg(socket: $0, msgvec: msgs.baseAddress!, - vlen: CUnsignedInt(msgs.count), flags: 0, - timeout: nil) + try withUnsafeHandle { + try NIOBSDSocket.recvmmsg( + socket: $0, + msgvec: msgs.baseAddress!, + vlen: CUnsignedInt(msgs.count), + flags: 0, + timeout: nil + ) } } @@ -319,9 +358,13 @@ typealias IOVector = iovec /// - returns: The `IOResult` which indicates how many messages could be send and if the operation returned before all messages could be send (because the socket is in non-blocking mode). /// - throws: An `IOError` if the operation failed. func sendmmsg(msgs: UnsafeMutableBufferPointer) throws -> IOResult { - return try withUnsafeHandle { - try NIOBSDSocket.sendmmsg(socket: $0, msgvec: msgs.baseAddress!, - vlen: CUnsignedInt(msgs.count), flags: 0) + try withUnsafeHandle { + try NIOBSDSocket.sendmmsg( + socket: $0, + msgvec: msgs.baseAddress!, + vlen: CUnsignedInt(msgs.count), + flags: 0 + ) } } @@ -331,7 +374,7 @@ typealias IOVector = iovec /// - how: the mode of `Shutdown`. /// - throws: An `IOError` if the operation failed. func shutdown(how: Shutdown) throws { - return try withUnsafeHandle { + try withUnsafeHandle { try NIOBSDSocket.shutdown(socket: $0, how: how) } } @@ -345,7 +388,7 @@ typealias IOVector = iovec /// Returns the value of the 'UDP_SEGMENT' socket option. func getUDPSegmentSize() throws -> CInt { - return try self.withUnsafeHandle { + try self.withUnsafeHandle { try NIOBSDSocket.getUDPSegmentSize(socket: $0) } } @@ -359,7 +402,7 @@ typealias IOVector = iovec /// Returns the value of the 'UDP_GRO' socket option. func getUDPReceiveOffload() throws -> Bool { - return try self.withUnsafeHandle { + try self.withUnsafeHandle { try NIOBSDSocket.getUDPReceiveOffload(socket: $0) } } diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 67c5516378..a299a8ec75 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -28,7 +28,9 @@ import struct WinSDK.socklen_t #endif extension ByteBuffer { - mutating func withMutableWritePointer(body: (UnsafeMutableRawBufferPointer) throws -> IOResult) rethrows -> IOResult { + mutating func withMutableWritePointer( + body: (UnsafeMutableRawBufferPointer) throws -> IOResult + ) rethrows -> IOResult { var singleResult: IOResult! _ = try self.writeWithUnsafeMutableBytes(minimumWritableBytes: 0) { ptr in let localWriteResult = try body(ptr) @@ -50,7 +52,8 @@ extension ByteBuffer { final class SocketChannel: BaseStreamSocketChannel { private var connectTimeout: TimeAmount? = nil - init(eventLoop: SelectableEventLoop, protocolFamily: NIOBSDSocket.ProtocolFamily, enableMPTCP: Bool = false) throws { + init(eventLoop: SelectableEventLoop, protocolFamily: NIOBSDSocket.ProtocolFamily, enableMPTCP: Bool = false) throws + { var protocolSubtype = NIOBSDSocket.ProtocolSubtype.default if enableMPTCP { guard let subtype = NIOBSDSocket.ProtocolSubtype.mptcp else { @@ -58,17 +61,37 @@ final class SocketChannel: BaseStreamSocketChannel { } protocolSubtype = subtype } - let socket = try Socket(protocolFamily: protocolFamily, type: .stream, protocolSubtype: protocolSubtype, setNonBlocking: true) - try super.init(socket: socket, parent: nil, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) + let socket = try Socket( + protocolFamily: protocolFamily, + type: .stream, + protocolSubtype: protocolSubtype, + setNonBlocking: true + ) + try super.init( + socket: socket, + parent: nil, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator() + ) } init(eventLoop: SelectableEventLoop, socket: NIOBSDSocket.Handle) throws { let sock = try Socket(socket: socket, setNonBlocking: true) - try super.init(socket: sock, parent: nil, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) + try super.init( + socket: sock, + parent: nil, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator() + ) } init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws { - try super.init(socket: socket, parent: parent, eventLoop: eventLoop, recvAllocator: AdaptiveRecvByteBufferAllocator()) + try super.init( + socket: socket, + parent: parent, + eventLoop: eventLoop, + recvAllocator: AdaptiveRecvByteBufferAllocator() + ) } override func setOption0(_ option: Option, value: Option.Value) throws { @@ -104,9 +127,11 @@ final class SocketChannel: BaseStreamSocketChannel { } func registrationFor(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { - return NIORegistration(channel: .socketChannel(self), - interested: interested, - registrationID: registrationID) + NIORegistration( + channel: .socketChannel(self), + interested: interested, + registrationID: registrationID + ) } private func scheduleConnectTimeout() { @@ -145,11 +170,12 @@ final class SocketChannel: BaseStreamSocketChannel { try self.socket.finishConnect() } - override func register(selector: Selector, interested: SelectorEventSet) throws { - try selector.register(selectable: self.socket, - interested: interested, - makeRegistration: self.registrationFor) + try selector.register( + selectable: self.socket, + interested: interested, + makeRegistration: self.registrationFor + ) } override func deregister(selector: Selector, mode: CloseMode) throws { @@ -172,9 +198,14 @@ final class ServerSocketChannel: BaseSocketChannel { /// The server socket channel is never writable. // This is `Channel` API so must be thread-safe. - override public var isWritable: Bool { return false } + override public var isWritable: Bool { false } - convenience init(eventLoop: SelectableEventLoop, group: EventLoopGroup, protocolFamily: NIOBSDSocket.ProtocolFamily, enableMPTCP: Bool = false) throws { + convenience init( + eventLoop: SelectableEventLoop, + group: EventLoopGroup, + protocolFamily: NIOBSDSocket.ProtocolFamily, + enableMPTCP: Bool = false + ) throws { var protocolSubtype = NIOBSDSocket.ProtocolSubtype.default if enableMPTCP { guard let subtype = NIOBSDSocket.ProtocolSubtype.mptcp else { @@ -182,7 +213,15 @@ final class ServerSocketChannel: BaseSocketChannel { } protocolSubtype = subtype } - try self.init(serverSocket: try ServerSocket(protocolFamily: protocolFamily, protocolSubtype: protocolSubtype, setNonBlocking: true), eventLoop: eventLoop, group: group) + try self.init( + serverSocket: try ServerSocket( + protocolFamily: protocolFamily, + protocolSubtype: protocolSubtype, + setNonBlocking: true + ), + eventLoop: eventLoop, + group: group + ) } init(serverSocket: ServerSocket, eventLoop: SelectableEventLoop, group: EventLoopGroup) throws { @@ -203,9 +242,11 @@ final class ServerSocketChannel: BaseSocketChannel { } func registrationFor(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { - return NIORegistration(channel: .serverSocketChannel(self), - interested: interested, - registrationID: registrationID) + NIORegistration( + channel: .serverSocketChannel(self), + interested: interested, + registrationID: registrationID + ) } override func setOption0(_ option: Option, value: Option.Value) throws { @@ -262,7 +303,7 @@ final class ServerSocketChannel: BaseSocketChannel { p.futureResult.map { // It's important to call the methods before we actually notify the original promise for ordering reasons. self.becomeActive0(promise: promise) - }.whenFailure{ error in + }.whenFailure { error in promise?.fail(error) } executeAndComplete(p) { @@ -299,9 +340,11 @@ final class ServerSocketChannel: BaseSocketChannel { readPending = false result = .some do { - let chan = try SocketChannel(socket: accepted, - parent: self, - eventLoop: group.next() as! SelectableEventLoop) + let chan = try SocketChannel( + socket: accepted, + parent: self, + eventLoop: group.next() as! SelectableEventLoop + ) assert(self.isActive) self.pipeline.syncOperations.fireChannelRead(NIOAny(chan)) } catch { @@ -328,10 +371,10 @@ final class ServerSocketChannel: BaseSocketChannel { switch err.errnoCode { case ECONNABORTED, - EMFILE, - ENFILE, - ENOBUFS, - ENOMEM: + EMFILE, + ENFILE, + ENOBUFS, + ENOMEM: // These are errors we may be able to recover from. The user may just want to stop accepting connections for example // or provide some other means of back-pressure. This could be achieved by a custom ChannelDuplexHandler. return false @@ -362,7 +405,7 @@ final class ServerSocketChannel: BaseSocketChannel { } override func hasFlushedPendingWrites() -> Bool { - return false + false } override func bufferPendingWrite(data: NIOAny, promise: EventLoopPromise?) { @@ -374,13 +417,15 @@ final class ServerSocketChannel: BaseSocketChannel { } override func flushNow() -> IONotificationState { - return IONotificationState.unregister + IONotificationState.unregister } override func register(selector: Selector, interested: SelectorEventSet) throws { - try selector.register(selectable: self.socket, - interested: interested, - makeRegistration: self.registrationFor) + try selector.register( + selectable: self.socket, + interested: interested, + makeRegistration: self.registrationFor + ) } override func deregister(selector: Selector, mode: CloseMode) throws { @@ -416,7 +461,7 @@ final class DatagramChannel: BaseSocketChannel { private var vectorReadManager: Optional // This is `Channel` API so must be thread-safe. override public var isWritable: Bool { - return pendingWrites.isWritable + pendingWrites.isWritable } override var isOpen: Bool { @@ -461,8 +506,10 @@ final class DatagramChannel: BaseSocketChannel { throw err } - self.pendingWrites = PendingDatagramWritesManager(bufferPool: eventLoop.bufferPool, - msgBufferPool: eventLoop.msgBufferPool) + self.pendingWrites = PendingDatagramWritesManager( + bufferPool: eventLoop.bufferPool, + msgBufferPool: eventLoop.msgBufferPool + ) try super.init( socket: socket, @@ -476,8 +523,10 @@ final class DatagramChannel: BaseSocketChannel { init(socket: Socket, parent: Channel? = nil, eventLoop: SelectableEventLoop) throws { self.vectorReadManager = nil try socket.setNonBlocking() - self.pendingWrites = PendingDatagramWritesManager(bufferPool: eventLoop.bufferPool, - msgBufferPool: eventLoop.msgBufferPool) + self.pendingWrites = PendingDatagramWritesManager( + bufferPool: eventLoop.bufferPool, + msgBufferPool: eventLoop.msgBufferPool + ) try super.init( socket: socket, parent: parent, @@ -513,14 +562,18 @@ final class DatagramChannel: BaseSocketChannel { switch self.localAddress?.protocol { case .some(.inet): self.reportExplicitCongestionNotifications = true - try self.socket.setOption(level: .ip, - name: .ip_recv_tos, - value: valueAsInt) + try self.socket.setOption( + level: .ip, + name: .ip_recv_tos, + value: valueAsInt + ) case .some(.inet6): self.reportExplicitCongestionNotifications = true - try self.socket.setOption(level: .ipv6, - name: .ipv6_recv_tclass, - value: valueAsInt) + try self.socket.setOption( + level: .ipv6, + name: .ipv6_recv_tclass, + value: valueAsInt + ) default: // Explicit congestion notification is only supported for IP throw ChannelError._operationUnsupported @@ -530,14 +583,18 @@ final class DatagramChannel: BaseSocketChannel { switch self.localAddress?.protocol { case .some(.inet): self.receivePacketInfo = true - try self.socket.setOption(level: .ip, - name: .ip_recv_pktinfo, - value: valueAsInt) + try self.socket.setOption( + level: .ip, + name: .ip_recv_pktinfo, + value: valueAsInt + ) case .some(.inet6): self.receivePacketInfo = true - try self.socket.setOption(level: .ipv6, - name: .ipv6_recv_pktinfo, - value: valueAsInt) + try self.socket.setOption( + level: .ipv6, + name: .ipv6_recv_pktinfo, + value: valueAsInt + ) default: // Receiving packet info is only supported for IP throw ChannelError._operationUnsupported @@ -576,11 +633,17 @@ final class DatagramChannel: BaseSocketChannel { case _ as ChannelOptions.Types.ExplicitCongestionNotificationsOption: switch self.localAddress?.protocol { case .some(.inet): - return try (self.socket.getOption(level: .ip, - name: .ip_recv_tos) != 0) as! Option.Value + return try + (self.socket.getOption( + level: .ip, + name: .ip_recv_tos + ) != 0) as! Option.Value case .some(.inet6): - return try (self.socket.getOption(level: .ipv6, - name: .ipv6_recv_tclass) != 0) as! Option.Value + return try + (self.socket.getOption( + level: .ipv6, + name: .ipv6_recv_tclass + ) != 0) as! Option.Value default: // Explicit congestion notification is only supported for IP throw ChannelError._operationUnsupported @@ -588,11 +651,17 @@ final class DatagramChannel: BaseSocketChannel { case _ as ChannelOptions.Types.ReceivePacketInfo: switch self.localAddress?.protocol { case .some(.inet): - return try (self.socket.getOption(level: .ip, - name: .ip_recv_pktinfo) != 0) as! Option.Value + return try + (self.socket.getOption( + level: .ip, + name: .ip_recv_pktinfo + ) != 0) as! Option.Value case .some(.inet6): - return try (self.socket.getOption(level: .ipv6, - name: .ipv6_recv_pktinfo) != 0) as! Option.Value + return try + (self.socket.getOption( + level: .ipv6, + name: .ipv6_recv_pktinfo + ) != 0) as! Option.Value default: // Receiving packet info is only supported for IP throw ChannelError._operationUnsupported @@ -613,9 +682,11 @@ final class DatagramChannel: BaseSocketChannel { } func registrationFor(interested: SelectorEventSet, registrationID: SelectorRegistrationID) -> NIORegistration { - return NIORegistration(channel: .datagramChannel(self), - interested: interested, - registrationID: registrationID) + NIORegistration( + channel: .datagramChannel(self), + interested: interested, + registrationID: registrationID + ) } override func connectSocket(to address: SocketAddress) throws -> Bool { @@ -624,8 +695,10 @@ final class DatagramChannel: BaseSocketChannel { self.pendingWrites.failAll( error: IOError( errnoCode: EISCONN, - reason: "Socket was connected before flushing pending write."), - close: false) + reason: "Socket was connected before flushing pending write." + ), + close: false + ) } if try self.socket.connect(to: address) { self.pendingWrites.markConnected(to: address) @@ -651,10 +724,12 @@ final class DatagramChannel: BaseSocketChannel { let pooledMsgBuffer = self.selectableEventLoop.msgBufferPool.get() defer { self.selectableEventLoop.msgBufferPool.put(pooledMsgBuffer) } return try pooledMsgBuffer.withUnsafePointers { _, _, controlMessageStorage in - return try self.singleReadFromSocket(controlBytesBuffer: controlMessageStorage[0]) + try self.singleReadFromSocket(controlBytesBuffer: controlMessageStorage[0]) } } else { - return try self.singleReadFromSocket(controlBytesBuffer: UnsafeMutableRawBufferPointer(start: nil, count: 0)) + return try self.singleReadFromSocket( + controlBytesBuffer: UnsafeMutableRawBufferPointer(start: nil, count: 0) + ) } } @@ -671,11 +746,13 @@ final class DatagramChannel: BaseSocketChannel { var controlBytes = UnsafeReceivedControlBytes(controlBytesBuffer: controlBytesBuffer) let (buffer, result) = try self.recvBufferPool.buffer(allocator: self.allocator) { buffer in - return try buffer.withMutableWritePointer { pointer in - try self.socket.recvmsg(pointer: pointer, - storage: &rawAddress, - storageLen: &rawAddressLength, - controlBytes: &controlBytes) + try buffer.withMutableWritePointer { pointer in + try self.socket.recvmsg( + pointer: pointer, + storage: &rawAddress, + storageLen: &rawAddressLength, + controlBytes: &controlBytes + ) } } @@ -689,15 +766,18 @@ final class DatagramChannel: BaseSocketChannel { let metadata: AddressedEnvelope.Metadata? if self.reportExplicitCongestionNotifications || self.receivePacketInfo, - let controlMessagesReceived = controlBytes.receivedControlMessages { + let controlMessagesReceived = controlBytes.receivedControlMessages + { metadata = .init(from: controlMessagesReceived) } else { metadata = nil } - let msg = AddressedEnvelope(remoteAddress: remoteAddress, - data: buffer, - metadata: metadata) + let msg = AddressedEnvelope( + remoteAddress: remoteAddress, + data: buffer, + metadata: metadata + ) assert(self.isActive) self.pipeline.syncOperations.fireChannelRead(NIOAny(msg)) readResult = .some @@ -723,12 +803,14 @@ final class DatagramChannel: BaseSocketChannel { break readLoop } - let (_, result) = try self.recvBufferPool.buffer(allocator: self.allocator) { buffer -> DatagramVectorReadManager.ReadResult in + let (_, result) = try self.recvBufferPool.buffer(allocator: self.allocator) { + buffer -> DatagramVectorReadManager.ReadResult in // This force-unwrap is safe, as we checked whether this is nil in the caller. try vectorReadManager.readFromSocket( socket: self.socket, buffer: &buffer, - parseControlMessages: self.reportExplicitCongestionNotifications || self.receivePacketInfo) + parseControlMessages: self.reportExplicitCongestionNotifications || self.receivePacketInfo + ) } switch result { @@ -765,7 +847,7 @@ final class DatagramChannel: BaseSocketChannel { // - https://bugzilla.redhat.com/show_bug.cgi?id=1375 // - https://lists.gt.net/linux/kernel/39575 case ECONNREFUSED, - ENOMEM: + ENOMEM: // These are errors we may be able to recover from. return false default: @@ -806,13 +888,17 @@ final class DatagramChannel: BaseSocketChannel { } /// Buffer a write in preparation for a flush. - private func bufferPendingAddressedWrite(envelope: AddressedEnvelope, promise: EventLoopPromise?) { + private func bufferPendingAddressedWrite(envelope: AddressedEnvelope, promise: EventLoopPromise?) + { // If the socket is connected, check the remote provided matches the connected address. if let connectedRemoteAddress = self.remoteAddress { guard envelope.remoteAddress == connectedRemoteAddress else { - promise?.fail(DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress( - envelopeRemoteAddress: envelope.remoteAddress, - connectedRemoteAddress: connectedRemoteAddress)) + promise?.fail( + DatagramChannelError.WriteOnConnectedSocketWithInvalidAddress( + envelopeRemoteAddress: envelope.remoteAddress, + connectedRemoteAddress: connectedRemoteAddress + ) + ) return } } @@ -824,7 +910,7 @@ final class DatagramChannel: BaseSocketChannel { } override final func hasFlushedPendingWrites() -> Bool { - return self.pendingWrites.isFlushPending + self.pendingWrites.isFlushPending } /// Mark a flush point. This is called when flush is received, and instructs @@ -848,16 +934,20 @@ final class DatagramChannel: BaseSocketChannel { defer { self.selectableEventLoop.msgBufferPool.put(msgBuffer) } return try msgBuffer.withUnsafePointers { _, _, controlMessageStorage in var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[0]) - controlBytes.appendExplicitCongestionState(metadata: metadata, - protocolFamily: self.localAddress?.protocol) - return try self.socket.sendmsg(pointer: ptr, - destinationPtr: destinationPtr, - destinationSize: destinationSize, - controlBytes: controlBytes.validControlBytes) + controlBytes.appendExplicitCongestionState( + metadata: metadata, + protocolFamily: self.localAddress?.protocol + ) + return try self.socket.sendmsg( + pointer: ptr, + destinationPtr: destinationPtr, + destinationSize: destinationSize, + controlBytes: controlBytes.validControlBytes + ) } }, vectorWriteOperation: { msgs in - return try self.socket.sendmmsg(msgs: msgs) + try self.socket.sendmmsg(msgs: msgs) } ) return result @@ -881,9 +971,11 @@ final class DatagramChannel: BaseSocketChannel { } override func register(selector: Selector, interested: SelectorEventSet) throws { - try selector.register(selectable: self.socket, - interested: interested, - makeRegistration: self.registrationFor) + try selector.register( + selectable: self.socket, + interested: interested, + makeRegistration: self.registrationFor + ) } override func deregister(selector: Selector, mode: CloseMode) throws { @@ -898,19 +990,19 @@ final class DatagramChannel: BaseSocketChannel { extension SocketChannel: CustomStringConvertible { var description: String { - return "SocketChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" + "SocketChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" } } extension ServerSocketChannel: CustomStringConvertible { var description: String { - return "ServerSocketChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" + "ServerSocketChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" } } extension DatagramChannel: CustomStringConvertible { var description: String { - return "DatagramChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" + "DatagramChannel { \(self.socketDescription), active = \(self.isActive), localAddress = \(self.localAddress.debugDescription), remoteAddress = \(self.remoteAddress.debugDescription) }" } } @@ -948,14 +1040,24 @@ extension DatagramChannel: MulticastChannel { } } -#if !os(Windows) + #if !os(Windows) @available(*, deprecated, renamed: "joinGroup(_:device:promise:)") func joinGroup(_ group: SocketAddress, interface: NIONetworkInterface?, promise: EventLoopPromise?) { if eventLoop.inEventLoop { - self.performGroupOperation0(group, device: interface.map { NIONetworkDevice($0) }, promise: promise, operation: .join) + self.performGroupOperation0( + group, + device: interface.map { NIONetworkDevice($0) }, + promise: promise, + operation: .join + ) } else { eventLoop.execute { - self.performGroupOperation0(group, device: interface.map { NIONetworkDevice($0) }, promise: promise, operation: .join) + self.performGroupOperation0( + group, + device: interface.map { NIONetworkDevice($0) }, + promise: promise, + operation: .join + ) } } } @@ -963,14 +1065,24 @@ extension DatagramChannel: MulticastChannel { @available(*, deprecated, renamed: "leaveGroup(_:device:promise:)") func leaveGroup(_ group: SocketAddress, interface: NIONetworkInterface?, promise: EventLoopPromise?) { if eventLoop.inEventLoop { - self.performGroupOperation0(group, device: interface.map { NIONetworkDevice($0) }, promise: promise, operation: .leave) + self.performGroupOperation0( + group, + device: interface.map { NIONetworkDevice($0) }, + promise: promise, + operation: .leave + ) } else { eventLoop.execute { - self.performGroupOperation0(group, device: interface.map { NIONetworkDevice($0) }, promise: promise, operation: .leave) + self.performGroupOperation0( + group, + device: interface.map { NIONetworkDevice($0) }, + promise: promise, + operation: .leave + ) } } } -#endif + #endif func joinGroup(_ group: SocketAddress, device: NIONetworkDevice?, promise: EventLoopPromise?) { if eventLoop.inEventLoop { @@ -995,10 +1107,12 @@ extension DatagramChannel: MulticastChannel { /// The implementation of `joinGroup` and `leaveGroup`. /// /// Joining and leaving a multicast group ultimately corresponds to a single, carefully crafted, socket option. - private func performGroupOperation0(_ group: SocketAddress, - device: NIONetworkDevice?, - promise: EventLoopPromise?, - operation: GroupOperation) { + private func performGroupOperation0( + _ group: SocketAddress, + device: NIONetworkDevice?, + promise: EventLoopPromise?, + operation: GroupOperation + ) { self.eventLoop.assertInEventLoop() guard self.isActive else { @@ -1039,20 +1153,37 @@ extension DatagramChannel: MulticastChannel { preconditionFailure("Should not be reachable, UNIX sockets are never multicast addresses") case (.v4(let groupAddress), .some(.v4(let interfaceAddress))): // IPv4Binding with specific target interface. - let multicastRequest = ip_mreq(imr_multiaddr: groupAddress.address.sin_addr, imr_interface: interfaceAddress.address.sin_addr) + let multicastRequest = ip_mreq( + imr_multiaddr: groupAddress.address.sin_addr, + imr_interface: interfaceAddress.address.sin_addr + ) try self.socket.setOption(level: .ip, name: operation.optionName(level: .ip), value: multicastRequest) case (.v4(let groupAddress), .none): // IPv4 binding without target interface. - let multicastRequest = ip_mreq(imr_multiaddr: groupAddress.address.sin_addr, imr_interface: in_addr(s_addr: INADDR_ANY)) + let multicastRequest = ip_mreq( + imr_multiaddr: groupAddress.address.sin_addr, + imr_interface: in_addr(s_addr: INADDR_ANY) + ) try self.socket.setOption(level: .ip, name: operation.optionName(level: .ip), value: multicastRequest) case (.v6(let groupAddress), .some(.v6)): // IPv6 binding with specific target interface. - let multicastRequest = ipv6_mreq(ipv6mr_multiaddr: groupAddress.address.sin6_addr, ipv6mr_interface: UInt32(device!.interfaceIndex)) - try self.socket.setOption(level: .ipv6, name: operation.optionName(level: .ipv6), value: multicastRequest) + let multicastRequest = ipv6_mreq( + ipv6mr_multiaddr: groupAddress.address.sin6_addr, + ipv6mr_interface: UInt32(device!.interfaceIndex) + ) + try self.socket.setOption( + level: .ipv6, + name: operation.optionName(level: .ipv6), + value: multicastRequest + ) case (.v6(let groupAddress), .none): // IPv6 binding with no specific interface requested. let multicastRequest = ipv6_mreq(ipv6mr_multiaddr: groupAddress.address.sin6_addr, ipv6mr_interface: 0) - try self.socket.setOption(level: .ipv6, name: operation.optionName(level: .ipv6), value: multicastRequest) + try self.socket.setOption( + level: .ipv6, + name: operation.optionName(level: .ipv6), + value: multicastRequest + ) case (.v4, .some(.v6)), (.v6, .some(.v4)), (.v4, .some(.unixDomainSocket)), (.v6, .some(.unixDomainSocket)): // Mismatched group and interface address: this is an error. throw ChannelError._badInterfaceAddressFamily diff --git a/Sources/NIOPosix/SocketProtocols.swift b/Sources/NIOPosix/SocketProtocols.swift index 90cca51fad..58236f9911 100644 --- a/Sources/NIOPosix/SocketProtocols.swift +++ b/Sources/NIOPosix/SocketProtocols.swift @@ -48,15 +48,19 @@ protocol SocketProtocol: BaseSocketProtocol { func read(pointer: UnsafeMutableRawBufferPointer) throws -> IOResult - func recvmsg(pointer: UnsafeMutableRawBufferPointer, - storage: inout sockaddr_storage, - storageLen: inout socklen_t, - controlBytes: inout UnsafeReceivedControlBytes) throws -> IOResult - - func sendmsg(pointer: UnsafeRawBufferPointer, - destinationPtr: UnsafePointer?, - destinationSize: socklen_t, - controlBytes: UnsafeMutableRawBufferPointer) throws -> IOResult + func recvmsg( + pointer: UnsafeMutableRawBufferPointer, + storage: inout sockaddr_storage, + storageLen: inout socklen_t, + controlBytes: inout UnsafeReceivedControlBytes + ) throws -> IOResult + + func sendmsg( + pointer: UnsafeRawBufferPointer, + destinationPtr: UnsafePointer?, + destinationSize: socklen_t, + controlBytes: UnsafeMutableRawBufferPointer + ) throws -> IOResult func sendFile(fd: CInt, offset: Int, count: Int) throws -> IOResult @@ -72,7 +76,7 @@ protocol SocketProtocol: BaseSocketProtocol { #if os(Linux) || os(Android) // This is a lazily initialised global variable that when read for the first time, will ignore SIGPIPE. private let globallyIgnoredSIGPIPE: Bool = { - /* no F_SETNOSIGPIPE on Linux :( */ + // no F_SETNOSIGPIPE on Linux :( #if canImport(Glibc) _ = Glibc.signal(SIGPIPE, SIG_IGN) #elseif canImport(Musl) @@ -99,7 +103,7 @@ extension BaseSocketProtocol { do { try Posix.fcntl(descriptor: fd, command: F_SETNOSIGPIPE, value: 1) } catch let error as IOError { - try? Posix.close(descriptor: fd) // don't care about failure here + try? Posix.close(descriptor: fd) // don't care about failure here if error.errnoCode == EINVAL { // Darwin seems to sometimes do this despite the docs claiming it can't happen throw NIOFcntlFailedError() @@ -113,7 +117,7 @@ extension BaseSocketProtocol { #if os(Windows) // Deliberately empty: SIGPIPE just ain't a thing on Windows #else - try ignoreSIGPIPE(descriptor: handle) + try ignoreSIGPIPE(descriptor: handle) #endif } } diff --git a/Sources/NIOPosix/System.swift b/Sources/NIOPosix/System.swift index 8a96793c02..b9f6b9710e 100644 --- a/Sources/NIOPosix/System.swift +++ b/Sources/NIOPosix/System.swift @@ -41,15 +41,17 @@ internal typealias MMsgHdr = CNIOWindows_mmsghdr #endif #if os(Android) -let INADDR_ANY = UInt32(0) // #define INADDR_ANY ((unsigned long int) 0x00000000) +let INADDR_ANY = UInt32(0) // #define INADDR_ANY ((unsigned long int) 0x00000000) let IFF_BROADCAST: CUnsignedInt = numericCast(SwiftGlibc.IFF_BROADCAST.rawValue) let IFF_POINTOPOINT: CUnsignedInt = numericCast(SwiftGlibc.IFF_POINTOPOINT.rawValue) let IFF_MULTICAST: CUnsignedInt = numericCast(SwiftGlibc.IFF_MULTICAST.rawValue) internal typealias in_port_t = UInt16 -extension ipv6_mreq { // http://lkml.iu.edu/hypermail/linux/kernel/0106.1/0080.html - init (ipv6mr_multiaddr: in6_addr, ipv6mr_interface: UInt32) { - self.init(ipv6mr_multiaddr: ipv6mr_multiaddr, - ipv6mr_ifindex: Int32(bitPattern: ipv6mr_interface)) +extension ipv6_mreq { // http://lkml.iu.edu/hypermail/linux/kernel/0106.1/0080.html + init(ipv6mr_multiaddr: in6_addr, ipv6mr_interface: UInt32) { + self.init( + ipv6mr_multiaddr: ipv6mr_multiaddr, + ipv6mr_ifindex: Int32(bitPattern: ipv6mr_interface) + ) } } #if arch(arm) @@ -86,11 +88,20 @@ private let sysPoll = poll #endif #if os(Android) -func sysRecvFrom_wrapper(sockfd: CInt, buf: UnsafeMutableRawPointer, len: CLong, flags: CInt, src_addr: UnsafeMutablePointer, addrlen: UnsafeMutablePointer) -> CLong { - return recvfrom(sockfd, buf, len, flags, src_addr, addrlen) // src_addr is 'UnsafeMutablePointer', but it need to be 'UnsafePointer' +func sysRecvFrom_wrapper( + sockfd: CInt, + buf: UnsafeMutableRawPointer, + len: CLong, + flags: CInt, + src_addr: UnsafeMutablePointer, + addrlen: UnsafeMutablePointer +) -> CLong { + // src_addr is 'UnsafeMutablePointer', but it need to be 'UnsafePointer' + recvfrom(sockfd, buf, len, flags, src_addr, addrlen) + // src_addr is 'UnsafeMutablePointer', but it need to be 'UnsafePointer' } func sysWritev_wrapper(fd: CInt, iov: UnsafePointer?, iovcnt: CInt) -> CLong { - return CLong(writev(fd, iov!, iovcnt)) // cast 'Int32' to 'CLong' + CLong(writev(fd, iov!, iovcnt)) // cast 'Int32' to 'CLong'// cast 'Int32' to 'CLong' } private let sysWritev = sysWritev_wrapper #elseif !os(Windows) @@ -102,8 +113,10 @@ private let sysSendMsg: @convention(c) (CInt, UnsafePointer?, CInt) -> s #endif private let sysDup: @convention(c) (CInt) -> CInt = dup #if !os(Windows) -private let sysGetpeername: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getpeername -private let sysGetsockname: @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getsockname +private let sysGetpeername: + @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getpeername +private let sysGetsockname: + @convention(c) (CInt, UnsafeMutablePointer?, UnsafeMutablePointer?) -> CInt = getsockname #endif #if os(Android) @@ -196,7 +209,7 @@ private func isUnacceptableErrnoForbiddingEINVAL(_ code: Int32) -> Bool { #if os(Windows) internal func strerror(_ errno: CInt) -> String { - return withUnsafeTemporaryAllocation(of: CChar.self, capacity: 95) { + withUnsafeTemporaryAllocation(of: CChar.self, capacity: 95) { let result = strerror_s($0.baseAddress, $0.count, errno) guard result == 0 else { return "Unknown error: \(errno)" } return String(cString: $0.baseAddress!) @@ -204,54 +217,66 @@ internal func strerror(_ errno: CInt) -> String { } #endif -private func preconditionIsNotUnacceptableErrno(err: CInt, where function: String) -> Void { +private func preconditionIsNotUnacceptableErrno(err: CInt, where function: String) { // strerror is documented to return "Unknown error: ..." for illegal value so it won't ever fail -#if os(Windows) + #if os(Windows) precondition(!isUnacceptableErrno(err), "unacceptable errno \(err) \(strerror(err)) in \(function))") -#else - precondition(!isUnacceptableErrno(err), "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))") -#endif + #else + precondition( + !isUnacceptableErrno(err), + "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))" + ) + #endif } -private func preconditionIsNotUnacceptableErrnoOnClose(err: CInt, where function: String) -> Void { +private func preconditionIsNotUnacceptableErrnoOnClose(err: CInt, where function: String) { // strerror is documented to return "Unknown error: ..." for illegal value so it won't ever fail -#if os(Windows) + #if os(Windows) precondition(!isUnacceptableErrnoOnClose(err), "unacceptable errno \(err) \(strerror(err)) in \(function))") -#else - precondition(!isUnacceptableErrnoOnClose(err), "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))") -#endif + #else + precondition( + !isUnacceptableErrnoOnClose(err), + "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))" + ) + #endif } -private func preconditionIsNotUnacceptableErrnoForbiddingEINVAL(err: CInt, where function: String) -> Void { +private func preconditionIsNotUnacceptableErrnoForbiddingEINVAL(err: CInt, where function: String) { // strerror is documented to return "Unknown error: ..." for illegal value so it won't ever fail -#if os(Windows) - precondition(!isUnacceptableErrnoForbiddingEINVAL(err), "unacceptable errno \(err) \(strerror(err)) in \(function))") -#else - precondition(!isUnacceptableErrnoForbiddingEINVAL(err), "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))") -#endif + #if os(Windows) + precondition( + !isUnacceptableErrnoForbiddingEINVAL(err), + "unacceptable errno \(err) \(strerror(err)) in \(function))" + ) + #else + precondition( + !isUnacceptableErrnoForbiddingEINVAL(err), + "unacceptable errno \(err) \(String(cString: strerror(err)!)) in \(function))" + ) + #endif } - -/* - * Sorry, we really try hard to not use underscored attributes. In this case - * however we seem to break the inlining threshold which makes a system call - * take twice the time, ie. we need this exception. - */ +// Sorry, we really try hard to not use underscored attributes. In this case +// however we seem to break the inlining threshold which makes a system call +// take twice the time, ie. we need this exception. @inline(__always) @discardableResult -internal func syscall(blocking: Bool, - where function: String = #function, - _ body: () throws -> T) - throws -> IOResult { +internal func syscall( + blocking: Bool, + where function: String = #function, + _ body: () throws -> T +) + throws -> IOResult +{ while true { let res = try body() if res == -1 { -#if os(Windows) + #if os(Windows) var err: CInt = 0 _get_errno(&err) -#else + #else let err = errno -#endif + #endif switch (err, blocking) { case (EINTR, _): continue @@ -269,9 +294,12 @@ internal func syscall(blocking: Bool, #if canImport(Darwin) @inline(__always) @discardableResult -internal func syscall(where function: String = #function, - _ body: () throws -> UnsafeMutablePointer?) - throws -> UnsafeMutablePointer { +internal func syscall( + where function: String = #function, + _ body: () throws -> UnsafeMutablePointer? +) + throws -> UnsafeMutablePointer +{ while true { if let res = try body() { return res @@ -290,9 +318,12 @@ internal func syscall(where function: String = #function, #elseif os(Linux) || os(Android) @inline(__always) @discardableResult -internal func syscall(where function: String = #function, - _ body: () throws -> OpaquePointer?) - throws -> OpaquePointer { +internal func syscall( + where function: String = #function, + _ body: () throws -> OpaquePointer? +) + throws -> OpaquePointer +{ while true { if let res = try body() { return res @@ -313,9 +344,12 @@ internal func syscall(where function: String = #function, #if !os(Windows) @inline(__always) @discardableResult -internal func syscallOptional(where function: String = #function, - _ body: () throws -> UnsafeMutablePointer?) - throws -> UnsafeMutablePointer? { +internal func syscallOptional( + where function: String = #function, + _ body: () throws -> UnsafeMutablePointer? +) + throws -> UnsafeMutablePointer? +{ while true { errno = 0 if let res = try body() { @@ -336,25 +370,26 @@ internal func syscallOptional(where function: String = #function, } #endif -/* - * Sorry, we really try hard to not use underscored attributes. In this case - * however we seem to break the inlining threshold which makes a system call - * take twice the time, ie. we need this exception. - */ +// Sorry, we really try hard to not use underscored attributes. In this case +// however we seem to break the inlining threshold which makes a system call +// take twice the time, ie. we need this exception. @inline(__always) @discardableResult -internal func syscallForbiddingEINVAL(where function: String = #function, - _ body: () throws -> T) - throws -> IOResult { +internal func syscallForbiddingEINVAL( + where function: String = #function, + _ body: () throws -> T +) + throws -> IOResult +{ while true { let res = try body() if res == -1 { -#if os(Windows) + #if os(Windows) var err: CInt = 0 _get_errno(&err) -#else + #else let err = errno -#endif + #endif switch err { case EINTR: continue @@ -370,12 +405,12 @@ internal func syscallForbiddingEINVAL(where function: Stri } internal enum Posix { -#if canImport(Darwin) + #if canImport(Darwin) static let UIO_MAXIOV: Int = 1024 static let SHUT_RD: CInt = CInt(Darwin.SHUT_RD) static let SHUT_WR: CInt = CInt(Darwin.SHUT_WR) static let SHUT_RDWR: CInt = CInt(Darwin.SHUT_RDWR) -#elseif os(Linux) || os(FreeBSD) || os(Android) + #elseif os(Linux) || os(FreeBSD) || os(Android) #if canImport(Glibc) static let UIO_MAXIOV: Int = Int(Glibc.UIO_MAXIOV) static let SHUT_RD: CInt = CInt(Glibc.SHUT_RD) @@ -387,7 +422,7 @@ internal enum Posix { static let SHUT_WR: CInt = CInt(Musl.SHUT_WR) static let SHUT_RDWR: CInt = CInt(Musl.SHUT_RDWR) #endif -#else + #else static var UIO_MAXIOV: Int { fatalError("unsupported OS") } @@ -400,15 +435,15 @@ internal enum Posix { static var SHUT_RDWR: Int { fatalError("unsupported OS") } -#endif + #endif -#if canImport(Darwin) + #if canImport(Darwin) static let IPTOS_ECN_NOTECT: CInt = CNIODarwin_IPTOS_ECN_NOTECT static let IPTOS_ECN_MASK: CInt = CNIODarwin_IPTOS_ECN_MASK static let IPTOS_ECN_ECT0: CInt = CNIODarwin_IPTOS_ECN_ECT0 static let IPTOS_ECN_ECT1: CInt = CNIODarwin_IPTOS_ECN_ECT1 static let IPTOS_ECN_CE: CInt = CNIODarwin_IPTOS_ECN_CE -#elseif os(Linux) || os(FreeBSD) || os(Android) + #elseif os(Linux) || os(FreeBSD) || os(Android) #if os(Android) static let IPTOS_ECN_NOTECT: CInt = CInt(CNIOLinux.IPTOS_ECN_NOTECT) #else @@ -418,33 +453,33 @@ internal enum Posix { static let IPTOS_ECN_ECT0: CInt = CInt(CNIOLinux.IPTOS_ECN_ECT0) static let IPTOS_ECN_ECT1: CInt = CInt(CNIOLinux.IPTOS_ECN_ECT1) static let IPTOS_ECN_CE: CInt = CInt(CNIOLinux.IPTOS_ECN_CE) -#elseif os(Windows) + #elseif os(Windows) static let IPTOS_ECN_NOTECT: CInt = CInt(0x00) static let IPTOS_ECN_MASK: CInt = CInt(0x03) static let IPTOS_ECN_ECT0: CInt = CInt(0x02) static let IPTOS_ECN_ECT1: CInt = CInt(0x01) static let IPTOS_ECN_CE: CInt = CInt(0x03) -#endif + #endif -#if canImport(Darwin) + #if canImport(Darwin) static let IP_RECVPKTINFO: CInt = CNIODarwin.IP_RECVPKTINFO static let IP_PKTINFO: CInt = CNIODarwin.IP_PKTINFO static let IPV6_RECVPKTINFO: CInt = CNIODarwin_IPV6_RECVPKTINFO static let IPV6_PKTINFO: CInt = CNIODarwin_IPV6_PKTINFO -#elseif os(Linux) || os(FreeBSD) || os(Android) + #elseif os(Linux) || os(FreeBSD) || os(Android) static let IP_RECVPKTINFO: CInt = CInt(CNIOLinux.IP_PKTINFO) static let IP_PKTINFO: CInt = CInt(CNIOLinux.IP_PKTINFO) static let IPV6_RECVPKTINFO: CInt = CInt(CNIOLinux.IPV6_RECVPKTINFO) static let IPV6_PKTINFO: CInt = CInt(CNIOLinux.IPV6_PKTINFO) -#elseif os(Windows) + #elseif os(Windows) static let IP_PKTINFO: CInt = CInt(WinSDK.IP_PKTINFO) static let IPV6_PKTINFO: CInt = CInt(WinSDK.IPV6_PKTINFO) -#endif + #endif -#if !os(Windows) + #if !os(Windows) @inline(never) internal static func shutdown(descriptor: CInt, how: Shutdown) throws { _ = try syscall(blocking: false) { @@ -456,12 +491,12 @@ internal enum Posix { internal static func close(descriptor: CInt) throws { let res = sysClose(descriptor) if res == -1 { -#if os(Windows) + #if os(Windows) var err: CInt = 0 _get_errno(&err) -#else + #else let err = errno -#endif + #endif // There is really nothing "good" we can do when EINTR was reported on close. // So just ignore it and "assume" everything is fine == we closed the file descriptor. @@ -478,7 +513,7 @@ internal enum Posix { @inline(never) internal static func bind(descriptor: CInt, ptr: UnsafePointer, bytes: Int) throws { - _ = try syscall(blocking: false) { + _ = try syscall(blocking: false) { sysBind(descriptor, ptr, socklen_t(bytes)) } } @@ -487,30 +522,43 @@ internal enum Posix { @discardableResult // TODO: Allow varargs internal static func fcntl(descriptor: CInt, command: CInt, value: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysFcntl(descriptor, command, value) }.result } @inline(never) - internal static func socket(domain: NIOBSDSocket.ProtocolFamily, type: NIOBSDSocket.SocketType, protocolSubtype: NIOBSDSocket.ProtocolSubtype) throws -> CInt { - return try syscall(blocking: false) { - return sysSocket(domain.rawValue, type.rawValue, protocolSubtype.rawValue) + internal static func socket( + domain: NIOBSDSocket.ProtocolFamily, + type: NIOBSDSocket.SocketType, + protocolSubtype: NIOBSDSocket.ProtocolSubtype + ) throws -> CInt { + try syscall(blocking: false) { + sysSocket(domain.rawValue, type.rawValue, protocolSubtype.rawValue) }.result } @inline(never) - internal static func setsockopt(socket: CInt, level: CInt, optionName: CInt, - optionValue: UnsafeRawPointer, optionLen: socklen_t) throws { + internal static func setsockopt( + socket: CInt, + level: CInt, + optionName: CInt, + optionValue: UnsafeRawPointer, + optionLen: socklen_t + ) throws { _ = try syscall(blocking: false) { sysSetsockopt(socket, level, optionName, optionValue, optionLen) } } @inline(never) - internal static func getsockopt(socket: CInt, level: CInt, optionName: CInt, - optionValue: UnsafeMutableRawPointer, - optionLen: UnsafeMutablePointer) throws { + internal static func getsockopt( + socket: CInt, + level: CInt, + optionName: CInt, + optionValue: UnsafeMutableRawPointer, + optionLen: UnsafeMutablePointer + ) throws { _ = try syscall(blocking: false) { sysGetsockopt(socket, level, optionName, optionValue, optionLen) }.result @@ -524,11 +572,13 @@ internal enum Posix { } @inline(never) - internal static func accept(descriptor: CInt, - addr: UnsafeMutablePointer?, - len: UnsafeMutablePointer?) throws -> CInt? { + internal static func accept( + descriptor: CInt, + addr: UnsafeMutablePointer?, + len: UnsafeMutablePointer? + ) throws -> CInt? { let result: IOResult = try syscall(blocking: true) { - return sysAccept(descriptor, addr, len) + sysAccept(descriptor, addr, len) } if case .processed(let fd) = result { @@ -555,14 +605,14 @@ internal enum Posix { @inline(never) internal static func open(file: UnsafePointer, oFlag: CInt, mode: mode_t) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysOpenWithMode(file, oFlag, mode) }.result } @inline(never) internal static func open(file: UnsafePointer, oFlag: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysOpen(file, oFlag) }.result } @@ -570,58 +620,80 @@ internal enum Posix { @inline(never) @discardableResult internal static func ftruncate(descriptor: CInt, size: off_t) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysFtruncate(descriptor, size) }.result } - + @inline(never) internal static func write(descriptor: CInt, pointer: UnsafeRawPointer, size: Int) throws -> IOResult { - return try syscall(blocking: true) { + try syscall(blocking: true) { sysWrite(descriptor, pointer, size) } } @inline(never) - internal static func pwrite(descriptor: CInt, pointer: UnsafeRawPointer, size: Int, offset: off_t) throws -> IOResult { - return try syscall(blocking: true) { + internal static func pwrite( + descriptor: CInt, + pointer: UnsafeRawPointer, + size: Int, + offset: off_t + ) throws -> IOResult { + try syscall(blocking: true) { sysPwrite(descriptor, pointer, size, offset) } } -#if !os(Windows) + #if !os(Windows) @inline(never) internal static func writev(descriptor: CInt, iovecs: UnsafeBufferPointer) throws -> IOResult { - return try syscall(blocking: true) { + try syscall(blocking: true) { sysWritev(descriptor, iovecs.baseAddress!, CInt(iovecs.count)) } } -#endif + #endif @inline(never) - internal static func read(descriptor: CInt, pointer: UnsafeMutableRawPointer, size: size_t) throws -> IOResult { - return try syscallForbiddingEINVAL { + internal static func read( + descriptor: CInt, + pointer: UnsafeMutableRawPointer, + size: size_t + ) throws -> IOResult { + try syscallForbiddingEINVAL { sysRead(descriptor, pointer, size) } } @inline(never) - internal static func pread(descriptor: CInt, pointer: UnsafeMutableRawPointer, size: size_t, offset: off_t) throws -> IOResult { - return try syscallForbiddingEINVAL { + internal static func pread( + descriptor: CInt, + pointer: UnsafeMutableRawPointer, + size: size_t, + offset: off_t + ) throws -> IOResult { + try syscallForbiddingEINVAL { sysPread(descriptor, pointer, size, offset) } } @inline(never) - internal static func recvmsg(descriptor: CInt, msgHdr: UnsafeMutablePointer, flags: CInt) throws -> IOResult { - return try syscall(blocking: true) { + internal static func recvmsg( + descriptor: CInt, + msgHdr: UnsafeMutablePointer, + flags: CInt + ) throws -> IOResult { + try syscall(blocking: true) { sysRecvMsg(descriptor, msgHdr, flags) } } - + @inline(never) - internal static func sendmsg(descriptor: CInt, msgHdr: UnsafePointer, flags: CInt) throws -> IOResult { - return try syscall(blocking: true) { + internal static func sendmsg( + descriptor: CInt, + msgHdr: UnsafePointer, + flags: CInt + ) throws -> IOResult { + try syscall(blocking: true) { sysSendMsg(descriptor, msgHdr, flags) } } @@ -629,21 +701,21 @@ internal enum Posix { @discardableResult @inline(never) internal static func lseek(descriptor: CInt, offset: off_t, whence: CInt) throws -> off_t { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysLseek(descriptor, offset, whence) }.result } -#endif + #endif @discardableResult @inline(never) internal static func dup(descriptor: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysDup(descriptor) }.result } -#if !os(Windows) + #if !os(Windows) // It's not really posix but exists on Linux and MacOS / BSD so just put it here for now to keep it simple @inline(never) internal static func sendfile(descriptor: CInt, fd: CInt, offset: off_t, count: size_t) throws -> IOResult { @@ -651,25 +723,25 @@ internal enum Posix { do { _ = try syscall(blocking: false) { () -> ssize_t in #if canImport(Darwin) - var w: off_t = off_t(count) - let result: CInt = Darwin.sendfile(fd, descriptor, offset, &w, nil, 0) - written = w - return ssize_t(result) + var w: off_t = off_t(count) + let result: CInt = Darwin.sendfile(fd, descriptor, offset, &w, nil, 0) + written = w + return ssize_t(result) #elseif os(Linux) || os(FreeBSD) || os(Android) - var off: off_t = offset - #if canImport(Glibc) - let result: ssize_t = Glibc.sendfile(descriptor, fd, &off, count) - #elseif canImport(Musl) - let result: ssize_t = Musl.sendfile(descriptor, fd, &off, count) - #endif - if result >= 0 { - written = off_t(result) - } else { - written = 0 - } - return result + var off: off_t = offset + #if canImport(Glibc) + let result: ssize_t = Glibc.sendfile(descriptor, fd, &off, count) + #elseif canImport(Musl) + let result: ssize_t = Musl.sendfile(descriptor, fd, &off, count) + #endif + if result >= 0 { + written = off_t(result) + } else { + written = 0 + } + return result #else - fatalError("unsupported OS") + fatalError("unsupported OS") #endif } return .processed(Int(written)) @@ -682,45 +754,64 @@ internal enum Posix { } @inline(never) - internal static func sendmmsg(sockfd: CInt, msgvec: UnsafeMutablePointer, vlen: CUnsignedInt, flags: CInt) throws -> IOResult { - return try syscall(blocking: true) { + internal static func sendmmsg( + sockfd: CInt, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt + ) throws -> IOResult { + try syscall(blocking: true) { Int(sysSendMmsg(sockfd, msgvec, vlen, flags)) } } @inline(never) - internal static func recvmmsg(sockfd: CInt, msgvec: UnsafeMutablePointer, vlen: CUnsignedInt, flags: CInt, timeout: UnsafeMutablePointer?) throws -> IOResult { - return try syscall(blocking: true) { + internal static func recvmmsg( + sockfd: CInt, + msgvec: UnsafeMutablePointer, + vlen: CUnsignedInt, + flags: CInt, + timeout: UnsafeMutablePointer? + ) throws -> IOResult { + try syscall(blocking: true) { Int(sysRecvMmsg(sockfd, msgvec, vlen, flags, timeout)) } } @inline(never) - internal static func getpeername(socket: CInt, address: UnsafeMutablePointer, addressLength: UnsafeMutablePointer) throws { + internal static func getpeername( + socket: CInt, + address: UnsafeMutablePointer, + addressLength: UnsafeMutablePointer + ) throws { _ = try syscall(blocking: false) { - return sysGetpeername(socket, address, addressLength) + sysGetpeername(socket, address, addressLength) } } @inline(never) - internal static func getsockname(socket: CInt, address: UnsafeMutablePointer, addressLength: UnsafeMutablePointer) throws { + internal static func getsockname( + socket: CInt, + address: UnsafeMutablePointer, + addressLength: UnsafeMutablePointer + ) throws { _ = try syscall(blocking: false) { - return sysGetsockname(socket, address, addressLength) + sysGetsockname(socket, address, addressLength) } } -#endif + #endif @inline(never) internal static func if_nametoindex(_ name: UnsafePointer?) throws -> CUnsignedInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysIfNameToIndex(name!) }.result } -#if !os(Windows) + #if !os(Windows) @inline(never) internal static func poll(fds: UnsafeMutablePointer, nfds: nfds_t, timeout: CInt) throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { sysPoll(fds, nfds, timeout) }.result } @@ -754,8 +845,12 @@ internal enum Posix { } @inline(never) - public static func readlink(pathname: String, outPath: UnsafeMutablePointer, outPathSize: Int) throws -> CLong { - return try syscall(blocking: false) { + public static func readlink( + pathname: String, + outPath: UnsafeMutablePointer, + outPathSize: Int + ) throws -> CLong { + try syscall(blocking: false) { sysReadlink(pathname, outPath, outPathSize) }.result } @@ -774,7 +869,7 @@ internal enum Posix { } } -#if canImport(Darwin) + #if canImport(Darwin) @inline(never) public static func mkpath_np(pathname: String, mode: mode_t) throws { _ = try syscall(blocking: false) { @@ -784,14 +879,14 @@ internal enum Posix { @inline(never) public static func opendir(pathname: String) throws -> UnsafeMutablePointer { - return try syscall { + try syscall { sysOpendir(pathname) } } @inline(never) public static func readdir(dir: UnsafeMutablePointer) throws -> UnsafeMutablePointer? { - return try syscallOptional { + try syscallOptional { sysReaddir(dir) } } @@ -802,17 +897,17 @@ internal enum Posix { sysClosedir(dir) } } -#elseif os(Linux) || os(FreeBSD) || os(Android) + #elseif os(Linux) || os(FreeBSD) || os(Android) @inline(never) public static func opendir(pathname: String) throws -> OpaquePointer { - return try syscall { + try syscall { sysOpendir(pathname) } } @inline(never) public static func readdir(dir: OpaquePointer) throws -> UnsafeMutablePointer? { - return try syscallOptional { + try syscallOptional { sysReaddir(dir) } } @@ -823,7 +918,7 @@ internal enum Posix { sysClosedir(dir) } } -#endif + #endif @inline(never) public static func rename(pathname: String, newName: String) throws { @@ -840,16 +935,18 @@ internal enum Posix { } @inline(never) - internal static func socketpair(domain: NIOBSDSocket.ProtocolFamily, - type: NIOBSDSocket.SocketType, - protocolSubtype: NIOBSDSocket.ProtocolSubtype, - socketVector: UnsafeMutablePointer?) throws { + internal static func socketpair( + domain: NIOBSDSocket.ProtocolFamily, + type: NIOBSDSocket.SocketType, + protocolSubtype: NIOBSDSocket.ProtocolSubtype, + socketVector: UnsafeMutablePointer? + ) throws { _ = try syscall(blocking: false) { sysSocketpair(domain.rawValue, type.rawValue, protocolSubtype.rawValue, socketVector) } } -#endif -#if !os(Windows) + #endif + #if !os(Windows) @inline(never) internal static func ioctl(fd: CInt, request: CUnsignedLong, ptr: UnsafeMutableRawPointer) throws { _ = try syscall(blocking: false) { @@ -857,7 +954,7 @@ internal enum Posix { sysIoctl(fd, numericCast(request), ptr) } } -#endif // !os(Windows) + #endif // !os(Windows) } /// `NIOFcntlFailedError` indicates that NIO was unable to perform an @@ -875,7 +972,7 @@ public struct NIOFcntlFailedError: Error {} public struct NIOFailedToSetSocketNonBlockingError: Error {} #if !os(Windows) -internal extension Posix { +extension Posix { static func setNonBlocking(socket: CInt) throws { let flags = try Posix.fcntl(descriptor: socket, command: F_GETFL, value: 0) do { @@ -899,15 +996,22 @@ internal enum KQueue { @inline(never) internal static func kqueue() throws -> CInt { - return try syscall(blocking: false) { + try syscall(blocking: false) { Darwin.kqueue() }.result } @inline(never) @discardableResult - internal static func kevent(kq: CInt, changelist: UnsafePointer?, nchanges: CInt, eventlist: UnsafeMutablePointer?, nevents: CInt, timeout: UnsafePointer?) throws -> CInt { - return try syscall(blocking: false) { + internal static func kevent( + kq: CInt, + changelist: UnsafePointer?, + nchanges: CInt, + eventlist: UnsafeMutablePointer?, + nevents: CInt, + timeout: UnsafePointer? + ) throws -> CInt { + try syscall(blocking: false) { sysKevent(kq, changelist, nchanges, eventlist, nevents, timeout) }.result } diff --git a/Sources/NIOPosix/Thread.swift b/Sources/NIOPosix/Thread.swift index 68b486c72a..ac4b7730c6 100644 --- a/Sources/NIOPosix/Thread.swift +++ b/Sources/NIOPosix/Thread.swift @@ -66,12 +66,12 @@ final class NIOThread { /// - body: The closure that will accept the `pthread_t`. /// - returns: The value returned by `body`. internal func withUnsafeThreadHandle(_ body: (ThreadOpsSystem.ThreadHandle) throws -> T) rethrows -> T { - return try body(self.handle) + try body(self.handle) } /// Get current name of the `NIOThread` or `nil` if not set. var currentName: String? { - return ThreadOpsSystem.threadName(self.handle) + ThreadOpsSystem.threadName(self.handle) } func join() { @@ -84,8 +84,11 @@ final class NIOThread { /// - name: The name of the `NIOThread` or `nil` if no specific name should be set. /// - body: The function to execute within the spawned `NIOThread`. /// - detach: Whether to detach the thread. If the thread is not detached it must be `join`ed. - static func spawnAndRun(name: String? = nil, detachThread: Bool = true, - body: @escaping (NIOThread) -> Void) { + static func spawnAndRun( + name: String? = nil, + detachThread: Bool = true, + body: @escaping (NIOThread) -> Void + ) { var handle: ThreadOpsSystem.ThreadHandle? = nil // Store everything we want to pass into the c function in a Box so we @@ -98,7 +101,7 @@ final class NIOThread { /// Returns `true` if the calling thread is the same as this one. var isCurrent: Bool { - return ThreadOpsSystem.isCurrentThread(self.handle) + ThreadOpsSystem.isCurrentThread(self.handle) } /// Returns the current running `NIOThread`. @@ -143,7 +146,7 @@ extension NIOThread: CustomStringConvertible { /// `ThreadSpecificVariable` is thread-safe so it can be used with multiple threads at the same time but the value /// returned by `currentValue` is defined per thread. public final class ThreadSpecificVariable { - /* the actual type in there is `Box<(ThreadSpecificVariable, T)>` but we can't use that as C functions can't capture (even types) */ + // the actual type in there is `Box<(ThreadSpecificVariable, T)>` but we can't use that as C functions can't capture (even types) private typealias BoxedType = Box<(AnyObject, AnyObject)> internal class Key { @@ -158,7 +161,7 @@ public final class ThreadSpecificVariable { } public func get() -> UnsafeMutableRawPointer? { - return ThreadOpsSystem.getThreadSpecificValue(self.underlyingKey) + ThreadOpsSystem.getThreadSpecificValue(self.underlyingKey) } public func set(value: UnsafeMutableRawPointer?) { @@ -171,7 +174,7 @@ public final class ThreadSpecificVariable { /// Initialize a new `ThreadSpecificVariable` without a current value (`currentValue == nil`). public init() { self.key = Key(destructor: { - Unmanaged.fromOpaque(($0 as UnsafeMutableRawPointer?)!).release() + Unmanaged.fromOpaque(($0 as UnsafeMutableRawPointer?)!).release() }) } @@ -185,9 +188,12 @@ public final class ThreadSpecificVariable { self.currentValue = value } - /// The value for the current thread. - @available(*, noasync, message: "threads can change between suspension points and therefore the thread specific value too") + @available( + *, + noasync, + message: "threads can change between suspension points and therefore the thread specific value too" + ) public var currentValue: Value? { get { self.get() @@ -201,10 +207,11 @@ public final class ThreadSpecificVariable { func get() -> Value? { guard let raw = self.key.get() else { return nil } // parenthesize the return value to silence the cast warning - return (Unmanaged - .fromOpaque(raw) - .takeUnretainedValue() - .value.1 as! Value) + return + (Unmanaged + .fromOpaque(raw) + .takeUnretainedValue() + .value.1 as! Value) } /// Set the current value for the calling threads. The `currentValue` for all other threads remains unchanged. @@ -219,8 +226,8 @@ public final class ThreadSpecificVariable { extension ThreadSpecificVariable: @unchecked Sendable where Value: Sendable {} extension NIOThread: Equatable { - static func ==(lhs: NIOThread, rhs: NIOThread) -> Bool { - return lhs.withUnsafeThreadHandle { lhs in + static func == (lhs: NIOThread, rhs: NIOThread) -> Bool { + lhs.withUnsafeThreadHandle { lhs in rhs.withUnsafeThreadHandle { rhs in ThreadOpsSystem.compareThreads(lhs, rhs) } diff --git a/Sources/NIOPosix/ThreadPosix.swift b/Sources/NIOPosix/ThreadPosix.swift index 852f08f634..aacaba5b47 100644 --- a/Sources/NIOPosix/ThreadPosix.swift +++ b/Sources/NIOPosix/ThreadPosix.swift @@ -37,24 +37,30 @@ private typealias ThreadDestructor = @convention(c) (UnsafeMutableRawPointer) -> #endif -private func sysPthread_create(handle: UnsafeMutablePointer, - destructor: @escaping ThreadDestructor, - args: UnsafeMutableRawPointer?) -> CInt { +private func sysPthread_create( + handle: UnsafeMutablePointer, + destructor: @escaping ThreadDestructor, + args: UnsafeMutableRawPointer? +) -> CInt { #if canImport(Darwin) return pthread_create(handle, nil, destructor, args) #else #if canImport(Musl) var handleLinux: OpaquePointer? = nil - let result = pthread_create(&handleLinux, - nil, - destructor, - args) + let result = pthread_create( + &handleLinux, + nil, + destructor, + args + ) #else var handleLinux = pthread_t() - let result = pthread_create(&handleLinux, - nil, - destructor, - args) + let result = pthread_create( + &handleLinux, + nil, + destructor, + args + ) #endif handle.pointee = handleLinux return result @@ -88,39 +94,47 @@ enum ThreadOpsPosix: ThreadOps { } } - static func run(handle: inout ThreadOpsSystem.ThreadHandle?, args: Box, detachThread: Bool) { + static func run( + handle: inout ThreadOpsSystem.ThreadHandle?, + args: Box, + detachThread: Bool + ) { let argv0 = Unmanaged.passRetained(args).toOpaque() - let res = sysPthread_create(handle: &handle, destructor: { - // Cast to UnsafeMutableRawPointer? and force unwrap to make the - // same code work on macOS and Linux. - let boxed = Unmanaged - .fromOpaque(($0 as UnsafeMutableRawPointer?)!) - .takeRetainedValue() - let (body, name) = (boxed.value.body, boxed.value.name) - let hThread: ThreadOpsSystem.ThreadHandle = pthread_self() - - if let name = name { - let maximumThreadNameLength: Int - #if os(Linux) || os(Android) - maximumThreadNameLength = 15 - #else - maximumThreadNameLength = .max - #endif - name.prefix(maximumThreadNameLength).withCString { namePtr in - // this is non-critical so we ignore the result here, we've seen - // EPERM in containers. - _ = sys_pthread_setname_np(hThread, namePtr) + let res = sysPthread_create( + handle: &handle, + destructor: { + // Cast to UnsafeMutableRawPointer? and force unwrap to make the + // same code work on macOS and Linux. + let boxed = Unmanaged + .fromOpaque(($0 as UnsafeMutableRawPointer?)!) + .takeRetainedValue() + let (body, name) = (boxed.value.body, boxed.value.name) + let hThread: ThreadOpsSystem.ThreadHandle = pthread_self() + + if let name = name { + let maximumThreadNameLength: Int + #if os(Linux) || os(Android) + maximumThreadNameLength = 15 + #else + maximumThreadNameLength = .max + #endif + name.prefix(maximumThreadNameLength).withCString { namePtr in + // this is non-critical so we ignore the result here, we've seen + // EPERM in containers. + _ = sys_pthread_setname_np(hThread, namePtr) + } } - } - body(NIOThread(handle: hThread, desiredName: name)) + body(NIOThread(handle: hThread, desiredName: name)) - #if os(Android) - return UnsafeMutableRawPointer(bitPattern: 0xdeadbee)! - #else - return nil - #endif - }, args: argv0) + #if os(Android) + return UnsafeMutableRawPointer(bitPattern: 0xdeadbee)! + #else + return nil + #endif + }, + args: argv0 + ) precondition(res == 0, "Unable to create thread: \(res)") if detachThread { @@ -131,11 +145,11 @@ enum ThreadOpsPosix: ThreadOps { } static func isCurrentThread(_ thread: ThreadOpsSystem.ThreadHandle) -> Bool { - return pthread_equal(thread, pthread_self()) != 0 + pthread_equal(thread, pthread_self()) != 0 } static var currentThread: ThreadOpsSystem.ThreadHandle { - return pthread_self() + pthread_self() } static func joinThread(_ thread: ThreadOpsSystem.ThreadHandle) { @@ -156,7 +170,7 @@ enum ThreadOpsPosix: ThreadOps { } static func getThreadSpecificValue(_ key: ThreadSpecificKey) -> UnsafeMutableRawPointer? { - return pthread_getspecific(key) + pthread_getspecific(key) } static func setThreadSpecificValue(key: ThreadSpecificKey, value: UnsafeMutableRawPointer?) { @@ -165,7 +179,7 @@ enum ThreadOpsPosix: ThreadOps { } static func compareThreads(_ lhs: ThreadOpsSystem.ThreadHandle, _ rhs: ThreadOpsSystem.ThreadHandle) -> Bool { - return pthread_equal(lhs, rhs) != 0 + pthread_equal(lhs, rhs) != 0 } } diff --git a/Sources/NIOPosix/ThreadWindows.swift b/Sources/NIOPosix/ThreadWindows.swift index 34873a0c72..2513973d98 100644 --- a/Sources/NIOPosix/ThreadWindows.swift +++ b/Sources/NIOPosix/ThreadWindows.swift @@ -16,7 +16,6 @@ import WinSDK - typealias ThreadOpsSystem = ThreadOpsWindows enum ThreadOpsWindows: ThreadOps { typealias ThreadHandle = HANDLE @@ -32,7 +31,11 @@ enum ThreadOpsWindows: ThreadOps { return string } - static func run(handle: inout ThreadOpsSystem.ThreadHandle?, args: Box, detachThread: Bool) { + static func run( + handle: inout ThreadOpsSystem.ThreadHandle?, + args: Box, + detachThread: Bool + ) { let argv0 = Unmanaged.passRetained(args).toOpaque() // FIXME(compnerd) this should use the `stdcall` calling convention @@ -60,11 +63,11 @@ enum ThreadOpsWindows: ThreadOps { } static func isCurrentThread(_ thread: ThreadOpsSystem.ThreadHandle) -> Bool { - return CompareObjectHandles(thread, GetCurrentThread()) + CompareObjectHandles(thread, GetCurrentThread()) } static var currentThread: ThreadOpsSystem.ThreadHandle { - return GetCurrentThread() + GetCurrentThread() } static func joinThread(_ thread: ThreadOpsSystem.ThreadHandle) { @@ -73,7 +76,7 @@ enum ThreadOpsWindows: ThreadOps { } static func allocateThreadSpecificValue(destructor: @escaping ThreadSpecificKeyDestructor) -> ThreadSpecificKey { - return FlsAlloc(destructor) + FlsAlloc(destructor) } static func deallocateThreadSpecificValue(_ key: ThreadSpecificKey) { @@ -82,7 +85,7 @@ enum ThreadOpsWindows: ThreadOps { } static func getThreadSpecificValue(_ key: ThreadSpecificKey) -> UnsafeMutableRawPointer? { - return FlsGetValue(key) + FlsGetValue(key) } static func setThreadSpecificValue(key: ThreadSpecificKey, value: UnsafeMutableRawPointer?) { @@ -90,7 +93,7 @@ enum ThreadOpsWindows: ThreadOps { } static func compareThreads(_ lhs: ThreadOpsSystem.ThreadHandle, _ rhs: ThreadOpsSystem.ThreadHandle) -> Bool { - return CompareObjectHandles(lhs, rhs) + CompareObjectHandles(lhs, rhs) } } diff --git a/Sources/NIOPosix/UnsafeTransfer.swift b/Sources/NIOPosix/UnsafeTransfer.swift index daef8dacc0..e7a5a8ca12 100644 --- a/Sources/NIOPosix/UnsafeTransfer.swift +++ b/Sources/NIOPosix/UnsafeTransfer.swift @@ -19,7 +19,7 @@ struct UnsafeTransfer { @usableFromInline var wrappedValue: Wrapped - + @inlinable init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue @@ -30,4 +30,3 @@ extension UnsafeTransfer: @unchecked Sendable {} extension UnsafeTransfer: Equatable where Wrapped: Equatable {} extension UnsafeTransfer: Hashable where Wrapped: Hashable {} - diff --git a/Sources/NIOPosix/Utilities.swift b/Sources/NIOPosix/Utilities.swift index 43329b9801..acd67193b2 100644 --- a/Sources/NIOPosix/Utilities.swift +++ b/Sources/NIOPosix/Utilities.swift @@ -19,7 +19,12 @@ /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } /// Allows to "box" another value. diff --git a/Sources/NIOPosix/VsockAddress.swift b/Sources/NIOPosix/VsockAddress.swift index 8e15e893e2..d4197e2dc6 100644 --- a/Sources/NIOPosix/VsockAddress.swift +++ b/Sources/NIOPosix/VsockAddress.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import NIOCore + #if canImport(Darwin) import CNIODarwin #elseif os(Linux) || os(Android) @@ -23,7 +24,7 @@ import Musl #endif import CNIOLinux #endif -fileprivate let vsockUnimplemented = "VSOCK support is not implemented for this platform" +private let vsockUnimplemented = "VSOCK support is not implemented for this platform" // MARK: - Public API that's available on all platforms. @@ -168,32 +169,32 @@ extension ChannelOptions.Types { extension NIOBSDSocket.AddressFamily { /// Address for vsock. public static var vsock: NIOBSDSocket.AddressFamily { -#if canImport(Darwin) || os(Linux) || os(Android) + #if canImport(Darwin) || os(Linux) || os(Android) NIOBSDSocket.AddressFamily(rawValue: AF_VSOCK) -#else + #else fatalError(vsockUnimplemented) -#endif + #endif } } extension NIOBSDSocket.ProtocolFamily { /// Address for vsock. public static var vsock: NIOBSDSocket.ProtocolFamily { -#if canImport(Darwin) || os(Linux) || os(Android) + #if canImport(Darwin) || os(Linux) || os(Android) NIOBSDSocket.ProtocolFamily(rawValue: PF_VSOCK) -#else + #else fatalError(vsockUnimplemented) -#endif + #endif } } extension VsockAddress { public func withSockAddr(_ body: (UnsafePointer, Int) throws -> T) rethrows -> T { -#if canImport(Darwin) || os(Linux) || os(Android) + #if canImport(Darwin) || os(Linux) || os(Android) return try self.address.withSockAddr({ try body($0, $1) }) -#else + #else fatalError(vsockUnimplemented) -#endif + #endif } } @@ -215,14 +216,14 @@ extension VsockAddress.ContextID { /// /// - Note: On Linux, ``local`` may be a better choice. static func getLocalContextID(_ socketFD: NIOBSDSocket.Handle) throws -> Self { -#if canImport(Darwin) + #if canImport(Darwin) let request = CNIODarwin_IOCTL_VM_SOCKETS_GET_LOCAL_CID let fd = socketFD -#elseif os(Linux) || os(Android) + #elseif os(Linux) || os(Android) let request = CNIOLinux_IOCTL_VM_SOCKETS_GET_LOCAL_CID let fd = try Posix.open(file: "/dev/vsock", oFlag: O_RDONLY | O_CLOEXEC) defer { try! Posix.close(descriptor: fd) } -#endif + #endif var cid = Self.any.rawValue try Posix.ioctl(fd: fd, request: request, ptr: &cid) return Self(rawValue: cid) @@ -231,7 +232,7 @@ extension VsockAddress.ContextID { extension sockaddr_vm { func withSockAddr(_ body: (UnsafePointer, Int) throws -> R) rethrows -> R { - return try withUnsafeBytes(of: self) { p in + try withUnsafeBytes(of: self) { p in try body(p.baseAddress!.assumingMemoryBound(to: sockaddr.self), p.count) } } diff --git a/Sources/NIOTCPEchoServer/Server.swift b/Sources/NIOTCPEchoServer/Server.swift index b2635b25d6..451d687b89 100644 --- a/Sources/NIOTCPEchoServer/Server.swift +++ b/Sources/NIOTCPEchoServer/Server.swift @@ -95,7 +95,6 @@ struct Server { } } - /// A simple newline based encoder and decoder. private final class NewlineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder { typealias InboundIn = ByteBuffer diff --git a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift index d3e63fcf47..c4a2984143 100644 --- a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift +++ b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift @@ -39,17 +39,22 @@ import NIOCore /// specify a type that must be returned from the supplied closure. The result will then be used to succeed the ``NIOTypedApplicationProtocolNegotiationHandler/protocolNegotiationResult`` /// promise. This allows us to construct pipelines that include protocol negotiation handlers and be able to bridge them into `NIOAsyncChannel` /// based bootstraps. -public final class NIOTypedApplicationProtocolNegotiationHandler: ChannelInboundHandler, RemovableChannelHandler { +public final class NIOTypedApplicationProtocolNegotiationHandler: ChannelInboundHandler, + RemovableChannelHandler +{ public typealias InboundIn = Any public typealias InboundOut = Any public var protocolNegotiationResult: EventLoopFuture { - return self.negotiatedPromise.futureResult + self.negotiatedPromise.futureResult } private var negotiatedPromise: EventLoopPromise { - precondition(self._negotiatedPromise != nil, "Tried to access the protocol negotiation result before the handler was added to a pipeline") + precondition( + self._negotiatedPromise != nil, + "Tried to access the protocol negotiation result before the handler was added to a pipeline" + ) return self._negotiatedPromise! } private var _negotiatedPromise: EventLoopPromise? @@ -113,7 +118,7 @@ public final class NIOTypedApplicationProtocolNegotiationHandler { @inlinable mutating func userInboundEventTriggered(event: Any) -> UserInboundEventTriggeredAction { - if case .handshakeCompleted(let negotiated) = event as? TLSUserEvent { + if case .handshakeCompleted(let negotiated) = event as? TLSUserEvent { switch self.state { case .initial: self.state = .waitingForUser(buffer: .init()) @@ -171,7 +171,7 @@ struct ProtocolNegotiationHandlerStateMachine { switch self.state { case .initial, .unbuffering, .waitingForUser: self.state = .finished - + case .finished: break } diff --git a/Sources/NIOTLS/SNIHandler.swift b/Sources/NIOTLS/SNIHandler.swift index a522fb545a..18efcac439 100644 --- a/Sources/NIOTLS/SNIHandler.swift +++ b/Sources/NIOTLS/SNIHandler.swift @@ -44,15 +44,15 @@ private enum InternalSNIErrors: Error { case recordIncomplete } -private extension ByteBuffer { - mutating func moveReaderIndexIfPossible(forwardBy distance: Int) throws { +extension ByteBuffer { + fileprivate mutating func moveReaderIndexIfPossible(forwardBy distance: Int) throws { guard self.readableBytes >= distance else { throw InternalSNIErrors.invalidLengthInRecord } self.moveReaderIndex(forwardBy: distance) } - mutating func readIntegerIfPossible() throws -> T { + fileprivate mutating func readIntegerIfPossible() throws -> T { guard let integer: T = self.readInteger() else { throw InternalSNIErrors.invalidLengthInRecord } @@ -60,8 +60,8 @@ private extension ByteBuffer { } } -private extension Sequence where Element == UInt8 { - func decodeStringValidatingASCII() -> String? { +extension Sequence where Element == UInt8 { + fileprivate func decodeStringValidatingASCII() -> String? { var bytesIterator = self.makeIterator() var scalars: [Unicode.Scalar] = [] scalars.reserveCapacity(self.underestimatedCount) @@ -101,14 +101,18 @@ public final class SNIHandler: ByteToMessageDecoder { private let completionHandler: (SNIResult) -> EventLoopFuture private var waitingForUser: Bool - + public init(sniCompleteHandler: @escaping (SNIResult) -> EventLoopFuture) { self.cumulationBuffer = nil self.completionHandler = sniCompleteHandler self.waitingForUser = false } - public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState { + public func decodeLast( + context: ChannelHandlerContext, + buffer: inout ByteBuffer, + seenEOF: Bool + ) throws -> DecodingState { context.fireChannelRead(NIOAny(buffer)) return .needMoreData } @@ -165,7 +169,9 @@ public final class SNIHandler: ByteToMessageDecoder { // // From this point onwards if we don't have enough data to satisfy a read, this is an error and // we will fall back to let the upper layers handle it. - tempBuffer = tempBuffer.getSlice(at: tempBuffer.readerIndex, length: Int(contentLength))! // length check above + + // length check above + tempBuffer = tempBuffer.getSlice(at: tempBuffer.readerIndex, length: Int(contentLength))! // Now parse the handshake header. If the length of the handshake message is not exactly the // length of this record, something has gone wrong and we should give up. @@ -198,7 +204,7 @@ public final class SNIHandler: ByteToMessageDecoder { } // Check the content type. - let contentType: UInt8 = buffer.readInteger()! // length check above + let contentType: UInt8 = buffer.readInteger()! // length check above guard contentType == tlsContentTypeHandshake else { // Whatever this is, it's not a handshake message, so something has gone // wrong. We're going to fall back to the default handler here and let @@ -207,7 +213,7 @@ public final class SNIHandler: ByteToMessageDecoder { } // Now, check the major version. - let majorVersion: UInt8 = buffer.readInteger()! // length check above + let majorVersion: UInt8 = buffer.readInteger()! // length check above guard majorVersion == 3 else { // A major version of 3 is the major version used for SSLv3 and all subsequent versions // of the protocol. If that's not what this is, we don't know what's happening here. @@ -217,7 +223,7 @@ public final class SNIHandler: ByteToMessageDecoder { // Skip the minor version byte, then grab the content length. buffer.moveReaderIndex(forwardBy: 1) - let contentLength: UInt16 = buffer.readInteger()! // length check above + let contentLength: UInt16 = buffer.readInteger()! // length check above return Int(contentLength) } @@ -254,8 +260,8 @@ public final class SNIHandler: ByteToMessageDecoder { } let handshakeTypeAndLength: UInt32 = buffer.readInteger()! - let handshakeType: UInt8 = UInt8((handshakeTypeAndLength & 0xFF000000) >> 24) - let handshakeLength: UInt32 = handshakeTypeAndLength & 0x00FFFFFF + let handshakeType: UInt8 = UInt8((handshakeTypeAndLength & 0xFF00_0000) >> 24) + let handshakeLength: UInt32 = handshakeTypeAndLength & 0x00FF_FFFF guard handshakeType == handshakeTypeClientHello else { throw InternalSNIErrors.invalidRecord } diff --git a/Sources/NIOTestUtils/ByteToMessageDecoderVerifier.swift b/Sources/NIOTestUtils/ByteToMessageDecoderVerifier.swift index f84b570c5c..5a30b0ebe7 100644 --- a/Sources/NIOTestUtils/ByteToMessageDecoderVerifier.swift +++ b/Sources/NIOTestUtils/ByteToMessageDecoderVerifier.swift @@ -18,13 +18,16 @@ public enum ByteToMessageDecoderVerifier { /// - seealso: verifyDecoder(inputOutputPairs:decoderFactory:) /// /// Verify `ByteToMessageDecoder`s with `String` inputs - public static func verifyDecoder(stringInputOutputPairs: [(String, [Decoder.InboundOut])], - decoderFactory: () -> Decoder) throws where Decoder.InboundOut: Equatable { + public static func verifyDecoder( + stringInputOutputPairs: [(String, [Decoder.InboundOut])], + decoderFactory: () -> Decoder + ) throws where Decoder.InboundOut: Equatable { let alloc = ByteBufferAllocator() - let ioPairs = stringInputOutputPairs.map { (ioPair: (String, [Decoder.InboundOut])) -> (ByteBuffer, [Decoder.InboundOut]) in - return (alloc.buffer(string: ioPair.0), ioPair.1) + let ioPairs = stringInputOutputPairs.map { + (ioPair: (String, [Decoder.InboundOut])) -> (ByteBuffer, [Decoder.InboundOut]) in + (alloc.buffer(string: ioPair.0), ioPair.1) } - + try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: ioPairs, decoderFactory: decoderFactory) } @@ -51,8 +54,10 @@ public enum ByteToMessageDecoderVerifier { /// ] /// XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: expectedInOuts, /// decoderFactory: { ExampleDecoder() })) - public static func verifyDecoder(inputOutputPairs: [(ByteBuffer, [Decoder.InboundOut])], - decoderFactory: () -> Decoder) throws where Decoder.InboundOut: Equatable { + public static func verifyDecoder( + inputOutputPairs: [(ByteBuffer, [Decoder.InboundOut])], + decoderFactory: () -> Decoder + ) throws where Decoder.InboundOut: Equatable { typealias Out = Decoder.InboundOut func verifySimple(channel: RecordingChannel) throws { @@ -60,19 +65,27 @@ public enum ByteToMessageDecoderVerifier { try channel.writeInbound(input) for expectedOutput in expectedOutputs { guard let actualOutput = try channel.readInbound(as: Out.self) else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .underProduction(expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .underProduction(expectedOutput) + ) } guard actualOutput == expectedOutput else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .wrongProduction(actual: actualOutput, - expected: expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .wrongProduction( + actual: actualOutput, + expected: expectedOutput + ) + ) } } let actualExtraOutput = try channel.readInbound(as: Out.self) guard actualExtraOutput == nil else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .overProduction(actualExtraOutput!)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .overProduction(actualExtraOutput!) + ) } } } @@ -91,19 +104,27 @@ public enum ByteToMessageDecoderVerifier { } for expectedOutput in expectedOutputs { guard let actualOutput = try channel.readInbound(as: Out.self) else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .underProduction(expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .underProduction(expectedOutput) + ) } guard actualOutput == expectedOutput else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .wrongProduction(actual: actualOutput, - expected: expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .wrongProduction( + actual: actualOutput, + expected: expectedOutput + ) + ) } } let actualExtraOutput = try channel.readInbound(as: Out.self) guard actualExtraOutput == nil else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .overProduction(actualExtraOutput!)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .overProduction(actualExtraOutput!) + ) } } } @@ -123,13 +144,19 @@ public enum ByteToMessageDecoderVerifier { try channel.writeInbound(overallBuffer) for expectedOutput in overallExpecteds { guard let actualOutput = try channel.readInbound(as: Out.self) else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .underProduction(expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .underProduction(expectedOutput) + ) } guard actualOutput == expectedOutput else { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .wrongProduction(actual: actualOutput, - expected: expectedOutput)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .wrongProduction( + actual: actualOutput, + expected: expectedOutput + ) + ) } } } @@ -142,10 +169,14 @@ public enum ByteToMessageDecoderVerifier { try verifyManyAtOnce(channel: channel) if case .leftOvers(inbound: let ib, outbound: let ob, pendingOutbound: let pob) = try channel.finish() { - throw VerificationError(inputs: channel.inboundWrites, - errorCode: .leftOversOnDeconstructingChannel(inbound: ib, - outbound: ob, - pendingOutbound: pob)) + throw VerificationError( + inputs: channel.inboundWrites, + errorCode: .leftOversOnDeconstructingChannel( + inbound: ib, + outbound: ob, + pendingOutbound: pob + ) + ) } } } @@ -160,7 +191,7 @@ extension ByteToMessageDecoderVerifier { } func readInbound(as type: T.Type = T.self) throws -> T? { - return try self.actualChannel.readInbound() + try self.actualChannel.readInbound() } @discardableResult public func writeInbound(_ data: ByteBuffer) throws -> EmbeddedChannel.BufferState { @@ -169,11 +200,11 @@ extension ByteToMessageDecoderVerifier { } var allocator: ByteBufferAllocator { - return self.actualChannel.allocator + self.actualChannel.allocator } func finish() throws -> EmbeddedChannel.LeftOverState { - return try self.actualChannel.finish() + try self.actualChannel.finish() } } } diff --git a/Sources/NIOTestUtils/EventCounterHandler.swift b/Sources/NIOTestUtils/EventCounterHandler.swift index cef809cc10..7b82293be9 100644 --- a/Sources/NIOTestUtils/EventCounterHandler.swift +++ b/Sources/NIOTestUtils/EventCounterHandler.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -import NIOCore -import NIOConcurrencyHelpers import Atomics +import NIOConcurrencyHelpers +import NIOCore /// `EventCounterHandler` is a `ChannelHandler` that counts and forwards all the events that it sees coming through /// the `ChannelPipeline`. @@ -58,87 +58,87 @@ extension EventCounterHandler { /// Returns the number of `channelRegistered` events seen so far in the `ChannelPipeline`. public var channelRegisteredCalls: Int { - return self._channelRegisteredCalls.load(ordering: .relaxed) + self._channelRegisteredCalls.load(ordering: .relaxed) } /// Returns the number of `channelUnregistered` events seen so far in the `ChannelPipeline`. public var channelUnregisteredCalls: Int { - return self._channelUnregisteredCalls.load(ordering: .relaxed) + self._channelUnregisteredCalls.load(ordering: .relaxed) } /// Returns the number of `channelActive` events seen so far in the `ChannelPipeline`. public var channelActiveCalls: Int { - return self._channelActiveCalls.load(ordering: .relaxed) + self._channelActiveCalls.load(ordering: .relaxed) } /// Returns the number of `channelInactive` events seen so far in the `ChannelPipeline`. public var channelInactiveCalls: Int { - return self._channelInactiveCalls.load(ordering: .relaxed) + self._channelInactiveCalls.load(ordering: .relaxed) } /// Returns the number of `channelRead` events seen so far in the `ChannelPipeline`. public var channelReadCalls: Int { - return self._channelReadCalls.load(ordering: .relaxed) + self._channelReadCalls.load(ordering: .relaxed) } /// Returns the number of `channelReadComplete` events seen so far in the `ChannelPipeline`. public var channelReadCompleteCalls: Int { - return self._channelReadCompleteCalls.load(ordering: .relaxed) + self._channelReadCompleteCalls.load(ordering: .relaxed) } /// Returns the number of `channelWritabilityChanged` events seen so far in the `ChannelPipeline`. public var channelWritabilityChangedCalls: Int { - return self._channelWritabilityChangedCalls.load(ordering: .relaxed) + self._channelWritabilityChangedCalls.load(ordering: .relaxed) } /// Returns the number of `userInboundEventTriggered` events seen so far in the `ChannelPipeline`. public var userInboundEventTriggeredCalls: Int { - return self._userInboundEventTriggeredCalls.load(ordering: .relaxed) + self._userInboundEventTriggeredCalls.load(ordering: .relaxed) } /// Returns the number of `errorCaught` events seen so far in the `ChannelPipeline`. public var errorCaughtCalls: Int { - return self._errorCaughtCalls.load(ordering: .relaxed) + self._errorCaughtCalls.load(ordering: .relaxed) } /// Returns the number of `register` events seen so far in the `ChannelPipeline`. public var registerCalls: Int { - return self._registerCalls.load(ordering: .relaxed) + self._registerCalls.load(ordering: .relaxed) } /// Returns the number of `bind` events seen so far in the `ChannelPipeline`. public var bindCalls: Int { - return self._bindCalls.load(ordering: .relaxed) + self._bindCalls.load(ordering: .relaxed) } /// Returns the number of `connect` events seen so far in the `ChannelPipeline`. public var connectCalls: Int { - return self._connectCalls.load(ordering: .relaxed) + self._connectCalls.load(ordering: .relaxed) } /// Returns the number of `write` events seen so far in the `ChannelPipeline`. public var writeCalls: Int { - return self._writeCalls.load(ordering: .relaxed) + self._writeCalls.load(ordering: .relaxed) } /// Returns the number of `flush` events seen so far in the `ChannelPipeline`. public var flushCalls: Int { - return self._flushCalls.load(ordering: .relaxed) + self._flushCalls.load(ordering: .relaxed) } /// Returns the number of `read` events seen so far in the `ChannelPipeline`. public var readCalls: Int { - return self._readCalls.load(ordering: .relaxed) + self._readCalls.load(ordering: .relaxed) } /// Returns the number of `close` events seen so far in the `ChannelPipeline`. public var closeCalls: Int { - return self._closeCalls.load(ordering: .relaxed) + self._closeCalls.load(ordering: .relaxed) } /// Returns the number of `triggerUserOutboundEvent` events seen so far in the `ChannelPipeline`. public var triggerUserOutboundEventCalls: Int { - return self._triggerUserOutboundEventCalls.load(ordering: .relaxed) + self._triggerUserOutboundEventCalls.load(ordering: .relaxed) } /// Validate some basic assumptions about the number of events and if any of those assumptions are violated, throw @@ -292,7 +292,7 @@ extension EventCounterHandler: ChannelDuplexHandler { self._channelReadCalls.wrappingIncrement(ordering: .relaxed) context.fireChannelRead(data) } - + /// @see: `_ChannelInboundHandler.channelReadComplete` public func channelReadComplete(context: ChannelHandlerContext) { self._channelReadCompleteCalls.wrappingIncrement(ordering: .relaxed) @@ -310,7 +310,7 @@ extension EventCounterHandler: ChannelDuplexHandler { self._userInboundEventTriggeredCalls.wrappingIncrement(ordering: .relaxed) context.fireUserInboundEventTriggered(event) } - + /// @see: `_ChannelInboundHandler.errorCaught` public func errorCaught(context: ChannelHandlerContext, error: Error) { self._errorCaughtCalls.wrappingIncrement(ordering: .relaxed) diff --git a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift index ec4133bb8a..b4326f5c58 100644 --- a/Sources/NIOTestUtils/NIOHTTP1TestServer.swift +++ b/Sources/NIOTestUtils/NIOHTTP1TestServer.swift @@ -11,10 +11,11 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import NIOConcurrencyHelpers import NIOCore -import NIOPosix import NIOHTTP1 -import NIOConcurrencyHelpers +import NIOPosix private final class BlockingQueue { private let condition = ConditionLock(value: false) @@ -36,9 +37,13 @@ private final class BlockingQueue { internal func popFirst(deadline: NIODeadline) throws -> Element { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.condition.lock(whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else { - throw TimeoutError() + guard + self.condition.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) + else { + throw TimeoutError() } let first = self.buffer.removeFirst() self.condition.unlock(withValue: !self.buffer.isEmpty) @@ -48,7 +53,6 @@ private final class BlockingQueue { extension BlockingQueue: @unchecked Sendable where Element: Sendable {} - private final class WebServerHandler: ChannelDuplexHandler { typealias InboundIn = HTTPServerRequestPart typealias OutboundIn = HTTPServerResponsePart @@ -250,13 +254,13 @@ public final class NIOHTTP1TestServer { channel.close(promise: nil) } return channel.eventLoop.makeSucceededFuture(()) - } - .bind(host: "127.0.0.1", port: 0) - .map { channel in - self.handleChannels() - return channel - } - .wait() + } + .bind(host: "127.0.0.1", port: 0) + .map { channel in + self.handleChannels() + return channel + } + .wait() } } @@ -270,8 +274,8 @@ extension NIOHTTP1TestServer { switch self.state { case .channelsAvailable(let channels): self.state = .stopped - channels.forEach { - $0.close(promise: nil) + for channel in channels { + channel.close(promise: nil) } case .waitingForChannel(let promise): self.state = .stopped @@ -357,8 +361,10 @@ extension NIOHTTP1TestServer { /// - deadline: The deadline by which a part must have been received. /// - verify: A closure which can be used to verify the contents of the `HTTPRequestHead`. /// - Throws: If the part was not a `.head` or nothing was read before the deadline. - public func receiveHeadAndVerify(deadline: NIODeadline = .now() + .seconds(10), - _ verify: (HTTPRequestHead) throws -> () = { _ in }) throws { + public func receiveHeadAndVerify( + deadline: NIODeadline = .now() + .seconds(10), + _ verify: (HTTPRequestHead) throws -> Void = { _ in } + ) throws { try verify(self.receiveHead(deadline: deadline)) } @@ -386,12 +392,13 @@ extension NIOHTTP1TestServer { /// - deadline: The deadline by which a part must have been received. /// - verify: A closure which can be used to verify the contents of the `ByteBuffer`. /// - Throws: If the part was not a `.body` or nothing was read before the deadline. - public func receiveBodyAndVerify(deadline: NIODeadline = .now() + .seconds(10), - _ verify: (ByteBuffer) throws -> () = { _ in }) throws { + public func receiveBodyAndVerify( + deadline: NIODeadline = .now() + .seconds(10), + _ verify: (ByteBuffer) throws -> Void = { _ in } + ) throws { try verify(self.receiveBody(deadline: deadline)) } - /// Waits for a message part to be received and checks that it was a `.end` before returning /// the `HTTPHeaders?` it contained. /// @@ -416,8 +423,10 @@ extension NIOHTTP1TestServer { /// - deadline: The deadline by which a part must have been received. /// - verify: A closure which can be used to verify the contents of the `HTTPHeaders?`. /// - Throws: If the part was not a `.end` or nothing was read before the deadline. - public func receiveEndAndVerify(deadline: NIODeadline = .now() + .seconds(10), - _ verify: (HTTPHeaders?) throws -> () = { _ in }) throws { + public func receiveEndAndVerify( + deadline: NIODeadline = .now() + .seconds(10), + _ verify: (HTTPHeaders?) throws -> Void = { _ in } + ) throws { try verify(self.receiveEnd()) } } @@ -430,6 +439,6 @@ public struct NIOHTTP1TestServerError: Error, Hashable, CustomStringConvertible } public var description: String { - return self.reason + self.reason } } diff --git a/Sources/NIOUDPEchoClient/main.swift b/Sources/NIOUDPEchoClient/main.swift index d4b05f4e46..f0348b7cd6 100644 --- a/Sources/NIOUDPEchoClient/main.swift +++ b/Sources/NIOUDPEchoClient/main.swift @@ -21,51 +21,51 @@ private final class EchoHandler: ChannelInboundHandler { public typealias InboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope private var numBytes = 0 - + private let remoteAddressInitializer: () throws -> SocketAddress - + init(remoteAddressInitializer: @escaping () throws -> SocketAddress) { self.remoteAddressInitializer = remoteAddressInitializer } - + public func channelActive(context: ChannelHandlerContext) { - + do { // Channel is available. It's time to send the message to the server to initialize the ping-pong sequence. - + // Get the server address. let remoteAddress = try self.remoteAddressInitializer() - + // Set the transmission data. let buffer = context.channel.allocator.buffer(string: line) self.numBytes = buffer.readableBytes - + // Forward the data. let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) - + context.writeAndFlush(Self.wrapOutboundOut(envelope), promise: nil) - + } catch { print("Could not resolve remote address") } } - + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let envelope = Self.unwrapInboundIn(data) let byteBuffer = envelope.data - + self.numBytes -= byteBuffer.readableBytes - + if self.numBytes <= 0 { let string = String(buffer: byteBuffer) print("Received: '\(string)' back from the server, closing channel.") context.close(promise: nil) } } - + public func errorCaught(context: ChannelHandlerContext, error: Error) { print("error: ", error) - + // As we are not really interested getting notified on success or failure we just pass nil as promise to // reduce allocations. context.close(promise: nil) @@ -101,15 +101,15 @@ enum ConnectTo { let connectTarget: ConnectTo switch (arg1, arg1.flatMap(Int.init), arg2, arg2.flatMap(Int.init), arg3.flatMap(Int.init)) { -case (.some(let h), .none , _, .some(let sp), .some(let lp)): - /* We received three arguments (String Int Int), let's interpret that as a server host with a server port and a local listening port */ +case (.some(let h), .none, _, .some(let sp), .some(let lp)): + // We received three arguments (String Int Int), let's interpret that as a server host with a server port and a local listening port connectTarget = .ip(host: h, sendPort: sp, listeningPort: lp) -case (.some(let sp), .none , .some(let lp), .none, _): - /* We received two arguments (String String), let's interpret that as sending socket path and listening socket path */ +case (.some(let sp), .none, .some(let lp), .none, _): + // We received two arguments (String String), let's interpret that as sending socket path and listening socket path assert(sp != lp, "The sending and listening sockets should differ.") connectTarget = .unixDomainSocket(sendPath: sp, listeningPath: lp) case (_, .some(let sp), _, .some(let lp), _): - /* We received two argument (Int Int), let's interpret that as the server port and a listening port on the default host. */ + // We received two argument (Int Int), let's interpret that as the server port and a listening port on the default host. connectTarget = .ip(host: defaultHost, sendPort: sp, listeningPort: lp) default: connectTarget = .ip(host: defaultHost, sendPort: defaultServerPort, listeningPort: defaultListeningPort) @@ -130,7 +130,7 @@ let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in channel.pipeline.addHandler(EchoHandler(remoteAddressInitializer: remoteAddress)) -} + } defer { try! group.syncShutdownGracefully() } diff --git a/Sources/NIOUDPEchoServer/main.swift b/Sources/NIOUDPEchoServer/main.swift index 979e6df800..bc942740ce 100644 --- a/Sources/NIOUDPEchoServer/main.swift +++ b/Sources/NIOUDPEchoServer/main.swift @@ -55,10 +55,13 @@ defer { try! group.syncShutdownGracefully() } -var arguments = CommandLine.arguments.dropFirst(0) // just to get an ArraySlice from [String] +var arguments = CommandLine.arguments.dropFirst(0) // just to get an ArraySlice from [String] if arguments.dropFirst().first == .some("--enable-gathering-reads") { bootstrap = bootstrap.channelOption(ChannelOptions.datagramVectorReadMessageCount, value: 30) - bootstrap = bootstrap.channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 30 * 2048)) + bootstrap = bootstrap.channelOption( + ChannelOptions.recvAllocator, + value: FixedSizeRecvByteBufferAllocator(capacity: 30 * 2048) + ) arguments = arguments.dropFirst() } let arg1 = arguments.dropFirst().first @@ -74,14 +77,14 @@ enum BindTo { let bindTarget: BindTo switch (arg1, arg1.flatMap(Int.init), arg2.flatMap(Int.init)) { -case (.some(let h), _ , .some(let p)): - /* we got two arguments, let's interpret that as host and port */ +case (.some(let h), _, .some(let p)): + // we got two arguments, let's interpret that as host and port bindTarget = .ip(host: h, port: p) case (.some(let portString), .none, _): - /* couldn't parse as number, expecting unix domain socket path */ + // couldn't parse as number, expecting unix domain socket path bindTarget = .unixDomainSocket(path: portString) case (_, .some(let p), _): - /* only one argument --> port */ + // only one argument --> port bindTarget = .ip(host: defaultHost, port: p) default: bindTarget = .ip(host: defaultHost, port: defaultPort) @@ -94,7 +97,7 @@ let channel = try { () -> Channel in case .unixDomainSocket(let path): return try bootstrap.bind(unixDomainSocketPath: path).wait() } - }() +}() print("Server started and listening on \(channel.localAddress!)") diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index a9e456f857..a25e69df90 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -21,7 +21,7 @@ public typealias NIOWebClientSocketUpgrader = NIOWebSocketClientUpgrader /// A `HTTPClientProtocolUpgrader` that knows how to do the WebSocket upgrade dance. /// -/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. +/// This upgrader assumes that the `HTTPClientUpgradeHandler` will create and send the upgrade request. /// This upgrader also assumes that the `HTTPClientUpgradeHandler` will appropriately mutate the /// pipeline to remove the HTTP `ChannelHandler`s. public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { @@ -29,7 +29,7 @@ public final class NIOWebSocketClientUpgrader: NIOHTTPClientProtocolUpgrader { public let supportedProtocol: String = "websocket" /// None of the websocket headers are actually defined as 'required'. public let requiredUpgradeHeaders: [String] = [] - + private let requestKey: String private let maxFrameSize: Int private let automaticErrorHandling: Bool @@ -141,7 +141,7 @@ extension NIOWebSocketClientUpgrader { @inlinable public static func randomRequestKey( using generator: inout Generator - ) -> String where Generator: RandomNumberGenerator{ + ) -> String where Generator: RandomNumberGenerator { var buffer = ByteBuffer() buffer.reserveCapacity(minimumWritableBytes: 16) /// we may want to use `randomBytes(count:)` once the proposal is accepted: https://forums.swift.org/t/pitch-requesting-larger-amounts-of-randomness-from-systemrandomnumbergenerator/27226 @@ -193,9 +193,11 @@ private func _upgrade( enableAutomaticErrorHandling: Bool, upgradePipelineHandler: @escaping @Sendable (Channel, HTTPResponseHead) -> EventLoopFuture ) -> EventLoopFuture { - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder()) - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))) + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) + ) if enableAutomaticErrorHandling { try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) } diff --git a/Sources/NIOWebSocket/NIOWebSocketFrameAggregator.swift b/Sources/NIOWebSocket/NIOWebSocketFrameAggregator.swift index 6f2447007e..1f929e8cae 100644 --- a/Sources/NIOWebSocket/NIOWebSocketFrameAggregator.swift +++ b/Sources/NIOWebSocket/NIOWebSocketFrameAggregator.swift @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// import NIOCore - /// `NIOWebSocketFrameAggregator` buffers inbound fragmented `WebSocketFrame`'s and aggregates them into a single `WebSocketFrame`. /// It guarantees that a `WebSocketFrame` with an `opcode` of `.continuation` is never forwarded. /// Frames which are not fragmented are just forwarded without any processing. @@ -30,15 +29,14 @@ public final class NIOWebSocketFrameAggregator: ChannelInboundHandler { } public typealias InboundIn = WebSocketFrame public typealias InboundOut = WebSocketFrame - + private let minNonFinalFragmentSize: Int private let maxAccumulatedFrameCount: Int private let maxAccumulatedFrameSize: Int - + private var bufferedFrames: [WebSocketFrame] = [] private var accumulatedFrameSize: Int = 0 - - + /// Configures a `NIOWebSocketFrameAggregator`. /// - Parameters: /// - minNonFinalFragmentSize: Minimum size in bytes of a fragment which is not the last fragment of a complete frame. Used to defend against many really small payloads. @@ -54,7 +52,6 @@ public final class NIOWebSocketFrameAggregator: ChannelInboundHandler { self.maxAccumulatedFrameSize = maxAccumulatedFrameSize } - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let frame = Self.unwrapInboundIn(data) do { @@ -64,16 +61,16 @@ public final class NIOWebSocketFrameAggregator: ChannelInboundHandler { throw Error.didReceiveFragmentBeforeReceivingTextOrBinaryFrame } try self.bufferFrame(frame) - + guard frame.fin else { break } // final frame received - + let aggregatedFrame = self.aggregateFrames( opcode: firstFrameOpcode, allocator: context.channel.allocator ) self.clearBuffer() - + context.fireChannelRead(wrapInboundOut(aggregatedFrame)) case .binary, .text: if frame.fin { @@ -95,7 +92,7 @@ public final class NIOWebSocketFrameAggregator: ChannelInboundHandler { context.fireErrorCaught(error) } } - + private func bufferFrame(_ frame: WebSocketFrame) throws { guard self.bufferedFrames.isEmpty || frame.opcode == .continuation else { throw Error.receivedNewFrameWithoutFinishingPrevious @@ -106,31 +103,31 @@ public final class NIOWebSocketFrameAggregator: ChannelInboundHandler { guard self.bufferedFrames.count < self.maxAccumulatedFrameCount else { throw Error.tooManyFragments } - + // if this is not a final frame, we will at least receive one more frame guard frame.fin || (self.bufferedFrames.count + 1) < self.maxAccumulatedFrameCount else { throw Error.tooManyFragments } - + self.bufferedFrames.append(frame) self.accumulatedFrameSize += frame.length - + guard self.accumulatedFrameSize <= self.maxAccumulatedFrameSize else { throw Error.accumulatedFrameSizeIsTooLarge } } - + private func aggregateFrames(opcode: WebSocketOpcode, allocator: ByteBufferAllocator) -> WebSocketFrame { var dataBuffer = allocator.buffer(capacity: self.accumulatedFrameSize) - + for frame in self.bufferedFrames { var unmaskedData = frame.unmaskedData dataBuffer.writeBuffer(&unmaskedData) } - + return WebSocketFrame(fin: true, opcode: opcode, data: dataBuffer) } - + private func clearBuffer() { self.bufferedFrames.removeAll(keepingCapacity: true) self.accumulatedFrameSize = 0 diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 0672bc4a06..2ac03cf2a4 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -41,8 +41,8 @@ public struct NIOWebSocketUpgradeError: Error, Equatable { public static let unsupportedWebSocketTarget = NIOWebSocketUpgradeError(actualError: .unsupportedWebSocketTarget) } -fileprivate extension HTTPHeaders { - func nonListHeader(_ name: String) throws -> String { +extension HTTPHeaders { + fileprivate func nonListHeader(_ name: String) throws -> String { let fields = self[canonicalForm: name] guard fields.count == 1 else { throw NIOWebSocketUpgradeError.invalidUpgradeHeader @@ -103,8 +103,12 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch shouldUpgrade: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture, upgradePipelineHandler: @escaping @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture ) { - self.init(maxFrameSize: 1 << 14, automaticErrorHandling: automaticErrorHandling, - shouldUpgrade: shouldUpgrade, upgradePipelineHandler: upgradePipelineHandler) + self.init( + maxFrameSize: 1 << 14, + automaticErrorHandling: automaticErrorHandling, + shouldUpgrade: shouldUpgrade, + upgradePipelineHandler: upgradePipelineHandler + ) } /// Create a new `NIOWebSocketServerUpgrader`. @@ -155,8 +159,12 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch self.automaticErrorHandling = automaticErrorHandling } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - return _buildUpgradeResponse( + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + _buildUpgradeResponse( channel: channel, upgradeRequest: upgradeRequest, initialResponseHeaders: initialResponseHeaders, @@ -186,7 +194,9 @@ public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unch /// /// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to /// remove the HTTP `ChannelHandler`s. -public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, Sendable { +public final class NIOTypedWebSocketServerUpgrader: NIOTypedHTTPServerProtocolUpgrader, + Sendable +{ private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture diff --git a/Sources/NIOWebSocket/WebSocketErrorCodes.swift b/Sources/NIOWebSocket/WebSocketErrorCodes.swift index 6199bdc179..5a26c76a82 100644 --- a/Sources/NIOWebSocket/WebSocketErrorCodes.swift +++ b/Sources/NIOWebSocket/WebSocketErrorCodes.swift @@ -132,7 +132,7 @@ extension ByteBuffer { /// /// - returns: The error code, or `nil` if there were not enough readable bytes. public mutating func readWebSocketErrorCode() -> WebSocketErrorCode? { - return self.readInteger(as: UInt16.self).map { WebSocketErrorCode(networkInteger: $0) } + self.readInteger(as: UInt16.self).map { WebSocketErrorCode(networkInteger: $0) } } /// Get a websocket error code from a byte buffer. @@ -144,7 +144,7 @@ extension ByteBuffer { /// - index: The index into the buffer to read the error code from. /// - returns: The error code, or `nil` if there were not enough bytes at that index. public func getWebSocketErrorCode(at index: Int) -> WebSocketErrorCode? { - return self.getInteger(at: index, as: UInt16.self).map { WebSocketErrorCode(networkInteger: $0) } + self.getInteger(at: index, as: UInt16.self).map { WebSocketErrorCode(networkInteger: $0) } } /// Write the given error code to the buffer. diff --git a/Sources/NIOWebSocket/WebSocketFrame.swift b/Sources/NIOWebSocket/WebSocketFrame.swift index f9f625c5a8..0e3c2cab83 100644 --- a/Sources/NIOWebSocket/WebSocketFrame.swift +++ b/Sources/NIOWebSocket/WebSocketFrame.swift @@ -14,12 +14,12 @@ import NIOCore -private extension UInt8 { - func isAnyBitSetInMask(_ mask: UInt8) -> Bool { - return self & mask != 0 +extension UInt8 { + fileprivate func isAnyBitSetInMask(_ mask: UInt8) -> Bool { + self & mask != 0 } - mutating func changingBitsInMask(_ mask: UInt8, to: Bool) { + fileprivate mutating func changingBitsInMask(_ mask: UInt8, to: Bool) { if to { self |= mask } else { @@ -42,10 +42,12 @@ public struct WebSocketMaskingKey: Sendable { return nil } - self._key = (buffer[buffer.startIndex], - buffer[buffer.index(buffer.startIndex, offsetBy: 1)], - buffer[buffer.index(buffer.startIndex, offsetBy: 2)], - buffer[buffer.index(buffer.startIndex, offsetBy: 3)]) + self._key = ( + buffer[buffer.startIndex], + buffer[buffer.index(buffer.startIndex, offsetBy: 1)], + buffer[buffer.index(buffer.startIndex, offsetBy: 2)], + buffer[buffer.index(buffer.startIndex, offsetBy: 3)] + ) } /// Creates a websocket masking key from the network-encoded @@ -56,10 +58,12 @@ public struct WebSocketMaskingKey: Sendable { /// masking key. @usableFromInline internal init(networkRepresentation integer: UInt32) { - self._key = (UInt8((integer & 0xFF000000) >> 24), - UInt8((integer & 0x00FF0000) >> 16), - UInt8((integer & 0x0000FF00) >> 8), - UInt8(integer & 0x000000FF)) + self._key = ( + UInt8((integer & 0xFF00_0000) >> 24), + UInt8((integer & 0x00FF_0000) >> 16), + UInt8((integer & 0x0000_FF00) >> 8), + UInt8(integer & 0x0000_00FF) + ) } } @@ -68,7 +72,7 @@ extension WebSocketMaskingKey: ExpressibleByArrayLiteral { public init(arrayLiteral elements: UInt8...) { precondition(elements.count == 4, "WebSocketMaskingKeys must be exactly 4 bytes long") - self.init(elements)! // length precondition above + self.init(elements)! // length precondition above } } @@ -81,9 +85,9 @@ extension WebSocketMaskingKey { public static func random( using generator: inout Generator ) -> WebSocketMaskingKey where Generator: RandomNumberGenerator { - return WebSocketMaskingKey(networkRepresentation: .random(in: UInt32.min...UInt32.max, using: &generator)) + WebSocketMaskingKey(networkRepresentation: .random(in: UInt32.min...UInt32.max, using: &generator)) } - + /// Returns a random masking key, using the `SystemRandomNumberGenerator` as a source for randomness. /// - Returns: A random masking key @inlinable @@ -94,8 +98,8 @@ extension WebSocketMaskingKey { } extension WebSocketMaskingKey: Equatable { - public static func ==(lhs: WebSocketMaskingKey, rhs: WebSocketMaskingKey) -> Bool { - return lhs._key == rhs._key + public static func == (lhs: WebSocketMaskingKey, rhs: WebSocketMaskingKey) -> Bool { + lhs._key == rhs._key } } @@ -103,11 +107,11 @@ extension WebSocketMaskingKey: Collection { public typealias Element = UInt8 public typealias Index = Int - public var startIndex: Int { return 0 } - public var endIndex: Int { return 4 } + public var startIndex: Int { 0 } + public var endIndex: Int { 4 } public func index(after: Int) -> Int { - return after + 1 + after + 1 } public subscript(index: Int) -> UInt8 { @@ -127,7 +131,7 @@ extension WebSocketMaskingKey: Collection { @inlinable public func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { - return try withUnsafeBytes(of: self._key) { ptr in + try withUnsafeBytes(of: self._key) { ptr in // this is boilerplate necessary to convert from UnsafeRawBufferPointer to UnsafeBufferPointer // we know ptr is bound since we defined self._key as let let typedPointer = ptr.baseAddress?.assumingMemoryBound(to: UInt8.self) @@ -162,7 +166,7 @@ public struct WebSocketFrame { /// a frame is not fragmented at all. public var fin: Bool { get { - return self.firstByte.isAnyBitSetInMask(0x80) + self.firstByte.isAnyBitSetInMask(0x80) } set { self.firstByte.changingBitsInMask(0x80, to: newValue) @@ -172,7 +176,7 @@ public struct WebSocketFrame { /// The value of the first reserved bit. Must be `false` unless using an extension that defines its use. public var rsv1: Bool { get { - return self.firstByte.isAnyBitSetInMask(0x40) + self.firstByte.isAnyBitSetInMask(0x40) } set { self.firstByte.changingBitsInMask(0x40, to: newValue) @@ -182,7 +186,7 @@ public struct WebSocketFrame { /// The value of the second reserved bit. Must be `false` unless using an extension that defines its use. public var rsv2: Bool { get { - return self.firstByte.isAnyBitSetInMask(0x20) + self.firstByte.isAnyBitSetInMask(0x20) } set { self.firstByte.changingBitsInMask(0x20, to: newValue) @@ -192,7 +196,7 @@ public struct WebSocketFrame { /// The value of the third reserved bit. Must be `false` unless using an extension that defines its use. public var rsv3: Bool { get { - return self.firstByte.isAnyBitSetInMask(0x10) + self.firstByte.isAnyBitSetInMask(0x10) } set { self.firstByte.changingBitsInMask(0x10, to: newValue) @@ -204,7 +208,7 @@ public struct WebSocketFrame { get { // this is a public initialiser which only fails if the opcode is invalid. But all opcodes in 0...0xF // space are valid so this can never fail. - return WebSocketOpcode(encodedWebSocketOpcode: firstByte & 0x0F)! + WebSocketOpcode(encodedWebSocketOpcode: firstByte & 0x0F)! } set { self.firstByte = (self.firstByte & 0xF0) + UInt8(webSocketOpcode: newValue) @@ -213,7 +217,7 @@ public struct WebSocketFrame { /// The total length of the data in the frame. public var length: Int { - return data.readableBytes + (extensionData?.readableBytes ?? 0) + data.readableBytes + (extensionData?.readableBytes ?? 0) } /// The application data. @@ -223,7 +227,7 @@ public struct WebSocketFrame { /// obtain it, or transform this data directly by calling `data.unmask(maskKey)`. public var data: ByteBuffer { get { - return self._storage.data + self._storage.data } set { if !isKnownUniquelyReferenced(&self._storage) { @@ -240,7 +244,7 @@ public struct WebSocketFrame { /// obtain it, or transform this data directly by calling `extensionData.unmask(maskKey)`. public var extensionData: ByteBuffer? { get { - return self._storage.extensionData + self._storage.extensionData } set { if !isKnownUniquelyReferenced(&self._storage) { @@ -303,9 +307,16 @@ public struct WebSocketFrame { /// - maskKey: The masking key for the frame, if any. Defaults to `nil`. /// - data: The application data for the frame. /// - extensionData: The extension data for the frame. - public init(fin: Bool = false, rsv1: Bool = false, rsv2: Bool = false, rsv3: Bool = false, - opcode: WebSocketOpcode = .continuation, maskKey: WebSocketMaskingKey? = nil, - data: ByteBuffer, extensionData: ByteBuffer? = nil) { + public init( + fin: Bool = false, + rsv1: Bool = false, + rsv2: Bool = false, + rsv3: Bool = false, + opcode: WebSocketOpcode = .continuation, + maskKey: WebSocketMaskingKey? = nil, + data: ByteBuffer, + extensionData: ByteBuffer? = nil + ) { self._storage = .init(data: data, extensionData: extensionData) self.fin = fin self.rsv1 = rsv1 @@ -348,7 +359,7 @@ extension WebSocketFrame { extension WebSocketFrame._Storage: Sendable {} extension WebSocketFrame._Storage: Equatable { - static func ==(lhs: WebSocketFrame._Storage, rhs: WebSocketFrame._Storage) -> Bool { - return lhs.data == rhs.data && lhs.extensionData == rhs.extensionData + static func == (lhs: WebSocketFrame._Storage, rhs: WebSocketFrame._Storage) -> Bool { + lhs.data == rhs.data && lhs.extensionData == rhs.extensionData } } diff --git a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift index b517186df7..583cdbec66 100644 --- a/Sources/NIOWebSocket/WebSocketFrameDecoder.swift +++ b/Sources/NIOWebSocket/WebSocketFrameDecoder.swift @@ -33,7 +33,7 @@ extension WebSocketErrorCode { case .invalidFrameLength: self = .messageTooLarge case .fragmentedControlFrame, - .multiByteControlFrameLength: + .multiByteControlFrameLength: self = .protocolError } } @@ -169,7 +169,11 @@ struct WSParser { return .insufficientData } - self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey)) + self.state = .waitingForData( + firstByte: firstByte, + length: length, + maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey) + ) return .continueParsing case .waitingForData(let firstByte, let length, let maskingKey): @@ -226,7 +230,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { public typealias OutboundOut = WebSocketFrame /// The maximum frame size the decoder is willing to tolerate from the remote peer. - /* private but tests */ let maxFrameSize: Int + let maxFrameSize: Int /// Our parser state. private var parser = WSParser() @@ -251,7 +255,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { self.maxFrameSize = maxFrameSize } - public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { + public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { // Even though the calling code will loop around calling us in `decode`, we can't quite // rely on that: sometimes we have zero-length elements to parse, and the caller doesn't // guarantee to call us with zero-length bytes. @@ -262,7 +266,7 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder { return .continue case .continueParsing: try self.parser.validateState(maxFrameSize: self.maxFrameSize) - // loop again, might be 'waiting' for 0 bytes + // loop again, might be 'waiting' for 0 bytes case .insufficientData: return .needMoreData } diff --git a/Sources/NIOWebSocket/WebSocketFrameEncoder.swift b/Sources/NIOWebSocket/WebSocketFrameEncoder.swift index 3e78c86132..8bac7446dc 100644 --- a/Sources/NIOWebSocket/WebSocketFrameEncoder.swift +++ b/Sources/NIOWebSocket/WebSocketFrameEncoder.swift @@ -48,7 +48,7 @@ public final class WebSocketFrameEncoder: ChannelOutboundHandler { /// for a mask key. private static let maximumFrameHeaderLength: Int = (2 + 4 + 8) - public init() { } + public init() {} public func handlerAdded(context: ChannelHandlerContext) { self.headerBuffer = context.channel.allocator.buffer(capacity: WebSocketFrameEncoder.maximumFrameHeaderLength) @@ -63,7 +63,11 @@ public final class WebSocketFrameEncoder: ChannelOutboundHandler { // First, we explode the frame structure and apply the mask. let frameHeader = FrameHeader(frame: data) - var (extensionData, applicationData) = self.mask(key: frameHeader.maskKey, extensionData: data.extensionData, applicationData: data.data) + var (extensionData, applicationData) = self.mask( + key: frameHeader.maskKey, + extensionData: data.extensionData, + applicationData: data.data + ) // Now we attempt to prepend the frame header to the first buffer. If we can't, we'll write to the header buffer. If we have // an extension data buffer, that's the first buffer, and we'll also write it here. @@ -83,7 +87,11 @@ public final class WebSocketFrameEncoder: ChannelOutboundHandler { } /// Applies the websocket masking operation based on the passed byte buffers. - private func mask(key: WebSocketMaskingKey?, extensionData: ByteBuffer?, applicationData: ByteBuffer) -> (ByteBuffer?, ByteBuffer) { + private func mask( + key: WebSocketMaskingKey?, + extensionData: ByteBuffer?, + applicationData: ByteBuffer + ) -> (ByteBuffer?, ByteBuffer) { guard let key = key else { return (extensionData, applicationData) } @@ -181,9 +189,8 @@ extension ByteBuffer { } } - /// A helper object that holds only a websocket frame header. Used to avoid accidentally CoWing on some paths. -fileprivate struct FrameHeader { +private struct FrameHeader { var length: Int var maskKey: WebSocketMaskingKey? var firstByte: UInt8 = 0 @@ -214,7 +221,6 @@ fileprivate struct FrameHeader { size += 4 // Masking key } - return size } } diff --git a/Sources/NIOWebSocket/WebSocketOpcode.swift b/Sources/NIOWebSocket/WebSocketOpcode.swift index 245f853a80..953d8c1b53 100644 --- a/Sources/NIOWebSocket/WebSocketOpcode.swift +++ b/Sources/NIOWebSocket/WebSocketOpcode.swift @@ -48,13 +48,13 @@ public struct WebSocketOpcode: Sendable { /// Whether the opcode is in the control range: that is, if the /// high bit of the opcode nibble is `1`. public var isControlOpcode: Bool { - return self.networkRepresentation & 0x8 == 0x8 + self.networkRepresentation & 0x8 == 0x8 } } -extension WebSocketOpcode: Equatable { } +extension WebSocketOpcode: Equatable {} -extension WebSocketOpcode: Hashable { } +extension WebSocketOpcode: Hashable {} extension WebSocketOpcode: CaseIterable { public static var allCases = (0..<0x10).map { WebSocketOpcode(rawValue: $0) } diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index a9913f7eaa..89714f4c90 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// import NIOCore - /// A simple `ChannelHandler` that catches protocol errors emitted by the /// `WebSocketFrameDecoder` and automatically generates protocol error responses. /// @@ -23,15 +22,17 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias InboundIn = Never public typealias OutboundOut = WebSocketFrame - public init() { } + public init() {} public func errorCaught(context: ChannelHandlerContext, error: Error) { if let error = error as? NIOWebSocketError { var data = context.channel.allocator.buffer(capacity: 2) data.write(webSocketErrorCode: WebSocketErrorCode(error)) - let frame = WebSocketFrame(fin: true, - opcode: .connectionClose, - data: data) + let frame = WebSocketFrame( + fin: true, + opcode: .connectionClose, + data: data + ) context.writeAndFlush(Self.wrapOutboundOut(frame)).whenComplete { (_: Result) in context.close(promise: nil) } diff --git a/Sources/NIOWebSocketClient/Client.swift b/Sources/NIOWebSocketClient/Client.swift index 0bf42d3b16..aaca0c76cd 100644 --- a/Sources/NIOWebSocketClient/Client.swift +++ b/Sources/NIOWebSocketClient/Client.swift @@ -53,7 +53,9 @@ struct Client { let upgrader = NIOTypedWebSocketClientUpgrader( upgradePipelineHandler: { (channel, _) in channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) + let asyncChannel = try NIOAsyncChannel( + wrappingChannelSynchronously: channel + ) return UpgradeResult.websocket(asyncChannel) } } @@ -75,14 +77,15 @@ struct Client { upgraders: [upgrader], notUpgradingCompletionHandler: { channel in channel.eventLoop.makeCompletedFuture { - return UpgradeResult.notUpgraded + UpgradeResult.notUpgraded } } ) - let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( - configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) - ) + let negotiationResultFuture = try channel.pipeline.syncOperations + .configureUpgradableHTTPClientPipeline( + configuration: .init(upgradeConfiguration: clientUpgradeConfiguration) + ) return negotiationResultFuture } diff --git a/Sources/NIOWebSocketServer/Server.swift b/Sources/NIOWebSocketServer/Server.swift index 1af63e29d5..6a9252b0e9 100644 --- a/Sources/NIOWebSocketServer/Server.swift +++ b/Sources/NIOWebSocketServer/Server.swift @@ -19,28 +19,28 @@ import NIOHTTP1 import NIOWebSocket let websocketResponse = """ - - - - - Swift NIO WebSocket Test Page - - - -

WebSocket Stream

-
- - -""" + var textDiv = document.getElementById("websocket-stream"); + textDiv.insertBefore(element, null); + }; + + + +

WebSocket Stream

+
+ + + """ @available(macOS 14, iOS 17, tvOS 17, watchOS 10, *) @main @@ -70,43 +70,49 @@ struct Server { /// This method starts the server and handles incoming connections. func run() async throws { - let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: self.eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .bind( - host: self.host, - port: self.port - ) { channel in - channel.eventLoop.makeCompletedFuture { - let upgrader = NIOTypedWebSocketServerUpgrader( - shouldUpgrade: { (channel, head) in - channel.eventLoop.makeSucceededFuture(HTTPHeaders()) - }, - upgradePipelineHandler: { (channel, _) in - channel.eventLoop.makeCompletedFuture { - let asyncChannel = try NIOAsyncChannel(wrappingChannelSynchronously: channel) - return UpgradeResult.websocket(asyncChannel) - } + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap( + group: self.eventLoopGroup + ) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind( + host: self.host, + port: self.port + ) { channel in + channel.eventLoop.makeCompletedFuture { + let upgrader = NIOTypedWebSocketServerUpgrader( + shouldUpgrade: { (channel, head) in + channel.eventLoop.makeSucceededFuture(HTTPHeaders()) + }, + upgradePipelineHandler: { (channel, _) in + channel.eventLoop.makeCompletedFuture { + let asyncChannel = try NIOAsyncChannel( + wrappingChannelSynchronously: channel + ) + return UpgradeResult.websocket(asyncChannel) } - ) + } + ) - let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( - upgraders: [upgrader], - notUpgradingCompletionHandler: { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) - let asyncChannel = try NIOAsyncChannel>(wrappingChannelSynchronously: channel) - return UpgradeResult.notUpgraded(asyncChannel) - } + let serverUpgradeConfiguration = NIOTypedHTTPServerUpgradeConfiguration( + upgraders: [upgrader], + notUpgradingCompletionHandler: { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(HTTPByteBufferResponsePartHandler()) + let asyncChannel = try NIOAsyncChannel< + HTTPServerRequestPart, HTTPPart + >(wrappingChannelSynchronously: channel) + return UpgradeResult.notUpgraded(asyncChannel) } - ) + } + ) - let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( - configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) - ) + let negotiationResultFuture = try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: .init(upgradeConfiguration: serverUpgradeConfiguration) + ) - return negotiationResultFuture - } + return negotiationResultFuture } + } // We are handling each incoming connection in a separate child task. It is important // to use a discarding task group here which automatically discards finished child tasks. @@ -208,8 +214,9 @@ struct Server { } } - - private func handleHTTPChannel(_ channel: NIOAsyncChannel>) async throws { + private func handleHTTPChannel( + _ channel: NIOAsyncChannel> + ) async throws { try await channel.executeThenClose { inbound, outbound in for try await requestPart in inbound { // We're not interested in request bodies here: we're just serving up GET responses @@ -238,14 +245,15 @@ struct Server { contentsOf: [ .head(responseHead), .body(Self.responseBody), - .end(nil) + .end(nil), ] ) } } } - private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws { + private func respond405(writer: NIOAsyncChannelOutboundWriter>) async throws + { var headers = HTTPHeaders() headers.add(name: "Connection", value: "close") headers.add(name: "Content-Length", value: "0") @@ -258,7 +266,7 @@ struct Server { try await writer.write( contentsOf: [ .head(head), - .end(nil) + .end(nil), ] ) } diff --git a/Sources/_NIOBase64/Base64.swift b/Sources/_NIOBase64/Base64.swift index eb63cdd6b7..943e2f3a93 100644 --- a/Sources/_NIOBase64/Base64.swift +++ b/Sources/_NIOBase64/Base64.swift @@ -15,17 +15,17 @@ // This is a simplified vendored version from: // https://github.com/fabianfett/swift-base64-kit -public extension String { +extension String { /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. @inlinable - init(base64Encoding bytes: Buffer) where Buffer.Element == UInt8 { + public init(base64Encoding bytes: Buffer) where Buffer.Element == UInt8 { self = Base64.encode(bytes: bytes) } @inlinable - func base64Decoded() throws -> [UInt8] { - return try Base64.decode(string: self) + public func base64Decoded() throws -> [UInt8] { + try Base64.decode(string: self) } } @@ -46,7 +46,7 @@ internal struct Base64 { // In Base64, 3 bytes become 4 output characters, and we pad to the // nearest multiple of four. let base64StringLength = ((bytes.count + 2) / 3) * 4 - let alphabet = Base64.encodeBase64 + let alphabet = Base64.encodingTable return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in var input = bytes.makeIterator() @@ -56,8 +56,16 @@ internal struct Base64 { let thirdByte = input.next() backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) - backingStorage[offset + 1] = Base64.encode(alphabet: alphabet, firstByte: firstByte, secondByte: secondByte) - backingStorage[offset + 2] = Base64.encode(alphabet: alphabet, secondByte: secondByte, thirdByte: thirdByte) + backingStorage[offset + 1] = Base64.encode( + alphabet: alphabet, + firstByte: firstByte, + secondByte: secondByte + ) + backingStorage[offset + 2] = Base64.encode( + alphabet: alphabet, + secondByte: secondByte, + thirdByte: thirdByte + ) backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) offset += 4 } @@ -77,8 +85,9 @@ internal struct Base64 { // Go over the encoded string in groups of 4 characters, // and build groups of 3 bytes from them. for i in stride(from: 0, to: bytes.count, by: 4) { - guard let byte0Index = Base64.encodeBase64.firstIndex(of: bytes[i]), - let byte1Index = Base64.encodeBase64.firstIndex(of: bytes[i+1]) else { + guard let byte0Index = Base64.encodingTable.firstIndex(of: bytes[i]), + let byte1Index = Base64.encodingTable.firstIndex(of: bytes[i + 1]) + else { throw Base64Error.invalidCharacter } @@ -86,8 +95,8 @@ internal struct Base64 { decoded.append(byte0) // Check if the 3rd char is not a padding character, and decode the 2nd byte - if bytes[i+2] != Base64.encodePaddingCharacter { - guard let byte2Index = Base64.encodeBase64.firstIndex(of: bytes[i+2]) else { + if bytes[i + 2] != Base64.encodePaddingCharacter { + guard let byte2Index = Base64.encodingTable.firstIndex(of: bytes[i + 2]) else { throw Base64Error.invalidCharacter } @@ -96,9 +105,10 @@ internal struct Base64 { } // Check if the 4th character is not a padding, and decode the 3rd byte - if bytes[i+3] != Base64.encodePaddingCharacter { - guard let byte3Index = Base64.encodeBase64.firstIndex(of: bytes[i+3]), - let byte2Index = Base64.encodeBase64.firstIndex(of: bytes[i+2]) else { + if bytes[i + 3] != Base64.encodePaddingCharacter { + guard let byte3Index = Base64.encodingTable.firstIndex(of: bytes[i + 3]), + let byte2Index = Base64.encodingTable.firstIndex(of: bytes[i + 2]) + else { throw Base64Error.invalidCharacter } let third = (UInt8(byte2Index) << 6 | UInt8(byte3Index)) @@ -108,12 +118,11 @@ internal struct Base64 { return decoded } - // MARK: Internal // The base64 unicode table. @usableFromInline - static let encodeBase64: [UInt8] = [ + static let encodingTable: [UInt8] = [ UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), @@ -179,8 +188,10 @@ extension String { /// /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. @inlinable - init(backportUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { + init( + backportUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { // The buffer will store zero terminated C string let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) @@ -205,8 +216,10 @@ extension String { extension String { @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { + init( + customUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) } else { diff --git a/Sources/_NIODataStructures/Heap.swift b/Sources/_NIODataStructures/Heap.swift index 15a2a32b09..9af0c0622b 100644 --- a/Sources/_NIODataStructures/Heap.swift +++ b/Sources/_NIODataStructures/Heap.swift @@ -26,7 +26,7 @@ import ucrt @usableFromInline internal struct Heap { @usableFromInline - internal private(set) var storage: Array + internal private(set) var storage: [Element] @inlinable internal init() { @@ -36,29 +36,29 @@ internal struct Heap { @inlinable internal func comparator(_ lhs: Element, _ rhs: Element) -> Bool { // This heap is always a min-heap. - return lhs < rhs + lhs < rhs } // named `PARENT` in CLRS @inlinable internal func parentIndex(_ i: Int) -> Int { - return (i-1) / 2 + (i - 1) / 2 } // named `LEFT` in CLRS @inlinable internal func leftIndex(_ i: Int) -> Int { - return 2*i + 1 + 2 * i + 1 } // named `RIGHT` in CLRS @inlinable internal func rightIndex(_ i: Int) -> Int { - return 2*i + 2 + 2 * i + 2 } // named `MAX-HEAPIFY` in CLRS - /* private but */ @inlinable + @inlinable mutating func _heapify(_ index: Int) { let left = self.leftIndex(index) let right = self.rightIndex(index) @@ -81,7 +81,7 @@ internal struct Heap { } // named `HEAP-INCREASE-KEY` in CRLS - /* private but */ @inlinable + @inlinable mutating func _heapRootify(index: Int, key: Element) { var index = index if self.comparator(storage[index], key) { @@ -108,7 +108,7 @@ internal struct Heap { @discardableResult @inlinable internal mutating func removeRoot() -> Element? { - return self._remove(index: 0) + self._remove(index: 0) } @discardableResult @@ -121,7 +121,7 @@ internal struct Heap { return false } } - + @inlinable internal mutating func removeFirst(where shouldBeRemoved: (Element) throws -> Bool) rethrows { guard self.storage.count > 0 else { @@ -131,12 +131,12 @@ internal struct Heap { guard let index = try self.storage.firstIndex(where: shouldBeRemoved) else { return } - + self._remove(index: index) } @discardableResult - /* private but */ @inlinable + @inlinable mutating func _remove(index: Int) -> Element? { guard self.storage.count > 0 else { return nil @@ -163,7 +163,7 @@ extension Heap: CustomDebugStringConvertible { return "" } let descriptions = self.storage.map { String(describing: $0) } - let maxLen: Int = descriptions.map { $0.count }.max()! // storage checked non-empty above + let maxLen: Int = descriptions.map { $0.count }.max()! // storage checked non-empty above let paddedDescs = descriptions.map { (desc: String) -> String in var desc = desc while desc.count < maxLen { @@ -200,7 +200,7 @@ extension Heap: CustomDebugStringConvertible { all += String(repeating: " ", count: rightWidth) func height(index: Int) -> Int { - return Int(log2(Double(index + 1))) + Int(log2(Double(index + 1))) } let myHeight = height(index: index) let nextHeight = height(index: index + 1) @@ -217,7 +217,7 @@ extension Heap: CustomDebugStringConvertible { @usableFromInline struct HeapIterator: IteratorProtocol { - /* private but */ @usableFromInline + @usableFromInline var _heap: Heap @inlinable @@ -227,44 +227,44 @@ struct HeapIterator: IteratorProtocol { @inlinable mutating func next() -> Element? { - return self._heap.removeRoot() + self._heap.removeRoot() } } extension Heap: Sequence { @inlinable var startIndex: Int { - return self.storage.startIndex + self.storage.startIndex } @inlinable var endIndex: Int { - return self.storage.endIndex + self.storage.endIndex } @inlinable var underestimatedCount: Int { - return self.storage.count + self.storage.count } @inlinable func makeIterator() -> HeapIterator { - return HeapIterator(heap: self) + HeapIterator(heap: self) } @inlinable subscript(position: Int) -> Element { - return self.storage[position] + self.storage[position] } @inlinable func index(after i: Int) -> Int { - return i + 1 + i + 1 } @inlinable var count: Int { - return self.storage.count + self.storage.count } } diff --git a/Sources/_NIODataStructures/PriorityQueue.swift b/Sources/_NIODataStructures/PriorityQueue.swift index b04b50f6a9..765826166e 100644 --- a/Sources/_NIODataStructures/PriorityQueue.swift +++ b/Sources/_NIODataStructures/PriorityQueue.swift @@ -25,7 +25,7 @@ public struct PriorityQueue { public mutating func remove(_ key: Element) { self._heap.remove(value: key) } - + @inlinable public mutating func removeFirst(where shouldBeRemoved: (Element) throws -> Bool) rethrows { try self._heap.removeFirst(where: shouldBeRemoved) @@ -38,18 +38,18 @@ public struct PriorityQueue { @inlinable public func peek() -> Element? { - return self._heap.storage.first + self._heap.storage.first } @inlinable public var isEmpty: Bool { - return self._heap.storage.isEmpty + self._heap.storage.isEmpty } @inlinable @discardableResult public mutating func pop() -> Element? { - return self._heap.removeRoot() + self._heap.removeRoot() } @inlinable @@ -60,45 +60,45 @@ public struct PriorityQueue { extension PriorityQueue: Equatable { @inlinable - public static func ==(lhs: PriorityQueue, rhs: PriorityQueue) -> Bool { - return lhs.count == rhs.count && lhs.elementsEqual(rhs) + public static func == (lhs: PriorityQueue, rhs: PriorityQueue) -> Bool { + lhs.count == rhs.count && lhs.elementsEqual(rhs) } } extension PriorityQueue: Sequence { public struct Iterator: IteratorProtocol { - /* private but */ @usableFromInline + @usableFromInline var _queue: PriorityQueue - /* fileprivate but */ @inlinable + @inlinable public init(queue: PriorityQueue) { self._queue = queue } @inlinable public mutating func next() -> Element? { - return self._queue.pop() + self._queue.pop() } } @inlinable public func makeIterator() -> Iterator { - return Iterator(queue: self) + Iterator(queue: self) } } extension PriorityQueue { @inlinable public var count: Int { - return self._heap.count + self._heap.count } } extension PriorityQueue: CustomStringConvertible { @inlinable public var description: String { - return "PriorityQueue(count: \(self.count)): \(Array(self))" + "PriorityQueue(count: \(self.count)): \(Array(self))" } } diff --git a/Sources/_NIODataStructures/_TinyArray.swift b/Sources/_NIODataStructures/_TinyArray.swift index bc1c154b6a..f4e21fd655 100644 --- a/Sources/_NIODataStructures/_TinyArray.swift +++ b/Sources/_NIODataStructures/_TinyArray.swift @@ -52,7 +52,7 @@ extension _TinyArray: RandomAccessCollection { @inlinable public func makeIterator() -> Iterator { - return Iterator(storage: self.storage) + Iterator(storage: self.storage) } public struct Iterator: IteratorProtocol { diff --git a/Tests/NIOBase64Tests/Base64Test.swift b/Tests/NIOBase64Tests/Base64Test.swift index 8dcaa457e8..dd6262a9d8 100644 --- a/Tests/NIOBase64Tests/Base64Test.swift +++ b/Tests/NIOBase64Tests/Base64Test.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import _NIOBase64 class Base64Test: XCTestCase { @@ -59,13 +60,16 @@ class Base64Test: XCTestCase { func testBase64DecodeHelloWorld() throws { let encoded = "SGVsbG8sIHdvcmxkIQ==" let decoded = try! encoded.base64Decoded() - XCTAssertEqual(decoded, "Hello, world!".utf8.map{ UInt8($0) }) + XCTAssertEqual(decoded, "Hello, world!".utf8.map { UInt8($0) }) } func testBase64EncodingAllTheBytesSequentially() throws { let data = Array(UInt8(0)...UInt8(255)) let encodedData = String(base64Encoding: data) - XCTAssertEqual(encodedData, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w==") + XCTAssertEqual( + encodedData, + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouMjY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0tPU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+/w==" + ) } func testBase64DecodingWithInvalidLength() { diff --git a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift index ae11f3eb78..05bcb953f7 100644 --- a/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift +++ b/Tests/NIOConcurrencyHelpersTests/NIOConcurrencyHelpersTests.swift @@ -11,6 +11,13 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import Dispatch +import NIOCore +import XCTest + +@testable import NIOConcurrencyHelpers + #if canImport(Darwin) import Darwin #elseif canImport(Glibc) @@ -18,16 +25,12 @@ import Glibc #else #error("The Concurrency helpers test module was unable to identify your C library.") #endif -import Dispatch -import XCTest -import NIOCore -@testable import NIOConcurrencyHelpers class NIOConcurrencyHelpersTests: XCTestCase { private func sumOfIntegers(until n: UInt64) -> UInt64 { - return n*(n+1)/2 + n * (n + 1) / 2 } - + #if canImport(Darwin) let noAsyncs: UInt64 = 50 #else @@ -41,7 +44,7 @@ class NIOConcurrencyHelpersTests: XCTestCase { @available(*, deprecated, message: "deprecated because it tests deprecated functionality") func testLargeContendedAtomicSum() { - + let noCounts: UInt64 = 2_000 let q = DispatchQueue(label: "q", attributes: .concurrent) @@ -118,7 +121,7 @@ class NIOConcurrencyHelpersTests: XCTestCase { var counter = max for _ in 0..<255 { - XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter-1)) + XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter - 1)) counter = counter - 1 } } @@ -149,7 +152,7 @@ class NIOConcurrencyHelpersTests: XCTestCase { var counter = upperBound for _ in 0..<255 { - XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter-1)) + XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter - 1)) XCTAssertFalse(ab.compareAndExchange(expected: counter, desired: counter)) counter = counter - 1 } @@ -330,7 +333,7 @@ class NIOConcurrencyHelpersTests: XCTestCase { var counter = max for _ in 0..<255 { - XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter-1)) + XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter - 1)) counter = counter - 1 } } @@ -361,7 +364,7 @@ class NIOConcurrencyHelpersTests: XCTestCase { var counter = upperBound for _ in 0..<255 { - XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter-1)) + XCTAssertTrue(ab.compareAndExchange(expected: counter, desired: counter - 1)) XCTAssertFalse(ab.compareAndExchange(expected: counter, desired: counter)) counter = counter - 1 } @@ -471,7 +474,6 @@ class NIOConcurrencyHelpersTests: XCTestCase { testFor(UInt.self) } - func testLockMutualExclusion() { let l = NIOLock() @@ -492,8 +494,10 @@ class NIOConcurrencyHelpersTests: XCTestCase { } sem1.wait() - XCTAssertEqual(DispatchTimeoutResult.timedOut, - g.wait(timeout: .now() + 0.1)) + XCTAssertEqual( + DispatchTimeoutResult.timedOut, + g.wait(timeout: .now() + 0.1) + ) XCTAssertEqual(1, x) l.unlock() @@ -523,8 +527,10 @@ class NIOConcurrencyHelpersTests: XCTestCase { } sem1.wait() - XCTAssertEqual(DispatchTimeoutResult.timedOut, - g.wait(timeout: .now() + 0.1)) + XCTAssertEqual( + DispatchTimeoutResult.timedOut, + g.wait(timeout: .now() + 0.1) + ) XCTAssertEqual(1, x) } sem2.wait() @@ -554,8 +560,10 @@ class NIOConcurrencyHelpersTests: XCTestCase { } sem1.wait() - XCTAssertEqual(DispatchTimeoutResult.timedOut, - g.wait(timeout: .now() + 0.1)) + XCTAssertEqual( + DispatchTimeoutResult.timedOut, + g.wait(timeout: .now() + 0.1) + ) XCTAssertEqual(1, x) l.unlock() @@ -644,13 +652,13 @@ class NIOConcurrencyHelpersTests: XCTestCase { l.lock() l.unlock(withValue: 1) - doneSem.wait() /* job on 'q1' is done */ + doneSem.wait() // job on 'q1' is done XCTAssertEqual(1, l.value) l.lock() l.unlock(withValue: 2) - doneSem.wait() /* job on 'q2' is done */ + doneSem.wait() // job on 'q2' is done } } @@ -890,12 +898,12 @@ class NIOConcurrencyHelpersTests: XCTestCase { spawnAndJoinRacingThreads(count: 6) { i in switch i { - case 0: // writer - for i in 1 ... iterations { + case 0: // writer + for i in 1...iterations { let nextObject = box.exchange(with: .init(i, allDeallocations: allDeallocations)) XCTAssertEqual(nextObject.value, i - 1) } - default: // readers + default: // readers while true { if box.load().value < 0 || box.load().value > iterations { XCTFail("bad") @@ -923,11 +931,11 @@ class NIOConcurrencyHelpersTests: XCTestCase { spawnAndJoinRacingThreads(count: 6) { i in switch i { - case 0: // writer - for i in 1 ... iterations { + case 0: // writer + for i in 1...iterations { box.store(IntHolderWithDeallocationTracking(i, allDeallocations: allDeallocations)) } - default: // readers + default: // readers while true { if box.load().value < 0 || box.load().value > iterations { XCTFail("loaded the wrong value") @@ -955,16 +963,18 @@ class NIOConcurrencyHelpersTests: XCTestCase { spawnAndJoinRacingThreads(count: 6) { i in switch i { - case 0: // writer - for i in 1 ... iterations { + case 0: // writer + for i in 1...iterations { let old = box.load() XCTAssertEqual(i - 1, old.value) - if !box.compareAndExchange(expected: old, - desired: .init(i, allDeallocations: allDeallocations)) { + if !box.compareAndExchange( + expected: old, + desired: .init(i, allDeallocations: allDeallocations) + ) { XCTFail("compare and exchange didn't work but it should have") } } - default: // readers + default: // readers while true { if box.load().value < 0 || box.load().value > iterations { XCTFail("loaded wrong value") @@ -1056,8 +1066,8 @@ class NIOConcurrencyHelpersTests: XCTestCase { } func spawnAndJoinRacingThreads(count: Int, _ body: @escaping (Int) -> Void) { - let go = DispatchSemaphore(value: 0) // will be incremented when the threads are supposed to run (and race). - let arrived = Array(repeating: DispatchSemaphore(value: 0), count: count) // waiting for all threads to arrive + let go = DispatchSemaphore(value: 0) // will be incremented when the threads are supposed to run (and race). + let arrived = Array(repeating: DispatchSemaphore(value: 0), count: count) // waiting for all threads to arrive let group = DispatchGroup() for i in 0.. Void) { group.wait() } -func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testInterval: TimeAmount? = nil, _ message: String = "condition not satisfied in time", file: StaticString = #filePath, line: UInt = #line) { +func assert( + _ condition: @autoclosure () -> Bool, + within time: TimeAmount, + testInterval: TimeAmount? = nil, + _ message: String = "condition not satisfied in time", + file: StaticString = #filePath, + line: UInt = #line +) { let testInterval = testInterval ?? TimeAmount.nanoseconds(time.nanoseconds / 5) let endTime = NIODeadline.now() + time repeat { if condition() { return } usleep(UInt32(testInterval.nanoseconds / 1000)) - } while (NIODeadline.now() < endTime) + } while NIODeadline.now() < endTime if !condition() { XCTFail(message, file: (file), line: line) @@ -1094,7 +1111,7 @@ func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testI } @available(*, deprecated, message: "deprecated because it is used to test deprecated functionality") -fileprivate class IntHolderWithDeallocationTracking { +private class IntHolderWithDeallocationTracking { private(set) var value: Int let allDeallocations: NIOAtomic diff --git a/Tests/NIOCoreTests/AddressedEnvelopeTests.swift b/Tests/NIOCoreTests/AddressedEnvelopeTests.swift index c2878cd7cb..6b40c10986 100644 --- a/Tests/NIOCoreTests/AddressedEnvelopeTests.swift +++ b/Tests/NIOCoreTests/AddressedEnvelopeTests.swift @@ -43,8 +43,16 @@ final class AddressedEnvelopeTests: XCTestCase { func testHashable_whenDifferentMetadata() throws { let address = try SocketAddress(ipAddress: "127.0.0.0", port: 443) - let envelope1 = AddressedEnvelope(remoteAddress: address, data: "foo", metadata: .init(ecnState: .congestionExperienced)) - let envelope2 = AddressedEnvelope(remoteAddress: address, data: "foo", metadata: .init(ecnState: .transportCapableFlag0)) + let envelope1 = AddressedEnvelope( + remoteAddress: address, + data: "foo", + metadata: .init(ecnState: .congestionExperienced) + ) + let envelope2 = AddressedEnvelope( + remoteAddress: address, + data: "foo", + metadata: .init(ecnState: .transportCapableFlag0) + ) XCTAssertNotEqual(envelope1, envelope2) } @@ -60,8 +68,16 @@ final class AddressedEnvelopeTests: XCTestCase { func testHashable_whenDifferentData_andDifferentMetadata() throws { let address = try SocketAddress(ipAddress: "127.0.0.0", port: 443) - let envelope1 = AddressedEnvelope(remoteAddress: address, data: "foo", metadata: .init(ecnState: .congestionExperienced)) - let envelope2 = AddressedEnvelope(remoteAddress: address, data: "bar", metadata: .init(ecnState: .transportCapableFlag0)) + let envelope1 = AddressedEnvelope( + remoteAddress: address, + data: "foo", + metadata: .init(ecnState: .congestionExperienced) + ) + let envelope2 = AddressedEnvelope( + remoteAddress: address, + data: "bar", + metadata: .init(ecnState: .transportCapableFlag0) + ) XCTAssertNotEqual(envelope1, envelope2) } @@ -69,8 +85,16 @@ final class AddressedEnvelopeTests: XCTestCase { func testHashable_whenDifferentAddress_andDifferentMetadata() throws { let address1 = try SocketAddress(ipAddress: "127.0.0.0", port: 443) let address2 = try SocketAddress(ipAddress: "127.0.0.0", port: 444) - let envelope1 = AddressedEnvelope(remoteAddress: address1, data: "foo", metadata: .init(ecnState: .congestionExperienced)) - let envelope2 = AddressedEnvelope(remoteAddress: address2, data: "bar", metadata: .init(ecnState: .transportCapableFlag0)) + let envelope1 = AddressedEnvelope( + remoteAddress: address1, + data: "foo", + metadata: .init(ecnState: .congestionExperienced) + ) + let envelope2 = AddressedEnvelope( + remoteAddress: address2, + data: "bar", + metadata: .init(ecnState: .transportCapableFlag0) + ) XCTAssertNotEqual(envelope1, envelope2) } diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift index 94e17ec593..c890ce027f 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelInboundStreamTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@testable import NIOCore import XCTest +@testable import NIOCore + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelInboundStreamTests: XCTestCase { func testTestingStream() async throws { diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift index 68446e2a34..2a73a8f4df 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelOutboundWriterTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@testable import NIOCore import XCTest +@testable import NIOCore + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelOutboundWriterTests: XCTestCase { func testTestingWriter() async throws { diff --git a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift index 4115b9dee8..7b38e49301 100644 --- a/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift +++ b/Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift @@ -13,10 +13,11 @@ //===----------------------------------------------------------------------===// import Atomics import NIOConcurrencyHelpers -@testable import NIOCore import NIOEmbedded import XCTest +@testable import NIOCore + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class AsyncChannelTests: XCTestCase { func testAsyncChannelCloseOnWrite() async throws { @@ -84,7 +85,7 @@ final class AsyncChannelTests: XCTestCase { func testAsyncChannelThrowsWhenChannelClosed() async throws { let channel = NIOAsyncTestingChannel() let wrapped = try await channel.testingEventLoop.executeInContext { - return try NIOAsyncChannel(wrappingChannelSynchronously: channel) + try NIOAsyncChannel(wrappingChannelSynchronously: channel) } try await channel.close(mode: .all) @@ -251,7 +252,9 @@ final class AsyncChannelTests: XCTestCase { do { let strongSentinel: Sentinel? = Sentinel() sentinel = strongSentinel! - try await XCTAsyncAssertNotNil(await channel.pipeline.handler(type: NIOAsyncChannelHandler.self).get()) + try await XCTAsyncAssertNotNil( + await channel.pipeline.handler(type: NIOAsyncChannelHandler.self).get() + ) try await channel.writeInbound(strongSentinel!) _ = try await channel.readInbound(as: Sentinel.self) } diff --git a/Tests/NIOCoreTests/AsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequenceTests.swift index 9a000a79ef..069519b40f 100644 --- a/Tests/NIOCoreTests/AsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequenceTests.swift @@ -14,7 +14,7 @@ import NIOCore import XCTest -fileprivate struct TestCase { +private struct TestCase { var buffers: [[UInt8]] var file: StaticString var line: UInt @@ -30,7 +30,7 @@ final class AsyncSequenceCollectTests: XCTestCase { func testAsyncSequenceCollect() async throws { let testCases = [ TestCase([ - [], + [] ]), TestCase([ [], @@ -74,7 +74,7 @@ final class AsyncSequenceCollectTests: XCTestCase { [], ]), TestCase([ - Array(0..<10), + Array(0..<10) ]), TestCase([ Array(0..<10), diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift index 4f7b9ec361..3720976224 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncSequenceTests.swift @@ -12,11 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import NIOCore import XCTest +@testable import NIOCore + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategy, @unchecked Sendable { +final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategy, @unchecked Sendable +{ enum Event { case didYield case didNext @@ -26,7 +28,7 @@ final class MockNIOElementStreamBackPressureStrategy: NIOAsyncSequenceProducerBa init() { var eventsContinuation: AsyncStream.Continuation! - self.events = .init() { eventsContinuation = $0 } + self.events = .init { eventsContinuation = $0 } self.eventsContinuation = eventsContinuation! } @@ -60,7 +62,7 @@ final class MockNIOBackPressuredStreamSourceDelegate: NIOAsyncSequenceProducerDe init() { var eventsContinuation: AsyncStream.Continuation! - self.events = .init() { eventsContinuation = $0 } + self.events = .init { eventsContinuation = $0 } self.eventsContinuation = eventsContinuation! } @@ -85,16 +87,18 @@ final class MockNIOBackPressuredStreamSourceDelegate: NIOAsyncSequenceProducerDe final class NIOAsyncSequenceProducerTests: XCTestCase { private var backPressureStrategy: MockNIOElementStreamBackPressureStrategy! private var delegate: MockNIOBackPressuredStreamSourceDelegate! - private var sequence: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >! - private var source: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.Source! + private var sequence: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >! + private var source: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.Source! override func setUp() { super.setUp() @@ -237,7 +241,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let result = self.source.yield(contentsOf: [1]) XCTAssertEqual(result, .stopProducing) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didYield]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didYield] + ) } func testYield_whenStreaming_andNotSuspended_andDemandMore() async throws { @@ -248,7 +255,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let result = self.source.yield(contentsOf: [1]) XCTAssertEqual(result, .produceMore) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didYield]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didYield] + ) } func testYield_whenSourceFinished() async throws { @@ -315,16 +325,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - var newSequence: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -335,16 +346,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andSuspended() async throws { - var newSequence: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -371,16 +383,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - var newSequence: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -391,7 +404,7 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence!.first { _ in true } + await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -402,16 +415,17 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - var newSequence: NIOAsyncSequenceProducer< - Int, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOAsyncSequenceProducer< + Int, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -422,7 +436,7 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return await sequence!.first { _ in true } + await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -441,7 +455,6 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let suspended = expectation(description: "task suspended") sequence._throwingSequence._storage._setDidSuspend { suspended.fulfill() } - let task: Task = Task { let iterator = sequence.makeAsyncIterator() return await iterator.next() @@ -567,7 +580,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { _ = await sequence.first { _ in true } } - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didNext]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didNext] + ) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.produceMore]) } @@ -583,7 +599,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { _ = await sequence.first { _ in true } } - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didNext]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didNext] + ) } func testNext_whenStreaming_whenNotEmptyBuffer_whenNoDemand() async throws { @@ -593,7 +612,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let element = await self.sequence.first { _ in true } XCTAssertEqual(element, 1) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didNext]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didNext] + ) } func testNext_whenStreaming_whenNotEmptyBuffer_whenNewDemand() async throws { @@ -603,7 +625,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let element = await self.sequence.first { _ in true } XCTAssertEqual(element, 1) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didNext]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didNext] + ) XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.produceMore]) } @@ -616,7 +641,10 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { let element = await self.sequence.first { _ in true } XCTAssertEqual(element, 1) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didNext]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didNext] + ) } func testNext_whenSourceFinished() async throws { @@ -685,7 +713,7 @@ final class NIOAsyncSequenceProducerTests: XCTestCase { } // This is needed until async let is supported to be used in autoclosures -fileprivate func XCTAssertEqualWithoutAutoclosure( +private func XCTAssertEqualWithoutAutoclosure( _ expression1: T, _ expression2: T, _ message: @autoclosure () -> String = "", diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift index 5787068510..95760bd411 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOAsyncWriterTests.swift @@ -13,9 +13,10 @@ //===----------------------------------------------------------------------===// import DequeModule -@testable import NIOCore -import XCTest import NIOConcurrencyHelpers +import XCTest + +@testable import NIOCore private struct SomeError: Error, Hashable {} @@ -101,9 +102,21 @@ final class NIOAsyncWriterTests: XCTestCase { file: StaticString = #filePath, line: UInt = #line ) { - XCTAssertEqual(self.delegate.didSuspendCallCount, suspendCallCount, "Unexpeced suspends", file: file, line: line) + XCTAssertEqual( + self.delegate.didSuspendCallCount, + suspendCallCount, + "Unexpeced suspends", + file: file, + line: line + ) XCTAssertEqual(self.delegate.didYieldCallCount, yieldCallCount, "Unexpected yields", file: file, line: line) - XCTAssertEqual(self.delegate.didTerminateCallCount, terminateCallCount, "Unexpected terminates", file: file, line: line) + XCTAssertEqual( + self.delegate.didTerminateCallCount, + terminateCallCount, + "Unexpected terminates", + file: file, + line: line + ) } func testMultipleConcurrentWrites() async throws { @@ -161,7 +174,7 @@ final class NIOAsyncWriterTests: XCTestCase { // MARK: - WriterDeinitialized func testWriterDeinitialized_whenInitial() async throws { - var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( elementType: String.self, isWritable: true, finishOnDeinit: true, @@ -180,7 +193,7 @@ final class NIOAsyncWriterTests: XCTestCase { } func testWriterDeinitialized_whenStreaming() async throws { - var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( + var newWriter: NIOAsyncWriter.NewWriter? = NIOAsyncWriter.makeWriter( elementType: String.self, isWritable: true, finishOnDeinit: true, @@ -571,7 +584,6 @@ final class NIOAsyncWriterTests: XCTestCase { // We are setting up a suspended yield here to check that it gets resumed self.sink.setWritability(to: false) - let suspended = expectation(description: "suspended on yield") self.delegate.didSuspendHandler = { suspended.fulfill() diff --git a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift index 0d9acc1018..5d4f8d1443 100644 --- a/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift +++ b/Tests/NIOCoreTests/AsyncSequences/NIOThrowingAsyncSequenceTests.swift @@ -12,25 +12,28 @@ // //===----------------------------------------------------------------------===// -@testable import NIOCore import XCTest +@testable import NIOCore + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { private var backPressureStrategy: MockNIOElementStreamBackPressureStrategy! private var delegate: MockNIOBackPressuredStreamSourceDelegate! - private var sequence: NIOThrowingAsyncSequenceProducer< - Int, - Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >! - private var source: NIOThrowingAsyncSequenceProducer< - Int, - Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.Source! + private var sequence: + NIOThrowingAsyncSequenceProducer< + Int, + Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >! + private var source: + NIOThrowingAsyncSequenceProducer< + Int, + Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.Source! override func setUp() { super.setUp() @@ -214,7 +217,10 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let result = self.source.yield(contentsOf: [1]) XCTAssertEqual(result, .stopProducing) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didYield]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didYield] + ) } func testYield_whenStreaming_andNotSuspended_andProduceMore() async throws { @@ -225,7 +231,10 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let result = self.source.yield(contentsOf: [1]) XCTAssertEqual(result, .produceMore) - XCTAssertEqualWithoutAutoclosure(await self.backPressureStrategy.events.prefix(2).collect(), [.didYield, .didYield]) + XCTAssertEqualWithoutAutoclosure( + await self.backPressureStrategy.events.prefix(2).collect(), + [.didYield, .didYield] + ) } func testYield_whenSourceFinished() async throws { @@ -273,7 +282,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + try await sequence.first { _ in true } } return try await group.next() ?? nil @@ -291,7 +300,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let sequence = try XCTUnwrap(self.sequence) let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence.first { _ in true } + try await sequence.first { _ in true } } return try await group.next() ?? nil @@ -305,7 +314,6 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testFinish_whenFinished() async throws { self.source.finish() - _ = try await self.sequence.first { _ in true } XCTAssertEqualWithoutAutoclosure(await self.delegate.events.prefix(1).collect(), [.didTerminate]) @@ -327,21 +335,23 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { func testFinishError_whenStreaming_andSuspended() async throws { let sequence = try XCTUnwrap(self.sequence) - await XCTAssertThrowsError(try await withThrowingTaskGroup(of: Void.self) { group in + await XCTAssertThrowsError( + try await withThrowingTaskGroup(of: Void.self) { group in - let suspended = expectation(description: "task suspended") - sequence._storage._setDidSuspend { suspended.fulfill() } + let suspended = expectation(description: "task suspended") + sequence._storage._setDidSuspend { suspended.fulfill() } - group.addTask { - _ = try await sequence.first { _ in true } - } + group.addTask { + _ = try await sequence.first { _ in true } + } - await fulfillment(of: [suspended], timeout: 1) + await fulfillment(of: [suspended], timeout: 1) - self.source.finish(ChannelError.alreadyClosed) + self.source.finish(ChannelError.alreadyClosed) - try await group.next() - }) { error in + try await group.next() + } + ) { error in XCTAssertEqual(error as? ChannelError, .alreadyClosed) } @@ -354,13 +364,15 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { self.source.finish(ChannelError.alreadyClosed) let sequence = try XCTUnwrap(self.sequence) - await XCTAssertThrowsError(try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - _ = try await sequence.first { _ in true } - } + await XCTAssertThrowsError( + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + _ = try await sequence.first { _ in true } + } - try await group.next() - }) { error in + try await group.next() + } + ) { error in XCTAssertEqual(error as? ChannelError, .alreadyClosed) } @@ -374,11 +386,13 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { var elements = [Int]() - await XCTAssertThrowsError(try await { - for try await element in self.sequence { - elements.append(element) - } - }()) { error in + await XCTAssertThrowsError( + try await { + for try await element in self.sequence { + elements.append(element) + } + }() + ) { error in XCTAssertEqual(error as? ChannelError, .alreadyClosed) } @@ -402,39 +416,41 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { // MARK: - Source Deinited func testSourceDeinited_whenInitial() async { - var newSequence: NIOThrowingAsyncSequenceProducer< - Int, - any Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil source = nil - + XCTAssertNil(source) XCTAssertNotNil(sequence) } func testSourceDeinited_whenStreaming_andSuspended() async throws { - var newSequence: NIOThrowingAsyncSequenceProducer< - Int, - any Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -462,17 +478,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferEmpty() async throws { - var newSequence: NIOThrowingAsyncSequenceProducer< - Int, - any Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -483,7 +500,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence!.first { _ in true } + try await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -494,17 +511,18 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } func testSourceDeinited_whenStreaming_andNotSuspended_andBufferNotEmpty() async throws { - var newSequence: NIOThrowingAsyncSequenceProducer< - Int, - any Error, - MockNIOElementStreamBackPressureStrategy, - MockNIOBackPressuredStreamSourceDelegate - >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( - elementType: Int.self, - backPressureStrategy: self.backPressureStrategy, - finishOnDeinit: true, - delegate: self.delegate - ) + var newSequence: + NIOThrowingAsyncSequenceProducer< + Int, + any Error, + MockNIOElementStreamBackPressureStrategy, + MockNIOBackPressuredStreamSourceDelegate + >.NewSequence? = NIOThrowingAsyncSequenceProducer.makeSequence( + elementType: Int.self, + backPressureStrategy: self.backPressureStrategy, + finishOnDeinit: true, + delegate: self.delegate + ) let sequence = newSequence?.sequence var source = newSequence?.source newSequence = nil @@ -515,7 +533,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { let element: Int? = try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await sequence!.first { _ in true } + try await sequence!.first { _ in true } } return try await group.next() ?? nil @@ -548,7 +566,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { XCTAssertTrue(error is CancellationError) } } - + @available(*, deprecated, message: "tests the deprecated custom generic failure type") func testTaskCancel_whenStreaming_andSuspended_withCustomErrorType() async throws { struct CustomError: Error {} @@ -575,7 +593,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { task.cancel() let result = await task.result XCTAssertEqualWithoutAutoclosure(await delegate.events.prefix(1).collect(), [.didTerminate]) - + try withExtendedLifetime(new.source) { XCTAssertNil(try result.get()) } @@ -651,7 +669,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { XCTAssertTrue(error is CancellationError, "unexpected error \(error)") } } - + @available(*, deprecated, message: "tests the deprecated custom generic failure type") func testTaskCancel_whenStreaming_andTaskIsAlreadyCancelled_withCustomErrorType() async throws { struct CustomError: Error {} @@ -891,7 +909,7 @@ final class NIOThrowingAsyncSequenceProducerTests: XCTestCase { } // This is needed until async let is supported to be used in autoclosures -fileprivate func XCTAssertEqualWithoutAutoclosure( +private func XCTAssertEqualWithoutAutoclosure( _ expression1: T, _ expression2: T, _ message: @autoclosure () -> String = "", diff --git a/Tests/NIOCoreTests/BaseObjectsTest.swift b/Tests/NIOCoreTests/BaseObjectsTest.swift index d30cc40d0f..74f8a91563 100644 --- a/Tests/NIOCoreTests/BaseObjectsTest.swift +++ b/Tests/NIOCoreTests/BaseObjectsTest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import NIOCore class BaseObjectTest: XCTestCase { diff --git a/Tests/NIOCoreTests/ByteBufferLengthPrefixTests.swift b/Tests/NIOCoreTests/ByteBufferLengthPrefixTests.swift index 7f6c600d4c..e80750519d 100644 --- a/Tests/NIOCoreTests/ByteBufferLengthPrefixTests.swift +++ b/Tests/NIOCoreTests/ByteBufferLengthPrefixTests.swift @@ -17,7 +17,7 @@ import XCTest final class ByteBufferLengthPrefixTests: XCTestCase { private var buffer = ByteBuffer() - + // MARK: - writeLengthPrefixed Tests func testWriteMessageWithLengthOfZero() throws { let bytesWritten = try buffer.writeLengthPrefixed(as: UInt8.self) { buffer in @@ -39,9 +39,7 @@ final class ByteBufferLengthPrefixTests: XCTestCase { } func testWriteMessageWithMultipleWrites() throws { let bytesWritten = try buffer.writeLengthPrefixed(as: UInt8.self) { buffer in - buffer.writeString("Hello") + - buffer.writeString(" ") + - buffer.writeString("World") + buffer.writeString("Hello") + buffer.writeString(" ") + buffer.writeString("World") } XCTAssertEqual(bytesWritten, 12) XCTAssertEqual(buffer.readInteger(as: UInt8.self), 11) @@ -86,14 +84,14 @@ final class ByteBufferLengthPrefixTests: XCTestCase { XCTAssertEqual(buffer.readString(length: 256), message) XCTAssertTrue(buffer.readableBytesView.isEmpty) } - + // MARK: - readLengthPrefixed Tests func testReadMessageWithLengthOfZero() { buffer.writeInteger(UInt8(0)) XCTAssertEqual( try buffer.readLengthPrefixed(as: UInt8.self) { buffer in buffer - }, + }, ByteBuffer() ) } @@ -196,7 +194,7 @@ final class ByteBufferLengthPrefixTests: XCTestCase { nil ) } - + // MARK: - readLengthPrefixedSlice func testReadSliceWithBigEndianInteger() { buffer.writeInteger(UInt16(256), endianness: .big) @@ -216,7 +214,7 @@ final class ByteBufferLengthPrefixTests: XCTestCase { ) XCTAssertTrue(buffer.readableBytes == 0) } - + // MARK: - getLengthPrefixedSlice func testGetSliceWithBigEndianInteger() { buffer.writeString("some data before the length prefix") diff --git a/Tests/NIOCoreTests/ByteBufferTest.swift b/Tests/NIOCoreTests/ByteBufferTest.swift index a63c8a5f19..51b64fe1d2 100644 --- a/Tests/NIOCoreTests/ByteBufferTest.swift +++ b/Tests/NIOCoreTests/ByteBufferTest.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.Data -import XCTest -@testable import NIOCore import NIOFoundationCompat - +import XCTest import _NIOBase64 +import struct Foundation.Data + +@testable import NIOCore + class ByteBufferTest: XCTestCase { private let allocator = ByteBufferAllocator() private var buf: ByteBuffer! = nil @@ -66,14 +67,14 @@ class ByteBufferTest: XCTestCase { // Only cares about the read buffer self.buf.writeInteger(Int8.max) self.buf.writeString("oh hi") - let actual: Int8 = buf.readInteger()! // Just getting rid of it from the read buffer + let actual: Int8 = buf.readInteger()! // Just getting rid of it from the read buffer XCTAssertEqual(Int8.max, actual) var otherBuffer = allocator.buffer(capacity: 32) otherBuffer.writeString("oh hi") XCTAssertEqual(otherBuffer, buf) } - + func testHasherUsesReadBuffersOnly() { // Only cares about the read buffer self.buf.clear() @@ -88,14 +89,14 @@ class ByteBufferTest: XCTestCase { // https://bugs.swift.org/browse/SR-11975 hasher.combine(self.buf!) let hash = hasher.finalize() - + var otherBuffer = allocator.buffer(capacity: 6) otherBuffer.writeString("oh hi") var otherHasher = Hasher() otherHasher.combine(otherBuffer) let otherHash = otherHasher.finalize() - + XCTAssertEqual(hash, otherHash) } @@ -137,11 +138,11 @@ class ByteBufferTest: XCTestCase { XCTAssertEqual(1, slice.capacity) XCTAssertEqual(16, slice.storageCapacity) let oldStorageBegin = slice.withUnsafeReadableBytes { ptr in - return UInt(bitPattern: ptr.baseAddress!) + UInt(bitPattern: ptr.baseAddress!) } slice.setInteger(1, at: 0, as: UInt8.self) let newStorageBegin = slice.withUnsafeReadableBytes { ptr in - return UInt(bitPattern: ptr.baseAddress!) + UInt(bitPattern: ptr.baseAddress!) } XCTAssertEqual(oldStorageBegin, newStorageBegin) } @@ -159,7 +160,6 @@ class ByteBufferTest: XCTestCase { slice.writeBytes(Array(32..<47)) } - func testReadWrite() { buf.writeString("X") buf.writeString("Y") @@ -205,39 +205,39 @@ class ByteBufferTest: XCTestCase { let string = buf.getString(at: 0, length: written) XCTAssertEqual("Hello", string) } - + func testNullTerminatedString() { let writtenHello = buf.writeNullTerminatedString("Hello") XCTAssertEqual(writtenHello, 6) XCTAssertEqual(buf.readableBytes, 6) - + let writtenEmpty = buf.writeNullTerminatedString("") XCTAssertEqual(writtenEmpty, 1) XCTAssertEqual(buf.readableBytes, 7) - + let writtenFoo = buf.writeNullTerminatedString("foo") XCTAssertEqual(writtenFoo, 4) XCTAssertEqual(buf.readableBytes, 11) - + XCTAssertEqual(buf.getNullTerminatedString(at: 0), "Hello") XCTAssertEqual(buf.getNullTerminatedString(at: 6), "") XCTAssertEqual(buf.getNullTerminatedString(at: 7), "foo") - + XCTAssertEqual(buf.readNullTerminatedString(), "Hello") XCTAssertEqual(buf.readerIndex, 6) - + XCTAssertEqual(buf.readNullTerminatedString(), "") XCTAssertEqual(buf.readerIndex, 7) - + XCTAssertEqual(buf.readNullTerminatedString(), "foo") XCTAssertEqual(buf.readerIndex, 11) } - + func testReadNullTerminatedStringWithoutNullTermination() { buf.writeString("Hello") XCTAssertNil(buf.readNullTerminatedString()) } - + func testGetNullTerminatedStringOutOfRangeTests() { buf.writeNullTerminatedString("Hello") XCTAssertNil(buf.getNullTerminatedString(at: 100)) @@ -247,31 +247,31 @@ class ByteBufferTest: XCTestCase { buf.writeInteger(UInt8(0)) XCTAssertEqual(buf.readNullTerminatedString(), "") } - + func testWriteSubstring() { var text = "Hello" let written = buf.writeSubstring(text[...]) var string = buf.getString(at: 0, length: written) XCTAssertEqual(text, string) - + text = "" buf.writeSubstring(text[...]) string = buf.getString(at: 0, length: written) XCTAssertEqual("Hello", string) } - + func testSetSubstring() { let text = "Hello" buf.writeSubstring(text[...]) - + var written = buf.setSubstring(text[...], at: 0) var string = buf.getString(at: 0, length: written) XCTAssertEqual(text, string) - + written = buf.setSubstring(text[text.index(after: text.startIndex)...], at: 1) string = buf.getString(at: 0, length: written + 1) XCTAssertEqual(text, string) - + written = buf.setSubstring(text[text.index(after: text.startIndex)...], at: 0) string = buf.getString(at: 0, length: written) XCTAssertEqual("ello", string) @@ -319,7 +319,7 @@ class ByteBufferTest: XCTestCase { let bytesConsumed = buf.readWithUnsafeReadableBytes { dst in // Pretend we did some operation which made use of entire 11 byte string - return 11 + 11 } XCTAssertEqual(11, bytesConsumed) XCTAssertEqual(11, buf.readerIndex) @@ -543,7 +543,7 @@ class ByteBufferTest: XCTestCase { testAssumptionOriginalBuffer(&buffer) var buffer10Missing = buffer - let first10Bytes = buffer10Missing.readData(length: 10) /* make the first 10 bytes disappear */ + let first10Bytes = buffer10Missing.readData(length: 10) // make the first 10 bytes disappear let otherBuffer10Missing = buffer10Missing XCTAssertEqual("0123456789".data(using: .utf8)!, first10Bytes) testAssumptionOriginalBuffer(&buffer) @@ -591,7 +591,7 @@ class ByteBufferTest: XCTestCase { XCTAssertEqual(2, slice.writerIndex) XCTAssertEqual(UInt8(3), slice.readInteger()) XCTAssertEqual(UInt8(4), slice.readInteger()) - XCTAssertEqual(0,slice.readableBytes) + XCTAssertEqual(0, slice.readableBytes) XCTAssertTrue(slice.discardReadBytes()) XCTAssertFalse(slice.discardReadBytes()) } @@ -618,7 +618,7 @@ class ByteBufferTest: XCTestCase { buffer.withUnsafeReadableBytes { data -> Void in XCTAssertEqual(string.utf8.count, data.count) - for (idx, expected) in zip(data.startIndex..(_ type: T.Type) { initBuffer() - let tooMany = (byteCount + 1)/MemoryLayout.size + let tooMany = (byteCount + 1) / MemoryLayout.size for _ in 1...SubSequence typealias Indices = Array.Indices public var indices: Indices { - return self.storage.indices + self.storage.indices } public subscript(bounds: Range) -> SubSequence { - return self.storage[bounds] + self.storage[bounds] } public subscript(position: Index) -> Element { - /* this is wrong but we need to check that we don't access this */ + // this is wrong but we need to check that we don't access this XCTFail("shouldn't have been called") return 0xff } public var startIndex: Index { - return self.storage.startIndex + self.storage.startIndex } public var endIndex: Index { - return self.storage.endIndex + self.storage.endIndex } func index(after i: Index) -> Index { - return self.storage.index(after: i) + self.storage.index(after: i) } func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { - return try self.storage.withUnsafeBufferPointer(body) + try self.storage.withUnsafeBufferPointer(body) } } buf.clear() @@ -1599,19 +1618,19 @@ class ByteBufferTest: XCTestCase { typealias Element = UInt8 public var indices: CountableRange { - return self.storage.indices + self.storage.indices } public subscript(position: Int) -> Element { - return self.storage[position] + self.storage[position] } public var underestimatedCount: Int { - return 8 + 8 } func makeIterator() -> Array.Iterator { - return self.storage.makeIterator() + self.storage.makeIterator() } } buf = self.allocator.buffer(capacity: 4) @@ -1642,10 +1661,10 @@ class ByteBufferTest: XCTestCase { self.buf.clear() self.buf.writeInteger(-1, endianness: .big, as: Int64.self) XCTAssertEqual(-1, self.buf.readInteger(endianness: .big, as: Int64.self)) - self.buf.setInteger(0xdeadbeef, at: 0, endianness: .little, as: UInt64.self) + self.buf.setInteger(0xdead_beef, at: 0, endianness: .little, as: UInt64.self) self.buf.moveWriterIndex(to: 8) self.buf.moveReaderIndex(to: 0) - XCTAssertEqual(0xdeadbeef, self.buf.getInteger(at: 0, endianness: .little, as: UInt64.self)) + XCTAssertEqual(0xdead_beef, self.buf.getInteger(at: 0, endianness: .little, as: UInt64.self)) } func testByteBufferFitsInACoupleOfEnums() throws { @@ -1683,7 +1702,7 @@ class ByteBufferTest: XCTestCase { func testLargeSliceBegin16MBIsOkayAndDoesNotCopy() throws { var fourMBBuf = self.allocator.buffer(capacity: 4 * 1024 * 1024) - fourMBBuf.writeBytes(Array(repeating: 0xff, count: fourMBBuf.capacity)) + fourMBBuf.writeBytes([UInt8](repeating: 0xff, count: fourMBBuf.capacity)) let totalBufferSize = 5 * fourMBBuf.readableBytes XCTAssertEqual(4 * 1024 * 1024, fourMBBuf.readableBytes) var buf = self.allocator.buffer(capacity: totalBufferSize) @@ -1720,7 +1739,7 @@ class ByteBufferTest: XCTestCase { func testLargeSliceBeginMoreThan16MBIsOkay() throws { var fourMBBuf = self.allocator.buffer(capacity: 4 * 1024 * 1024) - fourMBBuf.writeBytes(Array(repeating: 0xff, count: fourMBBuf.capacity)) + fourMBBuf.writeBytes([UInt8](repeating: 0xff, count: fourMBBuf.capacity)) let totalBufferSize = 5 * fourMBBuf.readableBytes + 1 XCTAssertEqual(4 * 1024 * 1024, fourMBBuf.readableBytes) var buf = self.allocator.buffer(capacity: totalBufferSize) @@ -1735,7 +1754,7 @@ class ByteBufferTest: XCTestCase { buf.setInteger(0xaa, at: 0, as: UInt8.self) buf.setInteger(0xbb, at: offset - 1, as: UInt8.self) buf.setInteger(0xcc, at: offset, as: UInt8.self) - buf.writeInteger(0xdd, as: UInt8.self) // write extra byte so the slice is the same length as above + buf.writeInteger(0xdd, as: UInt8.self) // write extra byte so the slice is the same length as above XCTAssertEqual(totalBufferSize, buf.readableBytes) let expectedReadableBytes = totalBufferSize - offset @@ -1761,7 +1780,7 @@ class ByteBufferTest: XCTestCase { hookedMalloc: { _ in .init(bitPattern: 0xdedbeef) }, hookedRealloc: { _, _ in fatalError() }, hookedFree: { precondition($0 == .init(bitPattern: 0xdedbeef)!) }, - hookedMemcpy: {_, _, _ in } + hookedMemcpy: { _, _, _ in } ) let targetSize = Int(UInt32.max) @@ -1782,11 +1801,11 @@ class ByteBufferTest: XCTestCase { XCTAssertEqual(slice.readerIndex, 0) XCTAssertEqual(Int(UInt32(3).nextPowerOf2()), slice.capacity) } - + func testSliceOnSliceAfterHitting16MBMark() { // This test ensures that a slice will get a new backing storage if its start is more than // 16MiB after the originating backing storage. - + // create a buffer with 16MiB + 1 byte let inputBufferLength = 16 * 1024 * 1024 + 1 var inputBuffer = ByteBufferAllocator().buffer(capacity: inputBufferLength) @@ -1794,28 +1813,30 @@ class ByteBufferTest: XCTestCase { inputBuffer.writeRepeatingByte(2, count: inputBufferLength - 9) inputBuffer.writeRepeatingByte(3, count: 1) // read a small slice from the inputBuffer, to create an offset of eight bytes - XCTAssertEqual(inputBuffer.readInteger(as: UInt64.self), 0x0101010101010101) - + XCTAssertEqual(inputBuffer.readInteger(as: UInt64.self), 0x0101_0101_0101_0101) + // read the remaining bytes into a new slice (this will have a length of 16MiB - 7Bbytes) let remainingSliceLength = inputBufferLength - 8 XCTAssertEqual(inputBuffer.readableBytes, remainingSliceLength) var remainingSlice = inputBuffer.readSlice(length: remainingSliceLength)! - + let finalSliceLength = 1 // let's create a new buffer that uses all but one byte - XCTAssertEqual(remainingSlice.readBytes(length: remainingSliceLength - finalSliceLength), - Array(repeating: 2, count: remainingSliceLength - finalSliceLength)) - + XCTAssertEqual( + remainingSlice.readBytes(length: remainingSliceLength - finalSliceLength), + [UInt8](repeating: 2, count: remainingSliceLength - finalSliceLength) + ) + // there should only be one byte left. XCTAssertEqual(remainingSlice.readableBytes, finalSliceLength) - + // with just one byte left, the last byte is exactly one byte above the 16MiB threshold. // For this reason a slice of the last byte, will need to get a new backing storage. let finalSlice = remainingSlice.readSlice(length: finalSliceLength) XCTAssertNotEqual(finalSlice?.storagePointerIntegerValue(), remainingSlice.storagePointerIntegerValue()) XCTAssertEqual(finalSlice?.storageCapacity, 1) XCTAssertEqual(finalSlice, ByteBuffer(integer: 3, as: UInt8.self)) - + XCTAssertEqual(remainingSlice.readableBytes, 0) } @@ -1845,15 +1866,15 @@ class ByteBufferTest: XCTestCase { } let actual = self.buf._storage.dumpBytes(slice: self.buf._slice, offset: 0, length: self.buf.readableBytes) let expected = """ - [ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f \ - 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f \ - 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f \ - 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f \ - 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f 90 91 92 93 94 95 96 97 98 99 9a 9b 9c 9d 9e 9f \ - a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 aa ab ac ad ae af b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ba bb bc bd be bf \ - c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 ca cb cc cd ce cf d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 da db dc dd de df \ - e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 ea eb ec ed ee ef f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 fa fb fc fd fe ff ] - """ + [ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f \ + 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f \ + 40 41 42 43 44 45 46 47 48 49 4a 4b 4c 4d 4e 4f 50 51 52 53 54 55 56 57 58 59 5a 5b 5c 5d 5e 5f \ + 60 61 62 63 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f \ + 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f 90 91 92 93 94 95 96 97 98 99 9a 9b 9c 9d 9e 9f \ + a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 aa ab ac ad ae af b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 ba bb bc bd be bf \ + c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 ca cb cc cd ce cf d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 da db dc dd de df \ + e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 ea eb ec ed ee ef f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 fa fb fc fd fe ff ] + """ XCTAssertEqual(expected, actual) } @@ -1887,42 +1908,43 @@ class ByteBufferTest: XCTestCase { func testHexDumpDetailed() { let buf = ByteBuffer(string: "Goodbye, world! It was nice knowing you.\n") let expected = """ - 00000000 47 6f 6f 64 62 79 65 2c 20 77 6f 72 6c 64 21 20 |Goodbye, world! | - 00000010 49 74 20 77 61 73 20 6e 69 63 65 20 6b 6e 6f 77 |It was nice know| - 00000020 69 6e 67 20 79 6f 75 2e 0a |ing you..| - 00000029 - """ + 00000000 47 6f 6f 64 62 79 65 2c 20 77 6f 72 6c 64 21 20 |Goodbye, world! | + 00000010 49 74 20 77 61 73 20 6e 69 63 65 20 6b 6e 6f 77 |It was nice know| + 00000020 69 6e 67 20 79 6f 75 2e 0a |ing you..| + 00000029 + """ let actual = buf.hexDump(format: .detailed) XCTAssertEqual(expected, actual) } - func testHexDumpDetailedWithMaxBytes() { let buf = ByteBuffer(string: "Goodbye, world! It was nice knowing you.\n") let expected = """ - 00000000 47 6f 6f 64 62 79 65 2c |Goodbye, | - ........ .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .................. - 00000020 6e 67 20 79 6f 75 2e 0a | ng you..| - 00000029 - """ + 00000000 47 6f 6f 64 62 79 65 2c |Goodbye, | + ........ .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .................. + 00000020 6e 67 20 79 6f 75 2e 0a | ng you..| + 00000029 + """ let actual = buf.hexDump(format: .detailed(maxBytes: 16)) XCTAssertEqual(expected, actual) } func testHexDumpDetailedWithMultilineFrontAndBack() { - let buf = ByteBuffer(string: """ - Goodbye, world! It was nice knowing you. - I will miss this pull request with all of it's 94+ comments. - """) + let buf = ByteBuffer( + string: """ + Goodbye, world! It was nice knowing you. + I will miss this pull request with all of it's 94+ comments. + """ + ) let expected = """ - 00000000 47 6f 6f 64 62 79 65 2c 20 77 6f 72 6c 64 21 20 |Goodbye, world! | - 00000010 49 74 |It | - ........ .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .................. - 00000050 69 74 27 73 20 39 34 2b 20 63 6f 6d 6d | it's 94+ comm| - 00000060 65 6e 74 73 2e |ents.| - 00000065 - """ + 00000000 47 6f 6f 64 62 79 65 2c 20 77 6f 72 6c 64 21 20 |Goodbye, world! | + 00000010 49 74 |It | + ........ .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .. .................. + 00000050 69 74 27 73 20 39 34 2b 20 63 6f 6d 6d | it's 94+ comm| + 00000060 65 6e 74 73 2e |ents.| + 00000065 + """ let actual = buf.hexDump(format: .detailed(maxBytes: 36)) XCTAssertEqual(expected, actual) } @@ -1931,11 +1953,11 @@ class ByteBufferTest: XCTestCase { var buf = ByteBuffer(string: "Goodbye, world! It was nice knowing you.\n") let _ = buf.readBytes(length: 5) let expected = """ - 00000000 79 65 2c 20 77 6f 72 6c 64 21 20 49 74 20 77 61 |ye, world! It wa| - 00000010 73 20 6e 69 63 65 20 6b 6e 6f 77 69 6e 67 20 79 |s nice knowing y| - 00000020 6f 75 2e 0a |ou..| - 00000024 - """ + 00000000 79 65 2c 20 77 6f 72 6c 64 21 20 49 74 20 77 61 |ye, world! It wa| + 00000010 73 20 6e 69 63 65 20 6b 6e 6f 77 69 6e 67 20 79 |s nice knowing y| + 00000020 6f 75 2e 0a |ou..| + 00000024 + """ let actual = buf.hexDump(format: .detailed) XCTAssertEqual(expected, actual) } @@ -1965,14 +1987,20 @@ class ByteBufferTest: XCTestCase { func testBytesView() throws { self.buf.clear() self.buf.writeString("hello world 012345678") - XCTAssertEqual(self.buf.viewBytes(at: self.buf.readerIndex, - length: self.buf.writerIndex - self.buf.readerIndex).map { (view: ByteBufferView) -> String in - String(decoding: view, as: Unicode.UTF8.self) - }, - self.buf.getString(at: self.buf.readerIndex, length: self.buf.readableBytes)) + XCTAssertEqual( + self.buf.viewBytes( + at: self.buf.readerIndex, + length: self.buf.writerIndex - self.buf.readerIndex + ).map { (view: ByteBufferView) -> String in + String(decoding: view, as: Unicode.UTF8.self) + }, + self.buf.getString(at: self.buf.readerIndex, length: self.buf.readableBytes) + ) XCTAssertEqual(self.buf.viewBytes(at: 0, length: 0).map { Array($0) }, []) - XCTAssertEqual(Array("hello world 012345678".utf8), - self.buf.viewBytes(at: 0, length: self.buf.readableBytes).map(Array.init)) + XCTAssertEqual( + Array("hello world 012345678".utf8), + self.buf.viewBytes(at: 0, length: self.buf.readableBytes).map(Array.init) + ) } func testViewsStartIndexIsStable() throws { @@ -1988,17 +2016,17 @@ class ByteBufferTest: XCTestCase { self.buf.writeString("hello") let view: ByteBufferView? = self.buf.viewBytes(at: 1, length: 3) XCTAssertEqual("ell", view.map { String(decoding: $0, as: Unicode.UTF8.self) }) - let viewSlice: ByteBufferView? = view.map { $0[$0.startIndex + 1 ..< $0.endIndex] } + let viewSlice: ByteBufferView? = view.map { $0[$0.startIndex + 1..<$0.endIndex] } XCTAssertEqual("ll", viewSlice.map { String(decoding: $0, as: Unicode.UTF8.self) }) XCTAssertEqual("l", viewSlice.map { String(decoding: $0.dropFirst(), as: Unicode.UTF8.self) }) XCTAssertEqual("", viewSlice.map { String(decoding: $0.dropFirst().dropLast(), as: Unicode.UTF8.self) }) } - + func testReadableBufferViewRangeEqualCapacity() throws { self.buf.clear() self.buf.moveWriterIndex(forwardBy: buf.capacity) let view = self.buf.readableBytesView - let viewSlice: ByteBufferView = view[view.startIndex ..< view.endIndex] + let viewSlice: ByteBufferView = view[view.startIndex.. Void { +private func testAllocationOfReallyBigByteBuffer_freeHook(_ ptr: UnsafeMutableRawPointer?) { precondition(AllocationExpectationState.reallocDone == testAllocationOfReallyBigByteBuffer_state) testAllocationOfReallyBigByteBuffer_state = .freeDone - /* free the pointer initially produced by malloc and then rebased by realloc offsetting it back */ + // free the pointer initially produced by malloc and then rebased by realloc offsetting it back free(ptr!.advanced(by: Int(Int32.max))) } private func testAllocationOfReallyBigByteBuffer_mallocHook(_ size: Int) -> UnsafeMutableRawPointer? { precondition(AllocationExpectationState.begin == testAllocationOfReallyBigByteBuffer_state) testAllocationOfReallyBigByteBuffer_state = .mallocDone - /* return a 16 byte pointer here, good enough to write an integer in there */ + // return a 16 byte pointer here, good enough to write an integer in there return malloc(16) } -private func testAllocationOfReallyBigByteBuffer_reallocHook(_ ptr: UnsafeMutableRawPointer?, _ count: Int) -> UnsafeMutableRawPointer? { +private func testAllocationOfReallyBigByteBuffer_reallocHook( + _ ptr: UnsafeMutableRawPointer?, + _ count: Int +) -> UnsafeMutableRawPointer? { precondition(AllocationExpectationState.mallocDone == testAllocationOfReallyBigByteBuffer_state) testAllocationOfReallyBigByteBuffer_state = .reallocDone - /* rebase this pointer by -Int32.max so that the byte copy extending the ByteBuffer below will land at actual index 0 into this buffer ;) */ + // rebase this pointer by -Int32.max so that the byte copy extending the ByteBuffer below will land at actual index 0 into this buffer ;) return ptr!.advanced(by: -Int(Int32.max)) } -private func testAllocationOfReallyBigByteBuffer_memcpyHook(_ dst: UnsafeMutableRawPointer, _ src: UnsafeRawPointer, _ count: Int) -> Void { - /* not actually doing any copies */ +private func testAllocationOfReallyBigByteBuffer_memcpyHook( + _ dst: UnsafeMutableRawPointer, + _ src: UnsafeRawPointer, + _ count: Int +) { + // not actually doing any copies } - private var testReserveCapacityLarger_reallocCount = 0 private var testReserveCapacityLarger_mallocCount = 0 -private func testReserveCapacityLarger_freeHook( _ ptr: UnsafeMutableRawPointer) -> Void { +private func testReserveCapacityLarger_freeHook(_ ptr: UnsafeMutableRawPointer) { free(ptr) } @@ -3205,12 +3250,16 @@ private func testReserveCapacityLarger_mallocHook(_ size: Int) -> UnsafeMutableR return malloc(size) } -private func testReserveCapacityLarger_reallocHook(_ ptr: UnsafeMutableRawPointer?, _ count: Int) -> UnsafeMutableRawPointer? { +private func testReserveCapacityLarger_reallocHook( + _ ptr: UnsafeMutableRawPointer?, + _ count: Int +) -> UnsafeMutableRawPointer? { testReserveCapacityLarger_reallocCount += 1 return realloc(ptr, count) } -private func testReserveCapacityLarger_memcpyHook(_ dst: UnsafeMutableRawPointer, _ src: UnsafeRawPointer, _ count: Int) -> Void { +private func testReserveCapacityLarger_memcpyHook(_ dst: UnsafeMutableRawPointer, _ src: UnsafeRawPointer, _ count: Int) +{ // No copying } @@ -3226,29 +3275,29 @@ extension ByteBuffer { // MARK: - Array init extension ByteBufferTest { - + func testCreateArrayFromBuffer() { let testString = "some sample data" let buffer = ByteBuffer(ByteBufferView(testString.utf8)) XCTAssertEqual(Array(buffer: buffer), Array(testString.utf8)) } - + } // MARK: - String init extension ByteBufferTest { - + func testCreateStringFromBuffer() { let testString = "some sample data" let buffer = ByteBuffer(ByteBufferView(testString.utf8)) XCTAssertEqual(String(buffer: buffer), testString) } - + } // MARK: - DispatchData init extension ByteBufferTest { - + func testCreateDispatchDataFromBuffer() { let testString = "some sample data" let buffer = ByteBuffer(ByteBufferView(testString.utf8)) @@ -3257,16 +3306,16 @@ extension ByteBufferTest { } XCTAssertTrue(DispatchData(buffer: buffer).elementsEqual(expectedData)) } - + } // MARK: - ExpressibleByArrayLiteral init extension ByteBufferTest { - + func testCreateBufferFromArray() { let bufferView: ByteBufferView = [0x00, 0x01, 0x02] let buffer = ByteBuffer(ByteBufferView(bufferView)) - + XCTAssertEqual(buffer.readableBytesView, [0x00, 0x01, 0x02]) } @@ -3278,7 +3327,7 @@ extension ByteBufferTest { func testByteBufferViewEqualityWithRange() { var buffer = self.allocator.buffer(capacity: 8) buffer.writeString("AAAABBBB") - + let view = ByteBufferView(buffer: buffer, range: 2..<6) let comparisonBuffer: ByteBufferView = [0x41, 0x41, 0x42, 0x42] @@ -3288,7 +3337,7 @@ extension ByteBufferTest { func testInvalidBufferEqualityWithDifferentRange() { var buffer = self.allocator.buffer(capacity: 4) buffer.writeString("AAAA") - + let view = ByteBufferView(buffer: buffer, range: 0..<2) let comparisonBuffer: ByteBufferView = [0x41, 0x41, 0x41, 0x41] @@ -3298,7 +3347,7 @@ extension ByteBufferTest { func testInvalidBufferEqualityWithDifferentContent() { var buffer = self.allocator.buffer(capacity: 4) buffer.writeString("AAAA") - + let view = ByteBufferView(buffer: buffer, range: 0..<4) let comparisonBuffer: ByteBufferView = [0x41, 0x41, 0x00, 0x00] @@ -3317,7 +3366,6 @@ extension ByteBufferTest { XCTAssertEqual(bufferView.hashValue, comparisonBufferView.hashValue) } - func testInvalidHash() { let bufferView: ByteBufferView = [0x00, 0x00, 0x00] let comparisonBufferView: ByteBufferView = [0x00, 0x01, 0x02] @@ -3337,7 +3385,14 @@ extension ByteBufferTest { func testWritingMultipleIntegers() { let w1 = self.buf.writeMultipleIntegers(UInt32(1), UInt8(2), UInt16(3), UInt64(4), UInt16(5), endianness: .big) - let w2 = self.buf.writeMultipleIntegers(UInt32(1), UInt8(2), UInt16(3), UInt64(4), UInt16(5), endianness: .little) + let w2 = self.buf.writeMultipleIntegers( + UInt32(1), + UInt8(2), + UInt16(3), + UInt64(4), + UInt16(5), + endianness: .little + ) XCTAssertEqual(17, w1) XCTAssertEqual(17, w2) @@ -3381,14 +3436,26 @@ extension ByteBufferTest { let startWriterIndex = self.buf.writerIndex let written = self.buf.writeMultipleIntegers( - v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v1, + v2, + v3, + v4, + v5, + v6, + v7, + v8, + v9, + v10, endianness: endianness, - as: (UInt8, UInt16, UInt32, UInt64, UInt64, UInt32, UInt16, UInt8, UInt16, UInt32).self) + as: (UInt8, UInt16, UInt32, UInt64, UInt64, UInt32, UInt16, UInt8, UInt16, UInt32).self + ) XCTAssertEqual(startWriterIndex + written, self.buf.writerIndex) XCTAssertEqual(written, self.buf.readableBytes) - let result = self.buf.readMultipleIntegers(endianness: endianness, - as: (UInt8, UInt16, UInt32, UInt64, UInt64, UInt32, UInt16, UInt8, UInt16, UInt32).self) + let result = self.buf.readMultipleIntegers( + endianness: endianness, + as: (UInt8, UInt16, UInt32, UInt64, UInt64, UInt32, UInt16, UInt8, UInt16, UInt32).self + ) XCTAssertNotNil(result) XCTAssertEqual(0, self.buf.readableBytes) @@ -3429,13 +3496,29 @@ extension ByteBufferTest { var values6 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! var values7 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! var values8 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values9 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values10 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values11 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values12 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values13 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values14 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! - var values15 = self.buf.readMultipleIntegers(as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self)! + var values9 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values10 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values11 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values12 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values13 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values14 = self.buf.readMultipleIntegers( + as: (UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8).self + )! + var values15 = self.buf.readMultipleIntegers( + as: ( + UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8, UInt8 + ).self + )! XCTAssertEqual([i, i], withUnsafeBytes(of: &values2, { Array($0) })) XCTAssertEqual([i, i, i], withUnsafeBytes(of: &values3, { Array($0) })) diff --git a/Tests/NIOCoreTests/ChannelOptionStorageTest.swift b/Tests/NIOCoreTests/ChannelOptionStorageTest.swift index b222e0917a..638beb76e5 100644 --- a/Tests/NIOCoreTests/ChannelOptionStorageTest.swift +++ b/Tests/NIOCoreTests/ChannelOptionStorageTest.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +import XCTest class ChannelOptionStorageTest: XCTestCase { func testWeStartWithNoOptions() throws { @@ -34,8 +34,10 @@ class ChannelOptionStorageTest: XCTestCase { } func testSetTwoOptionsOfSameType() throws { - let options: [(ChannelOptions.Types.SocketOption, SocketOptionValue)] = [(ChannelOptions.socketOption(.so_reuseaddr), 1), - (ChannelOptions.socketOption(.so_rcvtimeo), 2)] + let options: [(ChannelOptions.Types.SocketOption, SocketOptionValue)] = [ + (ChannelOptions.socketOption(.so_reuseaddr), 1), + (ChannelOptions.socketOption(.so_rcvtimeo), 2), + ] var cos = ChannelOptions.Storage() let optionsCollector = OptionsCollectingChannel() for kv in options { @@ -43,14 +45,18 @@ class ChannelOptionStorageTest: XCTestCase { } XCTAssertNoThrow(try cos.applyAllChannelOptions(to: optionsCollector).wait()) XCTAssertEqual(2, optionsCollector.allOptions.count) - XCTAssertEqual(options.map { $0.0 }, - optionsCollector.allOptions.map { option in - return option.0 as! ChannelOptions.Types.SocketOption - }) - XCTAssertEqual(options.map { $0.1 }, - optionsCollector.allOptions.map { option in - return option.1 as! SocketOptionValue - }) + XCTAssertEqual( + options.map { $0.0 }, + optionsCollector.allOptions.map { option in + option.0 as! ChannelOptions.Types.SocketOption + } + ) + XCTAssertEqual( + options.map { $0.1 }, + optionsCollector.allOptions.map { option in + option.1 as! SocketOptionValue + } + ) } func testSetOneOptionTwice() throws { @@ -60,14 +66,18 @@ class ChannelOptionStorageTest: XCTestCase { cos.append(key: ChannelOptions.socketOption(.so_reuseaddr), value: 2) XCTAssertNoThrow(try cos.applyAllChannelOptions(to: optionsCollector).wait()) XCTAssertEqual(1, optionsCollector.allOptions.count) - XCTAssertEqual([ChannelOptions.socketOption(.so_reuseaddr)], - optionsCollector.allOptions.map { option in - return option.0 as! ChannelOptions.Types.SocketOption - }) - XCTAssertEqual([SocketOptionValue(2)], - optionsCollector.allOptions.map { option in - return option.1 as! SocketOptionValue - }) + XCTAssertEqual( + [ChannelOptions.socketOption(.so_reuseaddr)], + optionsCollector.allOptions.map { option in + option.0 as! ChannelOptions.Types.SocketOption + } + ) + XCTAssertEqual( + [SocketOptionValue(2)], + optionsCollector.allOptions.map { option in + option.1 as! SocketOptionValue + } + ) } func testClearingOptions() throws { @@ -82,15 +92,21 @@ class ChannelOptionStorageTest: XCTestCase { cos.remove(key: ChannelOptions.socketOption(.so_reuseaddr)) XCTAssertNoThrow(try cos.applyAllChannelOptions(to: optionsCollector).wait()) XCTAssertEqual(2, optionsCollector.allOptions.count) - XCTAssertEqual([ChannelOptions.socketOption(.so_keepalive), - ChannelOptions.socketOption(.so_rcvbuf)], - optionsCollector.allOptions.map { option in - return option.0 as! ChannelOptions.Types.SocketOption - }) - XCTAssertEqual([SocketOptionValue(3), SocketOptionValue(5)], - optionsCollector.allOptions.map { option in - return option.1 as! SocketOptionValue - }) + XCTAssertEqual( + [ + ChannelOptions.socketOption(.so_keepalive), + ChannelOptions.socketOption(.so_rcvbuf), + ], + optionsCollector.allOptions.map { option in + option.0 as! ChannelOptions.Types.SocketOption + } + ) + XCTAssertEqual( + [SocketOptionValue(3), SocketOptionValue(5)], + optionsCollector.allOptions.map { option in + option.1 as! SocketOptionValue + } + ) } } @@ -125,6 +141,6 @@ class OptionsCollectingChannel: Channel { var _channelCore: ChannelCore { fatalError() } var eventLoop: EventLoop { - return EmbeddedEventLoop() + EmbeddedEventLoop() } } diff --git a/Tests/NIOCoreTests/CircularBufferTests.swift b/Tests/NIOCoreTests/CircularBufferTests.swift index 3e47d08590..166cda7876 100644 --- a/Tests/NIOCoreTests/CircularBufferTests.swift +++ b/Tests/NIOCoreTests/CircularBufferTests.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import NIOCore class CircularBufferTests: XCTestCase { @@ -105,10 +106,13 @@ class CircularBufferTests: XCTestCase { } func collectAllIndices(ring: CircularBuffer) -> [CircularBuffer.Index] { - return Array(ring.indices) + Array(ring.indices) } - func collectAllIndices(ring: CircularBuffer, range: Range.Index>) -> [CircularBuffer.Index] { + func collectAllIndices( + ring: CircularBuffer, + range: Range.Index> + ) -> [CircularBuffer.Index] { var index: CircularBuffer.Index = range.lowerBound var allIndices: [CircularBuffer.Index] = [] while index != range.upperBound { @@ -121,24 +125,42 @@ class CircularBufferTests: XCTestCase { func testHarderExpansion() { var ring = CircularBuffer(initialCapacity: 3) - XCTAssertEqual(self.collectAllIndices(ring: ring), - self.collectAllIndices(ring: ring, range: ring.startIndex ..< ring.startIndex)) + XCTAssertEqual( + self.collectAllIndices(ring: ring), + self.collectAllIndices(ring: ring, range: ring.startIndex..(initialCapacity: 4) - XCTAssertEqual(self.collectAllIndices(ring: ring), - self.collectAllIndices(ring: ring, range: ring.startIndex ..< ring.startIndex)) + XCTAssertEqual( + self.collectAllIndices(ring: ring), + self.collectAllIndices(ring: ring, range: ring.startIndex..(initialCapacity: 10) XCTAssertNil(ring.last) - for i in 0 ..< 20 { + for i in 0..<20 { ring.prepend(i) } XCTAssertEqual(20, ring.count) @@ -525,7 +613,7 @@ class CircularBufferTests: XCTestCase { XCTAssertEqual(19, ring.first) XCTAssertEqual(10, ring.last) } - + func testOperateOnBothSides() { var ring = CircularBuffer(initialCapacity: 3) XCTAssertNil(ring.last) @@ -593,7 +681,7 @@ class CircularBufferTests: XCTestCase { ring.append(2) XCTAssertEqual(ring.capacity, 4) XCTAssertEqual(ring.count, 2) - ring.removeAll() // default should not keep capacity + ring.removeAll() // default should not keep capacity XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) XCTAssertEqual(ring.capacity, 1) XCTAssertEqual(ring.count, 0) @@ -606,7 +694,7 @@ class CircularBufferTests: XCTestCase { // Now we want to replace the last subrange with two elements. This should // force an increase in size. - ring.replaceSubrange(ring.startIndex ..< ring.index(ring.startIndex, offsetBy: 1), with: [3, 4]) + ring.replaceSubrange(ring.startIndex..(initialCapacity: 4) - (0..<16).forEach { ring.append($0) } + for i in (0..<16) { + ring.append(i) + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) - (0..<4).forEach { _ in _ = ring.removeFirst() } + for _ in (0..<4) { + _ = ring.removeFirst() + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) - (16..<20).forEach { ring.append($0) } + for i in (16..<20) { + ring.append(i) + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) XCTAssertEqual(Array(4..<20), Array(ring)) ring.removeAll(keepingCapacity: shouldKeepCapacity) - (0..<8).forEach { ring.append($0) } + for i in (0..<8) { + ring.append(i) + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) - (0..<4).forEach { _ in _ = ring.removeFirst() } + for _ in (0..<4) { + _ = ring.removeFirst() + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) - (8..<64).forEach { ring.append($0) } + for i in (8..<64) { + ring.append(i) + } XCTAssertTrue(ring.testOnly_verifyInvariantsForNonSlices()) XCTAssertEqual(Array(4..<64), Array(ring)) @@ -739,36 +839,36 @@ class CircularBufferTests: XCTestCase { assert(dummy7 == nil, within: .seconds(1)) assert(dummy8 == nil, within: .seconds(1)) } - + func testIntIndexing() { var ring = CircularBuffer() - for i in 0 ..< 5 { + for i in 0..<5 { ring.append(i) XCTAssertEqual(ring[offset: i], i) } - - XCTAssertEqual(ring[ring.startIndex], ring[offset :0]) + + XCTAssertEqual(ring[ring.startIndex], ring[offset: 0]) XCTAssertEqual(ring[ring.index(before: ring.endIndex)], ring[offset: 4]) - + ring[offset: 1] = 10 XCTAssertEqual(ring[ring.index(after: ring.startIndex)], 10) } - + func testIndexDistance() { var bufferOfBackingSize4 = CircularBuffer(initialCapacity: 4) XCTAssertEqual(3, bufferOfBackingSize4.indexBeforeHeadIdx()) - + let index1 = CircularBuffer.Index(backingIndex: 0, backingCount: 4, backingIndexOfHead: 0) let index2 = CircularBuffer.Index(backingIndex: 1, backingCount: 4, backingIndexOfHead: 0) XCTAssertEqual(bufferOfBackingSize4.distance(from: index1, to: index2), 1) - + bufferOfBackingSize4.append(1) XCTAssertEqual(1, bufferOfBackingSize4.removeFirst()) XCTAssertEqual(1, bufferOfBackingSize4.headBackingIndex) let index3 = CircularBuffer.Index(backingIndex: 2, backingCount: 4, backingIndexOfHead: 1) let index4 = CircularBuffer.Index(backingIndex: 0, backingCount: 4, backingIndexOfHead: 1) XCTAssertEqual(bufferOfBackingSize4.distance(from: index3, to: index4), 2) - + let index5 = CircularBuffer.Index(backingIndex: 0, backingCount: 4, backingIndexOfHead: 1) let index6 = CircularBuffer.Index(backingIndex: 2, backingCount: 4, backingIndexOfHead: 1) XCTAssertEqual(bufferOfBackingSize4.distance(from: index5, to: index6), -2) @@ -782,7 +882,7 @@ class CircularBufferTests: XCTestCase { let index8 = CircularBuffer.Index(backingIndex: 2, backingCount: 4, backingIndexOfHead: 3) XCTAssertEqual(bufferOfBackingSize4.distance(from: index7, to: index8), 2) } - + func testIndexAdvancing() { var bufferOfBackingSize4 = CircularBuffer(initialCapacity: 4) XCTAssertEqual(3, bufferOfBackingSize4.indexBeforeHeadIdx()) @@ -791,7 +891,7 @@ class CircularBufferTests: XCTestCase { let index2 = bufferOfBackingSize4.index(after: index1) XCTAssertEqual(index2.backingIndex, 1) XCTAssertEqual(index2.isIndexGEQHeadIndex, true) - + bufferOfBackingSize4.append(1) bufferOfBackingSize4.append(2) XCTAssertEqual(1, bufferOfBackingSize4.removeFirst()) @@ -809,7 +909,7 @@ class CircularBufferTests: XCTestCase { let index6 = bufferOfBackingSize4.index(before: index5) XCTAssertEqual(index6.backingIndex, 3) XCTAssertEqual(index6.isIndexGEQHeadIndex, true) - + let index7 = CircularBuffer.Index(backingIndex: 2, backingCount: 4, backingIndexOfHead: 1) let index8 = bufferOfBackingSize4.index(before: index7) XCTAssertEqual(index8.backingIndex, 1) @@ -823,7 +923,7 @@ class CircularBufferTests: XCTestCase { } else { XCTFail("popFirst didn't find first element") } - + if let element = buf.popFirst() { XCTAssertEqual(2, element) } else { @@ -835,11 +935,11 @@ class CircularBufferTests: XCTestCase { } else { XCTFail("popFirst didn't find third element") } - + XCTAssertNil(buf.popFirst()) XCTAssertTrue(buf.testOnly_verifyInvariantsForNonSlices()) } - + func testSlicing() { var buf = CircularBuffer() for i in -4..<124 { @@ -853,10 +953,10 @@ class CircularBufferTests: XCTestCase { buf.append(125) buf.append(126) buf.append(127) - - let buf2: CircularBuffer = buf[buf.index(buf.startIndex, offsetBy: 100) ..< buf.endIndex] + + let buf2: CircularBuffer = buf[buf.index(buf.startIndex, offsetBy: 100)..() let emptyB = CircularBuffer() XCTAssertEqual(emptyA, emptyB) - + var buffA = CircularBuffer() var buffB = CircularBuffer() var buffC = CircularBuffer() var buffD = CircularBuffer() buffA.append(contentsOf: 1...10) buffB.append(contentsOf: 1...10) - buffC.append(contentsOf: 2...11) // Same count different values - buffD.append(contentsOf: 1...2) // Different count + buffC.append(contentsOf: 2...11) // Same count different values + buffD.append(contentsOf: 1...2) // Different count XCTAssertEqual(buffA, buffB) XCTAssertNotEqual(buffA, buffC) XCTAssertNotEqual(buffA, buffD) - + // Will make internal head/tail indexes different var prependBuff = CircularBuffer() var appendBuff = CircularBuffer() @@ -960,22 +1060,22 @@ class CircularBufferTests: XCTestCase { // But the contents are still the same XCTAssertEqual(prependBuff, appendBuff) } - + func testHash() { let emptyA = CircularBuffer() let emptyB = CircularBuffer() - XCTAssertEqual(Set([emptyA,emptyB]).count, 1) - + XCTAssertEqual(Set([emptyA, emptyB]).count, 1) + var buffA = CircularBuffer() var buffB = CircularBuffer() buffA.append(contentsOf: 1...10) buffB.append(contentsOf: 1...10) - XCTAssertEqual(Set([buffA,buffB]).count, 1) + XCTAssertEqual(Set([buffA, buffB]).count, 1) buffB.append(123) - XCTAssertEqual(Set([buffA,buffB]).count, 2) + XCTAssertEqual(Set([buffA, buffB]).count, 2) buffA.append(1) - XCTAssertEqual(Set([buffA,buffB]).count, 2) - + XCTAssertEqual(Set([buffA, buffB]).count, 2) + // Will make internal head/tail indexes different var prependBuff = CircularBuffer() var appendBuff = CircularBuffer() @@ -985,17 +1085,17 @@ class CircularBufferTests: XCTestCase { for i in 1...100 { appendBuff.append(i) } - XCTAssertEqual(Set([prependBuff,appendBuff]).count, 1) + XCTAssertEqual(Set([prependBuff, appendBuff]).count, 1) } - + func testArrayLiteralInit() { let empty: CircularBuffer = [] XCTAssert(empty.isEmpty) - + let increasingInts: CircularBuffer = [1, 2, 3, 4, 5] XCTAssertEqual(increasingInts.count, 5) XCTAssert(zip(increasingInts, 1...5).allSatisfy(==)) - + let someIntsArray = [-9, 384, 2, 10, 0, 0, 0] let someInts: CircularBuffer = [-9, 384, 2, 10, 0, 0, 0] XCTAssertEqual(someInts.count, 7) diff --git a/Tests/NIOCoreTests/CustomChannelTests.swift b/Tests/NIOCoreTests/CustomChannelTests.swift index a2493f466e..916fda7dbd 100644 --- a/Tests/NIOCoreTests/CustomChannelTests.swift +++ b/Tests/NIOCoreTests/CustomChannelTests.swift @@ -12,13 +12,13 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +import XCTest -struct NotImplementedError: Error { } +struct NotImplementedError: Error {} -struct InvalidTypeError: Error { } +struct InvalidTypeError: Error {} /// A basic ChannelCore that expects write0 to receive a NIOAny containing an Int. /// diff --git a/Tests/NIOCoreTests/IOErrorTest.swift b/Tests/NIOCoreTests/IOErrorTest.swift index 6be8785cdc..a7924a9639 100644 --- a/Tests/NIOCoreTests/IOErrorTest.swift +++ b/Tests/NIOCoreTests/IOErrorTest.swift @@ -14,11 +14,12 @@ // import XCTest + @testable import NIOCore class IOErrorTest: XCTestCase { func testMemoryLayoutBelowThreshold() { - XCTAssert(MemoryLayout.size <= 24) + XCTAssert(MemoryLayout.size <= 24) } @available(*, deprecated, message: "deprecated because it tests deprecated functionality") diff --git a/Tests/NIOCoreTests/IntegerTypesTest.swift b/Tests/NIOCoreTests/IntegerTypesTest.swift index f1b4a98888..3229728ceb 100644 --- a/Tests/NIOCoreTests/IntegerTypesTest.swift +++ b/Tests/NIOCoreTests/IntegerTypesTest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import NIOCore public final class IntegerTypesTest: XCTestCase { @@ -85,7 +86,7 @@ public final class IntegerTypesTest: XCTestCase { func testDescriptionUInt24() { XCTAssertEqual("0", _UInt24.min.description) XCTAssertEqual("16777215", _UInt24.max.description) - XCTAssertEqual("12345678", _UInt24(12345678 as UInt32).description) + XCTAssertEqual("12345678", _UInt24(12_345_678 as UInt32).description) XCTAssertEqual("1", _UInt24(1).description) XCTAssertEqual("8388608", _UInt24(1 << 23).description) XCTAssertEqual("66", _UInt24(66).description) @@ -94,7 +95,7 @@ public final class IntegerTypesTest: XCTestCase { func testDescriptionUInt56() { XCTAssertEqual("0", _UInt56.min.description) XCTAssertEqual("72057594037927935", _UInt56.max.description) - XCTAssertEqual("12345678901234567", _UInt56(12345678901234567 as UInt64).description) + XCTAssertEqual("12345678901234567", _UInt56(12_345_678_901_234_567 as UInt64).description) XCTAssertEqual("1", _UInt56(1).description) XCTAssertEqual("66", _UInt56(66).description) XCTAssertEqual("36028797018963968", _UInt56(UInt64(1) << 55).description) diff --git a/Tests/NIOCoreTests/LinuxTest.swift b/Tests/NIOCoreTests/LinuxTest.swift index 243bfd722d..9e25d7f391 100644 --- a/Tests/NIOCoreTests/LinuxTest.swift +++ b/Tests/NIOCoreTests/LinuxTest.swift @@ -13,12 +13,13 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import NIOCore class LinuxTest: XCTestCase { func testCoreCountQuota() throws { #if os(Linux) || os(Android) - try [ + let coreCountQuoats = [ ("50000", "100000", 1), ("100000", "100000", 1), ("100000\n", "100000", 1), @@ -29,8 +30,9 @@ class LinuxTest: XCTestCase { ("100000", "-1", nil), ("", "100000", nil), ("100000", "", nil), - ("100000", "0", nil) - ].forEach { quota, period, count in + ("100000", "0", nil), + ] + for (quota, period, count) in coreCountQuoats { try withTemporaryFile(content: quota) { (_, quotaPath) -> Void in try withTemporaryFile(content: period) { (_, periodPath) -> Void in XCTAssertEqual(Linux.coreCountCgroup1Restriction(quota: quotaPath, period: periodPath), count) @@ -42,15 +44,16 @@ class LinuxTest: XCTestCase { func testCoreCountCpuset() throws { #if os(Linux) || os(Android) - try [ + let cpusets = [ ("0", 1), ("0,3", 2), ("0-3", 4), ("0-3,7", 5), ("0-3,7\n", 5), ("0,2-4,6,7,9-11", 9), - ("", nil) - ].forEach { cpuset, count in + ("", nil), + ] + for (cpuset, count) in cpusets { try withTemporaryFile(content: cpuset) { (_, path) -> Void in XCTAssertEqual(Linux.coreCount(cpuset: path), count) } @@ -60,11 +63,12 @@ class LinuxTest: XCTestCase { func testCoreCountCgoup2() throws { #if os(Linux) || os(Android) - try [ + let contents = [ ("max 100000", nil), ("75000 100000", 1), - ("200000 100000", 2) - ].forEach { (content, count) in + ("200000 100000", 2), + ] + for (content, count) in contents { try withTemporaryFile(content: content) { (_, path) in XCTAssertEqual(Linux.coreCountCgroup2Restriction(cpuMaxPath: path), count) } diff --git a/Tests/NIOCoreTests/MarkedCircularBufferTests.swift b/Tests/NIOCoreTests/MarkedCircularBufferTests.swift index b4c55334b8..6d5681c66e 100644 --- a/Tests/NIOCoreTests/MarkedCircularBufferTests.swift +++ b/Tests/NIOCoreTests/MarkedCircularBufferTests.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest class MarkedCircularBufferTests: XCTestCase { func testEmptyMark() throws { @@ -124,7 +124,7 @@ class MarkedCircularBufferTests: XCTestCase { } let range = buf.startIndex..(remoteAddress: socketAddress, data: envelopeByteBuffer) - XCTAssertEqual(wrappedInNIOAnyBlock("\(envelope)"), wrappedInNIOAnyBlock(""" - AddressedEnvelope { \ - remoteAddress: \(socketAddress), \ - data: \(envelopeByteBuffer) } - """)) + XCTAssertEqual( + wrappedInNIOAnyBlock("\(envelope)"), + wrappedInNIOAnyBlock( + """ + AddressedEnvelope { \ + remoteAddress: \(socketAddress), \ + data: \(envelopeByteBuffer) } + """ + ) + ) } - + private func wrappedInNIOAnyBlock(_ item: Any) -> String { - return "NIOAny { \(item) }" + "NIOAny { \(item) }" } - + } diff --git a/Tests/NIOCoreTests/NIOCloseOnErrorHandlerTest.swift b/Tests/NIOCoreTests/NIOCloseOnErrorHandlerTest.swift index 7d3fd9e4e9..16445957b2 100644 --- a/Tests/NIOCoreTests/NIOCloseOnErrorHandlerTest.swift +++ b/Tests/NIOCoreTests/NIOCloseOnErrorHandlerTest.swift @@ -12,9 +12,9 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +import XCTest final class DummyFailingHandler1: ChannelInboundHandler { typealias InboundIn = NIOAny @@ -47,14 +47,18 @@ class NIOCloseOnErrorHandlerTest: XCTestCase { } func testChannelCloseOnError() throws { - XCTAssertNoThrow(self.channel.pipeline.addHandlers([ - DummyFailingHandler1(), - NIOCloseOnErrorHandler() - ])) + XCTAssertNoThrow( + self.channel.pipeline.addHandlers([ + DummyFailingHandler1(), + NIOCloseOnErrorHandler(), + ]) + ) XCTAssertNoThrow(try self.channel.connect(to: .init(ipAddress: "1.2.3.4", port: 5)).wait()) XCTAssertTrue(self.channel.isActive) - XCTAssertThrowsError(try self.channel.writeInbound("Hello World")) { XCTAssertTrue($0 is DummyFailingHandler1.DummyError1) } + XCTAssertThrowsError(try self.channel.writeInbound("Hello World")) { + XCTAssertTrue($0 is DummyFailingHandler1.DummyError1) + } XCTAssertFalse(self.channel.isActive) } diff --git a/Tests/NIOCoreTests/RecvByteBufAllocatorTest.swift b/Tests/NIOCoreTests/RecvByteBufAllocatorTest.swift index 56d7ed2cda..eadceb7178 100644 --- a/Tests/NIOCoreTests/RecvByteBufAllocatorTest.swift +++ b/Tests/NIOCoreTests/RecvByteBufAllocatorTest.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest final class AdaptiveRecvByteBufferAllocatorTest: XCTestCase { private let allocator = ByteBufferAllocator() @@ -90,8 +90,20 @@ final class AdaptiveRecvByteBufferAllocatorTest: XCTestCase { } } - private func testActualReadBytes(mayGrow: Bool, actualReadBytes: Int, expectedCapacity: Int, file: StaticString = #filePath, line: UInt = #line) { - XCTAssertEqual(mayGrow, adaptive.record(actualReadBytes: actualReadBytes), "unexpected value for mayGrow", file: file, line: line) + private func testActualReadBytes( + mayGrow: Bool, + actualReadBytes: Int, + expectedCapacity: Int, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertEqual( + mayGrow, + adaptive.record(actualReadBytes: actualReadBytes), + "unexpected value for mayGrow", + file: file, + line: line + ) let buffer = adaptive.buffer(allocator: allocator) XCTAssertEqual(expectedCapacity, buffer.capacity, "unexpected capacity", file: file, line: line) } @@ -135,7 +147,11 @@ final class AdaptiveRecvByteBufferAllocatorTest: XCTestCase { return } - let adaptive = AdaptiveRecvByteBufferAllocator(minimum: targetValue, initial: targetValue + 1, maximum: targetValue + 2) + let adaptive = AdaptiveRecvByteBufferAllocator( + minimum: targetValue, + initial: targetValue + 1, + maximum: targetValue + 2 + ) XCTAssertEqual(adaptive.minimum, 1 << 30) XCTAssertEqual(adaptive.maximum, 1 << 30) XCTAssertEqual(adaptive.initial, 1 << 30) diff --git a/Tests/NIOCoreTests/SingleStepByteToMessageDecoderTest.swift b/Tests/NIOCoreTests/SingleStepByteToMessageDecoderTest.swift index 80dbefca0f..60921984db 100644 --- a/Tests/NIOCoreTests/SingleStepByteToMessageDecoderTest.swift +++ b/Tests/NIOCoreTests/SingleStepByteToMessageDecoderTest.swift @@ -12,16 +12,17 @@ // //===----------------------------------------------------------------------===// +import NIOEmbedded import XCTest + @testable import NIOCore -import NIOEmbedded public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { private final class ByteToInt32Decoder: NIOSingleStepByteToMessageDecoder { typealias InboundOut = Int32 func decode(buffer: inout ByteBuffer) throws -> InboundOut? { - return buffer.readInteger() + buffer.readInteger() } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? { @@ -34,7 +35,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { typealias InboundOut = ByteBuffer func decode(buffer: inout ByteBuffer) throws -> InboundOut? { - return buffer.readSlice(length: 512) + buffer.readSlice(length: 512) } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? { @@ -69,7 +70,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { var lastBuffer: ByteBuffer? func decode(buffer: inout ByteBuffer) throws -> InboundOut? { - return buffer.readSlice(length: 2) + buffer.readSlice(length: 2) } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? { @@ -87,7 +88,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { messages.append(message) } - var count: Int { return messages.count } + var count: Int { messages.count } func retrieveMessage() -> InboundOut? { if messages.isEmpty { @@ -111,7 +112,12 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { XCTAssertNil(messageReceiver.retrieveMessage()) buffer.moveWriterIndex(to: writerIndex) - XCTAssertNoThrow(try processor.process(buffer: buffer.getSlice(at: writerIndex - 1, length: 1)!, messageReceiver.receiveMessage)) + XCTAssertNoThrow( + try processor.process( + buffer: buffer.getSlice(at: writerIndex - 1, length: 1)!, + messageReceiver.receiveMessage + ) + ) var buffer2 = allocator.buffer(capacity: 32) buffer2.writeInteger(Int32(2)) @@ -163,7 +169,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { XCTAssertEqual(processor._buffer!.capacity, 2048) XCTAssertEqual(2, processor._buffer!.readableBytes) XCTAssertEqual(1024, processor._buffer!.readerIndex) - + // Finally we're going to send in another 513 bytes. This will cause another chunk to be // passed into our decoder buffer, which has a capacity of 2048 bytes. Since the buffer has // enough available space (1022 bytes) there will be no buffer resize before the decoding. @@ -171,10 +177,10 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { // (3 * 512 bytes). This means that 75% of the buffer's capacity can now be reclaimed, which // will lead to a reclaim. The resulting buffer will have a capacity of 2048 bytes (based // on its previous growth), with 3 readable bytes remaining. - + XCTAssertNoThrow(try processor.process(buffer: buffer, messageReceiver.receiveMessage)) XCTAssertEqual(512, messageReceiver.retrieveMessage()!.readableBytes) - + XCTAssertEqual(processor._buffer!.capacity, 2048) XCTAssertEqual(3, processor._buffer!.readableBytes) XCTAssertEqual(0, processor._buffer!.readerIndex) @@ -199,7 +205,9 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { // Now we're going to send in one more byte. This will cause a chunk to be passed on, // shrinking the held memory to 3072 bytes. However, memory will be reclaimed. - XCTAssertNoThrow(try processor.process(buffer: buffer.getSlice(at: 0, length: 1)!, messageReceiver.receiveMessage)) + XCTAssertNoThrow( + try processor.process(buffer: buffer.getSlice(at: 0, length: 1)!, messageReceiver.receiveMessage) + ) XCTAssertEqual(2048, messageReceiver.retrieveMessage()!.readableBytes) XCTAssertEqual(3072, processor._buffer!.readableBytes) XCTAssertEqual(0, processor._buffer!.readerIndex) @@ -228,26 +236,44 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { XCTAssertNoThrow(try processor.finishProcessing(seenEOF: false, messageReceiver.receiveMessage)) XCTAssertEqual(processor.unprocessedBytes, 1) - XCTAssertEqual("12", messageReceiver.retrieveMessage().map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) - XCTAssertEqual("34", messageReceiver.retrieveMessage().map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) - XCTAssertEqual("56", messageReceiver.retrieveMessage().map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) - XCTAssertEqual("78", messageReceiver.retrieveMessage().map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) - XCTAssertEqual("90", messageReceiver.retrieveMessage().map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "12", + messageReceiver.retrieveMessage().map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + XCTAssertEqual( + "34", + messageReceiver.retrieveMessage().map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + XCTAssertEqual( + "56", + messageReceiver.retrieveMessage().map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + XCTAssertEqual( + "78", + messageReceiver.retrieveMessage().map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + XCTAssertEqual( + "90", + messageReceiver.retrieveMessage().map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) XCTAssertNil(messageReceiver.retrieveMessage()) - XCTAssertEqual("x", decoder.lastBuffer.map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - }) + XCTAssertEqual( + "x", + decoder.lastBuffer.map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) XCTAssertEqual(1, decoder.decodeLastCalls) } @@ -342,11 +368,11 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { typealias InboundOut = Never func decode(buffer: inout ByteBuffer) throws -> InboundOut? { - return nil + nil } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? { - return nil + nil } } @@ -374,7 +400,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> InboundOut? { - return try decode(buffer: &buffer) + try decode(buffer: &buffer) } } @@ -423,7 +449,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { typealias InboundOut = String func decode(buffer: inout ByteBuffer) throws -> String? { - return buffer.readString(length: 1) + buffer.readString(length: 1) } func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> String? { @@ -440,19 +466,19 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { XCTAssertNoThrow(XCTAssertEqual("a", try channel.readInbound())) XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) } - + func testWeDoNotCallShouldReclaimMemoryAsLongAsFramesAreProduced() { struct TestByteToMessageDecoder: NIOSingleStepByteToMessageDecoder { typealias InboundOut = TestMessage - + enum TestMessage: Equatable { case foo } - + var lastByteBuffer: ByteBuffer? var decodeHits = 0 var reclaimHits = 0 - + mutating func decode(buffer: inout ByteBuffer) throws -> TestMessage? { XCTAssertEqual(self.decodeHits * 3, buffer.readerIndex) self.decodeHits += 1 @@ -462,26 +488,28 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { buffer.moveReaderIndex(forwardBy: 3) return .foo } - + mutating func decodeLast(buffer: inout ByteBuffer, seenEOF: Bool) throws -> TestMessage? { try self.decode(buffer: &buffer) } - + mutating func shouldReclaimBytes(buffer: ByteBuffer) -> Bool { self.reclaimHits += 1 return true } } - + let decoder = TestByteToMessageDecoder() let processor = NIOSingleStepByteToMessageProcessor(decoder, maximumBufferSize: nil) - + let buffer = ByteBuffer(repeating: 0, count: 3001) var callbackCount = 0 - XCTAssertNoThrow(try processor.process(buffer: buffer) { _ in - callbackCount += 1 - }) - + XCTAssertNoThrow( + try processor.process(buffer: buffer) { _ in + callbackCount += 1 + } + ) + XCTAssertEqual(callbackCount, 1000) XCTAssertEqual(processor.decoder.decodeHits, 1001) XCTAssertEqual(processor.decoder.reclaimHits, 1) @@ -490,7 +518,7 @@ public final class NIOSingleStepByteToMessageDecoderTest: XCTestCase { func testUnprocessedBytes() { let allocator = ByteBufferAllocator() - let processor = NIOSingleStepByteToMessageProcessor(LargeChunkDecoder()) // reads slices of 512 bytes + let processor = NIOSingleStepByteToMessageProcessor(LargeChunkDecoder()) // reads slices of 512 bytes let messageReceiver: MessageReceiver = MessageReceiver() // We're going to send in 128 bytes. This will be held. diff --git a/Tests/NIOCoreTests/TimeAmount+DurationTests.swift b/Tests/NIOCoreTests/TimeAmount+DurationTests.swift index 73c0103c6c..0ade53a39a 100644 --- a/Tests/NIOCoreTests/TimeAmount+DurationTests.swift +++ b/Tests/NIOCoreTests/TimeAmount+DurationTests.swift @@ -11,9 +11,11 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@testable import NIOCore + import XCTest +@testable import NIOCore + class TimeAmountDurationTests: XCTestCase { func testTimeAmountFromDurationConversion() throws { guard #available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) else { diff --git a/Tests/NIOCoreTests/TimeAmountTests.swift b/Tests/NIOCoreTests/TimeAmountTests.swift index 7fb6a4efe8..0cf762a1dd 100644 --- a/Tests/NIOCoreTests/TimeAmountTests.swift +++ b/Tests/NIOCoreTests/TimeAmountTests.swift @@ -29,7 +29,7 @@ class TimeAmountTests: XCTestCase { let amounts: Set = [.seconds(1), .milliseconds(4), .seconds(1)] XCTAssertEqual(amounts, [.seconds(1), .milliseconds(4)]) } - + func testTimeAmountDoesAddTime() { var lhs = TimeAmount.nanoseconds(0) let rhs = TimeAmount.nanoseconds(5) @@ -43,7 +43,7 @@ class TimeAmountTests: XCTestCase { lhs -= rhs XCTAssertEqual(lhs, .nanoseconds(0)) } - + func testTimeAmountCappedOverflow() { let overflowCap = TimeAmount.nanoseconds(Int64.max) XCTAssertEqual(TimeAmount.microseconds(.max), overflowCap) @@ -52,7 +52,7 @@ class TimeAmountTests: XCTestCase { XCTAssertEqual(TimeAmount.minutes(.max), overflowCap) XCTAssertEqual(TimeAmount.hours(.max), overflowCap) } - + func testTimeAmountCappedUnderflow() { let underflowCap = TimeAmount.nanoseconds(.min) XCTAssertEqual(TimeAmount.microseconds(.min), underflowCap) diff --git a/Tests/NIOCoreTests/TypeAssistedChannelHandlerTests.swift b/Tests/NIOCoreTests/TypeAssistedChannelHandlerTests.swift index 49d55c68a9..274be63410 100644 --- a/Tests/NIOCoreTests/TypeAssistedChannelHandlerTests.swift +++ b/Tests/NIOCoreTests/TypeAssistedChannelHandlerTests.swift @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest class TypeAssistedChannelHandlerTest: XCTestCase { func testCanDefineBothInboundAndOutbound() throws { diff --git a/Tests/NIOCoreTests/UtilitiesTest.swift b/Tests/NIOCoreTests/UtilitiesTest.swift index 208d9aa644..b2a9814165 100644 --- a/Tests/NIOCoreTests/UtilitiesTest.swift +++ b/Tests/NIOCoreTests/UtilitiesTest.swift @@ -40,7 +40,10 @@ class UtilitiesTest: XCTestCase { XCTAssertNil(interface.pointToPointDestinationAddress) } else if try interface.address == SocketAddress(ipAddress: "::1", port: 0) { ipv6LoopbackPresent = true - XCTAssertEqual(interface.netmask, try SocketAddress(ipAddress: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", port: 0)) + XCTAssertEqual( + interface.netmask, + try SocketAddress(ipAddress: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", port: 0) + ) XCTAssertNil(interface.broadcastAddress) XCTAssertNil(interface.pointToPointDestinationAddress) } @@ -68,7 +71,10 @@ class UtilitiesTest: XCTestCase { XCTAssertNil(device.pointToPointDestinationAddress) } else if try device.address == SocketAddress(ipAddress: "::1", port: 0) { ipv6LoopbackPresent = true - XCTAssertEqual(device.netmask, try SocketAddress(ipAddress: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", port: 0)) + XCTAssertEqual( + device.netmask, + try SocketAddress(ipAddress: "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", port: 0) + ) XCTAssertNil(device.broadcastAddress) XCTAssertNil(device.pointToPointDestinationAddress) } diff --git a/Tests/NIOCoreTests/XCTest+AsyncAwait.swift b/Tests/NIOCoreTests/XCTest+AsyncAwait.swift index b602218894..9107a35518 100644 --- a/Tests/NIOCoreTests/XCTest+AsyncAwait.swift +++ b/Tests/NIOCoreTests/XCTest+AsyncAwait.swift @@ -24,23 +24,23 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - * Copyright 2021, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// +// Copyright 2021, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. import XCTest + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) internal func XCTAssertThrowsError( _ expression: @autoclosure () async throws -> T, diff --git a/Tests/NIOCoreTests/XCTest+Extensions.swift b/Tests/NIOCoreTests/XCTest+Extensions.swift index 25a910f34b..1717542d9e 100644 --- a/Tests/NIOCoreTests/XCTest+Extensions.swift +++ b/Tests/NIOCoreTests/XCTest+Extensions.swift @@ -12,24 +12,37 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest -func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testInterval: TimeAmount? = nil, _ message: String = "condition not satisfied in time", file: StaticString = #filePath, line: UInt = #line) { +func assert( + _ condition: @autoclosure () -> Bool, + within time: TimeAmount, + testInterval: TimeAmount? = nil, + _ message: String = "condition not satisfied in time", + file: StaticString = #filePath, + line: UInt = #line +) { let testInterval = testInterval ?? TimeAmount.nanoseconds(time.nanoseconds / 5) let endTime = NIODeadline.now() + time repeat { if condition() { return } usleep(UInt32(testInterval.nanoseconds / 1000)) - } while (NIODeadline.now() < endTime) + } while NIODeadline.now() < endTime if !condition() { XCTFail(message, file: (file), line: line) } } -func assertNoThrowWithValue(_ body: @autoclosure () throws -> T, defaultValue: T? = nil, message: String? = nil, file: StaticString = #filePath, line: UInt = #line) throws -> T { +func assertNoThrowWithValue( + _ body: @autoclosure () throws -> T, + defaultValue: T? = nil, + message: String? = nil, + file: StaticString = #filePath, + line: UInt = #line +) throws -> T { do { return try body() } catch { @@ -57,22 +70,22 @@ func withTemporaryFile(content: String? = nil, _ body: (NIOCore.NIOFileHandle return try body(fileHandle, temporaryFilePath) } -fileprivate var temporaryDirectory: String { +private var temporaryDirectory: String { get { -#if targetEnvironment(simulator) + #if targetEnvironment(simulator) // Simulator temp directories are so long (and contain the user name) that they're not usable // for UNIX Domain Socket paths (which are limited to 103 bytes). return "/tmp" -#else -#if os(Linux) + #else + #if os(Linux) return "/tmp" -#else + #else if #available(macOS 10.12, iOS 10, tvOS 10, watchOS 3, *) { return FileManager.default.temporaryDirectory.path } else { return "/tmp" } -#endif // os -#endif // targetEnvironment + #endif // os + #endif // targetEnvironment } } diff --git a/Tests/NIODataStructuresTests/HeapTests.swift b/Tests/NIODataStructuresTests/HeapTests.swift index 64282edfd9..186ae4ad11 100644 --- a/Tests/NIODataStructuresTests/HeapTests.swift +++ b/Tests/NIODataStructuresTests/HeapTests.swift @@ -13,10 +13,11 @@ //===----------------------------------------------------------------------===// import XCTest + @testable import _NIODataStructures public func getRandomNumbers(count: Int) -> [UInt8] { - return (0..() - let input = [16, 14, 10, 9, 8, 7, 4, 3, 2, 1] - input.forEach { - minHeap.append($0) + let inputs = [16, 14, 10, 9, 8, 7, 4, 3, 2, 1] + for input in inputs { + minHeap.append(input) XCTAssertTrue(minHeap.checkHeapProperty()) } - var minHeapInputPtr = input.count - 1 + var minHeapInputPtr = inputs.count - 1 while let minE = minHeap.removeRoot() { - XCTAssertEqual(minE, input[minHeapInputPtr]) + XCTAssertEqual(minE, inputs[minHeapInputPtr]) minHeapInputPtr -= 1 XCTAssertTrue(minHeap.checkHeapProperty(), "\(minHeap.debugDescription)") } @@ -51,16 +52,16 @@ class HeapTests: XCTestCase { func testSortedAsc() throws { var minHeap = Heap() - let input = Array([16, 14, 10, 9, 8, 7, 4, 3, 2, 1].reversed()) - input.forEach { - minHeap.append($0) + let inputs = Array([16, 14, 10, 9, 8, 7, 4, 3, 2, 1].reversed()) + for input in inputs { + minHeap.append(input) } var minHeapInputPtr = 0 while let minE = minHeap.removeRoot() { - XCTAssertEqual(minE, input[minHeapInputPtr]) + XCTAssertEqual(minE, inputs[minHeapInputPtr]) minHeapInputPtr += 1 } - XCTAssertEqual(input.count, minHeapInputPtr) + XCTAssertEqual(inputs.count, minHeapInputPtr) } func testAddAndRemoveRandomNumbers() throws { @@ -76,7 +77,7 @@ class HeapTests: XCTestCase { XCTAssertEqual(Array(minHeap.sorted()), Array(minHeap)) } - for _ in 0..() let randoms = getRandomNumbers(count: size) - randoms.forEach { pq.push($0) } + for number in randoms { + pq.push(number) + } - /* remove one random member, add it back and assert we're still the same */ - randoms.forEach { random in + // remove one random member, add it back and assert we're still the same + for random in randoms { var pq2 = pq pq2.remove(random) XCTAssertEqual(pq.count - 1, pq2.count) @@ -69,7 +71,7 @@ class PriorityQueueTest: XCTestCase { XCTAssertEqual(pq, pq2) } - /* remove up to `n` members and add them back at the end and check that the priority queues are still the same */ + // remove up to `n` members and add them back at the end and check that the priority queues are still the same for n in 1...5 where n <= size { var pq2 = pq let deleted = randoms.prefix(n).map { (random: UInt8) -> UInt8 in @@ -78,7 +80,9 @@ class PriorityQueueTest: XCTestCase { } XCTAssertEqual(pq.count - n, pq2.count) XCTAssertNotEqual(pq, pq2) - deleted.reversed().forEach { pq2.push($0) } + for number in deleted.reversed() { + pq2.push(number) + } XCTAssertEqual(pq, pq2, "pq: \(pq), pq2: \(pq2), deleted: \(deleted)") } } @@ -89,22 +93,20 @@ class PriorityQueueTest: XCTestCase { let clearlyTheLargest = SomePartiallyOrderedDataType(width: 100, height: 100) let inTheMiddles = zip(1...99, (1...99).reversed()).map { SomePartiallyOrderedDataType(width: $0, height: $1) } - /* - the four values are only partially ordered (from small (top) to large (bottom)): + // the four values are only partially ordered (from small (top) to large (bottom)): - clearlyTheSmallest - / | \ - inTheMiddle[0] | inTheMiddle[1...] - \ | / - clearlyTheLargest - */ + // clearlyTheSmallest + // / | \ + // inTheMiddle[0] | inTheMiddle[1...] + // \ | / + // clearlyTheLargest var pq = PriorityQueue() pq.push(clearlyTheLargest) pq.push(inTheMiddles[0]) pq.push(clearlyTheSmallest) - inTheMiddles[1...].forEach { - pq.push($0) + for number in inTheMiddles[1...] { + pq.push(number) } let pop1 = pq.pop() XCTAssertEqual(clearlyTheSmallest, pop1) @@ -115,7 +117,7 @@ class PriorityQueueTest: XCTestCase { XCTAssertEqual(clearlyTheLargest, pq.pop()!) XCTAssert(pq.isEmpty) } - + func testDescription() { let pq1 = PriorityQueue() var pq2 = PriorityQueue() @@ -129,12 +131,12 @@ class PriorityQueueTest: XCTestCase { /// This data type is only partially ordered. Ie. from `a < b` and `a != b` we can't imply `a > b`. struct SomePartiallyOrderedDataType: Comparable, CustomStringConvertible { - public static func <(lhs: SomePartiallyOrderedDataType, rhs: SomePartiallyOrderedDataType) -> Bool { - return lhs.width < rhs.width && lhs.height < rhs.height + public static func < (lhs: SomePartiallyOrderedDataType, rhs: SomePartiallyOrderedDataType) -> Bool { + lhs.width < rhs.width && lhs.height < rhs.height } - public static func ==(lhs: SomePartiallyOrderedDataType, rhs: SomePartiallyOrderedDataType) -> Bool { - return lhs.width == rhs.width && lhs.height == rhs.height + public static func == (lhs: SomePartiallyOrderedDataType, rhs: SomePartiallyOrderedDataType) -> Bool { + lhs.width == rhs.width && lhs.height == rhs.height } private let width: Int @@ -145,6 +147,6 @@ struct SomePartiallyOrderedDataType: Comparable, CustomStringConvertible { } public var description: String { - return "(w: \(self.width), h: \(self.height))" + "(w: \(self.width), h: \(self.height))" } } diff --git a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift index 5a40874674..8e0ada12ba 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingChannelTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import Atomics import NIOCore +import XCTest + @testable import NIOEmbedded @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) @@ -84,11 +85,14 @@ class AsyncTestingChannelTests: XCTestCase { let task1 = Task { try await channel.waitForInboundWrite(as: Int.self) } let task2 = Task { try await channel.waitForInboundWrite(as: Int.self) } let task3 = Task { try await channel.waitForInboundWrite(as: Int.self) } - try await XCTAsyncAssertEqual(Set([ - try await task1.value, - try await task2.value, - try await task3.value, - ]), [1, 2, 3]) + try await XCTAsyncAssertEqual( + Set([ + try await task1.value, + try await task2.value, + try await task3.value, + ]), + [1, 2, 3] + ) } try await channel.writeInbound(1) @@ -117,11 +121,14 @@ class AsyncTestingChannelTests: XCTestCase { let task1 = Task { try await channel.waitForOutboundWrite(as: Int.self) } let task2 = Task { try await channel.waitForOutboundWrite(as: Int.self) } let task3 = Task { try await channel.waitForOutboundWrite(as: Int.self) } - try await XCTAsyncAssertEqual(Set([ - try await task1.value, - try await task2.value, - try await task3.value, - ]), [1, 2, 3]) + try await XCTAsyncAssertEqual( + Set([ + try await task1.value, + try await task2.value, + try await task3.value, + ]), + [1, 2, 3] + ) } try await channel.writeOutbound(1) @@ -254,22 +261,33 @@ class AsyncTestingChannelTests: XCTestCase { let buffer = channel.allocator.buffer(capacity: 0) try await XCTAsyncAssertTrue(await channel.writeOutbound(buffer).isFull) - try await XCTAsyncAssertTrue(await channel.writeOutbound( - AddressedEnvelope(remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), - data: buffer)).isFull) + try await XCTAsyncAssertTrue( + await channel.writeOutbound( + AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), + data: buffer + ) + ).isFull + ) try await XCTAsyncAssertTrue(await channel.writeOutbound(buffer).isFull) - try await XCTAsyncAssertTrue(await channel.writeInbound(buffer).isFull) - try await XCTAsyncAssertTrue(await channel.writeInbound( - AddressedEnvelope(remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), - data: buffer)).isFull) + try await XCTAsyncAssertTrue( + await channel.writeInbound( + AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), + data: buffer + ) + ).isFull + ) try await XCTAsyncAssertTrue(await channel.writeInbound(buffer).isFull) - func check(expected: Expected.Type, - actual: Actual.Type, - file: StaticString = #filePath, - line: UInt = #line) async { + func check( + expected: Expected.Type, + actual: Actual.Type, + file: StaticString = #filePath, + line: UInt = #line + ) async { do { _ = try await channel.readOutbound(as: Expected.self) XCTFail("this should have failed", file: (file), line: line) @@ -337,7 +355,7 @@ class AsyncTestingChannelTests: XCTestCase { XCTAssertFalse(channel.isActive) } - private final class ExceptionThrowingInboundHandler : ChannelInboundHandler { + private final class ExceptionThrowingInboundHandler: ChannelInboundHandler { typealias InboundIn = String public func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -345,7 +363,7 @@ class AsyncTestingChannelTests: XCTestCase { } } - private final class ExceptionThrowingOutboundHandler : ChannelOutboundHandler { + private final class ExceptionThrowingOutboundHandler: ChannelOutboundHandler { typealias OutboundIn = String typealias OutboundOut = Never @@ -474,20 +492,20 @@ class AsyncTestingChannelTests: XCTestCase { } func testFinishWithRecursivelyScheduledTasks() async throws { - let channel = NIOAsyncTestingChannel() - let invocations = AtomicCounter() - - @Sendable func recursivelyScheduleAndIncrement() { - channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) { - invocations.increment() - recursivelyScheduleAndIncrement() - } + let channel = NIOAsyncTestingChannel() + let invocations = AtomicCounter() + + @Sendable func recursivelyScheduleAndIncrement() { + channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) { + invocations.increment() + recursivelyScheduleAndIncrement() } + } - recursivelyScheduleAndIncrement() + recursivelyScheduleAndIncrement() - _ = try await channel.finish() - XCTAssertEqual(invocations.load(), 1) + _ = try await channel.finish() + XCTAssertEqual(invocations.load(), 1) } func testSyncOptionsAreSupported() throws { @@ -536,20 +554,34 @@ class AsyncTestingChannelTests: XCTestCase { } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func XCTAsyncAssertTrue(_ predicate: @autoclosure () async throws -> Bool, file: StaticString = #filePath, line: UInt = #line) async rethrows { +private func XCTAsyncAssertTrue( + _ predicate: @autoclosure () async throws -> Bool, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { let result = try await predicate() XCTAssertTrue(result, file: file, line: line) } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func XCTAsyncAssertEqual(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows { +private func XCTAsyncAssertEqual( + _ lhs: @autoclosure () async throws -> Element, + _ rhs: @autoclosure () async throws -> Element, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { let lhsResult = try await lhs() let rhsResult = try await rhs() XCTAssertEqual(lhsResult, rhsResult, file: file, line: line) } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func XCTAsyncAssertThrowsError(_ expression: @autoclosure () async throws -> ResultType, file: StaticString = #filePath, line: UInt = #line, _ callback: Optional<(Error) -> Void> = nil) async { +private func XCTAsyncAssertThrowsError( + _ expression: @autoclosure () async throws -> ResultType, + file: StaticString = #filePath, + line: UInt = #line, + _ callback: ((Error) -> Void)? = nil +) async { do { let _ = try await expression() XCTFail("Did not throw", file: file, line: line) @@ -559,13 +591,21 @@ fileprivate func XCTAsyncAssertThrowsError(_ expression: @autoclosur } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func XCTAsyncAssertNil(_ expression: @autoclosure () async throws -> Any?, file: StaticString = #filePath, line: UInt = #line) async rethrows { +private func XCTAsyncAssertNil( + _ expression: @autoclosure () async throws -> Any?, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { let result = try await expression() XCTAssertNil(result, file: file, line: line) } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -fileprivate func XCTAsyncAssertNotNil(_ expression: @autoclosure () async throws -> Any?, file: StaticString = #filePath, line: UInt = #line) async rethrows { +private func XCTAsyncAssertNotNil( + _ expression: @autoclosure () async throws -> Any?, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { let result = try await expression() XCTAssertNotNil(result, file: file, line: line) } diff --git a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift index 0e67bd99d2..18dde6fbf8 100644 --- a/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift +++ b/Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift @@ -11,13 +11,15 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import Atomics +import NIOConcurrencyHelpers import NIOCore -@testable import NIOEmbedded import XCTest -import NIOConcurrencyHelpers -import Atomics -private class EmbeddedTestError: Error { } +@testable import NIOEmbedded + +private class EmbeddedTestError: Error {} @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) final class NIOAsyncTestingEventLoopTests: XCTestCase { @@ -453,13 +455,13 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase { let tasksRun = ManagedAtomic(0) let startTime = eventLoop.now - eventLoop.scheduleTask(in: .nanoseconds(3141592)) { - XCTAssertEqual(eventLoop.now, startTime + .nanoseconds(3141592)) + eventLoop.scheduleTask(in: .nanoseconds(3_141_592)) { + XCTAssertEqual(eventLoop.now, startTime + .nanoseconds(3_141_592)) tasksRun.wrappingIncrement(ordering: .relaxed) } - eventLoop.scheduleTask(in: .seconds(3141592)) { - XCTAssertEqual(eventLoop.now, startTime + .seconds(3141592)) + eventLoop.scheduleTask(in: .seconds(3_141_592)) { + XCTAssertEqual(eventLoop.now, startTime + .seconds(3_141_592)) tasksRun.wrappingIncrement(ordering: .relaxed) } diff --git a/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift b/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift index f6d64d00ee..0d82d20f3b 100644 --- a/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift +++ b/Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift @@ -12,8 +12,9 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest + @testable import NIOEmbedded class ChannelLifecycleHandler: ChannelInboundHandler { @@ -241,29 +242,40 @@ class EmbeddedChannelTest: XCTestCase { XCTAssertTrue(try channel.writeOutbound(ioData).isFull) XCTAssertTrue(try channel.writeOutbound(fileHandle).isFull) XCTAssertTrue(try channel.writeOutbound(fileRegion).isFull) - XCTAssertTrue(try channel.writeOutbound( - AddressedEnvelope(remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), - data: buffer)).isFull) + XCTAssertTrue( + try channel.writeOutbound( + AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), + data: buffer + ) + ).isFull + ) XCTAssertTrue(try channel.writeOutbound(buffer).isFull) XCTAssertTrue(try channel.writeOutbound(ioData).isFull) XCTAssertTrue(try channel.writeOutbound(fileRegion).isFull) - XCTAssertTrue(try channel.writeInbound(buffer).isFull) XCTAssertTrue(try channel.writeInbound(ioData).isFull) XCTAssertTrue(try channel.writeInbound(fileHandle).isFull) XCTAssertTrue(try channel.writeInbound(fileRegion).isFull) - XCTAssertTrue(try channel.writeInbound( - AddressedEnvelope(remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), - data: buffer)).isFull) + XCTAssertTrue( + try channel.writeInbound( + AddressedEnvelope( + remoteAddress: SocketAddress(ipAddress: "1.2.3.4", port: 5678), + data: buffer + ) + ).isFull + ) XCTAssertTrue(try channel.writeInbound(buffer).isFull) XCTAssertTrue(try channel.writeInbound(ioData).isFull) XCTAssertTrue(try channel.writeInbound(fileRegion).isFull) - func check(expected: Expected.Type, - actual: Actual.Type, - file: StaticString = #filePath, - line: UInt = #line) { + func check( + expected: Expected.Type, + actual: Actual.Type, + file: StaticString = #filePath, + line: UInt = #line + ) { do { _ = try channel.readOutbound(as: Expected.self) XCTFail("this should have failed", file: (file), line: line) @@ -336,7 +348,7 @@ class EmbeddedChannelTest: XCTestCase { XCTAssertFalse(channel.isActive) } - private final class ExceptionThrowingInboundHandler : ChannelInboundHandler { + private final class ExceptionThrowingInboundHandler: ChannelInboundHandler { typealias InboundIn = String public func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -344,7 +356,7 @@ class EmbeddedChannelTest: XCTestCase { } } - private final class ExceptionThrowingOutboundHandler : ChannelOutboundHandler { + private final class ExceptionThrowingOutboundHandler: ChannelOutboundHandler { typealias OutboundIn = String typealias OutboundOut = Never diff --git a/Tests/NIOEmbeddedTests/EmbeddedEventLoopTest.swift b/Tests/NIOEmbeddedTests/EmbeddedEventLoopTest.swift index 0f4ff6399e..f428c5534a 100644 --- a/Tests/NIOEmbeddedTests/EmbeddedEventLoopTest.swift +++ b/Tests/NIOEmbeddedTests/EmbeddedEventLoopTest.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore -@testable import NIOEmbedded import XCTest -import NIOConcurrencyHelpers -private class EmbeddedTestError: Error { } +@testable import NIOEmbedded + +private class EmbeddedTestError: Error {} public final class EmbeddedEventLoopTest: XCTestCase { func testExecuteDoesNotImmediatelyRunTasks() throws { @@ -358,13 +359,13 @@ public final class EmbeddedEventLoopTest: XCTestCase { let timeAtStart = eventLoop._now var tasksRun = 0 - eventLoop.scheduleTask(in: .nanoseconds(3141592)) { - XCTAssertEqual(eventLoop._now, timeAtStart + .nanoseconds(3141592)) + eventLoop.scheduleTask(in: .nanoseconds(3_141_592)) { + XCTAssertEqual(eventLoop._now, timeAtStart + .nanoseconds(3_141_592)) tasksRun += 1 } - eventLoop.scheduleTask(in: .seconds(3141592)) { - XCTAssertEqual(eventLoop._now, timeAtStart + .seconds(3141592)) + eventLoop.scheduleTask(in: .seconds(3_141_592)) { + XCTAssertEqual(eventLoop._now, timeAtStart + .seconds(3_141_592)) tasksRun += 1 } diff --git a/Tests/NIOEmbeddedTests/TestUtils.swift b/Tests/NIOEmbeddedTests/TestUtils.swift index ea0bf05bee..02fe152e7d 100644 --- a/Tests/NIOEmbeddedTests/TestUtils.swift +++ b/Tests/NIOEmbeddedTests/TestUtils.swift @@ -13,19 +13,26 @@ //===----------------------------------------------------------------------===// import Atomics import Foundation -import XCTest -import NIOCore import NIOConcurrencyHelpers +import NIOCore +import XCTest // FIXME: Duplicated with NIO -func assert(_ condition: @autoclosure () -> Bool, within time: TimeAmount, testInterval: TimeAmount? = nil, _ message: String = "condition not satisfied in time", file: StaticString = #filePath, line: UInt = #line) { +func assert( + _ condition: @autoclosure () -> Bool, + within time: TimeAmount, + testInterval: TimeAmount? = nil, + _ message: String = "condition not satisfied in time", + file: StaticString = #filePath, + line: UInt = #line +) { let testInterval = testInterval ?? TimeAmount.nanoseconds(time.nanoseconds / 5) let endTime = NIODeadline.now() + time repeat { if condition() { return } usleep(UInt32(testInterval.nanoseconds / 1000)) - } while (NIODeadline.now() < endTime) + } while NIODeadline.now() < endTime if !condition() { XCTFail(message, file: (file), line: line) @@ -45,17 +52,17 @@ extension EventLoopFuture { } else { let lock = NIOLock() let group = DispatchGroup() - var fulfilled = false // protected by lock + var fulfilled = false // protected by lock group.enter() self.eventLoop.execute { - let isFulfilled = self.isFulfilled // This will now enter the above branch. + let isFulfilled = self.isFulfilled // This will now enter the above branch. lock.withLock { fulfilled = isFulfilled } group.leave() } - group.wait() // this is very nasty but this is for tests only, so... + group.wait() // this is very nasty but this is for tests only, so... return lock.withLock { fulfilled } } } diff --git a/Tests/NIOEmbeddedTests/XCTest+AsyncAwait.swift b/Tests/NIOEmbeddedTests/XCTest+AsyncAwait.swift index 59665ad561..e91c4a5a20 100644 --- a/Tests/NIOEmbeddedTests/XCTest+AsyncAwait.swift +++ b/Tests/NIOEmbeddedTests/XCTest+AsyncAwait.swift @@ -24,21 +24,20 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - * Copyright 2021, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// +// Copyright 2021, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. import XCTest diff --git a/Tests/NIOFileSystemIntegrationTests/BufferedWriterTests.swift b/Tests/NIOFileSystemIntegrationTests/BufferedWriterTests.swift index d2b875e3e9..47b6a5e98e 100644 --- a/Tests/NIOFileSystemIntegrationTests/BufferedWriterTests.swift +++ b/Tests/NIOFileSystemIntegrationTests/BufferedWriterTests.swift @@ -218,16 +218,16 @@ final class BufferedWriterTests: XCTestCase { ) } XCTAssertEqual(writtenBytes, 128) - + guard let fileInfo = try await fs.info(forFileAt: path) else { XCTFail() return } - + // Test that the newly created file contains all the 128 characters. XCTAssertEqual(fileInfo.size, 128) } - + func testBufferedWriterReclaimsStorageAfterLargeWrite() async throws { let fs = FileSystem.shared let path = try await fs.temporaryFilePath() diff --git a/Tests/NIOFileSystemIntegrationTests/FileHandleTests.swift b/Tests/NIOFileSystemIntegrationTests/FileHandleTests.swift index cbfd56887a..aee12c5301 100644 --- a/Tests/NIOFileSystemIntegrationTests/FileHandleTests.swift +++ b/Tests/NIOFileSystemIntegrationTests/FileHandleTests.swift @@ -28,7 +28,7 @@ final class FileHandleTests: XCTestCase { .lexicallyNormalized() private static func temporaryFileName() -> FilePath { - return FilePath("swift-filesystem-tests-\(UInt64.random(in: .min ... .max))") + FilePath("swift-filesystem-tests-\(UInt64.random(in: .min ... .max))") } func withTemporaryFile( @@ -223,7 +223,7 @@ final class FileHandleTests: XCTestCase { try await self.withTemporaryFile { handle in // Check we can successfully return a value. let value = try await handle.withUnsafeDescriptor { descriptor in - return 42 + 42 } XCTAssertEqual(value, 42) } @@ -306,7 +306,7 @@ final class FileHandleTests: XCTestCase { func testWriteAndReadUnseekableFile() async throws { let privateTempDirPath = try await FileSystem.shared.createTemporaryDirectory(template: "test-XXX") self.addTeardownBlock { - try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) + try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) } guard mkfifo(privateTempDirPath.appending("fifo").string, 0o644) == 0 else { @@ -327,7 +327,7 @@ final class FileHandleTests: XCTestCase { func testWriteAndReadUnseekableFileOverMaximumSizeAllowedThrowsError() async throws { let privateTempDirPath = try await FileSystem.shared.createTemporaryDirectory(template: "test-XXX") self.addTeardownBlock { - try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) + try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) } guard mkfifo(privateTempDirPath.appending("fifo").string, 0o644) == 0 else { @@ -351,7 +351,7 @@ final class FileHandleTests: XCTestCase { func testWriteAndReadUnseekableFileWithOffsetsThrows() async throws { let privateTempDirPath = try await FileSystem.shared.createTemporaryDirectory(template: "test-XXX") self.addTeardownBlock { - try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) + try await FileSystem.shared.removeItem(at: privateTempDirPath, recursively: true) } guard mkfifo(privateTempDirPath.appending("fifo").string, 0o644) == 0 else { diff --git a/Tests/NIOFileSystemIntegrationTests/FileSystemTests.swift b/Tests/NIOFileSystemIntegrationTests/FileSystemTests.swift index 4a2be1ac4f..673f119f78 100644 --- a/Tests/NIOFileSystemIntegrationTests/FileSystemTests.swift +++ b/Tests/NIOFileSystemIntegrationTests/FileSystemTests.swift @@ -61,7 +61,7 @@ extension FileSystem { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class FileSystemTests: XCTestCase { - var fs: FileSystem { return .shared } + var fs: FileSystem { .shared } func testOpenFileForReading() async throws { try await self.fs.withFileHandle(forReadingAt: .testDataReadme) { file in @@ -708,7 +708,7 @@ final class FileSystemTests: XCTestCase { throw error } shouldCopyFile: { source, destination in // Copy the directory and 'file-1-regular' - return (source == path) || (source.lastComponent!.string == "file-0-regular") + (source == path) || (source.lastComponent!.string == "file-0-regular") } let paths = try await self.fs.withDirectoryHandle(atPath: copyPath) { dir in diff --git a/Tests/NIOFileSystemTests/Internal/CancellationTests.swift b/Tests/NIOFileSystemTests/Internal/CancellationTests.swift index 460d1c19cf..cd18efac14 100644 --- a/Tests/NIOFileSystemTests/Internal/CancellationTests.swift +++ b/Tests/NIOFileSystemTests/Internal/CancellationTests.swift @@ -48,7 +48,7 @@ final class CancellationTests: XCTestCase { let ranTearDown = ManagedAtomic(false) let isCancelled = try await withUncancellableTearDown { - return Task.isCancelled + Task.isCancelled } tearDown: { _ in ranTearDown.store(true, ordering: .releasing) } diff --git a/Tests/NIOFileSystemTests/Internal/Concurrency Primitives/BufferedStreamTests.swift b/Tests/NIOFileSystemTests/Internal/Concurrency Primitives/BufferedStreamTests.swift index 72fe020053..4791e66f0e 100644 --- a/Tests/NIOFileSystemTests/Internal/Concurrency Primitives/BufferedStreamTests.swift +++ b/Tests/NIOFileSystemTests/Internal/Concurrency Primitives/BufferedStreamTests.swift @@ -646,7 +646,7 @@ final class BufferedStreamTests: XCTestCase { try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await stream.first { _ in true } + try await stream.first { _ in true } } // This is always going to be a bit racy since we need the call to next() suspend @@ -666,7 +666,7 @@ final class BufferedStreamTests: XCTestCase { try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await stream.first { _ in true } + try await stream.first { _ in true } } // This is always going to be a bit racy since we need the call to next() suspend @@ -916,7 +916,7 @@ final class BufferedStreamTests: XCTestCase { try await withThrowingTaskGroup(of: Int?.self) { group in group.addTask { - return try await stream.first { $0 == 2 } + try await stream.first { $0 == 2 } } // This is always going to be a bit racy since we need the call to next() suspend diff --git a/Tests/NIOFileSystemTests/Internal/MockingInfrastructure.swift b/Tests/NIOFileSystemTests/Internal/MockingInfrastructure.swift index d8e4f31521..b61b3444f1 100644 --- a/Tests/NIOFileSystemTests/Internal/MockingInfrastructure.swift +++ b/Tests/NIOFileSystemTests/Internal/MockingInfrastructure.swift @@ -12,14 +12,12 @@ // //===----------------------------------------------------------------------===// -/* - This source file is part of the Swift System open source project - - Copyright (c) 2020 Apple Inc. and the Swift System project authors - Licensed under Apache License v2.0 with Runtime Library Exception - - See https://swift.org/LICENSE.txt for license information - */ +//This source file is part of the Swift System open source project +// +//Copyright (c) 2020 Apple Inc. and the Swift System project authors +//Licensed under Apache License v2.0 with Runtime Library Exception +// +//See https://swift.org/LICENSE.txt for license information #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(Linux) || os(Android) @_spi(Testing) import _NIOFileSystem @@ -158,7 +156,7 @@ internal struct MockTestCase: TestCase { var expected: Trace.Entry var interruptBehavior: InterruptBehavior - var interruptable: Bool { return interruptBehavior == .interruptable } + var interruptable: Bool { interruptBehavior == .interruptable } internal enum InterruptBehavior { // Retry the syscall on EINTR diff --git a/Tests/NIOFileSystemTests/Internal/SyscallTests.swift b/Tests/NIOFileSystemTests/Internal/SyscallTests.swift index abd7209c95..e971c5f943 100644 --- a/Tests/NIOFileSystemTests/Internal/SyscallTests.swift +++ b/Tests/NIOFileSystemTests/Internal/SyscallTests.swift @@ -166,7 +166,7 @@ final class SyscallTests: XCTestCase { let testCases = [ MockTestCase(name: "link", .noInterrupt, "src", "dst") { _ in try Syscall.link(from: "src", to: "dst").get() - }, + } ] testCases.run() } @@ -175,7 +175,7 @@ final class SyscallTests: XCTestCase { let testCases = [ MockTestCase(name: "unlink", .noInterrupt, "path") { _ in try Syscall.unlink(path: "path").get() - }, + } ] testCases.run() } diff --git a/Tests/NIOFoundationCompatTests/ByteBuffer+UUIDTests.swift b/Tests/NIOFoundationCompatTests/ByteBuffer+UUIDTests.swift index 9505725a30..783c8d4414 100644 --- a/Tests/NIOFoundationCompatTests/ByteBuffer+UUIDTests.swift +++ b/Tests/NIOFoundationCompatTests/ByteBuffer+UUIDTests.swift @@ -19,8 +19,12 @@ import XCTest final class ByteBufferUUIDTests: XCTestCase { func testSetUUIDBytes() { - let uuid = UUID(uuid: (0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf)) + let uuid = UUID( + uuid: ( + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + ) + ) var buffer = ByteBuffer() XCTAssertEqual(buffer.storageCapacity, 0) @@ -38,8 +42,12 @@ final class ByteBufferUUIDTests: XCTestCase { var buffer = ByteBuffer() buffer.writeRepeatingByte(.max, count: 32) - let uuid = UUID(uuid: (0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf)) + let uuid = UUID( + uuid: ( + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + ) + ) buffer.setUUIDBytes(uuid, at: buffer.readerIndex + 4) XCTAssertEqual(buffer.readBytes(length: 4), Array(repeating: .max, count: 4)) @@ -62,20 +70,33 @@ final class ByteBufferUUIDTests: XCTestCase { } func testWriteUUIDBytesIntoEmptyBuffer() { - let uuid = UUID(uuid: (0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf)) + let uuid = UUID( + uuid: ( + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + ) + ) var buffer = ByteBuffer() XCTAssertEqual(buffer.writeUUIDBytes(uuid), 16) - XCTAssertEqual(buffer.readableBytesView, [0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf]) + XCTAssertEqual( + buffer.readableBytesView, + [ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, + ] + ) XCTAssertEqual(buffer.readableBytes, 16) XCTAssertEqual(buffer.writerIndex, 16) } func testWriteUUIDBytesIntoNonEmptyBuffer() { - let uuid = UUID(uuid: (0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf)) + let uuid = UUID( + uuid: ( + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + ) + ) var buffer = ByteBuffer() buffer.writeRepeatingByte(42, count: 10) @@ -83,13 +104,19 @@ final class ByteBufferUUIDTests: XCTestCase { XCTAssertEqual(buffer.readableBytes, 26) XCTAssertEqual(buffer.writerIndex, 26) - XCTAssertEqual(buffer.readableBytesView.dropFirst(10), - [0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf]) + XCTAssertEqual( + buffer.readableBytesView.dropFirst(10), + [0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf] + ) } func testReadUUID() { - let uuid = UUID(uuid: (0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, - 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf)) + let uuid = UUID( + uuid: ( + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf + ) + ) var buffer = ByteBuffer() XCTAssertEqual(buffer.writeUUIDBytes(uuid), 16) XCTAssertEqual(buffer.readUUIDBytes(), uuid) @@ -106,8 +133,10 @@ final class ByteBufferUUIDTests: XCTestCase { XCTAssertEqual(buffer.readerIndex, 0) buffer.writeRepeatingByte(0, count: 8) - XCTAssertEqual(buffer.readUUIDBytes(), - UUID(uuid: (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))) + XCTAssertEqual( + buffer.readUUIDBytes(), + UUID(uuid: (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) + ) XCTAssertEqual(buffer.readerIndex, 16) } } diff --git a/Tests/NIOFoundationCompatTests/ByteBufferDataProtocolTests.swift b/Tests/NIOFoundationCompatTests/ByteBufferDataProtocolTests.swift index 00de9ff70e..6e469a1278 100644 --- a/Tests/NIOFoundationCompatTests/ByteBufferDataProtocolTests.swift +++ b/Tests/NIOFoundationCompatTests/ByteBufferDataProtocolTests.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest import NIOCore import NIOFoundationCompat +import XCTest struct FakeContiguousBytes: ContiguousBytes { func withUnsafeBytes(_ block: (UnsafeRawBufferPointer) throws -> T) rethrows -> T { @@ -68,8 +68,10 @@ class ByteBufferDataProtocolTests: XCTestCase { buffer.writeInteger(UInt64.max) buffer.writeInteger(UInt64.max) buffer.setData(dd, at: 4) - XCTAssertEqual(buffer.readBytes(length: buffer.readableBytes), - [0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0xFF, 0xFF, 0xFF, 0xFF]) + XCTAssertEqual( + buffer.readBytes(length: buffer.readableBytes), + [0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0xFF, 0xFF, 0xFF, 0xFF] + ) } func testWriteContiguousBytes() { @@ -87,7 +89,9 @@ class ByteBufferDataProtocolTests: XCTestCase { b.writeInteger(UInt64.min) b.setContiguousBytes(fake, at: 4) - XCTAssertEqual(b.readBytes(length: b.readableBytes), - [0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00]) + XCTAssertEqual( + b.readBytes(length: b.readableBytes), + [0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00] + ) } } diff --git a/Tests/NIOFoundationCompatTests/ByteBufferView+MutableDataProtocolTest.swift b/Tests/NIOFoundationCompatTests/ByteBufferView+MutableDataProtocolTest.swift index 18850aeb3b..a08a8443d2 100644 --- a/Tests/NIOFoundationCompatTests/ByteBufferView+MutableDataProtocolTest.swift +++ b/Tests/NIOFoundationCompatTests/ByteBufferView+MutableDataProtocolTest.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest import NIOCore import NIOFoundationCompat +import XCTest class ByteBufferViewDataProtocolTests: XCTestCase { @@ -32,7 +32,7 @@ class ByteBufferViewDataProtocolTests: XCTestCase { view.resetBytes(in: 2...4) XCTAssertTrue(view.elementsEqual([0, 0, 0, 0, 0])) } - + func testCreateDataFromBuffer() { let testString = "some sample bytes" let buffer = ByteBuffer(ByteBufferView(testString.utf8)) diff --git a/Tests/NIOFoundationCompatTests/Codable+ByteBufferTest.swift b/Tests/NIOFoundationCompatTests/Codable+ByteBufferTest.swift index 32e64ae4e3..f40b54ac4a 100644 --- a/Tests/NIOFoundationCompatTests/Codable+ByteBufferTest.swift +++ b/Tests/NIOFoundationCompatTests/Codable+ByteBufferTest.swift @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest import NIOCore import NIOFoundationCompat +import XCTest class CodableByteBufferTest: XCTestCase { var buffer: ByteBuffer! @@ -68,10 +68,16 @@ class CodableByteBufferTest: XCTestCase { self.buffer.writeString("GARBAGE {}!!? / GARBAGE") let expectedSandI = StringAndInt(string: "hello", int: 42) - XCTAssertNoThrow(XCTAssertEqual(expectedSandI, - try self.buffer.getJSONDecodable(StringAndInt.self, - at: beginIndex, - length: endIndex - beginIndex))) + XCTAssertNoThrow( + XCTAssertEqual( + expectedSandI, + try self.buffer.getJSONDecodable( + StringAndInt.self, + at: beginIndex, + length: endIndex - beginIndex + ) + ) + ) } func testGetJSONDecodableFromBufferFailsBecauseShort() { @@ -80,9 +86,13 @@ class CodableByteBufferTest: XCTestCase { self.buffer.writeString(#"{"string": "hello", "int": 42}"#) let endIndex = self.buffer.writerIndex - XCTAssertThrowsError(try self.buffer.getJSONDecodable(StringAndInt.self, - at: beginIndex, - length: endIndex - beginIndex - 1)) { error in + XCTAssertThrowsError( + try self.buffer.getJSONDecodable( + StringAndInt.self, + at: beginIndex, + length: endIndex - beginIndex - 1 + ) + ) { error in XCTAssert(error is DecodingError) } } @@ -94,9 +104,15 @@ class CodableByteBufferTest: XCTestCase { self.buffer.writeString("GARBAGE {}!!? / GARBAGE") let expectedSandI = StringAndInt(string: "hello", int: 42) - XCTAssertNoThrow(XCTAssertEqual(expectedSandI, - try self.buffer.readJSONDecodable(StringAndInt.self, - length: endIndex - beginIndex))) + XCTAssertNoThrow( + XCTAssertEqual( + expectedSandI, + try self.buffer.readJSONDecodable( + StringAndInt.self, + length: endIndex - beginIndex + ) + ) + ) } func testReadJSONDecodableFromBufferFailsBecauseShort() { @@ -104,8 +120,12 @@ class CodableByteBufferTest: XCTestCase { self.buffer.writeString(#"{"string": "hello", "int": 42}"#) let endIndex = self.buffer.writerIndex - XCTAssertThrowsError(try self.buffer.readJSONDecodable(StringAndInt.self, - length: endIndex - beginIndex - 1)) { error in + XCTAssertThrowsError( + try self.buffer.readJSONDecodable( + StringAndInt.self, + length: endIndex - beginIndex - 1 + ) + ) { error in XCTAssert(error is DecodingError) } } @@ -129,19 +149,39 @@ class CodableByteBufferTest: XCTestCase { let expectedSandI = StringAndInt(string: "hello", int: 42) self.buffer.writeString(String(repeating: "{", count: 1000)) var writtenBytes: Int? - XCTAssertNoThrow(writtenBytes = try self.buffer.setJSONEncodable(expectedSandI, - at: self.buffer.readerIndex + 123)) - XCTAssertNoThrow(try self.buffer.setJSONEncodable(expectedSandI, - encoder: JSONEncoder(), - at: self.buffer.readerIndex + 501)) - XCTAssertNoThrow(XCTAssertEqual(expectedSandI, - try self.buffer.getJSONDecodable(StringAndInt.self, - at: self.buffer.readerIndex + 123, - length: writtenBytes ?? -1))) - XCTAssertNoThrow(XCTAssertEqual(expectedSandI, - try self.buffer.getJSONDecodable(StringAndInt.self, - at: self.buffer.readerIndex + 501, - length: writtenBytes ?? -1))) + XCTAssertNoThrow( + writtenBytes = try self.buffer.setJSONEncodable( + expectedSandI, + at: self.buffer.readerIndex + 123 + ) + ) + XCTAssertNoThrow( + try self.buffer.setJSONEncodable( + expectedSandI, + encoder: JSONEncoder(), + at: self.buffer.readerIndex + 501 + ) + ) + XCTAssertNoThrow( + XCTAssertEqual( + expectedSandI, + try self.buffer.getJSONDecodable( + StringAndInt.self, + at: self.buffer.readerIndex + 123, + length: writtenBytes ?? -1 + ) + ) + ) + XCTAssertNoThrow( + XCTAssertEqual( + expectedSandI, + try self.buffer.getJSONDecodable( + StringAndInt.self, + at: self.buffer.readerIndex + 501, + length: writtenBytes ?? -1 + ) + ) + ) } func testFailingReadsDoNotChangeReaderIndex() { @@ -149,14 +189,18 @@ class CodableByteBufferTest: XCTestCase { var writtenBytes: Int? XCTAssertNoThrow(writtenBytes = try self.buffer.writeJSONEncodable(expectedSandI)) for length in 0..<(writtenBytes ?? 0) { - XCTAssertThrowsError(try self.buffer.readJSONDecodable(StringAndInt.self, - length: length)) { error in + XCTAssertThrowsError( + try self.buffer.readJSONDecodable( + StringAndInt.self, + length: length + ) + ) { error in XCTAssert(error is DecodingError) } } XCTAssertNoThrow(try self.buffer.readJSONDecodable(StringAndInt.self, length: writtenBytes ?? -1)) } - + func testCustomEncoderIsRespected() { let expectedDate = Date(timeIntervalSinceReferenceDate: 86400) let strategyExpectation = XCTestExpectation(description: "Custom encoding strategy invoked") @@ -185,7 +229,9 @@ class CodableByteBufferTest: XCTestCase { return Date(timeIntervalSinceReferenceDate: try container.decode(Double.self)) }) XCTAssertNoThrow(try encoder.encode(["date": expectedDate], into: &self.buffer)) - XCTAssertNoThrow(XCTAssertEqual(["date": expectedDate], try decoder.decode(Dictionary.self, from: self.buffer))) + XCTAssertNoThrow( + XCTAssertEqual(["date": expectedDate], try decoder.decode(Dictionary.self, from: self.buffer)) + ) XCTAssertEqual(XCTWaiter().wait(for: [strategyExpectation], timeout: 0.0), .completed) } @@ -207,10 +253,16 @@ class CodableByteBufferTest: XCTestCase { return Date(timeIntervalSinceReferenceDate: try container.decode(Double.self)) }) XCTAssertNoThrow(try self.buffer.writeJSONEncodable(["date": expectedDate], encoder: encoder)) - XCTAssertNoThrow(XCTAssertEqual(["date": expectedDate], - try self.buffer.readJSONDecodable(Dictionary.self, - decoder: decoder, - length: self.buffer.readableBytes))) + XCTAssertNoThrow( + XCTAssertEqual( + ["date": expectedDate], + try self.buffer.readJSONDecodable( + Dictionary.self, + decoder: decoder, + length: self.buffer.readableBytes + ) + ) + ) XCTAssertEqual(XCTWaiter().wait(for: [decoderStrategyExpectation], timeout: 0.0), .completed) XCTAssertEqual(XCTWaiter().wait(for: [encoderStrategyExpectation], timeout: 0.0), .completed) } diff --git a/Tests/NIOFoundationCompatTests/JSONSerialization+ByteBufferTest.swift b/Tests/NIOFoundationCompatTests/JSONSerialization+ByteBufferTest.swift index 8e3de7c239..97ff97b682 100644 --- a/Tests/NIOFoundationCompatTests/JSONSerialization+ByteBufferTest.swift +++ b/Tests/NIOFoundationCompatTests/JSONSerialization+ByteBufferTest.swift @@ -13,44 +13,58 @@ //===----------------------------------------------------------------------===// import Foundation -import XCTest import NIOCore import NIOFoundationCompat +import XCTest class JSONSerializationByteBufferTest: XCTestCase { - + func testSerializationRoundTrip() { - + let array = ["String1", "String2", "String3"] let dictionary = ["key1": "val1", "key2": "val2", "key3": "val3"] - + var dataArray = Data() var dataDictionary = Data() - + XCTAssertTrue(JSONSerialization.isValidJSONObject(array), "Array object cannot be converted to JSON") XCTAssertTrue(JSONSerialization.isValidJSONObject(dictionary), "Dictionary object cannot be converted to JSON") - + XCTAssertNoThrow(dataArray = try JSONSerialization.data(withJSONObject: array, options: .prettyPrinted)) - XCTAssertNoThrow(dataDictionary = try JSONSerialization.data(withJSONObject: dictionary, options: .prettyPrinted)) - + XCTAssertNoThrow( + dataDictionary = try JSONSerialization.data(withJSONObject: dictionary, options: .prettyPrinted) + ) + let arrayByteBuffer = ByteBuffer(data: dataArray) let dictByteBuffer = ByteBuffer(data: dataDictionary) - + var foundationArray: [String] = [] var foundationDict: [String: String] = [:] - + // Mutable containers comparison. - XCTAssertNoThrow(foundationArray = try JSONSerialization.jsonObject(with: arrayByteBuffer, options: .mutableContainers) as! [String]) + XCTAssertNoThrow( + foundationArray = + try JSONSerialization.jsonObject(with: arrayByteBuffer, options: .mutableContainers) as! [String] + ) XCTAssertEqual(foundationArray, array) - - XCTAssertNoThrow(foundationDict = try JSONSerialization.jsonObject(with: dictByteBuffer, options: .mutableContainers) as! [String : String]) + + XCTAssertNoThrow( + foundationDict = + try JSONSerialization.jsonObject(with: dictByteBuffer, options: .mutableContainers) as! [String: String] + ) XCTAssertEqual(foundationDict, dictionary) - + // Mutable leaves comparison. - XCTAssertNoThrow(foundationArray = try JSONSerialization.jsonObject(with: arrayByteBuffer, options: .mutableLeaves) as! [String]) + XCTAssertNoThrow( + foundationArray = + try JSONSerialization.jsonObject(with: arrayByteBuffer, options: .mutableLeaves) as! [String] + ) XCTAssertEqual(foundationArray, array) - - XCTAssertNoThrow(foundationDict = try JSONSerialization.jsonObject(with: dictByteBuffer, options: .mutableLeaves) as! [String : String]) + + XCTAssertNoThrow( + foundationDict = + try JSONSerialization.jsonObject(with: dictByteBuffer, options: .mutableLeaves) as! [String: String] + ) XCTAssertEqual(foundationDict, dictionary) } } diff --git a/Tests/NIOHTTP1Tests/ByteBufferUtilsTest.swift b/Tests/NIOHTTP1Tests/ByteBufferUtilsTest.swift index 43301862c9..dfc0ea6fd6 100644 --- a/Tests/NIOHTTP1Tests/ByteBufferUtilsTest.swift +++ b/Tests/NIOHTTP1Tests/ByteBufferUtilsTest.swift @@ -12,45 +12,64 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore +import XCTest + @testable import NIOHTTP1 -fileprivate enum DummyError: Error { +private enum DummyError: Error { case err } class ByteBufferUtilsTest: XCTestCase { - + func testComparators() { let someByteBuffer: ByteBuffer = ByteBuffer(string: "fiRSt") XCTAssert( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "first".utf8)) + to: "first".utf8 + ) + ) XCTAssert( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "fiRSt".utf8)) + to: "fiRSt".utf8 + ) + ) XCTAssert( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "fIrst".utf8)) + to: "fIrst".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "fIrt".utf8)) + to: "fIrt".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "firsta".utf8)) + to: "firsta".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "afirst".utf8)) + to: "afirst".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "eiRSt".utf8)) + to: "eiRSt".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "fIrso".utf8)) + to: "fIrso".utf8 + ) + ) XCTAssertFalse( someByteBuffer.readableBytesView.compareCaseInsensitiveASCIIBytes( - to: "firot".utf8)) + to: "firot".utf8 + ) + ) } private func byteBufferView(string: String) -> ByteBufferView { @@ -61,16 +80,42 @@ class ByteBufferUtilsTest: XCTestCase { } func testTrimming() { - XCTAssertEqual(byteBufferView(string: " first").trimSpaces().map({CChar($0)}), byteBufferView(string: "first").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: " first ").trimSpaces().map({CChar($0)}), byteBufferView(string: "first").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: "first ").trimSpaces().map({CChar($0)}), byteBufferView(string: "first").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: "first").trimSpaces().map({CChar($0)}), byteBufferView(string: "first").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: " \t\t fi rst").trimSpaces().map({CChar($0)}), byteBufferView(string: "fi rst").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: " firs t \t ").trimSpaces().map({CChar($0)}), byteBufferView(string: "firs t").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: "f\t irst ").trimSpaces().map({CChar($0)}), byteBufferView(string: "f\t irst").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: "f i rs t").trimSpaces().map({CChar($0)}), byteBufferView(string: "f i rs t").map({CChar($0)})) - XCTAssertEqual(byteBufferView(string: " \t \t ").trimSpaces().map({CChar($0)}), - byteBufferView(string: "").map({CChar($0)})) + XCTAssertEqual( + byteBufferView(string: " first").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "first").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: " first ").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "first").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: "first ").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "first").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: "first").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "first").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: " \t\t fi rst").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "fi rst").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: " firs t \t ").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "firs t").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: "f\t irst ").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "f\t irst").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: "f i rs t").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "f i rs t").map({ CChar($0) }) + ) + XCTAssertEqual( + byteBufferView(string: " \t \t ").trimSpaces().map({ CChar($0) }), + byteBufferView(string: "").map({ CChar($0) }) + ) } } diff --git a/Tests/NIOHTTP1Tests/ContentLengthTests.swift b/Tests/NIOHTTP1Tests/ContentLengthTests.swift index b295df6f47..3e61316581 100644 --- a/Tests/NIOHTTP1Tests/ContentLengthTests.swift +++ b/Tests/NIOHTTP1Tests/ContentLengthTests.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOHTTP1 +import XCTest final class ContentLengthTests: XCTestCase { @@ -73,7 +73,8 @@ final class ContentLengthTests: XCTestCase { // First one is fine, the extra bytes will be treated as the next request XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) // Which means the next request is now malformed - XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { + error in XCTAssertEqual(error as? HTTPParserError, .invalidMethod) } @@ -96,7 +97,8 @@ final class ContentLengthTests: XCTestCase { // The original request is still 26 bytes short. Sending the request once more will complete it XCTAssertNoThrow(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) // The leftover bytes from the previous write (we wrote 100 bytes where it wanted 26) will form a new malformed request - XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { error in + XCTAssertThrowsError(try channel.receiveRequestAndSendResponse(request: badRequest, sendResponse: true)) { + error in XCTAssertEqual(error as? HTTPParserError, .invalidMethod) } @@ -110,7 +112,9 @@ extension EmbeddedChannel { /// Throws if receiving the response fails fileprivate func sendRequestAndReceiveResponse(response: String) throws { // Send a request - XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/")))) + XCTAssertNoThrow( + try self.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/"))) + ) XCTAssertNoThrow(try self.writeOutbound(HTTPClientRequestPart.end(nil))) // Receive a response try self.writeInbound(ByteBuffer(string: response)) diff --git a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift index 81c0822b6a..537fa0740a 100644 --- a/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPClientUpgradeTests.swift @@ -12,29 +12,32 @@ // //===----------------------------------------------------------------------===// -import XCTest import Dispatch -@testable import NIOCore import NIOEmbedded +import XCTest + +@testable import NIOCore @testable import NIOHTTP1 extension EmbeddedChannel { - + fileprivate func readByteBufferOutputAsString() throws -> String? { - + if let requestData: IOData = try self.readOutbound(), - case .byteBuffer(var requestBuffer) = requestData { - + case .byteBuffer(var requestBuffer) = requestData + { + return requestBuffer.readString(length: requestBuffer.readableBytes) } - + return nil } } #if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader where UpgradeResult == Bool {} +protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader, NIOTypedHTTPClientProtocolUpgrader +where UpgradeResult == Bool {} #else @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrader {} @@ -43,31 +46,36 @@ protocol TypedAndUntypedHTTPClientProtocolUpgrader: NIOHTTPClientProtocolUpgrade private final class SuccessfulClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgrader { fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] - fileprivate let upgradeHeaders: [(String,String)] - + fileprivate let upgradeHeaders: [(String, String)] + private(set) var addCustomUpgradeRequestHeadersCallCount = 0 private(set) var shouldAllowUpgradeCallCount = 0 private(set) var upgradeContextResponseCallCount = 0 - - fileprivate init(forProtocol `protocol`: String, requiredUpgradeHeaders: [String] = [], upgradeHeaders: [(String,String)] = []) { + + fileprivate init( + forProtocol `protocol`: String, + requiredUpgradeHeaders: [String] = [], + upgradeHeaders: [(String, String)] = [] + ) { self.supportedProtocol = `protocol` self.requiredUpgradeHeaders = requiredUpgradeHeaders self.upgradeHeaders = upgradeHeaders } - + fileprivate func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { self.addCustomUpgradeRequestHeadersCallCount += 1 for (name, value) in self.upgradeHeaders { upgradeRequestHeaders.replaceOrAdd(name: name, value: value) } } - + fileprivate func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { self.shouldAllowUpgradeCallCount += 1 return true } - - fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + + fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture + { self.upgradeContextResponseCallCount += 1 return context.channel.eventLoop.makeSucceededFuture(()) } @@ -82,28 +90,31 @@ private final class ExplodingClientUpgrader: TypedAndUntypedHTTPClientProtocolUp fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] - fileprivate let upgradeHeaders: [(String,String)] - - fileprivate init(forProtocol `protocol`: String, - requiredUpgradeHeaders: [String] = [], - upgradeHeaders: [(String,String)] = []) { + fileprivate let upgradeHeaders: [(String, String)] + + fileprivate init( + forProtocol `protocol`: String, + requiredUpgradeHeaders: [String] = [], + upgradeHeaders: [(String, String)] = [] + ) { self.supportedProtocol = `protocol` self.requiredUpgradeHeaders = requiredUpgradeHeaders self.upgradeHeaders = upgradeHeaders } - + fileprivate func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { for (name, value) in self.upgradeHeaders { upgradeRequestHeaders.replaceOrAdd(name: name, value: value) } } - + fileprivate func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { XCTFail("This method should not be called.") return false } - - fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + + fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture + { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } @@ -118,31 +129,34 @@ private final class DenyingClientUpgrader: TypedAndUntypedHTTPClientProtocolUpgr fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] - fileprivate let upgradeHeaders: [(String,String)] + fileprivate let upgradeHeaders: [(String, String)] private(set) var addCustomUpgradeRequestHeadersCallCount = 0 - - fileprivate init(forProtocol `protocol`: String, - requiredUpgradeHeaders: [String] = [], - upgradeHeaders: [(String,String)] = []) { - + + fileprivate init( + forProtocol `protocol`: String, + requiredUpgradeHeaders: [String] = [], + upgradeHeaders: [(String, String)] = [] + ) { + self.supportedProtocol = `protocol` self.requiredUpgradeHeaders = requiredUpgradeHeaders self.upgradeHeaders = upgradeHeaders } - + fileprivate func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { self.addCustomUpgradeRequestHeadersCallCount += 1 for (name, value) in self.upgradeHeaders { upgradeRequestHeaders.replaceOrAdd(name: name, value: value) } } - + fileprivate func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - return false + false } - - fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + + fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture + { XCTFail("Upgrade should not be called.") return context.channel.eventLoop.makeSucceededFuture(()) } @@ -157,31 +171,34 @@ private final class UpgradeDelayClientUpgrader: TypedAndUntypedHTTPClientProtoco fileprivate let supportedProtocol: String fileprivate let requiredUpgradeHeaders: [String] - fileprivate let upgradeHeaders: [(String,String)] - + fileprivate let upgradeHeaders: [(String, String)] + fileprivate let upgradedHandler = SimpleUpgradedHandler() - + private var upgradePromise: EventLoopPromise? - fileprivate init(forProtocol `protocol`: String, - requiredUpgradeHeaders: [String] = [], - upgradeHeaders: [(String,String)] = []) { + fileprivate init( + forProtocol `protocol`: String, + requiredUpgradeHeaders: [String] = [], + upgradeHeaders: [(String, String)] = [] + ) { self.supportedProtocol = `protocol` self.requiredUpgradeHeaders = requiredUpgradeHeaders self.upgradeHeaders = upgradeHeaders } - + fileprivate func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { for (name, value) in self.upgradeHeaders { upgradeRequestHeaders.replaceOrAdd(name: name, value: value) } } - + fileprivate func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - return true + true } - fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { + fileprivate func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture + { self.upgradePromise = context.eventLoop.makePromise() return self.upgradePromise!.futureResult.flatMap { context.pipeline.addHandler(self.upgradedHandler) @@ -192,7 +209,7 @@ private final class UpgradeDelayClientUpgrader: TypedAndUntypedHTTPClientProtoco self.upgradePromise = channel.eventLoop.makePromise() return self.upgradePromise!.futureResult.flatMap { channel.pipeline.addHandler(self.upgradedHandler) - }.map { _ in true} + }.map { _ in true } } fileprivate func unblockUpgrade() { @@ -203,38 +220,40 @@ private final class UpgradeDelayClientUpgrader: TypedAndUntypedHTTPClientProtoco private final class SimpleUpgradedHandler: ChannelInboundHandler { fileprivate typealias InboundIn = ByteBuffer fileprivate typealias OutboundOut = ByteBuffer - + fileprivate var handlerAddedContextCallCount = 0 fileprivate var channelReadContextDataCallCount = 0 - + fileprivate func handlerAdded(context: ChannelHandlerContext) { self.handlerAddedContextCallCount += 1 } - + fileprivate func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.channelReadContextDataCallCount += 1 } } extension ChannelInboundHandler where OutboundOut == HTTPClientRequestPart { - + fileprivate func fireSendRequest(context: ChannelHandlerContext) { - + var headers = HTTPHeaders() headers.add(name: "Content-Type", value: "text/plain; charset=utf-8") headers.add(name: "Content-Length", value: "\(0)") - - let requestHead = HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/", - headers: headers) - + + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: headers + ) + context.write(Self.wrapOutboundOut(.head(requestHead)), promise: nil) - + let emptyBuffer = context.channel.allocator.buffer(capacity: 0) let body = HTTPClientRequestPart.body(.byteBuffer(emptyBuffer)) - context.write(Self.wrapOutboundOut(body), promise: nil) - + context.write(self.wrapOutboundOut(body), promise: nil) + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) } } @@ -244,7 +263,7 @@ extension ChannelInboundHandler where OutboundOut == HTTPClientRequestPart { private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChannelHandler { fileprivate typealias InboundIn = HTTPClientResponsePart fileprivate typealias OutboundOut = HTTPClientRequestPart - + fileprivate func channelActive(context: ChannelHandlerContext) { // We are connected. It's time to send the message to the server to initialise the upgrade dance. self.fireSendRequest(context: context) @@ -253,7 +272,7 @@ private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChanne fileprivate func channelRead(context: ChannelHandlerContext, data: NIOAny) { XCTFail("Received unexpected read") } - + fileprivate func errorCaught(context: ChannelHandlerContext, error: Error) { XCTFail("Received unexpected erro") } @@ -264,20 +283,20 @@ private final class ExplodingHTTPHandler: ChannelInboundHandler, RemovableChanne private final class RecordingHTTPHandler: ChannelInboundHandler, RemovableChannelHandler { fileprivate typealias InboundIn = HTTPClientResponsePart fileprivate typealias OutboundOut = HTTPClientRequestPart - + fileprivate var channelReadChannelHandlerContextDataCallCount = 0 fileprivate var errorCaughtChannelHandlerContextCallCount = 0 fileprivate var errorCaughtChannelHandlerLatestError: Error? - + fileprivate func channelActive(context: ChannelHandlerContext) { // We are connected. It's time to send the message to the server to initialise the upgrade dance. self.fireSendRequest(context: context) } - + fileprivate func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.channelReadChannelHandlerContextDataCallCount += 1 } - + fileprivate func errorCaught(context: ChannelHandlerContext, error: Error) { self.errorCaughtChannelHandlerContextCallCount += 1 self.errorCaughtChannelHandlerLatestError = error @@ -311,11 +330,13 @@ class HTTPClientUpgradeTestCase: XCTestCase { completionHandler: { context in channel.pipeline.removeHandler(clientHTTPHandler, promise: nil) upgradeCompletionHandler(context) - }) + } + ) - try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config).flatMap({ - channel.pipeline.addHandler(clientHTTPHandler) - }).wait() + try channel.pipeline.addHTTPClientHandlers(leftOverBytesStrategy: .forwardBytes, withClientUpgrade: config) + .flatMap({ + channel.pipeline.addHandler(clientHTTPHandler) + }).wait() try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) .wait() @@ -324,9 +345,9 @@ class HTTPClientUpgradeTestCase: XCTestCase { } // MARK: Test basic happy path requests and responses. - + func testSimpleUpgradeSucceeds() throws { - + let upgradeProtocol = "myProto" let addedUpgradeHeader = "myUpgradeHeader" let addedUpgradeValue = "upgradeHeader" @@ -336,27 +357,34 @@ class HTTPClientUpgradeTestCase: XCTestCase { // This header is not required by the server but we will validate its receipt. let clientHeaders = [(addedUpgradeHeader, addedUpgradeValue)] - let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol, - upgradeHeaders: clientHeaders) - + let clientUpgrader = SuccessfulClientUpgrader( + forProtocol: upgradeProtocol, + upgradeHeaders: clientHeaders + ) + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + // Read the server request. if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol.lowercased())\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n") + XCTAssertEqual( + requestString, + "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol.lowercased())\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n" + ) } else { XCTFail() } - + // Validate the pipeline still has http handlers. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) @@ -366,40 +394,50 @@ class HTTPClientUpgradeTestCase: XCTestCase { let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() // Once upgraded, validate the pipeline has been removed. - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) - + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self) + ) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) + // Check the client upgrader was used correctly. XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) XCTAssertEqual(1, clientUpgrader.shouldAllowUpgradeCallCount) XCTAssertEqual(1, clientUpgrader.upgradeContextResponseCallCount) - + XCTAssert(upgradeHandlerCallbackFired) } - + func testUpgradeWithRequiredHeadersShowsInRequest() throws { - + let upgradeProtocol = "myProto" let addedUpgradeHeader = "myUpgradeHeader" let addedUpgradeValue = "upgradeValue" - + let clientHeaders = [(addedUpgradeHeader, addedUpgradeValue)] - - let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol, - requiredUpgradeHeaders: [addedUpgradeHeader], - upgradeHeaders: clientHeaders) - + + let clientUpgrader = SuccessfulClientUpgrader( + forProtocol: upgradeProtocol, + requiredUpgradeHeaders: [addedUpgradeHeader], + upgradeHeaders: clientHeaders + ) + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: [clientUpgrader]) { _ in + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: [clientUpgrader] + ) { _ in } defer { XCTAssertNoThrow(try clientChannel.finish()) @@ -407,89 +445,103 @@ class HTTPClientUpgradeTestCase: XCTestCase { // Read the server request and check that it has the required header also added to the connection header. if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade,\(addedUpgradeHeader)\r\nUpgrade: \(upgradeProtocol.lowercased())\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n") + XCTAssertEqual( + requestString, + "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade,\(addedUpgradeHeader)\r\nUpgrade: \(upgradeProtocol.lowercased())\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n" + ) } else { XCTFail() } - + // Check the client upgrader was used correctly, no response received. XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) XCTAssertEqual(0, clientUpgrader.shouldAllowUpgradeCallCount) XCTAssertEqual(0, clientUpgrader.upgradeContextResponseCallCount) } - + func testSimpleUpgradeSucceedsWhenMultipleAvailableProtocols() throws { - + let unusedUpgradeProtocol = "unusedMyProto" let unusedUpgradeHeader = "unusedMyUpgradeHeader" let unusedUpgradeValue = "unusedUpgradeHeaderValue" - + let upgradeProtocol = "myProto" let addedUpgradeHeader = "myUpgradeHeader" let addedUpgradeValue = "upgradeHeaderValue" - + var upgradeHandlerCallbackFired = false - + // These headers are not required by the server but we will validate their receipt. let unusedClientHeaders = [(unusedUpgradeHeader, unusedUpgradeValue)] let clientHeaders = [(addedUpgradeHeader, addedUpgradeValue)] - - let unusedClientUpgrader = ExplodingClientUpgrader(forProtocol: unusedUpgradeProtocol, - upgradeHeaders: unusedClientHeaders) - - let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol, - upgradeHeaders: clientHeaders) + + let unusedClientUpgrader = ExplodingClientUpgrader( + forProtocol: unusedUpgradeProtocol, + upgradeHeaders: unusedClientHeaders + ) + + let clientUpgrader = SuccessfulClientUpgrader( + forProtocol: upgradeProtocol, + upgradeHeaders: clientHeaders + ) let clientUpgraders: [any TypedAndUntypedHTTPClientProtocolUpgrader] = [unusedClientUpgrader, clientUpgrader] // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: clientUpgraders) { (context) in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: clientUpgraders + ) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + // Read the server request. if let requestString = try clientChannel.readByteBufferOutputAsString() { - + // Check that the details for both protocols are sent to the server, in preference order. let expectedUpgrade = "\(unusedUpgradeProtocol),\(upgradeProtocol)".lowercased() - - XCTAssertEqual(requestString, "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: \(expectedUpgrade)\r\n\(unusedUpgradeHeader): \(unusedUpgradeValue)\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n") + + XCTAssertEqual( + requestString, + "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: \(expectedUpgrade)\r\n\(unusedUpgradeHeader): \(unusedUpgradeValue)\r\n\(addedUpgradeHeader): \(addedUpgradeValue)\r\n\r\n" + ) } else { XCTFail() } - + // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" - + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Should just upgrade to the accepted protocol, the other protocol uses an exploding upgrader. XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) XCTAssertEqual(1, clientUpgrader.shouldAllowUpgradeCallCount) XCTAssertEqual(1, clientUpgrader.upgradeContextResponseCallCount) - + XCTAssert(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testUpgradeCompleteFlush() throws { final class ChannelReadWriteHandler: ChannelDuplexHandler { typealias OutboundIn = Any typealias InboundIn = Any typealias OutboundOut = Any - + var messagesReceived = 0 - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.messagesReceived += 1 context.writeAndFlush(data, promise: nil) @@ -505,84 +557,93 @@ class HTTPClientUpgradeTestCase: XCTestCase { self.supportedProtocol = `protocol` self.handler = handler } - - func addCustom(upgradeRequestHeaders: inout HTTPHeaders) { } - + + func addCustom(upgradeRequestHeaders: inout HTTPHeaders) {} + func shouldAllowUpgrade(upgradeResponse: HTTPResponseHead) -> Bool { - return true + true } - + func upgrade(context: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { - return context.pipeline.addHandler(handler) + context.pipeline.addHandler(handler) } func upgrade(channel: any Channel, upgradeResponse: HTTPResponseHead) -> EventLoopFuture { - return channel.pipeline.addHandler(handler).map { _ in true } + channel.pipeline.addHandler(handler).map { _ in true } } } - + var upgradeHandlerCallbackFired = false let handler = ChannelReadWriteHandler() let upgrader = AddHandlerClientUpgrader(forProtocol: "myproto", addingHandler: handler) - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: [upgrader]) { (context) in - - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: [upgrader] + ) { (context) in + + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + // Read the server request. if let requestString = try clientChannel.readByteBufferOutputAsString() { - XCTAssertEqual(requestString, "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\n") + XCTAssertEqual( + requestString, + "GET / HTTP/1.1\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: 0\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\n" + ) XCTAssertNoThrow(XCTAssertEqual(try clientChannel.readByteBufferOutputAsString(), "")) // Empty body XCTAssertNoThrow(XCTAssertNil(try clientChannel.readByteBufferOutputAsString())) } else { XCTFail() } - + // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: myproto\r\n\r\nTest" - + XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + XCTAssert(upgradeHandlerCallbackFired) - + XCTAssertEqual(handler.messagesReceived, 1) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) XCTAssertNoThrow(XCTAssertEqual(try clientChannel.readByteBufferOutputAsString(), "Test")) } - + // MARK: Test requests and responses with other specific actions. - + func testNoUpgradeAsNoServerUpgrade() throws { - + var upgradeHandlerCallbackFired = false let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") let clientHandler = RecordingHTTPHandler() - + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 200 OK\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Check that the http elements are not removed from the pipeline. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) @@ -593,307 +654,342 @@ class HTTPClientUpgradeTestCase: XCTestCase { XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testFirstResponseReturnsServerError() throws { - + var upgradeHandlerCallbackFired = false - + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") let clientHandler = RecordingHTTPHandler() - + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 404 Not Found\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Should fail with error (response is malformed) and remove upgrader from pipeline. - + // Check that the http elements are not removed from the pipeline. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - + // Check that the HTTP handler received its response. XCTAssertEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) // Check a separate error is not reported, the error response will be forwarded on. XCTAssertEqual(0, clientHandler.errorCaughtChannelHandlerContextCallCount) - + XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } func testUpgradeResponseMissingAllProtocols() throws { - + var upgradeHandlerCallbackFired = false - + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") let clientHandler = RecordingHTTPHandler() - + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Should fail with error (response is malformed) and remove upgrader from pipeline. - + // Check that the http elements are not removed from the pipeline. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - + // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) // Check an error is reported XCTAssertEqual(1, clientHandler.errorCaughtChannelHandlerContextCallCount) - + let reportedError = clientHandler.errorCaughtChannelHandlerLatestError! as! NIOHTTPClientUpgradeError XCTAssertEqual(NIOHTTPClientUpgradeError.responseProtocolNotFound, reportedError) - + XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testUpgradeOnlyHandlesKnownProtocols() throws { var upgradeHandlerCallbackFired = false - + let clientUpgrader = ExplodingClientUpgrader(forProtocol: "myProto") let clientHandler = RecordingHTTPHandler() - + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Should fail with error (response is malformed) and remove upgrader from pipeline. - + // Check that the http elements are not removed from the pipeline. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) - + // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) // Check an error is reported XCTAssertEqual(1, clientHandler.errorCaughtChannelHandlerContextCallCount) - + let reportedError = clientHandler.errorCaughtChannelHandlerLatestError! as! NIOHTTPClientUpgradeError XCTAssertEqual(NIOHTTPClientUpgradeError.responseProtocolNotFound, reportedError) - + XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { - + let upgradeProtocol = "myProto" - + var upgradeHandlerCallbackFired = false - + let clientUpgrader = DenyingClientUpgrader(forProtocol: upgradeProtocol) let clientHandler = RecordingHTTPHandler() - + // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) clientChannel.embeddedEventLoop.run() - + // Should fail with error (response is denied) and remove upgrader from pipeline. - + // Check that the http elements are not removed from the pipeline. clientChannel.pipeline.assertContains(handlerType: HTTPRequestEncoder.self) clientChannel.pipeline.assertContains(handlerType: ByteToMessageHandler.self) XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) - + // Check that the HTTP handler received its response. XCTAssertLessThanOrEqual(1, clientHandler.channelReadChannelHandlerContextDataCallCount) - + // Check an error is reported XCTAssertEqual(1, clientHandler.errorCaughtChannelHandlerContextCallCount) - + let reportedError = clientHandler.errorCaughtChannelHandlerLatestError! as! NIOHTTPClientUpgradeError XCTAssertEqual(NIOHTTPClientUpgradeError.upgraderDeniedUpgrade, reportedError) XCTAssertFalse(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testUpgradeIsCaseInsensitive() throws { - + let upgradeProtocol = "mYPrOtO123" var upgradeHandlerCallbackFired = false - + let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol) // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: [clientUpgrader]) { _ in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: [clientUpgrader] + ) { _ in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + let response = "HTTP/1.1 101 Switching Protocols\r\nCoNnEcTiOn: uPgRaDe\r\nuPgRaDe: \(upgradeProtocol)\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + clientChannel.embeddedEventLoop.run() - + // Should fail with error (response is denied) and remove upgrader from pipeline. - + // Check that the http elements are removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) - + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self) + ) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) + // Check the client upgrader was used. XCTAssertEqual(1, clientUpgrader.addCustomUpgradeRequestHeadersCallCount) XCTAssertEqual(1, clientUpgrader.shouldAllowUpgradeCallCount) XCTAssertEqual(1, clientUpgrader.upgradeContextResponseCallCount) - + XCTAssert(upgradeHandlerCallbackFired) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } // MARK: Test when client pipeline experiences delay. func testBuffersInboundDataDuringAddingHandlers() throws { - + let upgradeProtocol = "myProto" var upgradeHandlerCallbackFired = false - + let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) - let clientChannel = try setUpClientChannel(clientHTTPHandler: ExplodingHTTPHandler(), - clientUpgraders: [clientUpgrader]) { (context) in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + let clientChannel = try setUpClientChannel( + clientHTTPHandler: ExplodingHTTPHandler(), + clientUpgraders: [clientUpgrader] + ) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + // Run the processing of the response, but with the upgrade delayed by the client upgrader. clientChannel.embeddedEventLoop.run() - + // Soundness check that the upgrade was delayed. XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) - + // Add some non-http data. let appData = "supersecretawesome data definitely not http\r\nawesome\r\ndata\ryeah" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: appData))) - + // Upgrade now. clientUpgrader.unblockUpgrade() clientChannel.embeddedEventLoop.run() - + // Check that the http elements are removed from the pipeline. - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: HTTPRequestEncoder.self)) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: ByteToMessageHandler.self)) - + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: HTTPRequestEncoder.self) + ) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) + XCTAssert(upgradeHandlerCallbackFired) // Check that the data gets fired to the new handler once it is added. XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) XCTAssertEqual(1, clientUpgrader.upgradedHandler.channelReadContextDataCallCount) - - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } - + func testFiresOutboundErrorDuringAddingHandlers() throws { - + let upgradeProtocol = "myProto" var errorOnAdditionalChannelWrite: Error? var upgradeHandlerCallbackFired = false - + let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) let clientHandler = RecordingHTTPHandler() - - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { (context) in - - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { (context) in + + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } - + // Push the successful server response. let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" XCTAssertNoThrow(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) - + let promise = clientChannel.eventLoop.makePromise(of: Void.self) - - promise.futureResult.whenFailure() { error in + + promise.futureResult.whenFailure { error in errorOnAdditionalChannelWrite = error } @@ -901,7 +997,7 @@ class HTTPClientUpgradeTestCase: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let secondRequest: HTTPClientRequestPart = .head(requestHead) clientChannel.writeAndFlush(secondRequest, promise: promise) - + clientChannel.embeddedEventLoop.run() let reportedError = clientHandler.errorCaughtChannelHandlerLatestError! as! NIOHTTPClientUpgradeError @@ -909,10 +1005,10 @@ class HTTPClientUpgradeTestCase: XCTestCase { let promiseError = errorOnAdditionalChannelWrite as! NIOHTTPClientUpgradeError XCTAssertEqual(NIOHTTPClientUpgradeError.writingToHandlerDuringUpgrade, promiseError) - + // Soundness check that the upgrade was delayed. XCTAssertEqual(0, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) - + // Upgrade now. clientUpgrader.unblockUpgrade() clientChannel.embeddedEventLoop.run() @@ -921,35 +1017,41 @@ class HTTPClientUpgradeTestCase: XCTestCase { XCTAssert(upgradeHandlerCallbackFired) XCTAssertEqual(1, clientUpgrader.upgradedHandler.handlerAddedContextCallCount) } - + func testFiresInboundErrorBeforeSendsRequestUpgrade() throws { - + let upgradeProtocol = "myProto" - + let clientUpgrader = SuccessfulClientUpgrader(forProtocol: upgradeProtocol) let clientHandler = RecordingHTTPHandler() - + let clientChannel = EmbeddedChannel() defer { XCTAssertNoThrow(try clientChannel.finish()) } - - let upgrader = NIOHTTPClientUpgradeHandler(upgraders: [clientUpgrader], - httpHandlers: [clientHandler], - upgradeCompletionHandler: { context in - }) - + + let upgrader = NIOHTTPClientUpgradeHandler( + upgraders: [clientUpgrader], + httpHandlers: [clientHandler], + upgradeCompletionHandler: { context in + } + ) + try clientChannel.pipeline.addHandler(upgrader).wait() - + try clientChannel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait() - - let headers = HTTPHeaders([("Connection", "upgrade"), - ("Upgrade", "\(upgradeProtocol)")]) - let head = HTTPResponseHead(version: .http1_1, - status: .switchingProtocols, - headers: headers) + + let headers = HTTPHeaders([ + ("Connection", "upgrade"), + ("Upgrade", "\(upgradeProtocol)"), + ]) + let head = HTTPResponseHead( + version: .http1_1, + status: .switchingProtocols, + headers: headers + ) let response = HTTPClientResponsePart.head(head) - + XCTAssertThrowsError(try clientChannel.writeInbound(response)) { error in let reportedError = error as! NIOHTTPClientUpgradeError XCTAssertEqual(NIOHTTPClientUpgradeError.receivedResponseBeforeRequestSent, reportedError) @@ -979,7 +1081,9 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { headers: headers ) - let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array(clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader }) + let upgraders: [any NIOTypedHTTPClientProtocolUpgrader] = Array( + clientUpgraders.map { $0 as! any NIOTypedHTTPClientProtocolUpgrader } + ) let config = NIOTypedHTTPClientUpgradeConfiguration( upgradeRequestHead: requestHead, @@ -993,8 +1097,12 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { } var configuration = NIOUpgradableHTTPClientPipelineConfiguration(upgradeConfiguration: config) configuration.leftOverBytesStrategy = .forwardBytes - let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: configuration) - let context = try channel.pipeline.syncOperations.context(handlerType: NIOTypedHTTPClientUpgradeHandler.self) + let upgradeResult = try channel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline( + configuration: configuration + ) + let context = try channel.pipeline.syncOperations.context( + handlerType: NIOTypedHTTPClientUpgradeHandler.self + ) try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)) .wait() @@ -1016,18 +1124,21 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { let clientHandler = RecordingHTTPHandler() // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: unknownProtocol\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { + error in XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) } @@ -1046,8 +1157,10 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { XCTAssertFalse(upgradeHandlerCallbackFired) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } override func testUpgradeResponseCanBeRejectedByClientUpgrader() throws { @@ -1059,18 +1172,21 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { let clientHandler = RecordingHTTPHandler() // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: \(upgradeProtocol)\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { + error in XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .upgraderDeniedUpgrade) } @@ -1092,8 +1208,10 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { XCTAssertFalse(upgradeHandlerCallbackFired) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } override func testFiresOutboundErrorDuringAddingHandlers() throws { @@ -1104,11 +1222,13 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { let clientUpgrader = UpgradeDelayClientUpgrader(forProtocol: upgradeProtocol) let clientHandler = RecordingHTTPHandler() - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { (context) in + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { (context) in - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) @@ -1120,7 +1240,7 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { let promise = clientChannel.eventLoop.makePromise(of: Void.self) - promise.futureResult.whenFailure() { error in + promise.futureResult.whenFailure { error in errorOnAdditionalChannelWrite = error } @@ -1153,18 +1273,21 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { let clientHandler = RecordingHTTPHandler() // The process should kick-off independently by sending the upgrade request to the server. - let clientChannel = try setUpClientChannel(clientHTTPHandler: clientHandler, - clientUpgraders: [clientUpgrader]) { _ in + let clientChannel = try setUpClientChannel( + clientHTTPHandler: clientHandler, + clientUpgraders: [clientUpgrader] + ) { _ in - // This is called before the upgrader gets called. - upgradeHandlerCallbackFired = true + // This is called before the upgrader gets called. + upgradeHandlerCallbackFired = true } defer { XCTAssertNoThrow(try clientChannel.finish()) } let response = "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\n\r\n" - XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { error in + XCTAssertThrowsError(try clientChannel.writeInbound(clientChannel.allocator.buffer(string: response))) { + error in XCTAssertEqual(error as? NIOHTTPClientUpgradeError, .responseProtocolNotFound) } @@ -1183,8 +1306,10 @@ final class TypedHTTPClientUpgradeTestCase: HTTPClientUpgradeTestCase { XCTAssertFalse(upgradeHandlerCallbackFired) - XCTAssertNoThrow(try clientChannel.pipeline - .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self)) + XCTAssertNoThrow( + try clientChannel.pipeline + .assertDoesNotContain(handlerType: NIOHTTPClientUpgradeHandler.self) + ) } } #endif diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift index d33e87453d..273ae2743e 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderLengthTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOHTTP1 +import XCTest private class MessageEndHandler: ChannelInboundHandler { typealias InboundIn = HTTPPart @@ -45,7 +45,7 @@ private class MessageEndHandler: ChannelInboun class HTTPDecoderLengthTest: XCTestCase { private var channel: EmbeddedChannel! private var loop: EmbeddedEventLoop { - return self.channel.embeddedEventLoop + self.channel.embeddedEventLoop } override func setUp() { @@ -70,7 +70,8 @@ class HTTPDecoderLengthTest: XCTestCase { case neither } - private func assertSemanticEOFOnChannelInactiveResponse(version: HTTPVersion, how eofMechanism: EOFMechanism) throws { + private func assertSemanticEOFOnChannelInactiveResponse(version: HTTPVersion, how eofMechanism: EOFMechanism) throws + { class ChannelInactiveHandler: ChannelInboundHandler { typealias InboundIn = HTTPClientResponsePart var response: HTTPResponseHead? @@ -127,7 +128,11 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Prime the decoder with a GET and consume it. - XCTAssertTrue(try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: version, method: .GET, uri: "/"))).isFull) + XCTAssertTrue( + try channel.writeOutbound( + HTTPClientRequestPart.head(HTTPRequestHead(version: version, method: .GET, uri: "/")) + ).isFull + ) XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // We now want to send a HTTP/1.1 response. This response has no content-length, no transfer-encoding, @@ -144,7 +149,9 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertFalse(handler.eof) // Send a body chunk. This should be immediately passed on. Still no end or EOF. - XCTAssertNoThrow(try channel.writeInbound(IOData.byteBuffer(channel.allocator.buffer(string: "some body data")))) + XCTAssertNoThrow( + try channel.writeInbound(IOData.byteBuffer(channel.allocator.buffer(string: "some body data"))) + ) XCTAssertNotNil(handler.response) XCTAssertEqual(handler.body!, Array("some body data".utf8)) XCTAssertFalse(handler.receivedEnd) @@ -180,9 +187,11 @@ class HTTPDecoderLengthTest: XCTestCase { try assertSemanticEOFOnChannelInactiveResponse(version: .http1_0, how: .halfClosure) } - private func assertIgnoresLengthFields(requestMethod: HTTPMethod, - responseStatus: HTTPResponseStatus, - responseFramingField: FramingField) throws { + private func assertIgnoresLengthFields( + requestMethod: HTTPMethod, + responseStatus: HTTPResponseStatus, + responseFramingField: FramingField + ) throws { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(HTTPRequestEncoder())) let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .forward) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(decoder))) @@ -191,9 +200,17 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Prime the decoder with a request and consume it. - XCTAssertTrue(try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http1_1, - method: requestMethod, - uri: "/"))).isFull) + XCTAssertTrue( + try channel.writeOutbound( + HTTPClientRequestPart.head( + HTTPRequestHead( + version: .http1_1, + method: requestMethod, + uri: "/" + ) + ) + ).isFull + ) XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // We now want to send a HTTP/1.1 response. This response may contain some length framing fields that RFC 7230 says MUST @@ -221,21 +238,29 @@ class HTTPDecoderLengthTest: XCTestCase { // follow. For this reason, we don't expect an `.end` here. XCTAssertFalse(handler.seenBody) XCTAssertFalse(handler.seenEnd) - + default: XCTAssertFalse(handler.seenBody) XCTAssert(handler.seenEnd) } - + XCTAssertTrue(try channel.finish().isClean) } func testIgnoresTransferEncodingFieldOnCONNECTResponses() throws { - try assertIgnoresLengthFields(requestMethod: .CONNECT, responseStatus: .ok, responseFramingField: .transferEncoding) + try assertIgnoresLengthFields( + requestMethod: .CONNECT, + responseStatus: .ok, + responseFramingField: .transferEncoding + ) } func testIgnoresContentLengthFieldOnCONNECTResponses() throws { - try assertIgnoresLengthFields(requestMethod: .CONNECT, responseStatus: .ok, responseFramingField: .contentLength) + try assertIgnoresLengthFields( + requestMethod: .CONNECT, + responseStatus: .ok, + responseFramingField: .contentLength + ) } func testEarlyFinishWithoutLengthAtAllOnCONNECTResponses() throws { @@ -243,7 +268,11 @@ class HTTPDecoderLengthTest: XCTestCase { } func testIgnoresTransferEncodingFieldOnHEADResponses() throws { - try assertIgnoresLengthFields(requestMethod: .HEAD, responseStatus: .ok, responseFramingField: .transferEncoding) + try assertIgnoresLengthFields( + requestMethod: .HEAD, + responseStatus: .ok, + responseFramingField: .transferEncoding + ) } func testIgnoresContentLengthFieldOnHEADResponses() throws { @@ -255,29 +284,43 @@ class HTTPDecoderLengthTest: XCTestCase { } func testIgnoresTransferEncodingFieldOn1XXResponses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, - responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), - responseFramingField: .transferEncoding) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), + responseFramingField: .transferEncoding + ) } func testIgnoresContentLengthFieldOn1XXResponses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, - responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), - responseFramingField: .contentLength) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), + responseFramingField: .contentLength + ) } func testEarlyFinishWithoutLengthAtAllOn1XXResponses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, - responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), - responseFramingField: .neither) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .custom(code: 103, reasonPhrase: "Early Hints"), + responseFramingField: .neither + ) } func testIgnoresTransferEncodingFieldOn204Responses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, responseStatus: .noContent, responseFramingField: .transferEncoding) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .noContent, + responseFramingField: .transferEncoding + ) } func testIgnoresContentLengthFieldOn204Responses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, responseStatus: .noContent, responseFramingField: .contentLength) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .noContent, + responseFramingField: .contentLength + ) } func testEarlyFinishWithoutLengthAtAllOn204Responses() throws { @@ -285,11 +328,19 @@ class HTTPDecoderLengthTest: XCTestCase { } func testIgnoresTransferEncodingFieldOn304Responses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, responseStatus: .notModified, responseFramingField: .transferEncoding) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .notModified, + responseFramingField: .transferEncoding + ) } func testIgnoresContentLengthFieldOn304Responses() throws { - try assertIgnoresLengthFields(requestMethod: .GET, responseStatus: .notModified, responseFramingField: .contentLength) + try assertIgnoresLengthFields( + requestMethod: .GET, + responseStatus: .notModified, + responseFramingField: .contentLength + ) } func testEarlyFinishWithoutLengthAtAllOn304Responses() throws { @@ -303,7 +354,13 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Send a GET with the appropriate Transfer Encoding header. - XCTAssertThrowsError(try channel.writeInbound(channel.allocator.buffer(string: "POST / HTTP/1.1\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n"))) { error in + XCTAssertThrowsError( + try channel.writeInbound( + channel.allocator.buffer( + string: "POST / HTTP/1.1\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n" + ) + ) + ) { error in XCTAssertEqual(error as? HTTPParserError, .unknown) } } @@ -315,7 +372,11 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Send a GET with the appropriate Transfer Encoding header. - XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "POST / HTTP/1.1\r\nTransfer-Encoding: gzip, chunked\r\n\r\n0\r\n\r\n"))) + XCTAssertNoThrow( + try channel.writeInbound( + channel.allocator.buffer(string: "POST / HTTP/1.1\r\nTransfer-Encoding: gzip, chunked\r\n\r\n0\r\n\r\n") + ) + ) // We should have a request, no body, and immediately see end of request. XCTAssert(handler.seenHead) @@ -333,7 +394,10 @@ class HTTPDecoderLengthTest: XCTestCase { try assertRequestTransferEncodingInError(transferEncodingHeader: "gzip, chunked, deflate") } - private func assertResponseTransferEncodingHasBodyTerminatedByEOF(transferEncodingHeader: String, eofMechanism: EOFMechanism) throws { + private func assertResponseTransferEncodingHasBodyTerminatedByEOF( + transferEncodingHeader: String, + eofMechanism: EOFMechanism + ) throws { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(HTTPRequestEncoder())) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))) @@ -341,14 +405,28 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Prime the decoder with a request and consume it. - XCTAssertTrue(try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/"))).isFull) + XCTAssertTrue( + try channel.writeOutbound( + HTTPClientRequestPart.head( + HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/" + ) + ) + ).isFull + ) XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // Send a 200 with the appropriate Transfer Encoding header. We should see the request, // but no body or end. - XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "HTTP/1.1 200 OK\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n"))) + XCTAssertNoThrow( + try channel.writeInbound( + channel.allocator.buffer( + string: "HTTP/1.1 200 OK\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n" + ) + ) + ) XCTAssert(handler.seenHead) XCTAssertFalse(handler.seenBody) XCTAssertFalse(handler.seenEnd) @@ -372,7 +450,10 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertTrue(try channel.finish().isClean) } - private func assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk(transferEncodingHeader: String, eofMechanism: EOFMechanism) throws { + private func assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk( + transferEncodingHeader: String, + eofMechanism: EOFMechanism + ) throws { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(HTTPRequestEncoder())) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))) @@ -380,13 +461,27 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) // Prime the decoder with a request and consume it. - XCTAssertTrue(try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/"))).isFull) + XCTAssertTrue( + try channel.writeOutbound( + HTTPClientRequestPart.head( + HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/" + ) + ) + ).isFull + ) XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self))) // Send a 200 with the appropriate Transfer Encoding header. We should see the request. - XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "HTTP/1.1 200 OK\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n"))) + XCTAssertNoThrow( + try channel.writeInbound( + channel.allocator.buffer( + string: "HTTP/1.1 200 OK\r\nTransfer-Encoding: \(transferEncodingHeader)\r\n\r\n" + ) + ) + ) XCTAssert(handler.seenHead) XCTAssertFalse(handler.seenBody) XCTAssertFalse(handler.seenEnd) @@ -416,38 +511,58 @@ class HTTPDecoderLengthTest: XCTestCase { } func testMultipleTEWithChunkedLastHasEOFBodyOnResponseWithChannelInactive() throws { - try assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk(transferEncodingHeader: "gzip, chunked", eofMechanism: .channelInactive) + try assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk( + transferEncodingHeader: "gzip, chunked", + eofMechanism: .channelInactive + ) } func testMultipleTEWithChunkedFirstHasEOFBodyOnResponseWithChannelInactive() throws { // Here http_parser is right, and this is EOF terminated. - try assertResponseTransferEncodingHasBodyTerminatedByEOF(transferEncodingHeader: "chunked, gzip", eofMechanism: .channelInactive) + try assertResponseTransferEncodingHasBodyTerminatedByEOF( + transferEncodingHeader: "chunked, gzip", + eofMechanism: .channelInactive + ) } func testMultipleTEWithChunkedInTheMiddleHasEOFBodyOnResponseWithChannelInactive() throws { // Here http_parser is right, and this is EOF terminated. - try assertResponseTransferEncodingHasBodyTerminatedByEOF(transferEncodingHeader: "gzip, chunked, deflate", eofMechanism: .channelInactive) + try assertResponseTransferEncodingHasBodyTerminatedByEOF( + transferEncodingHeader: "gzip, chunked, deflate", + eofMechanism: .channelInactive + ) } func testMultipleTEWithChunkedLastHasEOFBodyOnResponseWithHalfClosure() throws { - try assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk(transferEncodingHeader: "gzip, chunked", eofMechanism: .halfClosure) + try assertResponseTransferEncodingHasBodyTerminatedByEndOfChunk( + transferEncodingHeader: "gzip, chunked", + eofMechanism: .halfClosure + ) } func testMultipleTEWithChunkedFirstHasEOFBodyOnResponseWithHalfClosure() throws { // Here http_parser is right, and this is EOF terminated. - try assertResponseTransferEncodingHasBodyTerminatedByEOF(transferEncodingHeader: "chunked, gzip", eofMechanism: .halfClosure) + try assertResponseTransferEncodingHasBodyTerminatedByEOF( + transferEncodingHeader: "chunked, gzip", + eofMechanism: .halfClosure + ) } func testMultipleTEWithChunkedInTheMiddleHasEOFBodyOnResponseWithHalfClosure() throws { // Here http_parser is right, and this is EOF terminated. - try assertResponseTransferEncodingHasBodyTerminatedByEOF(transferEncodingHeader: "gzip, chunked, deflate", eofMechanism: .halfClosure) + try assertResponseTransferEncodingHasBodyTerminatedByEOF( + transferEncodingHeader: "gzip, chunked, deflate", + eofMechanism: .halfClosure + ) } func testRequestWithTEAndContentLengthErrors() throws { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))) // Send a GET with the invalid headers. - let request = channel.allocator.buffer(string: "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nContent-Length: 4\r\n\r\n") + let request = channel.allocator.buffer( + string: "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nContent-Length: 4\r\n\r\n" + ) XCTAssertThrowsError(try channel.writeInbound(request)) { error in XCTAssertEqual(HTTPParserError.unexpectedContentLength, error as? HTTPParserError) } @@ -462,9 +577,17 @@ class HTTPDecoderLengthTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))) // Prime the decoder with a request. - XCTAssertTrue(try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http1_1, - method: .GET, - uri: "/"))).isFull) + XCTAssertTrue( + try channel.writeOutbound( + HTTPClientRequestPart.head( + HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/" + ) + ) + ).isFull + ) // Send a 200 OK with the invalid headers. let response = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\nContent-Length: 4\r\n\r\n" @@ -483,8 +606,10 @@ class HTTPDecoderLengthTest: XCTestCase { // Send a GET with the invalid headers. let request = "POST / HTTP/1.1\r\nContent-Length: \(contentLengthField)\r\n\r\n" XCTAssertThrowsError(try channel.writeInbound(channel.allocator.buffer(string: request))) { error in - XCTAssert(HTTPParserError.unexpectedContentLength == error as? HTTPParserError || - HTTPParserError.invalidContentLength == error as? HTTPParserError) + XCTAssert( + HTTPParserError.unexpectedContentLength == error as? HTTPParserError + || HTTPParserError.invalidContentLength == error as? HTTPParserError + ) } // Must spin the loop. @@ -544,7 +669,9 @@ class HTTPDecoderLengthTest: XCTestCase { // Send a POST without a length field of any kind. This should be a zero-length request, // so .end should come immediately. - XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "POST / HTTP/1.1\r\nHost: example.org\r\n\r\n"))) + XCTAssertNoThrow( + try channel.writeInbound(channel.allocator.buffer(string: "POST / HTTP/1.1\r\nHost: example.org\r\n\r\n")) + ) XCTAssert(handler.seenHead) XCTAssertFalse(handler.seenBody) XCTAssert(handler.seenEnd) diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift index b6d726808a..da74c7685d 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift @@ -12,16 +12,16 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOHTTP1 import NIOTestUtils +import XCTest class HTTPDecoderTest: XCTestCase { private var channel: EmbeddedChannel! private var loop: EmbeddedEventLoop { - return self.channel.embeddedEventLoop + self.channel.embeddedEventLoop } override func setUp() { @@ -41,11 +41,11 @@ class HTTPDecoderTest: XCTestCase { // actually parse this at all. var buffer = channel.allocator.buffer(capacity: 64) buffer.writeStaticString("GET /a-file\r\n\r\n") - + XCTAssertThrowsError(try channel.writeInbound(buffer)) { error in XCTAssertEqual(.invalidVersion, error as? HTTPParserError) } - + self.loop.run() } @@ -83,7 +83,9 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))) // We need to prime the decoder by seeing a GET request. - try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http0_9, method: .GET, uri: "/"))) + try channel.writeOutbound( + HTTPClientRequestPart.head(HTTPRequestHead(version: .http0_9, method: .GET, uri: "/")) + ) // The HTTP parser has no special logic for HTTP/0.9 simple responses, but we'll send // one anyway just to prove it explodes. @@ -102,7 +104,9 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder()))) // We need to prime the decoder by seeing a GET request. - try channel.writeOutbound(HTTPClientRequestPart.head(HTTPRequestHead(version: .http0_9, method: .GET, uri: "/"))) + try channel.writeOutbound( + HTTPClientRequestPart.head(HTTPRequestHead(version: .http0_9, method: .GET, uri: "/")) + ) // The HTTP parser rejects HTTP/1.1-formatted responses claiming 0.9 as a version. var buffer = channel.allocator.buffer(capacity: 64) @@ -168,7 +172,7 @@ class HTTPDecoderTest: XCTestCase { written += buffer2.writeStaticString("X-Header: value\r\n") try channel.writeInbound(buffer2) - } while written < 8192 // Use a value that w + } while written < 8192 // Use a value that w var buffer3 = channel.allocator.buffer(capacity: 2) buffer3.writeStaticString("\r\n") @@ -192,17 +196,25 @@ class HTTPDecoderTest: XCTestCase { } } } - XCTAssertNoThrow(try - channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder()), - name: "decoder")) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(HTTPRequestDecoder()), + name: "decoder" + ) + ) XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver()).wait()) var buffer = channel.allocator.buffer(capacity: 64) - buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX") + buffer.writeStaticString( + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX" + ) + // allow the event loop to run (removal is not synchronous here) XCTAssertNoThrow(try channel.writeInbound(buffer)) - (channel.eventLoop as! EmbeddedEventLoop).run() // allow the event loop to run (removal is not synchronous here) - XCTAssertNoThrow(try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + (channel.eventLoop as! EmbeddedEventLoop).run() + XCTAssertNoThrow( + try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) XCTAssertNoThrow(try channel.finish()) } @@ -248,48 +260,57 @@ class HTTPDecoderTest: XCTestCase { } } } - XCTAssertNoThrow(try - channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)), - name: "decoder")) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .forwardBytes)), + name: "decoder" + ) + ) XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver()).wait()) - // This connect call is semantically wrong, but it's how you active embedded channels properly right now. + // This connect call is semantically wrong, but it's how you + // active embedded channels properly right now. XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 8888)).wait()) var buffer = channel.allocator.buffer(capacity: 64) - buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX") + buffer.writeStaticString( + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX" + ) XCTAssertNoThrow(try channel.writeInbound(buffer)) - (channel.eventLoop as! EmbeddedEventLoop).run() // allow the event loop to run (removal is not synchrnous here) - XCTAssertNoThrow(try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + // allow the event loop to run (removal is not synchrnous here) + (channel.eventLoop as! EmbeddedEventLoop).run() + XCTAssertNoThrow( + try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) XCTAssertNoThrow(try channel.finish()) } - + func testDontDropExtraBytesResponse() throws { class ByteCollector: ChannelInboundHandler { typealias InboundIn = ByteBuffer var called: Bool = false - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { var buffer = Self.unwrapInboundIn(data) XCTAssertEqual("XXXX", buffer.readString(length: buffer.readableBytes)!) self.called = true } - + func handlerAdded(context: ChannelHandlerContext) { _ = context.pipeline.removeHandler(name: "decoder") } - + func handlerRemoved(context: ChannelHandlerContext) { XCTAssert(self.called) } } - + class Receiver: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPClientResponsePart typealias InboundOut = HTTPClientResponsePart typealias OutboundOut = HTTPClientRequestPart - + func channelActive(context: ChannelHandlerContext) { var upgradeReq = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") upgradeReq.headers.add(name: "Connection", value: "Upgrade") @@ -298,7 +319,7 @@ class HTTPDecoderTest: XCTestCase { context.write(wrapOutboundOut(.head(upgradeReq)), promise: nil) context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil) } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { let part = Self.unwrapInboundIn(data) switch part { @@ -313,19 +334,28 @@ class HTTPDecoderTest: XCTestCase { } } } - - XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)), - name: "decoder")) + + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)), + name: "decoder" + ) + ) XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver()).wait()) - + XCTAssertNoThrow(try channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 8888)).wait()) - + var buffer = channel.allocator.buffer(capacity: 32) - buffer.writeStaticString("HTTP/1.1 101 Switching Protocols\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX") - + buffer.writeStaticString( + "HTTP/1.1 101 Switching Protocols\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\nXXXX" + ) + XCTAssertNoThrow(try channel.writeInbound(buffer)) - (channel.eventLoop as! EmbeddedEventLoop).run() // allow the event loop to run (removal is not synchrnous here) - XCTAssertNoThrow(try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + // allow the event loop to run (removal is not synchronous here) + (channel.eventLoop as! EmbeddedEventLoop).run() + XCTAssertNoThrow( + try channel.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self) + ) XCTAssertNoThrow(try channel.finish()) } @@ -394,7 +424,7 @@ class HTTPDecoderTest: XCTestCase { // changed in https://github.com/nodejs/http-parser/pull/432 . var buffer = channel.allocator.buffer(capacity: 64) buffer.writeStaticString("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") - buffer.writeStaticString("\r") // this is extra + buffer.writeStaticString("\r") // this is extra buffer.writeStaticString("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") try channel.writeInbound(buffer) @@ -459,8 +489,12 @@ class HTTPDecoderTest: XCTestCase { var expectedHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") expectedHead.headers.add(name: "foo", value: "bär") - XCTAssertNoThrow(XCTAssertEqual(.head(expectedHead), - try writeToFreshRequestDecoderChannel("GET / HTTP/1.1\r\nfoo: bär\r\n\r\n"))) + XCTAssertNoThrow( + XCTAssertEqual( + .head(expectedHead), + try writeToFreshRequestDecoderChannel("GET / HTTP/1.1\r\nfoo: bär\r\n\r\n") + ) + ) } func testDoesNotDeliverLeftoversUnnecessarily() { @@ -475,7 +509,11 @@ class HTTPDecoderTest: XCTestCase { var dataBuffer = channel.allocator.buffer(capacity: 128) dataBuffer.writeStaticString(data) - XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .fireError)))) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .fireError)) + ) + ) XCTAssertNoThrow(try channel.writeInbound(dataBuffer.getSlice(at: 0, length: dataBuffer.readableBytes - 6)!)) XCTAssertNoThrow(try channel.writeInbound(dataBuffer.getSlice(at: dataBuffer.readableBytes - 6, length: 6)!)) @@ -492,51 +530,109 @@ class HTTPDecoderTest: XCTestCase { var buffer = channel.allocator.buffer(capacity: 128) buffer.writeStaticString("HTTP/1.0 200 ok\r\n\r\n") - XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .fireError)))) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .GET, uri: "/")))) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .fireError)) + ) + ) + XCTAssertNoThrow( + try channel.writeOutbound( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .GET, + uri: "/" + ) + ) + ) + ) XCTAssertNoThrow(try channel.writeInbound(buffer)) - XCTAssertNoThrow(XCTAssertEqual(HTTPClientResponsePart.head(.init(version: .http1_0, - status: .ok)), try channel.readInbound())) + XCTAssertNoThrow( + XCTAssertEqual( + HTTPClientResponsePart.head( + .init( + version: .http1_0, + status: .ok + ) + ), + try channel.readInbound() + ) + ) } func testBasicVerifications() { let byteBufferContainingJustAnX = ByteBuffer(string: "X") let expectedInOuts: [(String, [HTTPServerRequestPart])] = [ - ("GET / HTTP/1.1\r\n\r\n", - [.head(.init(version: .http1_1, method: .GET, uri: "/")), - .end(nil)]), - ("POST /foo HTTP/1.1\r\n\r\n", - [.head(.init(version: .http1_1, method: .POST, uri: "/foo")), - .end(nil)]), - ("POST / HTTP/1.1\r\ncontent-length: 1\r\n\r\nX", - [.head(.init(version: .http1_1, - method: .POST, - uri: "/", - headers: .init([("content-length", "1")]))), - .body(byteBufferContainingJustAnX), - .end(nil)]), - ("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n1\r\nX\r\n0\r\n\r\n", - [.head(.init(version: .http1_1, - method: .POST, - uri: "/", - headers: .init([("transfer-encoding", "chunked")]))), - .body(byteBufferContainingJustAnX), - .end(nil)]), - ("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\none: two\r\n\r\n1\r\nX\r\n0\r\nfoo: bar\r\n\r\n", - [.head(.init(version: .http1_1, - method: .POST, - uri: "/", - headers: .init([("transfer-encoding", "chunked"), ("one", "two")]))), - .body(byteBufferContainingJustAnX), - .end(.init([("foo", "bar")]))]), + ( + "GET / HTTP/1.1\r\n\r\n", + [ + .head(.init(version: .http1_1, method: .GET, uri: "/")), + .end(nil), + ] + ), + ( + "POST /foo HTTP/1.1\r\n\r\n", + [ + .head(.init(version: .http1_1, method: .POST, uri: "/foo")), + .end(nil), + ] + ), + ( + "POST / HTTP/1.1\r\ncontent-length: 1\r\n\r\nX", + [ + .head( + .init( + version: .http1_1, + method: .POST, + uri: "/", + headers: .init([("content-length", "1")]) + ) + ), + .body(byteBufferContainingJustAnX), + .end(nil), + ] + ), + ( + "POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n1\r\nX\r\n0\r\n\r\n", + [ + .head( + .init( + version: .http1_1, + method: .POST, + uri: "/", + headers: .init([("transfer-encoding", "chunked")]) + ) + ), + .body(byteBufferContainingJustAnX), + .end(nil), + ] + ), + ( + "POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\none: two\r\n\r\n1\r\nX\r\n0\r\nfoo: bar\r\n\r\n", + [ + .head( + .init( + version: .http1_1, + method: .POST, + uri: "/", + headers: .init([("transfer-encoding", "chunked"), ("one", "two")]) + ) + ), + .body(byteBufferContainingJustAnX), + .end(.init([("foo", "bar")])), + ] + ), ] let expectedInOutsBB: [(ByteBuffer, [HTTPServerRequestPart])] = expectedInOuts.map { io in - return (ByteBuffer(string: io.0), io.1) - } - XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder(inputOutputPairs: expectedInOutsBB, - decoderFactory: { HTTPRequestDecoder() })) + (ByteBuffer(string: io.0), io.1) + } + XCTAssertNoThrow( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: expectedInOutsBB, + decoderFactory: { HTTPRequestDecoder() } + ) + ) } func testNothingHappensOnEOFForLeftOversInAllLeftOversModes() throws { @@ -601,14 +697,22 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(decoder)) XCTAssertNoThrow(try channel.pipeline.addHandler(receiver).wait()) XCTAssertNoThrow(try channel.writeInbound(buffer)) - let removalFutures = [ channel.pipeline.syncOperations.removeHandler(receiver), channel.pipeline.syncOperations.removeHandler(decoder) ] + let removalFutures = [ + channel.pipeline.syncOperations.removeHandler(receiver), + channel.pipeline.syncOperations.removeHandler(decoder), + ] channel.embeddedEventLoop.run() - try removalFutures.forEach { - XCTAssertNoThrow(try $0.wait()) - } - XCTAssertNoThrow(XCTAssertEqual("XXXX", try channel.readInbound(as: ByteBuffer.self).map { - String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) - })) + for future in removalFutures { + XCTAssertNoThrow(try future.wait()) + } + XCTAssertNoThrow( + XCTAssertEqual( + "XXXX", + try channel.readInbound(as: ByteBuffer.self).map { + String(decoding: $0.readableBytesView, as: Unicode.UTF8.self) + } + ) + ) XCTAssertNoThrow(XCTAssert(try channel.finish().isClean)) } @@ -638,10 +742,13 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(decoder)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(receiver)) XCTAssertNoThrow(try channel.writeInbound(buffer)) - let removalFutures = [ channel.pipeline.syncOperations.removeHandler(receiver), channel.pipeline.syncOperations.removeHandler(decoder) ] + let removalFutures = [ + channel.pipeline.syncOperations.removeHandler(receiver), + channel.pipeline.syncOperations.removeHandler(decoder), + ] channel.embeddedEventLoop.run() - try removalFutures.forEach { - XCTAssertNoThrow(try $0.wait()) + for future in removalFutures { + XCTAssertNoThrow(try future.wait()) } XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in switch error as? ByteToMessageDecoderError { @@ -682,10 +789,13 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(decoder)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(receiver)) XCTAssertNoThrow(try channel.writeInbound(buffer)) - let removalFutures = [ channel.pipeline.syncOperations.removeHandler(receiver), channel.pipeline.syncOperations.removeHandler(decoder) ] + let removalFutures = [ + channel.pipeline.syncOperations.removeHandler(receiver), + channel.pipeline.syncOperations.removeHandler(decoder), + ] channel.embeddedEventLoop.run() - try removalFutures.forEach { - XCTAssertNoThrow(try $0.wait()) + for future in removalFutures { + XCTAssertNoThrow(try future.wait()) } XCTAssertNoThrow(XCTAssert(try channel.finish().isClean)) } @@ -742,29 +852,53 @@ class HTTPDecoderTest: XCTestCase { let channel = EmbeddedChannel(handler: responseDecoder) XCTAssertNoThrow(try channel.pipeline.addHandler(eventCounter).wait()) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .GET, uri: "/")))) + XCTAssertNoThrow( + try channel.writeOutbound( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .GET, + uri: "/" + ) + ) + ) + ) var buffer = channel.allocator.buffer(capacity: 128) buffer.writeString("HTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n") XCTAssertThrowsError(try channel.writeInbound(buffer)) { error in XCTAssertEqual(.unsolicitedResponse, error as? NIOHTTPDecoderError) } - XCTAssertNoThrow(XCTAssertEqual(.head(.init(version: .http1_1, - status: .ok, - headers: ["content-length": "0"])), - try channel.readInbound(as: HTTPClientResponsePart.self))) - XCTAssertNoThrow(XCTAssertEqual(.end(nil), - try channel.readInbound(as: HTTPClientResponsePart.self))) + XCTAssertNoThrow( + XCTAssertEqual( + .head( + .init( + version: .http1_1, + status: .ok, + headers: ["content-length": "0"] + ) + ), + try channel.readInbound(as: HTTPClientResponsePart.self) + ) + ) + XCTAssertNoThrow( + XCTAssertEqual( + .end(nil), + try channel.readInbound(as: HTTPClientResponsePart.self) + ) + ) XCTAssertNoThrow(XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self))) XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound())) XCTAssertEqual(1, eventCounter.writeCalls) XCTAssertEqual(1, eventCounter.flushCalls) - XCTAssertEqual(2, eventCounter.channelReadCalls) // .head & .end + XCTAssertEqual(2, eventCounter.channelReadCalls) // .head & .end XCTAssertEqual(1, eventCounter.channelReadCompleteCalls) - XCTAssertEqual(["channelReadComplete", "write", "flush", "channelRead", "errorCaught"], eventCounter.allTriggeredEvents()) + XCTAssertEqual( + ["channelReadComplete", "write", "flush", "channelRead", "errorCaught"], + eventCounter.allTriggeredEvents() + ) XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) } - + func testForwardContinueThenResponse() { let eventCounter = EventCounterHandler() let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .forward) @@ -777,21 +911,27 @@ class HTTPDecoderTest: XCTestCase { var buffer = channel.allocator.buffer(capacity: 128) buffer.writeString("HTTP/1.1 100 continue\r\n\r\nHTTP/1.1 200 ok\r\ncontent-length: 0\r\n\r\n") XCTAssertNoThrow(try channel.writeInbound(buffer)) - - XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .continue))) - XCTAssertEqual(try channel.readInbound(as: HTTPClientResponsePart.self), .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"]))) + + XCTAssertEqual( + try channel.readInbound(as: HTTPClientResponsePart.self), + .head(.init(version: .http1_1, status: .continue)) + ) + XCTAssertEqual( + try channel.readInbound(as: HTTPClientResponsePart.self), + .head(.init(version: .http1_1, status: .ok, headers: ["content-length": "0"])) + ) XCTAssertEqual(.end(nil), try channel.readInbound(as: HTTPClientResponsePart.self)) XCTAssertNil(try channel.readInbound(as: HTTPClientResponsePart.self)) XCTAssertNotNil(try channel.readOutbound()) - + XCTAssertEqual(1, eventCounter.writeCalls) XCTAssertEqual(1, eventCounter.flushCalls) - XCTAssertEqual(3, eventCounter.channelReadCalls) // .head, .head & .end + XCTAssertEqual(3, eventCounter.channelReadCalls) // .head, .head & .end XCTAssertEqual(1, eventCounter.channelReadCompleteCalls) XCTAssertEqual(["channelReadComplete", "channelRead", "write", "flush"], eventCounter.allTriggeredEvents()) XCTAssertNoThrow(XCTAssertTrue(try channel.finish().isClean)) } - + func testForwardMultipleContinuesThenResponse() { let eventCounter = EventCounterHandler() let decoder = HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes, informationalResponseStrategy: .forward) @@ -806,25 +946,31 @@ class HTTPDecoderTest: XCTestCase { for _ in 0..?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" - XCTAssertNoThrow(try self.channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))) - let goodRequest = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(weirdAllowedFieldValue)\r\n\r\n") + XCTAssertNoThrow( + try self.channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder())) + ) + let goodRequest = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(weirdAllowedFieldValue)\r\n\r\n" + ) XCTAssertNoThrow(try self.channel.writeInbound(goodRequest)) @@ -1162,7 +1333,9 @@ class HTTPDecoderTest: XCTestCase { } let forbiddenFieldValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) let channel = EmbeddedChannel(handler: ByteToMessageHandler(HTTPRequestDecoder())) - let badRequest = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(forbiddenFieldValue)\r\n\r\n") + let badRequest = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(forbiddenFieldValue)\r\n\r\n" + ) XCTAssertThrowsError( try channel.writeInbound(badRequest), @@ -1177,7 +1350,9 @@ class HTTPDecoderTest: XCTestCase { for byte in UInt8(128)...UInt8(255) { let evenWeirderAllowedValue = weirdAllowedFieldValue + String(decoding: [byte], as: UTF8.self) let channel = EmbeddedChannel(handler: ByteToMessageHandler(HTTPRequestDecoder())) - let weirdGoodRequest = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(evenWeirderAllowedValue)\r\n\r\n") + let weirdGoodRequest = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Field: \(evenWeirderAllowedValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeInbound(weirdGoodRequest)) XCTAssertNoThrow(maybeHead = try channel.readInbound()) @@ -1194,11 +1369,16 @@ class HTTPDecoderTest: XCTestCase { func testDecodingInvalidTrailerFieldValues() throws { // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" - let request = ByteBuffer(string: "POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n") + let request = ByteBuffer( + string: "POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n" + ) - XCTAssertNoThrow(try self.channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder()))) + XCTAssertNoThrow( + try self.channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder())) + ) let goodTrailers = ByteBuffer(string: "Weird-Field: \(weirdAllowedFieldValue)\r\n\r\n") XCTAssertNoThrow(try self.channel.writeInbound(request)) @@ -1292,7 +1472,9 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.end(nil))) // Send a response. - let goodResponseWithContent = ByteBuffer(string: "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 4\r\n\r\nGood") + let goodResponseWithContent = ByteBuffer( + string: "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 4\r\n\r\nGood" + ) XCTAssertNoThrow(try channel.writeInbound(goodResponseWithContent)) var maybeBody: HTTPClientResponsePart? @@ -1356,7 +1538,9 @@ class HTTPDecoderTest: XCTestCase { XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.end(nil))) // Send a response. - let goodResponseWithContent = ByteBuffer(string: "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 4\r\n\r\nGood") + let goodResponseWithContent = ByteBuffer( + string: "HTTP/1.1 200 OK\r\nServer: foo\r\nContent-Length: 4\r\n\r\nGood" + ) XCTAssertNoThrow(try channel.writeInbound(goodResponseWithContent)) var maybeBody: HTTPClientResponsePart? diff --git a/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift b/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift index e8f6b92ce9..4665db11a3 100644 --- a/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPHeaderValidationTests.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import XCTest import Dispatch import NIOCore import NIOEmbedded import NIOHTTP1 +import XCTest final class HTTPHeaderValidationTests: XCTestCase { func testEncodingInvalidHeaderFieldNamesInRequests() throws { @@ -39,7 +39,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Host", "example.com"), (weirdAllowedFieldName, "present")]) let goodRequest = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: headers) - let goodRequestBytes = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\n\(weirdAllowedFieldName): present\r\n\r\n") + let goodRequestBytes = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\n\(weirdAllowedFieldName): present\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.end(nil))) @@ -93,7 +95,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Host", "example.com"), ("Transfer-Encoding", "chunked")]) let goodRequest = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: headers) - let goodRequestBytes = ByteBuffer(string: "POST / HTTP/1.1\r\nHost: example.com\r\ntransfer-encoding: chunked\r\n\r\n") + let goodRequestBytes = ByteBuffer( + string: "POST / HTTP/1.1\r\nHost: example.com\r\ntransfer-encoding: chunked\r\n\r\n" + ) let goodTrailers = ByteBuffer(string: "0\r\n\(weirdAllowedFieldName): present\r\n\r\n") XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) @@ -132,14 +136,17 @@ final class HTTPHeaderValidationTests: XCTestCase { func testEncodingInvalidHeaderFieldValuesInRequests() throws { // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" let channel = EmbeddedChannel() try channel.pipeline.syncOperations.addHTTPClientHandlers() let headers = HTTPHeaders([("Host", "example.com"), ("Weird-Value", weirdAllowedFieldValue)]) let goodRequest = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: headers) - let goodRequestBytes = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Value: \(weirdAllowedFieldValue)\r\n\r\n") + let goodRequestBytes = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Value: \(weirdAllowedFieldValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) @@ -180,7 +187,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Host", "example.com"), ("Weird-Value", evenWeirderAllowedValue)]) let goodRequest = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: headers) - let goodRequestBytes = ByteBuffer(string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n") + let goodRequestBytes = ByteBuffer( + string: "GET / HTTP/1.1\r\nHost: example.com\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) @@ -195,14 +204,17 @@ final class HTTPHeaderValidationTests: XCTestCase { func testEncodingInvalidTrailerFieldValuesInRequests() throws { // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" let channel = EmbeddedChannel() try channel.pipeline.syncOperations.addHTTPClientHandlers() let headers = HTTPHeaders([("Host", "example.com"), ("Transfer-Encoding", "chunked")]) let goodRequest = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: headers) - let goodRequestBytes = ByteBuffer(string: "POST / HTTP/1.1\r\nHost: example.com\r\ntransfer-encoding: chunked\r\n\r\n") + let goodRequestBytes = ByteBuffer( + string: "POST / HTTP/1.1\r\nHost: example.com\r\ntransfer-encoding: chunked\r\n\r\n" + ) let goodTrailers = ByteBuffer(string: "0\r\nWeird-Value: \(weirdAllowedFieldValue)\r\n\r\n") XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) @@ -227,7 +239,6 @@ final class HTTPHeaderValidationTests: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.syncOperations.addHTTPClientHandlers() - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) XCTAssertThrowsError( @@ -249,7 +260,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let weirdGoodTrailers = ByteBuffer(string: "0\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n") XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(goodRequest))) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.end(["Weird-Value": evenWeirderAllowedValue]))) + XCTAssertNoThrow( + try channel.writeOutbound(HTTPClientRequestPart.end(["Weird-Value": evenWeirderAllowedValue])) + ) XCTAssertNoThrow(maybeRequestHeadBytes = try channel.readOutbound()) XCTAssertNoThrow(maybeRequestEndBytes = try channel.readOutbound()) XCTAssertEqual(maybeRequestHeadBytes, goodRequestBytes) @@ -280,7 +293,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Content-Length", "0"), (weirdAllowedFieldName, "present")]) let goodResponse = HTTPResponseHead(version: .http1_1, status: .ok, headers: headers) - let goodResponseBytes = ByteBuffer(string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\(weirdAllowedFieldName): present\r\n\r\n") + let goodResponseBytes = ByteBuffer( + string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\(weirdAllowedFieldName): present\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(goodResponse))) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.end(nil))) @@ -376,7 +391,8 @@ final class HTTPHeaderValidationTests: XCTestCase { func testEncodingInvalidHeaderFieldValuesInResponses() throws { // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" let channel = EmbeddedChannel() try channel.pipeline.syncOperations.configureHTTPServerPipeline(withErrorHandling: false) @@ -384,7 +400,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Content-Length", "0"), ("Weird-Value", weirdAllowedFieldValue)]) let goodResponse = HTTPResponseHead(version: .http1_1, status: .ok, headers: headers) - let goodResponseBytes = ByteBuffer(string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nWeird-Value: \(weirdAllowedFieldValue)\r\n\r\n") + let goodResponseBytes = ByteBuffer( + string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nWeird-Value: \(weirdAllowedFieldValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(goodResponse))) @@ -427,7 +445,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let headers = HTTPHeaders([("Content-Length", "0"), ("Weird-Value", evenWeirderAllowedValue)]) let goodResponse = HTTPResponseHead(version: .http1_1, status: .ok, headers: headers) - let goodResponseBytes = ByteBuffer(string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n") + let goodResponseBytes = ByteBuffer( + string: "HTTP/1.1 200 OK\r\nContent-Length: 0\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(goodResponse))) @@ -442,7 +462,8 @@ final class HTTPHeaderValidationTests: XCTestCase { func testEncodingInvalidTrailerFieldValuesInResponses() throws { // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" let channel = EmbeddedChannel() try channel.pipeline.syncOperations.configureHTTPServerPipeline(withErrorHandling: false) @@ -498,7 +519,9 @@ final class HTTPHeaderValidationTests: XCTestCase { let weirdGoodTrailers = ByteBuffer(string: "0\r\nWeird-Value: \(evenWeirderAllowedValue)\r\n\r\n") XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(goodResponse))) - XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.end(["Weird-Value": evenWeirderAllowedValue]))) + XCTAssertNoThrow( + try channel.writeOutbound(HTTPServerResponsePart.end(["Weird-Value": evenWeirderAllowedValue])) + ) XCTAssertNoThrow(maybeResponseHeadBytes = try channel.readOutbound()) XCTAssertNoThrow(maybeResponseEndBytes = try channel.readOutbound()) XCTAssertEqual(maybeResponseHeadBytes, goodResponseBytes) @@ -515,10 +538,18 @@ final class HTTPHeaderValidationTests: XCTestCase { let channel = EmbeddedChannel() try channel.pipeline.syncOperations.addHTTPClientHandlers(enableOutboundHeaderValidation: false) - let headers = HTTPHeaders([("Host", "example.com"), ("Transfer-Encoding", "chunked"), (invalidHeaderName, invalidHeaderValue)]) + let headers = HTTPHeaders([ + ("Host", "example.com"), ("Transfer-Encoding", "chunked"), (invalidHeaderName, invalidHeaderValue), + ]) let toleratedRequest = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: headers) - let toleratedRequestBytes = ByteBuffer(string: "POST / HTTP/1.1\r\nHost: example.com\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\ntransfer-encoding: chunked\r\n\r\n") - let toleratedTrailerBytes = ByteBuffer(string: "0\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\n\r\n") + let toleratedRequestBytes = ByteBuffer( + string: + "POST / HTTP/1.1\r\nHost: example.com\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\ntransfer-encoding: chunked\r\n\r\n" + ) + let toleratedTrailerBytes = ByteBuffer( + string: + "0\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(toleratedRequest))) XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.end(headers))) @@ -537,13 +568,24 @@ final class HTTPHeaderValidationTests: XCTestCase { let invalidHeaderValue = "HeaderValueWith\rCR" let channel = EmbeddedChannel() - try channel.pipeline.syncOperations.configureHTTPServerPipeline(withErrorHandling: false, withOutboundHeaderValidation: false) + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withErrorHandling: false, + withOutboundHeaderValidation: false + ) try channel.primeForResponse() - let headers = HTTPHeaders([("Host", "example.com"), ("Transfer-Encoding", "chunked"), (invalidHeaderName, invalidHeaderValue)]) + let headers = HTTPHeaders([ + ("Host", "example.com"), ("Transfer-Encoding", "chunked"), (invalidHeaderName, invalidHeaderValue), + ]) let toleratedRequest = HTTPResponseHead(version: .http1_1, status: .ok, headers: headers) - let toleratedRequestBytes = ByteBuffer(string: "HTTP/1.1 200 OK\r\nHost: example.com\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\ntransfer-encoding: chunked\r\n\r\n") - let toleratedTrailerBytes = ByteBuffer(string: "0\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\n\r\n") + let toleratedRequestBytes = ByteBuffer( + string: + "HTTP/1.1 200 OK\r\nHost: example.com\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\ntransfer-encoding: chunked\r\n\r\n" + ) + let toleratedTrailerBytes = ByteBuffer( + string: + "0\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\(invalidHeaderName): \(invalidHeaderValue)\r\n\r\n" + ) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(toleratedRequest))) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.end(headers))) diff --git a/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift b/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift index 028ef6b2cd..bd424f77dc 100644 --- a/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPHeadersTest.swift @@ -12,18 +12,21 @@ // //===----------------------------------------------------------------------===// +import NIOEmbedded import XCTest + @testable import NIOCore -import NIOEmbedded @testable import NIOHTTP1 -class HTTPHeadersTest : XCTestCase { +class HTTPHeadersTest: XCTestCase { func testCasePreservedButInsensitiveLookup() { - let originalHeaders = [ ("User-Agent", "1"), - ("host", "2"), - ("X-SOMETHING", "3"), - ("SET-COOKIE", "foo=bar"), - ("Set-Cookie", "buz=cux")] + let originalHeaders = [ + ("User-Agent", "1"), + ("host", "2"), + ("X-SOMETHING", "3"), + ("SET-COOKIE", "foo=bar"), + ("Set-Cookie", "buz=cux"), + ] let headers = HTTPHeaders(originalHeaders) @@ -33,7 +36,7 @@ class HTTPHeadersTest : XCTestCase { XCTAssertEqual(["2"], headers["Host"]) XCTAssertEqual(["foo=bar", "buz=cux"], headers["set-cookie"]) - for (key,value) in headers { + for (key, value) in headers { switch key { case "User-Agent": XCTAssertEqual("1", value) @@ -52,11 +55,13 @@ class HTTPHeadersTest : XCTestCase { } func testDictionaryLiteralAlternative() { - let headers: HTTPHeaders = [ "User-Agent": "1", - "host": "2", - "X-SOMETHING": "3", - "SET-COOKIE": "foo=bar", - "Set-Cookie": "buz=cux"] + let headers: HTTPHeaders = [ + "User-Agent": "1", + "host": "2", + "X-SOMETHING": "3", + "SET-COOKIE": "foo=bar", + "Set-Cookie": "buz=cux", + ] // looking up headers value is case-insensitive XCTAssertEqual(["1"], headers["User-Agent"]) @@ -64,7 +69,7 @@ class HTTPHeadersTest : XCTestCase { XCTAssertEqual(["2"], headers["Host"]) XCTAssertEqual(["foo=bar", "buz=cux"], headers["set-cookie"]) - for (key,value) in headers { + for (key, value) in headers { switch key { case "User-Agent": XCTAssertEqual("1", value) @@ -83,12 +88,14 @@ class HTTPHeadersTest : XCTestCase { } func testWriteHeadersSeparately() { - let originalHeaders = [ ("User-Agent", "1"), - ("host", "2"), - ("X-SOMETHING", "3"), - ("X-Something", "4"), - ("SET-COOKIE", "foo=bar"), - ("Set-Cookie", "buz=cux")] + let originalHeaders = [ + ("User-Agent", "1"), + ("host", "2"), + ("X-SOMETHING", "3"), + ("X-Something", "4"), + ("SET-COOKIE", "foo=bar"), + ("Set-Cookie", "buz=cux"), + ] let headers = HTTPHeaders(originalHeaders) let channel = EmbeddedChannel() @@ -107,10 +114,12 @@ class HTTPHeadersTest : XCTestCase { } func testRevealHeadersSeparately() { - let originalHeaders = [ ("User-Agent", "1"), - ("host", "2"), - ("X-SOMETHING", "3, 4"), - ("X-Something", "5")] + let originalHeaders = [ + ("User-Agent", "1"), + ("host", "2"), + ("X-SOMETHING", "3, 4"), + ("X-Something", "5"), + ] let headers = HTTPHeaders(originalHeaders) XCTAssertEqual(headers[canonicalForm: "user-agent"], ["1"]) @@ -120,10 +129,12 @@ class HTTPHeadersTest : XCTestCase { } func testSubscriptDoesntSplitHeaders() { - let originalHeaders = [ ("User-Agent", "1"), - ("host", "2"), - ("X-SOMETHING", "3, 4"), - ("X-Something", "5")] + let originalHeaders = [ + ("User-Agent", "1"), + ("host", "2"), + ("X-SOMETHING", "3, 4"), + ("X-Something", "5"), + ] let headers = HTTPHeaders(originalHeaders) XCTAssertEqual(headers["user-agent"], ["1"]) @@ -133,16 +144,23 @@ class HTTPHeadersTest : XCTestCase { } func testCanonicalisationDoesntHappenForSetCookie() { - let originalHeaders = [ ("User-Agent", "1"), - ("host", "2"), - ("Set-Cookie", "foo=bar; expires=Sun, 17-Mar-2013 13:49:50 GMT"), - ("Set-Cookie", "buz=cux; expires=Fri, 13 Oct 2017 21:21:41 GMT")] + let originalHeaders = [ + ("User-Agent", "1"), + ("host", "2"), + ("Set-Cookie", "foo=bar; expires=Sun, 17-Mar-2013 13:49:50 GMT"), + ("Set-Cookie", "buz=cux; expires=Fri, 13 Oct 2017 21:21:41 GMT"), + ] let headers = HTTPHeaders(originalHeaders) XCTAssertEqual(headers[canonicalForm: "user-agent"], ["1"]) XCTAssertEqual(headers[canonicalForm: "host"], ["2"]) - XCTAssertEqual(headers[canonicalForm: "set-cookie"], ["foo=bar; expires=Sun, 17-Mar-2013 13:49:50 GMT", - "buz=cux; expires=Fri, 13 Oct 2017 21:21:41 GMT"]) + XCTAssertEqual( + headers[canonicalForm: "set-cookie"], + [ + "foo=bar; expires=Sun, 17-Mar-2013 13:49:50 GMT", + "buz=cux; expires=Fri, 13 Oct 2017 21:21:41 GMT", + ] + ) } func testTrimWhitespaceWorksOnEmptyString() { @@ -166,9 +184,11 @@ class HTTPHeadersTest : XCTestCase { } func testContains() { - let originalHeaders = [ ("X-Header", "1"), - ("X-SomeHeader", "3"), - ("X-Header", "2")] + let originalHeaders = [ + ("X-Header", "1"), + ("X-SomeHeader", "3"), + ("X-Header", "2"), + ] let headers = HTTPHeaders(originalHeaders) XCTAssertTrue(headers.contains(name: "x-header")) @@ -181,7 +201,7 @@ class HTTPHeadersTest : XCTestCase { (":method", "GET"), ("foo", "bar"), ("foo", "baz"), - ("custom-key", "value-1,value-2") + ("custom-key", "value-1,value-2"), ]) XCTAssertEqual(headers.first(name: ":method"), "GET") @@ -223,24 +243,30 @@ class HTTPHeadersTest : XCTestCase { } func testKeepAliveStateHasKeepAlive() { - let headers = HTTPHeaders([("Connection", "other, keEP-alive"), - ("Content-Type", "text/html"), - ("Connection", "server, x-options")]) + let headers = HTTPHeaders([ + ("Connection", "other, keEP-alive"), + ("Content-Type", "text/html"), + ("Connection", "server, x-options"), + ]) XCTAssertTrue(headers.isKeepAlive(version: .http1_1)) } func testKeepAliveStateHasClose() { - let headers = HTTPHeaders([("Connection", "x-options, other"), - ("Content-Type", "text/html"), - ("Connection", "server, clOse")]) + let headers = HTTPHeaders([ + ("Connection", "x-options, other"), + ("Content-Type", "text/html"), + ("Connection", "server, clOse"), + ]) XCTAssertFalse(headers.isKeepAlive(version: .http1_1)) } func testRandomAccess() { - let originalHeaders = [ ("X-first", "one"), - ("X-second", "two")] + let originalHeaders = [ + ("X-first", "one"), + ("X-second", "two"), + ] let headers = HTTPHeaders(originalHeaders) XCTAssertEqual(headers[headers.startIndex].name, originalHeaders.first?.0) diff --git a/Tests/NIOHTTP1Tests/HTTPRequestEncoderTest.swift b/Tests/NIOHTTP1Tests/HTTPRequestEncoderTest.swift index 0cc3eddb16..3fc4954085 100644 --- a/Tests/NIOHTTP1Tests/HTTPRequestEncoderTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPRequestEncoderTest.swift @@ -11,13 +11,15 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import NIOEmbedded import XCTest + @testable import NIOCore -import NIOEmbedded @testable import NIOHTTP1 -private extension ByteBuffer { - func assertContainsOnly(_ string: String) { +extension ByteBuffer { + fileprivate func assertContainsOnly(_ string: String) { let innerData = self.getString(at: self.readerIndex, length: self.readableBytes)! XCTAssertEqual(innerData, string) } @@ -64,7 +66,11 @@ class HTTPRequestEncoderTests: XCTestCase { } func testNoAutoHeadersForPOSTWhenDisabled() throws { - let writtenData = try sendRequest(withMethod: .POST, andHeaders: HTTPHeaders(), configuration: .noFramingTransformation) + let writtenData = try sendRequest( + withMethod: .POST, + andHeaders: HTTPHeaders(), + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("POST /uri HTTP/1.1\r\n\r\n") } @@ -98,7 +104,11 @@ class HTTPRequestEncoderTests: XCTestCase { func testAllowContentLengthHeadersWhenForced_forTRACE() throws { let headers = HTTPHeaders([("content-length", "0")]) - let writtenData = try sendRequest(withMethod: .TRACE, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = try sendRequest( + withMethod: .TRACE, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("TRACE /uri HTTP/1.1\r\ncontent-length: 0\r\n\r\n") } @@ -110,7 +120,11 @@ class HTTPRequestEncoderTests: XCTestCase { func testAllowTransferEncodingHeadersWhenForced_forTRACE() throws { let headers = HTTPHeaders([("transfer-encoding", "chunked")]) - let writtenData = try sendRequest(withMethod: .TRACE, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = try sendRequest( + withMethod: .TRACE, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("TRACE /uri HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n") } @@ -175,9 +189,17 @@ class HTTPRequestEncoderTests: XCTestCase { var buffer = channel.allocator.buffer(capacity: 16) var expected = channel.allocator.buffer(capacity: 32) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .POST, - uri: "/")))) + XCTAssertNoThrow( + try channel.writeOutbound( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .POST, + uri: "/" + ) + ) + ) + ) expected.writeString("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n") XCTAssertNoThrow(XCTAssertEqual(expected, try channel.readOutbound(as: ByteBuffer.self))) @@ -207,10 +229,18 @@ class HTTPRequestEncoderTests: XCTestCase { var buffer = channel.allocator.buffer(capacity: 16) var expected = channel.allocator.buffer(capacity: 32) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .POST, - uri: "/", - headers: ["TrAnSfEr-encoding": "chuNKED"])))) + XCTAssertNoThrow( + try channel.writeOutbound( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .POST, + uri: "/", + headers: ["TrAnSfEr-encoding": "chuNKED"] + ) + ) + ) + ) expected.writeString("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n") XCTAssertNoThrow(XCTAssertEqual(expected, try channel.readOutbound(as: ByteBuffer.self))) @@ -240,9 +270,17 @@ class HTTPRequestEncoderTests: XCTestCase { var buffer = channel.allocator.buffer(capacity: 16) var expected = channel.allocator.buffer(capacity: 32) - XCTAssertNoThrow(try channel.writeOutbound(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .POST, - uri: "/")))) + XCTAssertNoThrow( + try channel.writeOutbound( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .POST, + uri: "/" + ) + ) + ) + ) expected.writeString("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n") XCTAssertNoThrow(XCTAssertEqual(expected, try channel.readOutbound(as: ByteBuffer.self))) @@ -264,9 +302,16 @@ class HTTPRequestEncoderTests: XCTestCase { var buffer = channel.allocator.buffer(capacity: 16) var expected = channel.allocator.buffer(capacity: 32) - channel.write(HTTPClientRequestPart.head(.init(version: .http1_1, - method: .POST, - uri: "/")), promise: nil) + channel.write( + HTTPClientRequestPart.head( + .init( + version: .http1_1, + method: .POST, + uri: "/" + ) + ), + promise: nil + ) channel.flush() expected.writeString("POST / HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n") XCTAssertNoThrow(XCTAssertEqual(expected, try channel.readOutbound(as: ByteBuffer.self))) @@ -364,7 +409,12 @@ class HTTPRequestEncoderTests: XCTestCase { } try channel.pipeline.addHTTPClientHandlers(encoderConfiguration: .noFramingTransformation).wait() - let request = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/uri", headers: ["transfer-encoding": "chunked"]) + let request = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/uri", + headers: ["transfer-encoding": "chunked"] + ) try channel.writeOutbound(HTTPClientRequestPart.head(request)) guard let headBuffer = try channel.readOutbound(as: ByteBuffer.self) else { XCTFail("Unable to read buffer") @@ -399,9 +449,13 @@ class HTTPRequestEncoderTests: XCTestCase { } private func assertOutboundContainsOnly(_ channel: EmbeddedChannel, _ expected: String) { - XCTAssertNoThrow(XCTAssertNotNil(try channel.readOutbound(as: ByteBuffer.self).map { buffer in - buffer.assertContainsOnly(expected) - }, "couldn't read ByteBuffer")) + XCTAssertNoThrow( + XCTAssertNotNil( + try channel.readOutbound(as: ByteBuffer.self).map { buffer in + buffer.assertContainsOnly(expected) + }, + "couldn't read ByteBuffer" + ) + ) } } - diff --git a/Tests/NIOHTTP1Tests/HTTPResponseEncoderTest.swift b/Tests/NIOHTTP1Tests/HTTPResponseEncoderTest.swift index 7ac8031f51..95f67a9c1d 100644 --- a/Tests/NIOHTTP1Tests/HTTPResponseEncoderTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPResponseEncoderTest.swift @@ -11,13 +11,15 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// + +import NIOEmbedded import XCTest + @testable import NIOCore -import NIOEmbedded @testable import NIOHTTP1 -private extension ByteBuffer { - func assertContainsOnly(_ string: String) { +extension ByteBuffer { + fileprivate func assertContainsOnly(_ string: String) { let innerData = self.getString(at: self.readerIndex, length: self.readableBytes)! XCTAssertEqual(innerData, string) } @@ -42,7 +44,9 @@ class HTTPResponseEncoderTests: XCTestCase { XCTAssertEqual(true, try? channel.finish().isClean) } - XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(HTTPResponseEncoder(configuration: configuration))) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler(HTTPResponseEncoder(configuration: configuration)) + ) var switchingResponse = HTTPResponseHead(version: .http1_1, status: status) switchingResponse.headers = headers XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(switchingResponse))) @@ -77,7 +81,11 @@ class HTTPResponseEncoderTests: XCTestCase { } func testNoAutoHeadersWhenDisabled() throws { - let writtenData = sendResponse(withStatus: .ok, andHeaders: HTTPHeaders(), configuration: .noFramingTransformation) + let writtenData = sendResponse( + withStatus: .ok, + andHeaders: HTTPHeaders(), + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("HTTP/1.1 200 OK\r\n\r\n") } @@ -89,7 +97,11 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowContentLengthHeadersWhenForced_for101() throws { let headers = HTTPHeaders([("content-length", "0")]) - let writtenData = sendResponse(withStatus: .switchingProtocols, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = sendResponse( + withStatus: .switchingProtocols, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("HTTP/1.1 101 Switching Protocols\r\ncontent-length: 0\r\n\r\n") } @@ -101,8 +113,14 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowContentLengthHeadersWhenForced_forCustom1XX() throws { let headers = HTTPHeaders([("Link", "; rel=preload; as=style"), ("content-length", "0")]) - let writtenData = sendResponse(withStatus: .custom(code: 103, reasonPhrase: "Early Hints"), andHeaders: headers, configuration: .noFramingTransformation) - writtenData.assertContainsOnly("HTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\ncontent-length: 0\r\n\r\n") + let writtenData = sendResponse( + withStatus: .custom(code: 103, reasonPhrase: "Early Hints"), + andHeaders: headers, + configuration: .noFramingTransformation + ) + writtenData.assertContainsOnly( + "HTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\ncontent-length: 0\r\n\r\n" + ) } func testNoContentLengthHeadersFor204() throws { @@ -113,7 +131,11 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowContentLengthHeadersWhenForced_For204() throws { let headers = HTTPHeaders([("content-length", "0")]) - let writtenData = sendResponse(withStatus: .noContent, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = sendResponse( + withStatus: .noContent, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("HTTP/1.1 204 No Content\r\ncontent-length: 0\r\n\r\n") } @@ -131,7 +153,11 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowTransferEncodingHeadersWhenForced_for101() throws { let headers = HTTPHeaders([("transfer-encoding", "chunked")]) - let writtenData = sendResponse(withStatus: .switchingProtocols, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = sendResponse( + withStatus: .switchingProtocols, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("HTTP/1.1 101 Switching Protocols\r\ntransfer-encoding: chunked\r\n\r\n") } @@ -143,8 +169,14 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowTransferEncodingHeadersWhenForced_forCustom1XX() throws { let headers = HTTPHeaders([("Link", "; rel=preload; as=style"), ("transfer-encoding", "chunked")]) - let writtenData = sendResponse(withStatus: .custom(code: 103, reasonPhrase: "Early Hints"), andHeaders: headers, configuration: .noFramingTransformation) - writtenData.assertContainsOnly("HTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\ntransfer-encoding: chunked\r\n\r\n") + let writtenData = sendResponse( + withStatus: .custom(code: 103, reasonPhrase: "Early Hints"), + andHeaders: headers, + configuration: .noFramingTransformation + ) + writtenData.assertContainsOnly( + "HTTP/1.1 103 Early Hints\r\nLink: ; rel=preload; as=style\r\ntransfer-encoding: chunked\r\n\r\n" + ) } func testNoTransferEncodingHeadersFor204() throws { @@ -155,7 +187,11 @@ class HTTPResponseEncoderTests: XCTestCase { func testAllowTransferEncodingHeadersWhenForced_for204() throws { let headers = HTTPHeaders([("transfer-encoding", "chunked")]) - let writtenData = sendResponse(withStatus: .noContent, andHeaders: headers, configuration: .noFramingTransformation) + let writtenData = sendResponse( + withStatus: .noContent, + andHeaders: headers, + configuration: .noFramingTransformation + ) writtenData.assertContainsOnly("HTTP/1.1 204 No Content\r\ntransfer-encoding: chunked\r\n\r\n") } @@ -190,7 +226,9 @@ class HTTPResponseEncoderTests: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withEncoderConfiguration: .noFramingTransformation).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline(withEncoderConfiguration: .noFramingTransformation).wait() + ) let request = ByteBuffer(string: "GET / HTTP/1.1\r\n\r\n") XCTAssertNoThrow(try channel.writeInbound(request)) @@ -211,7 +249,11 @@ class HTTPResponseEncoderTests: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - XCTAssertNoThrow(try channel.pipeline.syncOperations.configureHTTPServerPipeline(withEncoderConfiguration: .noFramingTransformation)) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withEncoderConfiguration: .noFramingTransformation + ) + ) let request = ByteBuffer(string: "GET / HTTP/1.1\r\n\r\n") XCTAssertNoThrow(try channel.writeInbound(request)) @@ -232,7 +274,11 @@ class HTTPResponseEncoderTests: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - XCTAssertNoThrow(try channel.pipeline.syncOperations.configureHTTPServerPipeline(withEncoderConfiguration: .noFramingTransformation)) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withEncoderConfiguration: .noFramingTransformation + ) + ) let request = ByteBuffer(string: "GET / HTTP/1.1\r\n\r\n") XCTAssertNoThrow(try channel.writeInbound(request)) @@ -268,7 +314,11 @@ class HTTPResponseEncoderTests: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - XCTAssertNoThrow(try channel.pipeline.syncOperations.configureHTTPServerPipeline(withEncoderConfiguration: .noFramingTransformation)) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.configureHTTPServerPipeline( + withEncoderConfiguration: .noFramingTransformation + ) + ) let request = ByteBuffer(string: "GET / HTTP/1.1\r\n\r\n") XCTAssertNoThrow(try channel.writeInbound(request)) diff --git a/Tests/NIOHTTP1Tests/HTTPResponseStatusTests.swift b/Tests/NIOHTTP1Tests/HTTPResponseStatusTests.swift index 00e3b49a15..7bee0d34ce 100644 --- a/Tests/NIOHTTP1Tests/HTTPResponseStatusTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPResponseStatusTests.swift @@ -85,6 +85,9 @@ class HTTPResponseStatusTests: XCTestCase { func testHTTPResponseStatusCodeAndReason() { XCTAssertEqual("\(HTTPResponseStatus.ok)", "200 OK") XCTAssertEqual("\(HTTPResponseStatus.imATeapot)", "418 I'm a teapot") - XCTAssertEqual("\(HTTPResponseStatus.custom(code: 347, reasonPhrase: "I like ice cream"))", "347 I like ice cream") + XCTAssertEqual( + "\(HTTPResponseStatus.custom(code: 347, reasonPhrase: "I like ice cream"))", + "347 I like ice cream" + ) } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift index f3ab3977c5..94ab77a4b2 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerClientTest.swift @@ -12,19 +12,20 @@ // //===----------------------------------------------------------------------===// -import XCTest -import NIOCore -import NIOPosix +import Dispatch import NIOConcurrencyHelpers +import NIOCore import NIOFoundationCompat -import Dispatch +import NIOPosix +import XCTest + @testable import NIOHTTP1 extension Array where Array.Element == ByteBuffer { public func allAsBytes() -> [UInt8] { var out: [UInt8] = [] out.reserveCapacity(self.reduce(0, { $0 + $1.readableBytes })) - self.forEach { bb in + for bb in self { bb.withUnsafeReadableBytes { ptr in out.append(contentsOf: ptr) } @@ -33,7 +34,7 @@ extension Array where Array.Element == ByteBuffer { } public func allAsString() -> String? { - return String(decoding: self.allAsBytes(), as: Unicode.UTF8.self) + String(decoding: self.allAsBytes(), as: Unicode.UTF8.self) } } @@ -61,11 +62,11 @@ internal class ArrayAccumulationHandler: ChannelInboundHandler { } } -class HTTPServerClientTest : XCTestCase { - /* needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface */ +class HTTPServerClientTest: XCTestCase { + // needs to be something reasonably large and odd so it has good odds producing incomplete writes even on the loopback interface private static let massiveResponseLength = 1 * 1024 * 1024 + 7 private static let massiveResponseBytes: [UInt8] = { - return Array(repeating: 0xff, count: HTTPServerClientTest.massiveResponseLength) + Array(repeating: 0xff, count: HTTPServerClientTest.massiveResponseLength) }() enum SendMode { @@ -87,7 +88,7 @@ class HTTPServerClientTest : XCTestCase { self.mode = mode } - private func outboundBody(_ buffer: ByteBuffer) -> (body: HTTPServerResponsePart, destructor: () -> Void) { + private func outboundBody(_ buffer: ByteBuffer) -> (body: HTTPServerResponsePart, destructor: () -> Void) { switch mode { case .byteBuffer: return (.body(.byteBuffer(buffer)), { () in }) @@ -98,9 +99,11 @@ class HTTPServerClientTest : XCTestCase { let content = buffer.getData(at: 0, length: buffer.readableBytes)! XCTAssertNoThrow(try content.write(to: URL(fileURLWithPath: filePath))) let fh = try! NIOFileHandle(path: filePath) - let region = FileRegion(fileHandle: fh, - readerIndex: 0, - endIndex: buffer.readableBytes) + let region = FileRegion( + fileHandle: fh, + readerIndex: 0, + endIndex: buffer.readableBytes + ) return (.body(.fileRegion(region)), { try! fh.close() }) } } @@ -247,23 +250,24 @@ class HTTPServerClientTest : XCTestCase { } context.write(Self.wrapOutboundOut(.end(nil))).recover { error in XCTFail("unexpected error \(error)") - }.whenComplete { (_: Result) in - self.sentEnd = true - self.maybeClose(context: context) + }.whenComplete { (_: Result) in + self.sentEnd = true + self.maybeClose(context: context) } case "/zero-length-body-part": - + let r = HTTPServerResponsePart.head(.init(version: req.version, status: .ok)) context.write(Self.wrapOutboundOut(r)).whenFailure { error in XCTFail("unexpected error \(error)") } - + context.writeAndFlush(Self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer())))).whenFailure { error in XCTFail("unexpected error \(error)") } - context.writeAndFlush(Self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer(string: "Hello World"))))).whenFailure { error in - XCTFail("unexpected error \(error)") - } + context.writeAndFlush(Self.wrapOutboundOut(.body(.byteBuffer(ByteBuffer(string: "Hello World"))))) + .whenFailure { error in + XCTFail("unexpected error \(error)") + } context.write(Self.wrapOutboundOut(.end(nil))).recover { error in XCTFail("unexpected error \(error)") }.whenComplete { (_: Result) in @@ -306,7 +310,13 @@ class HTTPServerClientTest : XCTestCase { } private class HTTPClientResponsePartAssertHandler: ArrayAccumulationHandler { - public init(_ expectedVersion: HTTPVersion, _ expectedStatus: HTTPResponseStatus, _ expectedHeaders: HTTPHeaders, _ expectedBody: String?, _ expectedTrailers: HTTPHeaders? = nil) { + public init( + _ expectedVersion: HTTPVersion, + _ expectedStatus: HTTPResponseStatus, + _ expectedHeaders: HTTPHeaders, + _ expectedBody: String?, + _ expectedTrailers: HTTPHeaders? = nil + ) { super.init { parts in guard parts.count >= 2 else { XCTFail("only \(parts.count) parts") @@ -344,10 +354,12 @@ class HTTPServerClientTest : XCTestCase { } } - private func testSimpleGet(_ mode: SendMode, - httpVersion: HTTPVersion = .http1_1, - uri: String = "/helloworld", - expectedHeaders maybeExpectedHeaders: HTTPHeaders? = nil) throws { + private func testSimpleGet( + _ mode: SendMode, + httpVersion: HTTPVersion = .http1_1, + uri: String = "/helloworld", + expectedHeaders maybeExpectedHeaders: HTTPHeaders? = nil + ) throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) @@ -357,29 +369,33 @@ class HTTPServerClientTest : XCTestCase { let accumulation = HTTPClientResponsePartAssertHandler(httpVersion, .ok, expectedHeaders, "Hello World!\r\n") let httpHandler = SimpleHTTPServer(mode) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are appled to the accepted Channels - .childChannelInitializer { channel in - // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) @@ -414,29 +430,33 @@ class HTTPServerClientTest : XCTestCase { let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "12345678910") let httpHandler = SimpleHTTPServer(mode) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are appled to the accepted Channels - .childChannelInitializer { channel in - // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) @@ -456,7 +476,7 @@ class HTTPServerClientTest : XCTestCase { func testSimpleGetTrailersFileRegion() throws { try testSimpleGetTrailers(.fileRegion) } - + func testSimpleGetChunkedEncodingWithZeroLengthBodyPart() throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { @@ -469,29 +489,33 @@ class HTTPServerClientTest : XCTestCase { let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "Hello World") let httpHandler = SimpleHTTPServer(.byteBuffer) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are appled to the accepted Channels - .childChannelInitializer { channel in - // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: true).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) @@ -518,29 +542,39 @@ class HTTPServerClientTest : XCTestCase { expectedTrailers.add(name: "x-url-path", value: "/trailers") expectedTrailers.add(name: "x-should-trail", value: "sure") - let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "12345678910", expectedTrailers) + let accumulation = HTTPClientResponsePartAssertHandler( + .http1_1, + .ok, + expectedHeaders, + "12345678910", + expectedTrailers + ) let httpHandler = SimpleHTTPServer(mode) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } @@ -577,24 +611,28 @@ class HTTPServerClientTest : XCTestCase { } let numBytes = 16 * 1024 let httpHandler = SimpleHTTPServer(mode) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - // Set the handlers that are appled to the accepted Channels - .childChannelInitializer { channel in - // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + // Set the handlers that are appled to the accepted Channels + .childChannelInitializer { channel in + // Ensure we don't read faster then we can write by adding the BackPressureHandler into the pipeline. + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer({ $0.pipeline.addHandler(accumulation) }) - .connect(to: serverChannel.localAddress!) - .wait()) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer({ $0.pipeline.addHandler(accumulation) }) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } @@ -619,25 +657,29 @@ class HTTPServerClientTest : XCTestCase { let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .ok, expectedHeaders, "") let httpHandler = SimpleHTTPServer(.byteBuffer) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) @@ -663,25 +705,29 @@ class HTTPServerClientTest : XCTestCase { let accumulation = HTTPClientResponsePartAssertHandler(.http1_1, .noContent, expectedHeaders, "") let httpHandler = SimpleHTTPServer(.byteBuffer) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { - channel.pipeline.addHandler(httpHandler) - } - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false).flatMap { + channel.pipeline.addHandler(httpHandler) + } + }.bind(host: "127.0.0.1", port: 0).wait() + ) defer { XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } - let clientChannel = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - channel.pipeline.addHTTPClientHandlers().flatMap { - channel.pipeline.addHandler(accumulation) + let clientChannel = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + channel.pipeline.addHTTPClientHandlers().flatMap { + channel.pipeline.addHandler(accumulation) + } } - } - .connect(to: serverChannel.localAddress!) - .wait()) + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try clientChannel.syncCloseAcceptingAlreadyClosed()) } @@ -695,9 +741,13 @@ class HTTPServerClientTest : XCTestCase { } func testNoResponseHeaders() { - XCTAssertNoThrow(try self.testSimpleGet(.byteBuffer, - httpVersion: .http1_0, - uri: "/no-headers", - expectedHeaders: [:])) + XCTAssertNoThrow( + try self.testSimpleGet( + .byteBuffer, + httpVersion: .http1_0, + uri: "/no-headers", + expectedHeaders: [:] + ) + ) } } diff --git a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift index 1298ab2ee7..7029e2f73e 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerPipelineHandlerTest.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded +import XCTest + @testable import NIOHTTP1 private final class ReadRecorder: ChannelInboundHandler { @@ -24,7 +25,7 @@ private final class ReadRecorder: ChannelInboundHandler { case channelRead(InboundIn) case halfClose - static func ==(lhs: Event, rhs: Event) -> Bool { + static func == (lhs: Event, rhs: Event) -> Bool { switch (lhs, rhs) { case (.channelRead(let b1), .channelRead(let b2)): return b1 == b2 @@ -77,7 +78,6 @@ private final class ReadCountingHandler: ChannelOutboundHandler { } } - private final class QuiesceEventRecorder: ChannelInboundHandler { typealias InboundIn = Any typealias InboundOut = Any @@ -107,7 +107,6 @@ private final class CloseOutputSuppressor: ChannelOutboundHandler { } } - class HTTPServerPipelineHandlerTest: XCTestCase { var channel: EmbeddedChannel! = nil var requestHead: HTTPRequestHead! = nil @@ -165,33 +164,45 @@ class HTTPServerPipelineHandlerTest: XCTestCase { } // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Unblock by sending a response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // Two requests should have made it through now. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now send the last response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // Now all three. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) } func testReadCallsAreSuppressedWhenPipelining() throws { @@ -290,34 +301,46 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) // Only one request should have made it through, no half-closure yet. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Unblock by sending a response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // Two requests should have made it through now. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now send the last response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // Now the half-closure should be delivered. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .halfClose]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .halfClose, + ] + ) } func testPipelineHandlerWillDeliverHalfCloseEarly() throws { @@ -330,20 +353,28 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) // Only one request should have made it through, no half-closure yet. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Unblock by sending a response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // The second request head, followed by the half-close, should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .halfClose]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .halfClose, + ] + ) } func testAReadIsNotIssuedWhenUnbufferingAHalfCloseAfterRequestComplete() throws { @@ -385,31 +416,43 @@ class HTTPServerPipelineHandlerTest: XCTestCase { } // Only one request should have made it through, no half-closure yet. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Unblock by sending a response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // Two requests should have made it through now. Still no half-closure. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now send the half-closure. self.channel.pipeline.fireUserInboundEventTriggered(ChannelEvent.inputClosed) // The half-closure should be delivered immediately. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .halfClose]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .halfClose, + ] + ) } func testRecursiveChannelReadInvocationsDoNotCauseIssues() throws { @@ -465,7 +508,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { default: XCTFail("didn't expect \(head)") } - context.write(Self.wrapOutboundOut(.head(HTTPResponseHead(version: .http1_1, status: .ok))), promise: nil) + context.write( + Self.wrapOutboundOut(.head(HTTPResponseHead(version: .http1_1, status: .ok))), + promise: nil + ) if sendEnd { context.write(Self.wrapOutboundOut(.end(nil)), promise: nil) } @@ -478,7 +524,11 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.state = .req3HeadExpected // this will cause `channelRead` to be recursively called and we need to make sure everything then still works - try! (context.channel as! EmbeddedChannel).writeInbound(HTTPServerRequestPart.head(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/req_boom"))) + try! (context.channel as! EmbeddedChannel).writeInbound( + HTTPServerRequestPart.head( + HTTPRequestHead(version: .http1_1, method: .GET, uri: "/req_boom") + ) + ) try! (context.channel as! EmbeddedChannel).writeInbound(HTTPServerRequestPart.end(nil)) case .req3EndExpected: self.state = .reqBoomHeadExpected @@ -497,7 +547,9 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.pipeline.addHandler(handler).wait()) for f in 1...3 { - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(makeRequestHead(uri: "/req_\(f)")))) + XCTAssertNoThrow( + try self.channel.writeInbound(HTTPServerRequestPart.head(makeRequestHead(uri: "/req_\(f)"))) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) } @@ -524,18 +576,26 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) // The request should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now send a response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) // No further events should have happened. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertTrue(self.channel.isActive) self.channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -546,8 +606,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { func testQuiescingInTheMiddleOfARequestNoResponseBitsYet() throws { // Send through only the head. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead))]) + XCTAssertEqual( + self.readRecorder.reads, + [.channelRead(HTTPServerRequestPart.head(self.requestHead))] + ) XCTAssertTrue(self.channel.isActive) self.channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -562,9 +624,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { var reqWithConnectionClose: HTTPResponseHead = self.responseHead reqWithConnectionClose.headers.add(name: "connection", value: "close") - XCTAssertEqual([HTTPServerResponsePart.head(reqWithConnectionClose), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(reqWithConnectionClose), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -577,9 +643,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertTrue(self.channel.isActive) self.channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -593,9 +663,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { var reqWithConnectionClose: HTTPResponseHead = self.responseHead reqWithConnectionClose.headers.add(name: "connection", value: "close") - XCTAssertEqual([HTTPServerResponsePart.head(reqWithConnectionClose), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(reqWithConnectionClose), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) XCTAssertFalse(self.channel.isActive) XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 0) @@ -606,9 +680,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now send the response .head. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) @@ -621,9 +699,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) XCTAssertFalse(self.channel.isActive) - XCTAssertEqual([HTTPServerResponsePart.head(self.responseHead), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(self.responseHead), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) XCTAssertFalse(self.channel.isActive) XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 0) @@ -633,8 +715,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Send through a request .head. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead))]) + XCTAssertEqual( + self.readRecorder.reads, + [.channelRead(HTTPServerRequestPart.head(self.requestHead))] + ) // Now send the response .head. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) @@ -646,18 +730,26 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Request .end. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertTrue(self.channel.isActive) // Response .end. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) XCTAssertFalse(self.channel.isActive) - XCTAssertEqual([HTTPServerResponsePart.head(self.responseHead), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(self.responseHead), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -669,8 +761,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Send through a request .head. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead))]) + XCTAssertEqual( + self.readRecorder.reads, + [.channelRead(HTTPServerRequestPart.head(self.requestHead))] + ) // Now send the response .head. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) @@ -686,15 +780,23 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Request .end. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertFalse(self.channel.isActive) - XCTAssertEqual([HTTPServerResponsePart.head(self.responseHead), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(self.responseHead), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -710,9 +812,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { } // Check that only one request came through - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertTrue(self.channel.isActive) self.channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) @@ -728,14 +834,22 @@ class HTTPServerPipelineHandlerTest: XCTestCase { reqWithConnectionClose.headers.add(name: "connection", value: "close") // check that only one response (with connection: close) came through - XCTAssertEqual([HTTPServerResponsePart.head(reqWithConnectionClose), - HTTPServerResponsePart.end(nil)], - self.writeRecorder.writes) + XCTAssertEqual( + [ + HTTPServerResponsePart.head(reqWithConnectionClose), + HTTPServerResponsePart.end(nil), + ], + self.writeRecorder.writes + ) // Check that only one request came through - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertFalse(self.channel.isActive) XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 0) @@ -803,9 +917,17 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // We dispatch this to the event loop so that it doesn't happen immediately but rather can be // run from the driving test code whenever it wants by running the EmbeddedEventLoop. context.eventLoop.execute { - context.writeAndFlush(Self.wrapOutboundOut(.head(.init(version: .http1_1, - status: .ok))), - promise: nil) + context.writeAndFlush( + Self.wrapOutboundOut( + .head( + .init( + version: .http1_1, + status: .ok + ) + ) + ), + promise: nil + ) } XCTAssertEqual(.reqHeadExpected, self.state) self.state = .reqEndExpected @@ -897,18 +1019,26 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Remove the handler. XCTAssertNoThrow(try channel.pipeline.syncOperations.removeHandler(self.pipelineHandler).wait()) // The extra data should have been forwarded. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil)), - .channelRead(HTTPServerRequestPart.head(self.requestHead))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + ] + ) } func testQuiescingInAResponseThenRemovedFiresEventAndReads() throws { @@ -923,9 +1053,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) XCTAssertTrue(self.channel.isActive) XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 0) @@ -962,8 +1096,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Send through just the head. XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead))]) + XCTAssertEqual( + self.readRecorder.reads, + [.channelRead(HTTPServerRequestPart.head(self.requestHead))] + ) XCTAssertTrue(self.channel.isActive) XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 0) @@ -989,7 +1125,7 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertEqual(self.quiesceEventRecorder.quiesceCount, 1) XCTAssertEqual(self.readCounter.readCount, 3) } - + func testServerCanRespondContinue() throws { // Send in the first part of a request. var expect100ContinueHead = self.requestHead! @@ -1008,51 +1144,51 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Now the server sends the final response. XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) - } + } func testServerCanRespondProcessingMultipleTimes() throws { - // Send in a request. - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) + // Send in a request. + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - // We haven't completed our response, so no more reading - XCTAssertEqual(self.readCounter.readCount, 0) - self.channel.read() - XCTAssertEqual(self.readCounter.readCount, 0) + // We haven't completed our response, so no more reading + XCTAssertEqual(self.readCounter.readCount, 0) + self.channel.read() + XCTAssertEqual(self.readCounter.readCount, 0) - var processResponse: HTTPResponseHead = self.responseHead! - processResponse.status = .processing + var processResponse: HTTPResponseHead = self.responseHead! + processResponse.status = .processing - // Now the server sends multiple processing responses. - XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) + // Now the server sends multiple processing responses. + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) - // We are processing... Reading not allowed - XCTAssertEqual(self.readCounter.readCount, 0) - self.channel.read() - XCTAssertEqual(self.readCounter.readCount, 0) + // We are processing... Reading not allowed + XCTAssertEqual(self.readCounter.readCount, 0) + self.channel.read() + XCTAssertEqual(self.readCounter.readCount, 0) - // Continue processing... - XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) + // Continue processing... + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) - // We are processing... Reading not allowed - XCTAssertEqual(self.readCounter.readCount, 0) - self.channel.read() - XCTAssertEqual(self.readCounter.readCount, 0) + // We are processing... Reading not allowed + XCTAssertEqual(self.readCounter.readCount, 0) + self.channel.read() + XCTAssertEqual(self.readCounter.readCount, 0) - // Continue processing... - XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) + // Continue processing... + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(processResponse)).wait()) - // We are processing... Reading not allowed - XCTAssertEqual(self.readCounter.readCount, 0) - self.channel.read() - XCTAssertEqual(self.readCounter.readCount, 0) + // We are processing... Reading not allowed + XCTAssertEqual(self.readCounter.readCount, 0) + self.channel.read() + XCTAssertEqual(self.readCounter.readCount, 0) - // Now send the actual response! - XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) - XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) + // Now send the actual response! + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.head(self.responseHead)).wait()) + XCTAssertNoThrow(try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait()) - // This should have triggered a read - XCTAssertEqual(self.readCounter.readCount, 1) + // This should have triggered a read + XCTAssertEqual(self.readCounter.readCount, 1) } func testServerCloseOutputForcesReadsBackOn() throws { @@ -1065,9 +1201,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.read() XCTAssertEqual(self.readCounter.readCount, 0) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now the server sends close output XCTAssertNoThrow(try channel.close(mode: .output).wait()) @@ -1088,9 +1228,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.read() XCTAssertEqual(self.readCounter.readCount, 0) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Now the server sends close output XCTAssertNoThrow(try channel.close(mode: .output).wait()) @@ -1109,9 +1253,13 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.read() XCTAssertEqual(self.readCounter.readCount, 3) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) } func testCloseOutputFirstIsOkEvenIfItsABitWeird() throws { @@ -1154,16 +1302,24 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.channel.read() XCTAssertEqual(self.readCounter.readCount, 0) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) // Server sends close mode output. The buffered requests are dropped. XCTAssertNoThrow(try channel.close(mode: .output).wait()) - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(HTTPServerRequestPart.head(self.requestHead)), - .channelRead(HTTPServerRequestPart.end(nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead(HTTPServerRequestPart.head(self.requestHead)), + .channelRead(HTTPServerRequestPart.end(nil)), + ] + ) } func testWritesAfterCloseOutputAreDropped() throws { @@ -1188,7 +1344,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) // Sending a head twice is an error XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) { error in - XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "received request head in state requestAndResponseEndPending")) + XCTAssertEqual( + error as? HTTPServerPipelineHandler.ConnectionStateError, + .preconditionViolated(message: "received request head in state requestAndResponseEndPending") + ) } } @@ -1197,12 +1356,18 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // Writing an end whilst in state idle is an error XCTAssertThrowsError(try self.channel.writeOutbound(HTTPServerResponsePart.end(nil))) { error in - XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + XCTAssertEqual( + error as? HTTPServerPipelineHandler.ConnectionStateError, + .preconditionViolated(message: "Unexpectedly received a response in state idle") + ) } // Calling finish surfaces the error again XCTAssertThrowsError(try self.channel.finish()) { error in - XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Unexpectedly received a response in state idle")) + XCTAssertEqual( + error as? HTTPServerPipelineHandler.ConnectionStateError, + .preconditionViolated(message: "Unexpectedly received a response in state idle") + ) } } @@ -1210,7 +1375,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { self.pipelineHandler.failOnPreconditions = false // End sending a request which was never started XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in - XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + XCTAssertEqual( + error as? HTTPServerPipelineHandler.ConnectionStateError, + .preconditionViolated(message: "Received second request") + ) } } @@ -1219,7 +1387,10 @@ class HTTPServerPipelineHandlerTest: XCTestCase { // End sending a request which was never started XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in - XCTAssertEqual(error as? HTTPServerPipelineHandler.ConnectionStateError, .preconditionViolated(message: "Received second request")) + XCTAssertEqual( + error as? HTTPServerPipelineHandler.ConnectionStateError, + .preconditionViolated(message: "Received second request") + ) } // The handler should now refuse further io, and forcefully shutdown XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) diff --git a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift index cd39f43450..93fcaf05fe 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerProtocolErrorHandlerTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOHTTP1 +import XCTest class HTTPServerProtocolErrorHandlerTest: XCTestCase { func testHandlesBasicErrors() throws { @@ -54,9 +54,11 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound())) // Check the response. - assertResponseIs(response: written.readString(length: written.readableBytes)!, - expectedResponseLine: "HTTP/1.1 400 Bad Request", - expectedResponseHeaders: ["Connection: close", "Content-Length: 0"]) + assertResponseIs( + response: written.readString(length: written.readableBytes)!, + expectedResponseLine: "HTTP/1.1 400 Bad Request", + expectedResponseHeaders: ["Connection: close", "Content-Length: 0"] + ) } func testIgnoresNonParserErrors() throws { @@ -80,10 +82,17 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { XCTAssertNoThrow(try channel.finish()) } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait()) - let res = HTTPServerResponsePart.head(.init(version: .http1_1, - status: .ok, - headers: .init([("Content-Length", "0")]))) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true) + .wait() + ) + let res = HTTPServerResponsePart.head( + .init( + version: .http1_1, + status: .ok, + headers: .init([("Content-Length", "0")]) + ) + ) XCTAssertNoThrow(try channel.writeAndFlush(res).wait()) // now we have started a response but it's not complete yet, let's inject a parser error channel.pipeline.fireErrorCaught(HTTPParserError.invalidEOFState) @@ -114,9 +123,13 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { case .head: XCTAssertEqual(.head, self.nextExpected) self.nextExpected = .end - let res = HTTPServerResponsePart.head(.init(version: .http1_1, - status: .ok, - headers: .init([("Content-Length", "0")]))) + let res = HTTPServerResponsePart.head( + .init( + version: .http1_1, + status: .ok, + headers: .init([("Content-Length", "0")]) + ) + ) context.writeAndFlush(Self.wrapOutboundOut(res), promise: nil) default: XCTAssertEqual(.end, self.nextExpected) @@ -124,12 +137,13 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { } } - } let channel = EmbeddedChannel() - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { - channel.pipeline.addHandler(DelayWriteHandler()) - }.wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { + channel.pipeline.addHandler(DelayWriteHandler()) + }.wait() + ) var buffer = channel.allocator.buffer(capacity: 1024) buffer.writeStaticString("GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\nGET / HT") @@ -149,69 +163,99 @@ class HTTPServerProtocolErrorHandlerTest: XCTestCase { XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound())) // Check the response. - assertResponseIs(response: written.readString(length: written.readableBytes)!, - expectedResponseLine: "HTTP/1.1 200 OK", - expectedResponseHeaders: ["Content-Length: 0"]) + assertResponseIs( + response: written.readString(length: written.readableBytes)!, + expectedResponseLine: "HTTP/1.1 200 OK", + expectedResponseHeaders: ["Content-Length: 0"] + ) } - + func testDoesSendAResponseIfInformationalHeaderWasSent() throws { let channel = EmbeddedChannel() defer { XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: false)) } - - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait()) + + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true) + .wait() + ) XCTAssertNoThrow(try channel.connect(to: .makeAddressResolvingHost("127.0.0.1", port: 0)).wait()) - + // Send an head that expects a continue informational response let reqHeadBytes = "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nExpect: 100-continue\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: reqHeadBytes))) - let expectedHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["Transfer-Encoding":"chunked", "Expect":"100-continue"]) + let expectedHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: ["Transfer-Encoding": "chunked", "Expect": "100-continue"] + ) XCTAssertEqual(try channel.readInbound(as: HTTPServerRequestPart.self), .head(expectedHead)) - + // Respond with continue informational response let continueResponse = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(continueResponse))) - XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "HTTP/1.1 100 Continue\r\n\r\n")) - + XCTAssertEqual( + try channel.readOutbound(as: ByteBuffer.self), + ByteBuffer(string: "HTTP/1.1 100 Continue\r\n\r\n") + ) + // Expects a hex digit... But receives garbage XCTAssertThrowsError(try channel.writeInbound(ByteBuffer(string: "xyz"))) { XCTAssertEqual($0 as? HTTPParserError, .invalidChunkSize) } - + // Receive a bad request - XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "HTTP/1.1 400 Bad Request\r\nConnection: close\r\nContent-Length: 0\r\n\r\n")) + XCTAssertEqual( + try channel.readOutbound(as: ByteBuffer.self), + ByteBuffer(string: "HTTP/1.1 400 Bad Request\r\nConnection: close\r\nContent-Length: 0\r\n\r\n") + ) } - + func testDoesNotSendAResponseIfRealHeaderWasSentAfterInformationalHeader() throws { let channel = EmbeddedChannel() defer { XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: false)) } - + XCTAssertNoThrow(try channel.connect(to: .makeAddressResolvingHost("127.0.0.1", port: 0)).wait()) - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait()) - + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true) + .wait() + ) + // Send an head that expects a continue informational response let reqHeadBytes = "POST / HTTP/1.1\r\nTransfer-Encoding: chunked\r\nExpect: 100-continue\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: reqHeadBytes))) - let expectedHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["Transfer-Encoding":"chunked", "Expect":"100-continue"]) + let expectedHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: ["Transfer-Encoding": "chunked", "Expect": "100-continue"] + ) XCTAssertEqual(try channel.readInbound(as: HTTPServerRequestPart.self), .head(expectedHead)) - + // Respond with continue informational response let continueResponse = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(continueResponse))) - XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "HTTP/1.1 100 Continue\r\n\r\n")) - + XCTAssertEqual( + try channel.readOutbound(as: ByteBuffer.self), + ByteBuffer(string: "HTTP/1.1 100 Continue\r\n\r\n") + ) + // Send a a chunk XCTAssertNoThrow(try channel.writeInbound(ByteBuffer(string: "6\r\nfoobar\r\n"))) - + // Server responds with an actual head, even though request has not finished yet let acceptedResponse = HTTPResponseHead(version: .http1_1, status: .accepted, headers: ["Content-Length": "20"]) XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(acceptedResponse))) - XCTAssertEqual(try channel.readOutbound(as: ByteBuffer.self), ByteBuffer(string: "HTTP/1.1 202 Accepted\r\nContent-Length: 20\r\n\r\n")) - + XCTAssertEqual( + try channel.readOutbound(as: ByteBuffer.self), + ByteBuffer(string: "HTTP/1.1 202 Accepted\r\nContent-Length: 20\r\n\r\n") + ) + // Client sends garbage chunk XCTAssertThrowsError(try channel.writeInbound(ByteBuffer(string: "xyz"))) { XCTAssertEqual($0 as? HTTPParserError, .invalidChunkSize) } - + XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self)) } diff --git a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift index e94669b25f..343b6b9138 100644 --- a/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPServerUpgradeTests.swift @@ -12,20 +12,23 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded -@testable import NIOPosix +import XCTest + @testable import NIOHTTP1 +@testable import NIOPosix extension ChannelPipeline { fileprivate func assertDoesNotContainUpgrader() throws { try self.assertDoesNotContain(handlerType: HTTPServerUpgradeHandler.self) } - func assertDoesNotContain(handlerType: Handler.Type, - file: StaticString = #filePath, - line: UInt = #line) throws { + func assertDoesNotContain( + handlerType: Handler.Type, + file: StaticString = #filePath, + line: UInt = #line + ) throws { do { try self.context(handlerType: handlerType) .map { context in @@ -108,18 +111,23 @@ extension EmbeddedChannel { private typealias UpgradeCompletionHandler = @Sendable (ChannelHandlerContext) -> Void @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -private func serverHTTPChannelWithAutoremoval(group: EventLoopGroup, - pipelining: Bool, - upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], - extraHandlers: [ChannelHandler], - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, EventLoopFuture) { +private func serverHTTPChannelWithAutoremoval( + group: EventLoopGroup, + pipelining: Bool, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler +) throws -> (Channel, EventLoopFuture) { let p = group.next().makePromise(of: Channel.self) let c = try ServerBootstrap(group: group) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in p.succeed(channel) let upgradeConfig = (upgraders: upgraders, completionHandler: upgradeCompletionHandler) - return channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: pipelining, withServerUpgrade: upgradeConfig).flatMap { + return channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: pipelining, + withServerUpgrade: upgradeConfig + ).flatMap { let futureResults = extraHandlers.map { channel.pipeline.addHandler($0) } return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) } @@ -140,7 +148,9 @@ private class SingleHTTPResponseAccumulator: ChannelInboundHandler { public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let buffer = Self.unwrapInboundIn(data) self.receiveds.append(buffer) - if let finalBytes = buffer.getBytes(at: buffer.writerIndex - 4, length: 4), finalBytes == [0x0D, 0x0A, 0x0D, 0x0A] { + if let finalBytes = buffer.getBytes(at: buffer.writerIndex - 4, length: 4), + finalBytes == [0x0D, 0x0A, 0x0D, 0x0A] + { self.allDoneBlock(self.receiveds) } } @@ -155,7 +165,7 @@ private class ExplodingHandler: ChannelInboundHandler { } private func connectedClientChannel(group: EventLoopGroup, serverAddress: SocketAddress) throws -> Channel { - return try ClientBootstrap(group: group) + try ClientBootstrap(group: group) .connect(to: serverAddress) .wait() } @@ -186,7 +196,8 @@ internal func assertResponseIs(response: String, expectedResponseLine: String, e #if !canImport(Darwin) || swift(>=5.10) @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader where UpgradeResult == Bool {} +protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader, NIOTypedHTTPServerProtocolUpgrader +where UpgradeResult == Bool {} #else @available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) protocol TypedAndUntypedHTTPServerProtocolUpgrader: HTTPServerProtocolUpgrader {} @@ -205,7 +216,11 @@ private class ExplodingUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { self.requiredUpgradeHeaders = requiringHeaders } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { XCTFail("buildUpgradeResponse called") return channel.eventLoop.makeFailedFuture(Explosion.KABOOM) } @@ -233,8 +248,12 @@ private class UpgraderSaysNo: TypedAndUntypedHTTPServerProtocolUpgrader { self.supportedProtocol = `protocol` } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - return channel.eventLoop.makeFailedFuture(No.no) + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + channel.eventLoop.makeFailedFuture(No.no) } public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { @@ -251,29 +270,39 @@ private class UpgraderSaysNo: TypedAndUntypedHTTPServerProtocolUpgrader { private class SuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgrader { let supportedProtocol: String let requiredUpgradeHeaders: [String] - private let onUpgradeComplete: (HTTPRequestHead) -> () + private let onUpgradeComplete: (HTTPRequestHead) -> Void private let buildUpgradeResponseFuture: (Channel, HTTPHeaders) -> EventLoopFuture - public init(forProtocol `protocol`: String, - requiringHeaders headers: [String], - buildUpgradeResponseFuture: @escaping (Channel, HTTPHeaders) -> EventLoopFuture, - onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) { + public init( + forProtocol `protocol`: String, + requiringHeaders headers: [String], + buildUpgradeResponseFuture: @escaping (Channel, HTTPHeaders) -> EventLoopFuture, + onUpgradeComplete: @escaping (HTTPRequestHead) -> Void + ) { self.supportedProtocol = `protocol` self.requiredUpgradeHeaders = headers self.onUpgradeComplete = onUpgradeComplete self.buildUpgradeResponseFuture = buildUpgradeResponseFuture } - public convenience init(forProtocol `protocol`: String, - requiringHeaders headers: [String], - onUpgradeComplete: @escaping (HTTPRequestHead) -> ()) { - self.init(forProtocol: `protocol`, - requiringHeaders: headers, - buildUpgradeResponseFuture: { $0.eventLoop.makeSucceededFuture($1) }, - onUpgradeComplete: onUpgradeComplete) - } - - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { + public convenience init( + forProtocol `protocol`: String, + requiringHeaders headers: [String], + onUpgradeComplete: @escaping (HTTPRequestHead) -> Void + ) { + self.init( + forProtocol: `protocol`, + requiringHeaders: headers, + buildUpgradeResponseFuture: { $0.eventLoop.makeSucceededFuture($1) }, + onUpgradeComplete: onUpgradeComplete + ) + } + + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { var headers = initialResponseHeaders headers.add(name: "X-Upgrade-Complete", value: "true") return self.buildUpgradeResponseFuture(channel, headers) @@ -301,10 +330,12 @@ private class DelayedUnsuccessfulUpgrader: TypedAndUntypedHTTPServerProtocolUpgr self.requiredUpgradeHeaders = [] } - func buildUpgradeResponse(channel: Channel, - upgradeRequest: HTTPRequestHead, - initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - return channel.eventLoop.makeSucceededFuture([:]) + func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + channel.eventLoop.makeSucceededFuture([:]) } func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { @@ -337,7 +368,11 @@ private class UpgradeDelayer: TypedAndUntypedHTTPServerProtocolUpgrader { self.upgradeRequestedPromise = upgradeRequestedPromise } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { var headers = initialResponseHeaders headers.add(name: "X-Upgrade-Complete", value: "true") return channel.eventLoop.makeSucceededFuture(headers) @@ -372,8 +407,12 @@ private class UpgradeResponseDelayer: HTTPServerProtocolUpgrader { self.buildUpgradeResponseHandler = buildUpgradeResponseHandler } - public func buildUpgradeResponse(channel: Channel, upgradeRequest: HTTPRequestHead, initialResponseHeaders: HTTPHeaders) -> EventLoopFuture { - return self.buildUpgradeResponseHandler().map { + public func buildUpgradeResponse( + channel: Channel, + upgradeRequest: HTTPRequestHead, + initialResponseHeaders: HTTPHeaders + ) -> EventLoopFuture { + self.buildUpgradeResponseHandler().map { var headers = initialResponseHeaders headers.add(name: "X-Upgrade-Complete", value: "true") return headers @@ -381,7 +420,7 @@ private class UpgradeResponseDelayer: HTTPServerProtocolUpgrader { } public func upgrade(context: ChannelHandlerContext, upgradeRequest: HTTPRequestHead) -> EventLoopFuture { - return context.eventLoop.makeSucceededFuture(()) + context.eventLoop.makeSucceededFuture(()) } } @@ -416,8 +455,8 @@ private class DataRecorder: ChannelInboundHandler { } // Must be called from inside the event loop on pain of death! - public func receivedData() ->[T] { - return self.data + public func receivedData() -> [T] { + self.data } } @@ -445,23 +484,32 @@ class HTTPServerUpgradeTestCase: XCTestCase { static let eventLoop = MultiThreadedEventLoopGroup.singleton.next() - fileprivate func setUpTestWithAutoremoval(pipelining: Bool = false, - upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], - extraHandlers: [ChannelHandler], - notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, - _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler) throws -> (Channel, Channel, Channel) { - let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval(group: Self.eventLoop, - pipelining: pipelining, - upgraders: upgraders, - extraHandlers: extraHandlers, - upgradeCompletionHandler) - let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannel.localAddress!) + fileprivate func setUpTestWithAutoremoval( + pipelining: Bool = false, + upgraders: [any TypedAndUntypedHTTPServerProtocolUpgrader], + extraHandlers: [ChannelHandler], + notUpgradingHandler: (@Sendable (Channel) -> EventLoopFuture)? = nil, + _ upgradeCompletionHandler: @escaping UpgradeCompletionHandler + ) throws -> (Channel, Channel, Channel) { + let (serverChannel, connectedServerChannelFuture) = try serverHTTPChannelWithAutoremoval( + group: Self.eventLoop, + pipelining: pipelining, + upgraders: upgraders, + extraHandlers: extraHandlers, + upgradeCompletionHandler + ) + let clientChannel = try connectedClientChannel( + group: Self.eventLoop, + serverAddress: serverChannel.localAddress! + ) return (serverChannel, clientChannel, try connectedServerChannelFuture.wait()) } func testUpgradeWithoutUpgrade() throws { - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], - extraHandlers: []) { (_: ChannelHandlerContext) in + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + extraHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { @@ -477,8 +525,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeAfterInitialRequest() throws { - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], - extraHandlers: []) { (_: ChannelHandlerContext) in + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + extraHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { @@ -487,7 +537,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { } // This request fires a subsequent upgrade in immediately. It should also be ignored. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\nOPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n\r\nOPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nConnection: upgrade\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // At this time the channel pipeline should not contain our handler: it should have removed itself. @@ -500,9 +551,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertEqual(true, try? channel.finish().isClean) } - let handler = HTTPServerUpgradeHandler(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], - httpEncoder: HTTPResponseEncoder(), - extraHTTPHandlers: []) { (_: ChannelHandlerContext) in + let handler = HTTPServerUpgradeHandler( + upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + httpEncoder: HTTPResponseEncoder(), + extraHTTPHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } let data = HTTPServerRequestPart.body(channel.allocator.buffer(string: "hello")) @@ -531,8 +584,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -541,19 +596,23 @@ class HTTPServerUpgradeTestCase: XCTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -569,8 +628,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresCorrectHeaders() throws { - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], - extraHandlers: []) { (_: ChannelHandlerContext) in + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + extraHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { @@ -586,8 +647,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeRequiresHeadersInConnection() throws { - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], - extraHandlers: []) { (_: ChannelHandlerContext) in + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [ExplodingUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"])], + extraHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { @@ -596,7 +659,8 @@ class HTTPServerUpgradeTestCase: XCTestCase { } // This request is missing a 'Kafkaesque' connection header. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\nKafkaesque: true\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nConnection: upgrade\r\nUpgrade: myproto\r\nKafkaesque: true\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // At this time the channel pipeline should not contain our handler: it should have removed itself. @@ -604,8 +668,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } func testUpgradeOnlyHandlesKnownProtocols() throws { - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [ExplodingUpgrader(forProtocol: "myproto")], - extraHandlers: []) { (_: ChannelHandlerContext) in + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [ExplodingUpgrader(forProtocol: "myproto")], + extraHandlers: [] + ) { (_: ChannelHandlerContext) in XCTFail("upgrade completed") } defer { @@ -632,8 +698,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { upgraderCbFired.wrappedValue = true } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: []) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [] + ) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -642,19 +710,23 @@ class HTTPServerUpgradeTestCase: XCTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -678,25 +750,31 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [eventSaver.wrappedValue]) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [eventSaver.wrappedValue] + ) { context in XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -704,17 +782,19 @@ class HTTPServerUpgradeTestCase: XCTestCase { // At this time we should have received one user event. We schedule this onto the // event loop to guarantee thread safety. - XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { - XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) - if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { - XCTAssertEqual(proto, "myproto") - XCTAssertEqual(req.method, .OPTIONS) - XCTAssertEqual(req.uri, "*") - XCTAssertEqual(req.version, .http1_1) - } else { - XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") - } - }.futureResult.wait()) + XCTAssertNoThrow( + try connectedServer.eventLoop.scheduleTask(deadline: .now()) { + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { + XCTAssertEqual(proto, "myproto") + XCTAssertEqual(req.method, .OPTIONS) + XCTAssertEqual(req.uri, "*") + XCTAssertEqual(req.version, .http1_1) + } else { + XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") + } + }.futureResult.wait() + ) // We also want to confirm that the upgrade handler is no longer in the pipeline. try connectedServer.pipeline.waitForUpgraderToBeRemoved() @@ -733,8 +813,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { } let errorCatcher = ErrorSaver() - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: [errorCatcher]) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [errorCatcher] + ) { context in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -743,19 +825,23 @@ class HTTPServerUpgradeTestCase: XCTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -772,7 +858,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { // And we want to confirm we saved the error. XCTAssertEqual(errorCatcher.errors.count, 1) - switch(errorCatcher.errors[0]) { + switch errorCatcher.errors[0] { case UpgraderSaysNo.No.no: break default: @@ -782,24 +868,30 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testUpgradeIsCaseInsensitive() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["WeIrDcAsE"]) { req in } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { context in context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nWeirdcase: yup\r\nConnection: upgrade,weirdcase\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nWeirdcase: yup\r\nConnection: upgrade,weirdcase\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(client.allocator.buffer(string: request)).wait()) // Let the machinery do its thing. @@ -812,15 +904,21 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testDelayedUpgradeBehaviour() throws { let upgradeRequestPromise = Self.eventLoop.makePromise(of: Void.self) let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) - let (server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { context in } + let (server, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { context in } let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = SingleHTTPResponseAccumulator { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) @@ -849,16 +947,21 @@ class HTTPServerUpgradeTestCase: XCTestCase { let upgrader = UpgradeDelayer(forProtocol: "myproto", upgradeRequestedPromise: upgradeRequestPromise) let dataRecorder = DataRecorder() - let (server, client, _) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [dataRecorder]) { context in } - + let (server, client, _) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [dataRecorder] + ) { context in } let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) @@ -884,7 +987,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { let data = try server.eventLoop.submit { dataRecorder.receivedData() }.wait() - let resultString = data.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + let resultString = data.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) XCTAssertEqual(resultString, appData) } @@ -903,10 +1008,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { return delayedPromise.futureResult } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in }) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) // Upgrade has been requested but not proceeded. @@ -918,11 +1028,17 @@ class HTTPServerUpgradeTestCase: XCTestCase { delayedPromise.succeed(()) channel.embeddedEventLoop.run() XCTAssertNoThrow(try channel.pipeline.assertDoesNotContainUpgrader()) - XCTAssertNoThrow(assertResponseIs(response: try channel.readAllOutboundString(), - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", - "upgrade: myproto", - "connection: upgrade"])) + XCTAssertNoThrow( + assertResponseIs( + response: try channel.readAllOutboundString(), + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: [ + "X-Upgrade-Complete: true", + "upgrade: myproto", + "connection: upgrade", + ] + ) + ) } func testChainsDelayedUpgradesAppropriately() throws { @@ -951,10 +1067,17 @@ class HTTPServerUpgradeTestCase: XCTestCase { return myprotoPromise.futureResult } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: (upgraders: [myprotoUpgrader, failingProtocolUpgrader], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: ( + upgraders: [myprotoUpgrader, failingProtocolUpgrader], completionHandler: { context in } + ) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: failingProtocol, myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: failingProtocol, myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) // Upgrade has been requested but not proceeded for the failing protocol. @@ -968,7 +1091,7 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertEqual(upgradingProtocol, "myproto") channel.pipeline.assertContainsUpgrader() XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound(as: ByteBuffer.self))) - + XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in XCTAssertEqual(.no, error as? No) } @@ -977,9 +1100,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { myprotoPromise.succeed(()) channel.embeddedEventLoop.run() XCTAssertNoThrow(try channel.pipeline.assertDoesNotContainUpgrader()) - assertResponseIs(response: try channel.readAllOutboundString(), - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + assertResponseIs( + response: try channel.readAllOutboundString(), + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) } func testDelayedUpgradeResponseDeliversFullRequest() throws { @@ -1001,10 +1126,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { return delayedPromise.futureResult } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in }) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) // Upgrade has been requested but not proceeded. @@ -1061,11 +1191,16 @@ class HTTPServerUpgradeTestCase: XCTestCase { } // Here we're disabling the pipeline handler, because otherwise it makes this test case impossible to reach. - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, - withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: false, + withServerUpgrade: (upgraders: [delayedUpgrader], completionHandler: { context in }) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) // Upgrade has been requested but not proceeded. @@ -1075,11 +1210,13 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.throwIfErrorCaught()) // We now need to inject an extra buffered request. To do this we grab the context for the HTTPRequestDecoder and inject some reads. - XCTAssertNoThrow(try channel.pipeline.context(handlerType: ByteToMessageHandler.self).map { context in - let requestHead = HTTPServerRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/test")) - context.fireChannelRead(NIOAny(requestHead)) - context.fireChannelRead(NIOAny(HTTPServerRequestPart.end(nil))) - }.wait()) + XCTAssertNoThrow( + try channel.pipeline.context(handlerType: ByteToMessageHandler.self).map { context in + let requestHead = HTTPServerRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/test")) + context.fireChannelRead(NIOAny(requestHead)) + context.fireChannelRead(NIOAny(HTTPServerRequestPart.end(nil))) + }.wait() + ) // Ok, now we fail the upgrade. This fires an error, and then delivers the original request and the buffered one. delayedPromise.fail(No.no) @@ -1105,7 +1242,6 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTFail("Expected .head, got \(String(describing: t))") } - switch try channel.readInbound(as: HTTPServerRequestPart.self) { case .some(.head(let h)): XCTAssertEqual(h.method, .GET) @@ -1126,9 +1262,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { func testRemovesAllHTTPRelatedHandlersAfterUpgrade() throws { let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { req in } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(pipelining: true, - upgraders: [upgrader], - extraHandlers: []) { context in } + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + pipelining: true, + upgraders: [upgrader], + extraHandlers: [] + ) { context in } // First, validate the pipeline is right. connectedServer.pipeline.assertContains(handlerType: ByteToMessageHandler.self) @@ -1136,14 +1274,19 @@ class HTTPServerUpgradeTestCase: XCTestCase { connectedServer.pipeline.assertContains(handlerType: HTTPServerPipelineHandler.self) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. XCTAssertNoThrow(try connectedServer.pipeline.waitForUpgraderToBeRemoved()) // At this time we should validate that none of the HTTP handlers in the pipeline exist. - XCTAssertNoThrow(try connectedServer.pipeline.assertDoesNotContain(handlerType: ByteToMessageHandler.self)) + XCTAssertNoThrow( + try connectedServer.pipeline.assertDoesNotContain( + handlerType: ByteToMessageHandler.self + ) + ) XCTAssertNoThrow(try connectedServer.pipeline.assertDoesNotContain(handlerType: HTTPResponseEncoder.self)) XCTAssertNoThrow(try connectedServer.pipeline.assertDoesNotContain(handlerType: HTTPServerPipelineHandler.self)) } @@ -1153,12 +1296,12 @@ class HTTPServerUpgradeTestCase: XCTestCase { let upgradeRequest = UnsafeMutableTransferBox(nil) let upgradeHandlerCbFired = UnsafeMutableTransferBox(false) let upgraderCbFired = UnsafeMutableTransferBox(false) - + class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { typealias InboundIn = ByteBuffer typealias OutboundIn = Never typealias OutboundOut = Never - + enum State { case fresh case added @@ -1166,25 +1309,27 @@ class HTTPServerUpgradeTestCase: XCTestCase { case extraDataRead case closed } - + private let firstByteDonePromise: EventLoopPromise private let secondByteDonePromise: EventLoopPromise private let allDonePromise: EventLoopPromise private var state = State.fresh - - init(firstByteDonePromise: EventLoopPromise, - secondByteDonePromise: EventLoopPromise, - allDonePromise: EventLoopPromise) { + + init( + firstByteDonePromise: EventLoopPromise, + secondByteDonePromise: EventLoopPromise, + allDonePromise: EventLoopPromise + ) { self.firstByteDonePromise = firstByteDonePromise self.secondByteDonePromise = secondByteDonePromise self.allDonePromise = allDonePromise } - + func handlerAdded(context: ChannelHandlerContext) { XCTAssertEqual(.fresh, self.state) self.state = .added } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { var buf = Self.unwrapInboundIn(data) XCTAssertEqual(1, buf.readableBytes) @@ -1211,22 +1356,22 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTFail("channel read in wrong state \(self.state)") } } - + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { XCTAssertEqual(.extraDataRead, self.state) self.state = .closed context.close(mode: mode, promise: promise) - + self.allDonePromise.succeed(()) } } - + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in upgradeRequest.wrappedValue = req XCTAssert(upgradeHandlerCbFired.wrappedValue) upgraderCbFired.wrappedValue = true } - + let promiseGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try promiseGroup.syncShutdownGracefully()) @@ -1234,51 +1379,62 @@ class HTTPServerUpgradeTestCase: XCTestCase { let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true - _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, - secondByteDonePromise: secondByteDonePromise, - allDonePromise: allDonePromise)) + _ = context.channel.pipeline.addHandler( + CheckWeReadInlineAndExtraData( + firstByteDonePromise: firstByteDonePromise, + secondByteDonePromise: secondByteDonePromise, + allDonePromise: allDonePromise + ) + ) } let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) - + // This request is safe to upgrade. - var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + var request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" request += "A" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) XCTAssertNoThrow(try firstByteDonePromise.futureResult.wait() as Void) XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: "B"))).wait()) - + XCTAssertNoThrow(try secondByteDonePromise.futureResult.wait() as Void) XCTAssertNoThrow(try allDonePromise.futureResult.wait() as Void) // Let the machinery do its thing. XCTAssertNoThrow(try completePromise.futureResult.wait()) - + // At this time we want to assert that everything got called. Their own callbacks assert // that the ordering was correct. XCTAssert(upgradeHandlerCbFired.wrappedValue) XCTAssert(upgraderCbFired.wrappedValue) - + // We also want to confirm that the upgrade handler is no longer in the pipeline. try connectedServer.pipeline.assertDoesNotContainUpgrader() - + XCTAssertNoThrow(try allDonePromise.futureResult.wait()) } @@ -1293,10 +1449,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { defer { delayer.unblockUpgrade() } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: (upgraders: [delayer], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: (upgraders: [delayer], completionHandler: { context in }) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) channel.embeddedEventLoop.run() @@ -1310,9 +1471,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { return } XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) - assertResponseIs(response: responseBuffer.readString(length: responseBuffer.readableBytes)!, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + assertResponseIs( + response: responseBuffer.readString(length: responseBuffer.readableBytes)!, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) // Now send in some more bytes. XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "B"))) @@ -1323,8 +1486,10 @@ class HTTPServerUpgradeTestCase: XCTestCase { // This should have delivered the pending bytes and the buffered request, and in all ways have behaved // as though upgrade simply failed. - XCTAssertEqual(try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), - channel.allocator.buffer(string: "B")) + XCTAssertEqual( + try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), + channel.allocator.buffer(string: "B") + ) XCTAssertNoThrow(try channel.pipeline.assertDoesNotContainUpgrader()) XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) } @@ -1343,10 +1508,15 @@ class HTTPServerUpgradeTestCase: XCTestCase { delayer.unblockUpgrade() } - XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: (upgraders: [delayer], completionHandler: { context in })).wait()) + XCTAssertNoThrow( + try channel.pipeline.configureHTTPServerPipeline( + withServerUpgrade: (upgraders: [delayer], completionHandler: { context in }) + ).wait() + ) // Let's send in an upgrade request. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: request))) channel.embeddedEventLoop.run() @@ -1360,9 +1530,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { return } XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) - assertResponseIs(response: responseBuffer.readString(length: responseBuffer.readableBytes)!, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + assertResponseIs( + response: responseBuffer.readString(length: responseBuffer.readableBytes)!, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) // Now send in some more bytes. XCTAssertNoThrow(try channel.writeInbound(channel.allocator.buffer(string: "B"))) @@ -1375,10 +1547,14 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.pipeline.removeUpgrader()) // We should have received B and then the re-entrant read in that order. - XCTAssertEqual(try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), - channel.allocator.buffer(string: "B")) - XCTAssertEqual(try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), - channel.allocator.buffer(string: "re-entrant read from channelReadComplete!")) + XCTAssertEqual( + try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), + channel.allocator.buffer(string: "B") + ) + XCTAssertEqual( + try assertNoThrowWithValue(channel.readInbound(as: ByteBuffer.self)), + channel.allocator.buffer(string: "re-entrant read from channelReadComplete!") + ) XCTAssertNoThrow(try channel.pipeline.assertDoesNotContainUpgrader()) XCTAssertNoThrow(try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self))) } @@ -1391,21 +1567,23 @@ class HTTPServerUpgradeTestCase: XCTestCase { defer { XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) } - - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", - requiringHeaders: ["kafkaesque"], - buildUpgradeResponseFuture: { - // this is the wrong EL - otherELG.next().makeSucceededFuture($1) - }) { req in + let upgrader = SuccessfulUpgrader( + forProtocol: "myproto", + requiringHeaders: ["kafkaesque"] + ) { + // this is the wrong EL + otherELG.next().makeSucceededFuture($1) + } onUpgradeComplete: { req in upgradeRequest.wrappedValue = req XCTAssert(upgradeHandlerCbFired.wrappedValue) upgraderCbFired.wrappedValue = true } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in - // This is called before the upgrader gets called. + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in + // This is called before the upgrader gets called. XCTAssertNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -1415,16 +1593,21 @@ class HTTPServerUpgradeTestCase: XCTestCase { let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -1447,10 +1630,12 @@ class HTTPServerUpgradeTestCase: XCTestCase { let encoder = HTTPResponseEncoder() let handlers: [RemovableChannelHandler] = [HTTPServerPipelineHandler(), HTTPServerProtocolErrorHandler()] - let upgradeHandler = HTTPServerUpgradeHandler(upgraders: [SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: [], onUpgradeComplete: { _ in })], - httpEncoder: encoder, - extraHTTPHandlers: handlers, - upgradeCompletionHandler: { _ in }) + let upgradeHandler = HTTPServerUpgradeHandler( + upgraders: [SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: [], onUpgradeComplete: { _ in })], + httpEncoder: encoder, + extraHTTPHandlers: handlers, + upgradeCompletionHandler: { _ in } + ) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(encoder)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandlers(handlers)) @@ -1464,7 +1649,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { // Remove one of the extra handlers. XCTAssertNoThrow(try channel.pipeline.removeHandler(handlers.last!).wait()) - let head = HTTPServerRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"])) + let head = HTTPServerRequestPart.head( + .init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"]) + ) XCTAssertNoThrow(try channel.writeInbound(head)) XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in XCTAssertEqual(error as? ChannelPipelineError, .notFound) @@ -1494,9 +1681,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { } let encoder = HTTPResponseEncoder() - let handler = HTTPServerUpgradeHandler(upgraders: [SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { _ in }], - httpEncoder: encoder, - extraHTTPHandlers: []) { (_: ChannelHandlerContext) in + let handler = HTTPServerUpgradeHandler( + upgraders: [SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: []) { _ in }], + httpEncoder: encoder, + extraHTTPHandlers: [] + ) { (_: ChannelHandlerContext) in () } @@ -1509,7 +1698,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(userEventSaver)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(dataRecorder)) - let head = HTTPServerRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"])) + let head = HTTPServerRequestPart.head( + .init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"]) + ) XCTAssertNoThrow(try channel.writeInbound(head)) XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in XCTAssert(error is FailAllWritesHandler.FailAllWritesError) @@ -1533,9 +1724,11 @@ class HTTPServerUpgradeTestCase: XCTestCase { let upgrader = DelayedUnsuccessfulUpgrader(forProtocol: "myproto") let encoder = HTTPResponseEncoder() - let handler = HTTPServerUpgradeHandler(upgraders: [upgrader], - httpEncoder: encoder, - extraHTTPHandlers: []) { (_: ChannelHandlerContext) in + let handler = HTTPServerUpgradeHandler( + upgraders: [upgrader], + httpEncoder: encoder, + extraHTTPHandlers: [] + ) { (_: ChannelHandlerContext) in // no-op. () } @@ -1548,7 +1741,9 @@ class HTTPServerUpgradeTestCase: XCTestCase { XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(userEventSaver)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(dataRecorder)) - let head = HTTPServerRequestPart.head(.init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"])) + let head = HTTPServerRequestPart.head( + .init(version: .http1_1, method: .GET, uri: "/foo", headers: ["upgrade": "myproto"]) + ) XCTAssertNoThrow(try channel.writeInbound(head)) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) @@ -1586,11 +1781,15 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { var configuration = NIOUpgradableHTTPServerPipelineConfiguration( upgradeConfiguration: .init( upgraders: upgraders.map { $0 as! any NIOTypedHTTPServerProtocolUpgrader }, - notUpgradingCompletionHandler: { notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) } + notUpgradingCompletionHandler: { + notUpgradingHandler?($0) ?? $0.eventLoop.makeSucceededFuture(false) + } ) ) configuration.enablePipelining = pipelining - return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(configuration: configuration) + return try channel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline( + configuration: configuration + ) .flatMap { result in if result { return channel.pipeline.context(handlerType: NIOTypedHTTPServerUpgradeHandler.self) @@ -1607,7 +1806,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { return EventLoopFuture.andAllSucceed(futureResults, on: channel.eventLoop) } }.bind(host: "127.0.0.1", port: 0) - let clientChannel = try connectedClientChannel(group: Self.eventLoop, serverAddress: serverChannelFuture.wait().localAddress!) + let clientChannel = try connectedClientChannel( + group: Self.eventLoop, + serverAddress: serverChannelFuture.wait().localAddress! + ) return (try serverChannelFuture.wait(), clientChannel, try connectionChannelPromise.futureResult.wait()) } @@ -1618,26 +1820,28 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { let (_, client, connectedServer) = try setUpTestWithAutoremoval( upgraders: [upgrader], - extraHandlers: [], - notUpgradingHandler: { channel in - notUpgraderCbFired.wrappedValue = true - // We're closing the connection now. - channel.close(promise: nil) - return channel.eventLoop.makeSucceededFuture(true) - } - ) { _ in } - + extraHandlers: [] + ) { channel in + notUpgraderCbFired.wrappedValue = true + // We're closing the connection now. + channel.close(promise: nil) + return channel.eventLoop.makeSucceededFuture(true) + } _: { _ in + } let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) XCTAssertEqual(resultString, "") completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: notmyproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -1680,19 +1884,23 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -1721,8 +1929,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { upgraderCbFired.wrappedValue = true } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: []) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [] + ) { context in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -1731,19 +1941,23 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto, exploder\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -1773,8 +1987,10 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { } let errorCatcher = ErrorSaver() - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [explodingUpgrader, successfulUpgrader], - extraHandlers: [errorCatcher]) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [explodingUpgrader, successfulUpgrader], + extraHandlers: [errorCatcher] + ) { context in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -1783,19 +1999,23 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: noproto,myproto\r\nKafkaesque: yup\r\nConnection: upgrade, kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -1812,7 +2032,7 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // And we want to confirm we saved the error. XCTAssertEqual(errorCatcher.errors.count, 1) - switch(errorCatcher.errors[0]) { + switch errorCatcher.errors[0] { case UpgraderSaysNo.No.no: break default: @@ -1846,9 +2066,11 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { private let allDonePromise: EventLoopPromise private var state = State.fresh - init(firstByteDonePromise: EventLoopPromise, - secondByteDonePromise: EventLoopPromise, - allDonePromise: EventLoopPromise) { + init( + firstByteDonePromise: EventLoopPromise, + secondByteDonePromise: EventLoopPromise, + allDonePromise: EventLoopPromise + ) { self.firstByteDonePromise = firstByteDonePromise self.secondByteDonePromise = secondByteDonePromise self.allDonePromise = allDonePromise @@ -1908,30 +2130,40 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { let firstByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let secondByteDonePromise = promiseGroup.next().makePromise(of: Void.self) let allDonePromise = promiseGroup.next().makePromise(of: Void.self) - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true - _ = context.channel.pipeline.addHandler(CheckWeReadInlineAndExtraData(firstByteDonePromise: firstByteDonePromise, - secondByteDonePromise: secondByteDonePromise, - allDonePromise: allDonePromise)) + _ = context.channel.pipeline.addHandler( + CheckWeReadInlineAndExtraData( + firstByteDonePromise: firstByteDonePromise, + secondByteDonePromise: secondByteDonePromise, + allDonePromise: allDonePromise + ) + ) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + var request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" request += "A" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) @@ -1968,20 +2200,23 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { XCTAssertNoThrow(try otherELG.syncShutdownGracefully()) } - let upgrader = SuccessfulUpgrader(forProtocol: "myproto", - requiringHeaders: ["kafkaesque"], - buildUpgradeResponseFuture: { - // this is the wrong EL - otherELG.next().makeSucceededFuture($1) - }) { req in + let upgrader = SuccessfulUpgrader( + forProtocol: "myproto", + requiringHeaders: ["kafkaesque"] + ) { + // this is the wrong EL + otherELG.next().makeSucceededFuture($1) + } onUpgradeComplete: { req in upgradeRequest.wrappedValue = req XCTAssertFalse(upgradeHandlerCbFired.wrappedValue) upgraderCbFired.wrappedValue = true } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: []) { (context) in - // This is called before the upgrader gets called. + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [] + ) { (context) in + // This is called before the upgrader gets called. XCTAssertNotNil(upgradeRequest.wrappedValue) upgradeHandlerCbFired.wrappedValue = true @@ -1989,19 +2224,23 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -2025,25 +2264,31 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { XCTAssertEqual(eventSaver.wrappedValue.events.count, 0) } - let (_, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], - extraHandlers: [eventSaver.wrappedValue]) { context in + let (_, client, connectedServer) = try setUpTestWithAutoremoval( + upgraders: [upgrader], + extraHandlers: [eventSaver.wrappedValue] + ) { context in XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) context.close(promise: nil) } - let completePromise = Self.eventLoop.makePromise(of: Void.self) let clientHandler = ArrayAccumulationHandler { buffers in - let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") - assertResponseIs(response: resultString, - expectedResponseLine: "HTTP/1.1 101 Switching Protocols", - expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined( + separator: "" + ) + assertResponseIs( + response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"] + ) completePromise.succeed(()) } XCTAssertNoThrow(try client.pipeline.addHandler(clientHandler).wait()) // This request is safe to upgrade. - let request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" + let request = + "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade,kafkaesque\r\n\r\n" XCTAssertNoThrow(try client.writeAndFlush(NIOAny(client.allocator.buffer(string: request))).wait()) // Let the machinery do its thing. @@ -2051,17 +2296,19 @@ final class TypedHTTPServerUpgradeTestCase: HTTPServerUpgradeTestCase { // At this time we should have received one user event. We schedule this onto the // event loop to guarantee thread safety. - XCTAssertNoThrow(try connectedServer.eventLoop.scheduleTask(deadline: .now()) { - XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) - if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { - XCTAssertEqual(proto, "myproto") - XCTAssertEqual(req.method, .OPTIONS) - XCTAssertEqual(req.uri, "*") - XCTAssertEqual(req.version, .http1_1) - } else { - XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") - } - }.futureResult.wait()) + XCTAssertNoThrow( + try connectedServer.eventLoop.scheduleTask(deadline: .now()) { + XCTAssertEqual(eventSaver.wrappedValue.events.count, 1) + if case .upgradeComplete(let proto, let req) = eventSaver.wrappedValue.events[0] { + XCTAssertEqual(proto, "myproto") + XCTAssertEqual(req.method, .OPTIONS) + XCTAssertEqual(req.uri, "*") + XCTAssertEqual(req.version, .http1_1) + } else { + XCTFail("Unexpected event: \(eventSaver.wrappedValue.events[0])") + } + }.futureResult.wait() + ) // We also want to confirm that the upgrade handler is no longer in the pipeline. try connectedServer.pipeline.waitForUpgraderToBeRemoved() diff --git a/Tests/NIOHTTP1Tests/HTTPTest.swift b/Tests/NIOHTTP1Tests/HTTPTest.swift index fe08510b81..b1cf126e80 100644 --- a/Tests/NIOHTTP1Tests/HTTPTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPTest.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// +import NIOEmbedded import XCTest + @testable import NIOCore -import NIOEmbedded @testable import NIOHTTP1 private final class TestChannelInboundHandler: ChannelInboundHandler { @@ -32,7 +33,6 @@ private final class TestChannelInboundHandler: ChannelInboundHandler { } } - class HTTPTest: XCTestCase { func checkHTTPRequest(_ expected: HTTPRequestHead, body: String? = nil, trailers: HTTPHeaders? = nil) throws { @@ -72,7 +72,12 @@ class HTTPTest: XCTestCase { return s } - func sendAndCheckRequests(_ expecteds: [HTTPRequestHead], body: String?, trailers: HTTPHeaders?, sendStrategy: (String, EmbeddedChannel) -> EventLoopFuture) throws -> String? { + func sendAndCheckRequests( + _ expecteds: [HTTPRequestHead], + body: String?, + trailers: HTTPHeaders?, + sendStrategy: (String, EmbeddedChannel) -> EventLoopFuture + ) throws -> String? { var step = 0 var index = 0 let channel = EmbeddedChannel() @@ -82,27 +87,29 @@ class HTTPTest: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(HTTPRequestDecoder())) var bodyData: [UInt8]? = nil var allBodyDatas: [[UInt8]] = [] - try channel.pipeline.addHandler(TestChannelInboundHandler { reqPart in - switch reqPart { - case .head(var req): - XCTAssertEqual((index * 2), step) - req.headers.remove(name: "Content-Length") - req.headers.remove(name: "Transfer-Encoding") - XCTAssertEqual(expecteds[index], req) - step += 1 - case .body(var buffer): - if bodyData == nil { - bodyData = buffer.readBytes(length: buffer.readableBytes)! - } else { - bodyData!.append(contentsOf: buffer.readBytes(length: buffer.readableBytes)!) + try channel.pipeline.addHandler( + TestChannelInboundHandler { reqPart in + switch reqPart { + case .head(var req): + XCTAssertEqual((index * 2), step) + req.headers.remove(name: "Content-Length") + req.headers.remove(name: "Transfer-Encoding") + XCTAssertEqual(expecteds[index], req) + step += 1 + case .body(var buffer): + if bodyData == nil { + bodyData = buffer.readBytes(length: buffer.readableBytes)! + } else { + bodyData!.append(contentsOf: buffer.readBytes(length: buffer.readableBytes)!) + } + case .end(let receivedTrailers): + XCTAssertEqual(trailers, receivedTrailers) + step += 1 + XCTAssertEqual(((index + 1) * 2), step) } - case .end(let receivedTrailers): - XCTAssertEqual(trailers, receivedTrailers) - step += 1 - XCTAssertEqual(((index + 1) * 2), step) + return reqPart } - return reqPart - }).wait() + ).wait() var writeFutures: [EventLoopFuture] = [] for expected in expecteds { @@ -130,28 +137,40 @@ class HTTPTest: XCTestCase { } } - /* send all bytes in one go */ - let bd1 = try sendAndCheckRequests(expecteds, body: body, trailers: trailers, sendStrategy: { (reqString, chan) in - var buf = chan.allocator.buffer(capacity: 1024) - buf.writeString(reqString) - return chan.eventLoop.makeSucceededFuture(()).flatMapThrowing { - try chan.writeInbound(buf) - } - }) - - /* send the bytes one by one */ - let bd2 = try sendAndCheckRequests(expecteds, body: body, trailers: trailers, sendStrategy: { (reqString, chan) in - var writeFutures: [EventLoopFuture] = [] - for c in reqString { + // send all bytes in one go + let bd1 = try sendAndCheckRequests( + expecteds, + body: body, + trailers: trailers, + sendStrategy: { (reqString, chan) in var buf = chan.allocator.buffer(capacity: 1024) - - buf.writeString("\(c)") - writeFutures.append(chan.eventLoop.makeSucceededFuture(()).flatMapThrowing { [buf] in + buf.writeString(reqString) + return chan.eventLoop.makeSucceededFuture(()).flatMapThrowing { try chan.writeInbound(buf) - }) + } + } + ) + + // send the bytes one by one + let bd2 = try sendAndCheckRequests( + expecteds, + body: body, + trailers: trailers, + sendStrategy: { (reqString, chan) in + var writeFutures: [EventLoopFuture] = [] + for c in reqString { + var buf = chan.allocator.buffer(capacity: 1024) + + buf.writeString("\(c)") + writeFutures.append( + chan.eventLoop.makeSucceededFuture(()).flatMapThrowing { [buf] in + try chan.writeInbound(buf) + } + ) + } + return EventLoopFuture.andAllSucceed(writeFutures, on: chan.eventLoop) } - return EventLoopFuture.andAllSucceed(writeFutures, on: chan.eventLoop) - }) + ) XCTAssertEqual(bd1, bd2) XCTAssertEqual(body, bd1) @@ -187,27 +206,42 @@ class HTTPTest: XCTestCase { } func testHTTPBody() throws { - try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"), - body: "hello world") + try checkHTTPRequest( + HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"), + body: "hello world" + ) } func test1ByteHTTPBody() throws { - try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"), - body: "1") + try checkHTTPRequest( + HTTPRequestHead(version: .http1_1, method: .GET, uri: "/"), + body: "1" + ) } func testHTTPPipeliningWithBody() throws { - try checkHTTPRequests(Array(repeating: HTTPRequestHead(version: .http1_1, - method: .GET, uri: "/"), - count: 20), - body: "1") + try checkHTTPRequests( + Array( + repeating: HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/" + ), + count: 20 + ), + body: "1" + ) } func testChunkedBody() throws { var trailers = HTTPHeaders() trailers.add(name: "X-Key", value: "X-Value") trailers.add(name: "Something", value: "Else") - try checkHTTPRequest(HTTPRequestHead(version: .http1_1, method: .POST, uri: "/"), body: "100", trailers: trailers) + try checkHTTPRequest( + HTTPRequestHead(version: .http1_1, method: .POST, uri: "/"), + body: "100", + trailers: trailers + ) } func testHTTPRequestHeadCoWWorks() throws { diff --git a/Tests/NIOHTTP1Tests/HTTPTypesTest.swift b/Tests/NIOHTTP1Tests/HTTPTypesTest.swift index 3c07e7e230..7549229d74 100644 --- a/Tests/NIOHTTP1Tests/HTTPTypesTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPTypesTest.swift @@ -16,92 +16,91 @@ import NIOHTTP1 import XCTest final class HTTPTypesTest: XCTestCase { - + func testConvertToString() { - XCTAssertEqual(HTTPMethod.GET.rawValue, "GET") - XCTAssertEqual(HTTPMethod.PUT.rawValue, "PUT") - XCTAssertEqual(HTTPMethod.ACL.rawValue, "ACL") - XCTAssertEqual(HTTPMethod.HEAD.rawValue, "HEAD") - XCTAssertEqual(HTTPMethod.POST.rawValue, "POST") - XCTAssertEqual(HTTPMethod.COPY.rawValue, "COPY") - XCTAssertEqual(HTTPMethod.LOCK.rawValue, "LOCK") - XCTAssertEqual(HTTPMethod.MOVE.rawValue, "MOVE") - XCTAssertEqual(HTTPMethod.BIND.rawValue, "BIND") - XCTAssertEqual(HTTPMethod.LINK.rawValue, "LINK") - XCTAssertEqual(HTTPMethod.PATCH.rawValue, "PATCH") - XCTAssertEqual(HTTPMethod.TRACE.rawValue, "TRACE") - XCTAssertEqual(HTTPMethod.MKCOL.rawValue, "MKCOL") - XCTAssertEqual(HTTPMethod.MERGE.rawValue, "MERGE") - XCTAssertEqual(HTTPMethod.PURGE.rawValue, "PURGE") - XCTAssertEqual(HTTPMethod.NOTIFY.rawValue, "NOTIFY") - XCTAssertEqual(HTTPMethod.SEARCH.rawValue, "SEARCH") - XCTAssertEqual(HTTPMethod.UNLOCK.rawValue, "UNLOCK") - XCTAssertEqual(HTTPMethod.REBIND.rawValue, "REBIND") - XCTAssertEqual(HTTPMethod.UNBIND.rawValue, "UNBIND") - XCTAssertEqual(HTTPMethod.REPORT.rawValue, "REPORT") - XCTAssertEqual(HTTPMethod.DELETE.rawValue, "DELETE") - XCTAssertEqual(HTTPMethod.UNLINK.rawValue, "UNLINK") - XCTAssertEqual(HTTPMethod.CONNECT.rawValue, "CONNECT") - XCTAssertEqual(HTTPMethod.MSEARCH.rawValue, "MSEARCH") - XCTAssertEqual(HTTPMethod.OPTIONS.rawValue, "OPTIONS") - XCTAssertEqual(HTTPMethod.PROPFIND.rawValue, "PROPFIND") - XCTAssertEqual(HTTPMethod.CHECKOUT.rawValue, "CHECKOUT") - XCTAssertEqual(HTTPMethod.PROPPATCH.rawValue, "PROPPATCH") - XCTAssertEqual(HTTPMethod.SUBSCRIBE.rawValue, "SUBSCRIBE") - XCTAssertEqual(HTTPMethod.MKCALENDAR.rawValue, "MKCALENDAR") - XCTAssertEqual(HTTPMethod.MKACTIVITY.rawValue, "MKACTIVITY") - XCTAssertEqual(HTTPMethod.UNSUBSCRIBE.rawValue, "UNSUBSCRIBE") + XCTAssertEqual(HTTPMethod.GET.rawValue, "GET") + XCTAssertEqual(HTTPMethod.PUT.rawValue, "PUT") + XCTAssertEqual(HTTPMethod.ACL.rawValue, "ACL") + XCTAssertEqual(HTTPMethod.HEAD.rawValue, "HEAD") + XCTAssertEqual(HTTPMethod.POST.rawValue, "POST") + XCTAssertEqual(HTTPMethod.COPY.rawValue, "COPY") + XCTAssertEqual(HTTPMethod.LOCK.rawValue, "LOCK") + XCTAssertEqual(HTTPMethod.MOVE.rawValue, "MOVE") + XCTAssertEqual(HTTPMethod.BIND.rawValue, "BIND") + XCTAssertEqual(HTTPMethod.LINK.rawValue, "LINK") + XCTAssertEqual(HTTPMethod.PATCH.rawValue, "PATCH") + XCTAssertEqual(HTTPMethod.TRACE.rawValue, "TRACE") + XCTAssertEqual(HTTPMethod.MKCOL.rawValue, "MKCOL") + XCTAssertEqual(HTTPMethod.MERGE.rawValue, "MERGE") + XCTAssertEqual(HTTPMethod.PURGE.rawValue, "PURGE") + XCTAssertEqual(HTTPMethod.NOTIFY.rawValue, "NOTIFY") + XCTAssertEqual(HTTPMethod.SEARCH.rawValue, "SEARCH") + XCTAssertEqual(HTTPMethod.UNLOCK.rawValue, "UNLOCK") + XCTAssertEqual(HTTPMethod.REBIND.rawValue, "REBIND") + XCTAssertEqual(HTTPMethod.UNBIND.rawValue, "UNBIND") + XCTAssertEqual(HTTPMethod.REPORT.rawValue, "REPORT") + XCTAssertEqual(HTTPMethod.DELETE.rawValue, "DELETE") + XCTAssertEqual(HTTPMethod.UNLINK.rawValue, "UNLINK") + XCTAssertEqual(HTTPMethod.CONNECT.rawValue, "CONNECT") + XCTAssertEqual(HTTPMethod.MSEARCH.rawValue, "MSEARCH") + XCTAssertEqual(HTTPMethod.OPTIONS.rawValue, "OPTIONS") + XCTAssertEqual(HTTPMethod.PROPFIND.rawValue, "PROPFIND") + XCTAssertEqual(HTTPMethod.CHECKOUT.rawValue, "CHECKOUT") + XCTAssertEqual(HTTPMethod.PROPPATCH.rawValue, "PROPPATCH") + XCTAssertEqual(HTTPMethod.SUBSCRIBE.rawValue, "SUBSCRIBE") + XCTAssertEqual(HTTPMethod.MKCALENDAR.rawValue, "MKCALENDAR") + XCTAssertEqual(HTTPMethod.MKACTIVITY.rawValue, "MKACTIVITY") + XCTAssertEqual(HTTPMethod.UNSUBSCRIBE.rawValue, "UNSUBSCRIBE") XCTAssertEqual(HTTPMethod.SOURCE.rawValue, "SOURCE") XCTAssertEqual(HTTPMethod.RAW(value: "SOMETHINGELSE").rawValue, "SOMETHINGELSE") } - + func testConvertFromString() { - XCTAssertEqual(HTTPMethod(rawValue: "GET"), .GET) - XCTAssertEqual(HTTPMethod(rawValue: "PUT"), .PUT) - XCTAssertEqual(HTTPMethod(rawValue: "ACL"), .ACL) - XCTAssertEqual(HTTPMethod(rawValue: "HEAD"), .HEAD) - XCTAssertEqual(HTTPMethod(rawValue: "POST"), .POST) - XCTAssertEqual(HTTPMethod(rawValue: "COPY"), .COPY) - XCTAssertEqual(HTTPMethod(rawValue: "LOCK"), .LOCK) - XCTAssertEqual(HTTPMethod(rawValue: "MOVE"), .MOVE) - XCTAssertEqual(HTTPMethod(rawValue: "BIND"), .BIND) - XCTAssertEqual(HTTPMethod(rawValue: "LINK"), .LINK) - XCTAssertEqual(HTTPMethod(rawValue: "PATCH"), .PATCH) - XCTAssertEqual(HTTPMethod(rawValue: "TRACE"), .TRACE) - XCTAssertEqual(HTTPMethod(rawValue: "MKCOL"), .MKCOL) - XCTAssertEqual(HTTPMethod(rawValue: "MERGE"), .MERGE) - XCTAssertEqual(HTTPMethod(rawValue: "PURGE"), .PURGE) - XCTAssertEqual(HTTPMethod(rawValue: "NOTIFY"), .NOTIFY) - XCTAssertEqual(HTTPMethod(rawValue: "SEARCH"), .SEARCH) - XCTAssertEqual(HTTPMethod(rawValue: "UNLOCK"), .UNLOCK) - XCTAssertEqual(HTTPMethod(rawValue: "REBIND"), .REBIND) - XCTAssertEqual(HTTPMethod(rawValue: "UNBIND"), .UNBIND) - XCTAssertEqual(HTTPMethod(rawValue: "REPORT"), .REPORT) - XCTAssertEqual(HTTPMethod(rawValue: "DELETE"), .DELETE) - XCTAssertEqual(HTTPMethod(rawValue: "UNLINK"), .UNLINK) - XCTAssertEqual(HTTPMethod(rawValue: "CONNECT"), .CONNECT) - XCTAssertEqual(HTTPMethod(rawValue: "MSEARCH"), .MSEARCH) - XCTAssertEqual(HTTPMethod(rawValue: "OPTIONS"), .OPTIONS) - XCTAssertEqual(HTTPMethod(rawValue: "PROPFIND"), .PROPFIND) - XCTAssertEqual(HTTPMethod(rawValue: "CHECKOUT"), .CHECKOUT) - XCTAssertEqual(HTTPMethod(rawValue: "PROPPATCH"), .PROPPATCH) - XCTAssertEqual(HTTPMethod(rawValue: "SUBSCRIBE"), .SUBSCRIBE) - XCTAssertEqual(HTTPMethod(rawValue: "MKCALENDAR"), .MKCALENDAR) - XCTAssertEqual(HTTPMethod(rawValue: "MKACTIVITY"), .MKACTIVITY) - XCTAssertEqual(HTTPMethod(rawValue: "UNSUBSCRIBE"), .UNSUBSCRIBE) + XCTAssertEqual(HTTPMethod(rawValue: "GET"), .GET) + XCTAssertEqual(HTTPMethod(rawValue: "PUT"), .PUT) + XCTAssertEqual(HTTPMethod(rawValue: "ACL"), .ACL) + XCTAssertEqual(HTTPMethod(rawValue: "HEAD"), .HEAD) + XCTAssertEqual(HTTPMethod(rawValue: "POST"), .POST) + XCTAssertEqual(HTTPMethod(rawValue: "COPY"), .COPY) + XCTAssertEqual(HTTPMethod(rawValue: "LOCK"), .LOCK) + XCTAssertEqual(HTTPMethod(rawValue: "MOVE"), .MOVE) + XCTAssertEqual(HTTPMethod(rawValue: "BIND"), .BIND) + XCTAssertEqual(HTTPMethod(rawValue: "LINK"), .LINK) + XCTAssertEqual(HTTPMethod(rawValue: "PATCH"), .PATCH) + XCTAssertEqual(HTTPMethod(rawValue: "TRACE"), .TRACE) + XCTAssertEqual(HTTPMethod(rawValue: "MKCOL"), .MKCOL) + XCTAssertEqual(HTTPMethod(rawValue: "MERGE"), .MERGE) + XCTAssertEqual(HTTPMethod(rawValue: "PURGE"), .PURGE) + XCTAssertEqual(HTTPMethod(rawValue: "NOTIFY"), .NOTIFY) + XCTAssertEqual(HTTPMethod(rawValue: "SEARCH"), .SEARCH) + XCTAssertEqual(HTTPMethod(rawValue: "UNLOCK"), .UNLOCK) + XCTAssertEqual(HTTPMethod(rawValue: "REBIND"), .REBIND) + XCTAssertEqual(HTTPMethod(rawValue: "UNBIND"), .UNBIND) + XCTAssertEqual(HTTPMethod(rawValue: "REPORT"), .REPORT) + XCTAssertEqual(HTTPMethod(rawValue: "DELETE"), .DELETE) + XCTAssertEqual(HTTPMethod(rawValue: "UNLINK"), .UNLINK) + XCTAssertEqual(HTTPMethod(rawValue: "CONNECT"), .CONNECT) + XCTAssertEqual(HTTPMethod(rawValue: "MSEARCH"), .MSEARCH) + XCTAssertEqual(HTTPMethod(rawValue: "OPTIONS"), .OPTIONS) + XCTAssertEqual(HTTPMethod(rawValue: "PROPFIND"), .PROPFIND) + XCTAssertEqual(HTTPMethod(rawValue: "CHECKOUT"), .CHECKOUT) + XCTAssertEqual(HTTPMethod(rawValue: "PROPPATCH"), .PROPPATCH) + XCTAssertEqual(HTTPMethod(rawValue: "SUBSCRIBE"), .SUBSCRIBE) + XCTAssertEqual(HTTPMethod(rawValue: "MKCALENDAR"), .MKCALENDAR) + XCTAssertEqual(HTTPMethod(rawValue: "MKACTIVITY"), .MKACTIVITY) + XCTAssertEqual(HTTPMethod(rawValue: "UNSUBSCRIBE"), .UNSUBSCRIBE) XCTAssertEqual(HTTPMethod(rawValue: "SOURCE"), .SOURCE) XCTAssertEqual(HTTPMethod(rawValue: "SOMETHINGELSE"), HTTPMethod.RAW(value: "SOMETHINGELSE")) } - + func testConvertFromStringToExplicitValue() { switch HTTPMethod(rawValue: "GET") { case .RAW(value: "GET"): XCTFail("Expected \"GET\" to map to explicit .GET value and not .RAW(value: \"GET\")") case .GET: - break // everything is awesome + break // everything is awesome default: XCTFail("Unexpected case") } } } - diff --git a/Tests/NIOHTTP1Tests/NIOHTTPObjectAggregatorTest.swift b/Tests/NIOHTTP1Tests/NIOHTTPObjectAggregatorTest.swift index fdd532d285..4931be76af 100644 --- a/Tests/NIOHTTP1Tests/NIOHTTPObjectAggregatorTest.swift +++ b/Tests/NIOHTTP1Tests/NIOHTTPObjectAggregatorTest.swift @@ -12,22 +12,21 @@ // //===----------------------------------------------------------------------===// -import XCTest import NIOCore import NIOEmbedded import NIOHTTP1 import NIOTestUtils - +import XCTest private final class ReadRecorder: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = T - + enum Event: Equatable { case channelRead(InboundIn) case httpFrameTooLongEvent case httpExpectationFailedEvent - - static func ==(lhs: Event, rhs: Event) -> Bool { + + static func == (lhs: Event, rhs: Event) -> Bool { switch (lhs, rhs) { case (.channelRead(let b1), .channelRead(let b2)): return b1 == b2 @@ -40,14 +39,14 @@ private final class ReadRecorder: ChannelInboundHandler, Removable } } } - + public var reads: [Event] = [] - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.reads.append(.channelRead(Self.unwrapInboundIn(data))) context.fireChannelRead(data) } - + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case let evt as NIOHTTPObjectAggregatorEvent where evt == NIOHTTPObjectAggregatorEvent.httpFrameTooLong: @@ -80,8 +79,8 @@ private final class WriteRecorder: ChannelOutboundHandler, RemovableChannelHandl } } -private extension ByteBuffer { - func assertContainsOnly(_ string: String) { +extension ByteBuffer { + fileprivate func assertContainsOnly(_ string: String) { let innerData = self.getString(at: self.readerIndex, length: self.readableBytes)! XCTAssertEqual(innerData, string) } @@ -96,7 +95,6 @@ private func asHTTPResponseHead(_ response: HTTPServerResponsePart) -> HTTPRespo } } - class NIOHTTPServerRequestAggregatorTest: XCTestCase { var channel: EmbeddedChannel! = nil var requestHead: HTTPRequestHead! = nil @@ -104,13 +102,13 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { fileprivate var readRecorder: ReadRecorder! = nil fileprivate var writeRecorder: WriteRecorder! = nil fileprivate var aggregatorHandler: NIOHTTPServerRequestAggregator! = nil - + override func setUp() { self.channel = EmbeddedChannel() self.readRecorder = ReadRecorder() self.writeRecorder = WriteRecorder() self.aggregatorHandler = NIOHTTPServerRequestAggregator(maxContentLength: 1024 * 1024) - + XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(HTTPResponseEncoder())) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(self.writeRecorder)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(self.aggregatorHandler)) @@ -119,10 +117,10 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { self.requestHead = HTTPRequestHead(version: .http1_1, method: .PUT, uri: "/path") self.requestHead.headers.add(name: "Host", value: "example.com") self.requestHead.headers.add(name: "X-Test", value: "True") - + self.responseHead = HTTPResponseHead(version: .http1_1, status: .ok) self.responseHead.headers.add(name: "Server", value: "SwiftNIO") - + // this activates the channel XCTAssertNoThrow(try self.channel.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 1)).wait()) } @@ -135,7 +133,7 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { XCTAssertNoThrow(try self.channel.pipeline.syncOperations.addHandler(self.aggregatorHandler)) XCTAssertNoThrow(try channel.pipeline.syncOperations.addHandler(self.readRecorder)) } - + override func tearDown() { if let channel = self.channel { XCTAssertNoThrow(try channel.finish(acceptAlreadyClosed: true)) @@ -147,92 +145,160 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { self.writeRecorder = nil self.aggregatorHandler = nil } - + func testAggregateNoBody() { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, - [.channelRead(NIOHTTPServerRequestFull(head: self.requestHead, body: nil))]) + XCTAssertEqual( + self.readRecorder.reads, + [.channelRead(NIOHTTPServerRequestFull(head: self.requestHead, body: nil))] + ) } - + func testAggregateWithBody() { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "hello")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "hello") + ) + ) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - + // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, [ - .channelRead(NIOHTTPServerRequestFull( - head: self.requestHead, - body: channel.allocator.buffer(string: "hello")))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead( + NIOHTTPServerRequestFull( + head: self.requestHead, + body: channel.allocator.buffer(string: "hello") + ) + ) + ] + ) } - + func testAggregateChunkedBody() { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) - - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "hello")))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "world")))) + + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "hello") + ) + ) + ) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "world") + ) + ) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - + // Only one request should have made it through. - XCTAssertEqual(self.readRecorder.reads, [ - .channelRead(NIOHTTPServerRequestFull( - head: self.requestHead, - body: channel.allocator.buffer(string: "helloworld")))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead( + NIOHTTPServerRequestFull( + head: self.requestHead, + body: channel.allocator.buffer(string: "helloworld") + ) + ) + ] + ) } - + func testAggregateWithTrailer() { var reqWithChunking: HTTPRequestHead = self.requestHead reqWithChunking.headers.add(name: "transfer-encoding", value: "chunked") reqWithChunking.headers.add(name: "Trailer", value: "X-Trailer") - + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(reqWithChunking))) - - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "hello")))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "world")))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end( - HTTPHeaders.init([("X-Trailer", "true")])))) + + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "hello") + ) + ) + ) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "world") + ) + ) + ) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.end( + HTTPHeaders.init([("X-Trailer", "true")]) + ) + ) + ) reqWithChunking.headers.remove(name: "Trailer") reqWithChunking.headers.add(name: "X-Trailer", value: "true") // Trailer headers should get moved to normal ones - XCTAssertEqual(self.readRecorder.reads, [ - .channelRead(NIOHTTPServerRequestFull( - head: reqWithChunking, - body: channel.allocator.buffer(string: "helloworld")))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead( + NIOHTTPServerRequestFull( + head: reqWithChunking, + body: channel.allocator.buffer(string: "helloworld") + ) + ) + ] + ) } - + func testOversizeRequest() { resetSmallHandler(maxContentLength: 4) - + XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(self.requestHead))) XCTAssertTrue(channel.isActive) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "he")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "he") + ) + ) + ) XCTAssertEqual(self.writeRecorder.writes, []) - XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(string: "llo")))) { error in + XCTAssertThrowsError( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(string: "llo") + ) + ) + ) { error in XCTAssertEqual(NIOHTTPObjectAggregatorError.frameTooLong, error as? NIOHTTPObjectAggregatorError) } let resTooLarge = HTTPResponseHead( version: .http1_1, status: .payloadTooLarge, - headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")])) - - XCTAssertEqual(self.writeRecorder.writes, [ - HTTPServerResponsePart.head(resTooLarge), - HTTPServerResponsePart.end(nil)]) + headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]) + ) + + XCTAssertEqual( + self.writeRecorder.writes, + [ + HTTPServerResponsePart.head(resTooLarge), + HTTPServerResponsePart.end(nil), + ] + ) XCTAssertFalse(channel.isActive) XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) { error in @@ -246,20 +312,27 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { // send an HTTP/1.0 request with no keep-alive header let requestHead: HTTPRequestHead = HTTPRequestHead( version: .http1_0, - method: .PUT, uri: "/path", + method: .PUT, + uri: "/path", headers: HTTPHeaders( - [("Host", "example.com"), ("X-Test", "True"), ("content-length", "5")])) + [("Host", "example.com"), ("X-Test", "True"), ("content-length", "5")]) + ) XCTAssertThrowsError(try self.channel.writeInbound(HTTPServerRequestPart.head(requestHead))) let resTooLarge = HTTPResponseHead( version: .http1_0, status: .payloadTooLarge, - headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")])) + headers: HTTPHeaders([("Content-Length", "0"), ("connection", "close")]) + ) - XCTAssertEqual(self.writeRecorder.writes, [ - HTTPServerResponsePart.head(resTooLarge), - HTTPServerResponsePart.end(nil)]) + XCTAssertEqual( + self.writeRecorder.writes, + [ + HTTPServerResponsePart.head(resTooLarge), + HTTPServerResponsePart.end(nil), + ] + ) // Connection should be closed right away XCTAssertFalse(channel.isActive) @@ -275,9 +348,11 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { // HTTP/1.1 uses Keep-Alive unless told otherwise let requestHead: HTTPRequestHead = HTTPRequestHead( version: .http1_1, - method: .PUT, uri: "/path", + method: .PUT, + uri: "/path", headers: HTTPHeaders( - [("Host", "example.com"), ("X-Test", "True"), ("content-length", "8")])) + [("Host", "example.com"), ("X-Test", "True"), ("content-length", "8")]) + ) resetSmallHandler(maxContentLength: 4) @@ -294,8 +369,8 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { // An ill-behaved client may continue writing the request let requestParts = [ HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [1, 2, 3, 4])), - HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [5,6])), - HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [7,8])) + HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [5, 6])), + HTTPServerRequestPart.body(channel.allocator.buffer(bytes: [7, 8])), ] for requestPart in requestParts { @@ -314,18 +389,35 @@ class NIOHTTPServerRequestAggregatorTest: XCTestCase { XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.head(secondReqWithContentLength))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(bytes: [1])))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(bytes: [1]) + ) + ) + ) XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent]) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.body( - channel.allocator.buffer(bytes: [2])))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPServerRequestPart.body( + channel.allocator.buffer(bytes: [2]) + ) + ) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPServerRequestPart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, [ - .httpFrameTooLongEvent, - .channelRead(NIOHTTPServerRequestFull( - head: secondReqWithContentLength, - body: channel.allocator.buffer(bytes: [1, 2])))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .httpFrameTooLongEvent, + .channelRead( + NIOHTTPServerRequestFull( + head: secondReqWithContentLength, + body: channel.allocator.buffer(bytes: [1, 2]) + ) + ), + ] + ) } } @@ -393,68 +485,113 @@ class NIOHTTPClientResponseAggregatorTest: XCTestCase { resetSmallHandler(maxContentLength: 5) XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead))) - XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "hello")))) - - XCTAssertThrowsError(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "world")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "hello") + ) + ) + ) + + XCTAssertThrowsError( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "world") + ) + ) + ) XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil))) // User event triggered XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent]) } - func testAggregatedResponse() { XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead))) - XCTAssertNoThrow(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "hello")))) - XCTAssertNoThrow(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "world")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "hello") + ) + ) + ) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "world") + ) + ) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(HTTPHeaders([("X-Trail", "true")])))) var aggregatedHead: HTTPResponseHead = self.responseHead aggregatedHead.headers.add(name: "X-Trail", value: "true") - XCTAssertEqual(self.readRecorder.reads, [ - .channelRead(NIOHTTPClientResponseFull( - head: aggregatedHead, - body: self.channel.allocator.buffer(string: "helloworld")))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .channelRead( + NIOHTTPClientResponseFull( + head: aggregatedHead, + body: self.channel.allocator.buffer(string: "helloworld") + ) + ) + ] + ) } func testOkAfterOversized() { resetSmallHandler(maxContentLength: 4) XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead))) - XCTAssertNoThrow(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "hell")))) - XCTAssertThrowsError(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "owor")))) - XCTAssertThrowsError(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "ld")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "hell") + ) + ) + ) + XCTAssertThrowsError( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "owor") + ) + ) + ) + XCTAssertThrowsError( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "ld") + ) + ) + ) XCTAssertThrowsError(try self.channel.writeInbound(HTTPClientResponsePart.end(nil))) // User event triggered XCTAssertEqual(self.readRecorder.reads, [.httpFrameTooLongEvent]) XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.head(self.responseHead))) - XCTAssertNoThrow(try self.channel.writeInbound( - HTTPClientResponsePart.body( - self.channel.allocator.buffer(string: "test")))) + XCTAssertNoThrow( + try self.channel.writeInbound( + HTTPClientResponsePart.body( + self.channel.allocator.buffer(string: "test") + ) + ) + ) XCTAssertNoThrow(try self.channel.writeInbound(HTTPClientResponsePart.end(nil))) - XCTAssertEqual(self.readRecorder.reads, [ - .httpFrameTooLongEvent, - .channelRead(NIOHTTPClientResponseFull( - head: self.responseHead, - body: self.channel.allocator.buffer(string: "test")))]) + XCTAssertEqual( + self.readRecorder.reads, + [ + .httpFrameTooLongEvent, + .channelRead( + NIOHTTPClientResponseFull( + head: self.responseHead, + body: self.channel.allocator.buffer(string: "test") + ) + ), + ] + ) } - } diff --git a/Tests/NIOHTTP1Tests/UnsafeTransfer.swift b/Tests/NIOHTTP1Tests/UnsafeTransfer.swift index 007d96bdfb..83a42de82a 100644 --- a/Tests/NIOHTTP1Tests/UnsafeTransfer.swift +++ b/Tests/NIOHTTP1Tests/UnsafeTransfer.swift @@ -19,7 +19,7 @@ struct UnsafeTransfer { @usableFromInline var wrappedValue: Wrapped - + @inlinable init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue @@ -38,7 +38,7 @@ extension UnsafeTransfer: Hashable where Wrapped: Hashable {} final class UnsafeMutableTransferBox { @usableFromInline var wrappedValue: Wrapped - + @inlinable init(_ wrappedValue: Wrapped) { self.wrappedValue = wrappedValue diff --git a/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift b/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift index 015c6382ed..162cb454ba 100644 --- a/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift +++ b/Tests/NIOPosixTests/AcceptBackoffHandlerTest.swift @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -import XCTest -import NIOCore -@testable import NIOPosix import Atomics +import NIOCore +import XCTest +@testable import NIOPosix public final class AcceptBackoffHandlerTest: XCTestCase { @@ -49,26 +49,37 @@ public final class AcceptBackoffHandlerTest: XCTestCase { } let readCountHandler = ReadCountHandler() - let serverChannel = try setupChannel(group: group, - readCountHandler: readCountHandler, - backoffProvider: { _ in return .milliseconds(100) }, - errors: [error]) - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - serverChannel.readable() - serverChannel.read() - return readCountHandler.readCount - }.wait()) + let serverChannel = try setupChannel( + group: group, + readCountHandler: readCountHandler, + backoffProvider: { _ in .milliseconds(100) }, + errors: [error] + ) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + serverChannel.readable() + serverChannel.read() + return readCountHandler.readCount + }.wait() + ) // Inspect the read count after our scheduled backoff elapsed. - XCTAssertEqual(1, try serverChannel.eventLoop.scheduleTask(in: .milliseconds(100)) { - return readCountHandler.readCount - }.futureResult.wait()) + XCTAssertEqual( + 1, + try serverChannel.eventLoop.scheduleTask(in: .milliseconds(100)) { + readCountHandler.readCount + }.futureResult.wait() + ) // The read should go through as the scheduled read happened - XCTAssertEqual(2, try serverChannel.eventLoop.submit { - serverChannel.read() - return readCountHandler.readCount - }.wait()) + XCTAssertEqual( + 2, + try serverChannel.eventLoop.submit { + serverChannel.read() + return readCountHandler.readCount + }.wait() + ) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } @@ -88,29 +99,43 @@ public final class AcceptBackoffHandlerTest: XCTestCase { } let readCountHandler = ReadCountHandler() - let serverChannel = try setupChannel(group: group, readCountHandler: readCountHandler, backoffProvider: { err in - return .hours(1) - }, errors: [ENFILE]) - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - serverChannel.readable() - if read { - serverChannel.read() - } - return readCountHandler.readCount - }.wait()) + let serverChannel = try setupChannel( + group: group, + readCountHandler: readCountHandler, + backoffProvider: { err in + .hours(1) + }, + errors: [ENFILE] + ) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + serverChannel.readable() + if read { + serverChannel.read() + } + return readCountHandler.readCount + }.wait() + ) XCTAssertNoThrow(try serverChannel.pipeline.removeHandler(name: acceptHandlerName).wait()) if read { // Removal should have triggered a read. - XCTAssertEqual(1, try serverChannel.eventLoop.submit { - return readCountHandler.readCount - }.wait()) + XCTAssertEqual( + 1, + try serverChannel.eventLoop.submit { + readCountHandler.readCount + }.wait() + ) } else { // Removal should have triggered no read. - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - return readCountHandler.readCount - }.wait()) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + readCountHandler.readCount + }.wait() + ) } XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } @@ -122,27 +147,41 @@ public final class AcceptBackoffHandlerTest: XCTestCase { } let readCountHandler = ReadCountHandler() - let serverChannel = try setupChannel(group: group, readCountHandler: readCountHandler, backoffProvider: { err in - return .milliseconds(10) - }, errors: [ENFILE]) - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - serverChannel.readable() - serverChannel.read() - serverChannel.read() - return readCountHandler.readCount - }.wait()) + let serverChannel = try setupChannel( + group: group, + readCountHandler: readCountHandler, + backoffProvider: { err in + .milliseconds(10) + }, + errors: [ENFILE] + ) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + serverChannel.readable() + serverChannel.read() + serverChannel.read() + return readCountHandler.readCount + }.wait() + ) // Inspect the read count after our scheduled backoff elapsed multiple times. This should still only have triggered one read as we should only ever // schedule one read. - XCTAssertEqual(1, try serverChannel.eventLoop.scheduleTask(in: .milliseconds(500)) { - return readCountHandler.readCount - }.futureResult.wait()) + XCTAssertEqual( + 1, + try serverChannel.eventLoop.scheduleTask(in: .milliseconds(500)) { + readCountHandler.readCount + }.futureResult.wait() + ) // The read should go through as the scheduled read happened - XCTAssertEqual(2, try serverChannel.eventLoop.submit { - serverChannel.read() - return readCountHandler.readCount - }.wait()) + XCTAssertEqual( + 2, + try serverChannel.eventLoop.submit { + serverChannel.read() + return readCountHandler.readCount + }.wait() + ) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) } @@ -156,7 +195,6 @@ public final class AcceptBackoffHandlerTest: XCTestCase { class InactiveVerificationHandler: ChannelInboundHandler { typealias InboundIn = Any - private let promise: EventLoopPromise init(promise: EventLoopPromise) { @@ -173,25 +211,36 @@ public final class AcceptBackoffHandlerTest: XCTestCase { } let readCountHandler = ReadCountHandler() - let serverChannel = try setupChannel(group: group, readCountHandler: readCountHandler, backoffProvider: { err in - return .milliseconds(10) - }, errors: [ENFILE]) + let serverChannel = try setupChannel( + group: group, + readCountHandler: readCountHandler, + backoffProvider: { err in + .milliseconds(10) + }, + errors: [ENFILE] + ) let inactiveVerificationHandler = InactiveVerificationHandler(promise: serverChannel.eventLoop.makePromise()) XCTAssertNoThrow(try serverChannel.pipeline.addHandler(inactiveVerificationHandler).wait()) - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - serverChannel.readable() - serverChannel.read() - // Close the channel, this should also take care of cancel the scheduled read. - serverChannel.close(promise: nil) - return readCountHandler.readCount - }.wait()) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + serverChannel.readable() + serverChannel.read() + // Close the channel, this should also take care of cancel the scheduled read. + serverChannel.close(promise: nil) + return readCountHandler.readCount + }.wait() + ) // Inspect the read count after our scheduled backoff elapsed multiple times. This should have triggered no read as the channel was closed. - XCTAssertEqual(0, try serverChannel.eventLoop.scheduleTask(in: .milliseconds(500)) { - return readCountHandler.readCount - }.futureResult.wait()) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.scheduleTask(in: .milliseconds(500)) { + readCountHandler.readCount + }.futureResult.wait() + ) XCTAssertNoThrow(try inactiveVerificationHandler.waitForInactive()) } @@ -205,32 +254,46 @@ public final class AcceptBackoffHandlerTest: XCTestCase { let readCountHandler = ReadCountHandler() let backoffProviderCalled = ManagedAtomic(0) - let serverChannel = try setupChannel(group: group, readCountHandler: readCountHandler, backoffProvider: { err in - if backoffProviderCalled.loadThenWrappingIncrement(ordering: .relaxed) == 0 { - return .seconds(1) - } - return .seconds(2) - }, errors: [ENFILE, EMFILE]) - - XCTAssertEqual(0, try serverChannel.eventLoop.submit { - serverChannel.readable() - serverChannel.read() - let readCount = readCountHandler.readCount - // Directly trigger a read again without going through the pipeline. This will allow us to use serverChannel.readable() - serverChannel._channelCore.read0() - serverChannel.readable() - return readCount - }.wait()) + let serverChannel = try setupChannel( + group: group, + readCountHandler: readCountHandler, + backoffProvider: { err in + if backoffProviderCalled.loadThenWrappingIncrement(ordering: .relaxed) == 0 { + return .seconds(1) + } + return .seconds(2) + }, + errors: [ENFILE, EMFILE] + ) + + XCTAssertEqual( + 0, + try serverChannel.eventLoop.submit { + serverChannel.readable() + serverChannel.read() + let readCount = readCountHandler.readCount + // Directly trigger a read again without going through the pipeline. This will allow us to use serverChannel.readable() + serverChannel._channelCore.read0() + serverChannel.readable() + return readCount + }.wait() + ) // This should have not fired a read yet as we updated the scheduled read because we received two errors. - XCTAssertEqual(0, try serverChannel.eventLoop.scheduleTask(in: .seconds(1)) { - return readCountHandler.readCount - }.futureResult.wait()) + XCTAssertEqual( + 0, + try serverChannel.eventLoop.scheduleTask(in: .seconds(1)) { + readCountHandler.readCount + }.futureResult.wait() + ) // This should have fired now as the updated scheduled read task should have been complete by now - XCTAssertEqual(1, try serverChannel.eventLoop.scheduleTask(in: .seconds(1)) { - return readCountHandler.readCount - }.futureResult.wait()) + XCTAssertEqual( + 1, + try serverChannel.eventLoop.scheduleTask(in: .seconds(1)) { + readCountHandler.readCount + }.futureResult.wait() + ) XCTAssertNoThrow(try serverChannel.syncCloseAcceptingAlreadyClosed()) @@ -249,31 +312,43 @@ public final class AcceptBackoffHandlerTest: XCTestCase { } } - private func setupChannel(group: EventLoopGroup, - readCountHandler: ReadCountHandler, - backoffProvider: @escaping (IOError) -> TimeAmount? = AcceptBackoffHandler.defaultBackoffProvider, - errors: [Int32]) throws -> ServerSocketChannel { + private func setupChannel( + group: EventLoopGroup, + readCountHandler: ReadCountHandler, + backoffProvider: @escaping (IOError) -> TimeAmount? = AcceptBackoffHandler.defaultBackoffProvider, + errors: [Int32] + ) throws -> ServerSocketChannel { let eventLoop = group.next() as! SelectableEventLoop let socket = try NonAcceptingServerSocket(errors: errors) - let serverChannel = try assertNoThrowWithValue(ServerSocketChannel(serverSocket: socket, - eventLoop: eventLoop, - group: group)) + let serverChannel = try assertNoThrowWithValue( + ServerSocketChannel( + serverSocket: socket, + eventLoop: eventLoop, + group: group + ) + ) XCTAssertNoThrow(try serverChannel.setOption(ChannelOptions.autoRead, value: false).wait()) - XCTAssertNoThrow(try serverChannel.pipeline.addHandler(readCountHandler).flatMap { _ in - serverChannel.pipeline.addHandler(AcceptBackoffHandler(backoffProvider: backoffProvider), - name: self.acceptHandlerName) - }.wait()) - - XCTAssertNoThrow(try eventLoop.flatSubmit { - // this is pretty delicate at the moment: - // `bind` must be _synchronously_ follow `register`, otherwise in our current implementation, `epoll` will - // send us `EPOLLHUP`. To have it run synchronously, we need to invoke the `flatMap` on the eventloop that the - // `register` will succeed. - serverChannel.register().flatMap { () -> EventLoopFuture<()> in - return serverChannel.bind(to: try! SocketAddress(ipAddress: "127.0.0.1", port: 0)) - } - }.wait() as Void) + XCTAssertNoThrow( + try serverChannel.pipeline.addHandler(readCountHandler).flatMap { _ in + serverChannel.pipeline.addHandler( + AcceptBackoffHandler(backoffProvider: backoffProvider), + name: self.acceptHandlerName + ) + }.wait() + ) + + XCTAssertNoThrow( + try eventLoop.flatSubmit { + // this is pretty delicate at the moment: + // `bind` must be _synchronously_ follow `register`, otherwise in our current implementation, `epoll` will + // send us `EPOLLHUP`. To have it run synchronously, we need to invoke the `flatMap` on the eventloop that the + // `register` will succeed. + serverChannel.register().flatMap { () -> EventLoopFuture<()> in + serverChannel.bind(to: try! SocketAddress(ipAddress: "127.0.0.1", port: 0)) + } + }.wait() as Void + ) return serverChannel } } diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index c267c83704..b1fc060feb 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -13,10 +13,11 @@ //===----------------------------------------------------------------------===// import NIOConcurrencyHelpers +import NIOTLS +import XCTest + @testable import NIOCore @testable import NIOPosix -import XCTest -import NIOTLS private final class IPHeaderRemoverHandler: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope @@ -118,7 +119,9 @@ private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannel context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn)) context.pipeline.removeHandler(self, promise: nil) } else if string.hasPrefix("alpn:") { - context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5)))) + context.fireUserInboundEventTriggered( + TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5))) + ) context.pipeline.removeHandler(self, promise: nil) } else { context.fireChannelRead(data) @@ -183,7 +186,10 @@ private final class AddressedEnvelopingHandler: ChannelDuplexHandler { func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { let buffer = Self.unwrapOutboundIn(data) if let remoteAddress = self.remoteAddress { - context.write(Self.wrapOutboundOut(AddressedEnvelope(remoteAddress: remoteAddress, data: buffer)), promise: promise) + context.write( + Self.wrapOutboundOut(AddressedEnvelope(remoteAddress: remoteAddress, data: buffer)), + promise: promise + ) return } @@ -252,7 +258,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } - let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) + let stringChannel = try await self.makeClientChannel( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port! + ) try await stringChannel.executeThenClose { _, outbound in try await outbound.write("hello") } @@ -269,17 +278,19 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel in - channel.eventLoop.makeCompletedFuture { - try self.configureProtocolNegotiationHandlers(channel: channel) - } + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap( + group: eventLoopGroup + ) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel) } + } try await withThrowingTaskGroup(of: Void.self) { group in let (stream, continuation) = AsyncStream.makeStream() @@ -354,7 +365,8 @@ final class AsyncChannelBootstrapTests: XCTestCase { try! eventLoopGroup.syncShutdownGracefully() } - let channel: NIOAsyncChannel>, Never> = try await ServerBootstrap(group: eventLoopGroup) + let channel: NIOAsyncChannel>, Never> = + try await ServerBootstrap(group: eventLoopGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .bind( host: "127.0.0.1", @@ -491,22 +503,24 @@ final class AsyncChannelBootstrapTests: XCTestCase { } let channels = NIOLockedValueBox<[Channel]>([Channel]()) - let channel: NIOAsyncChannel, Never> = try await ServerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .serverChannelInitializer { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels)) - } + let channel: NIOAsyncChannel, Never> = try await ServerBootstrap( + group: eventLoopGroup + ) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .serverChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels)) } - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel in - channel.eventLoop.makeCompletedFuture { - try self.configureProtocolNegotiationHandlers(channel: channel) - } + } + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel) } + } try await withThrowingTaskGroup(of: Void.self) { group in let (stream, continuation) = AsyncStream.makeStream() @@ -582,8 +596,8 @@ final class AsyncChannelBootstrapTests: XCTestCase { try await ClientBootstrap( group: .singletonMultiThreadedEventLoopGroup ).connect(unixDomainSocketPath: "testClientBootstrapConnectFails") { channel in - return channel.eventLoop.makeCompletedFuture { - return try NIOAsyncChannel( + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( wrappingChannelSynchronously: channel, configuration: .init( inboundType: ByteBuffer.self, @@ -632,7 +646,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { let channel = try await DatagramBootstrap(group: eventLoopGroup) .bind( to: .init(ipAddress: "127.0.0.1", port: 0), - channelInitializer: { channel -> EventLoopFuture in channel.eventLoop.makeSucceededFuture(channel) } + channelInitializer: { channel -> EventLoopFuture in + channel.eventLoop.makeSucceededFuture(channel) + } ) let port = channel.localAddress!.port! @@ -690,8 +706,8 @@ final class AsyncChannelBootstrapTests: XCTestCase { try await DatagramBootstrap( group: .singletonMultiThreadedEventLoopGroup ).connect(unixDomainSocketPath: "testDatagramBootstrapConnectFails") { channel in - return channel.eventLoop.makeCompletedFuture { - return try NIOAsyncChannel( + channel.eventLoop.makeCompletedFuture { + try NIOAsyncChannel( wrappingChannelSynchronously: channel, configuration: .init( inboundType: AddressedEnvelope.self, @@ -726,7 +742,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -740,7 +758,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1WriteFD, pipe2ReadFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -754,7 +774,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe2ReadFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -794,7 +816,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1ReadFD, pipe1WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -808,7 +832,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -845,7 +871,10 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1ReadFD, pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1ReadFD, pipe1WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } + throw error } @@ -859,7 +888,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -894,12 +925,14 @@ final class AsyncChannelBootstrapTests: XCTestCase { input: pipe1ReadFD, output: pipe2WriteFD ) { channel in - return channel.eventLoop.makeCompletedFuture { - return try self.configureProtocolNegotiationHandlers(channel: channel) + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel) } } } catch { - try [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1ReadFD, pipe1WriteFD, pipe2ReadFD, pipe2WriteFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -913,7 +946,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe1WriteFD, pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1WriteFD, pipe2ReadFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -927,7 +962,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } } catch { - try [pipe2ReadFD].forEach { try SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe2ReadFD] { + try SystemCalls.close(descriptor: fileDescriptor) + } throw error } @@ -949,7 +986,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { XCTAssertEqual(response, expectedResponse) } catch { // We only got to close the FDs that are not owned by the PipeChannel - [pipe1WriteFD, pipe2ReadFD].forEach { try? SystemCalls.close(descriptor: $0) } + for fileDescriptor in [pipe1WriteFD, pipe2ReadFD] { + try? SystemCalls.close(descriptor: fileDescriptor) + } throw error } } @@ -1102,10 +1141,11 @@ final class AsyncChannelBootstrapTests: XCTestCase { } } - // MARK: - Test Helpers - private func makePipeFileDescriptors() -> (pipe1ReadFD: CInt, pipe1WriteFD: CInt, pipe2ReadFD: CInt, pipe2WriteFD: CInt) { + private func makePipeFileDescriptors() -> ( + pipe1ReadFD: CInt, pipe1WriteFD: CInt, pipe2ReadFD: CInt, pipe2WriteFD: CInt + ) { var pipe1FDs: [CInt] = [-1, -1] pipe1FDs.withUnsafeMutableBufferPointer { ptr in XCTAssertEqual(0, pipe(ptr.baseAddress!)) @@ -1125,9 +1165,9 @@ final class AsyncChannelBootstrapTests: XCTestCase { return (pipeFDs[0], pipeFDs[1]) } - - - private func makeRawSocketServerChannel(eventLoopGroup: EventLoopGroup) async throws -> NIOAsyncChannel { + private func makeRawSocketServerChannel( + eventLoopGroup: EventLoopGroup + ) async throws -> NIOAsyncChannel { try await NIORawSocketBootstrap(group: eventLoopGroup) .bind( host: "127.0.0.1", @@ -1135,16 +1175,24 @@ final class AsyncChannelBootstrapTests: XCTestCase { ) { channel in channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(IPHeaderRemoverHandler()) - try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0))) - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 1))) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 2))) + try channel.pipeline.syncOperations.addHandler( + AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + ) + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(LineDelimiterCoder(inboundID: 1)) + ) + try channel.pipeline.syncOperations.addHandler( + MessageToByteHandler(LineDelimiterCoder(outboundID: 2)) + ) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } } } - private func makeRawSocketClientChannel(eventLoopGroup: EventLoopGroup) async throws -> NIOAsyncChannel { + private func makeRawSocketClientChannel( + eventLoopGroup: EventLoopGroup + ) async throws -> NIOAsyncChannel { try await NIORawSocketBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -1152,9 +1200,15 @@ final class AsyncChannelBootstrapTests: XCTestCase { ) { channel in channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(IPHeaderRemoverHandler()) - try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0))) - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder(inboundID: 2))) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder(outboundID: 1))) + try channel.pipeline.syncOperations.addHandler( + AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + ) + try channel.pipeline.syncOperations.addHandler( + ByteToMessageHandler(LineDelimiterCoder(inboundID: 2)) + ) + try channel.pipeline.syncOperations.addHandler( + MessageToByteHandler(LineDelimiterCoder(outboundID: 1)) + ) try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } @@ -1169,10 +1223,17 @@ final class AsyncChannelBootstrapTests: XCTestCase { host: "127.0.0.1", ipProtocol: .reservedForTesting ) { channel in - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(IPHeaderRemoverHandler()) - try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0))) - return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: nil, inboundID: 1, outboundID: 2) + try channel.pipeline.syncOperations.addHandler( + AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + ) + return try self.configureProtocolNegotiationHandlers( + channel: channel, + proposedALPN: nil, + inboundID: 1, + outboundID: 2 + ) } } } @@ -1186,16 +1247,26 @@ final class AsyncChannelBootstrapTests: XCTestCase { host: "127.0.0.1", ipProtocol: .reservedForTesting ) { channel in - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(IPHeaderRemoverHandler()) - try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0))) - return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN, inboundID: 2, outboundID: 1) + try channel.pipeline.syncOperations.addHandler( + AddressedEnvelopingHandler(remoteAddress: SocketAddress(ipAddress: "127.0.0.1", port: 0)) + ) + return try self.configureProtocolNegotiationHandlers( + channel: channel, + proposedALPN: proposedALPN, + inboundID: 2, + outboundID: 1 + ) } } } - private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel { - return try await ClientBootstrap(group: eventLoopGroup) + private func makeClientChannel( + eventLoopGroup: EventLoopGroup, + port: Int + ) async throws -> NIOAsyncChannel { + try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) ) { channel in @@ -1214,12 +1285,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { port: Int, proposedALPN: TLSUserEventHandler.ALPN ) async throws -> EventLoopFuture { - return try await ClientBootstrap(group: eventLoopGroup) + try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) ) { channel in - return channel.eventLoop.makeCompletedFuture { - return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN) + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN) } } } @@ -1230,11 +1301,11 @@ final class AsyncChannelBootstrapTests: XCTestCase { proposedOuterALPN: TLSUserEventHandler.ALPN, proposedInnerALPN: TLSUserEventHandler.ALPN ) async throws -> EventLoopFuture> { - return try await ClientBootstrap(group: eventLoopGroup) + try await ClientBootstrap(group: eventLoopGroup) .connect( to: .init(ipAddress: "127.0.0.1", port: port) ) { channel in - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try self.configureNestedProtocolNegotiationHandlers( channel: channel, proposedOuterALPN: proposedOuterALPN, @@ -1270,14 +1341,17 @@ final class AsyncChannelBootstrapTests: XCTestCase { host: "127.0.0.1", port: port ) { channel in - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler()) return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN) } } } - private func makeUDPClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel { + private func makeUDPClientChannel( + eventLoopGroup: EventLoopGroup, + port: Int + ) async throws -> NIOAsyncChannel { try await DatagramBootstrap(group: eventLoopGroup) .connect( host: "127.0.0.1", @@ -1303,7 +1377,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { host: "127.0.0.1", port: port ) { channel in - return channel.eventLoop.makeCompletedFuture { + channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler()) return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN) } @@ -1332,20 +1406,26 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler> { alpnResult, channel in + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler> { + alpnResult, + channel in switch alpnResult { case .negotiated(let alpn): switch alpn { case "string": return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) + try channel.pipeline.syncOperations.addHandler( + TLSUserEventHandler(proposedALPN: proposedInnerALPN) + ) let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) return negotiationFuture } case "byte": return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) + try channel.pipeline.syncOperations.addHandler( + TLSUserEventHandler(proposedALPN: proposedInnerALPN) + ) let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) return negotiationHandler @@ -1362,8 +1442,12 @@ final class AsyncChannelBootstrapTests: XCTestCase { } @discardableResult - private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture { - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in + private func addTypedApplicationProtocolNegotiationHandler( + to channel: Channel + ) throws -> EventLoopFuture { + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { + alpnResult, + channel in switch alpnResult { case .negotiated(let alpn): switch alpn { @@ -1412,7 +1496,12 @@ extension AsyncStream { } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -private func XCTAsyncAssertEqual(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows { +private func XCTAsyncAssertEqual( + _ lhs: @autoclosure () async throws -> Element, + _ rhs: @autoclosure () async throws -> Element, + file: StaticString = #filePath, + line: UInt = #line +) async rethrows { let lhsResult = try await lhs() let rhsResult = try await rhs() XCTAssertEqual(lhsResult, rhsResult, file: file, line: line) diff --git a/Tests/NIOPosixTests/BlockingIOThreadPoolTest.swift b/Tests/NIOPosixTests/BlockingIOThreadPoolTest.swift index a58b1de07d..1545d88ad8 100644 --- a/Tests/NIOPosixTests/BlockingIOThreadPoolTest.swift +++ b/Tests/NIOPosixTests/BlockingIOThreadPoolTest.swift @@ -12,10 +12,10 @@ // //===----------------------------------------------------------------------===// -import XCTest -import NIOPosix import Dispatch import Foundation +import NIOPosix +import XCTest class BlockingIOThreadPoolTest: XCTestCase { func testDoubleShutdownWorks() throws { @@ -96,7 +96,7 @@ class BlockingIOThreadPoolTest: XCTestCase { XCTAssertNil(error) allDone.signal() } - blockOneThreadSem.signal() // that'll unblock the thread in the pool + blockOneThreadSem.signal() // that'll unblock the thread in the pool allDone.wait() } diff --git a/Tests/NIOPosixTests/BootstrapTest.swift b/Tests/NIOPosixTests/BootstrapTest.swift index 0351bb73c3..250db702cd 100644 --- a/Tests/NIOPosixTests/BootstrapTest.swift +++ b/Tests/NIOPosixTests/BootstrapTest.swift @@ -12,15 +12,16 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded -@testable import NIOPosix -import NIOConcurrencyHelpers import XCTest +@testable import NIOPosix + class BootstrapTest: XCTestCase { var group: MultiThreadedEventLoopGroup! - var groupBag: [MultiThreadedEventLoopGroup]? = nil // protected by `self.lock` + var groupBag: [MultiThreadedEventLoopGroup]? = nil // protected by `self.lock` let lock = NIOLock() override func setUp() { @@ -33,17 +34,19 @@ class BootstrapTest: XCTestCase { } override func tearDown() { - XCTAssertNoThrow(try self.lock.withLock { - guard let groupBag = self.groupBag else { - XCTFail() - return + XCTAssertNoThrow( + try self.lock.withLock { + guard let groupBag = self.groupBag else { + XCTFail() + return + } + for group in groupBag { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + self.groupBag = nil + XCTAssertNotNil(self.group) } - XCTAssertNoThrow(try groupBag.forEach { - XCTAssertNoThrow(try $0.syncShutdownGracefully()) - }) - self.groupBag = nil - XCTAssertNotNil(self.group) - }) + ) XCTAssertNoThrow(try self.group?.syncShutdownGracefully()) self.group = nil } @@ -57,9 +60,11 @@ class BootstrapTest: XCTestCase { } func testBootstrapsCallInitializersOnCorrectEventLoop() throws { - for numThreads in [1 /* everything on one event loop */, - 2 /* some stuff has shared event loops */, - 5 /* everything on a different event loop */] { + for numThreads in [ + 1, // everything on one event loop + 2, // some stuff has shared event loops + 5, // everything on a different event loop + ] { let group = MultiThreadedEventLoopGroup(numberOfThreads: numThreads) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) @@ -67,30 +72,36 @@ class BootstrapTest: XCTestCase { let childChannelDone = group.next().makePromise(of: Void.self) let serverChannelDone = group.next().makePromise(of: Void.self) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .childChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - childChannelDone.succeed(()) - return channel.eventLoop.makeSucceededFuture(()) - } - .serverChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - serverChannelDone.succeed(()) - return channel.eventLoop.makeSucceededFuture(()) - } - .bind(host: "localhost", port: 0) - .wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .childChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + childChannelDone.succeed(()) + return channel.eventLoop.makeSucceededFuture(()) + } + .serverChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + serverChannelDone.succeed(()) + return channel.eventLoop.makeSucceededFuture(()) + } + .bind(host: "localhost", port: 0) + .wait() + ) defer { XCTAssertNoThrow(try serverChannel.close().wait()) } - let client = try assertNoThrowWithValue(ClientBootstrap(group: group) - .channelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - return channel.eventLoop.makeSucceededFuture(()) - } - .connect(to: serverChannel.localAddress!) - .wait(), message: "resolver debug info: \(try! resolverDebugInformation(eventLoop: group.next(),host: "localhost", previouslyReceivedResult: serverChannel.localAddress!))") + let client = try assertNoThrowWithValue( + ClientBootstrap(group: group) + .channelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + return channel.eventLoop.makeSucceededFuture(()) + } + .connect(to: serverChannel.localAddress!) + .wait(), + message: + "resolver debug info: \(try! resolverDebugInformation(eventLoop: group.next(),host: "localhost", previouslyReceivedResult: serverChannel.localAddress!))" + ) defer { XCTAssertNoThrow(try client.syncCloseAcceptingAlreadyClosed()) } @@ -102,34 +113,38 @@ class BootstrapTest: XCTestCase { func testTCPBootstrapsTolerateFuturesFromDifferentEventLoopsReturnedInInitializers() throws { let childChannelDone = self.freshEventLoop().makePromise(of: Void.self) let serverChannelDone = self.freshEventLoop().makePromise(of: Void.self) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: self.freshEventLoop()) - .childChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - defer { - childChannelDone.succeed(()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: self.freshEventLoop()) + .childChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + defer { + childChannelDone.succeed(()) + } + return self.freshEventLoop().makeSucceededFuture(()) } - return self.freshEventLoop().makeSucceededFuture(()) - } - .serverChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - defer { - serverChannelDone.succeed(()) + .serverChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + defer { + serverChannelDone.succeed(()) + } + return self.freshEventLoop().makeSucceededFuture(()) } - return self.freshEventLoop().makeSucceededFuture(()) - } - .bind(host: "127.0.0.1", port: 0) - .wait()) + .bind(host: "127.0.0.1", port: 0) + .wait() + ) defer { XCTAssertNoThrow(try serverChannel.close().wait()) } - let client = try assertNoThrowWithValue(ClientBootstrap(group: self.freshEventLoop()) - .channelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - return self.freshEventLoop().makeSucceededFuture(()) - } - .connect(to: serverChannel.localAddress!) - .wait()) + let client = try assertNoThrowWithValue( + ClientBootstrap(group: self.freshEventLoop()) + .channelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + return self.freshEventLoop().makeSucceededFuture(()) + } + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try client.syncCloseAcceptingAlreadyClosed()) } @@ -138,37 +153,45 @@ class BootstrapTest: XCTestCase { } func testUDPBootstrapToleratesFuturesFromDifferentEventLoopsReturnedInInitializers() throws { - XCTAssertNoThrow(try DatagramBootstrap(group: self.freshEventLoop()) - .channelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - return self.freshEventLoop().makeSucceededFuture(()) - } - .bind(host: "127.0.0.1", port: 0) - .wait() - .close() - .wait()) + XCTAssertNoThrow( + try DatagramBootstrap(group: self.freshEventLoop()) + .channelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + return self.freshEventLoop().makeSucceededFuture(()) + } + .bind(host: "127.0.0.1", port: 0) + .wait() + .close() + .wait() + ) } func testPreConnectedClientSocketToleratesFuturesFromDifferentEventLoopsReturnedInInitializers() throws { var socketFDs: [CInt] = [-1, -1] - XCTAssertNoThrow(try Posix.socketpair(domain: .local, - type: .stream, - protocolSubtype: .default, - socketVector: &socketFDs)) + XCTAssertNoThrow( + try Posix.socketpair( + domain: .local, + type: .stream, + protocolSubtype: .default, + socketVector: &socketFDs + ) + ) defer { // 0 is closed together with the Channel below. XCTAssertNoThrow(try NIOBSDSocket.close(socket: socketFDs[1])) } - XCTAssertNoThrow(try ClientBootstrap(group: self.freshEventLoop()) - .channelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - return self.freshEventLoop().makeSucceededFuture(()) - } - .withConnectedSocket(socketFDs[0]) - .wait() - .close() - .wait()) + XCTAssertNoThrow( + try ClientBootstrap(group: self.freshEventLoop()) + .channelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + return self.freshEventLoop().makeSucceededFuture(()) + } + .withConnectedSocket(socketFDs[0]) + .wait() + .close() + .wait() + ) } func testPreConnectedServerSocketToleratesFuturesFromDifferentEventLoopsReturnedInInitializers() throws { @@ -177,37 +200,44 @@ class BootstrapTest: XCTestCase { let serverAddress = try assertNoThrowWithValue(SocketAddress.makeAddressResolvingHost("127.0.0.1", port: 0)) try serverAddress.withSockAddr { address, len in - try NIOBSDSocket.bind(socket: socket, address: address, - address_len: socklen_t(len)) + try NIOBSDSocket.bind( + socket: socket, + address: address, + address_len: socklen_t(len) + ) } let childChannelDone = self.freshEventLoop().next().makePromise(of: Void.self) let serverChannelDone = self.freshEventLoop().next().makePromise(of: Void.self) - let serverChannel = try assertNoThrowWithValue(try ServerBootstrap(group: self.freshEventLoop()) - .childChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - defer { - childChannelDone.succeed(()) + let serverChannel = try assertNoThrowWithValue( + try ServerBootstrap(group: self.freshEventLoop()) + .childChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + defer { + childChannelDone.succeed(()) + } + return self.freshEventLoop().makeSucceededFuture(()) } - return self.freshEventLoop().makeSucceededFuture(()) - } - .serverChannelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - defer { - serverChannelDone.succeed(()) + .serverChannelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + defer { + serverChannelDone.succeed(()) + } + return self.freshEventLoop().makeSucceededFuture(()) } - return self.freshEventLoop().makeSucceededFuture(()) - } - .withBoundSocket(socket) - .wait()) - let client = try assertNoThrowWithValue(ClientBootstrap(group: self.freshEventLoop()) - .channelInitializer { channel in - XCTAssert(channel.eventLoop.inEventLoop) - return self.freshEventLoop().makeSucceededFuture(()) - } - .connect(to: serverChannel.localAddress!) - .wait()) + .withBoundSocket(socket) + .wait() + ) + let client = try assertNoThrowWithValue( + ClientBootstrap(group: self.freshEventLoop()) + .channelInitializer { channel in + XCTAssert(channel.eventLoop.inEventLoop) + return self.freshEventLoop().makeSucceededFuture(()) + } + .connect(to: serverChannel.localAddress!) + .wait() + ) defer { XCTAssertNoThrow(try client.syncCloseAcceptingAlreadyClosed()) } @@ -223,16 +253,20 @@ class BootstrapTest: XCTestCase { func restrictBootstrapType(clientBootstrap: NIOClientTCPBootstrap) throws { let serverAcceptedChannelPromise = group.next().makePromise(of: Channel.self) - let serverChannel = try assertNoThrowWithValue(ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelInitializer { channel in - serverAcceptedChannelPromise.succeed(channel) - return channel.eventLoop.makeSucceededFuture(()) - }.bind(host: "127.0.0.1", port: 0).wait()) + let serverChannel = try assertNoThrowWithValue( + ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + serverAcceptedChannelPromise.succeed(channel) + return channel.eventLoop.makeSucceededFuture(()) + }.bind(host: "127.0.0.1", port: 0).wait() + ) - let clientChannel = try assertNoThrowWithValue(clientBootstrap - .channelInitializer({ (channel: Channel) in channel.eventLoop.makeSucceededFuture(()) }) - .connect(host: "127.0.0.1", port: serverChannel.localAddress!.port!).wait()) + let clientChannel = try assertNoThrowWithValue( + clientBootstrap + .channelInitializer({ (channel: Channel) in channel.eventLoop.makeSucceededFuture(()) }) + .connect(host: "127.0.0.1", port: serverChannel.localAddress!.port!).wait() + ) var buffer = clientChannel.allocator.buffer(capacity: 1) buffer.writeString("a") @@ -244,15 +278,21 @@ class BootstrapTest: XCTestCase { XCTAssertNoThrow(try clientChannel.close().wait()) // Wait for the close promises. These fire last. - XCTAssertNoThrow(try EventLoopFuture.andAllSucceed([clientChannel.closeFuture, - serverAcceptedChannel.closeFuture], - on: group.next()).wait()) + XCTAssertNoThrow( + try EventLoopFuture.andAllSucceed( + [ + clientChannel.closeFuture, + serverAcceptedChannel.closeFuture, + ], + on: group.next() + ).wait() + ) } let bootstrap = NIOClientTCPBootstrap(ClientBootstrap(group: group), tls: NIOInsecureNoTLS()) try restrictBootstrapType(clientBootstrap: bootstrap) } - + func testServerBootstrapBindTimeout() throws { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { @@ -271,58 +311,89 @@ class BootstrapTest: XCTestCase { func testServerBootstrapSetsChannelOptionsBeforeChannelInitializer() { var channel: Channel? = nil - XCTAssertNoThrow(channel = try ServerBootstrap(group: self.group) - .serverChannelOption(ChannelOptions.autoRead, value: false) - .serverChannelInitializer { channel in - channel.getOption(ChannelOptions.autoRead).whenComplete { result in - func workaround() { - XCTAssertNoThrow(XCTAssertFalse(try result.get())) + XCTAssertNoThrow( + channel = try ServerBootstrap(group: self.group) + .serverChannelOption(ChannelOptions.autoRead, value: false) + .serverChannelInitializer { channel in + channel.getOption(ChannelOptions.autoRead).whenComplete { result in + func workaround() { + XCTAssertNoThrow(XCTAssertFalse(try result.get())) + } + workaround() } - workaround() + return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) } - return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) XCTAssertNotNil(channel) XCTAssertNoThrow(try channel?.close().wait()) } func testClientBootstrapSetsChannelOptionsBeforeChannelInitializer() { - XCTAssertNoThrow(try withTCPServerChannel(group: self.group) { server in - var channel: Channel? = nil - XCTAssertNoThrow(channel = try ClientBootstrap(group: self.group) - .channelOption(ChannelOptions.autoRead, value: false) - .channelInitializer { channel in - channel.getOption(ChannelOptions.autoRead).whenComplete { result in - func workaround() { - XCTAssertNoThrow(XCTAssertFalse(try result.get())) + XCTAssertNoThrow( + try withTCPServerChannel(group: self.group) { server in + var channel: Channel? = nil + XCTAssertNoThrow( + channel = try ClientBootstrap(group: self.group) + .channelOption(ChannelOptions.autoRead, value: false) + .channelInitializer { channel in + channel.getOption(ChannelOptions.autoRead).whenComplete { result in + func workaround() { + XCTAssertNoThrow(XCTAssertFalse(try result.get())) + } + workaround() + } + return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) } - workaround() - } - return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) + .connect(to: server.localAddress!) + .wait() + ) + XCTAssertNotNil(channel) + XCTAssertNoThrow(try channel?.close().wait()) } - .connect(to: server.localAddress!) - .wait()) - XCTAssertNotNil(channel) - XCTAssertNoThrow(try channel?.close().wait()) - }) + ) } func testPreConnectedSocketSetsChannelOptionsBeforeChannelInitializer() { - XCTAssertNoThrow(try withTCPServerChannel(group: self.group) { server in - var maybeSocket: Socket? = nil - XCTAssertNoThrow(maybeSocket = try Socket(protocolFamily: .inet, type: .stream)) - XCTAssertNoThrow(XCTAssertEqual(true, try maybeSocket?.connect(to: server.localAddress!))) - var maybeFD: CInt? = nil - XCTAssertNoThrow(maybeFD = try maybeSocket?.takeDescriptorOwnership()) - guard let fd = maybeFD else { - XCTFail("could not get a socket fd") - return + XCTAssertNoThrow( + try withTCPServerChannel(group: self.group) { server in + var maybeSocket: Socket? = nil + XCTAssertNoThrow(maybeSocket = try Socket(protocolFamily: .inet, type: .stream)) + XCTAssertNoThrow(XCTAssertEqual(true, try maybeSocket?.connect(to: server.localAddress!))) + var maybeFD: CInt? = nil + XCTAssertNoThrow(maybeFD = try maybeSocket?.takeDescriptorOwnership()) + guard let fd = maybeFD else { + XCTFail("could not get a socket fd") + return + } + + var channel: Channel? = nil + XCTAssertNoThrow( + channel = try ClientBootstrap(group: self.group) + .channelOption(ChannelOptions.autoRead, value: false) + .channelInitializer { channel in + channel.getOption(ChannelOptions.autoRead).whenComplete { result in + func workaround() { + XCTAssertNoThrow(XCTAssertFalse(try result.get())) + } + workaround() + } + return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) + } + .withConnectedSocket(fd) + .wait() + ) + XCTAssertNotNil(channel) + XCTAssertNoThrow(try channel?.close().wait()) } + ) + } - var channel: Channel? = nil - XCTAssertNoThrow(channel = try ClientBootstrap(group: self.group) + func testDatagramBootstrapSetsChannelOptionsBeforeChannelInitializer() { + var channel: Channel? = nil + XCTAssertNoThrow( + channel = try DatagramBootstrap(group: self.group) .channelOption(ChannelOptions.autoRead, value: false) .channelInitializer { channel in channel.getOption(ChannelOptions.autoRead).whenComplete { result in @@ -332,61 +403,46 @@ class BootstrapTest: XCTestCase { workaround() } return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) - } - .withConnectedSocket(fd) - .wait()) - XCTAssertNotNil(channel) - XCTAssertNoThrow(try channel?.close().wait()) - }) - } - - func testDatagramBootstrapSetsChannelOptionsBeforeChannelInitializer() { - var channel: Channel? = nil - XCTAssertNoThrow(channel = try DatagramBootstrap(group: self.group) - .channelOption(ChannelOptions.autoRead, value: false) - .channelInitializer { channel in - channel.getOption(ChannelOptions.autoRead).whenComplete { result in - func workaround() { - XCTAssertNoThrow(XCTAssertFalse(try result.get())) - } - workaround() } - return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) XCTAssertNotNil(channel) XCTAssertNoThrow(try channel?.close().wait()) } func testPipeBootstrapSetsChannelOptionsBeforeChannelInitializer() { - XCTAssertNoThrow(try withPipe { inPipe, outPipe in - var maybeInFD: CInt? = nil - var maybeOutFD: CInt? = nil - XCTAssertNoThrow(maybeInFD = try inPipe.takeDescriptorOwnership()) - XCTAssertNoThrow(maybeOutFD = try outPipe.takeDescriptorOwnership()) - guard let inFD = maybeInFD, let outFD = maybeOutFD else { - XCTFail("couldn't get pipe fds") - return [inPipe, outPipe] - } - var channel: Channel? = nil - XCTAssertNoThrow(channel = try NIOPipeBootstrap(group: self.group) - .channelOption(ChannelOptions.autoRead, value: false) - .channelInitializer { channel in - channel.getOption(ChannelOptions.autoRead).whenComplete { result in - func workaround() { - XCTAssertNoThrow(XCTAssertFalse(try result.get())) + XCTAssertNoThrow( + try withPipe { inPipe, outPipe in + var maybeInFD: CInt? = nil + var maybeOutFD: CInt? = nil + XCTAssertNoThrow(maybeInFD = try inPipe.takeDescriptorOwnership()) + XCTAssertNoThrow(maybeOutFD = try outPipe.takeDescriptorOwnership()) + guard let inFD = maybeInFD, let outFD = maybeOutFD else { + XCTFail("couldn't get pipe fds") + return [inPipe, outPipe] + } + var channel: Channel? = nil + XCTAssertNoThrow( + channel = try NIOPipeBootstrap(group: self.group) + .channelOption(ChannelOptions.autoRead, value: false) + .channelInitializer { channel in + channel.getOption(ChannelOptions.autoRead).whenComplete { result in + func workaround() { + XCTAssertNoThrow(XCTAssertFalse(try result.get())) + } + workaround() + } + return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) } - workaround() - } - return channel.pipeline.addHandler(MakeSureAutoReadIsOffInChannelInitializer()) + .takingOwnershipOfDescriptors(input: inFD, output: outFD) + .wait() + ) + XCTAssertNotNil(channel) + XCTAssertNoThrow(try channel?.close().wait()) + return [] } - .takingOwnershipOfDescriptors(input: inFD, output: outFD) - .wait()) - XCTAssertNotNil(channel) - XCTAssertNoThrow(try channel?.close().wait()) - return [] - }) + ) } func testPipeBootstrapInEventLoop() { @@ -401,7 +457,10 @@ class BootstrapTest: XCTestCase { let readHandle = NIOFileHandle(descriptor: pipe.fileHandleForReading.fileDescriptor) let writeHandle = NIOFileHandle(descriptor: pipe.fileHandleForWriting.fileDescriptor) _ = NIOPipeBootstrap(group: self.group) - .takingOwnershipOfDescriptors(input: try readHandle.takeDescriptorOwnership(), output: try writeHandle.takeDescriptorOwnership()) + .takingOwnershipOfDescriptors( + input: try readHandle.takeDescriptorOwnership(), + output: try writeHandle.takeDescriptorOwnership() + ) .flatMap({ channel in channel.close() }).always({ _ in @@ -425,22 +484,24 @@ class BootstrapTest: XCTestCase { struct FoundHandlerThatWasNotSupposedToBeThereError: Error {} var maybeServer: Channel? = nil - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: group) - .serverChannelInitializer { channel in - // Here, we test that we can't find the AcceptHandler - return channel.pipeline.context(name: "AcceptHandler").flatMap { context -> EventLoopFuture in - XCTFail("unexpectedly found \(context)") - return channel.eventLoop.makeFailedFuture(FoundHandlerThatWasNotSupposedToBeThereError()) - }.flatMapError { error -> EventLoopFuture in - XCTAssertEqual(.notFound, error as? ChannelPipelineError) - if case .some(.notFound) = error as? ChannelPipelineError { - return channel.eventLoop.makeSucceededFuture(()) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: group) + .serverChannelInitializer { channel in + // Here, we test that we can't find the AcceptHandler + channel.pipeline.context(name: "AcceptHandler").flatMap { context -> EventLoopFuture in + XCTFail("unexpectedly found \(context)") + return channel.eventLoop.makeFailedFuture(FoundHandlerThatWasNotSupposedToBeThereError()) + }.flatMapError { error -> EventLoopFuture in + XCTAssertEqual(.notFound, error as? ChannelPipelineError) + if case .some(.notFound) = error as? ChannelPipelineError { + return channel.eventLoop.makeSucceededFuture(()) + } + return channel.eventLoop.makeFailedFuture(error) } - return channel.eventLoop.makeFailedFuture(error) } - } - .bind(host: "127.0.0.1", port: 0) - .wait()) + .bind(host: "127.0.0.1", port: 0) + .wait() + ) guard let server = maybeServer else { XCTFail("couldn't bootstrap server") @@ -558,30 +619,45 @@ class BootstrapTest: XCTestCase { XCTAssertNil(NIOPipeBootstrap(validatingGroup: elg)) XCTAssertNil(NIOPipeBootstrap(validatingGroup: el)) } - + func testConvenienceOptionsAreEquivalentUniversalClient() throws { - func setAndGetOption