From ba4d595f22c2b0316b63760822c7d72ee42cd942 Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Thu, 8 Aug 2024 13:41:37 -0700 Subject: [PATCH] Continue integration tests TODO: rework reconnect subscription stuff. --- spec/integration.spec.ts | 283 +++++++++++++++++++++++++++++-- spec/test-websocket-server.ts | 49 +++++- src/messages/message-mediator.ts | 9 +- src/push-manager.ts | 30 +++- src/push-subscription.ts | 3 +- 5 files changed, 343 insertions(+), 31 deletions(-) diff --git a/spec/integration.spec.ts b/spec/integration.spec.ts index 7668b3a..ac27f89 100644 --- a/spec/integration.spec.ts +++ b/spec/integration.spec.ts @@ -2,6 +2,7 @@ import * as crypto from "crypto"; import { createPushManager } from "../src"; import { deriveKeyAndNonce, generateEcKeys, randomBytes } from "../src/crypto"; +import { ClientAck, ClientAckCodes } from "../src/messages/message"; import { PushManager } from "../src/push-manager"; import { GenericPushSubscription } from "../src/push-subscription"; import { @@ -10,10 +11,15 @@ import { fromUtf8ToBuffer, } from "../src/string-manipulation"; -import { applicationPublicKey } from "./constants"; +import { + applicationPrivateKey, + applicationPublicKey, + applicationPublicKeyX, + applicationPublicKeyY, +} from "./constants"; import { TestLogger } from "./test-logger"; import { TestBackingStore } from "./test-storage"; -import { defaultUaid, TestWebSocketServer } from "./test-websocket-server"; +import { defaultUaid, helloHandlerWithUaid, TestWebSocketServer } from "./test-websocket-server"; const port = 1234; const url = "ws://localhost:" + port; @@ -39,10 +45,110 @@ describe("end to end", () => { afterEach(async () => { await pushManager?.destroy(); + // ensure the server is using the default handlers + server.useDefaultHandlers(); // ensure we don't leak connections between tests server.closeClients(); }); + describe("reconnection", () => { + beforeEach(async () => { + pushManager = await createPushManager(storage, logger, { + autopushUrl: url, + // Set reconnect to occur after 10ms + reconnectDelay: () => new Promise((resolve) => setTimeout(resolve, 10)), + }); + }); + + async function closeWebSocket() { + const client = server.clients[0]; + client.ws.close(); + await new Promise((resolve) => { + client.ws.on("close", resolve); + }); + return client; + } + + it("reconnects when the connection is closed", async () => { + const previousClient = await closeWebSocket(); + + // TODO: better await for reconnect + await new Promise((resolve) => { + setTimeout(() => { + resolve(); + }, 1000); + }); + + expect(server.clients).toHaveLength(1); + expect(server.clients[0]).not.toBe(previousClient); + }); + + it("maintains event subscriptions after reconnect", async () => { + const sub = await pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }); + + const notificationSpy = jest.fn(); + const notificationPromise = new Promise((resolve, reject) => { + sub.addEventListener("notification", (d) => { + notificationSpy(d); + resolve(); + }); + setTimeout(reject, 1000); + }); + + await closeWebSocket(); + + // TODO: better await for reconnect + await new Promise((resolve) => { + setTimeout(() => { + resolve(); + }, 500); + }); + + server.sendNotification(sub.channelID); + + await notificationPromise; + expect(notificationSpy).toHaveBeenCalled(); + }); + + it("maintains event subscriptions after reconnect and a new uaid", async () => { + const newUaid = "new-uaid"; + server.helloHandler = helloHandlerWithUaid(newUaid); + + const sub = await pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }); + + const notificationSpy = jest.fn(); + const notificationPromise = new Promise((resolve, reject) => { + sub.addEventListener("notification", (d) => { + notificationSpy(d); + resolve(); + }); + setTimeout(reject, 1000); + }); + + await closeWebSocket(); + + // TODO: better await for reconnect + await new Promise((resolve) => { + setTimeout(() => { + resolve(); + }, 500); + }); + + expect(server.identifiedClients[0].uaid).toEqual(newUaid); + + server.sendNotification(sub.channelID); + + await notificationPromise; + expect(notificationSpy).toHaveBeenCalled(); + }); + }); + describe("Hello", () => { it("connects to the server", async () => { pushManager = await createPushManager(storage, logger, { autopushUrl: url }); @@ -84,21 +190,78 @@ describe("end to end", () => { use_webpush: true, }); }); + + describe("existing subscriptions", () => { + beforeEach(async () => { + // Set up existing storage + await storage.write("channelIDs", JSON.stringify(["f2ca74ee-d688-4cb2-8ae1-9deb4805be29"])); + await storage.write( + "f2ca74ee-d688-4cb2-8ae1-9deb4805be29:endpoint", + JSON.stringify("https://example.com/push//f2ca74ee-d688-4cb2-8ae1-9deb4805be29"), + ); + await storage.write( + "f2ca74ee-d688-4cb2-8ae1-9deb4805be29:options", + JSON.stringify({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }), + ); + await storage.write( + "f2ca74ee-d688-4cb2-8ae1-9deb4805be29:auth", + JSON.stringify("kKZ96yjFVbvnUa458DDWNg"), + ); + await storage.write( + "f2ca74ee-d688-4cb2-8ae1-9deb4805be29:privateEncKey", + JSON.stringify({ + key_ops: ["deriveKey", "deriveBits"], + ext: true, + kty: "EC", + x: applicationPublicKeyX, + y: applicationPublicKeyY, + crv: "P-256", + d: applicationPrivateKey, + }), + ); + }); + + it("reconnects existing channels", async () => { + // Same Uaid as response + await storage.write("uaid", JSON.stringify(defaultUaid)); + + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const client = server.clients[0]; + expect(client).toHaveReceived({ + messageType: "hello", + uaid: "5f0774ac-09a3-45d9-91e4-f4aaebaeec72", + channelIDs: ["f2ca74ee-d688-4cb2-8ae1-9deb4805be29"], + use_webpush: true, + }); + }); + }); }); describe("Notification", () => { - it("sends a notification", async () => { - pushManager = await createPushManager(storage, logger, { autopushUrl: url }); - const sub = await pushManager.subscribe({ + let sub: GenericPushSubscription; + + beforeEach(async () => { + pushManager = await createPushManager(storage, logger, { + autopushUrl: url, + ackIntervalMs: 100, + }); + sub = await pushManager.subscribe({ userVisibleOnly: true, applicationServerKey: applicationPublicKey, }); + }); + + it("sends a notification", async () => { const notifiedSpy = jest.fn(); - const notifiedCalled = new Promise((resolve) => { + const notifiedCalled = new Promise((resolve, reject) => { sub.addEventListener("notification", (data) => { notifiedSpy(data); resolve(); }); + setTimeout(() => reject(), 1000); }); server.sendNotification(sub.channelID); @@ -108,11 +271,6 @@ describe("end to end", () => { }); it("sends a notification message", async () => { - pushManager = await createPushManager(storage, logger, { autopushUrl: url }); - const sub = await pushManager.subscribe({ - userVisibleOnly: true, - applicationServerKey: applicationPublicKey, - }); const notifiedSpy = jest.fn(); const notifiedCalled = new Promise((resolve, reject) => { sub.addEventListener("notification", (data) => { @@ -135,6 +293,109 @@ describe("end to end", () => { expect(notifiedSpy).toHaveBeenCalledWith("some data"); }); + + it("sends acks when notifications are received", async () => { + const ackPromise = new Promise((resolve, reject) => { + server.ackHandler = () => resolve(); + setTimeout(() => reject(), 1000); + }); + + const version = server.sendNotification(sub.channelID); + + const expectedAck: ClientAck = { + messageType: "ack", + updates: [{ channelID: sub.channelID, version, code: ClientAckCodes.SUCCESS }], + }; + + await expect(ackPromise).resolves.toBeUndefined(); + + const client = server.identifiedClientFor(sub.channelID); + if (!client) { + fail("Client not found"); + } + expect(client).toHaveReceived(expectedAck); + }); + + it("acks decryption errors", async () => { + const ackPromise = new Promise((resolve, reject) => { + server.ackHandler = () => resolve(); + setTimeout(() => reject(), 1000); + }); + + const version = server.sendNotification(sub.channelID, "This should have been encrypted", { + encoding: "aes128gcm", + }); + + const expectedAck: ClientAck = { + messageType: "ack", + updates: [{ channelID: sub.channelID, version, code: ClientAckCodes.DECRYPT_FAIL }], + }; + + await expect(ackPromise).resolves.toBeUndefined(); + + const client = server.identifiedClientFor(sub.channelID); + if (!client) { + fail("Client not found"); + } + expect(client).toHaveReceived(expectedAck); + }); + + it("groups acks together", async () => { + const ackPromise = new Promise((resolve, reject) => { + server.ackHandler = () => resolve(); + setTimeout(() => reject(), 1000); + }); + + const version1 = server.sendNotification(sub.channelID); + const version2 = server.sendNotification(sub.channelID); + + const expectedAck: ClientAck = { + messageType: "ack", + updates: [ + { channelID: sub.channelID, version: version1, code: ClientAckCodes.SUCCESS }, + { channelID: sub.channelID, version: version2, code: ClientAckCodes.SUCCESS }, + ], + }; + + await expect(ackPromise).resolves.toBeUndefined(); + + const client = server.identifiedClientFor(sub.channelID); + if (!client) { + fail("Client not found"); + } + expect(client).toHaveReceived(expectedAck); + }); + + it("groups acks togher across subscriptions", async () => { + const sub2 = await pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }); + + const ackPromise = new Promise((resolve, reject) => { + server.ackHandler = () => resolve(); + setTimeout(() => reject(), 1000); + }); + + const version1 = server.sendNotification(sub.channelID); + const version2 = server.sendNotification(sub2.channelID); + + const expectedAck: ClientAck = { + messageType: "ack", + updates: [ + { channelID: sub.channelID, version: version1, code: ClientAckCodes.SUCCESS }, + { channelID: sub2.channelID, version: version2, code: ClientAckCodes.SUCCESS }, + ], + }; + + await expect(ackPromise).resolves.toBeUndefined(); + + const client = server.identifiedClientFor(sub.channelID); + if (!client) { + fail("Client not found"); + } + expect(client).toHaveReceived(expectedAck); + }); }); }); diff --git a/spec/test-websocket-server.ts b/spec/test-websocket-server.ts index 76fb668..4980ccf 100644 --- a/spec/test-websocket-server.ts +++ b/spec/test-websocket-server.ts @@ -2,6 +2,7 @@ import { WebSocketServer, WebSocket } from "ws"; import { AutoConnectClientMessage, + ClientAck, ClientHello, ClientRegister, ServerHello, @@ -12,13 +13,19 @@ import { import { fromBufferToUtf8, newUuid, Uuid } from "../src/string-manipulation"; export const defaultUaid = "5f0774ac-09a3-45d9-91e4-f4aaebaeec72"; -const defaultHelloHandler = ( +export const helloHandlerWithUaid = + (uaid: string) => + (client: TestWebSocketClient, message: ClientHello, server: TestWebSocketServer) => + helloHandler(client, message, server, uaid); +const defaultHelloHandler = helloHandlerWithUaid(defaultUaid); +const helloHandler = ( client: TestWebSocketClient, message: ClientHello, server: TestWebSocketServer, + uaidToAssign: string, ) => { // Identify the client - const identifiedClient = new IdentifiedWebSocketClient(client, defaultUaid); + const identifiedClient = new IdentifiedWebSocketClient(client, uaidToAssign); server.identifiedClients.push(identifiedClient); // Make sure we track channels for this client @@ -29,7 +36,7 @@ const defaultHelloHandler = ( // Send a response const response: ServerHello = { messageType: "hello", - uaid: defaultUaid, + uaid: uaidToAssign, useWebPush: true, status: 200, // broadcasts: {}, @@ -70,7 +77,15 @@ const defaultUnregisterHandler = ( client.ws.send(JSON.stringify(response)); }; -const defaultServerPingHandler = (client: IdentifiedWebSocketClient) => { +const defaultAckHandler = (_client: IdentifiedWebSocketClient, _message: ClientAck) => { + // Do nothing +}; + +const defaultNackHandler = (_client: IdentifiedWebSocketClient, _message: ClientAck) => { + // Do nothing +}; + +const defaultPingHandler = (client: IdentifiedWebSocketClient) => { const response: ServerPing = { messageType: "ping", }; @@ -83,12 +98,15 @@ export class TestWebSocketServer { readonly clients: TestWebSocketClient[] = []; readonly identifiedClients: IdentifiedWebSocketClient[] = []; - helloHandler = defaultHelloHandler; - registerHandler = defaultRegisterHandler; - unregisterHandler = defaultUnregisterHandler; - serverPingHandler = defaultServerPingHandler; + helloHandler!: typeof defaultHelloHandler; + registerHandler!: typeof defaultRegisterHandler; + unregisterHandler!: typeof defaultUnregisterHandler; + ackHandler!: typeof defaultAckHandler; + nackHandler!: typeof defaultNackHandler; + serverPingHandler!: typeof defaultPingHandler; constructor(port: number) { + this.useDefaultHandlers(); this.server = new WebSocketServer({ port }); this.channelToClientMap = new Map(); @@ -122,6 +140,15 @@ export class TestWebSocketServer { }); } + useDefaultHandlers() { + this.helloHandler = defaultHelloHandler; + this.registerHandler = defaultRegisterHandler; + this.unregisterHandler = defaultUnregisterHandler; + this.ackHandler = defaultAckHandler; + this.nackHandler = defaultNackHandler; + this.serverPingHandler = defaultPingHandler; + } + closeClients() { for (const client of this.clients) { client.ws.close(1001, "Server closing"); @@ -224,6 +251,12 @@ export class TestWebSocketServer { case "unregister": this.unregisterHandler(client, message as ClientRegister, this); break; + case "ack": + this.ackHandler(client, message as ClientAck); + break; + case "nack": + this.nackHandler(client, message as ClientAck); + break; case "ping": this.serverPingHandler(client); break; diff --git a/src/messages/message-mediator.ts b/src/messages/message-mediator.ts index 5d632c9..4db0ca3 100644 --- a/src/messages/message-mediator.ts +++ b/src/messages/message-mediator.ts @@ -21,18 +21,18 @@ import { PingSender } from "./senders/ping-sender"; import { RegisterSender } from "./senders/register-sender"; import { UnregisterSender } from "./senders/unregister-sender"; -const ACK_INTERVAL = 30_000; // 30 seconds export class MessageMediator { private handlers: MessageHandler[]; // eslint-disable-next-line @typescript-eslint/no-explicit-any -- TODO: get rid of this any private senders: MessageSender[]; private ackInterval: NodeJS.Timeout | null = null; - private ackQueue: ClientMessageAck[] = []; + private readonly ackQueue: ClientMessageAck[] = []; private ackSender: AckSender; constructor( readonly pushManager: PushManager, readonly subscriptionHandler: SubscriptionHandler, + options: { ackIntervalMs: number }, private readonly logger: Logger, ) { this.handlers = [ @@ -54,7 +54,7 @@ export class MessageMediator { // Ack is separate because acks are grouped to reduce server load this.ackSender = new AckSender(new NamespacedLogger(logger, "AckSender")); - this.ackInterval = setInterval(() => this.sendAck(), ACK_INTERVAL); + this.ackInterval = setInterval(() => this.sendAck(), options.ackIntervalMs); } destroy() { @@ -136,8 +136,7 @@ export class MessageMediator { return; } - const updates = this.ackQueue.slice(); - this.ackQueue = []; + const updates = this.ackQueue.splice(0, this.ackQueue.length); const message = await this.ackSender.buildMessage({ updates }); this.pushManager.websocket.send(JSON.stringify(message)); diff --git a/src/push-manager.ts b/src/push-manager.ts index deb04eb..d3234c3 100644 --- a/src/push-manager.ts +++ b/src/push-manager.ts @@ -23,13 +23,29 @@ export interface PublicPushManager { } type PushManagerOptions = { - autopushUrl: string; + /** The Url to connect to. Defaults to `wss://push.services.mozilla.com` */ + autopushUrl?: string; + /** The interval between ACK messages. Defaults to 30 seconds (30000) */ + ackIntervalMs?: number; + /** A method which is awaited prior to reconnecting, should the websocket be disconnected. Defaults to a constant timeout of 1 second */ + reconnectDelay?: () => Promise; }; -const defaultPushManagerOptions: PushManagerOptions = Object.freeze({ +const defaultPushManagerOptions: Required = Object.freeze({ autopushUrl: "wss://push.services.mozilla.com", + ackIntervalMs: 30_000, // 30 seconds + reconnectDelay: async () => { + await new Promise((resolve) => setTimeout(resolve, 1000)); + }, }); +function populateOptions(userOptions: PushManagerOptions): Required { + return { + ...defaultPushManagerOptions, + ...userOptions, + }; +} + export class PushManager implements PublicPushManager { private _uaid: string | null = null; private _websocket: WebSocket | null = null; @@ -41,7 +57,7 @@ export class PushManager implements PublicPushManager { private constructor( private readonly storage: Storage, private readonly logger: Logger, - private readonly options: PushManagerOptions, + private readonly options: Required, ) {} get uaid() { @@ -112,15 +128,16 @@ export class PushManager implements PublicPushManager { externalLogger: Logger, options: PushManagerOptions = defaultPushManagerOptions, ) { + const finalOptions = populateOptions(options); const storage = new Storage(externalStorage); const logger = new TimedLogger(externalLogger); - const manager = new PushManager(storage, logger, options); + const manager = new PushManager(storage, logger, finalOptions); const subscriptionHandler = await SubscriptionHandler.create( storage, (channelID: Uuid) => manager.unsubscribe(channelID), new NamespacedLogger(logger, "SubscriptionHandler"), ); - const mediator = new MessageMediator(manager, subscriptionHandler, logger); + const mediator = new MessageMediator(manager, subscriptionHandler, finalOptions, logger); // Assign the circular dependencies manager.mediator = mediator; @@ -140,6 +157,7 @@ export class PushManager implements PublicPushManager { async destroy() { this.reconnect = false; this._websocket?.close(); + this.mediator.destroy(); } private async connect() { @@ -185,7 +203,7 @@ export class PushManager implements PublicPushManager { // TODO: implement a backoff strategy if (this.reconnect) { - setTimeout(() => this.connect(), 1000); + void this.options.reconnectDelay().then(() => this.connect()); // await this.connect(); } }; diff --git a/src/push-subscription.ts b/src/push-subscription.ts index 5d4a0ea..d829340 100644 --- a/src/push-subscription.ts +++ b/src/push-subscription.ts @@ -135,7 +135,8 @@ export class PushSubscription< if ( !message.headers || (message.headers["encoding"] !== "aes128gcm" && - message.headers["Content-Encoding"] !== "aes128gcm") + message.headers["Content-Encoding"] !== "aes128gcm" && + message.headers["content-encoding"] !== "aes128gcm") ) { this.logger.error("Unsupported encoding", message); throw ClientAckCodes.DECRYPT_FAIL;