diff --git a/client/src/App.tsx b/client/src/App.tsx index 6c9ae3331..9fd6efe39 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -240,6 +240,8 @@ const App = () => { handleDragStart: handleSidebarDragStart, } = useDraggableSidebar(320); + const [clientEncryptionKey, setClientEncryptionKey] = useState(""); + const { connectionStatus, serverCapabilities, @@ -262,6 +264,7 @@ const App = () => { oauthClientId, oauthScope, config, + clientEncryptionKey, connectionType, onNotification: (notification) => { setNotifications((prev) => [...prev, notification as ServerNotification]); @@ -435,9 +438,13 @@ const App = () => { }; try { - const stateMachine = new OAuthStateMachine(sseUrl, (updates) => { - currentState = { ...currentState, ...updates }; - }); + const stateMachine = new OAuthStateMachine( + sseUrl, + clientEncryptionKey, + (updates) => { + currentState = { ...currentState, ...updates }; + }, + ); while ( currentState.oauthStep !== "complete" && @@ -475,7 +482,7 @@ const App = () => { }); } }, - [sseUrl], + [sseUrl, clientEncryptionKey], ); useEffect(() => { @@ -528,6 +535,9 @@ const App = () => { if (data.defaultServerUrl) { setSseUrl(data.defaultServerUrl); } + if (data.clientEncryptionKey) { + setClientEncryptionKey(data.clientEncryptionKey); + } }) .catch((error) => console.error("Error fetching default environment:", error), @@ -833,6 +843,7 @@ const App = () => { onBack={() => setIsAuthDebuggerVisible(false)} authState={authState} updateAuthState={updateAuthState} + clientEncryptionKey={clientEncryptionKey} /> ); @@ -843,7 +854,10 @@ const App = () => { ); return ( Loading...}> - + ); } @@ -854,7 +868,10 @@ const App = () => { ); return ( Loading...}> - + ); } diff --git a/client/src/components/AuthDebugger.tsx b/client/src/components/AuthDebugger.tsx index 6252c1161..4961f3b52 100644 --- a/client/src/components/AuthDebugger.tsx +++ b/client/src/components/AuthDebugger.tsx @@ -8,11 +8,14 @@ import { OAuthStateMachine } from "../lib/oauth-state-machine"; import { SESSION_KEYS } from "../lib/constants"; import { validateRedirectUrl } from "@/utils/urlValidation"; +import { encodeWithKey } from "../lib/auth"; + export interface AuthDebuggerProps { serverUrl: string; onBack: () => void; authState: AuthDebuggerState; updateAuthState: (updates: Partial) => void; + clientEncryptionKey: string; } interface StatusMessageProps { @@ -60,13 +63,17 @@ const AuthDebugger = ({ onBack, authState, updateAuthState, + clientEncryptionKey, }: AuthDebuggerProps) => { // Check for existing tokens on mount useEffect(() => { if (serverUrl && !authState.oauthTokens) { const checkTokens = async () => { try { - const provider = new DebugInspectorOAuthClientProvider(serverUrl); + const provider = new DebugInspectorOAuthClientProvider( + serverUrl, + clientEncryptionKey, + ); const existingTokens = await provider.tokens(); if (existingTokens) { updateAuthState({ @@ -80,7 +87,7 @@ const AuthDebugger = ({ }; checkTokens(); } - }, [serverUrl, updateAuthState, authState.oauthTokens]); + }, [serverUrl, updateAuthState, authState.oauthTokens, clientEncryptionKey]); const startOAuthFlow = useCallback(() => { if (!serverUrl) { @@ -103,8 +110,9 @@ const AuthDebugger = ({ }, [serverUrl, updateAuthState]); const stateMachine = useMemo( - () => new OAuthStateMachine(serverUrl, updateAuthState), - [serverUrl, updateAuthState], + () => + new OAuthStateMachine(serverUrl, clientEncryptionKey, updateAuthState), + [serverUrl, updateAuthState, clientEncryptionKey], ); const proceedToNextStep = useCallback(async () => { @@ -150,11 +158,15 @@ const AuthDebugger = ({ latestError: null, }; - const oauthMachine = new OAuthStateMachine(serverUrl, (updates) => { - // Update our temporary state during the process - currentState = { ...currentState, ...updates }; - // But don't call updateAuthState yet - }); + const oauthMachine = new OAuthStateMachine( + serverUrl, + clientEncryptionKey, + (updates) => { + // Update our temporary state during the process + currentState = { ...currentState, ...updates }; + // But don't call updateAuthState yet + }, + ); // Manually step through each stage of the OAuth flow while (currentState.oauthStep !== "complete") { @@ -181,11 +193,30 @@ const AuthDebugger = ({ return; } + // Encrypt the client secret before storing + const client_secret = currentState.oauthClientInfo?.client_secret; + const encrypted_secret = + clientEncryptionKey && + client_secret && + typeof encodeWithKey === "function" + ? encodeWithKey(clientEncryptionKey, client_secret) + : undefined; + const stateToStore = encrypted_secret + ? { + ...currentState, + oauthClientInfo: { + ...currentState.oauthClientInfo, + client_secret: encrypted_secret, + }, + } + : currentState; + // Store the current auth state before redirecting sessionStorage.setItem( SESSION_KEYS.AUTH_DEBUGGER_STATE, - JSON.stringify(currentState), + JSON.stringify(stateToStore), ); + // Open the authorization URL automatically window.location.href = currentState.authorizationUrl.toString(); break; @@ -214,12 +245,13 @@ const AuthDebugger = ({ } finally { updateAuthState({ isInitiatingAuth: false }); } - }, [serverUrl, updateAuthState, authState]); + }, [serverUrl, updateAuthState, authState, clientEncryptionKey]); const handleClearOAuth = useCallback(() => { if (serverUrl) { const serverAuthProvider = new DebugInspectorOAuthClientProvider( serverUrl, + clientEncryptionKey, ); serverAuthProvider.clear(); updateAuthState({ @@ -235,7 +267,7 @@ const AuthDebugger = ({ updateAuthState({ statusMessage: null }); }, 3000); } - }, [serverUrl, updateAuthState]); + }, [serverUrl, updateAuthState, clientEncryptionKey]); return (
@@ -312,6 +344,7 @@ const AuthDebugger = ({ authState={authState} updateAuthState={updateAuthState} proceedToNextStep={proceedToNextStep} + clientEncryptionKey={clientEncryptionKey} />
diff --git a/client/src/components/JsonView.tsx b/client/src/components/JsonView.tsx index 1febff6a4..1eb2bbb74 100644 --- a/client/src/components/JsonView.tsx +++ b/client/src/components/JsonView.tsx @@ -51,7 +51,7 @@ const JsonView = memo( variant: "destructive", }); } - }, [toast, normalizedData]); + }, [toast, normalizedData, setCopied]); return (
diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index ccfd6d928..87f19428b 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -10,9 +10,13 @@ import { interface OAuthCallbackProps { onConnect: (serverUrl: string) => void; + clientEncryptionKey: string; } -const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { +const OAuthCallback = ({ + onConnect, + clientEncryptionKey, +}: OAuthCallbackProps) => { const { toast } = useToast(); const hasProcessedRef = useRef(false); @@ -44,7 +48,10 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { let result; try { // Create an auth provider with the current server URL - const serverAuthProvider = new InspectorOAuthClientProvider(serverUrl); + const serverAuthProvider = new InspectorOAuthClientProvider( + serverUrl, + clientEncryptionKey, + ); result = await auth(serverAuthProvider, { serverUrl, @@ -73,7 +80,7 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { handleCallback().finally(() => { window.history.replaceState({}, document.title, "/"); }); - }, [toast, onConnect]); + }, [toast, onConnect, clientEncryptionKey]); return (
diff --git a/client/src/components/OAuthDebugCallback.tsx b/client/src/components/OAuthDebugCallback.tsx index 95ccc0760..c1a37416b 100644 --- a/client/src/components/OAuthDebugCallback.tsx +++ b/client/src/components/OAuthDebugCallback.tsx @@ -5,6 +5,7 @@ import { parseOAuthCallbackParams, } from "@/utils/oauthUtils.ts"; import { AuthDebuggerState } from "@/lib/auth-types"; +import { decodeWithKey } from "@/lib/auth.ts"; interface OAuthCallbackProps { onConnect: ({ @@ -16,9 +17,13 @@ interface OAuthCallbackProps { errorMsg?: string; restoredState?: AuthDebuggerState; }) => void; + clientEncryptionKey: string; } -const OAuthDebugCallback = ({ onConnect }: OAuthCallbackProps) => { +const OAuthDebugCallback = ({ + onConnect, + clientEncryptionKey, +}: OAuthCallbackProps) => { useEffect(() => { let isProcessed = false; @@ -57,6 +62,14 @@ const OAuthDebugCallback = ({ onConnect }: OAuthCallbackProps) => { restoredState.authorizationUrl, ); } + // Decrypt client secret if present + if (restoredState && restoredState.oauthClientInfo.client_secret) { + const client_secret = restoredState.oauthClientInfo?.client_secret; + restoredState.oauthClientInfo.client_secret = + clientEncryptionKey && client_secret + ? decodeWithKey(clientEncryptionKey, client_secret) + : undefined; + } // Clean up the stored state sessionStorage.removeItem(SESSION_KEYS.AUTH_DEBUGGER_STATE); } catch (e) { @@ -94,7 +107,7 @@ const OAuthDebugCallback = ({ onConnect }: OAuthCallbackProps) => { return () => { isProcessed = true; }; - }, [onConnect]); + }, [onConnect, clientEncryptionKey]); const callbackParams = parseOAuthCallbackParams(window.location.search); diff --git a/client/src/components/OAuthFlowProgress.tsx b/client/src/components/OAuthFlowProgress.tsx index 5f44a4f51..e3d4480cc 100644 --- a/client/src/components/OAuthFlowProgress.tsx +++ b/client/src/components/OAuthFlowProgress.tsx @@ -56,6 +56,7 @@ interface OAuthFlowProgressProps { authState: AuthDebuggerState; updateAuthState: (updates: Partial) => void; proceedToNextStep: () => Promise; + clientEncryptionKey: string; } const steps: Array = [ @@ -72,11 +73,12 @@ export const OAuthFlowProgress = ({ authState, updateAuthState, proceedToNextStep, + clientEncryptionKey, }: OAuthFlowProgressProps) => { const { toast } = useToast(); const provider = useMemo( - () => new DebugInspectorOAuthClientProvider(serverUrl), - [serverUrl], + () => new DebugInspectorOAuthClientProvider(serverUrl, clientEncryptionKey), + [serverUrl, clientEncryptionKey], ); const [clientInfo, setClientInfo] = useState( null, diff --git a/client/src/components/__tests__/AuthDebugger.test.tsx b/client/src/components/__tests__/AuthDebugger.test.tsx index 5d5042ea5..1bd5fcd0f 100644 --- a/client/src/components/__tests__/AuthDebugger.test.tsx +++ b/client/src/components/__tests__/AuthDebugger.test.tsx @@ -141,6 +141,8 @@ Object.defineProperty(window, "sessionStorage", { value: sessionStorageMock, }); +const clientEncryptionKey = "test-encryption-key"; + describe("AuthDebugger", () => { const defaultAuthState = EMPTY_DEBUGGER_STATE; @@ -149,6 +151,7 @@ describe("AuthDebugger", () => { onBack: jest.fn(), authState: defaultAuthState, updateAuthState: jest.fn(), + clientEncryptionKey, }; beforeEach(() => { diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 0a8d26cfc..35c4e5d9a 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -46,12 +46,38 @@ export const discoverScopes = async ( } }; +// Helper: simple reversible encoding of a string using the token (XOR + base64) +export const encodeWithKey = (key: string, plaintext: string): string => { + const enc = new TextEncoder(); + const textBytes = enc.encode(plaintext); + const keyBytes = enc.encode(key); + const out = new Uint8Array(textBytes.length); + for (let i = 0; i < textBytes.length; i++) { + out[i] = textBytes[i] ^ keyBytes[i % keyBytes.length]; + } + return btoa(String.fromCharCode(...out)); +}; + +export const decodeWithKey = (key: string, encodedB64: string): string => { + const enc = new TextEncoder(); + const dec = new TextDecoder(); + const keyBytes = enc.encode(key); + const bytes = Uint8Array.from(atob(encodedB64), (c) => c.charCodeAt(0)); + const out = new Uint8Array(bytes.length); + for (let i = 0; i < bytes.length; i++) { + out[i] = bytes[i] ^ keyBytes[i % keyBytes.length]; + } + return dec.decode(out); +}; + export const getClientInformationFromSessionStorage = async ({ serverUrl, isPreregistered, + clientEncryptionKey, }: { serverUrl: string; isPreregistered?: boolean; + clientEncryptionKey: string; }) => { const key = getServerSpecificKey( isPreregistered @@ -65,17 +91,37 @@ export const getClientInformationFromSessionStorage = async ({ return undefined; } - return await OAuthClientInformationSchema.parseAsync(JSON.parse(value)); + const parsed = await OAuthClientInformationSchema.parseAsync( + JSON.parse(value), + ); + + // Decrypt client_secret if marked as encrypted and token is available + try { + if (clientEncryptionKey && parsed.client_secret) { + const decrypted = decodeWithKey( + clientEncryptionKey, + parsed.client_secret as unknown as string, + ); + return { ...parsed, client_secret: decrypted } as OAuthClientInformation; + } + } catch (e) { + console.warn("Failed to decrypt client_secret from session storage:", e); + // Fallback to parsed as-is + } + + return parsed; }; export const saveClientInformationToSessionStorage = ({ serverUrl, clientInformation, isPreregistered, + clientEncryptionKey, }: { serverUrl: string; clientInformation: OAuthClientInformation; isPreregistered?: boolean; + clientEncryptionKey: string; }) => { const key = getServerSpecificKey( isPreregistered @@ -83,7 +129,24 @@ export const saveClientInformationToSessionStorage = ({ : SESSION_KEYS.CLIENT_INFORMATION, serverUrl, ); - sessionStorage.setItem(key, JSON.stringify(clientInformation)); + let toStore: Partial & { + _encrypted_client_secret?: boolean; + } = { ...clientInformation }; + if (clientEncryptionKey && clientInformation.client_secret) { + try { + const encrypted = encodeWithKey( + clientEncryptionKey, + clientInformation.client_secret as unknown as string, + ); + toStore = { + ...toStore, + client_secret: encrypted, + }; + } catch (e) { + console.warn("Failed to encrypt client_secret for session storage:", e); + } + } + sessionStorage.setItem(key, JSON.stringify(toStore)); }; export const clearClientInformationFromSessionStorage = ({ @@ -105,6 +168,7 @@ export const clearClientInformationFromSessionStorage = ({ export class InspectorOAuthClientProvider implements OAuthClientProvider { constructor( protected serverUrl: string, + protected clientEncryptionKey: string, scope?: string, ) { this.scope = scope; @@ -144,6 +208,7 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { await getClientInformationFromSessionStorage({ serverUrl: this.serverUrl, isPreregistered: true, + clientEncryptionKey: this.clientEncryptionKey, }); // If no preregistered client information is found, get the dynamically registered client information @@ -152,23 +217,18 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { (await getClientInformationFromSessionStorage({ serverUrl: this.serverUrl, isPreregistered: false, + clientEncryptionKey: this.clientEncryptionKey, })) ); } saveClientInformation(clientInformation: OAuthClientInformation) { - // Remove client_secret before storing (not needed after initial OAuth flow) - const safeInfo = Object.fromEntries( - Object.entries(clientInformation).filter( - ([key]) => key !== "client_secret", - ), - ) as OAuthClientInformation; - // Save the dynamically registered client information to session storage saveClientInformationToSessionStorage({ serverUrl: this.serverUrl, - clientInformation: safeInfo, + clientInformation: clientInformation, isPreregistered: false, + clientEncryptionKey: this.clientEncryptionKey, }); } diff --git a/client/src/lib/hooks/__tests__/useConnection.test.tsx b/client/src/lib/hooks/__tests__/useConnection.test.tsx index b16eaa3d0..c8354c271 100644 --- a/client/src/lib/hooks/__tests__/useConnection.test.tsx +++ b/client/src/lib/hooks/__tests__/useConnection.test.tsx @@ -23,6 +23,7 @@ global.fetch = jest.fn().mockResolvedValue({ }, }); +const clientEncryptionKey = "test-encryption-key"; // Mock the SDK dependencies const mockRequest = jest.fn().mockResolvedValue({ test: "response" }); const mockClient = { @@ -127,6 +128,7 @@ describe("useConnection", () => { sseUrl: "http://localhost:8080", env: {}, config: DEFAULT_INSPECTOR_CONFIG, + clientEncryptionKey, }; describe("Request Configuration", () => { diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index 1341b0c2b..b83c95cb1 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -80,6 +80,7 @@ interface UseConnectionOptions { // eslint-disable-next-line @typescript-eslint/no-explicit-any getRoots?: () => any[]; defaultLoggingLevel?: LoggingLevel; + clientEncryptionKey: string; } export function useConnection({ @@ -98,6 +99,7 @@ export function useConnection({ onElicitationRequest, getRoots, defaultLoggingLevel, + clientEncryptionKey, }: UseConnectionOptions) { const [connectionStatus, setConnectionStatus] = useState("disconnected"); @@ -130,8 +132,9 @@ export function useConnection({ serverUrl: sseUrl, clientInformation: { client_id: oauthClientId }, isPreregistered: true, + clientEncryptionKey, }); - }, [oauthClientId, sseUrl]); + }, [oauthClientId, sseUrl, clientEncryptionKey]); const pushHistory = (request: object, response?: object) => { setRequestHistory((prev) => [ @@ -339,6 +342,7 @@ export function useConnection({ } const serverAuthProvider = new InspectorOAuthClientProvider( sseUrl, + clientEncryptionKey, scope, ); @@ -397,7 +401,10 @@ export function useConnection({ const headers: HeadersInit = {}; // Create an auth provider with the current server URL - const serverAuthProvider = new InspectorOAuthClientProvider(sseUrl); + const serverAuthProvider = new InspectorOAuthClientProvider( + sseUrl, + clientEncryptionKey, + ); // Use custom headers (migration is handled in App.tsx) let finalHeaders: CustomHeaders = customHeaders || []; @@ -754,7 +761,10 @@ export function useConnection({ clientTransport as StreamableHTTPClientTransport ).terminateSession(); await mcpClient?.close(); - const authProvider = new InspectorOAuthClientProvider(sseUrl); + const authProvider = new InspectorOAuthClientProvider( + sseUrl, + clientEncryptionKey, + ); authProvider.clear(); setMcpClient(null); setClientTransport(null); diff --git a/client/src/lib/hooks/useCopy.ts b/client/src/lib/hooks/useCopy.ts index 8dd7e401b..51c42d80a 100644 --- a/client/src/lib/hooks/useCopy.ts +++ b/client/src/lib/hooks/useCopy.ts @@ -21,7 +21,7 @@ function useCopy({ timeout = 500 }: UseCopyProps = {}) { clearTimeout(timeoutId); } }; - }, [copied]); + }, [copied, setCopied, timeout]); return { copied, setCopied }; } diff --git a/client/src/lib/oauth-state-machine.ts b/client/src/lib/oauth-state-machine.ts index c505f698e..77793593b 100644 --- a/client/src/lib/oauth-state-machine.ts +++ b/client/src/lib/oauth-state-machine.ts @@ -203,11 +203,15 @@ export const oauthTransitions: Record = { export class OAuthStateMachine { constructor( private serverUrl: string, + private clientEncryptionKey: string, private updateState: (updates: Partial) => void, ) {} async executeStep(state: AuthDebuggerState): Promise { - const provider = new DebugInspectorOAuthClientProvider(this.serverUrl); + const provider = new DebugInspectorOAuthClientProvider( + this.serverUrl, + this.clientEncryptionKey, + ); const context: StateMachineContext = { state, serverUrl: this.serverUrl, diff --git a/server/src/index.ts b/server/src/index.ts index 88954ebc5..9f02a12b9 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -160,6 +160,9 @@ const sessionToken = process.env.MCP_PROXY_AUTH_TOKEN || randomBytes(32).toString("hex"); const authDisabled = !!process.env.DANGEROUSLY_OMIT_AUTH; +const clientEncryptionKey = + process.env.MCP_CLIENT_ENCRYPTION_KEY || randomBytes(32).toString("hex"); + // Origin validation middleware to prevent DNS rebinding attacks const originValidationMiddleware = ( req: express.Request, @@ -747,6 +750,7 @@ app.get("/config", originValidationMiddleware, authMiddleware, (req, res) => { defaultArgs: values.args, defaultTransport: values.transport, defaultServerUrl: values["server-url"], + clientEncryptionKey, }); } catch (error) { console.error("Error in /config route:", error);