Skip to content

Commit

Permalink
Chunk Data/Text frames by outboundMaxFrameSize.
Browse files Browse the repository at this point in the history
This change fixes a bug when sending text/data that was large WebsocketKit did not chunk these text/data into mutliple frames.
  • Loading branch information
hishnash committed Jul 24, 2021
1 parent bc3c30d commit 91d92c1
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 20 deletions.
97 changes: 85 additions & 12 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ public final class WebSocket {
private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?

init(channel: Channel, type: PeerType) {
private var outboundMaxFrameSize: WebSocketMaxFrameSize

init(
channel: Channel,
type: PeerType,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default
) {
self.channel = channel
self.type = type
self.onTextCallback = { _, _ in }
Expand All @@ -45,6 +50,7 @@ public final class WebSocket {
self.waitingForPong = false
self.waitingForClose = false
self.scheduledTimeoutTask = nil
self.outboundMaxFrameSize = outboundMaxFrameSize
}

public func onText(_ callback: @escaping (WebSocket, String) -> ()) {
Expand Down Expand Up @@ -88,12 +94,78 @@ public final class WebSocket {
let string = String(text)
var buffer = channel.allocator.buffer(capacity: text.count)
buffer.writeString(string)
self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise)

self.send(buffer: buffer, opcode: .text, promise: promise)
}

public func send(_ binary: [UInt8], promise: EventLoopPromise<Void>? = nil) {
self.send(raw: binary, opcode: .binary, fin: true, promise: promise)
var buffer = channel.allocator.buffer(capacity: binary.count)
buffer.writeBytes(binary)
self.send(buffer: buffer, opcode: .binary, promise: promise)
}

public func send(
buffer: NIO.ByteBuffer,
opcode: WebSocketOpcode,
promise: EventLoopPromise<Void>? = nil
) {

guard buffer.readableBytes > self.outboundMaxFrameSize.value else {
let frame = WebSocketFrame(
fin: true,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: buffer
)
self.channel.writeAndFlush(frame, promise: promise)
return
}

var buffer = buffer

// We need to ensure we write all of these items in order on the event loop without other writes interrupting the frames.
self.channel.eventLoop.execute {
// Send the first frame with the opcode
let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value)!
let frame = WebSocketFrame(
fin: false,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: frameBuffer
)

self.channel.write(frame, promise: nil)


while let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value) {

let isFinalFrame = buffer.readableBytes == 0

let frame = WebSocketFrame(
fin: isFinalFrame,
opcode: .continuation,
maskKey: self.makeMaskKey(),
data: frameBuffer
)

if isFinalFrame {
self.channel.writeAndFlush(frame, promise: promise)
return
} else {
// write operations that happen when already on the event loop go directly through without any `delay`.
self.channel.write(frame, promise: nil)
}
}
// we will end up here if the number bytes is not a multiple of the `outboundMaxFrameSize`
let finalFrame = WebSocketFrame(
fin: true,
opcode: .continuation,
maskKey: self.makeMaskKey(),
data: buffer
)

self.channel.writeAndFlush(finalFrame, promise: promise)
}
}

public func sendPing(promise: EventLoopPromise<Void>? = nil) {
Expand All @@ -115,6 +187,7 @@ public final class WebSocket {
{
var buffer = channel.allocator.buffer(capacity: data.count)
buffer.writeBytes(data)

let frame = WebSocketFrame(
fin: fin,
opcode: opcode,
Expand Down Expand Up @@ -213,16 +286,16 @@ public final class WebSocket {
self.close(code: .protocolError, promise: nil)
}
case .text, .binary, .pong:
// create a new frame sequence or use existing
var frameSequence: WebSocketFrameSequence
if let existing = self.frameSequence {
frameSequence = existing
if self.frameSequence != nil {
// we should not have an existing sequence
self.close(code: .protocolError, promise: nil)
} else {
frameSequence = WebSocketFrameSequence(type: frame.opcode)
// Create a new frame sequence
var frameSequence = WebSocketFrameSequence(type: frame.opcode)
// Append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
}
// append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
case .continuation:
// we must have an existing sequence
if var frameSequence = self.frameSequence {
Expand Down
26 changes: 21 additions & 5 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ public final class WebSocketClient {

public struct Configuration {
public var tlsConfiguration: TLSConfiguration?
public var maxFrameSize: Int
public var inboundMaxFrameSize: WebSocketMaxFrameSize
public var outboundMaxFrameSize: WebSocketMaxFrameSize

public init(
tlsConfiguration: TLSConfiguration? = nil,
maxFrameSize: Int = 1 << 14
maxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default
) {
self.tlsConfiguration = tlsConfiguration
self.maxFrameSize = maxFrameSize
self.inboundMaxFrameSize = maxFrameSize
self.outboundMaxFrameSize = maxFrameSize
}

public init(
tlsConfiguration: TLSConfiguration? = nil,
inboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default,
outboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default
) {
self.tlsConfiguration = tlsConfiguration
self.inboundMaxFrameSize = inboundMaxFrameSize
self.outboundMaxFrameSize = outboundMaxFrameSize
}
}

Expand Down Expand Up @@ -75,10 +87,14 @@ public final class WebSocketClient {
}
let websocketUpgrader = NIOWebSocketClientUpgrader(
requestKey: Data(key).base64EncodedString(),
maxFrameSize: self.configuration.maxFrameSize,
maxFrameSize: self.configuration.inboundMaxFrameSize.value,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, req in
return WebSocket.client(on: channel, onUpgrade: onUpgrade)
return WebSocket.client(
on: channel,
outboundMaxFrameSize: self.configuration.outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}
)

Expand Down
24 changes: 21 additions & 3 deletions Sources/WebSocketKit/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,42 @@ import NIOWebSocket
extension WebSocket {
public static func client(
on channel: Channel,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .client, onUpgrade: onUpgrade)
return self.handle(
on: channel,
as: .client,
outboundMaxFrameSize: outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}

public static func server(
on channel: Channel,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .server, onUpgrade: onUpgrade)
return self.handle(
on: channel,
as: .server,
outboundMaxFrameSize: outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}

private static func handle(
on channel: Channel,
as type: PeerType,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
let webSocket = WebSocket(channel: channel, type: type)
let webSocket = WebSocket(
channel: channel,
type: type,
outboundMaxFrameSize: outboundMaxFrameSize
)

return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in
onUpgrade(webSocket)
}
Expand Down
17 changes: 17 additions & 0 deletions Sources/WebSocketKit/WebSocketMaxFrameSize.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
public struct WebSocketMaxFrameSize: ExpressibleByIntegerLiteral {
public let value: Int

public init(_ value: Int) {
precondition(value <= UInt32.max, "invalid overlarge max frame size")
self.value = value
}

public init(integerLiteral value: Int) {
precondition(value <= UInt32.max, "invalid overlarge max frame size")
self.value = value
}

public static var `default`: Self {
self.init(integerLiteral: 1 << 14)
}
}
27 changes: 27 additions & 0 deletions Tests/WebSocketKitTests/WebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ final class WebSocketKitTests: XCTestCase {
func testBadHost() throws {
XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait())
}


func testLargeTextFrame() throws {
let port = Int.random(in: 8000..<9000)

let sendPromise = self.elg.next().makePromise(of: Void.self)
let serverClose = self.elg.next().makePromise(of: Void.self)
let clientClose = self.elg.next().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
if text == "close" {
ws.close(promise: serverClose)
}
}
}.bind(host: "localhost", port: port).wait()
let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: 2)
WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in
ws.send("close", promise: sendPromise)
ws.onClose.cascade(to: clientClose)
}.cascadeFailure(to: sendPromise)

XCTAssertNoThrow(try sendPromise.futureResult.wait())
XCTAssertNoThrow(try serverClose.futureResult.wait())
XCTAssertNoThrow(try clientClose.futureResult.wait())
try server.close(mode: .all).wait()
}


func testServerClose() throws {
let port = Int.random(in: 8000..<9000)
Expand Down

0 comments on commit 91d92c1

Please sign in to comment.