Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install and run RabbitMQ
run: |
brew install rabbitmq
rabbitmq-plugins enable rabbitmq_stomp
rabbitmq-plugins enable rabbitmq_stomp rabbitmq_web_stomp
brew services start rabbitmq
- name: Run unit tests
run: swift test --enable-code-coverage
Expand Down
3 changes: 2 additions & 1 deletion Notice.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ components that this product depends on.
-------------------------------------------------------------------------------

This product was heavily influenced by MQTT NIO.
It contains a derivation of MQTT NIO's 'MQTTTask.swift'.
It contains a derivation of MQTT NIO's 'MQTTTask.swift', 'WebSocketInitialRequest.swift' and 'WebSocketHandler.swift'.
It also contains a version of MQTT NIO's 'TSTLSConfiguration.swift'.

* LICENSE (Apache License 2.0)
* https://github.com/swift-server-community/mqtt-nio/blob/main/LICENSE
Expand Down
4 changes: 4 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ let package = Package(
dependencies: [
.package(url: "https://github.com/apple/swift-log.git", from: "1.7.1"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.91.0"),
.package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.36.0"),
.package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.26.0"),
.package(url: "https://github.com/apple/swift-configuration.git", from: "1.0.0"),
],
Expand All @@ -26,6 +27,9 @@ let package = Package(
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOPosix", package: "swift-nio"),
.product(name: "NIOWebSocket", package: "swift-nio"),
.product(name: "NIOHTTP1", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl", condition: .when(platforms: [.linux, .macOS])),
.product(name: "NIOTransportServices", package: "swift-nio-transport-services"),
.product(name: "Configuration", package: "swift-configuration"),
],
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ It defines a text based wire-format for messages passed between these clients an
STOMP has been in active use for several years and is supported by many message brokers and client libraries.

STOMPNIO is a Swift NIO based implementation of a STOMP client. It supports:
- [x] STOMP versions 1.0, 1.1, and 1.2
- [ ] Unencrypted and encrypted (via TLS) connections
- [ ] WebSocket connections
- [x] POSIX sockets
- [x] Apple's Network framework via [NIOTransportServices](https://github.com/apple/swift-nio-transport-services) (required for iOS)
- [x] Unix domain sockets
- STOMP versions 1.0, 1.1, and 1.2
- Unencrypted and encrypted (via TLS) connections
- WebSocket connections
- POSIX sockets
- Apple's Network framework via [NIOTransportServices](https://github.com/apple/swift-nio-transport-services) (required for iOS)
- Unix domain sockets

## Overview

Expand Down
2 changes: 1 addition & 1 deletion RabbitMQ/enabled_plugins
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[rabbitmq_stomp].
[rabbitmq_stomp,rabbitmq_web_stomp].
230 changes: 180 additions & 50 deletions Sources/STOMPNIO/Connection/STOMPConnection.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
public import Logging
public import NIOCore
import NIOHTTP1
public import NIOPosix
import NIOWebSocket
import Synchronization

#if os(macOS) || os(Linux)
import NIOSSL
#endif

#if canImport(Network)
import Network
import NIOTransportServices
Expand Down Expand Up @@ -201,46 +207,69 @@ public final actor STOMPConnection: Sendable {
) -> EventLoopFuture<STOMPConnection> {
eventLoop.assertInEventLoop()

let bootstrap: any NIOClientTCPBootstrapProtocol
#if canImport(Network)
if let tsBootstrap = createTSBootstrap(eventLoopGroup: eventLoop, tlsOptions: nil) {
bootstrap = tsBootstrap
} else {
#if os(iOS) || os(tvOS)
logger.warning(
"Running BSD sockets on iOS or tvOS is not recommended. Please use NIOTSEventLoopGroup, to run with the Network framework"
)
#endif
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
}
#else
bootstrap = self.createSocketsBootstrap(eventLoopGroup: eventLoop)
#endif

let connect = bootstrap.channelInitializer { channel in
do {
try self._setupChannel(channel, configuration: configuration, logger: logger)
return eventLoop.makeSucceededVoidFuture()
} catch {
return eventLoop.makeFailedFuture(error)
let host =
switch address.value {
case .hostname(let hostname, _):
hostname
case .unixDomainSocket(let path):
path
}
}

let future: EventLoopFuture<any Channel>
switch address.value {
case .hostname(let host, let port):
future = connect.connect(host: host, port: port)
future.whenSuccess { _ in
logger.debug("Client connected to \(host):\(port)")
let channelPromise = eventLoop.makePromise(of: (any Channel).self)

do {
let bootstrap = try Self.createBootstrap(configuration: configuration, eventLoopGroup: eventLoop, host: host, logger: logger)

let connect = bootstrap.channelInitializer { channel in
do {
if let webSocketConfiguration = configuration.webSocket {
// Prepare for WebSockets and on upgrade add handlers
let promise = eventLoop.makePromise(of: Void.self)
promise.futureResult.map { _ in channel }.cascade(to: channelPromise)

return Self._setupChannelForWebSockets(
channel,
address: address,
configuration: configuration,
webSocketConfiguration: webSocketConfiguration,
upgradePromise: promise
) {
try self._setupChannel(channel, configuration: configuration, logger: logger)
}
} else {
try self._setupChannel(channel, configuration: configuration, logger: logger)
}
return eventLoop.makeSucceededVoidFuture()
} catch {
channelPromise.fail(error)
return eventLoop.makeFailedFuture(error)
}
}
case .unixDomainSocket(let path):
future = connect.connect(unixDomainSocketPath: path)
future.whenSuccess { _ in
logger.debug("Client connected to socket path \(path)")

let future: EventLoopFuture<any Channel>
switch address.value {
case .hostname(let host, let port):
future = connect.connect(host: host, port: port)
future.whenSuccess { _ in
logger.debug("Client connected to \(host):\(port)")
}
case .unixDomainSocket(let path):
future = connect.connect(unixDomainSocketPath: path)
future.whenSuccess { _ in
logger.debug("Client connected to socket path \(path)")
}
}

future.map { channel in
if configuration.webSocket == nil {
channelPromise.succeed(channel)
}
}.cascadeFailure(to: channelPromise)
} catch {
channelPromise.fail(error)
}

return future.flatMapThrowing { channel in
return channelPromise.futureResult.flatMapThrowing { channel in
let handler = try channel.pipeline.syncOperations.handler(type: STOMPChannelHandler.self)
return STOMPConnection(
channel: channel,
Expand Down Expand Up @@ -307,28 +336,129 @@ public final actor STOMPConnection: Sendable {
return stompChannelHandler
}

/// Create a BSD sockets based bootstrap
private static func createSocketsBootstrap(eventLoopGroup: any EventLoopGroup) -> ClientBootstrap {
ClientBootstrap(group: eventLoopGroup)
private static func _setupChannelForWebSockets(
_ channel: any Channel,
address: STOMPServerAddress,
configuration: STOMPConnectionConfiguration,
webSocketConfiguration: STOMPConnectionConfiguration.WebSocket,
upgradePromise promise: EventLoopPromise<Void>,
afterHandlerAdded: @Sendable @escaping () throws -> Void
) -> EventLoopFuture<Void> {
var hostHeader: String {
if case .enable(_, let sniServerName) = configuration.tls.base, let sniServerName {
return sniServerName
}
switch (configuration.tls.base, address.value) {
case (.enable, .hostname(let host, let port)) where port != 443:
return "\(host):\(port)"
case (.disable, .hostname(let host, let port)) where port != 80:
return "\(host):\(port)"
case (.enable, .hostname(let host, _)), (.disable, .hostname(let host, _)):
return host
case (.enable, .unixDomainSocket(let path)), (.disable, .unixDomainSocket(let path)):
return path
}
}

// Initial HTTP request handler, before upgrade
let httpHandler = STOMPWebSocketInitialRequestChannelHandler(
host: hostHeader,
urlPath: webSocketConfiguration.urlPath,
additionalHeaders: webSocketConfiguration.initialRequestHeaders,
upgradePromise: promise
)

// Create random request key
let requestKey = (0..<16).map { _ in UInt8.random(in: .min ..< .max) }
let websocketUpgrader = NIOWebSocketClientUpgrader(
requestKey: Data(requestKey).base64EncodedString(),
maxFrameSize: webSocketConfiguration.maxFrameSize
) { channel, _ in
let future = channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(STOMPWebSocketChannelHandler())
try afterHandlerAdded()
}
future.cascade(to: promise)
return future
}
let upgradeConfig: NIOHTTPClientUpgradeSendableConfiguration = (
upgraders: [websocketUpgrader],
completionHandler: { _ in
channel.pipeline.removeHandler(httpHandler, promise: nil)
}
)

// Add HTTP handler with WebSocket upgrade
return channel.pipeline.addHTTPClientHandlers(withClientUpgrade: upgradeConfig).flatMap {
channel.pipeline.addHandler(httpHandler)
}
}

#if canImport(Network)
/// Create a NIOTransportServices bootstrap using Network.framework
private static func createTSBootstrap(
private static func createBootstrap(
configuration: STOMPConnectionConfiguration,
eventLoopGroup: any EventLoopGroup,
tlsOptions: NWProtocolTLS.Options?
) -> NIOTSConnectionBootstrap? {
guard
let bootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoopGroup)
else {
return nil
host: String,
logger: Logger
) throws -> NIOClientTCPBootstrap {
var serverName: String {
if case .enable(_, let sniServerName) = configuration.tls.base, let sniServerName {
sniServerName
} else {
host
}
}
if let tlsOptions {
return bootstrap.tlsOptions(tlsOptions)

let bootstrap: NIOClientTCPBootstrap
#if canImport(Network)
// If the EventLoop is compatible with NIOTransportServices create a `NIOTSConnectionBootstrap`
if let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoopGroup) {
// Create `NIOClientTCPBootstrap` with NIOTS TLS provider
let options: NWProtocolTLS.Options
if case .enable(let config, _) = configuration.tls.base {
switch config {
case .ts(let config):
options = try config.getNWProtocolTLSOptions(logger: logger)
#if os(macOS) || os(Linux)
case .niossl:
throw STOMPClientError.wrongTLSConfig
#endif
}
} else {
options = NWProtocolTLS.Options()
}
sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverName)
let tlsProvider = NIOTSClientTLSProvider(tlsOptions: options)
bootstrap = NIOClientTCPBootstrap(tsBootstrap, tls: tlsProvider)
if case .enable = configuration.tls.base {
return bootstrap.enableTLS()
}
return bootstrap
}
return bootstrap
#endif

#if os(macOS) || os(Linux)
if let clientBootstrap = ClientBootstrap(validatingGroup: eventLoopGroup) {
if case .enable(let configuration, _) = configuration.tls.base {
let tlsConfiguration: TLSConfiguration
switch configuration {
case .niossl(let config):
tlsConfiguration = config
default:
tlsConfiguration = TLSConfiguration.makeClientConfiguration()
}
let sslContext = try NIOSSLContext(configuration: tlsConfiguration)
let tlsProvider = try NIOSSLClientTLSProvider<ClientBootstrap>(context: sslContext, serverHostname: serverName)
bootstrap = NIOClientTCPBootstrap(clientBootstrap, tls: tlsProvider)
return bootstrap.enableTLS()
} else {
bootstrap = NIOClientTCPBootstrap(clientBootstrap, tls: NIOInsecureNoTLS())
}
return bootstrap
}
#endif

preconditionFailure("Cannot create bootstrap for the supplied EventLoop")
}
#endif

@usableFromInline
func sendFrame(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
public import Configuration
import NIOHTTP1

extension STOMPConnectionConfiguration {
/// Creates a new STOMP connection configuration using values from the provided reader.
Expand Down Expand Up @@ -47,6 +48,26 @@ extension STOMPConnectionConfiguration {
} else {
[:]
}

let stompWebSocketConfig = stompConfig.scoped(to: "webSocket")
let urlPath = stompWebSocketConfig.string(forKey: "urlPath")
let maxFrameSize = stompWebSocketConfig.int(forKey: "maxFrameSize")
let initialRequestHeaders = stompWebSocketConfig.stringArray(forKey: "initialRequestHeaders").flatMap {
HTTPHeaders(configStringArray: $0)
}
self.webSocket =
if urlPath != nil || maxFrameSize != nil || initialRequestHeaders != nil {
.init(
urlPath: urlPath ?? "/ws",
maxFrameSize: maxFrameSize ?? 1 << 24,
initialRequestHeaders: initialRequestHeaders ?? [:]
)
} else {
nil
}

// TLS is disabled by default
self.tls = .disable
}
}

Expand All @@ -66,3 +87,24 @@ extension STOMPHeader: ExpressibleByConfigString {
self.init(name: name, value: value)
}
}

extension HTTPHeaders {
/// Creates HTTP headers from an array of configuration strings.
///
/// Each configuration string must be in the `<key>:<value>` format.
///
/// - Parameter configStringArray: The array of configuration strings to create the HTTP headers from.
fileprivate init?(configStringArray: [String]) {
var headers = HTTPHeaders()
for configString in configStringArray {
guard let colonIndex = configString.firstIndex(of: ":") else {
return nil
}
let name = String(configString[..<colonIndex].trimmingWhitespace())
let valueStartIndex = configString.index(after: colonIndex)
let value = String(configString[valueStartIndex...].trimmingWhitespace())
headers.add(name: name, value: value)
}
self = headers
}
}
Loading
Loading