From a48b85577c94e768e94f23a5767f6e73fbfa4b9a Mon Sep 17 00:00:00 2001 From: Matt Gibson Date: Thu, 8 Aug 2024 10:36:51 -0700 Subject: [PATCH] Start integration testing --- .cspell.json | 11 +- spec/integration.spec.ts | 174 +++++++++++++++++ spec/matchers/index.ts | 9 + spec/matchers/web-socket-received.ts | 97 ++++++++++ spec/test-websocket-server.ts | 268 +++++++++++++++++++++++++++ src/push-manager.ts | 19 +- 6 files changed, 573 insertions(+), 5 deletions(-) create mode 100644 spec/integration.spec.ts create mode 100644 spec/matchers/web-socket-received.ts create mode 100644 spec/test-websocket-server.ts diff --git a/.cspell.json b/.cspell.json index b4e915f..38a8874 100644 --- a/.cspell.json +++ b/.cspell.json @@ -2,8 +2,15 @@ "version": "0.2", "language": "en", // words - list of words to be always considered correct - "words": ["Csprng", "Uaid", "unregistering"], + "words": [ + "ACK'd", + "Csprng", + "Uaid", + "unregistering" + ], // flagWords - list of words to be always considered incorrect // This is useful for offensive words and common spelling errors. - "flagWords": ["channelId"] + "flagWords": [ + "channelId" + ] } diff --git a/spec/integration.spec.ts b/spec/integration.spec.ts new file mode 100644 index 0000000..7668b3a --- /dev/null +++ b/spec/integration.spec.ts @@ -0,0 +1,174 @@ +import * as crypto from "crypto"; + +import { createPushManager } from "../src"; +import { deriveKeyAndNonce, generateEcKeys, randomBytes } from "../src/crypto"; +import { PushManager } from "../src/push-manager"; +import { GenericPushSubscription } from "../src/push-subscription"; +import { + fromBufferToUrlB64, + fromUrlB64ToBuffer, + fromUtf8ToBuffer, +} from "../src/string-manipulation"; + +import { applicationPublicKey } from "./constants"; +import { TestLogger } from "./test-logger"; +import { TestBackingStore } from "./test-storage"; +import { defaultUaid, TestWebSocketServer } from "./test-websocket-server"; + +const port = 1234; +const url = "ws://localhost:" + port; + +describe("end to end", () => { + let storage: TestBackingStore; + let logger: TestLogger; + let server: TestWebSocketServer; + let pushManager: PushManager; + + beforeAll(() => { + server = new TestWebSocketServer(port); + }); + + afterAll(async () => { + await server.close(); + }); + + beforeEach(() => { + storage = new TestBackingStore(); + logger = new TestLogger(); + }); + + afterEach(async () => { + await pushManager?.destroy(); + // ensure we don't leak connections between tests + server.closeClients(); + }); + + describe("Hello", () => { + it("connects to the server", async () => { + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + expect(server.clients).toHaveLength(1); + }); + + it("immediately sends a hello message", async () => { + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const client = server.clients[0]; + expect(client).toHaveReceived(expect.objectContaining({ messageType: "hello" })); + }); + + it("sends a hello message with the correct uaid", async () => { + await storage.write("uaid", JSON.stringify("test-uaid")); + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const client = server.clients[0]; + expect(client).toHaveReceived({ + messageType: "hello", + uaid: "test-uaid", + channelIDs: [], + use_webpush: true, + }); + }); + + it("records the correct uaid to storage", async () => { + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + expect(storage.mock.write).toHaveBeenCalledWith("uaid", JSON.stringify(defaultUaid)); + // await expect(storage.read("uaid")).resolves.toEqual(defaultUaid); + }); + + it("updates uaid in storage when a new one is received", async () => { + await storage.write("uaid", JSON.stringify("test-uaid")); + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const client = server.clients[0]; + expect(client).toHaveReceived({ + messageType: "hello", + uaid: "test-uaid", + channelIDs: [], + use_webpush: true, + }); + }); + }); + + describe("Notification", () => { + it("sends a notification", async () => { + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const sub = await pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }); + const notifiedSpy = jest.fn(); + const notifiedCalled = new Promise((resolve) => { + sub.addEventListener("notification", (data) => { + notifiedSpy(data); + resolve(); + }); + }); + + server.sendNotification(sub.channelID); + await notifiedCalled; + + expect(notifiedSpy).toHaveBeenCalledWith(null); + }); + + it("sends a notification message", async () => { + pushManager = await createPushManager(storage, logger, { autopushUrl: url }); + const sub = await pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: applicationPublicKey, + }); + const notifiedSpy = jest.fn(); + const notifiedCalled = new Promise((resolve, reject) => { + sub.addEventListener("notification", (data) => { + notifiedSpy(data); + resolve(); + }); + setTimeout(() => reject(), 1000); + }); + + const data = "some data"; + const encrypted = await aes128GcmEncrypt(data, sub); + + server.sendNotification(sub.channelID, encrypted, { encoding: "aes128gcm" }); + const client = server.identifiedClientFor(sub.channelID); + if (!client) { + fail("Client not found"); + } + + await notifiedCalled; + + expect(notifiedSpy).toHaveBeenCalledWith("some data"); + }); + }); +}); + +const recordSize = new Uint8Array([0, 0, 4, 0]); +const keyLength = new Uint8Array([65]); +async function aes128GcmEncrypt(data: string, sub: GenericPushSubscription) { + const paddedData = Buffer.concat([fromUtf8ToBuffer(data), new Uint8Array([2, 0, 0, 0, 0])]); + const salt = await randomBytes(16); + const ecKeys = await generateEcKeys(); + const { contentEncryptionKey, nonce } = await deriveKeyAndNonce( + { + publicKey: sub.getKey("p256dhBuffer"), + }, + { + publicKey: ecKeys.uncompressedPublicKey, + privateKey: ecKeys.privateKey, + }, + fromUrlB64ToBuffer(sub.getKey("auth")), + salt, + ); + + const cryptoKey = crypto.createSecretKey(Buffer.from(contentEncryptionKey)); + const cipher = crypto.createCipheriv("aes-128-gcm", cryptoKey, Buffer.from(nonce)); + const encrypted = cipher.update(paddedData); + cipher.final(); + const authTag = cipher.getAuthTag(); + const result = Buffer.concat([ + salt, + recordSize, + keyLength, + new Uint8Array(ecKeys.uncompressedPublicKey), + encrypted, + authTag, + ]); + + return fromBufferToUrlB64(result); +} diff --git a/spec/matchers/index.ts b/spec/matchers/index.ts index 7f880c5..341f4a7 100644 --- a/spec/matchers/index.ts +++ b/spec/matchers/index.ts @@ -1,13 +1,22 @@ +import type { JsonObject, JsonValue } from "type-fest"; + import { toEqualBuffer } from "./to-equal-buffer"; +import { toHaveLastReceived, toHaveNthReceived, toHaveReceived } from "./web-socket-received"; export * from "./to-equal-buffer"; export function addCustomMatchers() { expect.extend({ toEqualBuffer: toEqualBuffer, + toHaveReceived: toHaveReceived, + toHaveLastReceived: toHaveLastReceived, + toHaveNthReceived: toHaveNthReceived, }); } export interface CustomMatchers { toEqualBuffer(expected: Uint8Array | ArrayBuffer | ArrayBufferLike): R; + toHaveReceived(expected: JsonObject | JsonValue): R; + toHaveLastReceived(expected: JsonObject | JsonValue): R; + toHaveNthReceived(expected: JsonObject | JsonValue, n: number): R; } diff --git a/spec/matchers/web-socket-received.ts b/spec/matchers/web-socket-received.ts new file mode 100644 index 0000000..f8bb8e9 --- /dev/null +++ b/spec/matchers/web-socket-received.ts @@ -0,0 +1,97 @@ +import type { JsonObject, JsonValue } from "type-fest"; + +import { TestWebSocketClient } from "../test-websocket-server"; + +/** + * Asserts that a given message was sent by a WebSocket client and received by the test server + */ +export const toHaveReceived: jest.CustomMatcher = function ( + received: TestWebSocketClient, + expected: JsonObject | JsonValue, +) { + if (received.messages.some((message) => this.equals(message, expected))) { + return { + message: () => `expected +${this.utils.printReceived(received.messages)} +not to have received +${this.utils.printExpected(expected)}`, + pass: true, + }; + } + + return { + message: () => `expected +${this.utils.printReceived(received.messages)} +to have received +${this.utils.printExpected(expected)}`, + pass: false, + }; +}; + +/** + * Asserts that a given message was the last one sent by a WebSocket client and received by the test server + */ +export const toHaveLastReceived: jest.CustomMatcher = function ( + received: TestWebSocketClient, + expected: JsonObject | JsonValue, +) { + if (this.equals(received.messages[received.messages.length - 1], expected)) { + return { + message: () => `expected +${received} +not to have last received +${expected}`, + pass: true, + }; + } + + return { + message: () => `expected +${received} +to have last received +${expected}`, + pass: false, + }; +}; + +/** + * Asserts that a given message was the Nth one sent by a WebSocket client and received by the test server + */ +export const toHaveNthReceived: jest.CustomMatcher = function ( + received: TestWebSocketClient, + expected: JsonObject | JsonValue, + n: number, +) { + if (n < 0) { + return { + message: () => "expected positive value for n", + pass: false, + }; + } + if (received.messages.length <= n) { + return { + message: () => `expected +${received} +to have received at least ${n + 1} messages`, + pass: false, + }; + } + + if (this.equals(received.messages[n], expected)) { + return { + message: () => `expected +${received} +not to have last received +${expected}`, + pass: true, + }; + } + + return { + message: () => `expected +${received} +to have last received +${expected}`, + pass: false, + }; +}; diff --git a/spec/test-websocket-server.ts b/spec/test-websocket-server.ts new file mode 100644 index 0000000..76fb668 --- /dev/null +++ b/spec/test-websocket-server.ts @@ -0,0 +1,268 @@ +import { WebSocketServer, WebSocket } from "ws"; + +import { + AutoConnectClientMessage, + ClientHello, + ClientRegister, + ServerHello, + ServerPing, + ServerRegister, + ServerUnregister, +} from "../src/messages/message"; +import { fromBufferToUtf8, newUuid, Uuid } from "../src/string-manipulation"; + +export const defaultUaid = "5f0774ac-09a3-45d9-91e4-f4aaebaeec72"; +const defaultHelloHandler = ( + client: TestWebSocketClient, + message: ClientHello, + server: TestWebSocketServer, +) => { + // Identify the client + const identifiedClient = new IdentifiedWebSocketClient(client, defaultUaid); + server.identifiedClients.push(identifiedClient); + + // Make sure we track channels for this client + for (const channelID of message.channelIDs ?? []) { + server.channelToClientMap.set(channelID, identifiedClient); + } + + // Send a response + const response: ServerHello = { + messageType: "hello", + uaid: defaultUaid, + useWebPush: true, + status: 200, + // broadcasts: {}, + }; + client.ws.send(JSON.stringify(response)); +}; + +const defaultRegisterHandler = ( + client: IdentifiedWebSocketClient, + message: ClientRegister, + server: TestWebSocketServer, +) => { + server.channelToClientMap.set(message.channelID, client); + + // Send a response + const response: ServerRegister = { + messageType: "register", + channelID: message.channelID, + pushEndpoint: `https://example.com/push//${message.channelID}`, + status: 200, + }; + client.ws.send(JSON.stringify(response)); +}; + +const defaultUnregisterHandler = ( + client: IdentifiedWebSocketClient, + message: ClientRegister, + server: TestWebSocketServer, +) => { + server.channelToClientMap.delete(message.channelID); + + // Send a response + const response: ServerUnregister = { + messageType: "unregister", + channelID: message.channelID, + status: 200, + }; + client.ws.send(JSON.stringify(response)); +}; + +const defaultServerPingHandler = (client: IdentifiedWebSocketClient) => { + const response: ServerPing = { + messageType: "ping", + }; + client.ws.send(JSON.stringify(response)); +}; + +export class TestWebSocketServer { + readonly server: WebSocketServer; + readonly channelToClientMap: Map; + readonly clients: TestWebSocketClient[] = []; + readonly identifiedClients: IdentifiedWebSocketClient[] = []; + + helloHandler = defaultHelloHandler; + registerHandler = defaultRegisterHandler; + unregisterHandler = defaultUnregisterHandler; + serverPingHandler = defaultServerPingHandler; + + constructor(port: number) { + this.server = new WebSocketServer({ port }); + this.channelToClientMap = new Map(); + + this.server.on("connection", (ws) => { + let client = new TestWebSocketClient(ws); + this.clients.push(client); + + ws.on("message", (data, isBinary) => { + if (isBinary) { + ws.close(1002, "Bad request"); + return; + } + + client = this.identifiedClientFor(ws) ?? client; + + this.messageHandler(client, JSON.parse(fromBufferToUtf8(data as ArrayBuffer))); + }); + ws.on("close", () => { + //remove client from identifiedClients + const identifiedClient = this.identifiedClientFor(ws); + if (identifiedClient) { + this.identifiedClients.splice(this.identifiedClients.indexOf(identifiedClient), 1); + } + + //remove client from clients + const clientIndex = this.clients.indexOf(client); + if (clientIndex !== -1) { + this.clients.splice(clientIndex, 1); + } + }); + }); + } + + closeClients() { + for (const client of this.clients) { + client.ws.close(1001, "Server closing"); + } + // Clear the identifiedClients array + this.identifiedClients.splice(0, this.identifiedClients.length); + // Clear the clients array + this.clients.splice(0, this.clients.length); + } + + async close() { + return new Promise((resolve, reject) => { + this.closeClients(); + this.server.close((e) => { + if (e) { + reject(e); + } else { + resolve(); + } + }); + }); + } + + get port() { + return this.server.options.port; + } + + clientFor(ws: WebSocket) { + return this.clients.find((client) => client.ws === ws); + } + + identifiedClientFor(id: Uuid | WebSocket) { + if (id instanceof WebSocket) { + return this.identifiedClients.find((client) => client.ws === id); + } + return this.channelToClientMap.get(id); + } + + /** + * Sends a notification to a client channel + * @param channelID channel ID to notify + * @param data The data to send + * @returns The version of the notification. This version should be ACK'd by the client + */ + sendNotification(channelID: Uuid, data?: string, headers?: Record): string { + const client = this.channelToClientMap.get(channelID); + if (!client) { + throw new Error("Client not found"); + } + + const version = newUuid(); + const message = { + messageType: "notification", + channelID, + version, + ttl: 60, + data, + headers, + }; + client.ws.send(JSON.stringify(message)); + return version; + } + + private messageHandler( + client: TestWebSocketClient | IdentifiedWebSocketClient, + message: AutoConnectClientMessage, + ) { + if (!message?.messageType) { + client.ws.close(1002, "Bad request"); + return; + } + + if (client instanceof IdentifiedWebSocketClient) { + this.identifiedMessageHandler(client, message); + return; + } else { + this.unidentifiedMessageHandler(client, message); + } + } + + private unidentifiedMessageHandler( + client: TestWebSocketClient, + message: AutoConnectClientMessage, + ) { + if (message.messageType === "hello") { + this.helloHandler(client, message as ClientHello, this); + } else { + client.ws.close(1002, "Bad request"); + } + } + + private identifiedMessageHandler( + client: IdentifiedWebSocketClient, + message: AutoConnectClientMessage, + ) { + switch (message.messageType) { + case "register": + this.registerHandler(client, message as ClientRegister, this); + break; + case "unregister": + this.unregisterHandler(client, message as ClientRegister, this); + break; + case "ping": + this.serverPingHandler(client); + break; + default: + client.ws.close(1002, "Bad request"); + } + } +} + +export class TestWebSocketClient { + readonly messages: unknown[] = []; + constructor(readonly ws: WebSocket) { + ws.on("message", (message) => { + const utf8 = fromBufferToUtf8(message as ArrayBuffer); + const json = JSON.parse(utf8); + this.messages.push(json); + }); + } + + send(...args: Parameters) { + this.ws.send(...args); + } +} + +export class IdentifiedWebSocketClient implements TestWebSocketClient { + constructor( + readonly upgradedClient: TestWebSocketClient, + readonly uaid: string, + ) {} + + get messages() { + return this.upgradedClient.messages; + } + + get ws() { + return this.upgradedClient.ws; + } + + send(...args: Parameters) { + this.upgradedClient.send(...args); + } +} diff --git a/src/push-manager.ts b/src/push-manager.ts index 349c242..deb04eb 100644 --- a/src/push-manager.ts +++ b/src/push-manager.ts @@ -22,6 +22,14 @@ export interface PublicPushManager { destroy(): Promise; } +type PushManagerOptions = { + autopushUrl: string; +}; + +const defaultPushManagerOptions: PushManagerOptions = Object.freeze({ + autopushUrl: "wss://push.services.mozilla.com", +}); + export class PushManager implements PublicPushManager { private _uaid: string | null = null; private _websocket: WebSocket | null = null; @@ -33,6 +41,7 @@ export class PushManager implements PublicPushManager { private constructor( private readonly storage: Storage, private readonly logger: Logger, + private readonly options: PushManagerOptions, ) {} get uaid() { @@ -98,10 +107,14 @@ export class PushManager implements PublicPushManager { await promise; } - static async create(externalStorage: PublicStorage, externalLogger: Logger) { + static async create( + externalStorage: PublicStorage, + externalLogger: Logger, + options: PushManagerOptions = defaultPushManagerOptions, + ) { const storage = new Storage(externalStorage); const logger = new TimedLogger(externalLogger); - const manager = new PushManager(storage, logger); + const manager = new PushManager(storage, logger, options); const subscriptionHandler = await SubscriptionHandler.create( storage, (channelID: Uuid) => manager.unsubscribe(channelID), @@ -137,7 +150,7 @@ export class PushManager implements PublicPushManager { const helloCompleted = new Promise((resolve) => { this._helloResolve = resolve; }); - this._websocket = new WebSocket("wss://push.services.mozilla.com"); + this._websocket = new WebSocket(this.options.autopushUrl); this._websocket.onmessage = async (event) => { // this.logger.debug("Received ws message", event); let messageData: AutoConnectServerMessage;