Skip to content

Commit

Permalink
Implement CSRF
Browse files Browse the repository at this point in the history
  • Loading branch information
BellCubeDev committed Dec 25, 2024
1 parent d1f1b08 commit b9e17ae
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 45 deletions.
9 changes: 9 additions & 0 deletions src/interface/hooks/updateRef.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import type { MutableRefObject, RefObject } from "react";
import type { Writeable } from "zod";

/**
* A wonderful workaround for not being allowed to mutate hook props directly according to React Compiler, including with a ref.
*/
export function updateRef<T>(ref: RefObject<T> | MutableRefObject<T>, value: T) {
(ref as Writeable<typeof ref>).current = value;
}
19 changes: 16 additions & 3 deletions src/interface/hooks/useCookies.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import CookieManager, { type Cookies } from '@react-native-cookies/cookies';
import { useConfig } from '../provider/config-provider';
import React from 'react';
import { setCookie as setCookieReal } from '../utils/setCookie';

// TODO: Two major optimizations
// 1. Change hook to a useCookie('cookieName') hook to reduce re-renders and block the main thread less
// 2. Only run one instance of the timeout in case the same cookie is requested by multiple components

/**
* Returns a tuple of the cookies for the Supplementary Server's origin (index 0) and whether the cookies are still being loaded (index 1).
*
* Includes separate implementations for React Native and web. React Native gets updates by polling every 30 seconds while web gets updates via the `cookieStore` API (except for Firefox, which falls back to polling).
*
*/
export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolean] {
export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolean, setCookie: typeof setCookieReal] {
const config = useConfig();
const [cookies, setCookies] = React.useState<Cookies | null>(null);
const [isLoadingCookies, setIsLoadingCookies] = React.useState(true);

const stringifiedCookiesRef = React.useRef<string | null>(null);
const generateNewAllCookiesPromise = React.useCallback(async (canceledObj: {canceled: boolean}) => {
const generateNewAllCookiesPromise = React.useCallback(async (canceledObj: {canceled: boolean}, onlyOnce?: boolean) => {
if (!config.supplementary) return;

const newCookies = await CookieManager.get(config.supplementary.canonicalRoot.origin, false);
Expand All @@ -29,6 +34,7 @@ export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolea
setCookies(newCookies);
setIsLoadingCookies(false);

if (onlyOnce) return;
setTimeout(() => generateNewAllCookiesPromise(canceledObj), 30000); // If Firefox gave us some sort of event for this, that'd be great
}, [config.supplementary]);

Expand All @@ -38,6 +44,13 @@ export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolea
return () => { canceledObj.canceled = true };
}, [generateNewAllCookiesPromise]);

const setCookie = React.useCallback(async (...args: Parameters<typeof setCookieReal>) => {
const res = await setCookieReal(...args);
if (!res) return res;
await generateNewAllCookiesPromise({canceled: false}, true);
return res;
}, [generateNewAllCookiesPromise]);

return React.useMemo(()=>[cookies, isLoadingCookies] as const, [cookies, isLoadingCookies]);
// eslint-disable-next-line react-compiler/react-compiler
return React.useMemo(()=>[cookies, isLoadingCookies, setCookie] as const, [cookies, isLoadingCookies, setCookie]);
}
18 changes: 15 additions & 3 deletions src/interface/hooks/useCookies.web.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import type { Cookies, Cookie } from '@react-native-cookies/cookies';
import React from 'react';
if (typeof window !== 'undefined' && !window.cookieStore) window.cookieStore = (await import('cookie-store')).cookieStore;
import { setCookie as setCookieReal } from '../utils/setCookie';

export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolean] {
// TODO: Two major optimizations
// 1. Change hook to a useCookie('cookieName') hook to reduce re-renders and block the main thread less
// 2. Only run one instance of the timeout in case the same cookie is requested by multiple components

export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolean, setCookie: typeof setCookieReal] {
const [cookies, setCookies] = React.useState<Cookies | null>(null);
const [isLoadingCookies, setIsLoadingCookies] = React.useState(true);

Expand Down Expand Up @@ -44,11 +49,11 @@ export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolea
}, []);

