diff --git a/benchmarks/operations/benchmark.js b/benchmarks/operations/benchmark.js index 0c36f5422..8961b6093 100644 --- a/benchmarks/operations/benchmark.js +++ b/benchmarks/operations/benchmark.js @@ -30,6 +30,10 @@ const modules = [ module: "http-request", name: "Outgoing HTTP request (`http.request`)", }, + { + module: "undici", + name: "Outgoing HTTP request (`undici.request`)", + }, ]; (async () => { diff --git a/benchmarks/operations/http-request.js b/benchmarks/operations/http-request.js index c8b51d248..500a8756d 100644 --- a/benchmarks/operations/http-request.js +++ b/benchmarks/operations/http-request.js @@ -3,13 +3,8 @@ const http = require("http"); module.exports = { step: async function step() { return new Promise((resolve, reject) => { - const options = { - hostname: "localhost", - port: 10411, - }; - - const req = http.request(options, (res) => { - res.on("data", () => {}); + const req = http.request("http://localhost:10411", (res) => { + res.resume(); res.on("end", resolve); }); diff --git a/benchmarks/operations/package-lock.json b/benchmarks/operations/package-lock.json index 4594d6f90..2722b4f51 100644 --- a/benchmarks/operations/package-lock.json +++ b/benchmarks/operations/package-lock.json @@ -10,7 +10,8 @@ "dependencies": { "@aikidosec/firewall": "file:../../build", "better-sqlite3": "^11.7.2", - "mongodb": "^6.9.0" + "mongodb": "^6.9.0", + "undici": "^7.3.0" } }, "../../build": { @@ -540,6 +541,14 @@ "node": "*" } }, + "node_modules/undici": { + "version": "7.3.0", + "resolved": "https://registry.npmjs.org/undici/-/undici-7.3.0.tgz", + "integrity": "sha512-Qy96NND4Dou5jKoSJ2gm8ax8AJM/Ey9o9mz7KN1bb9GP+G0l20Zw8afxTnY2f4b7hmhn/z8aC2kfArVQlAhFBw==", + "engines": { + "node": ">=20.18.1" + } + }, "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", diff --git a/benchmarks/operations/package.json b/benchmarks/operations/package.json index 20956f93d..736b97aad 100644 --- a/benchmarks/operations/package.json +++ b/benchmarks/operations/package.json @@ -7,6 +7,7 @@ "dependencies": { "@aikidosec/firewall": "file:../../build", "better-sqlite3": "^11.7.2", - "mongodb": "^6.9.0" + "mongodb": "^6.9.0", + "undici": "^7.3.0" } } diff --git a/benchmarks/operations/undici.js b/benchmarks/operations/undici.js new file mode 100644 index 000000000..bda22d2a0 --- /dev/null +++ b/benchmarks/operations/undici.js @@ -0,0 +1,7 @@ +const undici = require("undici"); + +module.exports = { + step: async function step() { + return await undici.request("http://localhost:10411"); + }, +}; diff --git a/library/sinks/Fetch.ts b/library/sinks/Fetch.ts index b9c767681..a0405dc6e 100644 --- a/library/sinks/Fetch.ts +++ b/library/sinks/Fetch.ts @@ -8,6 +8,7 @@ import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { tryParseURL } from "../helpers/tryParseURL"; import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { Hostname } from "../vulnerabilities/ssrf/Hostname"; import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; import { wrapDispatch } from "./undici/wrapDispatch"; @@ -16,13 +17,13 @@ export class Fetch implements Wrapper { private inspectHostname( agent: Agent, - hostname: string, + hostname: Hostname, port: number | undefined ): InterceptorResult { // Let the agent know that we are connecting to this hostname // This is to build a list of all hostnames that the application is connecting to if (typeof port === "number" && port > 0) { - agent.onConnectHostname(hostname, port); + agent.onConnectHostname(hostname.asString(), port); } const context = getContext(); @@ -46,7 +47,7 @@ export class Fetch implements Wrapper { if (url) { const attack = this.inspectHostname( agent, - url.hostname, + Hostname.fromURL(url), getPortFromURL(url) ); if (attack) { @@ -64,7 +65,7 @@ export class Fetch implements Wrapper { if (url) { const attack = this.inspectHostname( agent, - url.hostname, + Hostname.fromURL(url), getPortFromURL(url) ); if (attack) { @@ -77,7 +78,7 @@ export class Fetch implements Wrapper { if (args[0] instanceof URL && args[0].hostname.length > 0) { const attack = this.inspectHostname( agent, - args[0].hostname, + Hostname.fromURL(args[0]), getPortFromURL(args[0]) ); if (attack) { @@ -91,7 +92,7 @@ export class Fetch implements Wrapper { if (url) { const attack = this.inspectHostname( agent, - url.hostname, + Hostname.fromURL(url), getPortFromURL(url) ); if (attack) { diff --git a/library/sinks/HTTPRequest.ts b/library/sinks/HTTPRequest.ts index 8da9b4077..c1ad8e427 100644 --- a/library/sinks/HTTPRequest.ts +++ b/library/sinks/HTTPRequest.ts @@ -7,6 +7,7 @@ import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { Hostname } from "../vulnerabilities/ssrf/Hostname"; import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; import { isRedirectToPrivateIP } from "../vulnerabilities/ssrf/isRedirectToPrivateIP"; import { getUrlFromHTTPRequestArgs } from "./http-request/getUrlFromHTTPRequestArgs"; @@ -34,7 +35,7 @@ export class HTTPRequest implements Wrapper { // Check if the hostname is inside the context const foundDirectSSRF = checkContextForSSRF({ - hostname: url.hostname, + hostname: Hostname.fromURL(url), operation: `${module}.request`, context: context, port: port, diff --git a/library/sinks/Undici.ts b/library/sinks/Undici.ts index fcc370c4b..07fd33a1a 100644 --- a/library/sinks/Undici.ts +++ b/library/sinks/Undici.ts @@ -7,7 +7,9 @@ import { InterceptorResult } from "../agent/hooks/InterceptorResult"; import { Wrapper } from "../agent/Wrapper"; import { getSemverNodeVersion } from "../helpers/getNodeVersion"; import { isVersionGreaterOrEqual } from "../helpers/isVersionGreaterOrEqual"; +import { tryParseURL } from "../helpers/tryParseURL"; import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { Hostname } from "../vulnerabilities/ssrf/Hostname"; import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; import { wrapDispatch } from "./undici/wrapDispatch"; import { wrapExport } from "../agent/hooks/wrapExport"; @@ -25,15 +27,16 @@ const methods = [ export class Undici implements Wrapper { private inspectHostname( agent: Agent, - hostname: string, + hostname: Hostname, port: number | undefined, method: string ): InterceptorResult { // Let the agent know that we are connecting to this hostname // This is to build a list of all hostnames that the application is connecting to if (typeof port === "number" && port > 0) { - agent.onConnectHostname(hostname, port); + agent.onConnectHostname(hostname.asString(), port); } + const context = getContext(); if (!context) { diff --git a/library/sinks/http-request/wrapResponseHandler.ts b/library/sinks/http-request/wrapResponseHandler.ts index da29aade0..69526b2bd 100644 --- a/library/sinks/http-request/wrapResponseHandler.ts +++ b/library/sinks/http-request/wrapResponseHandler.ts @@ -6,6 +6,7 @@ import { isRedirectStatusCode } from "../../helpers/isRedirectStatusCode"; import { tryParseURL } from "../../helpers/tryParseURL"; import { findHostnameInContext } from "../../vulnerabilities/ssrf/findHostnameInContext"; import { getRedirectOrigin } from "../../vulnerabilities/ssrf/getRedirectOrigin"; +import { Hostname } from "../../vulnerabilities/ssrf/Hostname"; import { getUrlFromHTTPRequestArgs } from "./getUrlFromHTTPRequestArgs"; /** @@ -78,7 +79,11 @@ function addRedirectToContext(source: URL, destination: URL, context: Context) { const sourcePort = getPortFromURL(source); // Check if the source hostname is in the context - is true if it's the first redirect in the chain and the user input is the source - const found = findHostnameInContext(source.hostname, context, sourcePort); + const found = findHostnameInContext( + Hostname.fromURL(source), + context, + sourcePort + ); // If the source hostname is not in the context, check if it's a redirect in a already existing chain if (!found && context.outgoingRequestRedirects) { diff --git a/library/sinks/undici/getHostnameAndPortFromArgs.test.ts b/library/sinks/undici/getHostnameAndPortFromArgs.test.ts index b5b1bfcde..0a41ed063 100644 --- a/library/sinks/undici/getHostnameAndPortFromArgs.test.ts +++ b/library/sinks/undici/getHostnameAndPortFromArgs.test.ts @@ -1,78 +1,79 @@ import * as t from "tap"; +import { Hostname } from "../../vulnerabilities/ssrf/Hostname"; import { getHostnameAndPortFromArgs as get } from "./getHostnameAndPortFromArgs"; import { parse as parseUrl } from "url"; t.test("it works with url string", async (t) => { t.same(get(["http://localhost:4000"]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get(["http://localhost?test=1"]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 80, }); t.same(get(["https://localhost"]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); }); t.test("it works with url object", async (t) => { t.same(get([new URL("http://localhost:4000")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([new URL("http://localhost?test=1")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 80, }); t.same(get([new URL("https://localhost")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); }); t.test("it works with an array of strings", async (t) => { t.same(get([["http://localhost:4000"]]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([["http://localhost?test=1"]]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 80, }); t.same(get([["https://localhost"]]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); }); t.test("it works with an legacy url object", async (t) => { t.same(get([parseUrl("http://localhost:4000")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([parseUrl("http://localhost?test=1")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 80, }); t.same(get([parseUrl("https://localhost")]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); }); t.test("it works with an options object containing origin", async (t) => { t.same(get([{ origin: "http://localhost:4000" }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([{ origin: "http://localhost?test=1" }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 80, }); t.same(get([{ origin: "https://localhost" }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); }); @@ -81,15 +82,15 @@ t.test( "it works with an options object containing protocol, hostname and port", async (t) => { t.same(get([{ protocol: "http:", hostname: "localhost", port: 4000 }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([{ hostname: "localhost", port: 4000 }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 4000, }); t.same(get([{ protocol: "https:", hostname: "localhost" }]), { - hostname: "localhost", + hostname: Hostname.fromString("localhost"), port: 443, }); } @@ -104,3 +105,11 @@ t.test("without hostname", async (t) => { t.same(get([{}]), undefined); t.same(get([{ protocol: "https:", port: 4000 }]), undefined); }); + +t.test("invalid hostname", async (t) => { + t.same(get([{ protocol: "https:", hostname: " " }]), undefined); +}); + +t.test("empty args", async (t) => { + t.same(get([]), undefined); +}); diff --git a/library/sinks/undici/getHostnameAndPortFromArgs.ts b/library/sinks/undici/getHostnameAndPortFromArgs.ts index a8e408a50..ede259206 100644 --- a/library/sinks/undici/getHostnameAndPortFromArgs.ts +++ b/library/sinks/undici/getHostnameAndPortFromArgs.ts @@ -1,9 +1,10 @@ import { getPortFromURL } from "../../helpers/getPortFromURL"; import { tryParseURL } from "../../helpers/tryParseURL"; +import { Hostname } from "../../vulnerabilities/ssrf/Hostname"; import { isOptionsObject } from "../http-request/isOptionsObject"; type HostnameAndPort = { - hostname: string; + hostname: Hostname; port: number | undefined; }; @@ -36,7 +37,7 @@ export function getHostnameAndPortFromArgs( // If url is not undefined, extract the hostname and port if (url && url.hostname.length > 0) { return { - hostname: url.hostname, + hostname: Hostname.fromURL(url), port: getPortFromURL(url), }; } @@ -60,7 +61,7 @@ function parseOptionsObject(obj: any): HostnameAndPort | undefined { const url = tryParseURL(obj.origin); if (url) { return { - hostname: url.hostname, + hostname: Hostname.fromURL(url), port: getPortFromURL(url), }; } @@ -87,8 +88,13 @@ function parseOptionsObject(obj: any): HostnameAndPort | undefined { return undefined; } + const hostname = Hostname.fromString(obj.hostname); + if (!hostname) { + return undefined; + } + return { - hostname: obj.hostname, + hostname, port, }; } diff --git a/library/sinks/undici/onRedirect.ts b/library/sinks/undici/onRedirect.ts index b804d0f36..e6d9a6f86 100644 --- a/library/sinks/undici/onRedirect.ts +++ b/library/sinks/undici/onRedirect.ts @@ -1,6 +1,7 @@ import { Context, updateContext } from "../../agent/Context"; import { findHostnameInContext } from "../../vulnerabilities/ssrf/findHostnameInContext"; import { getRedirectOrigin } from "../../vulnerabilities/ssrf/getRedirectOrigin"; +import { Hostname } from "../../vulnerabilities/ssrf/Hostname"; import { RequestContextStorage } from "./RequestContextStorage"; /** @@ -20,7 +21,7 @@ export function onRedirect( // Check if the source hostname is in the context - is true if it's the first redirect in the chain and the user input is the source const found = findHostnameInContext( - requestContext.url.hostname, + Hostname.fromURL(requestContext.url), context, requestContext.port ); diff --git a/library/vulnerabilities/ssrf/Hostname.ts b/library/vulnerabilities/ssrf/Hostname.ts new file mode 100644 index 000000000..89877fdf3 --- /dev/null +++ b/library/vulnerabilities/ssrf/Hostname.ts @@ -0,0 +1,27 @@ +import { tryParseURL } from "../../helpers/tryParseURL"; + +export class Hostname { + private constructor(private readonly url: URL) {} + + static fromURL(url: URL) { + return new Hostname(url); + } + + static fromString(str: string) { + const url = tryParseURL(`http://${str}`); + + if (!url) { + return undefined; + } + + return new Hostname(url); + } + + asString() { + return this.url.hostname; + } + + toString() { + throw new Error("Use asString() instead"); + } +} diff --git a/library/vulnerabilities/ssrf/checkContextForSSRF.ts b/library/vulnerabilities/ssrf/checkContextForSSRF.ts index e19c32b9f..fac5a51c0 100644 --- a/library/vulnerabilities/ssrf/checkContextForSSRF.ts +++ b/library/vulnerabilities/ssrf/checkContextForSSRF.ts @@ -7,6 +7,7 @@ import { extractStringsFromUserInputCached } from "../../helpers/extractStringsF import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; import { findHostnameInUserInput } from "./findHostnameInUserInput"; import { getMetadataForSSRFAttack } from "./getMetadataForSSRFAttack"; +import { Hostname } from "./Hostname"; import { isRequestToItself } from "./isRequestToItself"; /** @@ -19,7 +20,7 @@ export function checkContextForSSRF({ operation, context, }: { - hostname: string; + hostname: Hostname; port: number | undefined; operation: string; context: Context; @@ -28,7 +29,7 @@ export function checkContextForSSRF({ // DNS lookup calls will be inspected somewhere else // This is just to inspect direct invocations of `http.request` and similar // Where the hostname might be a private IP address (or localhost) - if (!containsPrivateIPAddress(hostname)) { + if (!containsPrivateIPAddress(hostname.asString())) { return; } @@ -61,7 +62,10 @@ export function checkContextForSSRF({ kind: "ssrf", source: source, pathsToPayload: paths, - metadata: getMetadataForSSRFAttack({ hostname, port }), + metadata: getMetadataForSSRFAttack({ + hostname: hostname.asString(), + port, + }), payload: str, }; } diff --git a/library/vulnerabilities/ssrf/findHostnameInContext.ts b/library/vulnerabilities/ssrf/findHostnameInContext.ts index a4827355c..ec506047b 100644 --- a/library/vulnerabilities/ssrf/findHostnameInContext.ts +++ b/library/vulnerabilities/ssrf/findHostnameInContext.ts @@ -3,6 +3,7 @@ import { Source, SOURCES } from "../../agent/Source"; import { getPathsToPayload } from "../../helpers/attackPath"; import { extractStringsFromUserInputCached } from "../../helpers/extractStringsFromUserInputCached"; import { findHostnameInUserInput } from "./findHostnameInUserInput"; +import { Hostname } from "./Hostname"; import { isRequestToItself } from "./isRequestToItself"; type HostnameLocation = { @@ -14,7 +15,7 @@ type HostnameLocation = { }; export function findHostnameInContext( - hostname: string, + hostname: Hostname, context: Context, port: number | undefined ): HostnameLocation | undefined { @@ -48,7 +49,7 @@ export function findHostnameInContext( pathsToPayload: paths, payload: str, port: port, - hostname: hostname, + hostname: hostname.asString(), }; } } diff --git a/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts b/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts index 362777fbe..a0a231efc 100644 --- a/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts +++ b/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts @@ -1,41 +1,67 @@ import * as t from "tap"; import { findHostnameInUserInput } from "./findHostnameInUserInput"; - -t.test("returns false if user input and hostname are empty", async (t) => { - t.same(findHostnameInUserInput("", ""), false); -}); +import { Hostname } from "./Hostname"; t.test("returns false if user input is empty", async (t) => { - t.same(findHostnameInUserInput("", "example.com"), false); -}); - -t.test("returns false if hostname is empty", async (t) => { - t.same(findHostnameInUserInput("http://example.com", ""), false); + t.same( + findHostnameInUserInput("", Hostname.fromString("example.com")!), + false + ); }); t.test("it parses hostname from user input", async (t) => { - t.same(findHostnameInUserInput("http://localhost", "localhost"), true); + t.same( + findHostnameInUserInput( + "http://localhost", + Hostname.fromString("localhost")! + ), + true + ); }); t.test("it parses special IP", async (t) => { - t.same(findHostnameInUserInput("http://localhost", "localhost"), true); + t.same( + findHostnameInUserInput( + "http://localhost", + Hostname.fromString("localhost")! + ), + true + ); }); t.test("it parses hostname from user input with path behind it", async (t) => { - t.same(findHostnameInUserInput("http://localhost/path", "localhost"), true); + t.same( + findHostnameInUserInput( + "http://localhost/path", + Hostname.fromString("localhost")! + ), + true + ); }); t.test( "it parses hostname from user input with misspelled protocol", async (t) => { - t.same(findHostnameInUserInput("http:/localhost", "localhost"), true); + t.same( + findHostnameInUserInput( + "http:/localhost", + Hostname.fromString("localhost")! + ), + true + ); } ); t.test( "it parses hostname from user input without protocol separator", async (t) => { - t.same(findHostnameInUserInput("http:localhost", "localhost"), true); + t.same( + findHostnameInUserInput( + "http:localhost", + Hostname.fromString("localhost")! + ), + true + ); } ); @@ -43,7 +69,10 @@ t.test( "it parses hostname from user input with misspelled protocol and path behind it", async (t) => { t.same( - findHostnameInUserInput("http:/localhost/path/path", "localhost"), + findHostnameInUserInput( + "http:/localhost/path/path", + Hostname.fromString("localhost")! + ), true ); } @@ -52,52 +81,109 @@ t.test( t.test( "it parses hostname from user input without protocol and path behind it", async (t) => { - t.same(findHostnameInUserInput("localhost/path/path", "localhost"), true); + t.same( + findHostnameInUserInput( + "localhost/path/path", + Hostname.fromString("localhost")! + ), + true + ); } ); t.test("it flags FTP as protocol", async (t) => { - t.same(findHostnameInUserInput("ftp://localhost", "localhost"), true); + t.same( + findHostnameInUserInput( + "ftp://localhost", + Hostname.fromString("localhost")! + ), + true + ); }); t.test("it parses hostname from user input", async (t) => { - t.same(findHostnameInUserInput("localhost", "localhost"), true); + t.same( + findHostnameInUserInput("localhost", Hostname.fromString("localhost")!), + true + ); }); t.test("it ignores invalid URLs", async (t) => { - t.same(findHostnameInUserInput("http://", "localhost"), false); + t.same( + findHostnameInUserInput("http://", Hostname.fromString("localhost")!), + false + ); }); t.test("user input is smaller than hostname", async (t) => { - t.same(findHostnameInUserInput("localhost", "localhost localhost"), false); + t.same( + findHostnameInUserInput( + "localhost", + Hostname.fromString("localhost-localhost")! + ), + false + ); }); t.test("it find IP address inside URL", async () => { t.same( findHostnameInUserInput( "http://169.254.169.254/latest/meta-data/", - "169.254.169.254" + Hostname.fromString("169.254.169.254")! ), true ); }); t.test("it find IP address with strange notation inside URL", async () => { - t.same(findHostnameInUserInput("http://2130706433", "2130706433"), true); - t.same(findHostnameInUserInput("http://127.1", "127.1"), true); - t.same(findHostnameInUserInput("http://127.0.1", "127.0.1"), true); + t.same( + findHostnameInUserInput( + "http://2130706433", + Hostname.fromString("2130706433")! + ), + true + ); + t.same( + findHostnameInUserInput("http://127.1", Hostname.fromString("127.1")!), + true + ); + t.same( + findHostnameInUserInput("http://127.0.1", Hostname.fromString("127.0.1")!), + true + ); }); t.test("it works with ports", async () => { - t.same(findHostnameInUserInput("http://localhost", "localhost", 8080), false); t.same( - findHostnameInUserInput("http://localhost:8080", "localhost", 8080), + findHostnameInUserInput( + "http://localhost", + Hostname.fromString("localhost")!, + 8080 + )!, + false + ); + t.same( + findHostnameInUserInput( + "http://localhost:8080", + Hostname.fromString("localhost")!, + 8080 + )!, true ); // If port is not specified, it should return true - t.same(findHostnameInUserInput("http://localhost:8080", "localhost"), true); t.same( - findHostnameInUserInput("http://localhost:8080", "localhost", 4321), + findHostnameInUserInput( + "http://localhost:8080", + Hostname.fromString("localhost")! + )!, + true + ); + t.same( + findHostnameInUserInput( + "http://localhost:8080", + Hostname.fromString("localhost")!, + 4321 + ), false ); }); diff --git a/library/vulnerabilities/ssrf/findHostnameInUserInput.ts b/library/vulnerabilities/ssrf/findHostnameInUserInput.ts index f7809eec9..083c74211 100644 --- a/library/vulnerabilities/ssrf/findHostnameInUserInput.ts +++ b/library/vulnerabilities/ssrf/findHostnameInUserInput.ts @@ -1,29 +1,26 @@ import { getPortFromURL } from "../../helpers/getPortFromURL"; import { tryParseURL } from "../../helpers/tryParseURL"; +import { Hostname } from "./Hostname"; export function findHostnameInUserInput( userInput: string, - hostname: string, + hostnameURL: Hostname, port?: number ): boolean { if (userInput.length <= 1) { return false; } - const hostnameURL = tryParseURL(`http://${hostname}`); - if (!hostnameURL) { - return false; - } - const variants = [userInput, `http://${userInput}`, `https://${userInput}`]; for (const variant of variants) { const userInputURL = tryParseURL(variant); - if (userInputURL && userInputURL.hostname === hostnameURL.hostname) { + if (userInputURL && userInputURL.hostname === hostnameURL.asString()) { const userPort = getPortFromURL(userInputURL); if (!port) { return true; } + if (port && userPort === port) { return true; } diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts index 4731902a4..cbba74031 100644 --- a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts @@ -7,6 +7,7 @@ import { cleanupStackTrace } from "../../helpers/cleanupStackTrace"; import { escapeHTML } from "../../helpers/escapeHTML"; import { isPlainObject } from "../../helpers/isPlainObject"; import { getMetadataForSSRFAttack } from "./getMetadataForSSRFAttack"; +import { Hostname } from "./Hostname"; import { isPrivateIP } from "./isPrivateIP"; import { isIMDSIPAddress, isTrustedHostname } from "./imds"; import { RequestContextStorage } from "../../sinks/undici/RequestContextStorage"; @@ -141,7 +142,12 @@ function wrapDNSLookupCallback( return callback(err, addresses, family); } - let found = findHostnameInContext(hostname, context, port); + const validHostname = Hostname.fromString(hostname); + if (!validHostname) { + return callback(err, addresses, family); + } + + let found = findHostnameInContext(validHostname, context, port); // The hostname is not found in the context, check if it's a redirect if (!found && context.outgoingRequestRedirects) { @@ -165,7 +171,7 @@ function wrapDNSLookupCallback( // If the URL is the result of a redirect, get the origin of the redirect chain for reporting the attack source if (redirectOrigin) { found = findHostnameInContext( - redirectOrigin.hostname, + Hostname.fromURL(redirectOrigin), context, getPortFromURL(redirectOrigin) ); diff --git a/library/vulnerabilities/ssrf/isRedirectToPrivateIP.ts b/library/vulnerabilities/ssrf/isRedirectToPrivateIP.ts index 446cb17e6..409969a53 100644 --- a/library/vulnerabilities/ssrf/isRedirectToPrivateIP.ts +++ b/library/vulnerabilities/ssrf/isRedirectToPrivateIP.ts @@ -2,6 +2,7 @@ import { Context } from "../../agent/Context"; import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; import { findHostnameInContext } from "./findHostnameInContext"; import { getRedirectOrigin } from "./getRedirectOrigin"; +import { Hostname } from "./Hostname"; /** * This function is called before a outgoing request is made. @@ -23,7 +24,7 @@ export function isRedirectToPrivateIP(url: URL, context: Context) { if (redirectOrigin) { return findHostnameInContext( - redirectOrigin.hostname, + Hostname.fromURL(redirectOrigin), context, parseInt(redirectOrigin.port, 10) );