Skip to content

Commit

Permalink
Merge pull request #3 from claucambra/bugfix/improved-websocket-handling
Browse files Browse the repository at this point in the history
Bugfix/improved websocket handling
  • Loading branch information
claucambra committed Jul 9, 2024
2 parents d63580c + 5902576 commit b0eddd1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class RemoteChangeObserver: NSObject, NKCommonDelegate, URLSessionWebSock
private var webSocketUrlSession: URLSession?
private var webSocketTask: URLSessionWebSocketTask?
private var webSocketOperationQueue = OperationQueue()
private var webSocketPingTask: Task<(), Never>?
private(set) var webSocketPingFailCount = 0
private(set) var webSocketAuthenticationFailCount = 0

Expand Down Expand Up @@ -117,6 +118,8 @@ public class RemoteChangeObserver: NSObject, NKCommonDelegate, URLSessionWebSock
webSocketTask = nil
webSocketOperationQueue.cancelAllOperations()
webSocketOperationQueue.isSuspended = true
webSocketPingTask?.cancel()
webSocketPingTask = nil
webSocketPingFailCount = 0
}

Expand Down Expand Up @@ -250,6 +253,29 @@ public class RemoteChangeObserver: NSObject, NKCommonDelegate, URLSessionWebSock
readWebSocket()
}

private func startNewWebSocketPingTask() {
guard !Task.isCancelled else { return }

if let webSocketPingTask, !webSocketPingTask.isCancelled {
webSocketPingTask.cancel()
}

webSocketPingTask = Task.detached(priority: .background) {
do {
try await Task.sleep(nanoseconds: self.webSocketPingIntervalNanoseconds)
} catch let error {
self.logger.error(
"""
Could not sleep websocket ping for \(self.accountId, privacy: .public):
\(error.localizedDescription, privacy: .public)
"""
)
}
guard !Task.isCancelled else { return }
self.pingWebSocket()
}
}

private func pingWebSocket() { // Keep the socket connection alive
guard networkReachability != .notReachable else {
logger.error("Not pinging \(self.accountId, privacy: .public), network is unreachable")
Expand All @@ -268,24 +294,12 @@ public class RemoteChangeObserver: NSObject, NKCommonDelegate, URLSessionWebSock
if self.webSocketPingFailCount > self.webSocketPingFailLimit {
Task.detached(priority: .medium) { self.reconnectWebSocket() }
} else {
Task.detached(priority: .background) { self.pingWebSocket() }
self.startNewWebSocketPingTask()
}
return
}

Task.detached(priority: .background) {
do {
try await Task.sleep(nanoseconds: self.webSocketPingIntervalNanoseconds)
} catch let error {
self.logger.error(
"""
Could not sleep websocket ping for \(self.accountId, privacy: .public):
\(error.localizedDescription, privacy: .public)
"""
)
}
self.pingWebSocket()
}
self.startNewWebSocketPingTask()
}
}

Expand Down Expand Up @@ -331,7 +345,7 @@ public class RemoteChangeObserver: NSObject, NKCommonDelegate, URLSessionWebSock
NotificationCenter.default.post(
name: NotifyPushAuthenticatedNotificationName, object: self
)
pingWebSocket()
startNewWebSocketPingTask()
} else if string == "err: Invalid credentials" {
logger.debug(
"""
Expand Down
4 changes: 4 additions & 0 deletions Tests/Interface/MockNotifyPushServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class MockNotifyPushServer {
private var connectedClients: [NIOAsyncChannel<WebSocketFrame, WebSocketFrame>] = []
public var delay: Int?
public var refuse = false
public var pingHandler: (() -> Void)?

enum UpgradeResult {
case websocket(NIOAsyncChannel<WebSocketFrame, WebSocketFrame>)
Expand Down Expand Up @@ -56,6 +57,7 @@ public class MockNotifyPushServer {
self.delay = nil
self.refuse = false
self.connectedClients = []
self.pingHandler = nil
}

/// This method starts the server and handles incoming connections.
Expand Down Expand Up @@ -151,6 +153,8 @@ public class MockNotifyPushServer {
switch frame.opcode {
case .ping:
print("Received ping")
self.pingHandler?()

var frameData = frame.data
let maskingKey = frame.maskKey

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,46 @@ final class RemoteChangeObserverTests: XCTestCase {
}
XCTAssertTrue(authenticated)
}

func testPinging() async throws {
let remoteInterface = MockRemoteInterface(account: Self.account)
remoteInterface.capabilities = mockCapabilities

var authenticated = false

NotificationCenter.default.addObserver(
forName: NotifyPushAuthenticatedNotificationName, object: nil, queue: nil
) { _ in
authenticated = true
}

remoteChangeObserver = RemoteChangeObserver(
remoteInterface: remoteInterface,
changeNotificationInterface: MockChangeNotificationInterface(),
domain: nil
)

let pingIntervalNsecs = 500_000_000
remoteChangeObserver?.webSocketPingIntervalNanoseconds = UInt64(pingIntervalNsecs)

for _ in 0...Self.timeout {
try await Task.sleep(nanoseconds: 1_000_000)
if authenticated {
break
}
}
XCTAssertTrue(authenticated)

let intendedPings = 3
// Add a bit of buffer to the wait time
let intendedPingsWait = (intendedPings + 1) * pingIntervalNsecs

var pings = 0
Self.notifyPushServer.pingHandler = {
pings += 1
}

try await Task.sleep(nanoseconds: UInt64(intendedPingsWait))
XCTAssertEqual(pings, intendedPings)
}
}

0 comments on commit b0eddd1

Please sign in to comment.