React.useEffect(() => {
generateNewAllCookiesPromise();
if (typeof window === 'undefined') return;

try {
window.cookieStore.addEventListener('change', generateNewAllCookiesPromise);
generateNewAllCookiesPromise();
return () => window.cookieStore.removeEventListener('change', generateNewAllCookiesPromise);
} catch {
// Thanks, Firefox, for being mean
Expand All @@ -59,5 +64,12 @@ export function useCookies(): [cookies: Cookies | null, isLoadingCookies: boolea

}, [generateNewAllCookiesPromise, generateNewAllCookiesPromiseLoopEdition]);

return React.useMemo(()=>[cookies, isLoadingCookies] as const, [cookies, isLoadingCookies]);
const setCookie = React.useCallback(async (...args: Parameters<typeof setCookieReal>) => {
const res = await setCookieReal(...args);
if (!res) return res;
await generateNewAllCookiesPromise();
return res;
}, [generateNewAllCookiesPromise]);

return React.useMemo(()=>[cookies, isLoadingCookies, setCookie] as const, [cookies, isLoadingCookies, setCookie]);
}
11 changes: 3 additions & 8 deletions src/interface/hooks/useMergeHooks.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import React from "react";
import type { Writeable } from "zod";
import { updateRef } from "./updateRef";

export function useMergeRefs<TRefType>(...refs: (React.LegacyRef<TRefType> | React.MutableRefObject<TRefType> | undefined)[]): React.Ref<TRefType> {
return React.useMemo(() => (value: TRefType) => {
for (const ref_ of refs) {
const ref = noOp(ref_); // makes the React Compiler not complain
for (const ref of refs) {
if (!ref) continue;
if (typeof ref === 'function') ref(value);
else if (typeof ref === 'string') throw new Error('Cannot merge string refs; if you are using this, you can implement this functionality yourself.');
else (ref as Writeable<typeof ref>).current = value;
else updateRef(ref, value);
}
}, [refs]);
}

function noOp<T>(value: T): T {
return value;
}
3 changes: 2 additions & 1 deletion src/interface/hooks/useUpdatedRef.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import React from 'react';
import { updateRef } from './updateRef';

/**
* Returns a ref that is updated every time the value changes.
Expand All @@ -12,7 +13,7 @@ import React from 'react';
export function useUpdatedRef<T>(value: T) {
const ref = React.useRef(value);
React.useEffect(() => {
ref.current = value;
updateRef(ref, value);
}, [value]);
return ref;
}
5 changes: 2 additions & 3 deletions src/interface/provider/auth/authentication.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,12 @@ function AuthenticationProviderInternal({ children, trpc }: { readonly children:
const [hasFetchedForThisUser, setHasFetchedForThisUser] = React.useState(false);

const [cookies, isLoadingCookies] = useCookies();
const expirationCookie = cookies?.[STOCKEDHOME_COOKIE_EXPIRATION_NAME];
const getExpiresAt = React.useCallback(() => {
if (isLoadingCookies) return 'loading';
if (!cookies) return null;
const expirationCookie = cookies[STOCKEDHOME_COOKIE_EXPIRATION_NAME];
if (!expirationCookie) return null;
return new Date(expirationCookie.value);
}, [isLoadingCookies, cookies]);
}, [isLoadingCookies, expirationCookie]);

const [lastAskedForAuthRenewal, setLastAskedForAuthRenewal, isLoadingLastAskedForAuthRenewal] = useLastAskedForAuthRenewal();

Expand Down
67 changes: 43 additions & 24 deletions src/interface/provider/tRPC-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import type { APIRouter } from 'lib/trpc/primaryRouter';
import { httpBatchLink } from '@trpc/client';
import React from 'react';
import React, { useEffect } from 'react';
import { useConfig } from './config-provider';
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import type { BuiltRouter, RouterRecord } from '@trpc/server/unstable-core-do-not-import';
import { createTRPCReact } from '@trpc/react-query';
import type { Config } from 'lib/config/schema';
import superjson from 'superjson';
import { useCookies } from '../hooks/useCookies';
import { STOCKEDHOME_CSRF_COOKIE_HEADER_NAME } from 'lib/trpc/_csrf';

export type TRPCClient = Omit<typeof trpc, 'Provider' | 'useContext' | ''>;

Expand All @@ -26,44 +28,45 @@ const queryClient = new QueryClient({
defaultOptions: { queries: { staleTime: 5 * 1000 } },
});

function createTRPCClient(primaryConfig: Config, supplementaryConfig: Config | null) {
function createHttpLink(url: URL, csrfToken: string | undefined) {
return httpBatchLink<APIRouter>({
url,
transformer: superjson,
headers: { [`x-${STOCKEDHOME_CSRF_COOKIE_HEADER_NAME}`]: csrfToken, },
});
}

function createTRPCClient(primaryConfig: Config, supplementaryConfig: Config | null, csrfToken: string | undefined) {


// Dumb server if there's only one API server
if (!supplementaryConfig || primaryConfig.canonicalRoot === supplementaryConfig.canonicalRoot) {
return trpc.createClient({
links: [
httpBatchLink<APIRouter>({
url: new URL('api', primaryConfig.canonicalRoot),
transformer: superjson,
}),
createHttpLink(new URL('api', primaryConfig.canonicalRoot), csrfToken),
],
});

} else { // Witchcraft if we have multiple servers
const servers = {
primary: createHttpLink(new URL('api', primaryConfig.canonicalRoot), csrfToken),
supplementary: createHttpLink(new URL('api', supplementaryConfig.canonicalRoot), csrfToken),
} as const;
return trpc.createClient({
links: [
(runtime) => {
const servers = {
primary: httpBatchLink<APIRouter>({ url: new URL('api', primaryConfig.canonicalRoot), transformer: superjson, })(runtime),
supplementary: httpBatchLink<APIRouter>({ url: new URL('api', supplementaryConfig.canonicalRoot), transformer: superjson })(runtime),
} as const;

return (ctx) => {
if (primaryConfig.canonicalRoot === supplementaryConfig.canonicalRoot)
return servers.primary(ctx);

const { op } = ctx;
const pathParts = op.path.split('.');

const server = getServerForPath(primaryConfig.primaryEndpoints, pathParts);
return servers[server](ctx);
};
(runtime) => (ctx) => {
if (primaryConfig.canonicalRoot === supplementaryConfig.canonicalRoot)
return servers.primary(runtime)(ctx);

const { op } = ctx;
const pathParts = op.path.split('.');

const server = getServerForPath(primaryConfig.primaryEndpoints, pathParts);
return servers[server](runtime)(ctx);
},
],
});
}

}


Expand Down Expand Up @@ -95,7 +98,23 @@ export function getServerForPath<TRouter extends TRPCClient | BuiltRouter<{ ctx:

export function TRPCProvider({ children }: React.PropsWithChildren) {
const config = useConfig();
const client = React.useMemo(() => config.primary && createTRPCClient(config.primary, config.supplementary), [config]);
const [cookies, isLoadingCookies, setCookie] = useCookies();
const csrfToken = cookies?.[STOCKEDHOME_CSRF_COOKIE_HEADER_NAME]?.value;
const client = React.useMemo(() => config.primary && createTRPCClient(config.primary, config.supplementary, csrfToken), [config, csrfToken]);

const supplementaryCanonicalHost = config.supplementary?.canonicalRoot.host;
useEffect(() => {
if (csrfToken !== undefined) return;
if (isLoadingCookies) return;
if (!supplementaryCanonicalHost) return;

setCookie({
name: STOCKEDHOME_CSRF_COOKIE_HEADER_NAME,
value: window.crypto.randomUUID(),
domain: supplementaryCanonicalHost,
expires: new Date(Date.now() + (1000 * 60 * 60 * 24 * 31)),
}, supplementaryCanonicalHost);
}, [isLoadingCookies, csrfToken, supplementaryCanonicalHost, setCookie]);

if (!client) {return <trpcContext.Provider value={null}>
{children}
Expand Down
28 changes: 28 additions & 0 deletions src/interface/utils/setCookie.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import CookieManager, { type Cookie } from '@react-native-cookies/cookies';
import type { cookieStore } from 'cookie-store';

declare enum CookieSameSite {
strict = "strict",
lax = "lax",
none = "none"
}

export interface UnifiedCookie {
name: string;
value: string;
path?: string;
domain?: string;
version?: string;
expires?: Date;
secure?: boolean;
httpOnly?: boolean;
sameSite?: CookieSameSite;
}

export async function setCookie(cookie: UnifiedCookie, supplementaryCanonicalHost: string): Promise<boolean> {
const augmentedCookie = Object.assign(cookie, {
expires: cookie.expires?.toUTCString(),
});

return await CookieManager.set(supplementaryCanonicalHost, augmentedCookie);
}
16 changes: 16 additions & 0 deletions src/interface/utils/setCookie.web.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import type { cookieStore } from 'cookie-store';
import type { Cookie } from '@react-native-cookies/cookies';
import type { UnifiedCookie } from './setCookie';
if (typeof window !== 'undefined' && !window.cookieStore) window.cookieStore = (await import('cookie-store')).cookieStore;

export async function setCookie(cookie: UnifiedCookie, supplementaryCanonicalHost: string): Promise<true> {
const augmentedCookie = Object.assign(cookie, {
domain: supplementaryCanonicalHost,
expires: cookie.expires ?? null,
});

console.log('setCookie', augmentedCookie);

await window.cookieStore.set(augmentedCookie);
return true;
}
1 change: 1 addition & 0 deletions src/lib/trpc/_csrf.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const STOCKEDHOME_CSRF_COOKIE_HEADER_NAME = 'stockedhome-csrf-token';
75 changes: 72 additions & 3 deletions src/platforms/next/app/api/[...endpoint]/route.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { fetchRequestHandler } from "@trpc/server/adapters/fetch";
import { apiRouter } from "lib/trpc/primaryRouter";
import { type NextRequest } from "next/server";
import { NextResponse, type NextRequest } from "next/server";
import { loadConfigServer } from "lib/config/loader-server";
import { STOCKEDHOME_CSRF_COOKIE_HEADER_NAME } from "lib/trpc/_csrf";
import { TRPC_ERROR_CODES_BY_KEY } from "@trpc/server/unstable-core-do-not-import";

export const dynamic = 'force-dynamic';

async function tRPCRequestHandler(req: NextRequest) {
const config = await loadConfigServer();
Expand All @@ -17,5 +21,70 @@ async function tRPCRequestHandler(req: NextRequest) {
});
}

export const dynamic = 'force-dynamic';
export { tRPCRequestHandler as GET, tRPCRequestHandler as POST };
export async function GET(req: NextRequest) {
return await tRPCRequestHandler(req);
}

/*
{
"id": null,
"error": {
"message": "\"password\" must be at least 4 characters",
"code": -32600,
"data": {
"code": "BAD_REQUEST",
"httpStatus": 400,
"stack": "...",
"path": "user.changepassword"
}
}
}
*/

export async function POST(req: NextRequest) { // Do CSRF validation for POST requests
const csrfHeader = req.headers.get(`x-${STOCKEDHOME_CSRF_COOKIE_HEADER_NAME}`);
if (csrfHeader === null) {
return NextResponse.json({
error: {
message: "Missing CSRF token header",
code: TRPC_ERROR_CODES_BY_KEY.BAD_REQUEST,
data: {
code: "BAD_REQUEST",
httpStatus: 400,
stack: "...",
}
}
}, { status: 400 });
}

const csrfCookie = req.cookies.get(STOCKEDHOME_CSRF_COOKIE_HEADER_NAME);
if (!csrfCookie) {
return NextResponse.json({
error: {
message: "Missing CSRF token cookie",
code: TRPC_ERROR_CODES_BY_KEY.BAD_REQUEST,
data: {
code: "BAD_REQUEST",
httpStatus: 400,
stack: "...",
}
}
}, { status: 400 });
}

if (csrfHeader !== csrfCookie.value) {
return NextResponse.json({
error: {
message: "CSRF token header did not match CSRF token cookie",
code: TRPC_ERROR_CODES_BY_KEY.BAD_REQUEST,
data: {
code: "BAD_REQUEST",
httpStatus: 400,
stack: "...",
}
}
}, { status: 400 });
}

return await tRPCRequestHandler(req);
}

0 comments on commit b9e17ae

Please sign in to comment.