diff --git a/Sources/WebSocketKit/WebSocket.swift b/Sources/WebSocketKit/WebSocket.swift index b280212b..ff83401a 100644 --- a/Sources/WebSocketKit/WebSocket.swift +++ b/Sources/WebSocketKit/WebSocket.swift @@ -34,8 +34,13 @@ public final class WebSocket { private var waitingForPong: Bool private var waitingForClose: Bool private var scheduledTimeoutTask: Scheduled? - - init(channel: Channel, type: PeerType) { + private var outboundMaxFrameSize: WebSocketMaxFrameSize + + init( + channel: Channel, + type: PeerType, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default + ) { self.channel = channel self.type = type self.onTextCallback = { _, _ in } @@ -45,6 +50,7 @@ public final class WebSocket { self.waitingForPong = false self.waitingForClose = false self.scheduledTimeoutTask = nil + self.outboundMaxFrameSize = outboundMaxFrameSize } public func onText(_ callback: @escaping (WebSocket, String) -> ()) { @@ -88,12 +94,78 @@ public final class WebSocket { let string = String(text) var buffer = channel.allocator.buffer(capacity: text.count) buffer.writeString(string) - self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) + self.send(buffer: buffer, opcode: .text, promise: promise) } public func send(_ binary: [UInt8], promise: EventLoopPromise? = nil) { - self.send(raw: binary, opcode: .binary, fin: true, promise: promise) + var buffer = channel.allocator.buffer(capacity: binary.count) + buffer.writeBytes(binary) + self.send(buffer: buffer, opcode: .binary, promise: promise) + } + + public func send( + buffer: NIO.ByteBuffer, + opcode: WebSocketOpcode, + promise: EventLoopPromise? = nil + ) { + + guard buffer.readableBytes > self.outboundMaxFrameSize.value else { + let frame = WebSocketFrame( + fin: true, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: buffer + ) + self.channel.writeAndFlush(frame, promise: promise) + return + } + + var buffer = buffer + + // We need to ensure we write all of these items in order on the event loop without other writes interrupting the frames. + self.channel.eventLoop.execute { + // Send the first frame with the opcode + let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value)! + let frame = WebSocketFrame( + fin: false, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: frameBuffer + ) + + self.channel.write(frame, promise: nil) + + + while let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value) { + + let isFinalFrame = buffer.readableBytes == 0 + + let frame = WebSocketFrame( + fin: isFinalFrame, + opcode: .continuation, + maskKey: self.makeMaskKey(), + data: frameBuffer + ) + + if isFinalFrame { + self.channel.writeAndFlush(frame, promise: promise) + return + } else { + // write operations that happen when already on the event loop go directly through without any `delay`. + self.channel.write(frame, promise: nil) + } + } + // we will end up here if the number bytes is not a multiple of the `outboundMaxFrameSize` + let finalFrame = WebSocketFrame( + fin: true, + opcode: .continuation, + maskKey: self.makeMaskKey(), + data: buffer + ) + + self.channel.writeAndFlush(finalFrame, promise: promise) + } } public func sendPing(promise: EventLoopPromise? = nil) { @@ -115,6 +187,7 @@ public final class WebSocket { { var buffer = channel.allocator.buffer(capacity: data.count) buffer.writeBytes(data) + let frame = WebSocketFrame( fin: fin, opcode: opcode, @@ -213,16 +286,16 @@ public final class WebSocket { self.close(code: .protocolError, promise: nil) } case .text, .binary, .pong: - // create a new frame sequence or use existing - var frameSequence: WebSocketFrameSequence - if let existing = self.frameSequence { - frameSequence = existing + if self.frameSequence != nil { + // we should not have an existing sequence + self.close(code: .protocolError, promise: nil) } else { - frameSequence = WebSocketFrameSequence(type: frame.opcode) + // Create a new frame sequence + var frameSequence = WebSocketFrameSequence(type: frame.opcode) + // Append this frame and update the sequence + frameSequence.append(frame) + self.frameSequence = frameSequence } - // append this frame and update the sequence - frameSequence.append(frame) - self.frameSequence = frameSequence case .continuation: // we must have an existing sequence if var frameSequence = self.frameSequence { diff --git a/Sources/WebSocketKit/WebSocketClient.swift b/Sources/WebSocketKit/WebSocketClient.swift index 1eb4df78..0e6f80a0 100644 --- a/Sources/WebSocketKit/WebSocketClient.swift +++ b/Sources/WebSocketKit/WebSocketClient.swift @@ -22,14 +22,26 @@ public final class WebSocketClient { public struct Configuration { public var tlsConfiguration: TLSConfiguration? - public var maxFrameSize: Int + public var inboundMaxFrameSize: WebSocketMaxFrameSize + public var outboundMaxFrameSize: WebSocketMaxFrameSize public init( tlsConfiguration: TLSConfiguration? = nil, - maxFrameSize: Int = 1 << 14 + maxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default ) { self.tlsConfiguration = tlsConfiguration - self.maxFrameSize = maxFrameSize + self.inboundMaxFrameSize = maxFrameSize + self.outboundMaxFrameSize = maxFrameSize + } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + inboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default, + outboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default + ) { + self.tlsConfiguration = tlsConfiguration + self.inboundMaxFrameSize = inboundMaxFrameSize + self.outboundMaxFrameSize = outboundMaxFrameSize } } @@ -75,10 +87,14 @@ public final class WebSocketClient { } let websocketUpgrader = NIOWebSocketClientUpgrader( requestKey: Data(key).base64EncodedString(), - maxFrameSize: self.configuration.maxFrameSize, + maxFrameSize: self.configuration.inboundMaxFrameSize.value, automaticErrorHandling: true, upgradePipelineHandler: { channel, req in - return WebSocket.client(on: channel, onUpgrade: onUpgrade) + return WebSocket.client( + on: channel, + outboundMaxFrameSize: self.configuration.outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } ) diff --git a/Sources/WebSocketKit/WebSocketHandler.swift b/Sources/WebSocketKit/WebSocketHandler.swift index b54f9fb7..9a671be3 100644 --- a/Sources/WebSocketKit/WebSocketHandler.swift +++ b/Sources/WebSocketKit/WebSocketHandler.swift @@ -4,24 +4,42 @@ import NIOWebSocket extension WebSocket { public static func client( on channel: Channel, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .client, onUpgrade: onUpgrade) + return self.handle( + on: channel, + as: .client, + outboundMaxFrameSize: outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } public static func server( on channel: Channel, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - return self.handle(on: channel, as: .server, onUpgrade: onUpgrade) + return self.handle( + on: channel, + as: .server, + outboundMaxFrameSize: outboundMaxFrameSize, + onUpgrade: onUpgrade + ) } private static func handle( on channel: Channel, as type: PeerType, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (WebSocket) -> () ) -> EventLoopFuture { - let webSocket = WebSocket(channel: channel, type: type) + let webSocket = WebSocket( + channel: channel, + type: type, + outboundMaxFrameSize: outboundMaxFrameSize + ) + return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in onUpgrade(webSocket) } diff --git a/Sources/WebSocketKit/WebSocketMaxFrameSize.swift b/Sources/WebSocketKit/WebSocketMaxFrameSize.swift new file mode 100644 index 00000000..8212eaf2 --- /dev/null +++ b/Sources/WebSocketKit/WebSocketMaxFrameSize.swift @@ -0,0 +1,17 @@ +public struct WebSocketMaxFrameSize: ExpressibleByIntegerLiteral { + public let value: Int + + public init(_ value: Int) { + precondition(value <= UInt32.max, "invalid overlarge max frame size") + self.value = value + } + + public init(integerLiteral value: Int) { + precondition(value <= UInt32.max, "invalid overlarge max frame size") + self.value = value + } + + public static var `default`: Self { + self.init(integerLiteral: 1 << 14) + } +} diff --git a/Tests/WebSocketKitTests/WebSocketKitTests.swift b/Tests/WebSocketKitTests/WebSocketKitTests.swift index a6936d1f..812119c7 100644 --- a/Tests/WebSocketKitTests/WebSocketKitTests.swift +++ b/Tests/WebSocketKitTests/WebSocketKitTests.swift @@ -34,6 +34,111 @@ final class WebSocketKitTests: XCTestCase { func testBadHost() throws { XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait()) } + + + func testMutliFrameMessage() throws { + let port = Int.random(in: 8000..<9000) + + let sendPromise = self.elg.next().makePromise(of: Void.self) + let serverClose = self.elg.next().makePromise(of: Void.self) + let clientClose = self.elg.next().makePromise(of: Void.self) + let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in + ws.onText { ws, text in + if text == "close" { + ws.close(promise: serverClose) + } + } + }.bind(host: "localhost", port: port).wait() + let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: 2) + WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in + ws.send("close", promise: sendPromise) + ws.onClose.cascade(to: clientClose) + }.cascadeFailure(to: sendPromise) + + XCTAssertNoThrow(try sendPromise.futureResult.wait()) + XCTAssertNoThrow(try serverClose.futureResult.wait()) + XCTAssertNoThrow(try clientClose.futureResult.wait()) + try server.close(mode: .all).wait() + } + + func testMultiFrameMessageOrdering() throws { + let port = Int.random(in: 8000..<9000) + + // Sending from client to server + let sendMultiFrameMessagePromise = self.elg.next().makePromise(of: Void.self) + let sendSingleFrameMessagePromise = self.elg.next().makePromise(of: Void.self) + + // received on server + let receivedMultiFrameMessagePromise = self.elg.next().makePromise(of: Void.self) + let receivedSingleFrameMessagePromise = self.elg.next().makePromise(of: Void.self) + + // received echo on client + let receivedMultiFrameMessageEchoPromise = self.elg.next().makePromise(of: Void.self) + let receivedSingleFrameMessageEchoPromise = self.elg.next().makePromise(of: Void.self) + + + let clientClose = self.elg.next().makePromise(of: Void.self) + let serverClose = self.elg.next().makePromise(of: Void.self) + + let maxFrameSize: WebSocketMaxFrameSize = 13 + + let server = try ServerBootstrap.webSocket( + on: self.elg, + outboundMaxFrameSize: maxFrameSize, + inboundMaxFrameSize: maxFrameSize + ) { req, ws in + ws.onClose.cascade(to: serverClose) + + ws.onText { ws, text in + receivedSingleFrameMessagePromise.succeed(()) + ws.send(text, promise: nil) + } + + ws.onBinary { ws, buffer in + receivedMultiFrameMessagePromise.succeed(()) + ws.send(buffer: buffer, opcode: .binary, promise: nil) + } + }.bind(host: "localhost", port: port).wait() + + let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: maxFrameSize) + + WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in + ws.onBinary { ws, buffer in + XCTAssertEqual(buffer.readableBytes, 10000) + receivedMultiFrameMessageEchoPromise.succeed(()) + } + + ws.onText { ws, str in + XCTAssertEqual(str, "singleFrame") + receivedSingleFrameMessageEchoPromise.succeed(()) + } + + ws.send(buffer: ByteBuffer(repeating: 1, count: 10000), opcode: .binary, promise: sendMultiFrameMessagePromise) + + sendMultiFrameMessagePromise.futureResult.whenSuccess { _ in + ws.send("singleFrame", promise: sendSingleFrameMessagePromise) + } + + // Send close after Multi frame response has arrived. + receivedMultiFrameMessageEchoPromise.futureResult.whenComplete { _ in + ws.close(promise: clientClose) + } + }.cascadeFailure(to: clientClose) + + XCTAssertNoThrow(try sendMultiFrameMessagePromise.futureResult.wait()) + XCTAssertNoThrow(try sendSingleFrameMessagePromise.futureResult.wait()) + + XCTAssertNoThrow(try receivedMultiFrameMessagePromise.futureResult.wait()) + XCTAssertNoThrow(try receivedSingleFrameMessagePromise.futureResult.wait()) + + XCTAssertNoThrow(try receivedMultiFrameMessageEchoPromise.futureResult.wait()) + XCTAssertNoThrow(try receivedSingleFrameMessageEchoPromise.futureResult.wait()) + + XCTAssertNoThrow(try clientClose.futureResult.wait()) + + try server.close(mode: .all).wait() + } + func testServerClose() throws { let port = Int.random(in: 8000..<9000) @@ -226,15 +331,19 @@ final class WebSocketKitTests: XCTestCase { extension ServerBootstrap { static func webSocket( on eventLoopGroup: EventLoopGroup, + outboundMaxFrameSize: WebSocketMaxFrameSize = .default, + inboundMaxFrameSize: WebSocketMaxFrameSize = .default, onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> () + ) -> ServerBootstrap { ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in let webSocket = NIOWebSocketServerUpgrader( + maxFrameSize: inboundMaxFrameSize.value, shouldUpgrade: { channel, req in return channel.eventLoop.makeSucceededFuture([:]) }, upgradePipelineHandler: { channel, req in - return WebSocket.server(on: channel) { ws in + return WebSocket.server(on: channel, outboundMaxFrameSize: outboundMaxFrameSize) { ws in onUpgrade(req, ws) } }