Skip to content

Commit

Permalink
Merge pull request #6 from vapor/ws-close-code
Browse files Browse the repository at this point in the history
add support for websocket close code; other bug fixes
  • Loading branch information
tanner0101 authored May 9, 2018
2 parents 39eeef2 + a5ec841 commit 141cb4d
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
/Packages
/*.xcodeproj
Package.resolved
DerivedData

5 changes: 4 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ let package = Package(
dependencies: [
// 🌎 Utility package containing tools for byte manipulation, Codable, OS APIs, and debugging.
.package(url: "https://github.com/vapor/core.git", from: "3.0.0"),

// 🔑 Hashing (BCrypt, SHA2, HMAC), encryption (AES), public-key (RSA), and random data generation.
.package(url: "https://github.com/vapor/crypto.git", from: "3.0.0"),

// 🚀 Non-blocking, event-driven HTTP for Swift built on Swift NIO.
.package(url: "https://github.com/vapor/http.git", from: "3.0.0"),
Expand All @@ -20,7 +23,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "1.0.1"),
],
targets: [
.target(name: "WebSocket", dependencies: ["Core", "HTTP", "NIO", "NIOWebSocket"]),
.target(name: "WebSocket", dependencies: ["Core", "Crypto", "HTTP", "NIO", "NIOWebSocket"]),
.testTarget(name: "WebSocketTests", dependencies: ["WebSocket"]),
]
)
23 changes: 18 additions & 5 deletions Sources/WebSocket/WebSocket+Client.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Crypto

/// Allows `HTTPClient` to be used to create `WebSocket` connections.
///
/// let ws = try HTTPClient.webSocket(hostname: "echo.websocket.org", on: ...).wait()
Expand Down Expand Up @@ -25,6 +27,7 @@ extension HTTPClient {
/// - port: Remote server's port, defaults to 80 for TCP and 443 for TLS.
/// - path: Path on remote server to connect to.
/// - headers: Additional HTTP headers are used to establish a connection.
/// - maxFrameSize: Maximum WebSocket frame size this client will accept.
/// - worker: `Worker` to perform async work on.
/// - returns: A `Future` containing the connected `WebSocket`.
public static func webSocket(
Expand All @@ -33,9 +36,10 @@ extension HTTPClient {
port: Int? = nil,
path: String = "/",
headers: HTTPHeaders = .init(),
maxFrameSize: Int = 1 << 14,
on worker: Worker
) -> Future<WebSocket> {
let upgrader = WebSocketClientUpgrader(hostname: hostname, path: path, headers: headers)
let upgrader = WebSocketClientUpgrader(hostname: hostname, path: path, headers: headers, maxFrameSize: maxFrameSize)
return HTTPClient.upgrade(scheme: scheme, hostname: hostname, port: port, upgrader: upgrader, on: worker)
}
}
Expand All @@ -53,11 +57,15 @@ private final class WebSocketClientUpgrader: HTTPClientProtocolUpgrader {
/// Additional headers to use when upgrading.
let headers: HTTPHeaders

/// Maximum frame size for decoder.
private let maxFrameSize: Int

/// Creates a new `WebSocketClientUpgrader`.
init(hostname: String, path: String, headers: HTTPHeaders) {
init(hostname: String, path: String, headers: HTTPHeaders, maxFrameSize: Int) {
self.hostname = hostname
self.path = path
self.headers = headers
self.maxFrameSize = maxFrameSize
}

/// See `HTTPClientProtocolUpgrader`.
Expand All @@ -68,8 +76,13 @@ private final class WebSocketClientUpgrader: HTTPClientProtocolUpgrader {
upgradeReq.headers.add(name: .upgrade, value: "websocket")
upgradeReq.headers.add(name: .host, value: hostname)
upgradeReq.headers.add(name: .origin, value: "vapor/websocket")
upgradeReq.headers.add(name: .secWebSocketVersion, value: "13") // fixme: randomly gen
upgradeReq.headers.add(name: .secWebSocketKey, value: "MTMtMTUyMzk4NDIxNzk3NQ==") // fixme: randomly gen
upgradeReq.headers.add(name: .secWebSocketVersion, value: "13")
do {
let webSocketKey = try CryptoRandom().generateData(count: 16).base64EncodedString()
upgradeReq.headers.add(name: .secWebSocketKey, value: webSocketKey)
} catch {
print("[WebSocket] [Upgrader] Could not generate random value for Sec-WebSocket-Key header: \(error)")
}
return upgradeReq
}

Expand All @@ -86,7 +99,7 @@ private final class WebSocketClientUpgrader: HTTPClientProtocolUpgrader {
/// See `HTTPClientProtocolUpgrader`.
func upgrade(ctx: ChannelHandlerContext, upgradeResponse: HTTPResponseHead) -> Future<WebSocket> {
let webSocket = WebSocket(channel: ctx.channel)
return ctx.channel.pipeline.addHandlers(WebSocketFrameEncoder(), WebSocketFrameDecoder(), first: false).then {
return ctx.channel.pipeline.addHandlers(WebSocketFrameEncoder(), WebSocketFrameDecoder(maxFrameSize: maxFrameSize), first: false).then {
return ctx.channel.pipeline.add(webSocket: webSocket)
}.map(to: WebSocket.self) {
return webSocket
Expand Down
9 changes: 8 additions & 1 deletion Sources/WebSocket/WebSocket+Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ extension HTTPServer {
///
/// HTTPServer.start(..., upgraders: [ws])
///
/// - parameters:
/// - maxFrameSize: Maximum WebSocket frame size this server will accept.
/// - shouldUpgrade: Called when an incoming HTTPRequest attempts to upgrade.
/// Return non-nil headers to accept the upgrade.
/// - onUpgrade: Called when a new WebSocket client has connected.
/// - returns: An `HTTPProtocolUpgrader` for use with `HTTPServer`.
public static func webSocketUpgrader(
maxFrameSize: Int = 1 << 14,
shouldUpgrade: @escaping (HTTPRequest) -> (HTTPHeaders?),
onUpgrade: @escaping (WebSocket, HTTPRequest) -> ()
) -> HTTPProtocolUpgrader {
return WebSocketUpgrader(shouldUpgrade: { head in
return WebSocketUpgrader(maxFrameSize: maxFrameSize, shouldUpgrade: { head in
let req = HTTPRequest(
method: head.method,
url: head.uri,
Expand Down
40 changes: 37 additions & 3 deletions Sources/WebSocket/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public final class WebSocket: BasicWorker {
/// See `onError(...)`.
var onErrorCallback: (WebSocket, Error) -> ()

/// See `onCloseCode(...)`.
var onCloseCodeCallback: (WebSocketErrorCode) -> ()

/// Creates a new `WebSocket` using the supplied `Channel`.
/// Use `httpProtocolUpgrader(...)` to create a protocol upgrader that can create `WebSocket`s.
internal init(channel: Channel) {
Expand All @@ -31,6 +34,7 @@ public final class WebSocket: BasicWorker {
self.onTextCallback = { _, _ in }
self.onBinaryCallback = { _, _ in }
self.onErrorCallback = { _, _ in }
self.onCloseCodeCallback = { _ in }
}

// MARK: Receive
Expand Down Expand Up @@ -77,6 +81,18 @@ public final class WebSocket: BasicWorker {
onErrorCallback = callback
}

/// Adds a callback to this `WebSocket` to handle incoming close codes.
///
/// ws.onCloseCode { closeCode in
/// print(closeCode)
/// }
///
/// - parameters:
/// - callback: Closure to handle received close codes.
public func onCloseCode(_ callback: @escaping (WebSocketErrorCode) -> ()) {
onCloseCodeCallback = callback
}

// MARK: Send

/// Sends text-formatted data to the connected client.
Expand Down Expand Up @@ -134,23 +150,41 @@ public final class WebSocket: BasicWorker {
// MARK: Close

/// `true` if the `WebSocket` has been closed.
public private(set) var isClosed: Bool
public internal(set) var isClosed: Bool

/// A `Future` that will be completed when the `WebSocket` closes.
public var onClose: Future<Void> {
return channel.closeFuture
}

/// Closes the `WebSocket`'s connection, disconnecting the client.
public func close() {
///
/// - parameters:
/// - code: Optional `WebSocketCloseCode` to send before closing the connection.
/// If a code is provided, the WebSocket will wait until an acknowledgment is
/// received from the server before actually closing the connection.
public func close(code: WebSocketErrorCode? = nil) {
guard !isClosed else {
return
}
channel.close(promise: nil)
self.isClosed = true
if let code = code {
sendClose(code: code)
} else {
channel.close(promise: nil)
}
}

// MARK: Private

/// Private just send close code.
private func sendClose(code: WebSocketErrorCode) {
var buffer = channel.allocator.buffer(capacity: 2)
buffer.write(webSocketErrorCode: code)
let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: buffer)
send(frame, promise: nil)
}

/// Private send that accepts a raw `WebSocketOpcode`.
private func send(_ data: LosslessDataConvertible, opcode: WebSocketOpcode, promise: Promise<Void>?) {
guard !isClosed else { return }
Expand Down
21 changes: 12 additions & 9 deletions Sources/WebSocket/WebSocketHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@ private final class WebSocketHandler: ChannelInboundHandler {
/// See `ChannelInboundHandler`.
typealias OutboundOut = WebSocketFrame

/// If true, a close has been sent from this decoder.
private var awaitingClose: Bool

/// `WebSocket` to handle the incoming events.
private var webSocket: WebSocket

/// Creates a new `WebSocketEventDecoder`
init(webSocket: WebSocket) {
self.awaitingClose = false
self.webSocket = webSocket
}

Expand Down Expand Up @@ -61,18 +57,25 @@ private final class WebSocketHandler: ChannelInboundHandler {

/// Closes gracefully.
private func receivedClose(ctx: ChannelHandlerContext, frame: WebSocketFrame) {
/// Parse the close frame.
var data = frame.unmaskedData
if let closeCode = data.readInteger(as: UInt16.self)
.map(Int.init)
.flatMap(WebSocketErrorCode.init(codeNumber:))
{
webSocket.onCloseCodeCallback(closeCode)
}

// Handle a received close frame. In websockets, we're just going to send the close
// frame and then close, unless we already sent our own close frame.
if awaitingClose {
if webSocket.isClosed {
// Cool, we started the close and were waiting for the user. We're done.
ctx.close(promise: nil)
} else {
// This is an unsolicited close. We're going to send a response frame and
// then, when we've sent it, close up shop. We should send back the close code the remote
// peer sent us, unless they didn't send one at all.
var data = frame.unmaskedData
let closeDataCode = data.readSlice(length: 2) ?? ctx.channel.allocator.buffer(capacity: 0)
let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode)
let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data)
_ = ctx.write(wrapOutboundOut(closeFrame)).always {
_ = ctx.close(promise: nil)
}
Expand Down Expand Up @@ -103,6 +106,6 @@ private final class WebSocketHandler: ChannelInboundHandler {
_ = ctx.write(self.wrapOutboundOut(frame)).then {
ctx.close(mode: .output)
}
awaitingClose = true
webSocket.isClosed = true
}
}
14 changes: 11 additions & 3 deletions Tests/WebSocketTests/WebSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ class WebSocketTests: XCTestCase {
func testClient() throws {
// ws://echo.websocket.org
let worker = MultiThreadedEventLoopGroup(numThreads: 1)
let webSocket = try HTTPClient.webSocket(hostname: "echo.websocket.org", on: worker).wait()
let ws = try HTTPClient.webSocket(hostname: "echo.websocket.org", on: worker).wait()

let promise = worker.eventLoop.newPromise(String.self)
webSocket.onText { ws, text in
ws.onText { ws, text in
promise.succeed(result: text)
ws.close(code: .normalClosure)
}
ws.onCloseCode { code in
print("code: \(code)")
}
let message = "Hello, world!"
webSocket.send(message)
ws.send(message)
try XCTAssertEqual(promise.futureResult.wait(), message)
try ws.onClose.wait()
}

func testClientTLS() throws {
Expand Down Expand Up @@ -49,6 +54,9 @@ class WebSocketTests: XCTestCase {
ws.onBinary { ws, data in
print("data: \(data)")
}
ws.onCloseCode { code in
print("code: \(code)")
}
ws.onClose.always {
print("closed")
}
Expand Down

0 comments on commit 141cb4d

Please sign in to comment.