From d8a861a4e8831806cf82bf1ac6ede4253c984dd0 Mon Sep 17 00:00:00 2001 From: Matthaus Woolard Date: Thu, 22 Jul 2021 17:01:26 +1200 Subject: [PATCH] Chunk Data/Text frames by outboundMaxFrameSize. This change fixes a bug when sending text/data that was large WebsocketKit did not chunk these text/data into mutliple frames. --- Sources/WebSocketKit/WebSocket.swift | 67 +++++++++++++++++-- Sources/WebSocketKit/WebSocketClient.swift | 26 +++++-- Sources/WebSocketKit/WebSocketHandler.swift | 24 ++++++- .../WebSocketKit/WebSocketMaxFrameSize.swift | 17 +++++ .../WebSocketKitTests/WebSocketKitTests.swift | 27 ++++++++ 5 files changed, 149 insertions(+), 12 deletions(-) create mode 100644 Sources/WebSocketKit/WebSocketMaxFrameSize.swift diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index b280212b..d739ee9b 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -34,8 +34,13 @@ public final class WebSocket { private var waitingForPong: Bool private var waitingForClose: Bool private var scheduledTimeoutTask: Scheduled? - - init(channel: Channel, type: PeerType) { + private var outboundMaxFrameSize: WebSocketMaxFrameSize + + init( + channel: Channel, + type: PeerType, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default + ) { self.channel = channel self.type = type self.onTextCallback = { _, _ in } @@ -45,6 +50,7 @@ public final class WebSocket { self.waitingForPong = false self.waitingForClose = false self.scheduledTimeoutTask = nil + self.outboundMaxFrameSize = outboundMaxFrameSize } public func onText(_ callback: @escaping (WebSocket, String) -> ()) { @@ -88,12 +94,65 @@ public final class WebSocket { let string = String(text) var buffer = channel.allocator.buffer(capacity: text.count) buffer.writeString(string) - self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) + self.send(buffer: buffer, opcode: .text, promise: promise) } public func send(_ binary: [UInt8], promise: EventLoopPromise? = nil) { - self.send(raw: binary, opcode: .binary, fin: true, promise: promise) + var buffer = channel.allocator.buffer(capacity: binary.count) + buffer.writeBytes(binary) + self.send(buffer: buffer, opcode: .binary, promise: promise) + } + + public func send( + buffer: NIO.ByteBuffer, + opcode: WebSocketOpcode, + promise: EventLoopPromise? = nil + ) { + guard buffer.readableBytes > outboundMaxFrameSize.value else { + let frame = WebSocketFrame( + fin: true, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: buffer + ) + self.channel.writeAndFlush(frame, promise: promise) + return + } + + var buffer = buffer + + var framesToSend: [WebSocketFrame] = [] + + while let frameBuffer = buffer.readSlice(length: outboundMaxFrameSize.value) { + let frame = WebSocketFrame( + fin: buffer.readableBytes == 0, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: frameBuffer + ) + framesToSend.append(frame) + } + + if buffer.readableBytes > 0 { + let frame = WebSocketFrame( + fin: true, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: buffer + ) + framesToSend.append(frame) + } + + let startingOut: EventLoopFuture = self.channel.eventLoop.makeSucceededFuture(Void()) + + let future: EventLoopFuture = framesToSend.reduce(startingOut) { future, frame in + return future.flatMap { _ in + self.channel.writeAndFlush(frame) + } + } + + promise?.completeWith(future) } public func sendPing(promise: EventLoopPromise? = nil) { diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 1eb4df78..0e6f80a0 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -22,14 +22,26 @@ public final class WebSocketClient { public struct Configuration { public var tlsConfiguration: TLSConfiguration? - public var maxFrameSize: Int + public var inboundMaxFrameSize: WebSocketMaxFrameSize + public var outboundMaxFrameSize: WebSocketMaxFrameSize public init( tlsConfiguration: TLSConfiguration? = nil, - maxFrameSize: Int = 1 << 14 + maxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default ) { self.tlsConfiguration = tlsConfiguration - self.maxFrameSize = maxFrameSize + self.inboundMaxFrameSize = maxFrameSize + self.outboundMaxFrameSize = maxFrameSize + } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + inboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default, + outboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default + ) { + self.tlsConfiguration = tlsConfiguration + self.inboundMaxFrameSize = inboundMaxFrameSize + self.outboundMaxFrameSize = outboundMaxFrameSize } } @@ -75,10 +87,14 @@ public final class WebSocketClient { } let websocketUpgrader = NIOWebSocketClientUpgrader( requestKey: Data(key).base64EncodedString(), - maxFrameSize: self.configuration.maxFrameSize, + maxFrameSize: self.configuration.inboundMaxFrameSize.value, automaticErrorHandling: true, upgradePipelineHandler: { channel, req in - return WebSocket.client(on: channel, onUpgrade: onUpgrade) + return WebSocket.client( + on: channel, + outboundMaxFrameSize: self.configuration.outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } ) diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index b54f9fb7..9a671be3 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -4,24 +4,42 @@ import NIOWebSocket extension WebSocket { public static func client( on channel: Channel, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .client, onUpgrade: onUpgrade) + return self.handle( + on: channel, + as: .client, + outboundMaxFrameSize: outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } public static func server( on channel: Channel, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .server, onUpgrade: onUpgrade) + return self.handle( + on: channel, + as: .server, + outboundMaxFrameSize: outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } private static func handle( on channel: Channel, as type: PeerType, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - let webSocket = WebSocket(channel: channel, type: type) + let webSocket = WebSocket( + channel: channel, + type: type, + outboundMaxFrameSize: outboundMaxFrameSize + ) + return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in onUpgrade(webSocket) } diff --git a/Sources/WebSocketKit/WebSocketMaxFrameSize.swift b/Sources/WebSocketKit/WebSocketMaxFrameSize.swift new file mode 100644 index 00000000..8212eaf2 --- /dev/null +++ b/Sources/WebSocketKit/WebSocketMaxFrameSize.swift @@ -0,0 +1,17 @@ +public struct WebSocketMaxFrameSize: ExpressibleByIntegerLiteral { + public let value: Int + + public init(_ value: Int) { + precondition(value <= UInt32.max, "invalid overlarge max frame size") + self.value = value + } + + public init(integerLiteral value: Int) { + precondition(value <= UInt32.max, "invalid overlarge max frame size") + self.value = value + } + + public static var `default`: Self { + self.init(integerLiteral: 1 << 14) + } +} diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index a6936d1f..55216e81 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -34,6 +34,33 @@ final class WebSocketKitTests: XCTestCase { func testBadHost() throws { XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait()) } + + + func testLargeTextFrame() throws { + let port = Int.random(in: 8000..<9000) + + let sendPromise = self.elg.next().makePromise(of: Void.self) + let serverClose = self.elg.next().makePromise(of: Void.self) + let clientClose = self.elg.next().makePromise(of: Void.self) + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.onText { ws, text in + if text == "close" { + ws.close(promise: serverClose) + } + } + }.bind(host: "localhost", port: port).wait() + let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: 2) + WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in + ws.send("close", promise: sendPromise) + ws.onClose.cascade(to: clientClose) + }.cascadeFailure(to: sendPromise) + + XCTAssertNoThrow(try sendPromise.futureResult.wait()) + XCTAssertNoThrow(try serverClose.futureResult.wait()) + XCTAssertNoThrow(try clientClose.futureResult.wait()) + try server.close(mode: .all).wait() + } + func testServerClose() throws { let port = Int.random(in: 8000..<9000)