Skip to content

Commit

Permalink
Pass a copy of the control frame buffer to ping/pong callbacks (#116)
Browse files Browse the repository at this point in the history
* Pass a copy of the control frame buffer to callbacks

* Add back old API

Since the new and old methods are only overloads and share the same name, the 'renamed' parameter of the deprecation warning doesn't help.

* Allow specifying payload when sending ping

* Remove default value in favor of method forwarding

This preserves the signature of the original method and doesn't break the API

* Apply suggestions from code review

New APIs should use safe code

---------

Co-authored-by: Tim Condon <[email protected]>
  • Loading branch information
tkrajacic and 0xTim authored May 29, 2023
1 parent 2ec1450 commit 53fe063
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 14 deletions.
32 changes: 29 additions & 3 deletions Sources/WebSocketKit/Concurrency/WebSocket+Concurrency.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ extension WebSocket {
}

public func sendPing() async throws {
try await sendPing(Data())
}

public func sendPing(_ data: Data) async throws {
let promise = eventLoop.makePromise(of: Void.self)
sendPing(promise: promise)
sendPing(data, promise: promise)
return try await promise.futureResult.get()
}

Expand Down Expand Up @@ -60,19 +64,41 @@ extension WebSocket {
}
}

public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) {
self.eventLoop.execute {
self.onPong { socket, data in
Task {
await callback(socket, data)
}
}
}
}

@available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.")
@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) async -> ()) {
self.eventLoop.execute {
self.onPong { socket in
self.onPong { socket, _ in
Task {
await callback(socket)
}
}
}
}

public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) async -> ()) {
self.eventLoop.execute {
self.onPing { socket, data in
Task {
await callback(socket, data)
}
}
}
}

@available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.")
@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) async -> ()) {
self.eventLoop.execute {
self.onPing { socket in
self.onPing { socket, _ in
Task {
await callback(socket)
}
Expand Down
32 changes: 23 additions & 9 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ public final class WebSocket: Sendable {
internal let channel: Channel
private let onTextCallback: NIOLoopBoundBox<@Sendable (WebSocket, String) -> ()>
private let onBinaryCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()>
private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket) -> ()>
private let onPongCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
private let onPingCallback: NIOLoopBoundBox<@Sendable (WebSocket, ByteBuffer) -> ()>
private let type: PeerType
private let waitingForPong: NIOLockedValueBox<Bool>
private let waitingForClose: NIOLockedValueBox<Bool>
Expand All @@ -48,8 +48,8 @@ public final class WebSocket: Sendable {
self.type = type
self.onTextCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
self.onBinaryCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
self.onPongCallback = .init({ _ in }, eventLoop: channel.eventLoop)
self.onPingCallback = .init({ _ in }, eventLoop: channel.eventLoop)
self.onPongCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
self.onPingCallback = .init({ _, _ in }, eventLoop: channel.eventLoop)
self.waitingForPong = .init(false)
self.waitingForClose = .init(false)
self.scheduledTimeoutTask = .init(nil)
Expand All @@ -66,13 +66,23 @@ public final class WebSocket: Sendable {
self.onBinaryCallback.value = callback
}

@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) {
public func onPong(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) {
self.onPongCallback.value = callback
}

@available(*, deprecated, message: "Please use `onPong { socket, data in /* … */ }` with the additional `data` parameter.")
@preconcurrency public func onPong(_ callback: @Sendable @escaping (WebSocket) -> ()) {
self.onPongCallback.value = { ws, _ in callback(ws) }
}

@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) {
public func onPing(_ callback: @Sendable @escaping (WebSocket, ByteBuffer) -> ()) {
self.onPingCallback.value = callback
}

@available(*, deprecated, message: "Please use `onPing { socket, data in /* … */ }` with the additional `data` parameter.")
@preconcurrency public func onPing(_ callback: @Sendable @escaping (WebSocket) -> ()) {
self.onPingCallback.value = { ws, _ in callback(ws) }
}

/// If set, this will trigger automatic pings on the connection. If ping is not answered before
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed
Expand Down Expand Up @@ -112,8 +122,12 @@ public final class WebSocket: Sendable {
}

public func sendPing(promise: EventLoopPromise<Void>? = nil) {
sendPing(Data(), promise: promise)
}

public func sendPing(_ data: Data, promise: EventLoopPromise<Void>? = nil) {
self.send(
raw: Data(),
raw: data,
opcode: .ping,
fin: true,
promise: promise
Expand Down Expand Up @@ -236,7 +250,7 @@ public final class WebSocket: Sendable {
if let maskingKey = maskingKey {
frameData.webSocketUnmask(maskingKey)
}
self.onPingCallback.value(self)
self.onPingCallback.value(self, ByteBuffer(buffer: frameData))
self.send(
raw: frameData.readableBytesView,
opcode: .pong,
Expand All @@ -254,7 +268,7 @@ public final class WebSocket: Sendable {
frameData.webSocketUnmask(maskingKey)
}
self.waitingForPong.withLockedValue { $0 = false }
self.onPongCallback.value(self)
self.onPongCallback.value(self, ByteBuffer(buffer: frameData))
} else {
self.close(code: .protocolError, promise: nil)
}
Expand Down
7 changes: 5 additions & 2 deletions Tests/WebSocketKitTests/WebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ final class WebSocketKitTests: XCTestCase {
let pingPongData = ByteBuffer(bytes: "Vapor rules".utf8)

let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onPing { ws in
ws.onPing { ws, data in
XCTAssertEqual(pingPongData, data)
pingPromise.succeed("ping")
}
}.bind(host: "localhost", port: 0).wait()
Expand All @@ -144,7 +145,9 @@ final class WebSocketKitTests: XCTestCase {
}

WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
ws.onPong { ws in
ws.sendPing(Data(pingPongData.readableBytesView))
ws.onPong { ws, data in
XCTAssertEqual(pingPongData, data)
pongPromise.succeed("pong")
ws.close(promise: nil)
}
Expand Down

0 comments on commit 53fe063

Please sign in to comment.