Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk Data/Text frames by outboundMaxFrameSize. #96

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 85 additions & 12 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ public final class WebSocket {
private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?

init(channel: Channel, type: PeerType) {
private var outboundMaxFrameSize: WebSocketMaxFrameSize

init(
channel: Channel,
type: PeerType,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default
) {
self.channel = channel
self.type = type
self.onTextCallback = { _, _ in }
Expand All @@ -45,6 +50,7 @@ public final class WebSocket {
self.waitingForPong = false
self.waitingForClose = false
self.scheduledTimeoutTask = nil
self.outboundMaxFrameSize = outboundMaxFrameSize
}

public func onText(_ callback: @escaping (WebSocket, String) -> ()) {
Expand Down Expand Up @@ -88,12 +94,78 @@ public final class WebSocket {
let string = String(text)
var buffer = channel.allocator.buffer(capacity: text.count)
buffer.writeString(string)
self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise)

self.send(buffer: buffer, opcode: .text, promise: promise)
}

public func send(_ binary: [UInt8], promise: EventLoopPromise<Void>? = nil) {
self.send(raw: binary, opcode: .binary, fin: true, promise: promise)
var buffer = channel.allocator.buffer(capacity: binary.count)
buffer.writeBytes(binary)
self.send(buffer: buffer, opcode: .binary, promise: promise)
}

public func send(
buffer: NIO.ByteBuffer,
opcode: WebSocketOpcode,
promise: EventLoopPromise<Void>? = nil
) {

guard buffer.readableBytes > self.outboundMaxFrameSize.value else {
let frame = WebSocketFrame(
fin: true,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: buffer
)
self.channel.writeAndFlush(frame, promise: promise)
return
}

var buffer = buffer

// We need to ensure we write all of these items in order on the event loop without other writes interrupting the frames.
self.channel.eventLoop.execute {
// Send the first frame with the opcode
let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value)!
let frame = WebSocketFrame(
fin: false,
opcode: opcode,
maskKey: self.makeMaskKey(),
data: frameBuffer
)

self.channel.write(frame, promise: nil)


while let frameBuffer = buffer.readSlice(length: self.outboundMaxFrameSize.value) {

let isFinalFrame = buffer.readableBytes == 0

let frame = WebSocketFrame(
fin: isFinalFrame,
opcode: .continuation,
maskKey: self.makeMaskKey(),
data: frameBuffer
)

if isFinalFrame {
self.channel.writeAndFlush(frame, promise: promise)
return
} else {
// write operations that happen when already on the event loop go directly through without any `delay`.
self.channel.write(frame, promise: nil)
}
}
// we will end up here if the number bytes is not a multiple of the `outboundMaxFrameSize`
let finalFrame = WebSocketFrame(
fin: true,
opcode: .continuation,
maskKey: self.makeMaskKey(),
data: buffer
)

self.channel.writeAndFlush(finalFrame, promise: promise)
}
hishnash marked this conversation as resolved.
Show resolved Hide resolved
}

public func sendPing(promise: EventLoopPromise<Void>? = nil) {
Expand All @@ -115,6 +187,7 @@ public final class WebSocket {
{
var buffer = channel.allocator.buffer(capacity: data.count)
buffer.writeBytes(data)

let frame = WebSocketFrame(
fin: fin,
opcode: opcode,
Expand Down Expand Up @@ -213,16 +286,16 @@ public final class WebSocket {
self.close(code: .protocolError, promise: nil)
}
case .text, .binary, .pong:
// create a new frame sequence or use existing
var frameSequence: WebSocketFrameSequence
if let existing = self.frameSequence {
frameSequence = existing
if self.frameSequence != nil {
// we should not have an existing sequence
self.close(code: .protocolError, promise: nil)
} else {
frameSequence = WebSocketFrameSequence(type: frame.opcode)
// Create a new frame sequence
var frameSequence = WebSocketFrameSequence(type: frame.opcode)
// Append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
}
// append this frame and update the sequence
frameSequence.append(frame)
self.frameSequence = frameSequence
case .continuation:
// we must have an existing sequence
if var frameSequence = self.frameSequence {
Expand Down
26 changes: 21 additions & 5 deletions Sources/WebSocketKit/WebSocketClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ public final class WebSocketClient {

public struct Configuration {
public var tlsConfiguration: TLSConfiguration?
public var maxFrameSize: Int
public var inboundMaxFrameSize: WebSocketMaxFrameSize
public var outboundMaxFrameSize: WebSocketMaxFrameSize

public init(
tlsConfiguration: TLSConfiguration? = nil,
maxFrameSize: Int = 1 << 14
maxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default
) {
self.tlsConfiguration = tlsConfiguration
self.maxFrameSize = maxFrameSize
self.inboundMaxFrameSize = maxFrameSize
self.outboundMaxFrameSize = maxFrameSize
}

public init(
tlsConfiguration: TLSConfiguration? = nil,
inboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default,
outboundMaxFrameSize: WebSocketMaxFrameSize = WebSocketMaxFrameSize.default
) {
self.tlsConfiguration = tlsConfiguration
self.inboundMaxFrameSize = inboundMaxFrameSize
self.outboundMaxFrameSize = outboundMaxFrameSize
}
}

Expand Down Expand Up @@ -75,10 +87,14 @@ public final class WebSocketClient {
}
let websocketUpgrader = NIOWebSocketClientUpgrader(
requestKey: Data(key).base64EncodedString(),
maxFrameSize: self.configuration.maxFrameSize,
maxFrameSize: self.configuration.inboundMaxFrameSize.value,
automaticErrorHandling: true,
upgradePipelineHandler: { channel, req in
return WebSocket.client(on: channel, onUpgrade: onUpgrade)
return WebSocket.client(
on: channel,
outboundMaxFrameSize: self.configuration.outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}
)

Expand Down
24 changes: 21 additions & 3 deletions Sources/WebSocketKit/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,42 @@ import NIOWebSocket
extension WebSocket {
public static func client(
on channel: Channel,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .client, onUpgrade: onUpgrade)
return self.handle(
on: channel,
as: .client,
outboundMaxFrameSize: outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}

public static func server(
on channel: Channel,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
return self.handle(on: channel, as: .server, onUpgrade: onUpgrade)
return self.handle(
on: channel,
as: .server,
outboundMaxFrameSize: outboundMaxFrameSize,
onUpgrade: onUpgrade
)
}

private static func handle(
on channel: Channel,
as type: PeerType,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (WebSocket) -> ()
) -> EventLoopFuture<Void> {
let webSocket = WebSocket(channel: channel, type: type)
let webSocket = WebSocket(
channel: channel,
type: type,
outboundMaxFrameSize: outboundMaxFrameSize
)

return channel.pipeline.addHandler(WebSocketHandler(webSocket: webSocket)).map { _ in
onUpgrade(webSocket)
}
Expand Down
17 changes: 17 additions & 0 deletions Sources/WebSocketKit/WebSocketMaxFrameSize.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
public struct WebSocketMaxFrameSize: ExpressibleByIntegerLiteral {
public let value: Int

public init(_ value: Int) {
precondition(value <= UInt32.max, "invalid overlarge max frame size")
self.value = value
}

public init(integerLiteral value: Int) {
precondition(value <= UInt32.max, "invalid overlarge max frame size")
self.value = value
}

public static var `default`: Self {
self.init(integerLiteral: 1 << 14)
}
}
111 changes: 110 additions & 1 deletion Tests/WebSocketKitTests/WebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,111 @@ final class WebSocketKitTests: XCTestCase {
func testBadHost() throws {
XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait())
}


func testMutliFrameMessage() throws {
let port = Int.random(in: 8000..<9000)

let sendPromise = self.elg.next().makePromise(of: Void.self)
let serverClose = self.elg.next().makePromise(of: Void.self)
let clientClose = self.elg.next().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
if text == "close" {
ws.close(promise: serverClose)
}
}
}.bind(host: "localhost", port: port).wait()
let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: 2)
WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in
ws.send("close", promise: sendPromise)
ws.onClose.cascade(to: clientClose)
}.cascadeFailure(to: sendPromise)

XCTAssertNoThrow(try sendPromise.futureResult.wait())
XCTAssertNoThrow(try serverClose.futureResult.wait())
XCTAssertNoThrow(try clientClose.futureResult.wait())
try server.close(mode: .all).wait()
}

func testMultiFrameMessageOrdering() throws {
let port = Int.random(in: 8000..<9000)

// Sending from client to server
let sendMultiFrameMessagePromise = self.elg.next().makePromise(of: Void.self)
let sendSingleFrameMessagePromise = self.elg.next().makePromise(of: Void.self)

// received on server
let receivedMultiFrameMessagePromise = self.elg.next().makePromise(of: Void.self)
let receivedSingleFrameMessagePromise = self.elg.next().makePromise(of: Void.self)

// received echo on client
let receivedMultiFrameMessageEchoPromise = self.elg.next().makePromise(of: Void.self)
let receivedSingleFrameMessageEchoPromise = self.elg.next().makePromise(of: Void.self)


let clientClose = self.elg.next().makePromise(of: Void.self)
let serverClose = self.elg.next().makePromise(of: Void.self)

let maxFrameSize: WebSocketMaxFrameSize = 13

let server = try ServerBootstrap.webSocket(
on: self.elg,
outboundMaxFrameSize: maxFrameSize,
inboundMaxFrameSize: maxFrameSize
) { req, ws in
ws.onClose.cascade(to: serverClose)

ws.onText { ws, text in
receivedSingleFrameMessagePromise.succeed(())
ws.send(text, promise: nil)
}

ws.onBinary { ws, buffer in
receivedMultiFrameMessagePromise.succeed(())
ws.send(buffer: buffer, opcode: .binary, promise: nil)
}
}.bind(host: "localhost", port: port).wait()

let config = WebSocketClient.Configuration(tlsConfiguration: nil, maxFrameSize: maxFrameSize)

WebSocket.connect(to: "ws://localhost:\(port)", configuration: config, on: self.elg) { ws in
ws.onBinary { ws, buffer in
XCTAssertEqual(buffer.readableBytes, 10000)
receivedMultiFrameMessageEchoPromise.succeed(())
}

ws.onText { ws, str in
XCTAssertEqual(str, "singleFrame")
receivedSingleFrameMessageEchoPromise.succeed(())
}

ws.send(buffer: ByteBuffer(repeating: 1, count: 10000), opcode: .binary, promise: sendMultiFrameMessagePromise)

sendMultiFrameMessagePromise.futureResult.whenSuccess { _ in
ws.send("singleFrame", promise: sendSingleFrameMessagePromise)
}

// Send close after Multi frame response has arrived.
receivedMultiFrameMessageEchoPromise.futureResult.whenComplete { _ in
ws.close(promise: clientClose)
}
}.cascadeFailure(to: clientClose)

XCTAssertNoThrow(try sendMultiFrameMessagePromise.futureResult.wait())
XCTAssertNoThrow(try sendSingleFrameMessagePromise.futureResult.wait())

XCTAssertNoThrow(try receivedMultiFrameMessagePromise.futureResult.wait())
XCTAssertNoThrow(try receivedSingleFrameMessagePromise.futureResult.wait())

XCTAssertNoThrow(try receivedMultiFrameMessageEchoPromise.futureResult.wait())
XCTAssertNoThrow(try receivedSingleFrameMessageEchoPromise.futureResult.wait())

XCTAssertNoThrow(try clientClose.futureResult.wait())

try server.close(mode: .all).wait()
}


func testServerClose() throws {
let port = Int.random(in: 8000..<9000)
Expand Down Expand Up @@ -226,15 +331,19 @@ final class WebSocketKitTests: XCTestCase {
extension ServerBootstrap {
static func webSocket(
on eventLoopGroup: EventLoopGroup,
outboundMaxFrameSize: WebSocketMaxFrameSize = .default,
inboundMaxFrameSize: WebSocketMaxFrameSize = .default,
onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()

) -> ServerBootstrap {
ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
maxFrameSize: inboundMaxFrameSize.value,
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel) { ws in
return WebSocket.server(on: channel, outboundMaxFrameSize: outboundMaxFrameSize) { ws in
onUpgrade(req, ws)
}
}
Expand Down