Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SSRF performance: re-use extracted hostname from URL #536

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/operations/benchmark.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ const modules = [
module: "http-request",
name: "Outgoing HTTP request (`http.request`)",
},
{
module: "undici",
name: "Outgoing HTTP request (`undici.request`)",
},
];

(async () => {
Expand Down
9 changes: 2 additions & 7 deletions benchmarks/operations/http-request.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@ const http = require("http");
module.exports = {
step: async function step() {
return new Promise((resolve, reject) => {
const options = {
hostname: "localhost",
port: 10411,
};

const req = http.request(options, (res) => {
res.on("data", () => {});
const req = http.request("http://localhost:10411", (res) => {
res.resume();
res.on("end", resolve);
});

Expand Down
11 changes: 10 additions & 1 deletion benchmarks/operations/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion benchmarks/operations/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"dependencies": {
"@aikidosec/firewall": "file:../../build",
"better-sqlite3": "^11.7.2",
"mongodb": "^6.9.0"
"mongodb": "^6.9.0",
"undici": "^7.3.0"
}
}
7 changes: 7 additions & 0 deletions benchmarks/operations/undici.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const undici = require("undici");

module.exports = {
step: async function step() {
return await undici.request("http://localhost:10411");
},
};
13 changes: 7 additions & 6 deletions library/sinks/Fetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { Wrapper } from "../agent/Wrapper";
import { getPortFromURL } from "../helpers/getPortFromURL";
import { tryParseURL } from "../helpers/tryParseURL";
import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF";
import { Hostname } from "../vulnerabilities/ssrf/Hostname";
import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls";
import { wrapDispatch } from "./undici/wrapDispatch";

Expand All @@ -16,13 +17,13 @@ export class Fetch implements Wrapper {

private inspectHostname(
agent: Agent,
hostname: string,
hostname: Hostname,
port: number | undefined
): InterceptorResult {
// Let the agent know that we are connecting to this hostname
// This is to build a list of all hostnames that the application is connecting to
if (typeof port === "number" && port > 0) {
agent.onConnectHostname(hostname, port);
agent.onConnectHostname(hostname.asString(), port);
}
const context = getContext();

Expand All @@ -46,7 +47,7 @@ export class Fetch implements Wrapper {
if (url) {
const attack = this.inspectHostname(
agent,
url.hostname,
Hostname.fromURL(url),
getPortFromURL(url)
);
if (attack) {
Expand All @@ -64,7 +65,7 @@ export class Fetch implements Wrapper {
if (url) {
const attack = this.inspectHostname(
agent,
url.hostname,
Hostname.fromURL(url),
getPortFromURL(url)
);
if (attack) {
Expand All @@ -77,7 +78,7 @@ export class Fetch implements Wrapper {
if (args[0] instanceof URL && args[0].hostname.length > 0) {
const attack = this.inspectHostname(
agent,
args[0].hostname,
Hostname.fromURL(args[0]),
getPortFromURL(args[0])
);
if (attack) {
Expand All @@ -91,7 +92,7 @@ export class Fetch implements Wrapper {
if (url) {
const attack = this.inspectHostname(
agent,
url.hostname,
Hostname.fromURL(url),
getPortFromURL(url)
);
if (attack) {
Expand Down
3 changes: 2 additions & 1 deletion library/sinks/HTTPRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { InterceptorResult } from "../agent/hooks/InterceptorResult";
import { Wrapper } from "../agent/Wrapper";
import { getPortFromURL } from "../helpers/getPortFromURL";
import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF";
import { Hostname } from "../vulnerabilities/ssrf/Hostname";
import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls";
import { isRedirectToPrivateIP } from "../vulnerabilities/ssrf/isRedirectToPrivateIP";
import { getUrlFromHTTPRequestArgs } from "./http-request/getUrlFromHTTPRequestArgs";
Expand Down Expand Up @@ -34,7 +35,7 @@ export class HTTPRequest implements Wrapper {

// Check if the hostname is inside the context
const foundDirectSSRF = checkContextForSSRF({
hostname: url.hostname,
hostname: Hostname.fromURL(url),
operation: `${module}.request`,
context: context,
port: port,
Expand Down
7 changes: 5 additions & 2 deletions library/sinks/Undici.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import { InterceptorResult } from "../agent/hooks/InterceptorResult";
import { Wrapper } from "../agent/Wrapper";
import { getSemverNodeVersion } from "../helpers/getNodeVersion";
import { isVersionGreaterOrEqual } from "../helpers/isVersionGreaterOrEqual";
import { tryParseURL } from "../helpers/tryParseURL";
import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF";
import { Hostname } from "../vulnerabilities/ssrf/Hostname";
import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls";
import { wrapDispatch } from "./undici/wrapDispatch";
import { wrapExport } from "../agent/hooks/wrapExport";
Expand All @@ -25,15 +27,16 @@ const methods = [
export class Undici implements Wrapper {
private inspectHostname(
agent: Agent,
hostname: string,
hostname: Hostname,
port: number | undefined,
method: string
): InterceptorResult {
// Let the agent know that we are connecting to this hostname
// This is to build a list of all hostnames that the application is connecting to
if (typeof port === "number" && port > 0) {
agent.onConnectHostname(hostname, port);
agent.onConnectHostname(hostname.asString(), port);
}

const context = getContext();

if (!context) {
Expand Down
7 changes: 6 additions & 1 deletion library/sinks/http-request/wrapResponseHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { isRedirectStatusCode } from "../../helpers/isRedirectStatusCode";
import { tryParseURL } from "../../helpers/tryParseURL";
import { findHostnameInContext } from "../../vulnerabilities/ssrf/findHostnameInContext";
import { getRedirectOrigin } from "../../vulnerabilities/ssrf/getRedirectOrigin";
import { Hostname } from "../../vulnerabilities/ssrf/Hostname";
import { getUrlFromHTTPRequestArgs } from "./getUrlFromHTTPRequestArgs";

/**
Expand Down Expand Up @@ -78,7 +79,11 @@ function addRedirectToContext(source: URL, destination: URL, context: Context) {
const sourcePort = getPortFromURL(source);

// Check if the source hostname is in the context - is true if it's the first redirect in the chain and the user input is the source
const found = findHostnameInContext(source.hostname, context, sourcePort);
const found = findHostnameInContext(
Hostname.fromURL(source),
context,
sourcePort
);

// If the source hostname is not in the context, check if it's a redirect in a already existing chain
if (!found && context.outgoingRequestRedirects) {
Expand Down
45 changes: 27 additions & 18 deletions library/sinks/undici/getHostnameAndPortFromArgs.test.ts
Original file line number Diff line number Diff line change
@@ -1,78 +1,79 @@
import * as t from "tap";
import { Hostname } from "../../vulnerabilities/ssrf/Hostname";
import { getHostnameAndPortFromArgs as get } from "./getHostnameAndPortFromArgs";
import { parse as parseUrl } from "url";

t.test("it works with url string", async (t) => {
t.same(get(["http://localhost:4000"]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get(["http://localhost?test=1"]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 80,
});
t.same(get(["https://localhost"]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
});

t.test("it works with url object", async (t) => {
t.same(get([new URL("http://localhost:4000")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([new URL("http://localhost?test=1")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 80,
});
t.same(get([new URL("https://localhost")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
});

t.test("it works with an array of strings", async (t) => {
t.same(get([["http://localhost:4000"]]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([["http://localhost?test=1"]]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 80,
});
t.same(get([["https://localhost"]]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
});

t.test("it works with an legacy url object", async (t) => {
t.same(get([parseUrl("http://localhost:4000")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([parseUrl("http://localhost?test=1")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 80,
});
t.same(get([parseUrl("https://localhost")]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
});

t.test("it works with an options object containing origin", async (t) => {
t.same(get([{ origin: "http://localhost:4000" }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([{ origin: "http://localhost?test=1" }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 80,
});
t.same(get([{ origin: "https://localhost" }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
});
Expand All @@ -81,15 +82,15 @@ t.test(
"it works with an options object containing protocol, hostname and port",
async (t) => {
t.same(get([{ protocol: "http:", hostname: "localhost", port: 4000 }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([{ hostname: "localhost", port: 4000 }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 4000,
});
t.same(get([{ protocol: "https:", hostname: "localhost" }]), {
hostname: "localhost",
hostname: Hostname.fromString("localhost"),
port: 443,
});
}
Expand All @@ -104,3 +105,11 @@ t.test("without hostname", async (t) => {
t.same(get([{}]), undefined);
t.same(get([{ protocol: "https:", port: 4000 }]), undefined);
});

t.test("invalid hostname", async (t) => {
t.same(get([{ protocol: "https:", hostname: " " }]), undefined);
});

t.test("empty args", async (t) => {
t.same(get([]), undefined);
});
14 changes: 10 additions & 4 deletions library/sinks/undici/getHostnameAndPortFromArgs.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { getPortFromURL } from "../../helpers/getPortFromURL";
import { tryParseURL } from "../../helpers/tryParseURL";
import { Hostname } from "../../vulnerabilities/ssrf/Hostname";
import { isOptionsObject } from "../http-request/isOptionsObject";

type HostnameAndPort = {
hostname: string;
hostname: Hostname;
port: number | undefined;
};

Expand Down Expand Up @@ -36,7 +37,7 @@ export function getHostnameAndPortFromArgs(
// If url is not undefined, extract the hostname and port
if (url && url.hostname.length > 0) {
return {
hostname: url.hostname,
hostname: Hostname.fromURL(url),
port: getPortFromURL(url),
};
}
Expand All @@ -60,7 +61,7 @@ function parseOptionsObject(obj: any): HostnameAndPort | undefined {
const url = tryParseURL(obj.origin);
if (url) {
return {
hostname: url.hostname,
hostname: Hostname.fromURL(url),
port: getPortFromURL(url),
};
}
Expand All @@ -87,8 +88,13 @@ function parseOptionsObject(obj: any): HostnameAndPort | undefined {
return undefined;
}

const hostname = Hostname.fromString(obj.hostname);
if (!hostname) {
return undefined;
}

return {
hostname: obj.hostname,
hostname,
port,
};
}
3 changes: 2 additions & 1 deletion library/sinks/undici/onRedirect.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Context, updateContext } from "../../agent/Context";
import { findHostnameInContext } from "../../vulnerabilities/ssrf/findHostnameInContext";
import { getRedirectOrigin } from "../../vulnerabilities/ssrf/getRedirectOrigin";
import { Hostname } from "../../vulnerabilities/ssrf/Hostname";
import { RequestContextStorage } from "./RequestContextStorage";

/**
Expand All @@ -20,7 +21,7 @@ export function onRedirect(

// Check if the source hostname is in the context - is true if it's the first redirect in the chain and the user input is the source
const found = findHostnameInContext(
requestContext.url.hostname,
Hostname.fromURL(requestContext.url),
context,
requestContext.port
);
Expand Down
Loading
Loading