diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..82ab396 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,31 @@ +name: CI + +on: + push: + pull_request: + schedule: + - cron: "0 0 * * 0" + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + deno-version: + - 1.25.2 + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Use Deno ${{ matrix.deno-version }} + uses: denolib/setup-deno@v2 + with: + deno-version: ${{ matrix.deno-version }} + + - name: Check format + run: deno fmt --check + + - name: Run tests + run: deno test --allow-net diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b973bde --- /dev/null +++ b/LICENSE @@ -0,0 +1,15 @@ +ISC License + +Copyright (c) 2022 Damien ARRACHEQUESNE + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d97e002 --- /dev/null +++ b/README.md @@ -0,0 +1,236 @@ +# Socket.IO server for Deno + +An implementation of the Socket.IO protocol for Deno. + +Table of content: + +- [Usage](#usage) +- [Options](#options) + - [`path`](#path) + - [`connectTimeout`](#connecttimeout) + - [`pingTimeout`](#pingtimeout) + - [`pingInterval`](#pinginterval) + - [`upgradeTimeout`](#upgradetimeout) + - [`maxHttpBufferSize`](#maxhttpbuffersize) + - [`allowRequest`](#allowrequest) + - [`cors`](#cors) + - [`editHandshakeHeaders`](#edithandshakeheaders) + - [`editResponseHeaders`](#editresponseheaders) +- [Logs](#logs) + +## Usage + +```ts +import { serve } from "https://deno.land/std@0.150.0/http/server.ts"; +import { Server } from "https://deno.land/x/socket.io@0.1.0/mod.ts"; + +const io = new Server(); + +io.on("connection", (socket) => { + console.log(`socket ${socket.id} connected`); + + socket.emit("hello", "world"); + + socket.on("disconnect", (reason) => { + console.log(`socket ${socket.id} disconnected due to ${reason}`); + }); +}); + +await serve(io.handler(), { + port: 3000, +}); +``` + +And then run with: + +``` +$ deno run --allow-net index.ts +``` + +Like the [Node.js server](https://socket.io/docs/v4/typescript/), you can also +provide types for the events sent between the server and the clients: + +```ts +interface ServerToClientEvents { + noArg: () => void; + basicEmit: (a: number, b: string, c: Buffer) => void; + withAck: (d: string, callback: (e: number) => void) => void; +} + +interface ClientToServerEvents { + hello: () => void; +} + +interface InterServerEvents { + ping: () => void; +} + +interface SocketData { + user_id: string; +} + +const io = new Server< + ClientToServerEvents, + ServerToClientEvents, + InterServerEvents, + SocketData +>(); +``` + +## Options + +### `path` + +Default value: `/socket.io/` + +It is the name of the path that is captured on the server side. + +Caution! The server and the client values must match (unless you are using a +path-rewriting proxy in between). + +Example: + +```ts +const io = new Server(httpServer, { + path: "/my-custom-path/", +}); +``` + +### `connectTimeout` + +Default value: `45000` + +The number of ms before disconnecting a client that has not successfully joined +a namespace. + +### `pingTimeout` + +Default value: `20000` + +This value is used in the heartbeat mechanism, which periodically checks if the +connection is still alive between the server and the client. + +The server sends a ping, and if the client does not answer with a pong within +`pingTimeout` ms, the server considers that the connection is closed. + +Similarly, if the client does not receive a ping from the server within +`pingInterval + pingTimeout` ms, the client also considers that the connection +is closed. + +### `pingInterval` + +Default value: `25000` + +See [`pingTimeout`](#pingtimeout) for more explanation. + +### `upgradeTimeout` + +Default value: `10000` + +This is the delay in milliseconds before an uncompleted transport upgrade is +cancelled. + +### `maxHttpBufferSize` + +Default value: `1e6` (1 MB) + +This defines how many bytes a single message can be, before closing the socket. +You may increase or decrease this value depending on your needs. + +### `allowRequest` + +Default value: `-` + +A function that receives a given handshake or upgrade request as its first +parameter, and can decide whether to continue or not. + +Example: + +```ts +const io = new Server({ + allowRequest: (req, connInfo) => { + return Promise.reject("thou shall not pass"); + }, +}); +``` + +### `cors` + +Default value: `-` + +A set of options related to +[Cross-Origin Resource Sharing](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) +(CORS). + +Example: + +```ts +const io = new Server({ + cors: { + origin: ["https://example.com"], + allowedHeaders: ["my-header"], + credentials: true, + }, +}); +``` + +### `editHandshakeHeaders` + +Default value: `-` + +A function that allows to edit the response headers of the handshake request. + +Example: + +```ts +const io = new Server({ + editHandshakeHeaders: (responseHeaders, req, connInfo) => { + responseHeaders.set("set-cookie", "sid=1234"); + }, +}); +``` + +### `editResponseHeaders` + +Default value: `-` + +A function that allows to edit the response headers of all requests. + +Example: + +```ts +const io = new Server({ + editResponseHeaders: (responseHeaders, req, connInfo) => { + responseHeaders.set("my-header", "abcd"); + }, +}); +``` + +## Logs + +The library relies on the standard `log` module, so you can display the internal +logs of the Socket.IO server with: + +```ts +import * as log from "https://deno.land/std@0.150.0/log/mod.ts"; + +await log.setup({ + handlers: { + console: new log.handlers.ConsoleHandler("DEBUG"), + }, + loggers: { + "socket.io": { + level: "DEBUG", + handlers: ["console"], + }, + "engine.io": { + level: "DEBUG", + handlers: ["console"], + }, + }, +}); +``` + +## License + +[ISC](/LICENSE) diff --git a/deps.ts b/deps.ts new file mode 100644 index 0000000..b065b62 --- /dev/null +++ b/deps.ts @@ -0,0 +1,6 @@ +export { + type ConnInfo, + type Handler, +} from "https://deno.land/std@0.150.0/http/server.ts"; + +export { getLogger } from "https://deno.land/std@0.150.0/log/mod.ts"; diff --git a/mod.ts b/mod.ts new file mode 100644 index 0000000..fb2c82e --- /dev/null +++ b/mod.ts @@ -0,0 +1 @@ +export { Server, type ServerOptions } from "./packages/socket.io/mod.ts"; diff --git a/packages/engine.io-parser/base64-arraybuffer.ts b/packages/engine.io-parser/base64-arraybuffer.ts new file mode 100644 index 0000000..021ecc7 --- /dev/null +++ b/packages/engine.io-parser/base64-arraybuffer.ts @@ -0,0 +1,63 @@ +const chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +// Use a lookup table to find the index. +const lookup = new Uint8Array(256); +for (let i = 0; i < chars.length; i++) { + lookup[chars.charCodeAt(i)] = i; +} + +export function encodeToBase64(arraybuffer: ArrayBuffer): string { + const bytes = new Uint8Array(arraybuffer); + const len = bytes.length; + let base64 = ""; + + for (let i = 0; i < len; i += 3) { + base64 += chars[bytes[i] >> 2]; + base64 += chars[((bytes[i] & 3) << 4) | (bytes[i + 1] >> 4)]; + base64 += chars[((bytes[i + 1] & 15) << 2) | (bytes[i + 2] >> 6)]; + base64 += chars[bytes[i + 2] & 63]; + } + + if (len % 3 === 2) { + base64 = base64.substring(0, base64.length - 1) + "="; + } else if (len % 3 === 1) { + base64 = base64.substring(0, base64.length - 2) + "=="; + } + + return base64; +} + +export function decodeFromBase64(base64: string): ArrayBuffer { + const len = base64.length; + let bufferLength = base64.length * 0.75, + i, + p = 0, + encoded1, + encoded2, + encoded3, + encoded4; + + if (base64[base64.length - 1] === "=") { + bufferLength--; + if (base64[base64.length - 2] === "=") { + bufferLength--; + } + } + + const arraybuffer = new ArrayBuffer(bufferLength), + bytes = new Uint8Array(arraybuffer); + + for (i = 0; i < len; i += 4) { + encoded1 = lookup[base64.charCodeAt(i)]; + encoded2 = lookup[base64.charCodeAt(i + 1)]; + encoded3 = lookup[base64.charCodeAt(i + 2)]; + encoded4 = lookup[base64.charCodeAt(i + 3)]; + + bytes[p++] = (encoded1 << 2) | (encoded2 >> 4); + bytes[p++] = ((encoded2 & 15) << 4) | (encoded3 >> 2); + bytes[p++] = ((encoded3 & 3) << 6) | (encoded4 & 63); + } + + return arraybuffer; +} diff --git a/packages/engine.io-parser/mod.ts b/packages/engine.io-parser/mod.ts new file mode 100644 index 0000000..9686443 --- /dev/null +++ b/packages/engine.io-parser/mod.ts @@ -0,0 +1,159 @@ +import { decodeFromBase64, encodeToBase64 } from "./base64-arraybuffer.ts"; + +const SEPARATOR = String.fromCharCode(30); // see https://en.wikipedia.org/wiki/Delimiter#ASCII_delimited_text + +export type PacketType = + | "open" + | "close" + | "ping" + | "pong" + | "message" + | "upgrade" + | "noop" + | "error"; + +export type RawData = string | ArrayBuffer | ArrayBufferView | Blob; + +export interface Packet { + type: PacketType; + data?: RawData; +} + +const PACKET_TYPES = new Map<PacketType, string>(); +const PACKET_TYPES_REVERSE = new Map<string, PacketType>(); + +([ + "open", + "close", + "ping", + "pong", + "message", + "upgrade", + "noop", +] as PacketType[]).forEach((type, index) => { + PACKET_TYPES.set(type, "" + index); + PACKET_TYPES_REVERSE.set("" + index, type); +}); + +const ERROR_PACKET: Packet = { type: "error", data: "parser error" }; + +type BinaryType = "arraybuffer" | "blob"; + +export const Parser = { + encodePacket( + { type, data }: Packet, + supportsBinary: boolean, + callback: (encodedPacket: RawData) => void, + ) { + if (data instanceof Blob) { + return supportsBinary + ? callback(data) + : encodeBlobAsBase64(data, callback); + } else if (data instanceof ArrayBuffer) { + return callback(supportsBinary ? data : "b" + encodeToBase64(data)); + } else if (ArrayBuffer.isView(data)) { + if (supportsBinary) { + return callback(data); + } else { + const array = new Uint8Array( + data.buffer, + data.byteOffset, + data.byteLength, + ); + return callback("b" + encodeToBase64(array)); + } + } + // plain string + return callback(PACKET_TYPES.get(type) + (data || "")); + }, + + decodePacket( + encodedPacket: RawData, + binaryType?: BinaryType, + ): Packet { + if (typeof encodedPacket !== "string") { + return { + type: "message", + data: mapBinary(encodedPacket, binaryType), + }; + } + const typeChar = encodedPacket.charAt(0); + if (typeChar === "b") { + const buffer = decodeFromBase64(encodedPacket.substring(1)); + return { + type: "message", + data: mapBinary(buffer, binaryType), + }; + } + if (!PACKET_TYPES_REVERSE.has(typeChar)) { + return ERROR_PACKET; + } + const type = PACKET_TYPES_REVERSE.get(typeChar)!; + return encodedPacket.length > 1 + ? { + type, + data: encodedPacket.substring(1), + } + : { + type, + }; + }, + + encodePayload( + packets: Packet[], + callback: (encodedPayload: string) => void, + ) { + // some packets may be added to the array while encoding, so the initial length must be saved + const length = packets.length; + const encodedPackets = new Array(length); + let count = 0; + + packets.forEach((packet, i) => { + // force base64 encoding for binary packets + this.encodePacket(packet, false, (encodedPacket) => { + encodedPackets[i] = encodedPacket; + if (++count === length) { + callback(encodedPackets.join(SEPARATOR)); + } + }); + }); + }, + + decodePayload( + encodedPayload: string, + binaryType?: BinaryType, + ): Packet[] { + const encodedPackets = encodedPayload.split(SEPARATOR); + const packets = []; + for (let i = 0; i < encodedPackets.length; i++) { + const decodedPacket = this.decodePacket(encodedPackets[i], binaryType); + packets.push(decodedPacket); + if (decodedPacket.type === "error") { + break; + } + } + return packets; + }, +}; + +function encodeBlobAsBase64( + data: Blob, + callback: (encodedPacket: RawData) => void, +) { + const fileReader = new FileReader(); + fileReader.onload = function () { + const content = (fileReader.result as string).split(",")[1]; + callback("b" + content); + }; + return fileReader.readAsDataURL(data); +} + +function mapBinary(data: RawData, binaryType?: BinaryType) { + switch (binaryType) { + case "blob": + return new Blob([data]); + case "arraybuffer": + default: + return data; // assuming the data is already an ArrayBuffer + } +} diff --git a/packages/engine.io-parser/test.ts b/packages/engine.io-parser/test.ts new file mode 100644 index 0000000..e93efca --- /dev/null +++ b/packages/engine.io-parser/test.ts @@ -0,0 +1,185 @@ +import { + assertEquals, + assertInstanceOf, + assertStrictEquals, + describe, + it, +} from "../../test_deps.ts"; +import { Packet, Parser } from "./mod.ts"; + +describe("engine.io-parser", () => { + describe("single packet", () => { + it("should encode/decode a string", () => { + return new Promise((done) => { + const packet: Packet = { type: "message", data: "test" }; + Parser.encodePacket(packet, true, (encodedPacket) => { + assertEquals(encodedPacket, "4test"); + assertEquals(Parser.decodePacket(encodedPacket), packet); + done(); + }); + }); + }); + + it("should fail to decode a malformed packet", () => { + assertEquals(Parser.decodePacket(""), { + type: "error", + data: "parser error", + }); + assertEquals(Parser.decodePacket("a123"), { + type: "error", + data: "parser error", + }); + }); + + it("should encode/decode an ArrayBuffer", () => { + return new Promise((done) => { + const packet: Packet = { + type: "message", + data: Uint8Array.from([1, 2, 3, 4]).buffer, + }; + Parser.encodePacket(packet, true, (encodedPacket) => { + assertEquals( + new Uint8Array(encodedPacket as ArrayBuffer), + Uint8Array.from([1, 2, 3, 4]), + ); + + const decodedPacket = Parser.decodePacket( + encodedPacket, + "arraybuffer", + ); + + assertEquals(decodedPacket.type, packet.type); + assertInstanceOf(decodedPacket.data, ArrayBuffer); + assertEquals( + new Uint8Array(decodedPacket.data as ArrayBuffer), + new Uint8Array(packet.data as ArrayBuffer), + ); + done(); + }); + }); + }); + + it("should encode/decode an ArrayBuffer as base64", () => { + return new Promise((done) => { + const packet: Packet = { + type: "message", + data: Uint8Array.from([1, 2, 3, 4]).buffer, + }; + Parser.encodePacket(packet, false, (encodedPacket) => { + assertEquals(encodedPacket, "bAQIDBA=="); + + const decodedPacket = Parser.decodePacket( + encodedPacket, + "arraybuffer", + ); + + assertEquals(decodedPacket.type, packet.type); + assertInstanceOf(decodedPacket.data, ArrayBuffer); + assertEquals( + new Uint8Array(decodedPacket.data as ArrayBuffer), + new Uint8Array(packet.data as ArrayBuffer), + ); + done(); + }); + }); + }); + + it("should encode a typed array", () => { + return new Promise((done) => { + const buffer = Uint8Array.from([1, 2, 3, 4]).buffer; + const data = new Uint8Array(buffer, 1, 2); + + Parser.encodePacket( + { type: "message", data }, + true, + (encodedPacket) => { + assertStrictEquals(encodedPacket, data); // unmodified typed array + done(); + }, + ); + }); + }); + + it("should encode/decode a Blob", () => { + return new Promise((done) => { + const packet: Packet = { + type: "message", + data: new Blob(["1234", Uint8Array.from([1, 2, 3, 4])]), + }; + Parser.encodePacket(packet, true, (encodedPacket) => { + assertInstanceOf(encodedPacket, Blob); + + const decodedPacket = Parser.decodePacket(encodedPacket, "blob"); + + assertEquals(decodedPacket.type, "message"); + assertInstanceOf(decodedPacket.data, Blob); + done(); + }); + }); + }); + + it("should encode/decode a Blob as base64", () => { + return new Promise((done) => { + const packet: Packet = { + type: "message", + data: new Blob(["1234", Uint8Array.from([1, 2, 3, 4])]), + }; + Parser.encodePacket(packet, false, (encodedPacket) => { + assertEquals(encodedPacket, "bMTIzNAECAwQ="); + + const decodedPacket = Parser.decodePacket(encodedPacket, "blob"); + + assertEquals(decodedPacket.type, "message"); + assertInstanceOf(decodedPacket.data, Blob); + done(); + }); + }); + }); + }); + + describe("payload", () => { + it("should encode/decode all packet types", () => { + return new Promise((done) => { + const packets: Packet[] = [ + { type: "open" }, + { type: "close" }, + { type: "ping", data: "probe" }, + { type: "pong", data: "probe" }, + { type: "message", data: "test" }, + ]; + + Parser.encodePayload(packets, (payload) => { + assertEquals(payload, "0\x1e1\x1e2probe\x1e3probe\x1e4test"); + assertEquals(Parser.decodePayload(payload), packets); + done(); + }); + }); + }); + + it("should fail to decode a malformed payload", () => { + assertEquals(Parser.decodePayload("{"), [ + { type: "error", data: "parser error" }, + ]); + assertEquals(Parser.decodePayload("{}"), [ + { type: "error", data: "parser error" }, + ]); + assertEquals(Parser.decodePayload('["a123", "a456"]'), [ + { type: "error", data: "parser error" }, + ]); + }); + + it("should encode/decode a string + ArrayBuffer payload", () => { + return new Promise((done) => { + const packets: Packet[] = [ + { type: "message", data: "test" }, + { type: "message", data: Uint8Array.from([1, 2, 3, 4]).buffer }, + ]; + Parser.encodePayload(packets, (payload) => { + assertEquals(payload, "4test\x1ebAQIDBA=="); + assertEquals(Parser.decodePayload(payload, "arraybuffer"), packets); + done(); + }); + }); + }); + }); +}); diff --git a/packages/engine.io/lib/cors.ts b/packages/engine.io/lib/cors.ts new file mode 100644 index 0000000..6255d41 --- /dev/null +++ b/packages/engine.io/lib/cors.ts @@ -0,0 +1,99 @@ +type OriginOption = boolean | string | RegExp | (string | RegExp)[]; + +export interface CorsOptions { + origin?: OriginOption; + methods?: string | string[]; + allowedHeaders?: string | string[]; + exposedHeaders?: string | string[]; + credentials?: boolean; + maxAge?: number; +} + +export function addCorsHeaders( + headers: Headers, + opts: CorsOptions, + req: Request, +) { + addOrigin(opts, headers, req); + addCredentials(opts, headers); + addExposedHeaders(opts, headers); + + if (req.method === "OPTIONS") { + addMethods(opts, headers); + addAllowedHeaders(opts, headers, req); + addMaxAge(opts, headers); + } +} + +function join(arg: string | string[]) { + return Array.isArray(arg) ? arg.join(",") : arg; +} + +function isOriginAllowed(allowedOrigin: OriginOption, origin: string): boolean { + if (Array.isArray(allowedOrigin)) { + for (let i = 0; i < allowedOrigin.length; i++) { + if (isOriginAllowed(allowedOrigin[i], origin)) { + return true; + } + } + return false; + } else if (typeof allowedOrigin === "string") { + return allowedOrigin === origin; + } else if (allowedOrigin instanceof RegExp) { + return allowedOrigin.test(origin); + } else { + return !!allowedOrigin; + } +} + +function addOrigin(opts: CorsOptions, headers: Headers, req: Request) { + const origin = req.headers.get("origin")!; + const allowedOrigin = opts.origin; + + if (!allowedOrigin || allowedOrigin === "*") { + headers.set("Access-Control-Allow-Origin", "*"); + } else if (typeof allowedOrigin === "string") { + headers.set("Access-Control-Allow-Origin", allowedOrigin); + headers.append("Vary", "Origin"); + } else { + const isAllowed = isOriginAllowed(allowedOrigin, origin); + headers.set("Access-Control-Allow-Origin", isAllowed ? origin : "false"); + headers.append("Vary", "Origin"); + } +} + +function addMethods(opts: CorsOptions, headers: Headers) { + if (opts.methods) { + headers.set("Access-Control-Allow-Methods", join(opts.methods)); + } +} + +function addAllowedHeaders(opts: CorsOptions, headers: Headers, req: Request) { + if (opts.allowedHeaders) { + headers.set("Access-Control-Allow-Headers", join(opts.allowedHeaders)); + return; + } + const requestedHeaders = req.headers.get("access-control-request-headers"); + if (requestedHeaders) { + headers.append("Vary", "Access-Control-Request-Headers"); + headers.set("Access-Control-Allow-Headers", requestedHeaders); + } +} + +function addExposedHeaders(opts: CorsOptions, headers: Headers) { + if (opts.exposedHeaders) { + headers.set("Access-Control-Expose-Headers", join(opts.exposedHeaders)); + } +} + +function addCredentials(opts: CorsOptions, headers: Headers) { + if (opts.credentials) { + headers.set("Access-Control-Allow-Credentials", "true"); + } +} + +function addMaxAge(opts: CorsOptions, headers: Headers) { + if (opts.maxAge) { + headers.set("Access-Control-Max-Age", opts.maxAge.toString()); + } +} diff --git a/packages/engine.io/lib/server.ts b/packages/engine.io/lib/server.ts new file mode 100644 index 0000000..a89cad8 --- /dev/null +++ b/packages/engine.io/lib/server.ts @@ -0,0 +1,340 @@ +import { ConnInfo, getLogger, Handler } from "../../../deps.ts"; +import { EventEmitter } from "../../event-emitter/mod.ts"; +import { Socket } from "./socket.ts"; +import { Polling } from "./transports/polling.ts"; +import { WS } from "./transports/websocket.ts"; +import { addCorsHeaders, CorsOptions } from "./cors.ts"; +import { Transport } from "./transport.ts"; +import { generateId } from "./util.ts"; + +const TRANSPORTS = ["polling", "websocket"]; + +export interface ServerOptions { + /** + * Name of the request path to handle + * @default "/engine.io/" + */ + path: string; + /** + * Duration in milliseconds without a pong packet to consider the connection closed + * @default 20000 + */ + pingTimeout: number; + /** + * Duration in milliseconds before sending a new ping packet + * @default 25000 + */ + pingInterval: number; + /** + * Duration in milliseconds before an uncompleted transport upgrade is cancelled + * @default 10000 + */ + upgradeTimeout: number; + /** + * Maximum size in bytes or number of characters a message can be, before closing the session (to avoid DoS). + * @default 1e6 (1 MB) + */ + maxHttpBufferSize: number; + /** + * A function that receives a given handshake or upgrade request as its first parameter, + * and can decide whether to continue or not. + */ + allowRequest?: ( + req: Request, + connInfo: ConnInfo, + ) => Promise<void>; + /** + * The options related to Cross-Origin Resource Sharing (CORS) + */ + cors?: CorsOptions; + /** + * A function that allows to edit the response headers of the handshake request + */ + editHandshakeHeaders?: ( + responseHeaders: Headers, + req: Request, + connInfo: ConnInfo, + ) => void | Promise<void>; + /** + * A function that allows to edit the response headers of all requests + */ + editResponseHeaders?: ( + responseHeaders: Headers, + req: Request, + connInfo: ConnInfo, + ) => void | Promise<void>; +} + +interface ConnectionError { + req: Request; + code: number; + message: string; + context: Record<string, unknown>; +} + +interface ServerReservedEvents { + connection: (socket: Socket, request: Request, connInfo: ConnInfo) => void; + connection_error: (err: ConnectionError) => void; +} + +const enum ERROR_CODES { + UNKNOWN_TRANSPORT = 0, + UNKNOWN_SID, + BAD_HANDSHAKE_METHOD, + BAD_REQUEST, + FORBIDDEN, + UNSUPPORTED_PROTOCOL_VERSION, +} + +const ERROR_MESSAGES = new Map<ERROR_CODES, string>([ + [ERROR_CODES.UNKNOWN_TRANSPORT, "Transport unknown"], + [ERROR_CODES.UNKNOWN_SID, "Session ID unknown"], + [ERROR_CODES.BAD_HANDSHAKE_METHOD, "Bad handshake method"], + [ERROR_CODES.BAD_REQUEST, "Bad request"], + [ERROR_CODES.FORBIDDEN, "Forbidden"], + [ERROR_CODES.UNSUPPORTED_PROTOCOL_VERSION, "Unsupported protocol version"], +]); + +export class Server extends EventEmitter<never, never, ServerReservedEvents> { + public readonly opts: ServerOptions; + + private clients: Map<string, Socket> = new Map(); + + constructor(opts: Partial<ServerOptions> = {}) { + super(); + + this.opts = Object.assign( + { + path: "/engine.io/", + pingTimeout: 20000, + pingInterval: 25000, + upgradeTimeout: 10000, + maxHttpBufferSize: 1e6, + }, + opts, + ); + } + + /** + * Returns a request handler. + * + * @param additionalHandler - another handler which will receive the request if the path does not match + */ + public handler(additionalHandler?: Handler) { + return (req: Request, connInfo: ConnInfo): Response | Promise<Response> => { + const url = new URL(req.url); + if (url.pathname === this.opts.path) { + return this.handleRequest(req, connInfo, url); + } else if (additionalHandler) { + return additionalHandler(req, connInfo); + } else { + return new Response(null, { status: 404 }); + } + }; + } + + /** + * Handles an HTTP request. + * + * @param req + * @param connInfo + * @param url + * @private + */ + private async handleRequest( + req: Request, + connInfo: ConnInfo, + url: URL, + ): Promise<Response> { + getLogger("engine.io").debug(`[server] handling ${req.method} ${req.url}`); + + const responseHeaders = new Headers(); + if (this.opts.cors) { + addCorsHeaders(responseHeaders, this.opts.cors, req); + + if (req.method === "OPTIONS") { + return new Response(null, { status: 204, headers: responseHeaders }); + } + } + + if (this.opts.editResponseHeaders) { + await this.opts.editResponseHeaders(responseHeaders, req, connInfo); + } + + try { + await this.verify(req, url); + } catch ({ code, context }) { + const message = ERROR_MESSAGES.get(code)!; + this.emitReserved("connection_error", { + req, + code, + message, + context, + }); + const body = JSON.stringify({ + code, + message, + }); + responseHeaders.set("Content-Type", "application/json"); + return new Response(body, { + status: 400, + headers: responseHeaders, + }); + } + + if (this.opts.allowRequest) { + try { + await this.opts.allowRequest(req, connInfo); + } catch (reason) { + this.emitReserved("connection_error", { + req, + code: ERROR_CODES.FORBIDDEN, + message: ERROR_MESSAGES.get(ERROR_CODES.FORBIDDEN)!, + context: { + message: reason, + }, + }); + const body = JSON.stringify({ + code: ERROR_CODES.FORBIDDEN, + message: reason, + }); + responseHeaders.set("Content-Type", "application/json"); + return new Response(body, { + status: 403, + headers: responseHeaders, + }); + } + } + + const sid = url.searchParams.get("sid"); + if (sid) { + // the client must exist since we have checked it in the verify method + const socket = this.clients.get(sid)!; + + if (url.searchParams.get("transport") === "websocket") { + const transport = new WS(this.opts); + + const promise = transport.onRequest(req, responseHeaders); + + socket._maybeUpgrade(transport); + + return promise; + } + + getLogger("engine.io").debug( + "[server] setting new request for existing socket", + ); + + return socket.transport.onRequest(req, responseHeaders); + } else { + return this.handshake(req, connInfo, responseHeaders); + } + } + + /** + * Verifies a request. + * + * @param req + * @param url + * @private + */ + private verify(req: Request, url: URL): Promise<void> { + const transport = url.searchParams.get("transport") || ""; + if (!TRANSPORTS.includes(transport)) { + getLogger("engine.io").debug(`unknown transport "${transport}"`); + return Promise.reject({ + code: ERROR_CODES.UNKNOWN_TRANSPORT, + context: { + transport, + }, + }); + } + + const sid = url.searchParams.get("sid"); + if (sid) { + if (!this.clients.has(sid)) { + return Promise.reject({ + code: ERROR_CODES.UNKNOWN_SID, + context: { + sid, + }, + }); + } + } else { + // handshake is GET only + if (req.method !== "GET") { + return Promise.reject({ + code: ERROR_CODES.BAD_HANDSHAKE_METHOD, + context: { + method: req.method, + }, + }); + } + + const protocol = url.searchParams.get("EIO") === "4" ? 4 : 3; // 3rd revision by default + if (protocol === 3) { + return Promise.reject({ + code: ERROR_CODES.UNSUPPORTED_PROTOCOL_VERSION, + context: { + protocol, + }, + }); + } + } + + return Promise.resolve(); + } + + /** + * Handshakes a new client. + * + * @param req + * @param connInfo + * @param responseHeaders + * @private + */ + private async handshake( + req: Request, + connInfo: ConnInfo, + responseHeaders: Headers, + ): Promise<Response> { + const id = generateId(); + + let transport: Transport; + if (req.headers.has("upgrade")) { + transport = new WS(this.opts); + } else { + transport = new Polling(this.opts); + } + + getLogger("engine.io").info(`[server] new socket ${id}`); + + const socket = new Socket(id, this.opts, transport); + this.clients.set(id, socket); + + socket.once("close", (reason) => { + getLogger("engine.io").info( + `[server] socket ${id} closed due to ${reason}`, + ); + this.clients.delete(id); + }); + + if (this.opts.editHandshakeHeaders) { + await this.opts.editHandshakeHeaders(responseHeaders, req, connInfo); + } + + const promise = transport.onRequest(req, responseHeaders); + + this.emitReserved("connection", socket, req, connInfo); + + return promise; + } + + /** + * Closes all clients. + */ + public close() { + getLogger("engine.io").debug("[server] closing all open clients"); + this.clients.forEach((client) => client.close()); + } +} diff --git a/packages/engine.io/lib/socket.ts b/packages/engine.io/lib/socket.ts new file mode 100644 index 0000000..11e3354 --- /dev/null +++ b/packages/engine.io/lib/socket.ts @@ -0,0 +1,356 @@ +import { EventEmitter } from "../../event-emitter/mod.ts"; +import { getLogger } from "../../../deps.ts"; +import { Packet, PacketType, RawData } from "../../engine.io-parser/mod.ts"; +import { Transport, TransportError } from "./transport.ts"; +import { ServerOptions } from "./server.ts"; + +type ReadyState = "opening" | "open" | "closing" | "closed"; + +type UpgradeState = "not_upgraded" | "upgrading" | "upgraded"; + +export type CloseReason = + | "transport error" + | "transport close" + | "forced close" + | "ping timeout" + | "parse error"; + +interface SocketEvents { + open: () => void; + packet: (packet: Packet) => void; + packetCreate: (packet: Packet) => void; + message: (message: RawData) => void; + flush: (writeBuffer: Packet[]) => void; + drain: () => void; + heartbeat: () => void; + upgrading: (transport: Transport) => void; + upgrade: (transport: Transport) => void; + close: (reason: CloseReason) => void; +} + +export class Socket extends EventEmitter< + never, + never, + SocketEvents +> { + public readonly id: string; + public readyState: ReadyState = "opening"; + public transport: Transport; + + private readonly opts: ServerOptions; + private upgradeState: UpgradeState = "not_upgraded"; + private writeBuffer: Packet[] = []; + private pingIntervalTimerId?: number; + private pingTimeoutTimerId?: number; + + constructor(id: string, opts: ServerOptions, transport: Transport) { + super(); + + this.id = id; + this.opts = opts; + + this.transport = transport; + this.bindTransport(transport); + this.onOpen(); + } + + /** + * Called upon transport considered open. + * + * @private + */ + private onOpen() { + this.readyState = "open"; + + this.sendPacket( + "open", + JSON.stringify({ + sid: this.id, + upgrades: this.transport.upgradesTo, + pingInterval: this.opts.pingInterval, + pingTimeout: this.opts.pingTimeout, + maxPayload: this.opts.maxHttpBufferSize, + }), + ); + + this.emitReserved("open"); + this.schedulePing(); + } + + /** + * Called upon transport packet. + * + * @param packet + * @private + */ + private onPacket(packet: Packet) { + if (this.readyState !== "open") { + getLogger("engine.io").debug( + "[socket] packet received with closed socket", + ); + return; + } + + getLogger("engine.io").debug(`[socket] received packet ${packet.type}`); + + this.emitReserved("packet", packet); + + switch (packet.type) { + case "pong": + getLogger("engine.io").debug("[socket] got pong"); + + clearTimeout(this.pingTimeoutTimerId); + this.schedulePing(); + + this.emitReserved("heartbeat"); + break; + + case "message": + this.emitReserved("message", packet.data!); + break; + + case "error": + default: + this.onClose("parse error"); + break; + } + } + + /** + * Called upon transport error. + * + * @param err + * @private + */ + private onError(err: TransportError) { + getLogger("engine.io").debug(`[socket] transport error: ${err.message}`); + this.onClose("transport error"); + } + + /** + * Pings client every `pingInterval` and expects response + * within `pingTimeout` or closes connection. + * + * @private + */ + private schedulePing() { + this.pingIntervalTimerId = setTimeout(() => { + getLogger("engine.io").debug( + `[socket] writing ping packet - expecting pong within ${this.opts.pingTimeout} ms`, + this.opts.pingTimeout, + ); + this.sendPacket("ping"); + this.resetPingTimeout(); + }, this.opts.pingInterval); + } + + /** + * Resets ping timeout. + * + * @private + */ + private resetPingTimeout() { + clearTimeout(this.pingTimeoutTimerId); + this.pingTimeoutTimerId = setTimeout(() => { + if (this.readyState !== "closed") { + this.onClose("ping timeout"); + } + }, this.opts.pingTimeout); + } + + /** + * Attaches handlers for the given transport. + * + * @param transport + * @private + */ + private bindTransport(transport: Transport) { + this.transport = transport; + this.transport.once("error", (err) => this.onError(err)); + this.transport.on("packet", (packet) => this.onPacket(packet)); + this.transport.on("drain", () => this.flush()); + this.transport.on("close", () => this.onClose("transport close")); + } + + /** + * Upgrades socket to the given transport + * + * @param transport + * @private + */ + /* private */ _maybeUpgrade(transport: Transport) { + if (this.upgradeState === "upgrading") { + getLogger("engine.io").debug( + "[socket] transport has already been trying to upgrade", + ); + return transport.close(); + } else if (this.upgradeState === "upgraded") { + getLogger("engine.io").debug( + "[socket] transport has already been upgraded", + ); + return transport.close(); + } + + getLogger("engine.io").debug("[socket] upgrading existing transport"); + this.upgradeState = "upgrading"; + + const timeoutId = setTimeout(() => { + getLogger("engine.io").debug( + "[socket] client did not complete upgrade - closing transport", + ); + transport.close(); + }, this.opts.upgradeTimeout); + + transport.on("close", () => { + transport.off(); + }); + + transport.on("packet", (packet) => { + if (packet.type === "ping" && packet.data === "probe") { + getLogger("engine.io").debug( + "[socket] got probe ping packet, sending pong", + ); + transport.send([{ type: "pong", data: "probe" }]); + this.emitReserved("upgrading", transport); + } else if (packet.type === "upgrade") { + getLogger("engine.io").debug("[socket] got upgrade packet - upgrading"); + + this.upgradeState = "upgraded"; + + clearTimeout(timeoutId); + transport.off(); + this.closeTransport(); + this.bindTransport(transport); + + this.emitReserved("upgrade", transport); + this.flush(); + } else { + getLogger("engine.io").debug("[socket] invalid upgrade packet"); + + clearTimeout(timeoutId); + transport.close(); + } + }); + } + + /** + * Called upon transport considered closed. + * + * @param reason + * @private + */ + private onClose(reason: CloseReason) { + if (this.readyState === "closed") { + return; + } + getLogger("engine.io").debug(`[socket] socket closed due to ${reason}`); + + this.readyState = "closed"; + clearTimeout(this.pingIntervalTimerId); + clearTimeout(this.pingTimeoutTimerId); + + this.closeTransport(); + this.emitReserved("close", reason); + } + + /** + * Sends a "message" packet. + * + * @param data + */ + public send(data: RawData): Socket { + this.sendPacket("message", data); + return this; + } + + /** + * Sends a packet. + * + * @param type + * @param data + * @private + */ + private sendPacket( + type: PacketType, + data?: RawData, + ) { + if (["closing", "closed"].includes(this.readyState)) { + return; + } + + getLogger("engine.io").debug(`[socket] sending packet ${type} (${data})`); + + const packet: Packet = { + type, + data, + }; + + this.emitReserved("packetCreate", packet); + + this.writeBuffer.push(packet); + + this.flush(); + } + + /** + * Attempts to flush the packets buffer. + * + * @private + */ + private flush() { + const shouldFlush = this.readyState !== "closed" && + this.transport.writable && + this.writeBuffer.length > 0; + + if (!shouldFlush) { + return; + } + + getLogger("engine.io").debug( + `[socket] flushing buffer with ${this.writeBuffer.length} packet(s) to transport`, + ); + + this.emitReserved("flush", this.writeBuffer); + + const buffer = this.writeBuffer; + this.writeBuffer = []; + + this.transport.send(buffer); + this.emitReserved("drain"); + } + + /** + * Closes the socket and underlying transport. + */ + public close() { + if (this.readyState !== "open") { + return; + } + + this.readyState = "closing"; + + const close = () => { + this.closeTransport(); + this.onClose("forced close"); + }; + + if (this.writeBuffer.length) { + getLogger("engine.io").debug( + `[socket] buffer not empty, waiting for the drain event`, + ); + this.once("drain", close); + } else { + close(); + } + } + + /** + * Closes the underlying transport. + * + * @private + */ + private closeTransport() { + this.transport.off(); + this.transport.close(); + } +} diff --git a/packages/engine.io/lib/transport.ts b/packages/engine.io/lib/transport.ts new file mode 100644 index 0000000..40cbfcd --- /dev/null +++ b/packages/engine.io/lib/transport.ts @@ -0,0 +1,125 @@ +import { EventEmitter } from "../../event-emitter/mod.ts"; +import { Packet, Parser, RawData } from "../../engine.io-parser/mod.ts"; +import { getLogger } from "../../../deps.ts"; +import { ServerOptions } from "./server.ts"; + +interface TransportEvents { + packet: (packet: Packet) => void; + error: (error: TransportError) => void; + drain: () => void; + close: () => void; +} + +type ReadyState = "open" | "closing" | "closed"; + +export abstract class Transport extends EventEmitter< + never, + never, + TransportEvents +> { + public writable = false; + + protected readyState: ReadyState = "open"; + protected readonly opts: ServerOptions; + + constructor(opts: ServerOptions) { + super(); + this.opts = opts; + } + + /** + * The name of the transport + */ + public abstract get name(): string; + + /** + * The list of transports to upgrade to + */ + public abstract get upgradesTo(): string[]; + + /** + * Called with an incoming HTTP request. + * + * @param req + * @param responseHeaders + */ + public abstract onRequest( + req: Request, + responseHeaders: Headers, + ): Promise<Response>; + + /** + * Writes an array of packets. + * + * @param packets + */ + public abstract send(packets: Packet[]): void; + + /** + * Closes the transport. + * + * @protected + */ + protected abstract doClose(): void; + + /** + * Manually closes the transport. + */ + public close() { + if (["closing", "closed"].includes(this.readyState)) { + return; + } + + getLogger("engine.io").debug("[transport] closing transport"); + this.readyState = "closing"; + this.doClose(); + } + + /** + * Called when the transport encounters a fatal error. + * + * @param message + * @protected + */ + protected onError(message: string) { + this.emitReserved("error", new TransportError(message)); + } + + /** + * Called with a parsed packet from the data stream. + * + * @param packet + * @protected + */ + protected onPacket(packet: Packet) { + if (packet.type === "close") { + getLogger("engine.io").debug("[transport] received 'close' packet"); + return this.doClose(); + } + this.emitReserved("packet", packet); + } + + /** + * Called with the encoded packet data. + * + * @param data + * @protected + */ + protected onData(data: RawData) { + this.onPacket(Parser.decodePacket(data)); + } + + /** + * Called upon transport close. + * + * @protected + */ + protected onClose() { + this.readyState = "closed"; + this.emitReserved("close"); + } +} + +export class TransportError extends Error { + public readonly type = "TransportError"; +} diff --git a/packages/engine.io/lib/transports/polling.ts b/packages/engine.io/lib/transports/polling.ts new file mode 100644 index 0000000..894bef1 --- /dev/null +++ b/packages/engine.io/lib/transports/polling.ts @@ -0,0 +1,157 @@ +import { getLogger } from "../../../../deps.ts"; +import { Transport } from "../transport.ts"; +import { Packet, Parser } from "../../../engine.io-parser/mod.ts"; + +export class Polling extends Transport { + private pollingPromise?: { + resolve: (res: Response) => void; + reject: () => void; + responseHeaders: Headers; + }; + + public get name() { + return "polling"; + } + + public get upgradesTo(): string[] { + return ["websocket"]; + } + + public onRequest(req: Request, responseHeaders: Headers): Promise<Response> { + if (req.method === "GET") { + return this.onPollRequest(req, responseHeaders); + } else if (req.method === "POST") { + return this.onDataRequest(req, responseHeaders); + } + return Promise.resolve( + new Response(null, { status: 400, headers: responseHeaders }), + ); + } + + /** + * The client sends a long-polling request awaiting the server to send data. + * + * @param req + * @param responseHeaders + * @private + */ + private onPollRequest( + req: Request, + responseHeaders: Headers, + ): Promise<Response> { + if (this.pollingPromise) { + getLogger("engine.io").debug("[polling] request overlap"); + this.onError("overlap from client"); + return Promise.resolve( + new Response(null, { status: 400, headers: responseHeaders }), + ); + } + + req.signal.addEventListener("abort", () => { + // note: this gets never triggered + this.onError("poll connection closed prematurely"); + }); + + getLogger("engine.io").debug( + "[polling] new polling request", + ); + + return new Promise<Response>((resolve, reject) => { + this.pollingPromise = { resolve, reject, responseHeaders }; + + getLogger("engine.io").debug("[polling] transport is now writable"); + this.writable = true; + this.emitReserved("drain"); + }); + } + + /** + * The client sends a request with data. + * + * @param req + * @param responseHeaders + */ + private async onDataRequest( + req: Request, + responseHeaders: Headers, + ): Promise<Response> { + req.signal.addEventListener("abort", () => { + // note: this gets never triggered + this.onError("data request connection closed prematurely"); + }); + + getLogger("engine.io").debug( + "[polling] new data request", + ); + + const data = await req.text(); + + if (data.length > this.opts.maxHttpBufferSize) { + this.onError("payload too large"); + return Promise.resolve( + new Response(null, { status: 413, headers: responseHeaders }), + ); + } + + const packets = Parser.decodePayload(data); + + getLogger("engine.io").debug( + `[polling] decoded ${packets.length} packet(s)`, + ); + + for (const packet of packets) { + this.onPacket(packet); + } + + return Promise.resolve( + new Response("OK", { + status: 200, + headers: responseHeaders, + }), + ); + } + + public send(packets: Packet[]) { + this.writable = false; + Parser.encodePayload(packets, (data: string) => this.write(data)); + } + + /** + * Writes data as response to long-polling request + * + * @param data + * @private + */ + private write(data: string) { + getLogger("engine.io").debug(`[polling] writing ${data}`); + + if (!this.pollingPromise) { + return; + } + + const headers = this.pollingPromise.responseHeaders; + headers.set("Content-Type", "text/plain; charset=UTF-8"); + + // note: the HTTP server automatically handles the compression + // see https://deno.land/manual@v1.24.3/runtime/http_server_apis#automatic-body-compression + this.pollingPromise.resolve( + new Response(data, { + status: 200, + headers, + }), + ); + + this.pollingPromise = undefined; + } + + protected doClose() { + if (this.writable) { + getLogger("engine.io").debug( + "[polling] transport writable - closing right away", + ); + this.send([{ type: "close" }]); + } + + this.onClose(); + } +} diff --git a/packages/engine.io/lib/transports/websocket.ts b/packages/engine.io/lib/transports/websocket.ts new file mode 100644 index 0000000..1d2dea4 --- /dev/null +++ b/packages/engine.io/lib/transports/websocket.ts @@ -0,0 +1,69 @@ +import { getLogger } from "../../../../deps.ts"; +import { Transport } from "../transport.ts"; +import { Packet, Parser, RawData } from "../../../engine.io-parser/mod.ts"; + +export class WS extends Transport { + private socket?: WebSocket; + + public get name() { + return "websocket"; + } + + public get upgradesTo(): string[] { + return []; + } + + public send(packets: Packet[]) { + for (const packet of packets) { + Parser.encodePacket(packet, true, (data: RawData) => { + if (this.writable) { + this.socket?.send(data); + } + }); + } + } + + public onRequest(req: Request, responseHeaders: Headers): Promise<Response> { + const { socket, response } = Deno.upgradeWebSocket(req); + + this.socket = socket; + + socket.onopen = () => { + getLogger("engine.io").debug( + "[websocket] transport is now writable", + ); + this.writable = true; + this.emitReserved("drain"); + }; + + socket.onmessage = ({ data }) => { + // note: we use the length of the string here, which might be different from the number of bytes (up to 4 bytes) + const byteLength = typeof data === "string" + ? data.length + : data.byteLength; + if (byteLength > this.opts.maxHttpBufferSize) { + return this.onError("payload too large"); + } else { + this.onData(data); + } + }; + + socket.onclose = (closeEvent) => { + getLogger("engine.io").debug( + `[websocket] onclose with code ${closeEvent.code}`, + ); + this.writable = false; + this.onClose(); + }; + + responseHeaders.forEach((value, key) => { + response.headers.set(key, value); + }); + + return Promise.resolve(response); + } + + protected doClose() { + this.socket?.close(); + } +} diff --git a/packages/engine.io/lib/util.ts b/packages/engine.io/lib/util.ts new file mode 100644 index 0000000..11d83d1 --- /dev/null +++ b/packages/engine.io/lib/util.ts @@ -0,0 +1,7 @@ +import { encodeToBase64 } from "../../engine.io-parser/base64-arraybuffer.ts"; + +export function generateId(): string { + const buffer = new Uint8Array(15); + crypto.getRandomValues(buffer); + return encodeToBase64(buffer).replace(/\//g, "-").replace(/\+/g, "_"); +} diff --git a/packages/engine.io/mod.ts b/packages/engine.io/mod.ts new file mode 100644 index 0000000..d122bf9 --- /dev/null +++ b/packages/engine.io/mod.ts @@ -0,0 +1,3 @@ +export { Server, type ServerOptions } from "./lib/server.ts"; +export { type Socket } from "./lib/socket.ts"; +export { generateId } from "./lib/util.ts"; diff --git a/packages/engine.io/test/close.test.ts b/packages/engine.io/test/close.test.ts new file mode 100644 index 0000000..a6d9e2b --- /dev/null +++ b/packages/engine.io/test/close.test.ts @@ -0,0 +1,328 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + enableLogs, + parseSessionID, + sleep, + testServe, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("close", () => { + it("should trigger upon ping timeout (polling)", () => { + const engine = new Server({ + pingInterval: 5, + pingTimeout: 5, + }); + + return testServe(engine, async (port) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "ping timeout"); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + await sleep(10); + + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 400); + + // consume the response body + await pollResponse.body?.cancel(); + }); + }); + + it("should trigger upon ping timeout (ws)", () => { + const engine = new Server({ + pingInterval: 5, + pingTimeout: 5, + }); + + return testServeWithAsyncResults(engine, 2, (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "ping timeout"); + + partialDone(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onclose = partialDone; + }); + }); + + it("should trigger when the server closes the socket (polling)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "forced close"); + + partialDone(); + }); + + setTimeout(() => socket.close(), 10); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + const body = await pollResponse.text(); + + assertEquals(body, "1"); + + partialDone(); + }); + }); + + it("should trigger when the server closes the socket (ws)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "forced close"); + + partialDone(); + }); + + socket.close(); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onopen = () => { + socket.onclose = partialDone; + }; + }); + }); + + it("should trigger when the client sends a 'close' packet (polling)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "transport close"); + + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "1", + }, + ); + + assertEquals(dataResponse.status, 200); + + const body = await dataResponse.text(); + + assertEquals(body, "OK"); + + partialDone(); + }); + }); + + it("should trigger when the client sends a 'close' packet (ws)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "transport close"); + + partialDone(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = () => { + socket.send("1"); + socket.onclose = partialDone; + }; + }); + }); + + it.ignore("should trigger when the client closes the connection (polling)", () => { + // TODO + }); + + it("should trigger when the client closes the connection (ws)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "transport close"); + partialDone(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = () => { + socket.close(); + partialDone(); + }; + }); + }); + + it("should trigger when the client sends ill-formatted data", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "parse error"); + + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "abc", + }, + ); + + assertEquals(dataResponse.status, 200); + + const body = await dataResponse.text(); + + assertEquals(body, "OK"); + + partialDone(); + }); + }); + + it("should trigger when the client sends a payload bigger than maxHttpBufferSize (polling)", () => { + const engine = new Server({ + maxHttpBufferSize: 100, + }); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "transport error"); + + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "a".repeat(101), + }, + ); + + assertEquals(dataResponse.status, 413); + + // consume the response body + await dataResponse.body?.cancel(); + + partialDone(); + }); + }); + + it("should trigger when the client sends a payload bigger than maxHttpBufferSize (ws)", () => { + const engine = new Server({ + maxHttpBufferSize: 100, + }); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + engine.on("connection", (socket) => { + socket.on("close", (reason) => { + assertEquals(reason, "transport error"); + done(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = () => { + socket.send("b".repeat(101)); + }; + }); + }); +}); diff --git a/packages/engine.io/test/cors.test.ts b/packages/engine.io/test/cors.test.ts new file mode 100644 index 0000000..80e4373 --- /dev/null +++ b/packages/engine.io/test/cors.test.ts @@ -0,0 +1,190 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { enableLogs, parseSessionID, testServe } from "./util.ts"; + +await enableLogs(); + +describe("CORS", () => { + it("should send the CORS headers for an authorized origin (preflight request)", () => { + const engine = new Server({ + cors: { + origin: ["https://example.com"], + exposedHeaders: ["my-other-header"], + credentials: true, + methods: ["GET", "POST"], + allowedHeaders: ["my-header"], + maxAge: 42, + }, + }); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "OPTIONS", + headers: { + origin: "https://example.com", + }, + }, + ); + + assertEquals(response.status, 204); + assertEquals( + response.headers.get("Access-Control-Allow-Origin"), + "https://example.com", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Credentials"), + "true", + ); + assertEquals( + response.headers.get("Access-Control-Expose-Headers"), + "my-other-header", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Methods"), + "GET,POST", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Headers"), + "my-header", + ); + assertEquals(response.headers.get("Access-Control-Max-Age"), "42"); + + // consume the response body + await response.body?.cancel(); + }); + }); + + it("should send the CORS headers for an authorized origin (actual request)", () => { + const engine = new Server({ + cors: { + origin: ["https://example.com"], + exposedHeaders: ["my-other-header"], + credentials: true, + methods: ["GET", "POST"], + allowedHeaders: ["my-header"], + maxAge: 42, + }, + }); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "GET", + headers: { + origin: "https://example.com", + }, + }, + ); + + assertEquals(response.status, 200); + assertEquals( + response.headers.get("Access-Control-Allow-Origin"), + "https://example.com", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Credentials"), + "true", + ); + assertEquals( + response.headers.get("Access-Control-Expose-Headers"), + "my-other-header", + ); + assertEquals(response.headers.has("Access-Control-Allow-Methods"), false); + assertEquals(response.headers.has("Access-Control-Allow-Headers"), false); + assertEquals(response.headers.has("Access-Control-Max-Age"), false); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "POST", + body: "1", + headers: { + origin: "https://example.com", + }, + }, + ); + + assertEquals(dataResponse.status, 200); + assertEquals( + dataResponse.headers.get("Access-Control-Allow-Origin"), + "https://example.com", + ); + assertEquals( + dataResponse.headers.get("Access-Control-Allow-Credentials"), + "true", + ); + assertEquals( + dataResponse.headers.get("Access-Control-Expose-Headers"), + "my-other-header", + ); + assertEquals( + dataResponse.headers.has("Access-Control-Allow-Methods"), + false, + ); + assertEquals( + dataResponse.headers.has("Access-Control-Allow-Headers"), + false, + ); + assertEquals(dataResponse.headers.has("Access-Control-Max-Age"), false); + + // consume the response body + await dataResponse.body?.cancel(); + }); + }); + + it("should not send the CORS headers for an unauthorized origin", () => { + const engine = new Server({ + cors: { + origin: ["https://example.com"], + exposedHeaders: ["my-other-header"], + credentials: true, + methods: ["GET", "POST"], + allowedHeaders: ["my-header"], + maxAge: 42, + }, + }); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "OPTIONS", + headers: { + origin: "https://wrong-domain.com", + }, + }, + ); + + assertEquals(response.status, 204); + assertEquals( + response.headers.get("Access-Control-Allow-Origin"), + "false", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Credentials"), + "true", + ); + assertEquals( + response.headers.get("Access-Control-Expose-Headers"), + "my-other-header", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Methods"), + "GET,POST", + ); + assertEquals( + response.headers.get("Access-Control-Allow-Headers"), + "my-header", + ); + assertEquals(response.headers.get("Access-Control-Max-Age"), "42"); + + // consume the response body + await response.body?.cancel(); + }); + }); +}); diff --git a/packages/engine.io/test/handshake.test.ts b/packages/engine.io/test/handshake.test.ts new file mode 100644 index 0000000..f606495 --- /dev/null +++ b/packages/engine.io/test/handshake.test.ts @@ -0,0 +1,124 @@ +import { + assertEquals, + assertExists, + describe, + it, +} from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { enableLogs, testServe, testServeWithAsyncResults } from "./util.ts"; + +await enableLogs(); + +describe("handshake", () => { + it("should send handshake data (polling)", () => { + const engine = new Server(); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const body = await response.text(); + assertEquals(body[0], "0"); + + const handshake = JSON.parse(body.substring(1)); + assertExists(handshake.sid); + assertEquals(handshake.pingTimeout, 20000); + assertEquals(handshake.pingInterval, 25000); + assertEquals(handshake.upgrades, ["websocket"]); + assertEquals(handshake.maxPayload, 1000000); + }); + }); + + it("should send handshake data with custom values (polling)", () => { + const engine = new Server({ + pingInterval: 100, + pingTimeout: 200, + maxHttpBufferSize: 300, + }); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const body = await response.text(); + assertEquals(body[0], "0"); + + const handshake = JSON.parse(body.substring(1)); + assertExists(handshake.sid); + assertEquals(handshake.pingInterval, 100); + assertEquals(handshake.pingTimeout, 200); + assertEquals(handshake.maxPayload, 300); + }); + }); + + it("should send handshake data (ws)", () => { + const engine = new Server(); + + return testServeWithAsyncResults( + engine, + 1, + (port, done) => { + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = (event) => { + assertEquals(event.data[0], "0"); + + const handshake = JSON.parse(event.data.substring(1)); + assertExists(handshake.sid); + assertEquals(handshake.pingTimeout, 20000); + assertEquals(handshake.pingInterval, 25000); + assertEquals(handshake.upgrades, []); + assertEquals(handshake.maxPayload, 1000000); + + socket.close(); + done(); + }; + }, + ); + }); + + it("should trigger a connection event", () => { + const engine = new Server(); + + return testServeWithAsyncResults( + engine, + 2, + async (port, partialDone) => { + engine.on("connection", (socket) => { + assertExists(socket.id); + assertEquals(socket.transport.name, "polling"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + // consume the response body + await response.body?.cancel(); + + partialDone(); + }, + ); + }); +}); diff --git a/packages/engine.io/test/heartbeat.test.ts b/packages/engine.io/test/heartbeat.test.ts new file mode 100644 index 0000000..37cd0d8 --- /dev/null +++ b/packages/engine.io/test/heartbeat.test.ts @@ -0,0 +1,86 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + enableLogs, + parseSessionID, + testServe, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("heartbeat", () => { + it("should keep the connection alive (polling)", () => { + const engine = new Server({ + pingInterval: 10, + pingTimeout: 25, + }); + + return testServe(engine, async (port) => { + const handshake = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(handshake); + const HEARTBEAT_COUNT = 10; + + for (let i = 0; i < HEARTBEAT_COUNT; i++) { + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + assertEquals(await pollResponse.text(), "2"); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "3", + }, + ); + + assertEquals(dataResponse.status, 200); + // consume the response body + await dataResponse.body?.cancel(); + } + }); + }); + + it("should keep the connection alive (ws)", () => { + const engine = new Server({ + pingInterval: 10, + pingTimeout: 25, + }); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + let i = 0; + const HEARTBEAT_COUNT = 10; + + socket.onmessage = ({ data }) => { + switch (i++) { + case 0: + // ignore handshake + break; + case HEARTBEAT_COUNT: + socket.close(); + done(); + break; + default: + assertEquals(data, "2"); + socket.send("3"); + } + }; + }); + }); +}); diff --git a/packages/engine.io/test/messages.test.ts b/packages/engine.io/test/messages.test.ts new file mode 100644 index 0000000..531b712 --- /dev/null +++ b/packages/engine.io/test/messages.test.ts @@ -0,0 +1,332 @@ +import { + assertEquals, + assertInstanceOf, + describe, + it, +} from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + enableLogs, + parseSessionID, + testServe, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("messages", () => { + it("should arrive from server to client (polling, plain-text)", () => { + const engine = new Server(); + + return testServe(engine, async (port) => { + engine.on("connection", (socket) => { + socket.send("hello €亜Б"); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + const body = await pollResponse.text(); + + assertEquals(body, "4hello €亜Б"); + }); + }); + + it("should arrive from server to client (polling, binary)", () => { + const engine = new Server(); + + return testServe(engine, async (port) => { + engine.on("connection", (socket) => { + socket.send(Uint8Array.from([1, 2, 3, 4])); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + const body = await pollResponse.text(); + + assertEquals(body, "bAQIDBA=="); + }); + }); + + it("should arrive from server to client (polling, mixed)", () => { + const engine = new Server(); + + return testServe(engine, async (port) => { + engine.on("connection", (socket) => { + socket.send(Uint8Array.from([1, 2, 3, 4])); + socket.send("hello €亜Б"); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const pollResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + const body = await pollResponse.text(); + + assertEquals(body, "bAQIDBA==\x1e4hello €亜Б"); + }); + }); + + it("should arrive from server to client (ws, plain-text)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + engine.on("connection", (socket) => { + socket.send("hello €亜Б"); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = ({ data }) => { + if (typeof data === "string" && data[0] === "0") { + // ignore handshake + return; + } + + assertEquals(data, "4hello €亜Б"); + + done(); + }; + }); + }); + + it("should arrive from server to client (ws, binary)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + engine.on("connection", (socket) => { + socket.send(Uint8Array.from([1, 2, 3, 4])); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.binaryType = "arraybuffer"; + + socket.onmessage = ({ data }) => { + if (typeof data === "string" && data[0] === "0") { + // ignore handshake + return; + } + + assertEquals(new Uint8Array(data), Uint8Array.from([1, 2, 3, 4])); + + done(); + }; + }); + }); + + it("should arrive from client to server (polling, plain-text)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("message", (val) => { + assertEquals(val, "hello €亜Б"); + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "4hello €亜Б", + }, + ); + + // consume the response body + await dataResponse.body?.cancel(); + + partialDone(); + }); + }); + + it("should arrive from client to server (polling, binary)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("message", (val) => { + assertInstanceOf(val, ArrayBuffer); + assertEquals( + new Uint8Array(val), + Uint8Array.from([1, 2, 3, 4]), + ); + + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "bAQIDBA==", + }, + ); + + // consume the response body + await dataResponse.body?.cancel(); + + partialDone(); + }); + }); + + it("should arrive from client to server (polling, mixed)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 3, async (port, partialDone) => { + engine.on("connection", (socket) => { + let count = 0; + + socket.on("message", (val) => { + if (++count === 1) { + assertInstanceOf(val, ArrayBuffer); + assertEquals(new Uint8Array(val), Uint8Array.from([1, 2, 3, 4])); + } else { + assertEquals(val, "hello €亜Б"); + } + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "bAQIDBA==\x1e4hello €亜Б", + }, + ); + + // consume the response body + await dataResponse.body?.cancel(); + + partialDone(); + }); + }); + + it("should arrive from client to server (ws, plain-text)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + engine.on("connection", (socket) => { + socket.on("message", (val) => { + assertEquals(val, "hello €亜Б"); + done(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onmessage = () => { + socket.send("4hello €亜Б"); + }; + }); + }); + + it("should arrive from client to server (ws, binary)", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 1, (port, done) => { + engine.on("connection", (socket) => { + socket.on("message", (val) => { + assertInstanceOf(val, ArrayBuffer); + assertEquals( + new Uint8Array(val), + Uint8Array.from([1, 2, 3, 4]), + ); + + done(); + }); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.binaryType = "arraybuffer"; + + socket.onmessage = () => { + socket.send(Uint8Array.from([1, 2, 3, 4])); + }; + }); + }); +}); diff --git a/packages/engine.io/test/response_headers.test.ts b/packages/engine.io/test/response_headers.test.ts new file mode 100644 index 0000000..3d80299 --- /dev/null +++ b/packages/engine.io/test/response_headers.test.ts @@ -0,0 +1,46 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { enableLogs, parseSessionID, testServe } from "./util.ts"; + +await enableLogs(); + +describe("response headers", () => { + it("should send custom response headers", () => { + const engine = new Server({ + editHandshakeHeaders: (responseHeaders) => { + responseHeaders.set("abc", "123"); + }, + editResponseHeaders: (responseHeaders) => { + responseHeaders.set("def", "456"); + }, + }); + + return testServe(engine, async (port) => { + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.headers.get("abc"), "123"); + assertEquals(response.headers.get("def"), "456"); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "4hello", + }, + ); + + assertEquals(dataResponse.headers.has("abc"), false); + assertEquals(dataResponse.headers.get("def"), "456"); + + // consume the response body + await dataResponse.body?.cancel(); + }); + }); +}); diff --git a/packages/engine.io/test/upgrade.test.ts b/packages/engine.io/test/upgrade.test.ts new file mode 100644 index 0000000..0d64ef8 --- /dev/null +++ b/packages/engine.io/test/upgrade.test.ts @@ -0,0 +1,91 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + enableLogs, + parseSessionID, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("upgrade", () => { + it("should upgrade", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("message", (val) => { + assertEquals(val, "upgraded!"); + + partialDone(); + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket&sid=${sid}`, + ); + + socket.onopen = () => { + socket.send("2probe"); // ping packet + }; + + socket.onmessage = ({ data }) => { + assertEquals(data, "3probe"); // pong packet + socket.send("5"); // upgrade packet + socket.send("4upgraded!"); + + partialDone(); + }; + }); + }); + + it("should timeout if the upgrade takes too much time", () => { + const engine = new Server({ + upgradeTimeout: 5, + }); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection", (socket) => { + socket.on("upgrading", (transport) => { + transport.on("close", partialDone); + }); + + socket.on("upgrade", () => { + throw "should not happen"; + }); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket&sid=${sid}`, + ); + + socket.onopen = () => { + socket.send("2probe"); // ping packet + }; + + socket.onmessage = ({ data }) => { + assertEquals(data, "3probe"); // pong packet + + partialDone(); + }; + }); + }); +}); diff --git a/packages/engine.io/test/util.ts b/packages/engine.io/test/util.ts new file mode 100644 index 0000000..c3c6b2c --- /dev/null +++ b/packages/engine.io/test/util.ts @@ -0,0 +1,86 @@ +import { Server } from "../lib/server.ts"; +import * as log from "../../../test_deps.ts"; +import { serve } from "../../../test_deps.ts"; + +export function testServe( + engine: Server, + callback: (port: number) => Promise<void>, +): Promise<void> { + return new Promise((resolve) => { + const abortController = new AbortController(); + + serve(engine.handler(), { + onListen: async ({ port }) => { + await callback(port); + + // close the server + abortController.abort(); + engine.close(); + setTimeout(resolve, 5); + }, + signal: abortController.signal, + }); + }); +} + +function createPartialDone( + count: number, + resolve: () => void, + reject: (reason: string) => void, +) { + let i = 0; + return () => { + if (++i === count) { + resolve(); + } else if (i > count) { + reject(`called too many times: ${i} > ${count}`); + } + }; +} + +export function testServeWithAsyncResults( + engine: Server, + count: number, + callback: (port: number, partialDone: () => void) => Promise<void> | void, +): Promise<void> { + return new Promise((resolve, reject) => { + const abortController = new AbortController(); + + serve(engine.handler(), { + onListen: ({ port }) => { + const partialDone = createPartialDone(count, () => { + // close the server + abortController.abort(); + engine.close(); + setTimeout(resolve, 10); + }, reject); + + return callback(port, partialDone); + }, + signal: abortController.signal, + }); + }); +} + +export function sleep(duration: number): Promise<void> { + return new Promise((resolve) => setTimeout(resolve, duration)); +} + +export async function parseSessionID(response: Response): Promise<string> { + const body = await response.text(); + return JSON.parse(body.substring(1)).sid; +} + +export function enableLogs() { + return log.setup({ + handlers: { + console: new log.handlers.ConsoleHandler("DEBUG"), + }, + loggers: { + "engine.io": { + level: "ERROR", // set to "DEBUG" to display the logs + handlers: ["console"], + }, + }, + }); +} diff --git a/packages/engine.io/test/verification.test.ts b/packages/engine.io/test/verification.test.ts new file mode 100644 index 0000000..c0d6d34 --- /dev/null +++ b/packages/engine.io/test/verification.test.ts @@ -0,0 +1,241 @@ +import { + assertEquals, + assertExists, + describe, + it, +} from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { enableLogs, testServe, testServeWithAsyncResults } from "./util.ts"; + +await enableLogs(); + +describe("verification", () => { + it("should ignore requests that do not match the given path", () => { + const engine = new Server(); + + return testServe(engine, async (port) => { + const response = await fetch(`http://localhost:${port}/test/`, { + method: "get", + }); + + assertEquals(response.status, 404); + + // consume the response body + await response.body?.cancel(); + }); + }); + + it("should disallow non-existent transports", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertEquals(err.code, 0); + assertEquals(err.message, "Transport unknown"); + assertEquals(err.context.transport, "tobi"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?transport=tobi`, + { + method: "get", + }, + ); + + assertEquals(response.status, 400); + + const body = await response.json(); + assertEquals(body.code, 0); + assertEquals(body.message, "Transport unknown"); + + partialDone(); + }); + }); + + it("should disallow `constructor` as transports", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertEquals(err.code, 0); + assertEquals(err.message, "Transport unknown"); + assertEquals(err.context.transport, "constructor"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?transport=constructor`, + { + method: "get", + }, + ); + + assertEquals(response.status, 400); + + const body = await response.json(); + assertEquals(body.code, 0); + assertEquals(body.message, "Transport unknown"); + + partialDone(); + }); + }); + + it("should disallow non-existent sids", () => { + const engine = new Server(); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertEquals(err.code, 1); + assertEquals(err.message, "Session ID unknown"); + assertEquals(err.context.sid, "test"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?transport=polling&sid=test`, + { + method: "get", + }, + ); + + assertEquals(response.status, 400); + + const body = await response.json(); + assertEquals(body.code, 1); + assertEquals(body.message, "Session ID unknown"); + + partialDone(); + }); + }); + + it("should disallow requests that are rejected by `allowRequest` (polling)", () => { + const engine = new Server({ + allowRequest: () => { + return Promise.reject("Thou shall not pass"); + }, + }); + + return testServeWithAsyncResults(engine, 2, async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 4); + assertEquals(err.message, "Forbidden"); + assertEquals(err.context.message, "Thou shall not pass"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 403); + + const body = await response.json(); + assertEquals(body.code, 4); + assertEquals(body.message, "Thou shall not pass"); + + partialDone(); + }); + }); + + it("should disallow requests that are rejected by `allowRequest` (ws)", () => { + const engine = new Server({ + allowRequest: () => { + return Promise.reject("Thou shall not pass"); + }, + }); + + return testServeWithAsyncResults(engine, 2, (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 4); + assertEquals(err.message, "Forbidden"); + assertEquals(err.context.message, "Thou shall not pass"); + + partialDone(); + }); + + const socket = new WebSocket( + `ws://localhost:${port}/engine.io/?EIO=4&transport=websocket`, + ); + + socket.onclose = partialDone; + }); + }); + + it("should disallow invalid handshake method", () => { + const engine = new Server(); + + return testServeWithAsyncResults( + engine, + 2, + async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 2); + assertEquals(err.message, "Bad handshake method"); + assertEquals(err.context.method, "PUT"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?transport=polling`, + { + method: "put", + }, + ); + + assertEquals(response.status, 400); + + const body = await response.json(); + assertEquals(body.code, 2); + assertEquals(body.message, "Bad handshake method"); + + partialDone(); + }, + ); + }); + + it("should disallow unsupported protocol versions", () => { + const engine = new Server(); + + return testServeWithAsyncResults( + engine, + 2, + async (port, partialDone) => { + engine.on("connection_error", (err) => { + assertExists(err.req); + assertEquals(err.code, 5); + assertEquals(err.message, "Unsupported protocol version"); + assertEquals(err.context.protocol, 3); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/engine.io/?EIO=3&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 400); + + const body = await response.json(); + assertEquals(body.code, 5); + assertEquals(body.message, "Unsupported protocol version"); + + partialDone(); + }, + ); + }); +}); diff --git a/packages/event-emitter/mod.ts b/packages/event-emitter/mod.ts new file mode 100644 index 0000000..10572aa --- /dev/null +++ b/packages/event-emitter/mod.ts @@ -0,0 +1,230 @@ +/** + * An events map is an interface that maps event names to their value, which represents the type of the `on` listener. + */ +export interface EventsMap { + [event: string]: any; +} + +/** + * The default events map, used if no EventsMap is given. Using this EventsMap is equivalent to accepting all event + * names, and any data. + */ +export interface DefaultEventsMap { + [event: string]: (...args: any[]) => void; +} + +/** + * Returns a union type containing all the keys of an event map. + */ +export type EventNames<Map extends EventsMap> = keyof Map & (string | symbol); + +/** The tuple type representing the parameters of an event listener */ +export type EventParams< + Map extends EventsMap, + Ev extends EventNames<Map>, +> = Parameters<Map[Ev]>; + +/** + * The event names that are either in ReservedEvents or in UserEvents + */ +export type ReservedOrUserEventNames< + ReservedEventsMap extends EventsMap, + UserEvents extends EventsMap, +> = EventNames<ReservedEventsMap> | EventNames<UserEvents>; + +/** + * Type of a listener of a user event or a reserved event. If `Ev` is in `ReservedEvents`, the reserved event listener + * is returned. + */ +export type ReservedOrUserListener< + ReservedEvents extends EventsMap, + UserEvents extends EventsMap, + Ev extends ReservedOrUserEventNames<ReservedEvents, UserEvents>, +> = FallbackToUntypedListener< + Ev extends EventNames<ReservedEvents> ? ReservedEvents[Ev] + : Ev extends EventNames<UserEvents> ? UserEvents[Ev] + : never +>; + +/** + * Returns an untyped listener type if `T` is `never`; otherwise, returns `T`. + * + * Needed because of https://github.com/microsoft/TypeScript/issues/41778 + */ +type FallbackToUntypedListener<T> = [T] extends [never] + ? (...args: any[]) => void | Promise<void> + : T; + +/** + * Strictly typed version of an `EventEmitter`. A `TypedEventEmitter` takes type parameters for mappings of event names + * to event data types, and strictly types method calls to the `EventEmitter` according to these event maps. + * + * @typeParam ListenEvents - `EventsMap` of user-defined events that can be listened to with `on` or `once` + * @typeParam EmitEvents - `EventsMap` of user-defined events that can be emitted with `emit` + * @typeParam ReservedEvents - `EventsMap` of reserved events, that can be emitted with `emitReserved`, and can be + * listened to with `listen`. + */ +abstract class BaseEventEmitter< + ListenEvents extends EventsMap, + EmitEvents extends EventsMap, + ReservedEvents extends EventsMap = never, +> { + private _listeners: Map< + ReservedOrUserEventNames<ReservedEvents, ListenEvents>, + Array<ReservedOrUserListener<ReservedEvents, ListenEvents, any>> + > = new Map(); + + /** + * Adds the `listener` function as an event listener for `ev`. + * + * @param event - Name of the event + * @param listener - Callback function + */ + public on<Ev extends ReservedOrUserEventNames<ReservedEvents, ListenEvents>>( + event: Ev, + listener: ReservedOrUserListener<ReservedEvents, ListenEvents, Ev>, + ): this { + const listeners = this._listeners.get(event); + if (listeners) { + listeners.push(listener); + } else { + this._listeners.set(event, [listener]); + } + return this; + } + + /** + * Adds a one-time `listener` function as an event listener for `ev`. + * + * @param event - Name of the event + * @param listener - Callback function + */ + public once< + Ev extends ReservedOrUserEventNames<ReservedEvents, ListenEvents>, + >( + event: Ev, + listener: ReservedOrUserListener<ReservedEvents, ListenEvents, Ev>, + ): this { + // @ts-ignore force listener type + const onceListener: ReservedOrUserListener< + ReservedEvents, + ListenEvents, + Ev + > = (...args: any[]) => { + this.off(event, onceListener); + listener.apply(this, args); + }; + + // to work with .off(event, listener) + onceListener.fn = listener; + + return this.on(event, onceListener); + } + + /** + * Removes the `listener` function as an event listener for `ev`. + * + * @param event - Name of the event + * @param listener - Callback function + */ + public off<Ev extends ReservedOrUserEventNames<ReservedEvents, ListenEvents>>( + event?: Ev, + listener?: ReservedOrUserListener<ReservedEvents, ListenEvents, Ev>, + ): this { + if (!event) { + this._listeners.clear(); + return this; + } + + if (!listener) { + this._listeners.delete(event); + return this; + } + + const listeners = this._listeners.get(event); + + if (!listeners) { + return this; + } + + for (let i = 0; i < listeners.length; i++) { + if (listeners[i] === listener || listeners[i].fn === listener) { + listeners.splice(i, 1); + break; + } + } + + if (listeners.length === 0) { + this._listeners.delete(event); + } + + return this; + } + + /** + * Emits an event. + * + * @param event - Name of the event + * @param args - Values to send to listeners of this event + */ + public emit<Ev extends EventNames<EmitEvents>>( + event: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + const listeners = this._listeners.get(event as EventNames<ListenEvents>); + + if (!listeners) { + return false; + } + + if (listeners.length === 1) { + listeners[0].apply(this, args); + } else { + for (const listener of listeners.slice()) { + listener.apply(this, args); + } + } + + return true; + } + + /** + * Returns the listeners listening to an event. + * + * @param event - Event name + * @returns Array of listeners subscribed to `event` + */ + public listeners< + Ev extends ReservedOrUserEventNames<ReservedEvents, ListenEvents>, + >( + event: Ev, + ): ReservedOrUserListener<ReservedEvents, ListenEvents, Ev>[] { + return this._listeners.get(event) || []; + } +} + +/** + * This class extends the BaseEventEmitter abstract class, so a class extending `EventEmitter` can override the `emit` + * method and still call `emitReserved()` (since it uses `super.emit()`) + */ +export class EventEmitter< + ListenEvents extends EventsMap, + EmitEvents extends EventsMap, + ReservedEvents extends EventsMap = never, +> extends BaseEventEmitter<ListenEvents, EmitEvents, ReservedEvents> { + /** + * Emits a reserved event. + * + * This method is `protected`, so that only a class extending `EventEmitter` can emit its own reserved events. + * + * @param event - Reserved event name + * @param args - Arguments to emit along with the event + * @protected + */ + protected emitReserved<Ev extends EventNames<ReservedEvents>>( + event: Ev, + ...args: EventParams<ReservedEvents, Ev> + ): boolean { + return super.emit(event as EventNames<EmitEvents>, ...args); + } +} diff --git a/packages/event-emitter/test.ts b/packages/event-emitter/test.ts new file mode 100644 index 0000000..1fa4ee1 --- /dev/null +++ b/packages/event-emitter/test.ts @@ -0,0 +1,245 @@ +import { EventEmitter } from "./mod.ts"; +import { assertEquals, describe, it } from "../../test_deps.ts"; + +describe("EventEmitter", () => { + describe(".on(event, fn)", () => { + it("should add listeners", () => { + const emitter = new EventEmitter(); + const calls: Array<string | number> = []; + + emitter.on("foo", (val: number) => { + calls.push("one", val); + }); + + emitter.on("foo", (val: number) => { + calls.push("two", val); + }); + + emitter.emit("foo", 1); + emitter.emit("bar", 1); + emitter.emit("foo", 2); + + assertEquals(calls, ["one", 1, "two", 1, "one", 2, "two", 2]); + }); + + it("should add listeners for events which are same names with methods of Object.prototype", () => { + const emitter = new EventEmitter(); + const calls: Array<string | number> = []; + + emitter.on("constructor", (val: number) => { + calls.push("one", val); + }); + + emitter.on("__proto__", (val: number) => { + calls.push("two", val); + }); + + emitter.emit("constructor", 1); + emitter.emit("__proto__", 2); + + assertEquals(calls, ["one", 1, "two", 2]); + }); + }); + + describe(".once(event, fn)", () => { + it("should add a single-shot listener", () => { + const emitter = new EventEmitter(); + const calls: Array<string | number> = []; + + emitter.once("foo", (val: number) => { + calls.push("one", val); + }); + + emitter.emit("foo", 1); + emitter.emit("foo", 2); + emitter.emit("foo", 3); + emitter.emit("bar", 1); + + assertEquals(calls, ["one", 1]); + }); + }); + + describe(".off(event, fn)", () => { + it("should remove a listener", () => { + const emitter = new EventEmitter(); + const calls: string[] = []; + + function one() { + calls.push("one"); + } + function two() { + calls.push("two"); + } + + emitter.on("foo", one); + emitter.on("foo", two); + emitter.off("foo", two); + + emitter.emit("foo"); + + assertEquals(calls, ["one"]); + }); + + it("should work with .once()", () => { + const emitter = new EventEmitter(); + const calls: string[] = []; + + function one() { + calls.push("one"); + } + + emitter.once("foo", one); + emitter.once("fee", one); + emitter.off("foo", one); + + emitter.emit("foo"); + + assertEquals(calls, []); + }); + + it("should work when called from an event", () => { + const emitter = new EventEmitter(); + let called; + + function b() { + called = true; + } + emitter.on("tobi", () => { + emitter.off("tobi", b); + }); + emitter.on("tobi", b); + emitter.emit("tobi"); + + assertEquals(called, true); + + called = false; + emitter.emit("tobi"); + + assertEquals(called, false); + }); + }); + + describe(".off(event)", () => { + it("should remove all listeners for an event", () => { + const emitter = new EventEmitter(); + const calls: string[] = []; + + function one() { + calls.push("one"); + } + function two() { + calls.push("two"); + } + + emitter.on("foo", one); + emitter.on("foo", two); + emitter.off("foo"); + + emitter.emit("foo"); + emitter.emit("foo"); + + assertEquals(calls, []); + }); + + it("should remove event array to avoid memory leak", () => { + const emitter = new EventEmitter(); + + function cb() {} + + emitter.on("foo", cb); + emitter.off("foo", cb); + + // @ts-ignore check internal state + assertEquals(emitter._listeners.has("foo"), false); + }); + + it("should only remove the event array when the last subscriber unsubscribes", () => { + const emitter = new EventEmitter(); + + function cb1() {} + function cb2() {} + + emitter.on("foo", cb1); + emitter.on("foo", cb2); + emitter.off("foo", cb1); + + // @ts-ignore check internal state + assertEquals(emitter._listeners.has("foo"), true); + }); + }); + + describe(".off()", () => { + it("should remove all listeners", () => { + const emitter = new EventEmitter(); + const calls: string[] = []; + + function one() { + calls.push("one"); + } + function two() { + calls.push("two"); + } + + emitter.on("foo", one); + emitter.on("bar", two); + + emitter.emit("foo"); + emitter.emit("bar"); + + emitter.off(); + + emitter.emit("foo"); + emitter.emit("bar"); + + assertEquals(calls, ["one", "two"]); + }); + }); + + describe(".listeners(event)", () => { + describe("when handlers are present", () => { + it("should return an array of callbacks", () => { + const emitter = new EventEmitter(); + + function foo() {} + emitter.on("foo", foo); + + assertEquals(emitter.listeners("foo"), [foo]); + }); + }); + + describe("when no handlers are present", () => { + it("should return an empty array", () => { + const emitter = new EventEmitter(); + + assertEquals(emitter.listeners("foo"), []); + }); + }); + }); + + describe("typed events", () => { + interface ListenEvents { + foo: () => void; + } + + interface EmitEvents { + bar: (foo: number) => void; + } + + interface ReservedEvents { + foobar: (bar: string) => void; + } + + class CustomEmitter + extends EventEmitter<ListenEvents, EmitEvents, ReservedEvents> { + foobar() { + this.emitReserved("foobar", "1"); + } + } + + const emitter = new CustomEmitter(); + emitter.on("foo", () => {}); + emitter.on("foobar", (_bar) => {}); + emitter.emit("bar", 1); + emitter.foobar(); + }); +}); diff --git a/packages/socket.io-parser/mod.ts b/packages/socket.io-parser/mod.ts new file mode 100644 index 0000000..f3f86d2 --- /dev/null +++ b/packages/socket.io-parser/mod.ts @@ -0,0 +1,276 @@ +import { type RawData } from "../engine.io-parser/mod.ts"; +import { EventEmitter } from "../event-emitter/mod.ts"; +import { getLogger } from "../../deps.ts"; + +export enum PacketType { + CONNECT, + DISCONNECT, + EVENT, + ACK, + CONNECT_ERROR, + BINARY_EVENT, + BINARY_ACK, +} + +export interface Packet { + type: PacketType; + nsp: string; + data?: any; + id?: number; + attachments?: number; +} + +type Attachments = ArrayBuffer | ArrayBufferView | Blob; + +export class Encoder { + public encode(packet: Packet): RawData[] { + if (packet.type === PacketType.EVENT || packet.type === PacketType.ACK) { + const attachments: Attachments[] = []; + extractAttachments(packet.data, attachments); + if (attachments.length) { + packet.attachments = attachments.length; + packet.type = packet.type === PacketType.EVENT + ? PacketType.BINARY_EVENT + : PacketType.BINARY_ACK; + return [encodeAsString(packet), ...attachments]; + } + } + return [encodeAsString(packet)]; + } +} + +function encodeAsString(packet: Packet): string { + let output = "" + packet.type; + + if ( + packet.type === PacketType.BINARY_EVENT || + packet.type === PacketType.BINARY_ACK + ) { + output += packet.attachments + "-"; + } + + if (packet.nsp !== "/") { + output += packet.nsp + ","; + } + + if (packet.id !== undefined) { + output += packet.id; + } + + if (packet.data) { + output += JSON.stringify(packet.data); + } + + getLogger("socket.io").debug(`[parser] encoded packet as ${output}`); + + return output; +} + +/** + * Extract the binary attachments from the payload + * @param data + * @param attachments + */ +function extractAttachments(data: any, attachments: Attachments[]) { + if (Array.isArray(data)) { + for (let i = 0; i < data.length; i++) { + const elem = data[i]; + if (isAttachment(elem)) { + data[i] = { _placeholder: true, num: attachments.length }; + attachments.push(elem); + } else { + extractAttachments(data[i], attachments); + } + } + } else if (typeof data === "object" && !(data instanceof Date)) { + for (const key in data) { + if (Object.prototype.hasOwnProperty.call(data, key)) { + const elem = data[key]; + if (isAttachment(elem)) { + data[key] = { _placeholder: true, num: attachments.length }; + attachments.push(elem); + } else { + extractAttachments(data[key], attachments); + } + } + } + } +} + +function isAttachment(data: unknown): boolean { + return data instanceof ArrayBuffer || ArrayBuffer.isView(data) || + data instanceof Blob; +} + +interface DecoderEvents { + packet: (packet: Packet) => void; + error: () => void; +} + +export class Decoder extends EventEmitter< + Record<never, never>, + Record<never, never>, + DecoderEvents +> { + private buffer?: { + packet: Packet; + attachments: Attachments[]; + }; + + public add(data: RawData): void { + if (typeof data === "string") { + if (this.buffer) { + getLogger("socket.io").debug( + "[parser] got plaintext data while reconstructing a packet", + ); + this.emitReserved("error"); + return; + } + + const packet = decodeString(data); + + if (packet === null) { + this.emitReserved("error"); + } else if (packet.attachments) { + this.buffer = { + packet, + attachments: [], + }; + } else { + this.emitReserved("packet", packet); + } + } else { + if (!this.buffer) { + getLogger("socket.io").debug( + "[parser] got plaintext data while not reconstructing a packet", + ); + this.emitReserved("error"); + return; + } + const { packet, attachments } = this.buffer; + attachments.push(data); + if (attachments.length === packet.attachments) { + injectAttachments(packet.data, attachments); + this.emitReserved("packet", packet); + this.buffer = undefined; + } + } + } + + public destroy() { + this.buffer = undefined; + } +} + +function decodeString(str: string): Packet | null { + const packet: Partial<Packet> = {}; + + const type = parseInt(str.charAt(0), 10); + if (PacketType[type] === undefined) { + getLogger("socket.io").debug(`[parser] unknown packet type: ${type}`); + return null; + } + packet.type = type; + + let i = 0; + if ( + type === PacketType.BINARY_EVENT || + type === PacketType.BINARY_ACK + ) { + const start = i + 1; + while (str.charAt(++i) !== "-" && i !== str.length) { + // advance cursor + } + const attachments = parseInt(str.substring(start, i), 10); + if (str.charAt(i) !== "-" || !isFinite(attachments) || attachments < 0) { + getLogger("socket.io").debug( + `[parser] illegal attachment count: ${attachments}`, + ); + return null; + } + packet.attachments = attachments; + } + + if (str.charAt(i + 1) === "/") { + const start = i + 1; + while (str.charAt(++i) !== "," && i !== str.length) { + // advance cursor + } + packet.nsp = str.substring(start, i); + } else { + packet.nsp = "/"; + } + + if (isValidCharCodeForInteger(str.charCodeAt(i + 1))) { + const start = i + 1; + while (++i !== str.length) { + if (!isValidCharCodeForInteger(str.charCodeAt(i))) { + --i; + break; + } + } + packet.id = parseInt(str.substring(start, i + 1), 10); + } + + if (str.charAt(++i)) { + let payload; + try { + payload = JSON.parse(str.substr(i)); + } catch (err) { + getLogger("socket.io").debug(`[parser] invalid payload: ${err}`); + return null; + } + if (!isPayloadValid(type, payload)) { + getLogger("socket.io").debug(`[parser] invalid payload`); + return null; + } + packet.data = payload; + } + + return packet as Packet; +} + +function isValidCharCodeForInteger(code: number) { + return code >= 48 && code <= 57; +} + +function isPayloadValid(type: PacketType, payload: unknown): boolean { + switch (type) { + case PacketType.CONNECT: + return typeof payload === "object"; + case PacketType.DISCONNECT: + return payload === undefined; + case PacketType.CONNECT_ERROR: + return typeof payload === "string" || typeof payload === "object"; + case PacketType.EVENT: + case PacketType.BINARY_EVENT: + return Array.isArray(payload) && payload.length > 0; + case PacketType.ACK: + case PacketType.BINARY_ACK: + return Array.isArray(payload); + } +} + +function injectAttachments(data: any, attachments: Attachments[]) { + if (Array.isArray(data)) { + for (let i = 0; i < data.length; i++) { + const elem = data[i]; + if (elem && elem._placeholder === true) { + data[i] = attachments.shift(); + } else { + injectAttachments(elem, attachments); + } + } + } else if (typeof data === "object" && !(data instanceof Date)) { + for (const key in data) { + if (Object.prototype.hasOwnProperty.call(data, key)) { + const elem = data[key]; + if (elem && elem._placeholder === true) { + data[key] = attachments.shift(); + } else { + injectAttachments(elem, attachments); + } + } + } + } +} diff --git a/packages/socket.io-parser/test.ts b/packages/socket.io-parser/test.ts new file mode 100644 index 0000000..376c957 --- /dev/null +++ b/packages/socket.io-parser/test.ts @@ -0,0 +1,181 @@ +import { assertEquals, describe, it } from "../../test_deps.ts"; +import { Decoder, Encoder, Packet, PacketType } from "./mod.ts"; +import { RawData } from "../engine.io-parser/mod.ts"; + +const encoder = new Encoder(); + +describe("socket.io-parser", () => { + describe("without attachments", () => { + it("should encode/decode a CONNECT packet (main namespace)", () => { + const packet = { + type: PacketType.CONNECT, + nsp: "/", + }; + return testEncodeDecode(packet, "0"); + }); + + it("should encode/decode a CONNECT packet (custom namespace)", () => { + const packet = { + type: PacketType.CONNECT, + nsp: "/woot", + data: { + token: "123", + }, + }; + return testEncodeDecode(packet, '0/woot,{"token":"123"}'); + }); + + it("should encode/decode a DISCONNECT packet", () => { + const packet = { + type: PacketType.DISCONNECT, + nsp: "/", + }; + return testEncodeDecode(packet, "1"); + }); + + it("should encode/decode an EVENT packet", () => { + const packet = { + type: PacketType.EVENT, + nsp: "/", + data: ["a", 1, {}], + }; + return testEncodeDecode(packet, '2["a",1,{}]'); + }); + + it("should encode/decode an EVENT packet with an ACK id", () => { + const packet = { + type: PacketType.EVENT, + nsp: "/", + id: 1, + data: ["a", 1, {}], + }; + return testEncodeDecode(packet, '21["a",1,{}]'); + }); + + it("should encode/decode an ACK packet", () => { + const packet = { + type: PacketType.ACK, + nsp: "/", + id: 123, + data: ["a", 1, {}], + }; + return testEncodeDecode(packet, '3123["a",1,{}]'); + }); + + it("should encode/decode a CONNECT_ERROR packet", () => { + const packet = { + type: PacketType.CONNECT_ERROR, + nsp: "/", + data: { + message: "Unauthorized", + }, + }; + return testEncodeDecode(packet, '4{"message":"Unauthorized"}'); + }); + + it("should emit a 'decode_error' event upon parsing error", async () => { + await expectDecodeError('442["some","data"'); + await expectDecodeError('0/admin,"invalid"'); + await expectDecodeError("1/admin,{}"); + await expectDecodeError('2/admin,"invalid'); + await expectDecodeError("2/admin,{}"); + await expectDecodeError("999"); + await expectDecodeError("5"); + }); + }); + + describe("with binary attachments", () => { + it("should encode/decode an EVENT packet with multiple binary attachments", () => { + return new Promise((done) => { + const packet = { + type: PacketType.EVENT, + nsp: "/cool", + id: 23, + data: ["a", { b: Uint8Array.from([1, 2, 3]) }, [ + "c", + Int32Array.from([4, 5, 6]), + ]], + }; + + const encodedPackets = encoder.encode(packet); + + assertEquals(encodedPackets.length, 3); + assertEquals( + encodedPackets[0], + '52-/cool,23["a",{"b":{"_placeholder":true,"num":0}},["c",{"_placeholder":true,"num":1}]]', + ); + assertEquals(encodedPackets[1], Uint8Array.from([1, 2, 3])); + assertEquals(encodedPackets[2], Int32Array.from([4, 5, 6])); + + const decoder = new Decoder(); + + decoder.on("packet", (decodedPacket) => { + assertEquals(decodedPacket, { + type: PacketType.BINARY_EVENT, + nsp: "/cool", + id: 23, + data: ["a", { b: Uint8Array.from([1, 2, 3]) }, [ + "c", + Int32Array.from([4, 5, 6]), + ]], + attachments: 2, + }); + + done(); + }); + + encodedPackets.forEach((p) => decoder.add(p)); + }); + }); + + it("should emit a 'decode_error' event when adding an attachment without header", () => { + return expectDecodeError(Uint8Array.from([1, 2, 3])); + }); + + it("should emit a 'decode_error' event when decoding a binary event without attachments", () => { + return expectDecodeError( + '51-["hello",{"_placeholder":true,"num":0}]', + '2["hello"]', + ); + }); + }); +}); + +function testEncodeDecode(packet: Packet, expected: string): Promise<void> { + return new Promise((resolve, reject) => { + const encodedPackets = encoder.encode(packet); + + assertEquals(encodedPackets.length, 1); + assertEquals(encodedPackets[0], expected); + + const decoder = new Decoder(); + + decoder.on("packet", (decodedPacket) => { + assertEquals(decodedPacket, packet); + + resolve(); + }); + + decoder.on("error", () => { + reject("should not happen"); + }); + + decoder.add(encodedPackets[0]); + }); +} + +function expectDecodeError(...encodedPackets: RawData[]): Promise<void> { + return new Promise((resolve, reject) => { + const decoder = new Decoder(); + + decoder.on("packet", () => { + reject("should not happen"); + }); + + decoder.on("error", () => { + resolve(); + }); + + encodedPackets.forEach((p) => decoder.add(p)); + }); +} diff --git a/packages/socket.io/lib/adapter.ts b/packages/socket.io/lib/adapter.ts new file mode 100644 index 0000000..e7fa592 --- /dev/null +++ b/packages/socket.io/lib/adapter.ts @@ -0,0 +1,298 @@ +import { EventEmitter } from "../../event-emitter/mod.ts"; +import { type Socket } from "./socket.ts"; +import { type Namespace } from "./namespace.ts"; +import { type Packet } from "../../socket.io-parser/mod.ts"; + +export type SocketId = string; +export type Room = string | number; + +export interface BroadcastOptions { + rooms: Set<Room>; + except?: Set<Room>; + flags?: BroadcastFlags; +} + +export interface BroadcastFlags { + volatile?: boolean; + local?: boolean; + broadcast?: boolean; + timeout?: number; +} + +interface AdapterEvents { + "create-room": (room: Room) => void; + "delete-room": (room: Room) => void; + "join-room": (room: Room, sid: SocketId) => void; + "leave-room": (room: Room, sid: SocketId) => void; +} + +export class Adapter extends EventEmitter<never, never, AdapterEvents> { + private readonly nsp: Namespace; + + private rooms: Map<Room, Set<SocketId>> = new Map(); + private sids: Map<SocketId, Set<Room>> = new Map(); + + constructor(nsp: Namespace) { + super(); + this.nsp = nsp; + } + + /** + * Returns the number of Socket.IO servers in the cluster + */ + public serverCount(): Promise<number> { + return Promise.resolve(1); + } + + /** + * Adds a socket to a list of room. + * + * @param id - the socket ID + * @param rooms - a set of rooms + */ + public addAll(id: SocketId, rooms: Set<Room>): Promise<void> | void { + let roomsForSid = this.sids.get(id); + if (!roomsForSid) { + this.sids.set(id, roomsForSid = new Set()); + } + + for (const room of rooms) { + roomsForSid.add(room); + + let sidsForRoom = this.rooms.get(room); + + if (!sidsForRoom) { + this.rooms.set(room, sidsForRoom = new Set()); + this.emitReserved("create-room", room); + } + if (!sidsForRoom.has(id)) { + sidsForRoom.add(id); + this.emitReserved("join-room", room, id); + } + } + } + + /** + * Removes a socket from a room. + * + * @param {SocketId} id the socket id + * @param {Room} room the room name + */ + public del(id: SocketId, room: Room): Promise<void> | void { + this.sids.get(id)?.delete(room); + this.removeSidFromRoom(room, id); + } + + private removeSidFromRoom(room: Room, id: SocketId) { + const sids = this.rooms.get(room); + + if (!sids) { + return; + } + + const deleted = sids.delete(id); + if (deleted) { + this.emitReserved("leave-room", room, id); + } + if (sids.size === 0 && this.rooms.delete(room)) { + this.emitReserved("delete-room", room); + } + } + + /** + * Removes a socket from all rooms it's joined. + * + * @param id - the socket ID + */ + public delAll(id: SocketId): void { + const rooms = this.sids.get(id); + + if (!rooms) { + return; + } + + for (const room of rooms) { + this.removeSidFromRoom(room, id); + } + + this.sids.delete(id); + } + + /** + * Broadcasts a packet. + * + * Options: + * - `flags` {Object} flags for this packet + * - `except` {Array} sids that should be excluded + * - `rooms` {Array} list of rooms to broadcast to + * + * @param {Object} packet the packet object + * @param {Object} opts the options + */ + public broadcast(packet: Packet, opts: BroadcastOptions): void { + const encodedPackets = this.nsp._server._encoder.encode(packet); + + this.apply(opts, (socket) => { + socket._notifyOutgoingListeners(packet); + socket.client._writeToEngine(encodedPackets, { + volatile: opts.flags && opts.flags.volatile, + }); + }); + } + + /** + * Broadcasts a packet and expects multiple acknowledgements. + * + * Options: + * - `flags` {Object} flags for this packet + * - `except` {Array} sids that should be excluded + * - `rooms` {Array} list of rooms to broadcast to + * + * @param {Object} packet the packet object + * @param {Object} opts the options + * @param clientCountCallback - the number of clients that received the packet + * @param ack - the callback that will be called for each client response + */ + public broadcastWithAck( + packet: Packet, + opts: BroadcastOptions, + clientCountCallback: (clientCount: number) => void, + ack: (...args: unknown[]) => void, + ) { + const flags = opts.flags || {}; + const packetOpts = { + preEncoded: true, + volatile: flags.volatile, + }; + + packet.nsp = this.nsp.name; + // we can use the same id for each packet, since the _ids counter is common (no duplicate) + packet.id = this.nsp._ids++; + + const encodedPackets = this.nsp._server._encoder.encode(packet); + + let clientCount = 0; + + this.apply(opts, (socket) => { + // track the total number of acknowledgements that are expected + clientCount++; + // call the ack callback for each client response + socket._acks.set(packet.id!, ack); + + socket._notifyOutgoingListeners(packet); + socket.client._writeToEngine(encodedPackets, packetOpts); + }); + + clientCountCallback(clientCount); + } + + /** + * Gets the list of rooms a given socket has joined. + * + * @param {SocketId} id the socket id + */ + public socketRooms(id: SocketId): Set<Room> | undefined { + return this.sids.get(id); + } + + /** + * Returns the matching socket instances + * + * @param opts - the filters to apply + */ + public fetchSockets(opts: BroadcastOptions): Promise<Socket[]> { + const sockets: Socket[] = []; + + this.apply(opts, (socket) => { + sockets.push(socket); + }); + + return Promise.resolve(sockets); + } + + /** + * Makes the matching socket instances join the specified rooms + * + * @param opts - the filters to apply + * @param rooms - the rooms to join + */ + public addSockets(opts: BroadcastOptions, rooms: Room[]): void { + this.apply(opts, (socket) => { + socket.join(rooms); + }); + } + + /** + * Makes the matching socket instances leave the specified rooms + * + * @param opts - the filters to apply + * @param rooms - the rooms to leave + */ + public delSockets(opts: BroadcastOptions, rooms: Room[]): void { + this.apply(opts, (socket) => { + rooms.forEach((room) => socket.leave(room)); + }); + } + + /** + * Makes the matching socket instances disconnect + * + * @param opts - the filters to apply + * @param close - whether to close the underlying connection + */ + public disconnectSockets(opts: BroadcastOptions, close: boolean): void { + this.apply(opts, (socket) => { + socket.disconnect(close); + }); + } + + private apply( + opts: BroadcastOptions, + callback: (socket: Socket) => void, + ): void { + const rooms = opts.rooms; + const except = this.computeExceptSids(opts.except); + + if (rooms.size) { + const ids = new Set(); + for (const room of rooms) { + if (!this.rooms.has(room)) continue; + + for (const id of this.rooms.get(room)!) { + if (ids.has(id) || except.has(id)) continue; + const socket = this.nsp.sockets.get(id); + if (socket) { + callback(socket); + ids.add(id); + } + } + } + } else { + for (const [id] of this.sids) { + if (except.has(id)) continue; + const socket = this.nsp.sockets.get(id); + if (socket) callback(socket); + } + } + } + + private computeExceptSids(exceptRooms?: Set<Room>) { + const exceptSids = new Set(); + if (exceptRooms && exceptRooms.size > 0) { + for (const room of exceptRooms) { + this.rooms.get(room)?.forEach((sid) => exceptSids.add(sid)); + } + } + return exceptSids; + } + + /** + * Send a packet to the other Socket.IO servers in the cluster + * @param _packet - an array of arguments, which may include an acknowledgement callback at the end + */ + public serverSideEmit(_packet: unknown[]): void { + console.warn( + "this adapter does not support the serverSideEmit() functionality", + ); + } +} diff --git a/packages/socket.io/lib/broadcast-operator.ts b/packages/socket.io/lib/broadcast-operator.ts new file mode 100644 index 0000000..040149e --- /dev/null +++ b/packages/socket.io/lib/broadcast-operator.ts @@ -0,0 +1,361 @@ +import { Adapter, BroadcastFlags, Room, SocketId } from "./adapter.ts"; +import { EventNames, EventParams, EventsMap } from "../../event-emitter/mod.ts"; +import { Handshake, RESERVED_EVENTS, Socket } from "./socket.ts"; +import { PacketType } from "../../socket.io-parser/mod.ts"; + +/** + * Interface for classes that aren't `EventEmitter`s, but still expose a + * strictly typed `emit` method. + */ +interface TypedEventBroadcaster<EmitEvents extends EventsMap> { + emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean; +} + +export class BroadcastOperator<EmitEvents extends EventsMap, SocketData> + implements TypedEventBroadcaster<EmitEvents> { + constructor( + private readonly adapter: Adapter, + private readonly rooms: Set<Room> = new Set<Room>(), + private readonly exceptRooms: Set<Room> = new Set<Room>(), + private readonly flags: BroadcastFlags = {}, + ) {} + + /** + * Targets a room when emitting. + * + * @param room + * @return a new BroadcastOperator instance + */ + public to(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + const rooms = new Set(this.rooms); + if (Array.isArray(room)) { + room.forEach((r) => rooms.add(r)); + } else { + rooms.add(room); + } + return new BroadcastOperator( + this.adapter, + rooms, + this.exceptRooms, + this.flags, + ); + } + + /** + * Targets a room when emitting. + * + * @param room + * @return a new BroadcastOperator instance + */ + public in(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + return this.to(room); + } + + /** + * Excludes a room when emitting. + * + * @param room + * @return a new BroadcastOperator instance + */ + public except( + room: Room | Room[], + ): BroadcastOperator<EmitEvents, SocketData> { + const exceptRooms = new Set(this.exceptRooms); + if (Array.isArray(room)) { + room.forEach((r) => exceptRooms.add(r)); + } else { + exceptRooms.add(room); + } + return new BroadcastOperator( + this.adapter, + this.rooms, + exceptRooms, + this.flags, + ); + } + + /** + * Sets a modifier for a subsequent event emission that the event data may be lost if the client is not ready to + * receive messages (because of network slowness or other issues, or because they’re connected through long polling + * and is in the middle of a request-response cycle). + * + * @return a new BroadcastOperator instance + */ + public get volatile(): BroadcastOperator<EmitEvents, SocketData> { + const flags = Object.assign({}, this.flags, { volatile: true }); + return new BroadcastOperator( + this.adapter, + this.rooms, + this.exceptRooms, + flags, + ); + } + + /** + * Sets a modifier for a subsequent event emission that the event data will only be broadcast to the current node. + * + * @return a new BroadcastOperator instance + */ + public get local(): BroadcastOperator<EmitEvents, SocketData> { + const flags = Object.assign({}, this.flags, { local: true }); + return new BroadcastOperator( + this.adapter, + this.rooms, + this.exceptRooms, + flags, + ); + } + + /** + * Adds a timeout in milliseconds for the next operation + * + * <pre><code> + * + * io.timeout(1000).emit("some-event", (err, responses) => { + * // ... + * }); + * + * </pre></code> + * + * @param timeout + */ + public timeout(timeout: number) { + const flags = Object.assign({}, this.flags, { timeout }); + return new BroadcastOperator( + this.adapter, + this.rooms, + this.exceptRooms, + flags, + ); + } + + /** + * Emits to all clients. + * + * @return Always true + */ + public emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + if (RESERVED_EVENTS.has(ev)) { + throw new Error(`"${String(ev)}" is a reserved event name`); + } + // set up packet object + const data = [ev, ...args]; + const packet = { + // @ts-ignore FIXME + nsp: this.adapter.nsp.name, + type: PacketType.EVENT, + data, + }; + + const withAck = typeof data[data.length - 1] === "function"; + + if (!withAck) { + this.adapter.broadcast(packet, { + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }); + + return true; + } + + const ack = data.pop() as (...args: unknown[]) => void; + let timedOut = false; + const responses: unknown[] = []; + + const timer = setTimeout(() => { + timedOut = true; + ack.apply(this, [new Error("operation has timed out"), responses]); + }, this.flags.timeout); + + let expectedServerCount = -1; + let actualServerCount = 0; + let expectedClientCount = 0; + + const checkCompleteness = () => { + if ( + !timedOut && + expectedServerCount === actualServerCount && + responses.length === expectedClientCount + ) { + clearTimeout(timer); + ack.apply(this, [null, responses]); + } + }; + + this.adapter.broadcastWithAck( + packet, + { + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }, + (clientCount: number) => { + // each Socket.IO server in the cluster sends the number of clients that were notified + expectedClientCount += clientCount; + actualServerCount++; + checkCompleteness(); + }, + (clientResponse: unknown) => { + // each client sends an acknowledgement + responses.push(clientResponse); + checkCompleteness(); + }, + ); + + this.adapter.serverCount().then((serverCount: number) => { + expectedServerCount = serverCount; + checkCompleteness(); + }); + + return true; + } + + /** + * Returns the matching socket instances + */ + public fetchSockets<SocketData = unknown>(): Promise< + RemoteSocket<EmitEvents, SocketData>[] + > { + return this.adapter + .fetchSockets({ + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }) + .then((sockets: Socket[]) => { + return sockets.map((socket) => { + if (socket instanceof Socket) { + // FIXME the TypeScript compiler complains about missing private properties + return socket as unknown as RemoteSocket<EmitEvents, SocketData>; + } else { + return new RemoteSocket( + this.adapter, + socket as SocketDetails<SocketData>, + ); + } + }); + }); + } + + /** + * Makes the matching socket instances join the specified rooms + * + * @param room + */ + public socketsJoin(room: Room | Room[]): void { + this.adapter.addSockets( + { + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }, + Array.isArray(room) ? room : [room], + ); + } + + /** + * Makes the matching socket instances leave the specified rooms + * + * @param room + */ + public socketsLeave(room: Room | Room[]): void { + this.adapter.delSockets( + { + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }, + Array.isArray(room) ? room : [room], + ); + } + + /** + * Makes the matching socket instances disconnect + * + * @param close - whether to close the underlying connection + */ + public disconnectSockets(close = false): void { + this.adapter.disconnectSockets( + { + rooms: this.rooms, + except: this.exceptRooms, + flags: this.flags, + }, + close, + ); + } +} + +/** + * Format of the data when the Socket instance exists on another Socket.IO server + */ +interface SocketDetails<SocketData> { + id: SocketId; + handshake: Handshake; + rooms: Room[]; + data: SocketData; +} + +/** + * Expose of subset of the attributes and methods of the Socket class + */ +export class RemoteSocket<EmitEvents extends EventsMap, SocketData> + implements TypedEventBroadcaster<EmitEvents> { + public readonly id: SocketId; + public readonly handshake: Handshake; + public readonly rooms: Set<Room>; + public readonly data: SocketData; + + private readonly operator: BroadcastOperator<EmitEvents, SocketData>; + + constructor(adapter: Adapter, details: SocketDetails<SocketData>) { + this.id = details.id; + this.handshake = details.handshake; + this.rooms = new Set(details.rooms); + this.data = details.data; + this.operator = new BroadcastOperator(adapter, new Set([this.id])); + } + + public emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + return this.operator.emit(ev, ...args); + } + + /** + * Joins a room. + * + * @param {String|Array} room - room or array of rooms + */ + public join(room: Room | Room[]): void { + return this.operator.socketsJoin(room); + } + + /** + * Leaves a room. + * + * @param {String} room + */ + public leave(room: Room): void { + return this.operator.socketsLeave(room); + } + + /** + * Disconnects this client. + * + * @param {Boolean} close - if `true`, closes the underlying connection + * @return {Socket} self + */ + public disconnect(close = false): this { + this.operator.disconnectSockets(close); + return this; + } +} diff --git a/packages/socket.io/lib/client.ts b/packages/socket.io/lib/client.ts new file mode 100644 index 0000000..3754754 --- /dev/null +++ b/packages/socket.io/lib/client.ts @@ -0,0 +1,235 @@ +import { EventsMap } from "../../event-emitter/mod.ts"; +import { Decoder, Packet, PacketType } from "../../socket.io-parser/mod.ts"; +import { type Socket as RawSocket } from "../../engine.io/mod.ts"; +import { ConnInfo, getLogger } from "../../../deps.ts"; +import { Handshake, Socket } from "./socket.ts"; +import { Server } from "./server.ts"; +import { RawData } from "../../engine.io-parser/mod.ts"; +import { CloseReason } from "../../engine.io/lib/socket.ts"; + +interface WriteOptions { + volatile?: boolean; +} + +export class Client< + ListenEvents extends EventsMap, + EmitEvents extends EventsMap, + ServerSideEvents extends EventsMap, + SocketData = unknown, +> { + public readonly conn: RawSocket; + + private readonly server: Server< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >; + private readonly handshake: Omit<Handshake, "issued" | "time" | "auth">; + private readonly decoder: Decoder; + + private sockets = new Map< + string, + Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData> + >(); + + private connectTimerId?: number; + + constructor( + server: Server<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + decoder: Decoder, + conn: RawSocket, + req: Request, + connInfo: ConnInfo, + ) { + this.server = server; + this.decoder = decoder; + this.conn = conn; + + const url = new URL(req.url); + this.handshake = { + url: url.pathname, + headers: req.headers, + query: url.searchParams, + address: (connInfo.remoteAddr as Deno.NetAddr).hostname, + secure: false, + xdomain: req.headers.has("origin"), + }; + + conn.on("message", (data) => this.decoder.add(data)); + conn.on("close", (reason) => this.onclose(reason)); + + this.decoder.on("packet", (packet) => this.onPacket(packet)); + this.decoder.on("error", () => this.onclose("parse error")); + + this.connectTimerId = setTimeout(() => { + if (this.sockets.size === 0) { + getLogger("socket.io").debug( + "[client] no namespace joined yet, close the client", + ); + this.close(); + } + // @ts-ignore FIXME + }, this.server.opts.connectTimeout); + } + + private onPacket(packet: Packet) { + const socket = this.sockets.get(packet.nsp); + + if (!socket && packet.type === PacketType.CONNECT) { + this.connect(packet.nsp, packet.data); + } else if ( + socket && + packet.type !== PacketType.CONNECT && + packet.type !== PacketType.CONNECT_ERROR + ) { + socket._onpacket(packet); + } else { + getLogger("socket.io").debug( + `[client] invalid state (packet type: ${packet.type})`, + ); + this.close(); + } + } + + private async connect(name: string, auth: Record<string, unknown> = {}) { + if (this.server._nsps.has(name)) { + getLogger("socket.io").debug(`[client] connecting to namespace ${name}`); + return this.doConnect(name, auth); + } + + try { + await this.server._checkNamespace(name, auth); + } catch (_) { + getLogger("socket.io").debug( + `[client] creation of namespace ${name} was denied`, + ); + this._packet({ + type: PacketType.CONNECT_ERROR, + nsp: name, + data: { + message: "Invalid namespace", + }, + }); + return; + } + + getLogger("socket.io").debug( + `[client] connecting to dynamic namespace ${name}`, + ); + this.doConnect(name, auth); + } + + /** + * Connects a client to a namespace. + * + * @param name - the namespace + * @param {Object} auth - the auth parameters + * + * @private + */ + private doConnect(name: string, auth: Record<string, unknown>): void { + const nsp = this.server.of(name); + + const now = new Date(); + const handshake: Handshake = Object.assign({ + issued: now.getTime(), + time: now.toISOString(), + auth, + }, this.handshake); + + nsp._add(this, handshake, (socket) => { + this.sockets.set(name, socket); + + if (this.connectTimerId) { + clearTimeout(this.connectTimerId); + this.connectTimerId = undefined; + } + }); + } + + /** + * Disconnects from all namespaces and closes transport. + * + * @private + */ + _disconnect(): void { + for (const socket of this.sockets.values()) { + socket.disconnect(); + } + this.sockets.clear(); + this.close(); + } + + /** + * Removes a socket. Called by each `Socket`. + * + * @private + */ + _remove( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ): void { + this.sockets.delete(socket.id); + } + + private close() { + if (this.conn.readyState === "open") { + getLogger("socket.io").debug("[client] forcing transport close"); + this.conn.close(); + this.onclose("forced close"); + } + } + + private onclose(reason: CloseReason) { + getLogger("socket.io").debug( + `[client] client closed with reason ${reason}`, + ); + + // ignore a potential subsequent `close` event + this.conn.off(); + this.decoder.off(); + + if (this.connectTimerId) { + clearTimeout(this.connectTimerId); + this.connectTimerId = undefined; + } + + for (const socket of this.sockets.values()) { + socket._onclose(reason); + } + + this.sockets.clear(); + this.decoder.destroy(); + } + + /** + * Writes a packet to the transport. + * + * @param {Object} packet object + * @param {Object} opts + * @private + */ + /* private */ _packet(packet: Packet, opts: WriteOptions = {}) { + if (this.conn.readyState !== "open") { + getLogger("socket.io").debug(`[client] ignoring packet write ${packet}`); + return; + } + const encodedPackets = this.server._encoder.encode(packet); + this._writeToEngine(encodedPackets, opts); + } + + /* private */ _writeToEngine( + encodedPackets: RawData[], + opts: WriteOptions, + ) { + if (opts.volatile && !this.conn.transport.writable) { + getLogger("socket.io").debug( + "[client] volatile packet is discarded since the transport is not currently writable", + ); + return; + } + for (const encodedPacket of encodedPackets) { + this.conn.send(encodedPacket); + } + } +} diff --git a/packages/socket.io/lib/namespace.ts b/packages/socket.io/lib/namespace.ts new file mode 100644 index 0000000..136643c --- /dev/null +++ b/packages/socket.io/lib/namespace.ts @@ -0,0 +1,340 @@ +import { + DefaultEventsMap, + EventEmitter, + EventNames, + EventParams, + EventsMap, +} from "../../event-emitter/mod.ts"; +import { Handshake, Socket } from "./socket.ts"; +import { Server, ServerReservedEvents } from "./server.ts"; +import { Adapter, Room, SocketId } from "./adapter.ts"; +import { Client } from "./client.ts"; +import { getLogger } from "../../../deps.ts"; +import { BroadcastOperator, RemoteSocket } from "./broadcast-operator.ts"; + +export interface NamespaceReservedEvents< + ListenEvents extends EventsMap, + EmitEvents extends EventsMap, + ServerSideEvents extends EventsMap, + SocketData, +> { + connection: ( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) => void; +} + +export const RESERVED_EVENTS: ReadonlySet<string | symbol> = new Set< + keyof ServerReservedEvents<never, never, never, never> +>(["connection", "new_namespace"] as const); + +export class Namespace< + ListenEvents extends EventsMap = DefaultEventsMap, + EmitEvents extends EventsMap = DefaultEventsMap, + ServerSideEvents extends EventsMap = DefaultEventsMap, + SocketData = unknown, +> extends EventEmitter< + ServerSideEvents, + EmitEvents, + NamespaceReservedEvents< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + > +> { + public readonly name: string; + public readonly sockets: Map< + SocketId, + Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData> + > = new Map(); + public adapter: Adapter; + + /* private */ readonly _server: Server< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >; + + /* private */ _fns: Array< + ( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) => Promise<void> + > = []; + + /* private */ _ids = 0; + + constructor( + server: Server<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + name: string, + ) { + super(); + this._server = server; + this.name = name; + this.adapter = new Adapter(this as Namespace); + } + + /** + * Sets up namespace middleware. + * + * @param fn - the middleware function + */ + public use( + fn: ( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) => Promise<void>, + ): this { + this._fns.push(fn); + return this; + } + + /** + * Executes the middleware for an incoming client. + * + * @param socket - the socket that will get added + * @private + */ + private async run( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ): Promise<void> { + switch (this._fns.length) { + case 0: + return; + case 1: + return this._fns[0](socket); + default: + for (const fn of this._fns.slice()) { + await fn(socket); + } + } + } + + /** + * Targets a room when emitting. + * + * @param room + * @return self + */ + public to(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + return new BroadcastOperator(this.adapter).to(room); + } + + /** + * Targets a room when emitting. + * + * @param room + * @return self + */ + public in(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + return new BroadcastOperator(this.adapter).in(room); + } + + /** + * Excludes a room when emitting. + * + * @param room + * @return self + */ + public except( + room: Room | Room[], + ): BroadcastOperator<EmitEvents, SocketData> { + return new BroadcastOperator(this.adapter).except(room); + } + + /** + * Adds a new client + * + * @param client - the client + * @param handshake - the handshake + * @private + */ + /* private */ async _add( + client: Client<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + handshake: Handshake, + callback: ( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) => void, + ) { + getLogger("socket.io").debug( + `[namespace] adding socket to nsp ${this.name}`, + ); + const socket = new Socket< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >(this, client, handshake); + + try { + await this.run(socket); + } catch (err) { + getLogger("socket.io").debug( + "[namespace] middleware error, sending CONNECT_ERROR packet to the client", + ); + socket._cleanup(); + return socket._error({ + message: err.message || err, + data: err.data, + }); + } + + if (client.conn.readyState !== "open") { + getLogger("socket.io").debug( + "[namespace] next called after client was closed - ignoring socket", + ); + socket._cleanup(); + return; + } + + // track socket + this.sockets.set(socket.id, socket); + + // it's paramount that the internal `onconnect` logic + // fires before user-set events to prevent state order + // violations (such as a disconnection before the connection + // logic is complete) + socket._onconnect(); + + callback(socket); + + // fire user-set events + this.emitReserved("connection", socket); + } + + /** + * Removes a client. Called by each `Socket`. + * + * @private + */ + /* private */ _remove( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ): void { + this.sockets.delete(socket.id); + } + + /** + * Emits to all clients. + * + * @return Always true + */ + public emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + return new BroadcastOperator<EmitEvents, SocketData>(this.adapter).emit( + ev, + ...args, + ); + } + + /** + * Sends a `message` event to all clients. + * + * @return self + */ + public send(...args: EventParams<EmitEvents, "message">): this { + this.emit("message", ...args); + return this; + } + + /** + * Emit a packet to other Socket.IO servers + * + * @param ev - the event name + * @param args - an array of arguments, which may include an acknowledgement callback at the end + */ + public serverSideEmit<Ev extends EventNames<ServerSideEvents>>( + ev: Ev, + ...args: EventParams<ServerSideEvents, Ev> + ): boolean { + if (RESERVED_EVENTS.has(ev)) { + throw new Error(`"${String(ev)}" is a reserved event name`); + } + args.unshift(ev); + this.adapter.serverSideEmit(args); + return true; + } + + /** + * Called when a packet is received from another Socket.IO server + * + * @param args - an array of arguments, which may include an acknowledgement callback at the end + * + * @private + */ + /* private */ _onServerSideEmit(args: [string, ...unknown[]]) { + // @ts-ignore FIXME + super.emit.apply(this, args); + } + + /** + * Sets a modifier for a subsequent event emission that the event data may be lost if the client is not ready to + * receive messages (because of network slowness or other issues, or because they’re connected through long polling + * and is in the middle of a request-response cycle). + * + * @return self + */ + public get volatile(): BroadcastOperator<EmitEvents, SocketData> { + return new BroadcastOperator(this.adapter).volatile; + } + + /** + * Sets a modifier for a subsequent event emission that the event data will only be broadcast to the current node. + * + * @return self + */ + public get local(): BroadcastOperator<EmitEvents, SocketData> { + return new BroadcastOperator(this.adapter).local; + } + + /** + * Adds a timeout in milliseconds for the next operation + * + * <pre><code> + * + * io.timeout(1000).emit("some-event", (err, responses) => { + * // ... + * }); + * + * </pre></code> + * + * @param timeout + */ + public timeout(timeout: number) { + return new BroadcastOperator(this.adapter).timeout(timeout); + } + + /** + * Returns the matching socket instances + */ + public fetchSockets(): Promise<RemoteSocket<EmitEvents, SocketData>[]> { + return new BroadcastOperator(this.adapter).fetchSockets(); + } + + /** + * Makes the matching socket instances join the specified rooms + * + * @param room + */ + public socketsJoin(room: Room | Room[]): void { + return new BroadcastOperator(this.adapter).socketsJoin(room); + } + + /** + * Makes the matching socket instances leave the specified rooms + * + * @param room + */ + public socketsLeave(room: Room | Room[]): void { + return new BroadcastOperator(this.adapter).socketsLeave(room); + } + + /** + * Makes the matching socket instances disconnect + * + * @param close - whether to close the underlying connection + */ + public disconnectSockets(close = false): void { + return new BroadcastOperator(this.adapter).disconnectSockets(close); + } +} diff --git a/packages/socket.io/lib/parent-namespace.ts b/packages/socket.io/lib/parent-namespace.ts new file mode 100644 index 0000000..7517dfc --- /dev/null +++ b/packages/socket.io/lib/parent-namespace.ts @@ -0,0 +1,70 @@ +import { + DefaultEventsMap, + EventNames, + EventParams, + EventsMap, +} from "../../event-emitter/mod.ts"; +import { Namespace } from "./namespace.ts"; +import { Server } from "./server.ts"; + +export class ParentNamespace< + ListenEvents extends EventsMap = DefaultEventsMap, + EmitEvents extends EventsMap = DefaultEventsMap, + ServerSideEvents extends EventsMap = DefaultEventsMap, + SocketData = unknown, +> extends Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> { + private static count = 0; + + private children: Set< + Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> + > = new Set(); + + constructor( + server: Server<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) { + super(server, "/_" + ParentNamespace.count++); + // this.adapter = { + // broadcast(packet: Packet, opts: BroadcastOptions) { + // this.children.forEach((nsp: Namespace) => { + // nsp.adapter.broadcast(packet, opts); + // }); + // } + // }; + } + + public emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + this.children.forEach((nsp) => { + nsp.emit(ev, ...args); + }); + + return true; + } + + /* private */ _createChild( + name: string, + ): Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> { + const namespace = new Namespace(this._server, name); + namespace._fns = this._fns.slice(0); + this.listeners("connect").forEach((listener) => + namespace.on("connect", listener) + ); + this.listeners("connection").forEach((listener) => + namespace.on("connection", listener) + ); + this.children.add(namespace); + this._server._nsps.set(name, namespace); + return namespace; + } + + // fetchSockets(): Promise<RemoteSocket<EmitEvents, SocketData>[]> { + // // note: we could make the fetchSockets() method work for dynamic namespaces created with a regex (by sending the + // // regex to the other Socket.IO servers, and returning the sockets of each matching namespace for example), but + // // the behavior for namespaces created with a function is less clear + // // note²: we cannot loop over each children namespace, because with multiple Socket.IO servers, a given namespace + // // may exist on one node but not exist on another (since it is created upon client connection) + // throw new Error("fetchSockets() is not supported on parent namespaces"); + // } +} diff --git a/packages/socket.io/lib/server.ts b/packages/socket.io/lib/server.ts new file mode 100644 index 0000000..ba37b0e --- /dev/null +++ b/packages/socket.io/lib/server.ts @@ -0,0 +1,381 @@ +import { + Server as Engine, + ServerOptions as EngineOptions, +} from "../../engine.io/mod.ts"; +import { + DefaultEventsMap, + EventEmitter, + EventNames, + EventParams, + EventsMap, +} from "../../event-emitter/mod.ts"; +import { getLogger, type Handler } from "../../../deps.ts"; +import { Client } from "./client.ts"; +import { Decoder, Encoder } from "../../socket.io-parser/mod.ts"; +import { Namespace, NamespaceReservedEvents } from "./namespace.ts"; +import { ParentNamespace } from "./parent-namespace.ts"; +import { Socket } from "./socket.ts"; +import { Room } from "./adapter.ts"; +import { BroadcastOperator, RemoteSocket } from "./broadcast-operator.ts"; + +export interface ServerOptions { + /** + * Name of the request path to handle + * @default "/socket.io/" + */ + path: string; + /** + * Duration in milliseconds before a client without namespace is closed + * @default 45000 + */ + connectTimeout: number; + /** + * The parser to use to encode and decode packets + */ + parser: { + createEncoder(): Encoder; + createDecoder(): Decoder; + }; +} + +export interface ServerReservedEvents< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData, +> extends + NamespaceReservedEvents< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + > { + new_namespace: ( + namespace: Namespace< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >, + ) => void; +} + +type ParentNspNameMatchFn = ( + name: string, + auth: Record<string, unknown>, +) => Promise<void>; + +export class Server< + ListenEvents extends EventsMap = DefaultEventsMap, + EmitEvents extends EventsMap = ListenEvents, + ServerSideEvents extends EventsMap = DefaultEventsMap, + SocketData = unknown, +> extends EventEmitter< + ListenEvents, + EmitEvents, + ServerReservedEvents< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + > +> { + public readonly engine: Engine; + public readonly mainNamespace: Namespace< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >; + + private readonly opts: ServerOptions; + /* private */ readonly _encoder: Encoder; + + /* private */ _nsps: Map< + string, + Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> + > = new Map(); + + private parentNsps: Map< + ParentNspNameMatchFn, + ParentNamespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> + > = new Map(); + + constructor(opts: Partial<ServerOptions & EngineOptions> = {}) { + super(); + + this.opts = Object.assign({ + path: "/socket.io/", + connectTimeout: 45_000, + parser: { + createEncoder() { + return new Encoder(); + }, + createDecoder() { + return new Decoder(); + }, + }, + }, opts); + + this.engine = new Engine(this.opts); + + this.engine.on("connection", (conn, req, connInfo) => { + getLogger("socket.io").debug( + `[server] incoming connection with id ${conn.id}`, + ); + new Client(this, this.opts.parser.createDecoder(), conn, req, connInfo); + }); + + this._encoder = this.opts.parser.createEncoder(); + + const mainNamespace = this.of("/"); + + ["on", "once", "off", "emit", "listeners"].forEach((method) => { + // @ts-ignore FIXME proper typing + this[method] = function () { + // @ts-ignore FIXME proper typing + return mainNamespace[method].apply(mainNamespace, arguments); + }; + }); + + this.mainNamespace = mainNamespace; + } + + /** + * Returns a request handler. + * + * @param additionalHandler - another handler which will receive the request if the path does not match + */ + public handler(additionalHandler?: Handler) { + return this.engine.handler(additionalHandler); + } + + /** + * Executes the middleware for an incoming namespace not already created on the server. + * + * @param name - name of incoming namespace + * @param auth - the auth parameters + * @param fn - callback + * + * @private + */ + /* private */ async _checkNamespace( + name: string, + auth: Record<string, unknown>, + ): Promise<void> { + if (this.parentNsps.size === 0) return Promise.reject(); + + for (const [isValid, parentNsp] of this.parentNsps) { + try { + await isValid(name, auth); + } catch (_) { + continue; + } + + if (this._nsps.has(name)) { + // the namespace was created in the meantime + getLogger("socket.io").debug( + `[server] dynamic namespace ${name} already exists`, + ); + } else { + const namespace = parentNsp._createChild(name); + getLogger("socket.io").debug( + `[server] dynamic namespace ${name} was created`, + ); + this.emitReserved("new_namespace", namespace); + } + + return Promise.resolve(); + } + + return Promise.reject(); + } + + /** + * Looks up a namespace. + * + * @param name - nsp name + */ + public of( + name: string | RegExp | ParentNspNameMatchFn, + ): Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData> { + if (typeof name === "function" || name instanceof RegExp) { + const parentNsp = new ParentNamespace(this); + getLogger("socket.io").debug( + `[server] initializing parent namespace ${parentNsp.name}`, + ); + if (typeof name === "function") { + this.parentNsps.set(name, parentNsp); + } else { + this.parentNsps.set( + (nsp: string) => + (name as RegExp).test(nsp) ? Promise.resolve() : Promise.reject(), + parentNsp, + ); + } + + return parentNsp; + } + + if (String(name)[0] !== "/") name = "/" + name; + + let nsp = this._nsps.get(name); + if (!nsp) { + getLogger("socket.io").debug(`[server] initializing namespace ${name}`); + nsp = new Namespace(this, name); + this._nsps.set(name, nsp); + if (name !== "/") { + this.emitReserved("new_namespace", nsp); + } + } + + return nsp; + } + + /** + * Closes the server + */ + public close() { + this.engine.close(); + } + + /** + * Sets up namespace middleware. + * + * @param fn - the middleware function + */ + public use( + fn: ( + socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + ) => Promise<void>, + ): this { + this.mainNamespace.use(fn); + return this; + } + + /** + * Targets a room when emitting. + * + * @param room + * @return self + */ + public to(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.to(room); + } + + /** + * Targets a room when emitting. + * + * @param room + * @return self + */ + public in(room: Room | Room[]): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.in(room); + } + + /** + * Excludes a room when emitting. + * + * @param name + * @return self + */ + public except( + name: Room | Room[], + ): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.except(name); + } + + /** + * Sends a `message` event to all clients. + * + * @return self + */ + public send(...args: EventParams<EmitEvents, "message">): this { + this.mainNamespace.emit("message", ...args); + return this; + } + + /** + * Emit a packet to other Socket.IO servers + * + * @param ev - the event name + * @param args - an array of arguments, which may include an acknowledgement callback at the end + */ + public serverSideEmit<Ev extends EventNames<ServerSideEvents>>( + ev: Ev, + ...args: EventParams<ServerSideEvents, Ev> + ): boolean { + return this.mainNamespace.serverSideEmit(ev, ...args); + } + + /** + * Sets a modifier for a subsequent event emission that the event data may be lost if the client is not ready to + * receive messages (because of network slowness or other issues, or because they’re connected through long polling + * and is in the middle of a request-response cycle). + * + * @return self + */ + public get volatile(): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.volatile; + } + + /** + * Sets a modifier for a subsequent event emission that the event data will only be broadcast to the current node. + * + * @return self + */ + public get local(): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.local; + } + + /** + * Adds a timeout in milliseconds for the next operation + * + * <pre><code> + * + * io.timeout(1000).emit("some-event", (err, responses) => { + * // ... + * }); + * + * </pre></code> + * + * @param timeout + */ + public timeout(timeout: number): BroadcastOperator<EmitEvents, SocketData> { + return this.mainNamespace.timeout(timeout); + } + + /** + * Returns the matching socket instances + */ + public fetchSockets(): Promise<RemoteSocket<EmitEvents, SocketData>[]> { + return this.mainNamespace.fetchSockets(); + } + + /** + * Makes the matching socket instances join the specified rooms + * + * @param room + */ + public socketsJoin(room: Room | Room[]): void { + return this.mainNamespace.socketsJoin(room); + } + + /** + * Makes the matching socket instances leave the specified rooms + * + * @param room + */ + public socketsLeave(room: Room | Room[]): void { + return this.mainNamespace.socketsLeave(room); + } + + /** + * Makes the matching socket instances disconnect + * + * @param close - whether to close the underlying connection + */ + public disconnectSockets(close = false): void { + return this.mainNamespace.disconnectSockets(close); + } +} diff --git a/packages/socket.io/lib/socket.ts b/packages/socket.io/lib/socket.ts new file mode 100644 index 0000000..8c17431 --- /dev/null +++ b/packages/socket.io/lib/socket.ts @@ -0,0 +1,551 @@ +import { Packet, PacketType } from "../../socket.io-parser/mod.ts"; +import { getLogger } from "../../../deps.ts"; +import { + DefaultEventsMap, + EventEmitter, + EventNames, + EventParams, + EventsMap, +} from "../../event-emitter/mod.ts"; +import { Adapter, BroadcastFlags, Room, SocketId } from "./adapter.ts"; +import { generateId } from "../../engine.io/mod.ts"; +import { Namespace } from "./namespace.ts"; +import { Client } from "./client.ts"; +import { BroadcastOperator } from "./broadcast-operator.ts"; + +type ClientReservedEvents = "connect" | "connect_error"; + +type DisconnectReason = + // Engine.IO close reasons + | "transport error" + | "transport close" + | "forced close" + | "ping timeout" + | "parse error" + // Socket.IO disconnect reasons + | "client namespace disconnect" + | "server namespace disconnect"; + +export interface SocketReservedEvents { + disconnect: (reason: DisconnectReason) => void; + disconnecting: (reason: DisconnectReason) => void; +} + +// EventEmitter reserved events: https://nodejs.org/api/events.html#events_event_newlistener +export interface EventEmitterReservedEvents { + newListener: ( + eventName: string | symbol, + listener: (...args: unknown[]) => void, + ) => void; + removeListener: ( + eventName: string | symbol, + listener: (...args: unknown[]) => void, + ) => void; +} + +export const RESERVED_EVENTS: ReadonlySet<string | symbol> = new Set< + | ClientReservedEvents + | keyof SocketReservedEvents + | keyof EventEmitterReservedEvents +>( + [ + "connect", + "connect_error", + "disconnect", + "disconnecting", + "newListener", + "removeListener", + ] as const, +); + +/** + * The handshake details + */ +export interface Handshake { + /** + * The headers sent as part of the handshake + */ + headers: Headers; + + /** + * The date of creation (as string) + */ + time: string; + + /** + * The ip of the client + */ + address: string; + + /** + * Whether the connection is cross-domain + */ + xdomain: boolean; + + /** + * Whether the connection is secure + */ + secure: boolean; + + /** + * The date of creation (as unix timestamp) + */ + issued: number; + + /** + * The request URL string + */ + url: string; + + /** + * The query object + */ + query: URLSearchParams; + + /** + * The auth object + */ + auth: Record<string, unknown>; +} + +function noop() {} + +export class Socket< + ListenEvents extends EventsMap = DefaultEventsMap, + EmitEvents extends EventsMap = DefaultEventsMap, + ServerSideEvents extends EventsMap = DefaultEventsMap, + SocketData = unknown, +> extends EventEmitter< + ListenEvents, + EmitEvents, + SocketReservedEvents +> { + public readonly id: SocketId; + public readonly handshake: Handshake; + /** + * Additional information that can be attached to the Socket instance and which will be used in the fetchSockets method + */ + public data: Partial<SocketData> = {}; + + public connected = false; + + private readonly nsp: Namespace< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >; + private readonly adapter: Adapter; + + /* private */ _acks: Map<number, () => void> = new Map(); + private flags: BroadcastFlags = {}; + private anyIncomingListeners?: Array<(...args: unknown[]) => void>; + private anyOutgoingListeners?: Array<(...args: unknown[]) => void>; + + /* private */ readonly client: Client< + ListenEvents, + EmitEvents, + ServerSideEvents, + SocketData + >; + + constructor( + nsp: Namespace<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + client: Client<ListenEvents, EmitEvents, ServerSideEvents, SocketData>, + handshake: Handshake, + ) { + super(); + this.nsp = nsp; + this.id = generateId(); + this.client = client; + this.adapter = nsp.adapter; + this.handshake = handshake; + } + + /** + * Emits to this client. + * + * @return Always returns `true`. + */ + public emit<Ev extends EventNames<EmitEvents>>( + ev: Ev, + ...args: EventParams<EmitEvents, Ev> + ): boolean { + if (RESERVED_EVENTS.has(ev)) { + throw new Error(`"${String(ev)}" is a reserved event name`); + } + const data: unknown[] = [ev, ...args]; + const packet: Packet = { + nsp: this.nsp.name, + type: PacketType.EVENT, + data: data, + }; + + // access last argument to see if it's an ACK callback + if (typeof data[data.length - 1] === "function") { + const id = this.nsp._ids++; + getLogger("socket.io").debug( + `[socket] emitting packet with ack id ${id}`, + ); + + this.registerAckCallback(id, data.pop() as (...args: unknown[]) => void); + packet.id = id; + } + + const flags = Object.assign({}, this.flags); + this.flags = {}; + + this._notifyOutgoingListeners(packet); + this.packet(packet, flags); + + return true; + } + + /** + * @private + */ + private registerAckCallback(id: number, ack: (...args: unknown[]) => void) { + const timeout = this.flags.timeout; + if (timeout === undefined) { + this._acks.set(id, ack); + return; + } + + const timerId = setTimeout(() => { + getLogger("socket.io").debug( + `[socket] event with ack id ${id} has timed out after ${timeout} ms`, + ); + this._acks.delete(id); + ack.call(this, new Error("operation has timed out")); + }, timeout); + + this._acks.set(id, (...args) => { + clearTimeout(timerId); + ack.apply(this, [null, ...args]); + }); + } + + /** + * @param packet + */ + /* private */ _onpacket(packet: Packet) { + if (!this.connected) { + return; + } + + getLogger("socket.io").debug(`[socket] got packet type ${packet.type}`); + switch (packet.type) { + case PacketType.EVENT: + case PacketType.BINARY_EVENT: + this.onevent(packet); + break; + + case PacketType.ACK: + case PacketType.BINARY_ACK: + this.onack(packet); + break; + + case PacketType.DISCONNECT: + this.ondisconnect(); + break; + } + } + + /** + * Called upon event packet. + * + * @param {Packet} packet - packet object + * @private + */ + private onevent(packet: Packet): void { + const args = packet.data || []; + getLogger("socket.io").debug(`[socket] emitting event ${args}`); + + if (null != packet.id) { + getLogger("socket.io").debug("[socket] attaching ack callback to event"); + args.push(this.ack(packet.id)); + } + + if (this.anyIncomingListeners && this.anyIncomingListeners.length) { + const listeners = this.anyIncomingListeners.slice(); + for (const listener of listeners) { + listener.apply(this, args); + } + } + + if (this.connected) { + super.emit.apply(this, args); + } + } + + /** + * Produces an ack callback to emit with an event. + * + * @param {Number} id - packet id + * @private + */ + private ack(id: number): () => void { + const self = this; + let sent = false; + return function () { + // prevent double callbacks + if (sent) return; + const args = Array.prototype.slice.call(arguments); + getLogger("socket.io").debug(`[socket] sending ack ${id}`); + + self.packet({ + id: id, + type: PacketType.ACK, + data: args, + }); + + sent = true; + }; + } + + /** + * Called upon ack packet. + * + * @private + */ + private onack(packet: Packet): void { + const ack = this._acks.get(packet.id!); + if ("function" == typeof ack) { + getLogger("socket.io").debug( + `[socket] calling ack ${packet.id}`, + ); + ack.apply(this, packet.data); + this._acks.delete(packet.id!); + } else { + getLogger("socket.io").debug(`[socket] bad ack ${packet.id}`); + } + } + + /** + * Called upon client disconnect packet. + * + * @private + */ + private ondisconnect(): void { + getLogger("socket.io").debug("[socket] got disconnect packet"); + this._onclose("client namespace disconnect"); + } + + /** + * Called upon closing. Called by `Client`. + * + * @param {String} reason + * @throw {Error} optional error object + * + * @private + */ + /* private */ _onclose(reason: DisconnectReason): this | undefined { + if (!this.connected) return this; + getLogger("socket.io").debug(`[socket] closing socket - reason ${reason}`); + this.emitReserved("disconnecting", reason); + this._cleanup(); + this.nsp._remove(this); + this.client._remove(this); + this.connected = false; + this.emitReserved("disconnect", reason); + return; + } + + /** + * Makes the socket leave all the rooms it was part of and prevents it from joining any other room + * + * @private + */ + /* private */ _cleanup() { + this.leaveAll(); + this.join = noop; + } + + /** + * Notify the listeners for each packet sent (emit or broadcast) + * + * @param packet + * + * @private + */ + /* private */ _notifyOutgoingListeners(packet: Packet) { + if (this.anyOutgoingListeners && this.anyOutgoingListeners.length) { + const listeners = this.anyOutgoingListeners.slice(); + for (const listener of listeners) { + listener.apply(this, packet.data); + } + } + } + + /** + * Sends a `message` event. + * + * @return self + */ + public send(...args: EventParams<EmitEvents, "message">): this { + this.emit("message", ...args); + return this; + } + + /** + * Writes a packet. + * + * @param {Object} packet - packet object + * @param {Object} opts - options + * @private + */ + private packet( + packet: Omit<Packet, "nsp"> & Partial<Pick<Packet, "nsp">>, + opts = {}, + ): void { + packet.nsp = this.nsp.name; + this.client._packet(packet as Packet, opts); + } + + /** + * Joins a room. + * + * @param {String|Array} rooms - room or array of rooms + * @return a Promise or nothing, depending on the adapter + */ + public join(rooms: Room | Array<Room>): Promise<void> | void { + getLogger("socket.io").debug(`[socket] join room ${rooms}`); + + return this.adapter.addAll( + this.id, + new Set(Array.isArray(rooms) ? rooms : [rooms]), + ); + } + + /** + * Leaves a room. + * + * @param {String} room + * @return a Promise or nothing, depending on the adapter + */ + public leave(room: Room): Promise<void> | void { + getLogger("socket.io").debug("[socket] leave room %s", room); + + return this.adapter.del(this.id, room); + } + + /** + * Leave all rooms. + * + * @private + */ + private leaveAll(): void { + this.adapter.delAll(this.id); + } + + /** + * Called by `Namespace` upon successful + * middleware execution (ie: authorization). + * Socket is added to namespace array before + * call to join, so adapters can access it. + * + * @private + */ + /* private */ _onconnect(): void { + getLogger("socket.io").debug("[socket] socket connected - writing packet"); + this.connected = true; + this.join(this.id); + this.packet({ type: PacketType.CONNECT, data: { sid: this.id } }); + } + + /** + * Produces an `error` packet. + * + * @param err - error object + * + * @private + */ + /* private */ _error(err: { message: string; data: unknown }) { + this.packet({ type: PacketType.CONNECT_ERROR, data: err }); + } + + /** + * Disconnects this client. + * + * @param {Boolean} close - if `true`, closes the underlying connection + * @return {Socket} self + */ + public disconnect(close = false): this { + if (!this.connected) return this; + if (close) { + this.client._disconnect(); + } else { + this.packet({ type: PacketType.DISCONNECT }); + this._onclose("server namespace disconnect"); + } + return this; + } + + /** + * Sets a modifier for a subsequent event emission that the event data may be lost if the client is not ready to + * receive messages (because of network slowness or other issues, or because they’re connected through long polling + * and is in the middle of a request-response cycle). + * + * @return {Socket} self + */ + public get volatile(): this { + this.flags.volatile = true; + return this; + } + + /** + * Sets a modifier for a subsequent event emission that the event data will only be broadcast to every sockets but the + * sender. + * + * @return {Socket} self + */ + public get broadcast(): BroadcastOperator<EmitEvents, SocketData> { + return this.newBroadcastOperator(); + } + + /** + * Sets a modifier for a subsequent event emission that the event data will only be broadcast to the current node. + * + * @return {Socket} self + */ + public get local(): BroadcastOperator<EmitEvents, SocketData> { + return this.newBroadcastOperator().local; + } + + /** + * Sets a modifier for a subsequent event emission that the callback will be called with an error when the + * given number of milliseconds have elapsed without an acknowledgement from the client: + * + * ``` + * socket.timeout(5000).emit("my-event", (err) => { + * if (err) { + * // the client did not acknowledge the event in the given delay + * } + * }); + * ``` + * + * @returns self + */ + public timeout(timeout: number): this { + this.flags.timeout = timeout; + return this; + } + + /** + * Returns the rooms the socket is currently in + */ + public get rooms(): Set<Room> { + return this.adapter.socketRooms(this.id) || new Set(); + } + + private newBroadcastOperator(): BroadcastOperator<EmitEvents, SocketData> { + const flags = Object.assign({}, this.flags); + this.flags = {}; + return new BroadcastOperator( + this.adapter, + new Set<Room>(), + new Set<Room>([this.id]), + flags, + ); + } +} diff --git a/packages/socket.io/mod.ts b/packages/socket.io/mod.ts new file mode 100644 index 0000000..11f4536 --- /dev/null +++ b/packages/socket.io/mod.ts @@ -0,0 +1 @@ +export { Server, type ServerOptions } from "./lib/server.ts"; diff --git a/packages/socket.io/test/broadcast.test.ts b/packages/socket.io/test/broadcast.test.ts new file mode 100644 index 0000000..c66620b --- /dev/null +++ b/packages/socket.io/test/broadcast.test.ts @@ -0,0 +1,87 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + eioPoll, + enableLogs, + runHandshake, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("broadcast", () => { + it("should emit to all sockets", () => { + const io = new Server({ + pingInterval: 50, + }); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.of("/custom"); + + const [sid1] = await runHandshake(port); + const [sid2] = await runHandshake(port); + const [sid3] = await runHandshake(port, "/custom"); + + io.of("/").emit("foo", "bar"); + + const [body1, body2, body3] = await Promise.all([ + eioPoll(port, sid1), + eioPoll(port, sid2), + eioPoll(port, sid3), + ]); + + assertEquals(body1, '42["foo","bar"]'); + assertEquals(body2, '42["foo","bar"]'); + assertEquals(body3, "2"); + + // drain buffer + await eioPoll(port, sid1); + await eioPoll(port, sid2); + + done(); + }, + ); + }); + + it("should emit to all sockets in a room", () => { + const io = new Server({ + pingInterval: 50, + }); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.of("/custom"); + + io.once("connection", (socket) => { + socket.join("room1"); + }); + + const [sid1] = await runHandshake(port); + const [sid2] = await runHandshake(port); + const [sid3] = await runHandshake(port, "/custom"); + + io.to("room1").emit("foo", "bar"); + + const [body1, body2, body3] = await Promise.all([ + eioPoll(port, sid1), + eioPoll(port, sid2), + eioPoll(port, sid3), + ]); + + assertEquals(body1, '42["foo","bar"]'); + assertEquals(body2, "2"); + assertEquals(body3, "2"); + + // drain buffer + await eioPoll(port, sid1); + + done(); + }, + ); + }); +}); diff --git a/packages/socket.io/test/event.test.ts b/packages/socket.io/test/event.test.ts new file mode 100644 index 0000000..dca3829 --- /dev/null +++ b/packages/socket.io/test/event.test.ts @@ -0,0 +1,188 @@ +// TODO +// - dynamic namespace + +import { + assertEquals, + assertIsError, + describe, + it, +} from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + eioPoll, + eioPush, + enableLogs, + runHandshake, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("event", () => { + it("should receive events", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + socket.on("random", (a, b, c) => { + assertEquals(a, 1); + assertEquals(b, "2"); + assertEquals(c, [3]); + + partialDone(); + }); + }); + + const [sid] = await runHandshake(port); + + await eioPush(port, sid, '42["random",1,"2",[3]]'); + + partialDone(); + }, + ); + }); + + it("should emit events", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + socket.emit("random", 4, "5", [6]); + + partialDone(); + }); + + const [_, firstPacket] = await runHandshake(port); + assertEquals(firstPacket, '42["random",4,"5",[6]]'); + + partialDone(); + }, + ); + }); + + it("should receive events with ack", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + socket.on("random", (a, b, c, callback) => { + assertEquals(a, 1); + assertEquals(b, "2"); + assertEquals(c, [3]); + callback("foo", 123); + + partialDone(); + }); + }); + + const [sid] = await runHandshake(port); + + await eioPush(port, sid, '421["random",1,"2",[3]]'); + + const body = await eioPoll(port, sid); + assertEquals(body, '431["foo",123]'); + + partialDone(); + }, + ); + }); + + it("should emit events with ack", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + socket.emit("random", 4, "5", [6], (a: string, b: number) => { + assertEquals(a, "bar"); + assertEquals(b, 456); + + partialDone(); + }); + }); + + const [sid, firstPacket] = await runHandshake(port); + assertEquals(firstPacket, '420["random",4,"5",[6]]'); + await eioPush(port, sid, '430["bar",456]'); + + partialDone(); + }, + ); + }); + + it("should timeout if the client does not acknowledge the event", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.on("connection", (socket) => { + socket.timeout(0).emit("unknown", (err: Error) => { + assertIsError(err); + + setTimeout(done, 10); + }); + }); + + await runHandshake(port); + }, + ); + }); + + it("should timeout if the client does not acknowledge the event in time", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.on("connection", (socket) => { + socket.timeout(0).emit("echo", 42, (err: Error) => { + assertIsError(err); + + setTimeout(done, 10); + }); + }); + + const [sid] = await runHandshake(port); + await eioPush(port, sid, "430[]"); + }, + ); + }); + + it("should not timeout if the client does acknowledge the event", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + socket.timeout(50).emit("echo", (err: Error, val: number) => { + assertEquals(err, null); + assertEquals(val, 42); + + partialDone(); + }); + }); + + const [sid] = await runHandshake(port); + await eioPush(port, sid, "430[42]"); + + partialDone(); + }, + ); + }); +}); diff --git a/packages/socket.io/test/handshake.test.ts b/packages/socket.io/test/handshake.test.ts new file mode 100644 index 0000000..b177ec5 --- /dev/null +++ b/packages/socket.io/test/handshake.test.ts @@ -0,0 +1,239 @@ +import { + assertEquals, + assertExists, + describe, + it, +} from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + eioPoll, + eioPush, + enableLogs, + testServeWithAsyncResults, +} from "./util.ts"; +import { parseSessionID } from "../../engine.io/test/util.ts"; + +await enableLogs(); + +describe("handshake", () => { + it("should trigger a connection event", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + assertExists(socket.id); + assertEquals(socket.handshake.address, "127.0.0.1"); + assertEquals(socket.handshake.auth, {}); + assertEquals(socket.handshake.xdomain, false); + assertEquals(socket.handshake.secure, false); // always false + assertEquals(socket.handshake.query.get("EIO"), "4"); + assertEquals(socket.handshake.query.get("transport"), "polling"); + assertEquals(socket.handshake.url, "/socket.io/"); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: "40", + }, + ); + + assertEquals(dataResponse.status, 200); + + // consume the response body + await dataResponse.body?.cancel(); + + const pollResponse = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + const body = await pollResponse.text(); + + assertEquals(body[0], "4"); // Engine.IO MESSAGE packet type + assertEquals(body[1], "0"); // Socket.IO CONNECT packet type + + const handshake = JSON.parse(body.substring(2)); + assertExists(handshake.sid); + + partialDone(); + }, + ); + }); + + it("should trigger a connection event with custom auth payload, header and query parameter", () => { + const server = new Server(); + + return testServeWithAsyncResults( + server, + 2, + async (port, partialDone) => { + server.on("connection", (socket) => { + assertExists(socket.id); + assertEquals(socket.handshake.query.get("foo"), "123"); + assertEquals(socket.handshake.headers.get("bar"), "456"); + assertEquals(socket.handshake.auth, { + foobar: "789", + }); + + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&foo=123`, + { + method: "get", + headers: { + bar: "456", + }, + }, + ); + + assertEquals(response.status, 200); + + const sid = await parseSessionID(response); + + const dataResponse = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body: '40{"foobar":"789"}', + }, + ); + + assertEquals(dataResponse.status, 200); + + // consume the response body + await dataResponse.body?.cancel(); + + const pollResponse = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + assertEquals(pollResponse.status, 200); + + // consume the response body + await pollResponse.body?.cancel(); + + partialDone(); + }, + ); + }); + + it("should trigger a connection event (custom namespace)", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.of("/custom").on("connection", (socket) => { + assertExists(socket.id); + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40/custom,"); + + const body = await eioPoll(port, sid); + assertEquals(body.startsWith("40/custom,{"), true); + + partialDone(); + }, + ); + }); + + it("should trigger a connection event (dynamic namespace)", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.of(/^\/dynamic-\d+$/).on("connection", (socket) => { + assertExists(socket.id); + partialDone(); + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + assertEquals(response.status, 200); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40/dynamic-101,"); + + const body = await eioPoll(port, sid); + assertEquals(body.startsWith("40/dynamic-101,{"), true); + + partialDone(); + }, + ); + }); + + it("should return an error when reaching a non-existent namespace", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40/unknown,"); + + const body = await eioPoll(port, sid); + + assertEquals(body, '44/unknown,{"message":"Invalid namespace"}'); + + done(); + }, + ); + }); +}); diff --git a/packages/socket.io/test/middleware.test.ts b/packages/socket.io/test/middleware.test.ts new file mode 100644 index 0000000..bf10719 --- /dev/null +++ b/packages/socket.io/test/middleware.test.ts @@ -0,0 +1,154 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { + eioPoll, + eioPush, + enableLogs, + parseSessionID, + runHandshake, + testServeWithAsyncResults, +} from "./util.ts"; + +await enableLogs(); + +describe("event", () => { + it("should call the middleware functions before the connection", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + const result: number[] = []; + + io.use((socket) => { + assertEquals(socket.connected, false); + result.push(1); + return Promise.resolve(); + }); + + io.use((_) => { + result.push(2); + return Promise.resolve(); + }); + + io.use((_) => { + result.push(3); + return Promise.resolve(); + }); + + io.on("connection", (socket) => { + assertEquals(socket.connected, true); + assertEquals(result, [1, 2, 3]); + + done(); + }); + + await runHandshake(port); + }, + ); + }); + + it("should be ignored if socket gets closed", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.use((socket) => { + socket.client.conn.close(); + setTimeout(done, 10); + return Promise.resolve(); + }); + + io.on("connection", (_) => { + throw "should not happen"; + }); + + await runHandshake(port); + }, + ); + }); + + it("should disallow connection", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.use((_) => { + throw "Authentication error"; + }); + + io.use((_) => { + throw "should not happen"; + }); + + io.on("connection", (_) => { + throw "should not happen"; + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40"); + const body = await eioPoll(port, sid); + + assertEquals(body, '44{"message":"Authentication error"}'); + + done(); + }, + ); + }); + + it("should disallow connection and include an error object", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 1, + async (port, done) => { + io.use((_) => { + throw { + message: "Authentication error", + data: { a: "b", c: 3 }, + }; + }); + + io.on("connection", (_) => { + throw "should not happen"; + }); + + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + await eioPush(port, sid, "40"); + const body = await eioPoll(port, sid); + + assertEquals( + body, + '44{"message":"Authentication error","data":{"a":"b","c":3}}', + ); + + done(); + }, + ); + }); + + it("should work with a custom namespace", () => { + }); +}); diff --git a/packages/socket.io/test/socket.test.ts b/packages/socket.io/test/socket.test.ts new file mode 100644 index 0000000..0c82aa5 --- /dev/null +++ b/packages/socket.io/test/socket.test.ts @@ -0,0 +1,50 @@ +import { assertEquals, describe, it } from "../../../test_deps.ts"; +import { Server } from "../lib/server.ts"; +import { enableLogs, runHandshake, testServeWithAsyncResults } from "./util.ts"; + +await enableLogs(); + +describe("socket", () => { + it("should keep track of rooms", () => { + const io = new Server(); + + return testServeWithAsyncResults( + io, + 2, + async (port, partialDone) => { + io.on("connection", (socket) => { + assertEquals(socket.rooms.size, 1); + assertEquals(socket.rooms.has(socket.id), true); + + socket.join("room1"); + + assertEquals(socket.rooms.size, 2); + assertEquals(socket.rooms.has("room1"), true); + + socket.leave("room1"); + + assertEquals(socket.rooms.size, 1); + assertEquals(socket.rooms.has("room1"), false); + + socket.join("room2"); + + socket.on("disconnecting", () => { + assertEquals(socket.rooms.has("room2"), true); + + partialDone(); + }); + + socket.on("disconnect", () => { + assertEquals(socket.rooms.size, 0); + + partialDone(); + }); + + socket.disconnect(); + }); + + await runHandshake(port); + }, + ); + }); +}); diff --git a/packages/socket.io/test/util.ts b/packages/socket.io/test/util.ts new file mode 100644 index 0000000..262a70f --- /dev/null +++ b/packages/socket.io/test/util.ts @@ -0,0 +1,114 @@ +import { Server } from "../lib/server.ts"; +import * as log from "../../../test_deps.ts"; +import { serve } from "../../../test_deps.ts"; + +function createPartialDone( + count: number, + resolve: () => void, + reject: (reason: string) => void, +) { + let i = 0; + return () => { + if (++i === count) { + resolve(); + } else if (i > count) { + reject(`called too many times: ${i} > ${count}`); + } + }; +} + +export function testServeWithAsyncResults( + server: Server, + count: number, + callback: (port: number, partialDone: () => void) => Promise<void> | void, +): Promise<void> { + return new Promise((resolve, reject) => { + const abortController = new AbortController(); + + serve(server.handler(), { + onListen: ({ port }) => { + const partialDone = createPartialDone(count, () => { + setTimeout(() => { + // close the server + abortController.abort(); + server.close(); + + setTimeout(resolve, 10); + }, 10); + }, reject); + + return callback(port, partialDone); + }, + signal: abortController.signal, + }); + }); +} + +export async function parseSessionID(response: Response): Promise<string> { + const body = await response.text(); + return JSON.parse(body.substring(1)).sid; +} + +export async function runHandshake( + port: number, + namespace = "/", +): Promise<string[]> { + // Engine.IO handshake + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling`, + { + method: "get", + }, + ); + + const sid = await parseSessionID(response); + + // Socket.IO handshake + await eioPush(port, sid, namespace === "/" ? "40" : `40${namespace},`); + const body = await eioPoll(port, sid); + // might be defined if an event is emitted in the "connection" handler + const firstPacket = body.substring(33); // length of '40{"sid":"xxx"}' + 1 for the separator character + + return [sid, firstPacket]; +} + +export async function eioPoll(port: number, sid: string) { + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "get", + }, + ); + + return response.text(); +} +export async function eioPush(port: number, sid: string, body: BodyInit) { + const response = await fetch( + `http://localhost:${port}/socket.io/?EIO=4&transport=polling&sid=${sid}`, + { + method: "post", + body, + }, + ); + + // consume the response body + await response.body?.cancel(); +} + +export function enableLogs() { + return log.setup({ + handlers: { + console: new log.handlers.ConsoleHandler("DEBUG"), + }, + loggers: { + "engine.io": { + level: "ERROR", // set to "DEBUG" to display the Engine.IO logs + handlers: ["console"], + }, + "socket.io": { + level: "ERROR", // set to "DEBUG" to display the Socket.IO logs + handlers: ["console"], + }, + }, + }); +} diff --git a/test_deps.ts b/test_deps.ts new file mode 100644 index 0000000..c453e73 --- /dev/null +++ b/test_deps.ts @@ -0,0 +1,7 @@ +export * from "https://deno.land/std@0.150.0/testing/asserts.ts"; + +export { describe, it } from "https://deno.land/std@0.150.0/testing/bdd.ts"; + +export { serve } from "https://deno.land/std@0.150.0/http/server.ts"; + +export * from "https://deno.land/std@0.150.0/log/mod.ts";