diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index a7439df94..869eef187 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -24,9 +24,15 @@ const OAuthCallback = () => { } try { - const accessToken = await handleOAuthCallback(serverUrl, code); - // Store the access token for future use - sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, accessToken); + const tokens = await handleOAuthCallback(serverUrl, code); + // Store both access and refresh tokens + sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token); + if (tokens.refresh_token) { + sessionStorage.setItem( + SESSION_KEYS.REFRESH_TOKEN, + tokens.refresh_token, + ); + } // Redirect back to the main app with server URL to trigger auto-connect window.location.href = `/?serverUrl=${encodeURIComponent(serverUrl)}`; } catch (error) { diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 0417731d9..592dc178e 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -1,10 +1,21 @@ import pkceChallenge from "pkce-challenge"; import { SESSION_KEYS } from "./constants"; +import { z } from "zod"; -export interface OAuthMetadata { - authorization_endpoint: string; - token_endpoint: string; -} +export const OAuthMetadataSchema = z.object({ + authorization_endpoint: z.string(), + token_endpoint: z.string(), +}); + +export type OAuthMetadata = z.infer; + +export const OAuthTokensSchema = z.object({ + access_token: z.string(), + refresh_token: z.string().optional(), + expires_in: z.number().optional(), +}); + +export type OAuthTokens = z.infer; export async function discoverOAuthMetadata( serverUrl: string, @@ -15,10 +26,11 @@ export async function discoverOAuthMetadata( if (response.ok) { const metadata = await response.json(); - return { + const validatedMetadata = OAuthMetadataSchema.parse({ authorization_endpoint: metadata.authorization_endpoint, token_endpoint: metadata.token_endpoint, - }; + }); + return validatedMetadata; } } catch (error) { console.warn("OAuth metadata discovery failed:", error); @@ -26,10 +38,11 @@ export async function discoverOAuthMetadata( // Fall back to default endpoints const baseUrl = new URL(serverUrl); - return { + const defaultMetadata = { authorization_endpoint: new URL("/authorize", baseUrl).toString(), token_endpoint: new URL("/token", baseUrl).toString(), }; + return OAuthMetadataSchema.parse(defaultMetadata); } export async function startOAuthFlow(serverUrl: string): Promise { @@ -60,7 +73,7 @@ export async function startOAuthFlow(serverUrl: string): Promise { export async function handleOAuthCallback( serverUrl: string, code: string, -): Promise { +): Promise { // Get stored code verifier const codeVerifier = sessionStorage.getItem(SESSION_KEYS.CODE_VERIFIER); if (!codeVerifier) { @@ -69,7 +82,6 @@ export async function handleOAuthCallback( // Discover OAuth endpoints const metadata = await discoverOAuthMetadata(serverUrl); - // Exchange code for tokens const response = await fetch(metadata.token_endpoint, { method: "POST", @@ -88,6 +100,35 @@ export async function handleOAuthCallback( throw new Error("Token exchange failed"); } - const data = await response.json(); - return data.access_token; + const tokens = await response.json(); + return OAuthTokensSchema.parse(tokens); +} + +export async function refreshAccessToken( + serverUrl: string, +): Promise { + const refreshToken = sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN); + if (!refreshToken) { + throw new Error("No refresh token available"); + } + + const metadata = await discoverOAuthMetadata(serverUrl); + + const response = await fetch(metadata.token_endpoint, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + grant_type: "refresh_token", + refresh_token: refreshToken, + }), + }); + + if (!response.ok) { + throw new Error("Token refresh failed"); + } + + const tokens = await response.json(); + return OAuthTokensSchema.parse(tokens); } diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index e302b52fe..13a237037 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -3,4 +3,5 @@ export const SESSION_KEYS = { CODE_VERIFIER: "mcp_code_verifier", SERVER_URL: "mcp_server_url", ACCESS_TOKEN: "mcp_access_token", + REFRESH_TOKEN: "mcp_refresh_token", } as const; diff --git a/client/src/lib/hooks/useConnection.ts b/client/src/lib/hooks/useConnection.ts index de2d29ecc..6c42c3f55 100644 --- a/client/src/lib/hooks/useConnection.ts +++ b/client/src/lib/hooks/useConnection.ts @@ -16,7 +16,7 @@ import { import { useState } from "react"; import { toast } from "react-toastify"; import { z } from "zod"; -import { startOAuthFlow } from "../auth"; +import { startOAuthFlow, refreshAccessToken } from "../auth"; import { SESSION_KEYS } from "../constants"; import { Notification, StdErrNotificationSchema } from "../notificationTypes"; @@ -121,7 +121,49 @@ export function useConnection({ } }; - const connect = async () => { + const initiateOAuthFlow = async () => { + sessionStorage.removeItem(SESSION_KEYS.ACCESS_TOKEN); + sessionStorage.removeItem(SESSION_KEYS.REFRESH_TOKEN); + sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); + const redirectUrl = await startOAuthFlow(sseUrl); + window.location.href = redirectUrl; + }; + + const handleTokenRefresh = async () => { + try { + const tokens = await refreshAccessToken(sseUrl); + sessionStorage.setItem(SESSION_KEYS.ACCESS_TOKEN, tokens.access_token); + if (tokens.refresh_token) { + sessionStorage.setItem( + SESSION_KEYS.REFRESH_TOKEN, + tokens.refresh_token, + ); + } + return tokens.access_token; + } catch (error) { + console.error("Token refresh failed:", error); + await initiateOAuthFlow(); + throw error; + } + }; + + const handleAuthError = async (error: unknown) => { + if (error instanceof SseError && error.code === 401) { + if (sessionStorage.getItem(SESSION_KEYS.REFRESH_TOKEN)) { + try { + await handleTokenRefresh(); + return true; + } catch (error) { + console.error("Token refresh failed:", error); + } + } else { + await initiateOAuthFlow(); + } + } + return false; + }; + + const connect = async (_e?: unknown, retryCount: number = 0) => { try { const client = new Client( { @@ -182,14 +224,15 @@ export function useConnection({ await client.connect(clientTransport); } catch (error) { console.error("Failed to connect to MCP server:", error); + const shouldRetry = await handleAuthError(error); + if (shouldRetry) { + return connect(undefined, retryCount + 1); + } + if (error instanceof SseError && error.code === 401) { - // Store the server URL for the callback handler - sessionStorage.setItem(SESSION_KEYS.SERVER_URL, sseUrl); - const redirectUrl = await startOAuthFlow(sseUrl); - window.location.href = redirectUrl; + // Don't set error state if we're about to redirect for auth return; } - throw error; }