8
8
import ConcurrencyExtras
9
9
import Foundation
10
10
import Helpers
11
+ import WebSocket
12
+ import WebSocketFoundation
11
13
12
14
#if canImport(FoundationNetworking)
13
15
import FoundationNetworking
14
16
#endif
15
17
18
+ typealias WebSocketFactory = @Sendable (
19
+ _ url: URL ,
20
+ _ headers: [ String : String ]
21
+ ) async throws -> any WebSocket
22
+
16
23
public typealias JSONObject = Helpers . JSONObject
17
24
18
25
@MainActor
@@ -24,18 +31,17 @@ public final class RealtimeClientV2 {
24
31
/// Long-running task that keeps sending heartbeat messages.
25
32
var heartbeatTask : Task < Void , Never > ?
26
33
27
- /// Long-running task for listening for incoming messages from WebSocket.
28
- var messageTask : Task < Void , Never > ?
29
-
30
34
var connectionTask : Task < Void , Never > ?
31
35
32
36
/// All managed channels indexed by their topics.
33
37
private( set) public var channels : [ String : RealtimeChannelV2 ] = [ : ]
34
- var sendBuffer : [ @Sendable ( ) async -> Void ] = [ ]
38
+ var sendBuffer : [ ( ) -> Void ] = [ ]
39
+
40
+ var ws : ( any WebSocket ) ?
35
41
36
42
let url : URL
37
43
let options : RealtimeClientOptions
38
- let ws : any WebSocketClient
44
+ let wsFactory : WebSocketFactory
39
45
let http : any HTTPClientType
40
46
let apikey : String ?
41
47
@@ -75,13 +81,11 @@ public final class RealtimeClientV2 {
75
81
self . init (
76
82
url: url,
77
83
options: options,
78
- ws: WebSocket (
79
- realtimeURL: Self . realtimeWebSocketURL (
80
- baseURL: Self . realtimeBaseURL ( url: url) ,
81
- apikey: options. apikey
82
- ) ,
83
- options: options
84
- ) ,
84
+ wsFactory: { url, headers in
85
+ let configuration = URLSessionConfiguration . default
86
+ configuration. httpAdditionalHeaders = headers
87
+ return try await URLSessionWebSocket . connect ( to: url, configuration: configuration)
88
+ } ,
85
89
http: HTTPClient (
86
90
fetch: options. fetch ?? { try await URLSession . shared. data ( for: $0) } ,
87
91
interceptors: interceptors
@@ -92,12 +96,12 @@ public final class RealtimeClientV2 {
92
96
init (
93
97
url: URL ,
94
98
options: RealtimeClientOptions ,
95
- ws : any WebSocketClient ,
99
+ wsFactory : @escaping WebSocketFactory ,
96
100
http: any HTTPClientType
97
101
) {
98
102
self . url = url
99
103
self . options = options
100
- self . ws = ws
104
+ self . wsFactory = wsFactory
101
105
self . http = http
102
106
apikey = options. apikey
103
107
@@ -110,7 +114,6 @@ public final class RealtimeClientV2 {
110
114
111
115
deinit {
112
116
heartbeatTask? . cancel ( )
113
- messageTask? . cancel ( )
114
117
channels = [ : ]
115
118
}
116
119
@@ -122,77 +125,67 @@ public final class RealtimeClientV2 {
122
125
}
123
126
124
127
func connect( reconnect: Bool ) async {
125
- if status == . disconnected {
126
- connectionTask = Task {
127
- if reconnect {
128
- try ? await Task . sleep ( nanoseconds: NSEC_PER_SEC * UInt64( options. reconnectDelay) )
128
+ connectionTask = Task {
129
+ if reconnect {
130
+ try ? await Task . sleep ( nanoseconds: NSEC_PER_SEC * UInt64( options. reconnectDelay) )
129
131
130
- if Task . isCancelled {
131
- options. logger? . debug ( " Reconnect cancelled, returning " )
132
- return
133
- }
134
- }
132
+ if Task . isCancelled {
133
+ options. logger? . debug ( " Reconnect cancelled, returning " )
135
134
136
- if status == . connected {
137
- options. logger? . debug ( " WebsSocket already connected " )
138
135
return
139
136
}
137
+ }
140
138
141
- status = . connecting
142
-
143
- for await connectionStatus in ws. connect ( ) {
144
- if Task . isCancelled {
145
- break
146
- }
147
-
148
- switch connectionStatus {
149
- case . connected:
150
- await onConnected ( reconnect: reconnect)
139
+ if status == . connected {
140
+ // websocket connected while it was waiting for a reconnection.
141
+ options. logger? . debug ( " WebsSocket already connected " )
142
+ return
143
+ }
151
144
152
- case . disconnected:
153
- await onDisconnected ( )
145
+ status = . connecting
154
146
155
- case let . error( error) :
156
- await onError ( error)
147
+ do {
148
+ let ws = try await wsFactory (
149
+ Self . realtimeWebSocketURL (
150
+ baseURL: Self . realtimeBaseURL ( url: url) ,
151
+ apikey: options. apikey
152
+ ) ,
153
+ options. headers. dictionary
154
+ )
155
+ self . ws = ws
156
+ status = . connected
157
+ startHeartbeating ( )
158
+ if reconnect {
159
+ rejoinChannels ( )
160
+ }
161
+ flushSendBuffer ( )
162
+
163
+ for await event in ws. events {
164
+ if Task . isCancelled { break }
165
+
166
+ switch event {
167
+ case let . text( text) :
168
+ await onMessage ( Data ( text. utf8) )
169
+ case let . binary( data) :
170
+ await onMessage ( data)
171
+ case let . close( code, reason) :
172
+ options. logger? . verbose ( " connection closed code \( code ?? 0 ) , reason \( reason) " )
157
173
}
158
174
}
175
+ } catch {
176
+ options. logger?
177
+ . debug (
178
+ " WebSocket error \( error. localizedDescription) . Trying again in \( options. reconnectDelay) "
179
+ )
180
+ Task {
181
+ self . disconnect ( )
182
+ await self . connect ( reconnect: true )
183
+ }
159
184
}
160
185
}
161
186
162
187
_ = await statusChange. first { @Sendable in $0 == . connected }
163
- }
164
-
165
- private func onConnected( reconnect: Bool ) async {
166
- status = . connected
167
- options. logger? . debug ( " Connected to realtime WebSocket " )
168
- listenForMessages ( )
169
- startHeartbeating ( )
170
- if reconnect {
171
- await rejoinChannels ( )
172
- }
173
188
174
- await flushSendBuffer ( )
175
- }
176
-
177
- private func onDisconnected( ) async {
178
- options. logger?
179
- . debug (
180
- " WebSocket disconnected. Trying again in \( options. reconnectDelay) "
181
- )
182
- await reconnect ( )
183
- }
184
-
185
- private func onError( _ error: ( any Error ) ? ) async {
186
- options. logger?
187
- . debug (
188
- " WebSocket error \( error? . localizedDescription ?? " <none> " ) . Trying again in \( options. reconnectDelay) "
189
- )
190
- await reconnect ( )
191
- }
192
-
193
- private func reconnect( ) async {
194
- disconnect ( )
195
- await connect ( reconnect: true )
196
189
}
197
190
198
191
/// Creates a new channel and bind it to this client.
@@ -252,35 +245,10 @@ public final class RealtimeClientV2 {
252
245
}
253
246
}
254
247
255
- private func rejoinChannels( ) async {
256
- await withTaskGroup ( of: Void . self) { group in
257
- for channel in channels. values {
258
- group. addTask {
259
- await channel. subscribe ( )
260
- }
261
- }
262
-
263
- await group. waitForAll ( )
264
- }
265
- }
266
-
267
- private func listenForMessages( ) {
268
- messageTask = Task { [ weak self] in
269
- guard let self else { return }
270
-
271
- do {
272
- for try await message in ws. receive ( ) {
273
- if Task . isCancelled {
274
- return
275
- }
276
-
277
- await onMessage ( message)
278
- }
279
- } catch {
280
- options. logger? . debug (
281
- " Error while listening for messages. Trying again in \( options. reconnectDelay) \( error) "
282
- )
283
- await reconnect ( )
248
+ private func rejoinChannels( ) {
249
+ for channel in channels. values {
250
+ Task {
251
+ await channel. subscribe ( )
284
252
}
285
253
}
286
254
}
@@ -300,8 +268,11 @@ public final class RealtimeClientV2 {
300
268
private func sendHeartbeat( ) async {
301
269
if pendingHeartbeatRef != nil {
302
270
pendingHeartbeatRef = nil
303
- options. logger? . debug ( " Heartbeat timeout " )
304
- await reconnect ( )
271
+ options. logger? . debug ( " Heartbeat timeout, trying to reconnect in \( options. reconnectDelay) s " )
272
+ Task {
273
+ disconnect ( )
274
+ await connect ( reconnect: true )
275
+ }
305
276
} else {
306
277
let ref = makeRef ( )
307
278
pendingHeartbeatRef = ref
@@ -331,10 +302,9 @@ public final class RealtimeClientV2 {
331
302
public func disconnect( code: Int ? = nil , reason: String ? = nil ) {
332
303
options. logger? . debug ( " Closing WebSocket connection " )
333
304
ref = 0
334
- messageTask? . cancel ( )
335
305
heartbeatTask? . cancel ( )
336
306
connectionTask? . cancel ( )
337
- ws. disconnect ( code: code, reason: reason)
307
+ ws? . close ( code: code, reason: reason)
338
308
status = . disconnected
339
309
}
340
310
@@ -354,6 +324,14 @@ public final class RealtimeClientV2 {
354
324
}
355
325
}
356
326
327
+ private func onMessage( _ data: Data ) async {
328
+ guard let message = try ? JSONDecoder ( ) . decode ( RealtimeMessageV2 . self, from: data) else {
329
+ return
330
+ }
331
+
332
+ await onMessage ( message)
333
+ }
334
+
357
335
private func onMessage( _ message: RealtimeMessageV2 ) async {
358
336
let channel = channels [ message. topic]
359
337
@@ -376,11 +354,11 @@ public final class RealtimeClientV2 {
376
354
///
377
355
/// If the socket is not connected, the message gets enqueued within a local buffer, and sent out when a connection is next established.
378
356
public func push( _ message: RealtimeMessageV2 ) async {
379
- let callback = { @ Sendable [ weak self] in
357
+ let callback = { [ weak self] in
380
358
do {
381
359
// Check cancellation before sending, because this push may have been cancelled before a connection was established.
382
360
try Task . checkCancellation ( )
383
- try await self ? . ws. send ( message)
361
+ try self ? . ws? . send ( binary : JSONEncoder ( ) . encode ( message) )
384
362
} catch {
385
363
self ? . options. logger? . error (
386
364
"""
@@ -394,19 +372,15 @@ public final class RealtimeClientV2 {
394
372
}
395
373
396
374
if status == . connected {
397
- await callback ( )
375
+ callback ( )
398
376
} else {
399
377
sendBuffer. append ( callback)
400
378
}
401
379
}
402
380
403
- private func flushSendBuffer( ) async {
404
- let sendBuffer = self . sendBuffer
405
- self . sendBuffer = [ ]
406
-
407
- for send in sendBuffer {
408
- await send ( )
409
- }
381
+ private func flushSendBuffer( ) {
382
+ sendBuffer. forEach { $0 ( ) }
383
+ sendBuffer = [ ]
410
384
}
411
385
412
386
func makeRef( ) -> Int {
0 commit comments