diff --git a/README.md b/README.md index 2cbef13e8..d767f2c15 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,14 @@ Default: 600 (10 minutes) WebSocket port to listen on.
Default: 3001 +The `websocketPort` may also be set to the same as `port` so that a single port can be used for mulitple protocols. + +Note: In the event the ports are the same: + +- The [Connections API Server](#websocket-connections-api) will be hosted on `lambdaPort`. + +### CLI Options in `serverless.yml` + Any of the CLI options can be added to your `serverless.yml`. For example: ```yml @@ -722,20 +730,25 @@ Example response velocity template: }, ``` -## WebSocket +## WebSocket Connections API -Usage in order to send messages back to clients: +The `connections-port` for the connections API is available at the following endpoint: -`POST http://localhost:3001/@connections/{connectionId}` +- if `websocketPort == 3001`: (connections API and websocket share `websocketPort`) + - `POST http://localhost:3001/@connections/{connectionId}` +- if `websocketPort == port`: (connections API is bound to the `lambdaPort`) + - `POST http://localhost:3002/@connections/{connectionId}` Or, ```js import aws from 'aws-sdk' +const connectionsPort = 3001; // Or 3002 if websocketPort === port in serverless offline options + const apiGatewayManagementApi = new aws.ApiGatewayManagementApi({ apiVersion: '2018-11-29', - endpoint: 'http://localhost:3001', + endpoint: `http://localhost:${connectionsPort}`, }); apiGatewayManagementApi.postToConnection({ diff --git a/src/events/alb/HttpServer.js b/src/events/alb/HttpServer.js index 172509c06..971e8df8a 100644 --- a/src/events/alb/HttpServer.js +++ b/src/events/alb/HttpServer.js @@ -1,6 +1,4 @@ import { Buffer } from "node:buffer" -import { exit } from "node:process" -import { Server } from "@hapi/hapi" import { log } from "../../utils/log.js" import { detectEncoding, @@ -9,43 +7,30 @@ import { } from "../../utils/index.js" import LambdaAlbRequestEvent from "./lambda-events/LambdaAlbRequestEvent.js" import logRoutes from "../../utils/logRoutes.js" +import AbstractHttpServer from "../../lambda/AbstractHttpServer.js" const { stringify } = JSON const { entries } = Object -export default class HttpServer { +export default class HttpServer extends AbstractHttpServer { #lambda = null #options = null #serverless = null - #server = null - #terminalInfo = [] constructor(serverless, options, lambda) { + super(lambda, options, options.albPort) + this.#serverless = serverless this.#options = options this.#lambda = lambda } async createServer() { - const { host, albPort } = this.#options - - const serverOptions = { - host, - port: albPort, - router: { - // allows for paths with trailing slashes to be the same as without - // e.g. : /my-path is the same as /my-path/ - stripTrailingSlash: true, - }, - } - - this.#server = new Server(serverOptions) - - this.#server.ext("onPreResponse", (request, h) => { + this.httpServer.ext("onPreResponse", (request, h) => { if (request.headers.origin) { const response = request.response.isBoom ? request.response.output @@ -134,32 +119,13 @@ export default class HttpServer { } async start() { - const { albPort, host, httpsProtocol } = this.#options - - try { - await this.#server.start() - } catch (err) { - log.error( - `Unexpected error while starting serverless-offline alb server on port ${albPort}:`, - err, - ) - exit(1) - } - - // TODO move the following block - const server = `${httpsProtocol ? "https" : "http"}://${host}:${albPort}` + await super.start() - log.notice(`ALB Server ready: ${server} 🚀`) - } - - stop(timeout) { - return this.#server.stop({ - timeout, - }) + log.notice(`${this.serverName} Server ready: ${this.basePath} 🚀`) } get server() { - return this.#server.listener + return this.httpServer.listener } #createHapiHandler(params) { @@ -346,7 +312,7 @@ export default class HttpServer { stage, }) - this.#server.route({ + this.httpServer.route({ handler: hapiHandler, method: hapiMethod, options: hapiOptions, diff --git a/src/events/http/HttpServer.js b/src/events/http/HttpServer.js index 59e07b225..b8094105d 100644 --- a/src/events/http/HttpServer.js +++ b/src/events/http/HttpServer.js @@ -1,10 +1,8 @@ import { Buffer } from "node:buffer" -import { readFile } from "node:fs/promises" + import { createRequire } from "node:module" import { join, resolve } from "node:path" -import { exit } from "node:process" import h2o2 from "@hapi/h2o2" -import { Server } from "@hapi/hapi" import { log } from "../../utils/log.js" import authFunctionNameExtractor from "../authFunctionNameExtractor.js" import authJWTSettingsExtractor from "./authJWTSettingsExtractor.js" @@ -30,11 +28,12 @@ import { jsonPath, splitHandlerPathAndName, } from "../../utils/index.js" +import AbstractHttpServer from "../../lambda/AbstractHttpServer.js" const { parse, stringify } = JSON const { assign, entries, keys } = Object -export default class HttpServer { +export default class HttpServer extends AbstractHttpServer { #apiKeysValues = null #hasPrivateHttpEvent = false @@ -43,68 +42,26 @@ export default class HttpServer { #options = null - #server = null - #serverless = null #terminalInfo = [] constructor(serverless, options, lambda) { + super(lambda, options, options.httpPort) this.#lambda = lambda this.#options = options this.#serverless = serverless } - async #loadCerts(httpsProtocol) { - const [cert, key] = await Promise.all([ - readFile(resolve(httpsProtocol, "cert.pem"), "utf8"), - readFile(resolve(httpsProtocol, "key.pem"), "utf8"), - ]) - - return { - cert, - key, - } - } - async createServer() { - const { enforceSecureCookies, host, httpPort, httpsProtocol } = - this.#options - - const serverOptions = { - host, - port: httpPort, - router: { - stripTrailingSlash: true, - }, - state: enforceSecureCookies - ? { - isHttpOnly: true, - isSameSite: false, - isSecure: true, - } - : { - isHttpOnly: false, - isSameSite: false, - isSecure: false, - }, - // https support - ...(httpsProtocol != null && { - tls: await this.#loadCerts(httpsProtocol), - }), - } - - // Hapijs server creation - this.#server = new Server(serverOptions) - try { - await this.#server.register([h2o2]) + await this.httpServer.register([h2o2]) } catch (err) { log.error(err) } // Enable CORS preflight response - this.#server.ext("onPreResponse", (request, h) => { + this.httpServer.ext("onPreResponse", (request, h) => { if (request.headers.origin) { const response = request.response.isBoom ? request.response.output @@ -193,29 +150,9 @@ export default class HttpServer { } async start() { - const { host, httpPort, httpsProtocol } = this.#options + await super.start() - try { - await this.#server.start() - } catch (err) { - log.error( - `Unexpected error while starting serverless-offline server on port ${httpPort}:`, - err, - ) - exit(1) - } - - // TODO move the following block - const server = `${httpsProtocol ? "https" : "http"}://${host}:${httpPort}` - - log.notice(`Server ready: ${server} 🚀`) - } - - // stops the server - stop(timeout) { - return this.#server.stop({ - timeout, - }) + log.notice(`Server ready: ${this.basePath} 🚀`) } #logPluginIssue() { @@ -278,8 +215,8 @@ export default class HttpServer { const scheme = createJWTAuthScheme(jwtSettings) // Set the auth scheme and strategy on the server - this.#server.auth.scheme(authSchemeName, scheme) - this.#server.auth.strategy(authStrategyName, authSchemeName) + this.httpServer.auth.scheme(authSchemeName, scheme) + this.httpServer.auth.strategy(authStrategyName, authSchemeName) return authStrategyName } @@ -387,8 +324,8 @@ export default class HttpServer { ) // Set the auth scheme and strategy on the server - this.#server.auth.scheme(authSchemeName, scheme) - this.#server.auth.strategy(authStrategyName, authSchemeName) + this.httpServer.auth.scheme(authSchemeName, scheme) + this.httpServer.auth.strategy(authStrategyName, authSchemeName) return authStrategyName } @@ -416,11 +353,11 @@ export default class HttpServer { const strategy = provider(endpoint, functionKey, method, path) - this.#server.auth.scheme( + this.httpServer.auth.scheme( strategy.scheme, strategy.getAuthenticateFunction, ) - this.#server.auth.strategy(strategy.name, strategy.scheme) + this.httpServer.auth.strategy(strategy.name, strategy.scheme) return strategy.name } @@ -1118,7 +1055,7 @@ export default class HttpServer { stage, }) - this.#server.route({ + this.httpServer.route({ handler: hapiHandler, method: hapiMethod, options: hapiOptions, @@ -1267,17 +1204,17 @@ export default class HttpServer { path: hapiPath, } - this.#server.route(route) + this.httpServer.route(route) }) } create404Route() { // If a {proxy+} or $default route exists, don't conflict with it - if (this.#server.match("*", "/{p*}")) { + if (this.httpServer.match("*", "/{p*}")) { return } - const existingRoutes = this.#server + const existingRoutes = this.httpServer .table() // Exclude this (404) route .filter((route) => route.path !== "/{p*}") @@ -1305,7 +1242,7 @@ export default class HttpServer { path: "/{p*}", } - this.#server.route(route) + this.httpServer.route(route) } #getArrayStackTrace(stack) { @@ -1329,6 +1266,6 @@ export default class HttpServer { // TEMP FIXME quick fix to expose gateway server for testing, look for better solution getServer() { - return this.#server + return this.httpServer } } diff --git a/src/events/websocket/HttpServer.js b/src/events/websocket/HttpServer.js index c50289398..2edf9d7ef 100644 --- a/src/events/websocket/HttpServer.js +++ b/src/events/websocket/HttpServer.js @@ -1,87 +1,62 @@ -import { readFile } from "node:fs/promises" -import { resolve } from "node:path" -import { exit } from "node:process" -import { Server } from "@hapi/hapi" -import { log } from "../../utils/log.js" +import AbstractHttpServer from "../../lambda/AbstractHttpServer.js" import { catchAllRoute, connectionsRoutes } from "./http-routes/index.js" -export default class HttpServer { +export default class HttpServer extends AbstractHttpServer { #options = null - #server = null + #lambda = null #webSocketClients = null - constructor(options, webSocketClients) { + constructor(options, lambda, webSocketClients) { + super(lambda, options, options.websocketPort) this.#options = options + this.#lambda = lambda this.#webSocketClients = webSocketClients } - async #loadCerts(httpsProtocol) { - const [cert, key] = await Promise.all([ - readFile(resolve(httpsProtocol, "cert.pem"), "utf8"), - readFile(resolve(httpsProtocol, "key.pem"), "utf8"), - ]) - - return { - cert, - key, - } + async createServer() { + // No-op } - async createServer() { - const { host, httpsProtocol, websocketPort } = this.#options + async start() { + // add routes + this.httpServer.route(connectionsRoutes(this.#webSocketClients)) - const serverOptions = { - host, - port: websocketPort, - router: { - stripTrailingSlash: true, - }, - // https support - ...(httpsProtocol != null && { - tls: await this.#loadCerts(httpsProtocol), - }), + if (this.#options.websocketPort !== this.#options.httpPort) { + this.httpServer.route([catchAllRoute()]) } - this.#server = new Server(serverOptions) + await super.start() } - async start() { - // add routes - const routes = [ - ...connectionsRoutes(this.#webSocketClients), - catchAllRoute(), - ] - this.#server.route(routes) + async stop(timeout) { + if (this.#options.websocketPort === this.#options.httpPort) { + return + } - const { host, httpsProtocol, websocketPort } = this.#options + await super.stop(timeout) + } - try { - await this.#server.start() - } catch (err) { - log.error( - `Unexpected error while starting serverless-offline websocket server on port ${websocketPort}:`, - err, - ) - exit(1) - } + get listener() { + return this.#lambda.getServer(this.#options.websocketPort).listener + } - log.notice( - `Offline [http for websocket] listening on ${ - httpsProtocol ? "https" : "http" - }://${host}:${websocketPort}`, + get httpServer() { + return this.#lambda.getServer( + this.#options.websocketPort === this.#options.httpPort + ? this.#options.lambdaPort + : this.#options.websocketPort, ) } - // stops the server - stop(timeout) { - return this.#server.stop({ - timeout, - }) + get serverName() { + return "websocket" } - get server() { - return this.#server.listener + get port() { + return this.#options.websocketPort === this.#options.httpPort + ? this.#options.lambdaPort + : this.#options.websocketPort } } diff --git a/src/events/websocket/WebSocket.js b/src/events/websocket/WebSocket.js index 19b543855..f0d5c75e3 100644 --- a/src/events/websocket/WebSocket.js +++ b/src/events/websocket/WebSocket.js @@ -27,7 +27,11 @@ export default class WebSocket { this.#lambda, ) - this.#httpServer = new HttpServer(this.#options, webSocketClients) + this.#httpServer = new HttpServer( + this.#options, + this.#lambda, + webSocketClients, + ) await this.#httpServer.createServer() @@ -35,7 +39,7 @@ export default class WebSocket { this.#webSocketServer = new WebSocketServer( this.#options, webSocketClients, - this.#httpServer.server, + this.#httpServer.listener, ) await this.#webSocketServer.createServer() diff --git a/src/events/websocket/WebSocketServer.js b/src/events/websocket/WebSocketServer.js index a421f5a4c..4fbb731df 100644 --- a/src/events/websocket/WebSocketServer.js +++ b/src/events/websocket/WebSocketServer.js @@ -7,19 +7,19 @@ export default class WebSocketServer { #options = null - #sharedServer = null + #sharedListener = null #webSocketClients = null - constructor(options, webSocketClients, sharedServer) { + constructor(options, webSocketClients, sharedListener) { this.#options = options - this.#sharedServer = sharedServer + this.#sharedListener = sharedListener this.#webSocketClients = webSocketClients } async createServer() { const server = new WsWebSocketServer({ - server: this.#sharedServer, + server: this.#sharedListener, verifyClient: async ({ req }, cb) => { const connectionId = crypto.randomUUID() const key = req.headers["sec-websocket-key"] diff --git a/src/lambda/AbstractHttpServer.js b/src/lambda/AbstractHttpServer.js new file mode 100644 index 000000000..51fff9618 --- /dev/null +++ b/src/lambda/AbstractHttpServer.js @@ -0,0 +1,124 @@ +import { exit } from "node:process" +import { readFileSync } from "node:fs" +import { resolve } from "node:path" +import { Server } from "@hapi/hapi" +import { log } from "../utils/log.js" + +function loadCerts(httpsProtocol) { + return { + cert: readFileSync(resolve(httpsProtocol, "cert.pem"), "utf8"), + key: readFileSync(resolve(httpsProtocol, "key.pem"), "utf8"), + } +} + +export default class AbstractHttpServer { + #lambda = null + + #options = null + + #port = null + + #started = false + + constructor(lambda, options, port) { + this.#lambda = lambda + this.#options = options + this.#port = port + + if (this.#lambda.getServer(port)) { + return + } + + const { host, httpsProtocol, enforceSecureCookies } = options + + const server = new Server({ + host, + port, + router: { + stripTrailingSlash: true, + }, + state: enforceSecureCookies + ? { + isHttpOnly: true, + isSameSite: false, + isSecure: true, + } + : { + isHttpOnly: false, + isSameSite: false, + isSecure: false, + }, + ...(httpsProtocol != null && { + tls: loadCerts(httpsProtocol), + }), + }) + + this.#lambda.putServer(port, server) + } + + async start() { + if (this.#started) { + return + } + this.#started = true + + try { + await this.httpServer.start() + } catch (err) { + log.error( + `Unexpected error while starting serverless-offline ${this.serverName} server on port ${this.port}:`, + err, + ) + exit(1) + } + + log.notice( + `Offline [http for ${this.serverName}] listening on ${this.basePath}`, + ) + } + + stop(timeout) { + if (!this.#started) { + return Promise.resolve() + } + this.#started = false + return this.httpServer.stop({ timeout }) + } + + get httpServer() { + return this.#lambda.getServer(this.port) + } + + get listener() { + return this.httpServer.listener + } + + get port() { + return this.#port + } + + get basePath() { + const { host, httpsProtocol } = this.#options + return `${httpsProtocol ? "https" : "http"}://${host}:${this.port}` + } + + get serverName() { + if (this.port === this.#options.lambdaPort) { + return "lambda" + } + + if (this.port === this.#options.httpPort) { + return "api gateway" + } + + if (this.port === this.#options.websocketPort) { + return "websocket" + } + + if (this.port === this.#options.albPort) { + return "alb" + } + + return "unknown" + } +} diff --git a/src/lambda/HttpServer.js b/src/lambda/HttpServer.js index d2eb946e2..24fae8323 100644 --- a/src/lambda/HttpServer.js +++ b/src/lambda/HttpServer.js @@ -1,27 +1,20 @@ -import { exit } from "node:process" -import { Server } from "@hapi/hapi" import { log } from "../utils/log.js" import { invocationsRoute, invokeAsyncRoute } from "./routes/index.js" +import AbstractHttpServer from "./AbstractHttpServer.js" -export default class HttpServer { +export default class HttpServer extends AbstractHttpServer { #lambda = null #options = null - #server = null + constructor(lambda, options) { + super(lambda, options, options.lambdaPort) - constructor(options, lambda) { this.#lambda = lambda this.#options = options - const { host, lambdaPort } = options - - const serverOptions = { - host, - port: lambdaPort, - } - - this.#server = new Server(serverOptions) + // disable the default stripTrailingSlash + this.httpServer.settings.router.stripTrailingSlash = false } async start() { @@ -29,30 +22,11 @@ export default class HttpServer { const invRoute = invocationsRoute(this.#lambda, this.#options) const invAsyncRoute = invokeAsyncRoute(this.#lambda, this.#options) - this.#server.route([invAsyncRoute, invRoute]) - - const { host, httpsProtocol, lambdaPort } = this.#options - - try { - await this.#server.start() - } catch (err) { - log.error( - `Unexpected error while starting serverless-offline lambda server on port ${lambdaPort}:`, - err, - ) - exit(1) - } + this.httpServer.route([invAsyncRoute, invRoute]) - log.notice( - `Offline [http for lambda] listening on ${ - httpsProtocol ? "https" : "http" - }://${host}:${lambdaPort}`, - ) + await super.start() // Print all the invocation routes to debug - const basePath = `${ - httpsProtocol ? "https" : "http" - }://${host}:${lambdaPort}` const funcNamePairs = this.#lambda.listFunctionNamePairs() log.notice( @@ -75,7 +49,7 @@ export default class HttpServer { (functionName) => ` * ${ invRoute.method - } ${basePath}${invRoute.path.replace( + } ${this.basePath}${invRoute.path.replace( "{functionName}", functionName, )}`, @@ -92,7 +66,7 @@ export default class HttpServer { (functionName) => ` * ${ invAsyncRoute.method - } ${basePath}${invAsyncRoute.path.replace( + } ${this.basePath}${invAsyncRoute.path.replace( "{functionName}", functionName, )}`, @@ -100,11 +74,4 @@ export default class HttpServer { ].join("\n"), ) } - - // stops the server - stop(timeout) { - return this.#server.stop({ - timeout, - }) - } } diff --git a/src/lambda/Lambda.js b/src/lambda/Lambda.js index 779a8ef35..db19fd79d 100644 --- a/src/lambda/Lambda.js +++ b/src/lambda/Lambda.js @@ -6,6 +6,10 @@ const { assign } = Object export default class Lambda { #httpServer = null + #httpServers = new Map() + + #options = null + #lambdas = new Map() #lambdaFunctionNamesKeys = new Map() @@ -13,8 +17,9 @@ export default class Lambda { #lambdaFunctionPool = null constructor(serverless, options) { - this.#httpServer = new HttpServer(options, this) - this.#lambdaFunctionPool = new LambdaFunctionPool(serverless, options) + this.#httpServer = new HttpServer(this, options) + this.#options = options + this.#lambdaFunctionPool = new LambdaFunctionPool(serverless, this.#options) } #createEvent(functionKey, functionDefinition) { @@ -53,7 +58,6 @@ export default class Lambda { start() { this.#lambdaFunctionPool.start() - return this.#httpServer.start() } @@ -65,4 +69,12 @@ export default class Lambda { cleanup() { return this.#lambdaFunctionPool.cleanup() } + + putServer(port, server) { + this.#httpServers.set(port, server) + } + + getServer(port) { + return this.#httpServers.get(port) + } } diff --git a/tests/integration/websocket-oneway-shared/serverless.yml b/tests/integration/websocket-oneway-shared/serverless.yml new file mode 100644 index 000000000..41bf43e75 --- /dev/null +++ b/tests/integration/websocket-oneway-shared/serverless.yml @@ -0,0 +1,35 @@ +service: oneway-shared-websocket-tests + +configValidationMode: error +deprecationNotificationMode: error + +plugins: + - ../../../src/index.js + +provider: + architecture: arm64 + deploymentMethod: direct + memorySize: 1024 + name: aws + region: us-east-1 + runtime: nodejs18.x + stage: dev + versionFunctions: false + +functions: + handler: + events: + - http: + method: get + path: echo + - websocket: + route: $connect + - websocket: + route: $disconnect + - websocket: + route: $default + handler: src/handler.handler + +custom: + serverless-offline: + websocketPort: 3000 diff --git a/tests/integration/websocket-oneway-shared/src/handler.js b/tests/integration/websocket-oneway-shared/src/handler.js new file mode 100644 index 000000000..6d234640d --- /dev/null +++ b/tests/integration/websocket-oneway-shared/src/handler.js @@ -0,0 +1,19 @@ +const { parse } = JSON + +export async function handler(event) { + const { body, requestContext } = event + + if ( + body && + parse(body).throwError && + requestContext && + requestContext.routeKey === "$default" + ) { + throw new Error("Throwing error from incoming message") + } + + return { + body: body || undefined, + statusCode: 200, + } +} diff --git a/tests/integration/websocket-oneway-shared/src/package.json b/tests/integration/websocket-oneway-shared/src/package.json new file mode 100644 index 000000000..3dbc1ca59 --- /dev/null +++ b/tests/integration/websocket-oneway-shared/src/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/tests/integration/websocket-oneway-shared/websocket-oneway-shared.test.js b/tests/integration/websocket-oneway-shared/websocket-oneway-shared.test.js new file mode 100644 index 000000000..3db39da36 --- /dev/null +++ b/tests/integration/websocket-oneway-shared/websocket-oneway-shared.test.js @@ -0,0 +1,55 @@ +import assert from "node:assert" +import { join } from "desm" +import { WebSocket } from "ws" +import { setup, teardown } from "../../_testHelpers/index.js" +import websocketSend from "../../_testHelpers/websocketPromise.js" +import { BASE_URL } from "../../config.js" + +const { parse, stringify } = JSON + +describe("one way websocket tests on shared port", function desc() { + beforeEach(() => + setup({ + servicePath: join(import.meta.url), + }), + ) + + afterEach(() => teardown()) + + it("websocket echos nothing", async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url) + const { data, code, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + assert.equal(err, undefined) + assert.equal(data, undefined) + }) + + it("execution error emits Internal Server Error", async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + throwError: true, + }) + + const ws = new WebSocket(url) + const { data, code, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + assert.equal(err, undefined) + assert.equal(parse(data).message, "Internal server error") + }) +}) diff --git a/tests/integration/websocket-twoway-shared/serverless.yml b/tests/integration/websocket-twoway-shared/serverless.yml new file mode 100644 index 000000000..1f3643c5d --- /dev/null +++ b/tests/integration/websocket-twoway-shared/serverless.yml @@ -0,0 +1,37 @@ +service: twoway-shared-websocket-tests + +configValidationMode: error +deprecationNotificationMode: error + +plugins: + - ../../../src/index.js + +provider: + architecture: arm64 + deploymentMethod: direct + memorySize: 1024 + name: aws + region: us-east-1 + runtime: nodejs18.x + stage: dev + versionFunctions: false + +functions: + handler: + events: + - http: + method: get + path: echo + - websocket: + route: $connect + - websocket: + route: $disconnect + - websocket: + route: $default + # Enable 2-way comms + routeResponseSelectionExpression: $default + handler: src/handler.handler + +custom: + serverless-offline: + websocketPort: 3000 diff --git a/tests/integration/websocket-twoway-shared/src/handler.js b/tests/integration/websocket-twoway-shared/src/handler.js new file mode 100644 index 000000000..55dbdb570 --- /dev/null +++ b/tests/integration/websocket-twoway-shared/src/handler.js @@ -0,0 +1,32 @@ +const { parse } = JSON + +export async function handler(event) { + const { body, queryStringParameters, requestContext } = event + const statusCode = + queryStringParameters && queryStringParameters.statusCode + ? Number(queryStringParameters.statusCode) + : 200 + + if ( + queryStringParameters && + queryStringParameters.throwError && + requestContext && + requestContext.routeKey === "$connect" + ) { + throw new Error("Throwing error during connect phase") + } + + if ( + body && + parse(body).throwError && + requestContext && + requestContext.routeKey === "$default" + ) { + throw new Error("Throwing error from incoming message") + } + + return { + body: body || undefined, + statusCode, + } +} diff --git a/tests/integration/websocket-twoway-shared/src/package.json b/tests/integration/websocket-twoway-shared/src/package.json new file mode 100644 index 000000000..3dbc1ca59 --- /dev/null +++ b/tests/integration/websocket-twoway-shared/src/package.json @@ -0,0 +1,3 @@ +{ + "type": "module" +} diff --git a/tests/integration/websocket-twoway-shared/websocket-twoway-shared.test.js b/tests/integration/websocket-twoway-shared/websocket-twoway-shared.test.js new file mode 100644 index 000000000..bff283f3a --- /dev/null +++ b/tests/integration/websocket-twoway-shared/websocket-twoway-shared.test.js @@ -0,0 +1,102 @@ +import assert from "node:assert" +import { join } from "desm" +import { WebSocket } from "ws" +import { setup, teardown } from "../../_testHelpers/index.js" +import websocketSend from "../../_testHelpers/websocketPromise.js" +import { BASE_URL } from "../../config.js" + +const { parse, stringify } = JSON + +describe("two way websocket tests on shared port", function desc() { + beforeEach(() => + setup({ + servicePath: join(import.meta.url), + }), + ) + + afterEach(() => teardown()) + + it("websocket echos sent message", async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url) + const { code, data, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + assert.equal(err, undefined) + assert.deepEqual(data, payload) + }) + + // + ;[401, 500, 501, 502].forEach((statusCode) => { + it(`websocket connection emits status code ${statusCode}`, async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.searchParams.set("statusCode", statusCode) + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url) + const { code, data, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + + if (statusCode >= 200 && statusCode < 300) { + assert.equal(err, undefined) + assert.deepEqual(data, payload) + } else { + assert.equal(err.message, `Unexpected server response: ${statusCode}`) + assert.equal(data, undefined) + } + }) + }) + + it("websocket emits 502 on connection error", async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.searchParams.set("throwError", "true") + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url) + const { code, data, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + assert.equal(err.message, "Unexpected server response: 502") + assert.equal(data, undefined) + }) + + it("execution error emits Internal Server Error", async () => { + const url = new URL("/dev", BASE_URL) + url.port = url.port ? "3000" : url.port + url.protocol = "ws" + + const payload = stringify({ + hello: "world", + now: new Date().toISOString(), + throwError: true, + }) + + const ws = new WebSocket(url) + const { code, data, err } = await websocketSend(ws, payload) + + assert.equal(code, undefined) + assert.equal(err, undefined) + assert.equal(parse(data).message, "Internal server error") + }) +})