diff --git a/.gitignore b/.gitignore index 848d233..7dcccaf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ dist node_modules package-lock.json -.DS_Store \ No newline at end of file +.DS_Store +PR-SUBMISSION-GUIDE.md diff --git a/src/startHTTPServer.test.ts b/src/startHTTPServer.test.ts index bc0ff15..096096e 100644 --- a/src/startHTTPServer.test.ts +++ b/src/startHTTPServer.test.ts @@ -701,3 +701,483 @@ it("does not require auth for OPTIONS requests", async () => { await httpServer.close(); }); + +// Stateless OAuth 2.0 JWT Bearer Token Authentication Tests (PR #37) + +it("accepts requests with valid Bearer token in stateless mode", async () => { + const stdioTransport = new StdioClientTransport({ + args: ["src/fixtures/simple-stdio-server.ts"], + command: "tsx", + }); + + const stdioClient = new Client( + { + name: "mcp-proxy", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { + name: string; + version: string; + }; + + const serverCapabilities = stdioClient.getServerCapabilities() as { + capabilities: Record; + }; + + const port = await getRandomPort(); + + // Mock authenticate callback that validates JWT Bearer token + const mockAuthResult = { email: "test@example.com", userId: "user123" }; + const authenticate = vi.fn().mockResolvedValue(mockAuthResult); + + const httpServer = await startHTTPServer({ + authenticate, + createServer: async () => { + const mcpServer = new Server(serverVersion, { + capabilities: serverCapabilities, + }); + + await proxyServer({ + client: stdioClient, + server: mcpServer, + serverCapabilities, + }); + + return mcpServer; + }, + port, + stateless: true, // Enable stateless mode + }); + + // Create a stateless streamable HTTP client with Bearer token + const streamTransport = new StreamableHTTPClientTransport( + new URL(`http://localhost:${port}/mcp`), + { + requestInit: { + headers: { + Authorization: "Bearer valid-jwt-token", + }, + }, + }, + ); + + const streamClient = new Client( + { + name: "stream-client-oauth", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await streamClient.connect(streamTransport); + + // Test that we can make requests with valid authentication + const result = await streamClient.listResources(); + expect(result).toEqual({ + resources: [ + { + name: "Example Resource", + uri: "file:///example.txt", + }, + ], + }); + + // Verify authenticate callback was called + expect(authenticate).toHaveBeenCalled(); + + await streamClient.close(); + await httpServer.close(); + await stdioClient.close(); +}); + +it("returns 401 when authenticate callback returns null in stateless mode", async () => { + const stdioTransport = new StdioClientTransport({ + args: ["src/fixtures/simple-stdio-server.ts"], + command: "tsx", + }); + + const stdioClient = new Client( + { + name: "mcp-proxy", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { + name: string; + version: string; + }; + + const serverCapabilities = stdioClient.getServerCapabilities() as { + capabilities: Record; + }; + + const port = await getRandomPort(); + + // Mock authenticate callback that rejects invalid token + const authenticate = vi.fn().mockResolvedValue(null); + + const httpServer = await startHTTPServer({ + authenticate, + createServer: async () => { + const mcpServer = new Server(serverVersion, { + capabilities: serverCapabilities, + }); + + await proxyServer({ + client: stdioClient, + server: mcpServer, + serverCapabilities, + }); + + return mcpServer; + }, + port, + stateless: true, + }); + + // Create client with invalid Bearer token + const streamTransport = new StreamableHTTPClientTransport( + new URL(`http://localhost:${port}/mcp`), + { + requestInit: { + headers: { + Authorization: "Bearer invalid-jwt-token", + }, + }, + }, + ); + + const streamClient = new Client( + { + name: "stream-client-invalid-token", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + // Connection should fail due to invalid authentication + await expect(streamClient.connect(streamTransport)).rejects.toThrow(); + + // Verify authenticate callback was called + expect(authenticate).toHaveBeenCalled(); + + await httpServer.close(); + await stdioClient.close(); +}); + +it("returns 401 when authenticate callback throws error in stateless mode", async () => { + const stdioTransport = new StdioClientTransport({ + args: ["src/fixtures/simple-stdio-server.ts"], + command: "tsx", + }); + + const stdioClient = new Client( + { + name: "mcp-proxy", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { + name: string; + version: string; + }; + + const serverCapabilities = stdioClient.getServerCapabilities() as { + capabilities: Record; + }; + + const port = await getRandomPort(); + + // Mock authenticate callback that throws (e.g., JWKS endpoint failure) + const authenticate = vi + .fn() + .mockRejectedValue(new Error("JWKS fetch failed")); + + const httpServer = await startHTTPServer({ + authenticate, + createServer: async () => { + const mcpServer = new Server(serverVersion, { + capabilities: serverCapabilities, + }); + + await proxyServer({ + client: stdioClient, + server: mcpServer, + serverCapabilities, + }); + + return mcpServer; + }, + port, + stateless: true, + }); + + // Create client with Bearer token + const streamTransport = new StreamableHTTPClientTransport( + new URL(`http://localhost:${port}/mcp`), + { + requestInit: { + headers: { + Authorization: "Bearer some-token", + }, + }, + }, + ); + + const streamClient = new Client( + { + name: "stream-client-auth-error", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + // Connection should fail due to authentication error + await expect(streamClient.connect(streamTransport)).rejects.toThrow(); + + // Verify authenticate callback was called + expect(authenticate).toHaveBeenCalled(); + + await httpServer.close(); + await stdioClient.close(); +}); + +it("does not call authenticate on subsequent requests in stateful mode", async () => { + const stdioTransport = new StdioClientTransport({ + args: ["src/fixtures/simple-stdio-server.ts"], + command: "tsx", + }); + + const stdioClient = new Client( + { + name: "mcp-proxy", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { + name: string; + version: string; + }; + + const serverCapabilities = stdioClient.getServerCapabilities() as { + capabilities: Record; + }; + + const port = await getRandomPort(); + + // Mock authenticate callback + const authenticate = vi.fn().mockResolvedValue({ userId: "user123" }); + + const onConnect = vi.fn().mockResolvedValue(undefined); + const onClose = vi.fn().mockResolvedValue(undefined); + + const httpServer = await startHTTPServer({ + authenticate, + createServer: async () => { + const mcpServer = new Server(serverVersion, { + capabilities: serverCapabilities, + }); + + await proxyServer({ + client: stdioClient, + server: mcpServer, + serverCapabilities, + }); + + return mcpServer; + }, + onClose, + onConnect, + port, + stateless: false, // Explicitly use stateful mode + }); + + // Create client + const streamTransport = new StreamableHTTPClientTransport( + new URL(`http://localhost:${port}/mcp`), + ); + + const streamClient = new Client( + { + name: "stream-client-stateful", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await streamClient.connect(streamTransport); + + // Make first request + await streamClient.listResources(); + + // Make second request + await streamClient.listResources(); + + // In stateful mode, authenticate should NOT be called per-request + // It may be called during initialization, but not on every tool call + // The key is that it's not called multiple times for each request + expect(authenticate).not.toHaveBeenCalled(); + + await streamClient.close(); + await httpServer.close(); + await stdioClient.close(); +}); + +it("calls authenticate on every request in stateless mode", async () => { + const stdioTransport = new StdioClientTransport({ + args: ["src/fixtures/simple-stdio-server.ts"], + command: "tsx", + }); + + const stdioClient = new Client( + { + name: "mcp-proxy", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { + name: string; + version: string; + }; + + const serverCapabilities = stdioClient.getServerCapabilities() as { + capabilities: Record; + }; + + const port = await getRandomPort(); + + // Mock authenticate callback + const authenticate = vi.fn().mockResolvedValue({ userId: "user123" }); + + const httpServer = await startHTTPServer({ + authenticate, + createServer: async () => { + const mcpServer = new Server(serverVersion, { + capabilities: serverCapabilities, + }); + + await proxyServer({ + client: stdioClient, + server: mcpServer, + serverCapabilities, + }); + + return mcpServer; + }, + port, + stateless: true, // Enable stateless mode + }); + + // Create client with Bearer token + const streamTransport = new StreamableHTTPClientTransport( + new URL(`http://localhost:${port}/mcp`), + { + requestInit: { + headers: { + Authorization: "Bearer test-token", + }, + }, + }, + ); + + const streamClient = new Client( + { + name: "stream-client-per-request", + version: "1.0.0", + }, + { + capabilities: {}, + }, + ); + + await streamClient.connect(streamTransport); + + const initialCallCount = authenticate.mock.calls.length; + + // Make first request + await streamClient.listResources(); + const firstRequestCallCount = authenticate.mock.calls.length; + + // Make second request + await streamClient.listResources(); + const secondRequestCallCount = authenticate.mock.calls.length; + + // In stateless mode, authenticate should be called on EVERY request + expect(firstRequestCallCount).toBeGreaterThan(initialCallCount); + expect(secondRequestCallCount).toBeGreaterThan(firstRequestCallCount); + + await streamClient.close(); + await httpServer.close(); + await stdioClient.close(); +}); + +it("includes Authorization in CORS allowed headers", async () => { + const port = await getRandomPort(); + + const httpServer = await startHTTPServer({ + createServer: async () => { + const mcpServer = new Server( + { name: "test", version: "1.0.0" }, + { capabilities: {} }, + ); + return mcpServer; + }, + port, + }); + + // Test OPTIONS request to verify CORS headers + const response = await fetch(`http://localhost:${port}/mcp`, { + headers: { + Origin: "https://example.com", + }, + method: "OPTIONS", + }); + + expect(response.status).toBe(204); + + // Verify Authorization is in the allowed headers + const allowedHeaders = response.headers.get("Access-Control-Allow-Headers"); + expect(allowedHeaders).toBeTruthy(); + expect(allowedHeaders).toContain("Authorization"); + + await httpServer.close(); +}); diff --git a/src/startHTTPServer.ts b/src/startHTTPServer.ts index f94705f..5ebdd82 100644 --- a/src/startHTTPServer.ts +++ b/src/startHTTPServer.ts @@ -93,6 +93,7 @@ const cleanupServer = async ( const handleStreamRequest = async ({ activeTransports, + authenticate, createServer, enableJsonResponse, endpoint, @@ -107,6 +108,7 @@ const handleStreamRequest = async ({ string, { server: T; transport: StreamableHTTPServerTransport } >; + authenticate?: (request: http.IncomingMessage) => Promise; createServer: (request: http.IncomingMessage) => Promise; enableJsonResponse?: boolean; endpoint: string; @@ -132,6 +134,41 @@ const handleStreamRequest = async ({ const body = await getBody(req); + // Per-request authentication in stateless mode + if (stateless && authenticate) { + try { + const authResult = await authenticate(req); + if (!authResult) { + res.setHeader("Content-Type", "application/json"); + res.writeHead(401).end( + JSON.stringify({ + error: { + code: -32000, + message: "Unauthorized: Authentication failed" + }, + id: (body as { id?: unknown })?.id ?? null, + jsonrpc: "2.0" + }) + ); + return true; + } + } catch (error) { + console.error("Authentication error:", error); + res.setHeader("Content-Type", "application/json"); + res.writeHead(401).end( + JSON.stringify({ + error: { + code: -32000, + message: "Unauthorized: Authentication error" + }, + id: (body as { id?: unknown })?.id ?? null, + jsonrpc: "2.0" + }) + ); + return true; + } + } + if (sessionId) { const activeTransport = activeTransports[sessionId]; if (!activeTransport) { @@ -457,6 +494,7 @@ const handleSSERequest = async ({ export const startHTTPServer = async ({ apiKey, + authenticate, createServer, enableJsonResponse, eventStore, @@ -470,6 +508,7 @@ export const startHTTPServer = async ({ streamEndpoint = "/mcp", }: { apiKey?: string; + authenticate?: (request: http.IncomingMessage) => Promise; createServer: (request: http.IncomingMessage) => Promise; enableJsonResponse?: boolean; eventStore?: EventStore; @@ -508,8 +547,8 @@ export const startHTTPServer = async ({ res.setHeader("Access-Control-Allow-Origin", origin.origin); res.setHeader("Access-Control-Allow-Credentials", "true"); res.setHeader("Access-Control-Allow-Methods", "GET, POST, OPTIONS"); - res.setHeader("Access-Control-Allow-Headers", "*"); - res.setHeader("Access-Control-Expose-Headers", "mcp-session-id"); + res.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, Mcp-Session-Id, Last-Event-Id"); + res.setHeader("Access-Control-Expose-Headers", "Mcp-Session-Id"); } catch (error) { console.error("[mcp-proxy] error parsing origin", error); } @@ -553,6 +592,7 @@ export const startHTTPServer = async ({ streamEndpoint && (await handleStreamRequest({ activeTransports: activeStreamTransports, + authenticate, createServer, enableJsonResponse, endpoint: streamEndpoint,