Skip to content

Commit c49b361

Browse files
committed
refactor(realtime): adopt WebSocket from swift-websocket package
1 parent 44243da commit c49b361

File tree

4 files changed

+254
-268
lines changed

4 files changed

+254
-268
lines changed

Package.swift

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ let package = Package(
2828
.package(url: "https://github.com/pointfreeco/swift-custom-dump", from: "1.3.2"),
2929
.package(url: "https://github.com/pointfreeco/swift-snapshot-testing", from: "1.17.2"),
3030
.package(url: "https://github.com/pointfreeco/xctest-dynamic-overlay", from: "1.2.2"),
31+
.package(url: "https://github.com/grdsdev/swift-websocket", branch: "main"),
3132
],
3233
targets: [
3334
.target(
@@ -121,6 +122,8 @@ let package = Package(
121122
dependencies: [
122123
.product(name: "ConcurrencyExtras", package: "swift-concurrency-extras"),
123124
.product(name: "IssueReporting", package: "xctest-dynamic-overlay"),
125+
.product(name: "WebSocket", package: "swift-websocket"),
126+
.product(name: "WebSocketFoundation", package: "swift-websocket"),
124127
"Helpers",
125128
]
126129
),

Sources/Realtime/V2/RealtimeClientV2.swift

+89-115
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,18 @@
88
import ConcurrencyExtras
99
import Foundation
1010
import Helpers
11+
import WebSocket
12+
import WebSocketFoundation
1113

1214
#if canImport(FoundationNetworking)
1315
import FoundationNetworking
1416
#endif
1517

18+
typealias WebSocketFactory = @Sendable (
19+
_ url: URL,
20+
_ headers: [String: String]
21+
) async throws -> any WebSocket
22+
1623
public typealias JSONObject = Helpers.JSONObject
1724

1825
@MainActor
@@ -24,18 +31,17 @@ public final class RealtimeClientV2 {
2431
/// Long-running task that keeps sending heartbeat messages.
2532
var heartbeatTask: Task<Void, Never>?
2633

27-
/// Long-running task for listening for incoming messages from WebSocket.
28-
var messageTask: Task<Void, Never>?
29-
3034
var connectionTask: Task<Void, Never>?
3135

3236
/// All managed channels indexed by their topics.
3337
private(set) public var channels: [String: RealtimeChannelV2] = [:]
34-
var sendBuffer: [@Sendable () async -> Void] = []
38+
var sendBuffer: [() -> Void] = []
39+
40+
var ws: (any WebSocket)?
3541

3642
let url: URL
3743
let options: RealtimeClientOptions
38-
let ws: any WebSocketClient
44+
let wsFactory: WebSocketFactory
3945
let http: any HTTPClientType
4046
let apikey: String?
4147

@@ -75,13 +81,11 @@ public final class RealtimeClientV2 {
7581
self.init(
7682
url: url,
7783
options: options,
78-
ws: WebSocket(
79-
realtimeURL: Self.realtimeWebSocketURL(
80-
baseURL: Self.realtimeBaseURL(url: url),
81-
apikey: options.apikey
82-
),
83-
options: options
84-
),
84+
wsFactory: { url, headers in
85+
let configuration = URLSessionConfiguration.default
86+
configuration.httpAdditionalHeaders = headers
87+
return try await URLSessionWebSocket.connect(to: url, configuration: configuration)
88+
},
8589
http: HTTPClient(
8690
fetch: options.fetch ?? { try await URLSession.shared.data(for: $0) },
8791
interceptors: interceptors
@@ -92,12 +96,12 @@ public final class RealtimeClientV2 {
9296
init(
9397
url: URL,
9498
options: RealtimeClientOptions,
95-
ws: any WebSocketClient,
99+
wsFactory: @escaping WebSocketFactory,
96100
http: any HTTPClientType
97101
) {
98102
self.url = url
99103
self.options = options
100-
self.ws = ws
104+
self.wsFactory = wsFactory
101105
self.http = http
102106
apikey = options.apikey
103107

@@ -110,7 +114,6 @@ public final class RealtimeClientV2 {
110114

111115
deinit {
112116
heartbeatTask?.cancel()
113-
messageTask?.cancel()
114117
channels = [:]
115118
}
116119

@@ -122,77 +125,67 @@ public final class RealtimeClientV2 {
122125
}
123126

124127
func connect(reconnect: Bool) async {
125-
if status == .disconnected {
126-
connectionTask = Task {
127-
if reconnect {
128-
try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(options.reconnectDelay))
128+
connectionTask = Task {
129+
if reconnect {
130+
try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(options.reconnectDelay))
129131

130-
if Task.isCancelled {
131-
options.logger?.debug("Reconnect cancelled, returning")
132-
return
133-
}
134-
}
132+
if Task.isCancelled {
133+
options.logger?.debug("Reconnect cancelled, returning")
135134

136-
if status == .connected {
137-
options.logger?.debug("WebsSocket already connected")
138135
return
139136
}
137+
}
140138

141-
status = .connecting
142-
143-
for await connectionStatus in ws.connect() {
144-
if Task.isCancelled {
145-
break
146-
}
147-
148-
switch connectionStatus {
149-
case .connected:
150-
await onConnected(reconnect: reconnect)
139+
if status == .connected {
140+
// websocket connected while it was waiting for a reconnection.
141+
options.logger?.debug("WebsSocket already connected")
142+
return
143+
}
151144

152-
case .disconnected:
153-
await onDisconnected()
145+
status = .connecting
154146

155-
case let .error(error):
156-
await onError(error)
147+
do {
148+
let ws = try await wsFactory(
149+
Self.realtimeWebSocketURL(
150+
baseURL: Self.realtimeBaseURL(url: url),
151+
apikey: options.apikey
152+
),
153+
options.headers.dictionary
154+
)
155+
self.ws = ws
156+
status = .connected
157+
startHeartbeating()
158+
if reconnect {
159+
rejoinChannels()
160+
}
161+
flushSendBuffer()
162+
163+
for await event in ws.events {
164+
if Task.isCancelled { break }
165+
166+
switch event {
167+
case let .text(text):
168+
await onMessage(Data(text.utf8))
169+
case let .binary(data):
170+
await onMessage(data)
171+
case let .close(code, reason):
172+
options.logger?.verbose("connection closed code \(code ?? 0), reason \(reason)")
157173
}
158174
}
175+
} catch {
176+
options.logger?
177+
.debug(
178+
"WebSocket error \(error.localizedDescription). Trying again in \(options.reconnectDelay)"
179+
)
180+
Task {
181+
self.disconnect()
182+
await self.connect(reconnect: true)
183+
}
159184
}
160185
}
161186

162187
_ = await statusChange.first { @Sendable in $0 == .connected }
163-
}
164-
165-
private func onConnected(reconnect: Bool) async {
166-
status = .connected
167-
options.logger?.debug("Connected to realtime WebSocket")
168-
listenForMessages()
169-
startHeartbeating()
170-
if reconnect {
171-
await rejoinChannels()
172-
}
173188

174-
await flushSendBuffer()
175-
}
176-
177-
private func onDisconnected() async {
178-
options.logger?
179-
.debug(
180-
"WebSocket disconnected. Trying again in \(options.reconnectDelay)"
181-
)
182-
await reconnect()
183-
}
184-
185-
private func onError(_ error: (any Error)?) async {
186-
options.logger?
187-
.debug(
188-
"WebSocket error \(error?.localizedDescription ?? "<none>"). Trying again in \(options.reconnectDelay)"
189-
)
190-
await reconnect()
191-
}
192-
193-
private func reconnect() async {
194-
disconnect()
195-
await connect(reconnect: true)
196189
}
197190

198191
/// Creates a new channel and bind it to this client.
@@ -252,35 +245,10 @@ public final class RealtimeClientV2 {
252245
}
253246
}
254247

255-
private func rejoinChannels() async {
256-
await withTaskGroup(of: Void.self) { group in
257-
for channel in channels.values {
258-
group.addTask {
259-
await channel.subscribe()
260-
}
261-
}
262-
263-
await group.waitForAll()
264-
}
265-
}
266-
267-
private func listenForMessages() {
268-
messageTask = Task { [weak self] in
269-
guard let self else { return }
270-
271-
do {
272-
for try await message in ws.receive() {
273-
if Task.isCancelled {
274-
return
275-
}
276-
277-
await onMessage(message)
278-
}
279-
} catch {
280-
options.logger?.debug(
281-
"Error while listening for messages. Trying again in \(options.reconnectDelay) \(error)"
282-
)
283-
await reconnect()
248+
private func rejoinChannels() {
249+
for channel in channels.values {
250+
Task {
251+
await channel.subscribe()
284252
}
285253
}
286254
}
@@ -300,8 +268,11 @@ public final class RealtimeClientV2 {
300268
private func sendHeartbeat() async {
301269
if pendingHeartbeatRef != nil {
302270
pendingHeartbeatRef = nil
303-
options.logger?.debug("Heartbeat timeout")
304-
await reconnect()
271+
options.logger?.debug("Heartbeat timeout, trying to reconnect in \(options.reconnectDelay)s")
272+
Task {
273+
disconnect()
274+
await connect(reconnect: true)
275+
}
305276
} else {
306277
let ref = makeRef()
307278
pendingHeartbeatRef = ref
@@ -331,10 +302,9 @@ public final class RealtimeClientV2 {
331302
public func disconnect(code: Int? = nil, reason: String? = nil) {
332303
options.logger?.debug("Closing WebSocket connection")
333304
ref = 0
334-
messageTask?.cancel()
335305
heartbeatTask?.cancel()
336306
connectionTask?.cancel()
337-
ws.disconnect(code: code, reason: reason)
307+
ws?.close(code: code, reason: reason)
338308
status = .disconnected
339309
}
340310

@@ -354,6 +324,14 @@ public final class RealtimeClientV2 {
354324
}
355325
}
356326

327+
private func onMessage(_ data: Data) async {
328+
guard let message = try? JSONDecoder().decode(RealtimeMessageV2.self, from: data) else {
329+
return
330+
}
331+
332+
await onMessage(message)
333+
}
334+
357335
private func onMessage(_ message: RealtimeMessageV2) async {
358336
let channel = channels[message.topic]
359337

@@ -376,11 +354,11 @@ public final class RealtimeClientV2 {
376354
///
377355
/// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established.
378356
public func push(_ message: RealtimeMessageV2) async {
379-
let callback = { @Sendable [weak self] in
357+
let callback = { [weak self] in
380358
do {
381359
// Check cancellation before sending, because this push may have been cancelled before a connection was established.
382360
try Task.checkCancellation()
383-
try await self?.ws.send(message)
361+
try self?.ws?.send(binary: JSONEncoder().encode(message))
384362
} catch {
385363
self?.options.logger?.error(
386364
"""
@@ -394,19 +372,15 @@ public final class RealtimeClientV2 {
394372
}
395373

396374
if status == .connected {
397-
await callback()
375+
callback()
398376
} else {
399377
sendBuffer.append(callback)
400378
}
401379
}
402380

403-
private func flushSendBuffer() async {
404-
let sendBuffer = self.sendBuffer
405-
self.sendBuffer = []
406-
407-
for send in sendBuffer {
408-
await send()
409-
}
381+
private func flushSendBuffer() {
382+
sendBuffer.forEach { $0() }
383+
sendBuffer = []
410384
}
411385

412386
func makeRef() -> Int {

0 commit comments

Comments
 (0)