diff --git a/README.md b/README.md index 241ded3..1e55bcc 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ options: - `--streamEndpoint`: Set the streamable HTTP endpoint path (default: `/mcp`). Overrides `--endpoint` if `server` is set to `stream`. - `--stateless`: Enable stateless mode for HTTP streamable transport (no session management). In this mode, each request creates a new server instance instead of maintaining persistent sessions. - `--port`: Specify the port to listen on (default: 8080) +- `--requestTimeout`: Timeout in milliseconds for requests to the MCP server (default: 300000, which is 5 minutes) - `--debug`: Enable debug logging - `--shell`: Spawn the server via the user's shell - `--apiKey`: API key for authenticating requests (uses X-API-Key header) diff --git a/src/bin/mcp-proxy.ts b/src/bin/mcp-proxy.ts index 7c6bd3d..5125939 100644 --- a/src/bin/mcp-proxy.ts +++ b/src/bin/mcp-proxy.ts @@ -66,6 +66,11 @@ const argv = await yargs(hideBin(process.argv)) describe: "The port to listen on", type: "number", }, + requestTimeout: { + default: 300000, + describe: "The timeout (in milliseconds) for requests to the MCP server (default: 5 minutes)", + type: "number", + }, server: { choices: ["sse", "stream"], describe: @@ -156,6 +161,7 @@ const proxy = async () => { proxyServer({ client, + requestTimeout: argv.requestTimeout, server, serverCapabilities, }); diff --git a/src/fixtures/slow-stdio-server.ts b/src/fixtures/slow-stdio-server.ts new file mode 100644 index 0000000..8b479f0 --- /dev/null +++ b/src/fixtures/slow-stdio-server.ts @@ -0,0 +1,91 @@ +#!/usr/bin/env tsx +/** + * A test fixture that simulates a slow MCP server for testing timeout functionality. + * This server intentionally delays responses to test timeout behavior. + */ + +import { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { setTimeout as delay } from "node:timers/promises"; + +const server = new Server( + { + name: "slow-test-server", + version: "1.0.0", + }, + { + capabilities: { + resources: {}, + tools: {}, + }, + }, +); + +// Configure delay via environment variable or default to 2 seconds +const RESPONSE_DELAY = parseInt(process.env.RESPONSE_DELAY || "2000", 10); + +import { + CallToolRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + ReadResourceRequestSchema, +} from "@modelcontextprotocol/sdk/types.js"; + +server.setRequestHandler(ListResourcesRequestSchema, async () => { + await delay(RESPONSE_DELAY); + return { + resources: [ + { + name: "Slow Resource", + uri: "file:///slow.txt", + }, + ], + }; +}); + +server.setRequestHandler(ReadResourceRequestSchema, async ({ params }) => { + await delay(RESPONSE_DELAY); + return { + contents: [ + { + text: `Content from slow server after ${RESPONSE_DELAY}ms delay`, + uri: params.uri, + }, + ], + }; +}); + +server.setRequestHandler(ListToolsRequestSchema, async () => { + await delay(RESPONSE_DELAY); + return { + tools: [ + { + description: "A slow test tool", + inputSchema: { + properties: { + input: { + type: "string", + }, + }, + type: "object", + }, + name: "slowTool", + }, + ], + }; +}); + +server.setRequestHandler(CallToolRequestSchema, async ({ params }) => { + await delay(RESPONSE_DELAY); + return { + content: [ + { + text: `Tool response after ${RESPONSE_DELAY}ms delay: ${params.arguments?.input}`, + type: "text" as const, + }, + ], + }; +}); + +const transport = new StdioServerTransport(); +await server.connect(transport); diff --git a/src/proxyServer.test.ts b/src/proxyServer.test.ts new file mode 100644 index 0000000..c6edf1a --- /dev/null +++ b/src/proxyServer.test.ts @@ -0,0 +1,148 @@ +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { Server } from "@modelcontextprotocol/sdk/server/index.js"; +import { McpError } from "@modelcontextprotocol/sdk/types.js"; +import { EventSource } from "eventsource"; +import { getRandomPort } from "get-port-please"; +import { describe, expect, it } from "vitest"; + +import { proxyServer } from "./proxyServer.js"; +import { startHTTPServer } from "./startHTTPServer.js"; + +if (!("EventSource" in global)) { + // @ts-expect-error - figure out how to use --experimental-eventsource with vitest + global.EventSource = EventSource; +} + +interface TestConfig { + requestTimeout?: number; + serverDelay?: string; + serverFixture?: string; +} + +interface TestEnvironment { + cleanup: () => Promise; + httpServer: { close: () => Promise }; + stdioClient: Client; + streamClient: Client; +} + +async function createTestEnvironment(config: TestConfig = {}): Promise { + const { + requestTimeout, + serverDelay, + serverFixture = "simple-stdio-server.ts" + } = config; + + const stdioTransport = new StdioClientTransport({ + args: [`src/fixtures/${serverFixture}`], + command: "tsx", + env: serverDelay ? { ...process.env, RESPONSE_DELAY: serverDelay } as Record : process.env as Record, + }); + + const stdioClient = new Client( + { name: "mcp-proxy-test", version: "1.0.0" }, + { capabilities: {} } + ); + + await stdioClient.connect(stdioTransport); + + const serverVersion = stdioClient.getServerVersion() as { name: string; version: string }; + const serverCapabilities = stdioClient.getServerCapabilities() as { capabilities: Record }; + const port = await getRandomPort(); + + const httpServer = await startHTTPServer({ + createServer: async () => { + const mcpServer = new Server(serverVersion, { capabilities: serverCapabilities }); + await proxyServer({ + client: stdioClient, + requestTimeout, + server: mcpServer, + serverCapabilities, + }); + return mcpServer; + }, + port, + }); + + const streamClient = new Client( + { name: "stream-client", version: "1.0.0" }, + { capabilities: {} } + ); + + const transport = new StreamableHTTPClientTransport(new URL(`http://localhost:${port}/mcp`)); + await streamClient.connect(transport); + + return { + cleanup: async () => { + await streamClient.close(); + await stdioClient.close(); + }, + httpServer, + stdioClient, + streamClient + }; +} + +describe("proxyServer timeout functionality", () => { + it("should respect custom timeout settings", async () => { + const { cleanup, streamClient } = await createTestEnvironment({ + requestTimeout: 1000, + serverDelay: "500", + serverFixture: "slow-stdio-server.ts" + }); + + // This should succeed as timeout (1s) > delay (500ms) + const result = await streamClient.listResources(); + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe("Slow Resource"); + + await cleanup(); + }, 10000); + + it("should timeout when request takes longer than configured timeout", async () => { + const { cleanup, streamClient } = await createTestEnvironment({ + requestTimeout: 500, + serverDelay: "1000", + serverFixture: "slow-stdio-server.ts" + }); + + // This should throw a timeout error as delay (1s) > timeout (500ms) + await expect(streamClient.listResources()).rejects.toThrow(McpError); + + await cleanup(); + }, 10000); + + it("should use default SDK timeout when no custom timeout is provided", async () => { + const { cleanup, streamClient } = await createTestEnvironment(); + + // This should succeed with default timeout + const result = await streamClient.listResources(); + expect(result.resources).toBeDefined(); + + await cleanup(); + }, 10000); + + it("should handle resource reads with custom timeout", async () => { + const { cleanup, streamClient } = await createTestEnvironment({ + requestTimeout: 600, + serverDelay: "300", + serverFixture: "slow-stdio-server.ts" + }); + + // First get the resources + const resources = await streamClient.listResources(); + expect(resources.resources).toHaveLength(1); + + // Resource read should succeed as timeout (600ms) > delay (300ms) + const resourceContent = await streamClient.readResource({ + uri: resources.resources[0].uri, + }); + + expect(resourceContent.contents).toBeDefined(); + expect(resourceContent.contents[0].text).toContain("300ms delay"); + + await cleanup(); + }, 10000); +}); diff --git a/src/proxyServer.ts b/src/proxyServer.ts index 5520e78..921d6d8 100644 --- a/src/proxyServer.ts +++ b/src/proxyServer.ts @@ -18,10 +18,12 @@ import { export const proxyServer = async ({ client, + requestTimeout, server, serverCapabilities, }: { client: Client; + requestTimeout?: number; server: Server; serverCapabilities: ServerCapabilities; }): Promise => { @@ -42,28 +44,43 @@ export const proxyServer = async ({ if (serverCapabilities?.prompts) { server.setRequestHandler(GetPromptRequestSchema, async (args) => { - return client.getPrompt(args.params); + return client.getPrompt( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); server.setRequestHandler(ListPromptsRequestSchema, async (args) => { - return client.listPrompts(args.params); + return client.listPrompts( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); } if (serverCapabilities?.resources) { server.setRequestHandler(ListResourcesRequestSchema, async (args) => { - return client.listResources(args.params); + return client.listResources( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); server.setRequestHandler( ListResourceTemplatesRequestSchema, async (args) => { - return client.listResourceTemplates(args.params); + return client.listResourceTemplates( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }, ); server.setRequestHandler(ReadResourceRequestSchema, async (args) => { - return client.readResource(args.params); + return client.readResource( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); if (serverCapabilities?.resources.subscribe) { @@ -75,26 +92,42 @@ export const proxyServer = async ({ ); server.setRequestHandler(SubscribeRequestSchema, async (args) => { - return client.subscribeResource(args.params); + return client.subscribeResource( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); server.setRequestHandler(UnsubscribeRequestSchema, async (args) => { - return client.unsubscribeResource(args.params); + return client.unsubscribeResource( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); } } if (serverCapabilities?.tools) { server.setRequestHandler(CallToolRequestSchema, async (args) => { - return client.callTool(args.params); + return client.callTool( + args.params, + undefined, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); server.setRequestHandler(ListToolsRequestSchema, async (args) => { - return client.listTools(args.params); + return client.listTools( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); } server.setRequestHandler(CompleteRequestSchema, async (args) => { - return client.complete(args.params); + return client.complete( + args.params, + requestTimeout ? { timeout: requestTimeout } : undefined, + ); }); };