From f8f5686966c28dd8543490b0b3cc4dcdb4ffb08f Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 6 Jan 2025 22:47:35 +0000 Subject: [PATCH 01/31] Server.tool() convenience API --- package-lock.json | 19 +++++-- package.json | 3 +- src/server/index.ts | 117 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 5 deletions(-) diff --git a/package-lock.json b/package-lock.json index 5e68c8c..44a876a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,7 +11,8 @@ "dependencies": { "content-type": "^1.0.5", "raw-body": "^3.0.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@eslint/js": "^9.8.0", @@ -6146,12 +6147,22 @@ } }, "node_modules/zod": { - "version": "3.23.8", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz", - "integrity": "sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==", + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.1.tgz", + "integrity": "sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==", + "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.1.tgz", + "integrity": "sha512-3h08nf3Vw3Wl3PK+q3ow/lIil81IT2Oa7YpQyUUDsEWbXveMesdfK1xBd2RhCkynwZndAxixji/7SYJJowr62w==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } } } } diff --git a/package.json b/package.json index abc9e5c..893e010 100644 --- a/package.json +++ b/package.json @@ -43,7 +43,8 @@ "dependencies": { "content-type": "^1.0.5", "raw-body": "^3.0.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@eslint/js": "^9.8.0", diff --git a/src/server/index.ts b/src/server/index.ts index d15ad3c..77c5f5c 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,13 +1,19 @@ +import z, { AnyZodObject, ZodRawShape, ZodTypeAny } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import { Protocol, ProtocolOptions, + RequestHandlerExtra, RequestOptions, } from "../shared/protocol.js"; import { + CallToolRequestSchema, + CallToolResult, ClientCapabilities, CreateMessageRequest, CreateMessageResultSchema, EmptyResultSchema, + ErrorCode, Implementation, InitializedNotificationSchema, InitializeRequest, @@ -16,7 +22,10 @@ import { LATEST_PROTOCOL_VERSION, ListRootsRequest, ListRootsResultSchema, + ListToolsRequestSchema, + ListToolsResult, LoggingMessageNotification, + McpError, Notification, Request, ResourceUpdatedNotification, @@ -26,6 +35,7 @@ import { ServerRequest, ServerResult, SUPPORTED_PROTOCOL_VERSIONS, + Tool } from "../types.js"; export type ServerOptions = ProtocolOptions & { @@ -35,6 +45,31 @@ export type ServerOptions = ProtocolOptions & { capabilities: ServerCapabilities; }; +/** + * Callback for a tool handler registered with Server.tool(). + * + * Parameters will include tool arguments, if applicable, as well as other request handler context. + */ +export type ToolCallback = + Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : ( + extra: RequestHandlerExtra, + ) => CallToolResult | Promise; + +type RegisteredTool = { + description?: string; + inputSchema?: AnyZodObject; + callback: ToolCallback; +}; + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: "object" as const +}; + /** * An MCP server on top of a pluggable transport. * @@ -72,6 +107,7 @@ export class Server< private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; + private _registeredTools: { [name: string]: RegisteredTool } = {}; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -305,4 +341,85 @@ export class Server< async sendPromptListChanged() { return this.notification({ method: "notifications/prompts/list_changed" }); } + + private setToolRequestHandlers() { + // TODO: Check that these handlers do not already exist + // TODO: Register tool capability + + this.setRequestHandler(ListToolsRequestSchema, (): ListToolsResult => ({ + tools: Object.entries(this._registeredTools).map(([name, tool]): Tool => { + return { + name, + description: tool.description, + inputSchema: tool.inputSchema ? zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"] : EMPTY_OBJECT_JSON_SCHEMA, + }; + }) + })); + + this.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + } + + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); + if (!parseResult.success) { + throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback + return await Promise.resolve(cb(args, extra)); + } else { + const cb = tool.callback as ToolCallback; + return await Promise.resolve(cb(extra)); + } + }); + } + + /** + * Registers a zero-argument tool `name`, which will run the given function when the client calls it. + */ + tool(name: string, cb: ToolCallback): void; + + /** + * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. + */ + tool(name: string, description: string, cb: ToolCallback): void; + + /** + * Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool(name: string, paramsSchema: Args, cb: ToolCallback): void; + + /** + * Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool(name: string, description: string, paramsSchema: Args, cb: ToolCallback): void; + + tool(name: string, ...rest: unknown[]): void { + if (this._registeredTools[name]) { + throw new Error(`Tool ${name} is already registered`); + } + + let description: string | undefined; + if (rest[0] instanceof String) { + description = rest.shift() as string; + } + + let paramsSchema: ZodRawShape | undefined; + if (rest.length > 1) { + paramsSchema = rest.shift() as ZodRawShape; + } + + const cb = rest[0] as ToolCallback; + this._registeredTools[name] = { + description, + inputSchema: paramsSchema === undefined ? undefined : z.object(paramsSchema), + callback: cb, + }; + + this.setToolRequestHandlers(); + } } From 4bf5c5358a4c00303938d4a88790d6a97ee6b0f0 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:13:47 +0000 Subject: [PATCH 02/31] Assert tool request handlers do not already exist --- src/server/index.ts | 4 +++- src/shared/protocol.ts | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/server/index.ts b/src/server/index.ts index 77c5f5c..a5345dd 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -343,7 +343,9 @@ export class Server< } private setToolRequestHandlers() { - // TODO: Check that these handlers do not already exist + this.assertCanSetRequestHandler(ListToolsRequestSchema.shape.method.value); + this.assertCanSetRequestHandler(CallToolRequestSchema.shape.method.value); + // TODO: Register tool capability this.setRequestHandler(ListToolsRequestSchema, (): ListToolsResult => ({ diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index f430b31..9b4ba56 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -511,6 +511,15 @@ export abstract class Protocol< this._requestHandlers.delete(method); } + /** + * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. + */ + protected assertCanSetRequestHandler(method: string): void { + if (this._requestHandlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + /** * Registers a handler to invoke when this protocol object receives a notification with the given method. * From 1506bd1b513f3f8fefb90e60b6997454c94daa6c Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:26:03 +0000 Subject: [PATCH 03/31] Allow capabilities to be dynamically registered after initialization --- src/client/index.ts | 22 +++++++- src/server/index.ts | 122 +++++++++++++++++++++++++++-------------- src/shared/protocol.ts | 16 ++++++ 3 files changed, 117 insertions(+), 43 deletions(-) diff --git a/src/client/index.ts b/src/client/index.ts index 1ad5161..da0d04c 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,4 +1,5 @@ import { + mergeCapabilities, Protocol, ProtocolOptions, RequestOptions, @@ -44,7 +45,7 @@ export type ClientOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this client. */ - capabilities: ClientCapabilities; + capabilities?: ClientCapabilities; }; /** @@ -90,10 +91,25 @@ export class Client< */ constructor( private _clientInfo: Implementation, - options: ClientOptions, + options?: ClientOptions, ) { super(options); - this._capabilities = options.capabilities; + this._capabilities = options?.capabilities ?? {}; + } + + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ClientCapabilities): void { + if (this.transport) { + throw new Error( + "Cannot register capabilities after connecting to transport", + ); + } + + this._capabilities = mergeCapabilities(this._capabilities, capabilities); } protected assertCapability( diff --git a/src/server/index.ts b/src/server/index.ts index a5345dd..ab4f413 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,6 +1,7 @@ import z, { AnyZodObject, ZodRawShape, ZodTypeAny } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { + mergeCapabilities, Protocol, ProtocolOptions, RequestHandlerExtra, @@ -35,19 +36,19 @@ import { ServerRequest, ServerResult, SUPPORTED_PROTOCOL_VERSIONS, - Tool + Tool, } from "../types.js"; export type ServerOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this server. */ - capabilities: ServerCapabilities; + capabilities?: ServerCapabilities; }; /** * Callback for a tool handler registered with Server.tool(). - * + * * Parameters will include tool arguments, if applicable, as well as other request handler context. */ export type ToolCallback = @@ -56,9 +57,7 @@ export type ToolCallback = args: z.objectOutputType, extra: RequestHandlerExtra, ) => CallToolResult | Promise - : ( - extra: RequestHandlerExtra, - ) => CallToolResult | Promise; + : (extra: RequestHandlerExtra) => CallToolResult | Promise; type RegisteredTool = { description?: string; @@ -67,7 +66,7 @@ type RegisteredTool = { }; const EMPTY_OBJECT_JSON_SCHEMA = { - type: "object" as const + type: "object" as const, }; /** @@ -119,10 +118,10 @@ export class Server< */ constructor( private _serverInfo: Implementation, - options: ServerOptions, + options?: ServerOptions, ) { super(options); - this._capabilities = options.capabilities; + this._capabilities = options?.capabilities ?? {}; this.setRequestHandler(InitializeRequestSchema, (request) => this._oninitialize(request), @@ -132,6 +131,21 @@ export class Server< ); } + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ServerCapabilities): void { + if (this.transport) { + throw new Error( + "Cannot register capabilities after connecting to transport", + ); + } + + this._capabilities = mergeCapabilities(this._capabilities, capabilities); + } + protected assertCapabilityForMethod(method: RequestT["method"]): void { switch (method as ServerRequest["method"]) { case "sampling/createMessage": @@ -348,36 +362,54 @@ export class Server< // TODO: Register tool capability - this.setRequestHandler(ListToolsRequestSchema, (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools).map(([name, tool]): Tool => { - return { - name, - description: tool.description, - inputSchema: tool.inputSchema ? zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"] : EMPTY_OBJECT_JSON_SCHEMA, - }; - }) - })); - - this.setRequestHandler(CallToolRequestSchema, async (request, extra): Promise => { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new McpError(ErrorCode.InvalidParams, `Tool ${request.params.name} not found`); - } - - if (tool.inputSchema) { - const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); - if (!parseResult.success) { - throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`); + this.setRequestHandler( + ListToolsRequestSchema, + (): ListToolsResult => ({ + tools: Object.entries(this._registeredTools).map( + ([name, tool]): Tool => { + return { + name, + description: tool.description, + inputSchema: tool.inputSchema + ? (zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"]) + : EMPTY_OBJECT_JSON_SCHEMA, + }; + }, + ), + }), + ); + + this.setRequestHandler( + CallToolRequestSchema, + async (request, extra): Promise => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} not found`, + ); } - const args = parseResult.data; - const cb = tool.callback as ToolCallback - return await Promise.resolve(cb(args, extra)); - } else { - const cb = tool.callback as ToolCallback; - return await Promise.resolve(cb(extra)); - } - }); + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = tool.callback as ToolCallback; + return await Promise.resolve(cb(extra)); + } + }, + ); } /** @@ -393,12 +425,21 @@ export class Server< /** * Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. */ - tool(name: string, paramsSchema: Args, cb: ToolCallback): void; + tool( + name: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; /** * Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. */ - tool(name: string, description: string, paramsSchema: Args, cb: ToolCallback): void; + tool( + name: string, + description: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; tool(name: string, ...rest: unknown[]): void { if (this._registeredTools[name]) { @@ -418,7 +459,8 @@ export class Server< const cb = rest[0] as ToolCallback; this._registeredTools[name] = { description, - inputSchema: paramsSchema === undefined ? undefined : z.object(paramsSchema), + inputSchema: + paramsSchema === undefined ? undefined : z.object(paramsSchema), callback: cb, }; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 9b4ba56..cf33cd2 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,6 +1,7 @@ import { ZodLiteral, ZodObject, ZodType, z } from "zod"; import { CancelledNotificationSchema, + ClientCapabilities, ErrorCode, JSONRPCError, JSONRPCNotification, @@ -15,6 +16,7 @@ import { Request, RequestId, Result, + ServerCapabilities, } from "../types.js"; import { Transport } from "./transport.js"; @@ -547,3 +549,17 @@ export abstract class Protocol< this._notificationHandlers.delete(method); } } + +export function mergeCapabilities(base: T, additional: T): T { + return Object.entries(additional).reduce( + (acc, [key, value]) => { + if (value && typeof value === "object") { + acc[key] = acc[key] ? { ...acc[key], ...value } : value; + } else { + acc[key] = value; + } + return acc; + }, + { ...base }, + ); +} From 7ac5a5fcf54038a60c42995c5a13ead4885f327a Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:29:21 +0000 Subject: [PATCH 04/31] Automatically register `tools` capability --- src/server/index.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/server/index.ts b/src/server/index.ts index ab4f413..5c2ba46 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -360,7 +360,9 @@ export class Server< this.assertCanSetRequestHandler(ListToolsRequestSchema.shape.method.value); this.assertCanSetRequestHandler(CallToolRequestSchema.shape.method.value); - // TODO: Register tool capability + this.registerCapabilities({ + tools: {}, + }); this.setRequestHandler( ListToolsRequestSchema, From c14a5d44bfb58abb7257d7f9e6f74d0eb15a153c Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:33:43 +0000 Subject: [PATCH 05/31] Tests for `mergeCapabilities` --- src/shared/protocol.test.ts | 98 +++++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 5 deletions(-) diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 9423a2a..3073d0a 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,13 +1,15 @@ -import { Protocol } from "./protocol.js"; -import { Transport } from "./transport.js"; +import { ZodType, z } from "zod"; import { - McpError, + ClientCapabilities, ErrorCode, + McpError, + Notification, Request, Result, - Notification, + ServerCapabilities, } from "../types.js"; -import { ZodType, z } from "zod"; +import { Protocol, mergeCapabilities } from "./protocol.js"; +import { Transport } from "./transport.js"; // Mock Transport class class MockTransport implements Transport { @@ -61,3 +63,89 @@ describe("protocol tests", () => { expect(oncloseMock).toHaveBeenCalled(); }); }); + +describe("mergeCapabilities", () => { + it("should merge client capabilities", () => { + const base: ClientCapabilities = { + sampling: {}, + roots: { + listChanged: true, + }, + }; + + const additional: ClientCapabilities = { + experimental: { + feature: true, + }, + roots: { + newProp: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + sampling: {}, + roots: { + listChanged: true, + newProp: true, + }, + experimental: { + feature: true, + }, + }); + }); + + it("should merge server capabilities", () => { + const base: ServerCapabilities = { + logging: {}, + prompts: { + listChanged: true, + }, + }; + + const additional: ServerCapabilities = { + resources: { + subscribe: true, + }, + prompts: { + newProp: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + logging: {}, + prompts: { + listChanged: true, + newProp: true, + }, + resources: { + subscribe: true, + }, + }); + }); + + it("should override existing values with additional values", () => { + const base: ServerCapabilities = { + prompts: { + listChanged: false, + }, + }; + + const additional: ServerCapabilities = { + prompts: { + listChanged: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged.prompts!.listChanged).toBe(true); + }); + + it("should handle empty objects", () => { + const base = {}; + const additional = {}; + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({}); + }); +}); From b538731af0c028582042c63d4dc0f9e6f90f6230 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:40:53 +0000 Subject: [PATCH 06/31] Fix string type check --- src/server/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/index.ts b/src/server/index.ts index 5c2ba46..db5e36a 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -449,7 +449,7 @@ export class Server< } let description: string | undefined; - if (rest[0] instanceof String) { + if (typeof rest[0] === "string") { description = rest.shift() as string; } From ff22f255adf0a09ce03eecf271f2cb8e4d41def1 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:44:23 +0000 Subject: [PATCH 07/31] Tests for Server.tool --- src/server/index.test.ts | 328 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 0a23955..6c63f25 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -15,6 +15,8 @@ import { ListToolsRequestSchema, SetLevelRequestSchema, ErrorCode, + ListToolsResultSchema, + CallToolResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; @@ -545,3 +547,329 @@ test("should handle request timeout", async () => { code: ErrorCode.RequestTimeout, }); }); + +describe("Server.tool", () => { + test("should register zero-argument tool", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toEqual({ + type: "object", + }); + }); + + test("should register tool with args schema", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async (args) => ({ + content: [ + { + type: "text", + text: `${args.name}: ${args.value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toMatchObject({ + type: "object", + properties: { + name: { type: "string" }, + value: { type: "number" }, + }, + }); + }); + + test("should register tool with description", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.tool("test", "Test description", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].description).toBe("Test description"); + }); + + test("should validate tool args", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + server.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async (args) => ({ + content: [ + { + type: "text", + text: `${args.name}: ${args.value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + name: "test", + value: "not a number", + }, + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate tool registration", () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + server.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + expect(() => { + server.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should allow client to call server tools", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + server.tool( + "test", + "Test tool", + { + input: z.string(), + }, + async (args) => ({ + content: [ + { + type: "text", + text: `Processed: ${args.input}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + input: "hello", + }, + }, + }, + CallToolResultSchema, + ); + + expect(result.content).toEqual([ + { + type: "text", + text: "Processed: hello", + }, + ]); + }); + + test("should handle server tool errors gracefully", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + server.tool("error-test", async () => { + throw new Error("Tool execution failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "error-test", + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow("Tool execution failed"); + }); +}); From aac195985584f10f975064b8ba2ce855eb07bddd Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 14:53:56 +0000 Subject: [PATCH 08/31] Errors invoking tools should be erroneous tool results, not McpErrors --- src/server/index.test.ts | 58 ++++++++++++++++++++++++++++++++++++++-- src/server/index.ts | 28 +++++++++++++++++-- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 6c63f25..978b26b 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -860,16 +860,70 @@ describe("Server.tool", () => { server.connect(serverTransport), ]); + const result = await client.request( + { + method: "tools/call", + params: { + name: "error-test", + }, + }, + CallToolResultSchema, + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual([ + { + type: "text", + text: "Tool execution failed", + }, + ]); + }); + + test("should throw McpError for invalid tool name", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + server.tool("test-tool", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + await expect( client.request( { method: "tools/call", params: { - name: "error-test", + name: "nonexistent-tool", }, }, CallToolResultSchema, ), - ).rejects.toThrow("Tool execution failed"); + ).rejects.toThrow(/Tool nonexistent-tool not found/); }); }); diff --git a/src/server/index.ts b/src/server/index.ts index db5e36a..469a3a7 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -405,10 +405,34 @@ export class Server< const args = parseResult.data; const cb = tool.callback as ToolCallback; - return await Promise.resolve(cb(args, extra)); + try { + return await Promise.resolve(cb(args, extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } } else { const cb = tool.callback as ToolCallback; - return await Promise.resolve(cb(extra)); + try { + return await Promise.resolve(cb(extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } } }, ); From bb28c9b519c09bf110bcb1d5c9b36c8ab577a958 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 15:57:10 +0000 Subject: [PATCH 09/31] Minor tweak to make test code more idiomatic --- src/server/index.test.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 978b26b..ed614a2 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -606,11 +606,11 @@ describe("Server.tool", () => { name: z.string(), value: z.number(), }, - async (args) => ({ + async ({ name, value }) => ({ content: [ { type: "text", - text: `${args.name}: ${args.value}`, + text: `${name}: ${value}`, }, ], }), @@ -705,11 +705,11 @@ describe("Server.tool", () => { name: z.string(), value: z.number(), }, - async (args) => ({ + async ({ name, value }) => ({ content: [ { type: "text", - text: `${args.name}: ${args.value}`, + text: `${name}: ${value}`, }, ], }), @@ -791,11 +791,11 @@ describe("Server.tool", () => { { input: z.string(), }, - async (args) => ({ + async ({ input }) => ({ content: [ { type: "text", - text: `Processed: ${args.input}`, + text: `Processed: ${input}`, }, ], }), From 7f0cf7305ddbb20287a37e63e156782f673729b8 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 16:42:13 +0000 Subject: [PATCH 10/31] URI Template parser and matcher --- src/shared/uriTemplate.test.ts | 167 ++++++++++++++++++++++++ src/shared/uriTemplate.ts | 229 +++++++++++++++++++++++++++++++++ 2 files changed, 396 insertions(+) create mode 100644 src/shared/uriTemplate.test.ts create mode 100644 src/shared/uriTemplate.ts diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts new file mode 100644 index 0000000..8b60ccb --- /dev/null +++ b/src/shared/uriTemplate.test.ts @@ -0,0 +1,167 @@ +import { UriTemplate } from "./uriTemplate.js"; + +describe("UriTemplate", () => { + describe("simple string expansion", () => { + it("should expand simple string variables", () => { + const template = new UriTemplate("http://example.com/users/{username}"); + expect(template.expand({ username: "fred" })).toBe( + "http://example.com/users/fred", + ); + }); + + it("should handle multiple variables", () => { + const template = new UriTemplate("{x,y}"); + expect(template.expand({ x: "1024", y: "768" })).toBe("1024,768"); + }); + + it("should encode reserved characters", () => { + const template = new UriTemplate("{var}"); + expect(template.expand({ var: "value with spaces" })).toBe( + "value+with+spaces", + ); + }); + }); + + describe("reserved expansion", () => { + it("should not encode reserved characters with + operator", () => { + const template = new UriTemplate("{+path}/here"); + expect(template.expand({ path: "/foo/bar" })).toBe("/foo/bar/here"); + }); + }); + + describe("fragment expansion", () => { + it("should add # prefix and not encode reserved chars", () => { + const template = new UriTemplate("X{#var}"); + expect(template.expand({ var: "/test" })).toBe("X#/test"); + }); + }); + + describe("label expansion", () => { + it("should add . prefix", () => { + const template = new UriTemplate("X{.var}"); + expect(template.expand({ var: "test" })).toBe("X.test"); + }); + }); + + describe("path expansion", () => { + it("should add / prefix", () => { + const template = new UriTemplate("X{/var}"); + expect(template.expand({ var: "test" })).toBe("X/test"); + }); + }); + + describe("query expansion", () => { + it("should add ? prefix and name=value format", () => { + const template = new UriTemplate("X{?var}"); + expect(template.expand({ var: "test" })).toBe("X?var=test"); + }); + }); + + describe("form continuation expansion", () => { + it("should add & prefix and name=value format", () => { + const template = new UriTemplate("X{&var}"); + expect(template.expand({ var: "test" })).toBe("X&var=test"); + }); + }); + + describe("matching", () => { + it("should match simple strings and extract variables", () => { + const template = new UriTemplate("http://example.com/users/{username}"); + const match = template.match("http://example.com/users/fred"); + expect(match).toEqual({ username: "fred" }); + }); + + it("should match multiple variables", () => { + const template = new UriTemplate("/users/{username}/posts/{postId}"); + const match = template.match("/users/fred/posts/123"); + expect(match).toEqual({ username: "fred", postId: "123" }); + }); + + it("should return null for non-matching URIs", () => { + const template = new UriTemplate("/users/{username}"); + const match = template.match("/posts/123"); + expect(match).toBeNull(); + }); + + it("should handle exploded arrays", () => { + const template = new UriTemplate("{/list*}"); + const match = template.match("/red,green,blue"); + expect(match).toEqual({ list: ["red", "green", "blue"] }); + }); + }); + + describe("edge cases", () => { + it("should handle empty variables", () => { + const template = new UriTemplate("{empty}"); + expect(template.expand({})).toBe(""); + expect(template.expand({ empty: "" })).toBe(""); + }); + + it("should handle undefined variables", () => { + const template = new UriTemplate("{a}{b}{c}"); + expect(template.expand({ b: "2" })).toBe("2"); + }); + + it("should handle special characters in variable names", () => { + const template = new UriTemplate("{$var_name}"); + expect(template.expand({ "$var_name": "value" })).toBe("value"); + }); + }); + + describe("complex patterns", () => { + it("should handle nested path segments", () => { + const template = new UriTemplate("/api/{version}/{resource}/{id}"); + expect(template.expand({ + version: "v1", + resource: "users", + id: "123" + })).toBe("/api/v1/users/123"); + }); + + it("should handle query parameters with arrays", () => { + const template = new UriTemplate("/search{?tags*}"); + expect(template.expand({ + tags: ["nodejs", "typescript", "testing"] + })).toBe("/search?tags=nodejs,typescript,testing"); + }); + + it("should handle multiple query parameters", () => { + const template = new UriTemplate("/search{?q,page,limit}"); + expect(template.expand({ + q: "test", + page: "1", + limit: "10" + })).toBe("/search?q=test&page=1&limit=10"); + }); + }); + + describe("matching complex patterns", () => { + it("should match nested path segments", () => { + const template = new UriTemplate("/api/{version}/{resource}/{id}"); + const match = template.match("/api/v1/users/123"); + expect(match).toEqual({ + version: "v1", + resource: "users", + id: "123" + }); + }); + + it("should match query parameters", () => { + const template = new UriTemplate("/search{?q}"); + const match = template.match("/search?q=test"); + expect(match).toEqual({ q: "test" }); + }); + + it("should match multiple query parameters", () => { + const template = new UriTemplate("/search{?q,page}"); + const match = template.match("/search?q=test&page=1"); + expect(match).toEqual({ q: "test", page: "1" }); + }); + + it("should handle partial matches correctly", () => { + const template = new UriTemplate("/users/{id}"); + expect(template.match("/users/123/extra")).toBeNull(); + expect(template.match("/users")).toBeNull(); + }); + }); +}); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts new file mode 100644 index 0000000..c75343c --- /dev/null +++ b/src/shared/uriTemplate.ts @@ -0,0 +1,229 @@ +// Claude-authored implementation of RFC 6570 URI Templates + +type Variables = Record; + +export class UriTemplate { + private readonly parts: Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + >; + + constructor(template: string) { + this.parts = this.parse(template); + } + + private parse( + template: string, + ): Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + > { + const parts: Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + > = []; + let currentText = ""; + let i = 0; + + while (i < template.length) { + if (template[i] === "{") { + if (currentText) { + parts.push(currentText); + currentText = ""; + } + const end = template.indexOf("}", i); + if (end === -1) throw new Error("Unclosed template expression"); + + const expr = template.slice(i + 1, end); + const operator = this.getOperator(expr); + const exploded = expr.includes("*"); + const names = this.getNames(expr); + const name = names[0]; + parts.push({ name, operator, names, exploded }); + i = end + 1; + } else { + currentText += template[i]; + i++; + } + } + + if (currentText) { + parts.push(currentText); + } + + return parts; + } + + private getOperator(expr: string): string { + const operators = ["+", "#", ".", "/", "?", "&"]; + return operators.find((op) => expr.startsWith(op)) || ""; + } + + private getNames(expr: string): string[] { + const operator = this.getOperator(expr); + return expr + .slice(operator.length) + .split(",") + .map((name) => name.replace("*", "").trim()) + .filter((name) => name.length > 0); + } + + private encodeValue(value: string, operator: string): string { + if (operator === "+" || operator === "#") { + return encodeURI(value); + } + return encodeURIComponent(value).replace(/%20/g, "+"); + } + + private expandPart( + part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }, + variables: Variables, + ): string { + if (part.operator === "?" || part.operator === "&") { + const pairs = part.names + .map((name) => { + const value = variables[name]; + if (value === undefined) return ""; + const encoded = Array.isArray(value) + ? value.map((v) => this.encodeValue(v, part.operator)).join(",") + : this.encodeValue(value.toString(), part.operator); + return `${name}=${encoded}`; + }) + .filter((pair) => pair.length > 0); + + if (pairs.length === 0) return ""; + const separator = part.operator === "?" ? "?" : "&"; + return separator + pairs.join("&"); + } + + if (part.names.length > 1) { + const values = part.names + .map((name) => variables[name]) + .filter((v) => v !== undefined); + if (values.length === 0) return ""; + return values.map((v) => (Array.isArray(v) ? v[0] : v)).join(","); + } + + const value = variables[part.name]; + if (value === undefined) return ""; + + const values = Array.isArray(value) ? value : [value]; + const encoded = values.map((v) => this.encodeValue(v, part.operator)); + + switch (part.operator) { + case "": + return encoded.join(","); + case "+": + return encoded.join(","); + case "#": + return "#" + encoded.join(","); + case ".": + return "." + encoded.join("."); + case "/": + return "/" + encoded.join("/"); + default: + return encoded.join(","); + } + } + + expand(variables: Variables): string { + return this.parts + .map((part) => { + if (typeof part === "string") return part; + return this.expandPart(part, variables); + }) + .join(""); + } + + private escapeRegExp(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + } + + private partToRegExp(part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }): Array<{ pattern: string; name: string }> { + const patterns: Array<{ pattern: string; name: string }> = []; + + if (part.operator === "?" || part.operator === "&") { + for (let i = 0; i < part.names.length; i++) { + const name = part.names[i]; + const prefix = i === 0 ? "\\" + part.operator : "&"; + patterns.push({ + pattern: prefix + this.escapeRegExp(name) + "=([^&]+)", + name, + }); + } + return patterns; + } + + let pattern: string; + const name = part.name; + + switch (part.operator) { + case "": + pattern = part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"; + break; + case "+": + case "#": + pattern = "(.+)"; + break; + case ".": + pattern = "\\.([^/,]+)"; + break; + case "/": + pattern = "/" + (part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"); + break; + default: + pattern = "([^/]+)"; + } + + patterns.push({ pattern, name }); + return patterns; + } + + match(uri: string): Variables | null { + let pattern = "^"; + const names: Array<{ name: string; exploded: boolean }> = []; + + for (const part of this.parts) { + if (typeof part === "string") { + pattern += this.escapeRegExp(part); + } else { + const patterns = this.partToRegExp(part); + for (const { pattern: partPattern, name } of patterns) { + pattern += partPattern; + names.push({ name, exploded: part.exploded }); + } + } + } + + pattern += "$"; + const regex = new RegExp(pattern); + const match = uri.match(regex); + + if (!match) return null; + + const result: Variables = {}; + for (let i = 0; i < names.length; i++) { + const { name, exploded } = names[i]; + const value = match[i + 1]; + const cleanName = name.replace("*", ""); + + if (exploded && value.includes(",")) { + result[cleanName] = value.split(","); + } else { + result[cleanName] = value; + } + } + + return result; + } +} \ No newline at end of file From dc77d9cf47cb42abae215cdbb5627156808fd624 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 18:40:53 +0000 Subject: [PATCH 11/31] More thorough handling of edge cases/pathologies --- src/shared/uriTemplate.test.ts | 75 ++++++++++++++++++++++++++++++++++ src/shared/uriTemplate.ts | 70 ++++++++++++++++++++++++++++--- 2 files changed, 139 insertions(+), 6 deletions(-) diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts index 8b60ccb..fc20121 100644 --- a/src/shared/uriTemplate.test.ts +++ b/src/shared/uriTemplate.test.ts @@ -164,4 +164,79 @@ describe("UriTemplate", () => { expect(template.match("/users")).toBeNull(); }); }); + + describe("security and edge cases", () => { + it("should handle extremely long input strings", () => { + const longString = "x".repeat(100000); + const template = new UriTemplate(`/api/{param}`); + expect(template.expand({ param: longString })).toBe(`/api/${longString}`); + expect(template.match(`/api/${longString}`)).toEqual({ param: longString }); + }); + + it("should handle deeply nested template expressions", () => { + const template = new UriTemplate("{a}{b}{c}{d}{e}{f}{g}{h}{i}{j}".repeat(1000)); + expect(() => template.expand({ + a: "1", b: "2", c: "3", d: "4", e: "5", + f: "6", g: "7", h: "8", i: "9", j: "0" + })).not.toThrow(); + }); + + it("should handle malformed template expressions", () => { + expect(() => new UriTemplate("{unclosed")).toThrow(); + expect(() => new UriTemplate("{}")).not.toThrow(); + expect(() => new UriTemplate("{,}")).not.toThrow(); + expect(() => new UriTemplate("{a}{")).toThrow(); + }); + + it("should handle pathological regex patterns", () => { + const template = new UriTemplate("/api/{param}"); + // Create a string that could cause catastrophic backtracking + const input = "/api/" + "a".repeat(100000); + expect(() => template.match(input)).not.toThrow(); + }); + + it("should handle invalid UTF-8 sequences", () => { + const template = new UriTemplate("/api/{param}"); + const invalidUtf8 = "���"; + expect(() => template.expand({ param: invalidUtf8 })).not.toThrow(); + expect(() => template.match(`/api/${invalidUtf8}`)).not.toThrow(); + }); + + it("should handle template/URI length mismatches", () => { + const template = new UriTemplate("/api/{param}"); + expect(template.match("/api/")).toBeNull(); + expect(template.match("/api")).toBeNull(); + expect(template.match("/api/value/extra")).toBeNull(); + }); + + it("should handle repeated operators", () => { + const template = new UriTemplate("{?a}{?b}{?c}"); + expect(template.expand({ a: "1", b: "2", c: "3" })).toBe("?a=1&b=2&c=3"); + }); + + it("should handle overlapping variable names", () => { + const template = new UriTemplate("{var}{vara}"); + expect(template.expand({ var: "1", vara: "2" })).toBe("12"); + }); + + it("should handle empty segments", () => { + const template = new UriTemplate("///{a}////{b}////"); + expect(template.expand({ a: "1", b: "2" })).toBe("///1////2////"); + expect(template.match("///1////2////")).toEqual({ a: "1", b: "2" }); + }); + + it("should handle maximum template expression limit", () => { + // Create a template with many expressions + const expressions = Array(10000).fill("{param}").join(""); + expect(() => new UriTemplate(expressions)).not.toThrow(); + }); + + it("should handle maximum variable name length", () => { + const longName = "a".repeat(10000); + const template = new UriTemplate(`{${longName}}`); + const vars: Record = {}; + vars[longName] = "value"; + expect(() => template.expand(vars)).not.toThrow(); + }); + }); }); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts index c75343c..900c14e 100644 --- a/src/shared/uriTemplate.ts +++ b/src/shared/uriTemplate.ts @@ -2,13 +2,26 @@ type Variables = Record; +const MAX_TEMPLATE_LENGTH = 1000000; // 1MB +const MAX_VARIABLE_LENGTH = 1000000; // 1MB +const MAX_TEMPLATE_EXPRESSIONS = 10000; +const MAX_REGEX_LENGTH = 1000000; // 1MB + export class UriTemplate { + private static validateLength(str: string, max: number, context: string): void { + if (str.length > max) { + throw new Error( + `${context} exceeds maximum length of ${max} characters (got ${str.length})`, + ); + } + } private readonly parts: Array< | string | { name: string; operator: string; names: string[]; exploded: boolean } >; constructor(template: string) { + UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, "Template"); this.parts = this.parse(template); } @@ -24,6 +37,7 @@ export class UriTemplate { > = []; let currentText = ""; let i = 0; + let expressionCount = 0; while (i < template.length) { if (template[i] === "{") { @@ -34,11 +48,28 @@ export class UriTemplate { const end = template.indexOf("}", i); if (end === -1) throw new Error("Unclosed template expression"); + expressionCount++; + if (expressionCount > MAX_TEMPLATE_EXPRESSIONS) { + throw new Error( + `Template contains too many expressions (max ${MAX_TEMPLATE_EXPRESSIONS})`, + ); + } + const expr = template.slice(i + 1, end); const operator = this.getOperator(expr); const exploded = expr.includes("*"); const names = this.getNames(expr); const name = names[0]; + + // Validate variable name length + for (const name of names) { + UriTemplate.validateLength( + name, + MAX_VARIABLE_LENGTH, + "Variable name", + ); + } + parts.push({ name, operator, names, exploded }); i = end + 1; } else { @@ -69,6 +100,7 @@ export class UriTemplate { } private encodeValue(value: string, operator: string): string { + UriTemplate.validateLength(value, MAX_VARIABLE_LENGTH, "Variable value"); if (operator === "+" || operator === "#") { return encodeURI(value); } @@ -132,12 +164,31 @@ export class UriTemplate { } expand(variables: Variables): string { - return this.parts - .map((part) => { - if (typeof part === "string") return part; - return this.expandPart(part, variables); - }) - .join(""); + let result = ""; + let hasQueryParam = false; + + for (const part of this.parts) { + if (typeof part === "string") { + result += part; + continue; + } + + const expanded = this.expandPart(part, variables); + if (!expanded) continue; + + // Convert ? to & if we already have a query parameter + if ((part.operator === "?" || part.operator === "&") && hasQueryParam) { + result += expanded.replace("?", "&"); + } else { + result += expanded; + } + + if (part.operator === "?" || part.operator === "&") { + hasQueryParam = true; + } + } + + return result; } private escapeRegExp(str: string): string { @@ -152,6 +203,11 @@ export class UriTemplate { }): Array<{ pattern: string; name: string }> { const patterns: Array<{ pattern: string; name: string }> = []; + // Validate variable name length for matching + for (const name of part.names) { + UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, "Variable name"); + } + if (part.operator === "?" || part.operator === "&") { for (let i = 0; i < part.names.length; i++) { const name = part.names[i]; @@ -190,6 +246,7 @@ export class UriTemplate { } match(uri: string): Variables | null { + UriTemplate.validateLength(uri, MAX_TEMPLATE_LENGTH, "URI"); let pattern = "^"; const names: Array<{ name: string; exploded: boolean }> = []; @@ -206,6 +263,7 @@ export class UriTemplate { } pattern += "$"; + UriTemplate.validateLength(pattern, MAX_REGEX_LENGTH, "Generated regex pattern"); const regex = new RegExp(pattern); const match = uri.match(regex); From 1e040e99320a534194397da34f2694ce1d7b270b Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 18:42:21 +0000 Subject: [PATCH 12/31] Add method to determine if a URI is a template or not --- src/shared/uriTemplate.test.ts | 17 +++++++++++++++++ src/shared/uriTemplate.ts | 29 ++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts index fc20121..0e75131 100644 --- a/src/shared/uriTemplate.test.ts +++ b/src/shared/uriTemplate.test.ts @@ -1,6 +1,23 @@ import { UriTemplate } from "./uriTemplate.js"; describe("UriTemplate", () => { + describe("isTemplate", () => { + it("should return true for strings containing template expressions", () => { + expect(UriTemplate.isTemplate("{foo}")).toBe(true); + expect(UriTemplate.isTemplate("/users/{id}")).toBe(true); + expect(UriTemplate.isTemplate("http://example.com/{path}/{file}")).toBe(true); + expect(UriTemplate.isTemplate("/search{?q,limit}")).toBe(true); + }); + + it("should return false for strings without template expressions", () => { + expect(UriTemplate.isTemplate("")).toBe(false); + expect(UriTemplate.isTemplate("plain string")).toBe(false); + expect(UriTemplate.isTemplate("http://example.com/foo/bar")).toBe(false); + expect(UriTemplate.isTemplate("{}")).toBe(false); // Empty braces don't count + expect(UriTemplate.isTemplate("{ }")).toBe(false); // Just whitespace doesn't count + }); + }); + describe("simple string expansion", () => { it("should expand simple string variables", () => { const template = new UriTemplate("http://example.com/users/{username}"); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts index 900c14e..47671fd 100644 --- a/src/shared/uriTemplate.ts +++ b/src/shared/uriTemplate.ts @@ -1,6 +1,6 @@ // Claude-authored implementation of RFC 6570 URI Templates -type Variables = Record; +export type Variables = Record; const MAX_TEMPLATE_LENGTH = 1000000; // 1MB const MAX_VARIABLE_LENGTH = 1000000; // 1MB @@ -8,7 +8,22 @@ const MAX_TEMPLATE_EXPRESSIONS = 10000; const MAX_REGEX_LENGTH = 1000000; // 1MB export class UriTemplate { - private static validateLength(str: string, max: number, context: string): void { + /** + * Returns true if the given string contains any URI template expressions. + * A template expression is a sequence of characters enclosed in curly braces, + * like {foo} or {?bar}. + */ + static isTemplate(str: string): boolean { + // Look for any sequence of characters between curly braces + // that isn't just whitespace + return /\{[^}\s]+\}/.test(str); + } + + private static validateLength( + str: string, + max: number, + context: string, + ): void { if (str.length > max) { throw new Error( `${context} exceeds maximum length of ${max} characters (got ${str.length})`, @@ -60,7 +75,7 @@ export class UriTemplate { const exploded = expr.includes("*"); const names = this.getNames(expr); const name = names[0]; - + // Validate variable name length for (const name of names) { UriTemplate.validateLength( @@ -263,7 +278,11 @@ export class UriTemplate { } pattern += "$"; - UriTemplate.validateLength(pattern, MAX_REGEX_LENGTH, "Generated regex pattern"); + UriTemplate.validateLength( + pattern, + MAX_REGEX_LENGTH, + "Generated regex pattern", + ); const regex = new RegExp(pattern); const match = uri.match(regex); @@ -284,4 +303,4 @@ export class UriTemplate { return result; } -} \ No newline at end of file +} From e4b382044eb64484851331da938f5308c7f7cd79 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 22:00:55 +0000 Subject: [PATCH 13/31] Add UriTemplate.toString --- src/shared/uriTemplate.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts index 47671fd..cd3f46f 100644 --- a/src/shared/uriTemplate.ts +++ b/src/shared/uriTemplate.ts @@ -30,6 +30,7 @@ export class UriTemplate { ); } } + private readonly template: string; private readonly parts: Array< | string | { name: string; operator: string; names: string[]; exploded: boolean } @@ -37,9 +38,14 @@ export class UriTemplate { constructor(template: string) { UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, "Template"); + this.template = template; this.parts = this.parse(template); } + toString(): string { + return this.template; + } + private parse( template: string, ): Array< From 098266ae22d10737a3c319927dce82c447bea13b Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 19:12:03 +0000 Subject: [PATCH 14/31] Add simplified API for registering resources and resource templates --- src/server/index.test.ts | 275 +++++++++++++++++++++++++++++++++++++++ src/server/index.ts | 242 +++++++++++++++++++++++++++++++++- 2 files changed, 512 insertions(+), 5 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index ed614a2..8a5e297 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -17,10 +17,14 @@ import { ErrorCode, ListToolsResultSchema, CallToolResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; +import { UriTemplate } from "../shared/uriTemplate.js"; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -478,6 +482,7 @@ test("should handle server cancelling a request", async () => { // Request should be rejected await expect(createMessagePromise).rejects.toBe("Cancelled by test"); }); + test("should handle request timeout", async () => { const server = new Server( { @@ -927,3 +932,273 @@ describe("Server.tool", () => { ).rejects.toThrow(/Tool nonexistent-tool not found/); }); }); + +describe("Server.resource", () => { + test("should register resource with uri and readCallback", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe("test"); + expect(result.resources[0].uri).toBe("test://resource"); + }); + + test("should register resource with metadata", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource( + "test", + "test://resource", + { + description: "Test resource", + mimeType: "text/plain", + }, + async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].description).toBe("Test resource"); + expect(result.resources[0].mimeType).toBe("text/plain"); + }); + + test("should register resource template", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource( + "test", + new UriTemplate("test://resource/{id}"), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/templates/list", + }, + ListResourceTemplatesResultSchema, + ); + + expect(result.resourceTemplates).toHaveLength(1); + expect(result.resourceTemplates[0].name).toBe("test"); + expect(result.resourceTemplates[0].uriTemplate).toBe( + "test://resource/{id}", + ); + }); + + test("should prevent duplicate resource registration", () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + server.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + expect(() => { + server.resource("test2", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should prevent duplicate resource template registration", () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + server.resource( + "test", + new UriTemplate("test://resource/{id}"), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + expect(() => { + server.resource( + "test", + new UriTemplate("test://resource/{id}"), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content 2", + }, + ], + }), + ); + }).toThrow(/already registered/); + }); + + test("should handle resource read errors gracefully", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource("error-test", "test://error", async () => { + throw new Error("Resource read failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://error", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource read failed/); + }); + + test("should throw McpError for invalid resource URI", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://nonexistent", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); + }); +}); diff --git a/src/server/index.ts b/src/server/index.ts index 469a3a7..8bb900e 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -21,6 +21,9 @@ import { InitializeRequestSchema, InitializeResult, LATEST_PROTOCOL_VERSION, + ListResourcesRequestSchema, + ListResourcesResult, + ListResourceTemplatesRequestSchema, ListRootsRequest, ListRootsResultSchema, ListToolsRequestSchema, @@ -28,7 +31,10 @@ import { LoggingMessageNotification, McpError, Notification, + ReadResourceRequestSchema, + ReadResourceResult, Request, + Resource, ResourceUpdatedNotification, Result, ServerCapabilities, @@ -38,6 +44,7 @@ import { SUPPORTED_PROTOCOL_VERSIONS, Tool, } from "../types.js"; +import { UriTemplate, Variables } from "../shared/uriTemplate.js"; export type ServerOptions = ProtocolOptions & { /** @@ -53,11 +60,11 @@ export type ServerOptions = ProtocolOptions & { */ export type ToolCallback = Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; type RegisteredTool = { description?: string; @@ -69,6 +76,30 @@ const EMPTY_OBJECT_JSON_SCHEMA = { type: "object" as const, }; +export type ResourceMetadata = Omit; + +export type ReadResourceCallback = ( + uri: URL, + variables?: Variables, +) => ReadResourceResult | Promise; + +export type ListResourcesCallback = () => + | ListResourcesResult + | Promise; + +type RegisteredResource = { + name: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +}; + +type RegisteredResourceTemplate = { + uriTemplate: UriTemplate; + metadata?: ResourceMetadata; + listCallback?: ListResourcesCallback; + readCallback: ReadResourceCallback; +}; + /** * An MCP server on top of a pluggable transport. * @@ -107,6 +138,10 @@ export class Server< private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredResources: { [uri: string]: RegisteredResource } = {}; + private _registeredResourceTemplates: { + [name: string]: RegisteredResourceTemplate; + } = {}; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -492,4 +527,201 @@ export class Server< this.setToolRequestHandlers(); } + + private setResourceRequestHandlers() { + this.assertCanSetRequestHandler( + ListResourcesRequestSchema.shape.method.value, + ); + this.assertCanSetRequestHandler( + ListResourceTemplatesRequestSchema.shape.method.value, + ); + this.assertCanSetRequestHandler( + ReadResourceRequestSchema.shape.method.value, + ); + + this.registerCapabilities({ + resources: {}, + }); + + this.setRequestHandler(ListResourcesRequestSchema, async () => { + const resources = Object.entries(this._registeredResources).map( + ([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata, + }), + ); + + const templateResources: Resource[] = []; + for (const template of Object.values(this._registeredResourceTemplates)) { + if (!template.listCallback) { + continue; + } + + const result = await template.listCallback(); + for (const resource of result.resources) { + templateResources.push({ + ...resource, + ...template.metadata, + }); + } + } + + return { resources: [...resources, ...templateResources] }; + }); + + this.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { + const resourceTemplates = Object.entries( + this._registeredResourceTemplates, + ).map(([name, template]) => ({ + name, + uriTemplate: template.uriTemplate.toString(), + ...template.metadata, + })); + + return { resourceTemplates }; + }); + + this.setRequestHandler(ReadResourceRequestSchema, async (request) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + return resource.readCallback(uri); + } + + // Then check templates + for (const template of Object.values(this._registeredResourceTemplates)) { + const variables = template.uriTemplate.match(uri.toString()); + if (variables) { + return template.readCallback(uri, variables); + } + } + + throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); + }); + } + + resource(name: string, uri: string, readCallback: ReadResourceCallback): void; + + resource( + name: string, + uri: string, + metadata: ResourceMetadata, + readCallback: ReadResourceCallback, + ): void; + + resource( + name: string, + uriTemplate: UriTemplate, + readCallback: ReadResourceCallback, + ): void; + + resource( + name: string, + uriTemplate: UriTemplate, + metadata: ResourceMetadata, + readCallback: ReadResourceCallback, + ): void; + + resource( + name: string, + uriTemplate: UriTemplate, + listCallback: ListResourcesCallback, + readCallback: ReadResourceCallback, + ): void; + + resource( + name: string, + uriTemplate: UriTemplate, + metadata: ResourceMetadata, + listCallback: ListResourcesCallback, + readCallback: ReadResourceCallback, + ): void; + + resource( + name: string, + uriOrTemplate: string | UriTemplate, + ...rest: unknown[] + ): void { + let metadata: ResourceMetadata | undefined; + if (typeof rest[0] === "object") { + metadata = rest.shift() as ResourceMetadata; + } + + let listCallback: ListResourcesCallback | undefined; + if (rest.length > 1) { + listCallback = rest.shift() as ListResourcesCallback; + } + + const readCallback = rest[0] as ReadResourceCallback; + if (typeof uriOrTemplate === "string") { + this.registerResource({ + name, + uri: uriOrTemplate, + metadata, + readCallback, + }); + } else { + this.registerResourceTemplate({ + name, + uriTemplate: uriOrTemplate, + metadata, + listCallback, + readCallback, + }); + } + } + + private registerResource({ + name, + uri, + metadata, + readCallback, + }: { + name: string; + uri: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; + }): void { + if (this._registeredResources[uri]) { + throw new Error(`Resource ${uri} is already registered`); + } + + this._registeredResources[uri] = { + name, + metadata, + readCallback, + }; + + this.setResourceRequestHandlers(); + } + + private registerResourceTemplate({ + name, + uriTemplate, + metadata, + listCallback, + readCallback, + }: { + name: string; + uriTemplate: UriTemplate; + metadata?: ResourceMetadata; + listCallback?: ListResourcesCallback; + readCallback: ReadResourceCallback; + }): void { + if (this._registeredResourceTemplates[name]) { + throw new Error(`Resource template ${name} is already registered`); + } + + this._registeredResourceTemplates[name] = { + uriTemplate, + metadata, + listCallback, + readCallback, + }; + + this.setResourceRequestHandlers(); + } } From 19705091311713f197c1f5f2f73b45b71e567643 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 22:05:55 +0000 Subject: [PATCH 15/31] Method documentation --- src/server/index.ts | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/server/index.ts b/src/server/index.ts index 8bb900e..159b382 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -603,8 +603,14 @@ export class Server< }); } + /** + * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. + */ resource(name: string, uri: string, readCallback: ReadResourceCallback): void; + /** + * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. + */ resource( name: string, uri: string, @@ -612,12 +618,18 @@ export class Server< readCallback: ReadResourceCallback, ): void; + /** + * Registers a resource `name` with a URI template pattern, which will use the given callback to respond to read requests. + */ resource( name: string, uriTemplate: UriTemplate, readCallback: ReadResourceCallback, ): void; + /** + * Registers a resource `name` with a URI template pattern and metadata, which will use the given callback to respond to read requests. + */ resource( name: string, uriTemplate: UriTemplate, @@ -625,6 +637,9 @@ export class Server< readCallback: ReadResourceCallback, ): void; + /** + * Registers a resource `name` with a URI template pattern, which will use the list callback to enumerate matching resources and read callback to respond to read requests. + */ resource( name: string, uriTemplate: UriTemplate, @@ -632,6 +647,9 @@ export class Server< readCallback: ReadResourceCallback, ): void; + /** + * Registers a resource `name` with a URI template pattern and metadata, which will use the list callback to enumerate matching resources and read callback to respond to read requests. + */ resource( name: string, uriTemplate: UriTemplate, From 3f9acec31ef27614a88285878868f8e19484dfe3 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 22:10:05 +0000 Subject: [PATCH 16/31] Add a test for listCallback --- src/server/index.test.ts | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 8a5e297..f6aee24 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1065,6 +1065,63 @@ describe("Server.resource", () => { ); }); + test("should register resource template with listCallback", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource( + "test", + new UriTemplate("test://resource/{id}"), + async () => ({ + resources: [ + { + name: "Resource 1", + uri: "test://resource/1", + }, + { + name: "Resource 2", + uri: "test://resource/2", + }, + ], + }), + async (uri) => ({ + contents: [ + { + uri: uri.href, + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(2); + expect(result.resources[0].name).toBe("Resource 1"); + expect(result.resources[0].uri).toBe("test://resource/1"); + expect(result.resources[1].name).toBe("Resource 2"); + expect(result.resources[1].uri).toBe("test://resource/2"); + }); + test("should prevent duplicate resource registration", () => { const server = new Server({ name: "test server", From 45f99e6590237f490e6ceccd92be0e9bee150646 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Tue, 7 Jan 2025 22:36:20 +0000 Subject: [PATCH 17/31] Make `Variables` non-optional with resource template URIs --- src/server/index.test.ts | 44 ++++++++++++++++++++++++++++++++++++++++ src/server/index.ts | 29 +++++++++++++++----------- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index f6aee24..99a60e7 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1122,6 +1122,50 @@ describe("Server.resource", () => { expect(result.resources[1].uri).toBe("test://resource/2"); }); + test("should pass template variables to readCallback", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.resource( + "test", + new UriTemplate("test://resource/{category}/{id}"), + async (uri, { category, id }) => ({ + contents: [ + { + uri: uri.href, + text: `Category: ${category}, ID: ${id}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/read", + params: { + uri: "test://resource/books/123", + }, + }, + ReadResourceResultSchema, + ); + + expect(result.contents[0].text).toBe("Category: books, ID: 123"); + }); + test("should prevent duplicate resource registration", () => { const server = new Server({ name: "test server", diff --git a/src/server/index.ts b/src/server/index.ts index 159b382..a0cc2ed 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -78,26 +78,30 @@ const EMPTY_OBJECT_JSON_SCHEMA = { export type ResourceMetadata = Omit; -export type ReadResourceCallback = ( - uri: URL, - variables?: Variables, -) => ReadResourceResult | Promise; - export type ListResourcesCallback = () => | ListResourcesResult | Promise; +export type ReadResourceCallback = ( + uri: URL, +) => ReadResourceResult | Promise; + type RegisteredResource = { name: string; metadata?: ResourceMetadata; readCallback: ReadResourceCallback; }; +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, +) => ReadResourceResult | Promise; + type RegisteredResourceTemplate = { uriTemplate: UriTemplate; metadata?: ResourceMetadata; listCallback?: ListResourcesCallback; - readCallback: ReadResourceCallback; + readCallback: ReadResourceTemplateCallback; }; /** @@ -624,7 +628,7 @@ export class Server< resource( name: string, uriTemplate: UriTemplate, - readCallback: ReadResourceCallback, + readCallback: ReadResourceTemplateCallback, ): void; /** @@ -634,7 +638,7 @@ export class Server< name: string, uriTemplate: UriTemplate, metadata: ResourceMetadata, - readCallback: ReadResourceCallback, + readCallback: ReadResourceTemplateCallback, ): void; /** @@ -644,7 +648,7 @@ export class Server< name: string, uriTemplate: UriTemplate, listCallback: ListResourcesCallback, - readCallback: ReadResourceCallback, + readCallback: ReadResourceTemplateCallback, ): void; /** @@ -655,7 +659,7 @@ export class Server< uriTemplate: UriTemplate, metadata: ResourceMetadata, listCallback: ListResourcesCallback, - readCallback: ReadResourceCallback, + readCallback: ReadResourceTemplateCallback, ): void; resource( @@ -673,8 +677,8 @@ export class Server< listCallback = rest.shift() as ListResourcesCallback; } - const readCallback = rest[0] as ReadResourceCallback; if (typeof uriOrTemplate === "string") { + const readCallback = rest[0] as ReadResourceCallback; this.registerResource({ name, uri: uriOrTemplate, @@ -682,6 +686,7 @@ export class Server< readCallback, }); } else { + const readCallback = rest[0] as ReadResourceTemplateCallback; this.registerResourceTemplate({ name, uriTemplate: uriOrTemplate, @@ -727,7 +732,7 @@ export class Server< uriTemplate: UriTemplate; metadata?: ResourceMetadata; listCallback?: ListResourcesCallback; - readCallback: ReadResourceCallback; + readCallback: ReadResourceTemplateCallback; }): void { if (this._registeredResourceTemplates[name]) { throw new Error(`Resource template ${name} is already registered`); From 8356cb900c3843b2f1fc2b3d0a598bef6060b8c7 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 13:44:12 +0000 Subject: [PATCH 18/31] Create `ResourceTemplate` class and move `listCallback` into it --- src/server/index.test.ts | 42 +++++++++++++--- src/server/index.ts | 101 ++++++++++++++++++++------------------- 2 files changed, 88 insertions(+), 55 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 99a60e7..4ec0f30 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -25,6 +25,7 @@ import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; import { UriTemplate } from "../shared/uriTemplate.js"; +import { ResourceTemplate } from "./index.js"; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -553,6 +554,34 @@ test("should handle request timeout", async () => { }); }); +describe("ResourceTemplate", () => { + test("should create ResourceTemplate with string pattern", () => { + const template = new ResourceTemplate("test://{category}/{id}", undefined); + expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with UriTemplate", () => { + const uriTemplate = new UriTemplate("test://{category}/{id}"); + const template = new ResourceTemplate(uriTemplate, undefined); + expect(template.uriTemplate).toBe(uriTemplate); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with list callback", async () => { + const listCallback = jest.fn().mockResolvedValue({ + resources: [{ name: "Test", uri: "test://example" }], + }); + + const template = new ResourceTemplate("test://{id}", listCallback); + expect(template.listCallback).toBe(listCallback); + + const result = await template.listCallback?.(); + expect(result?.resources).toHaveLength(1); + expect(listCallback).toHaveBeenCalled(); + }); +}); + describe("Server.tool", () => { test("should register zero-argument tool", async () => { const server = new Server({ @@ -1032,7 +1061,7 @@ describe("Server.resource", () => { server.resource( "test", - new UriTemplate("test://resource/{id}"), + new ResourceTemplate("test://resource/{id}", undefined), async () => ({ contents: [ { @@ -1077,8 +1106,7 @@ describe("Server.resource", () => { server.resource( "test", - new UriTemplate("test://resource/{id}"), - async () => ({ + new ResourceTemplate("test://resource/{id}", async () => ({ resources: [ { name: "Resource 1", @@ -1089,7 +1117,7 @@ describe("Server.resource", () => { uri: "test://resource/2", }, ], - }), + })), async (uri) => ({ contents: [ { @@ -1134,7 +1162,7 @@ describe("Server.resource", () => { server.resource( "test", - new UriTemplate("test://resource/{category}/{id}"), + new ResourceTemplate("test://resource/{category}/{id}", undefined), async (uri, { category, id }) => ({ contents: [ { @@ -1201,7 +1229,7 @@ describe("Server.resource", () => { server.resource( "test", - new UriTemplate("test://resource/{id}"), + new ResourceTemplate("test://resource/{id}", undefined), async () => ({ contents: [ { @@ -1215,7 +1243,7 @@ describe("Server.resource", () => { expect(() => { server.resource( "test", - new UriTemplate("test://resource/{id}"), + new ResourceTemplate("test://resource/{id}", undefined), async () => ({ contents: [ { diff --git a/src/server/index.ts b/src/server/index.ts index a0cc2ed..9f51f91 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -98,9 +98,8 @@ export type ReadResourceTemplateCallback = ( ) => ReadResourceResult | Promise; type RegisteredResourceTemplate = { - uriTemplate: UriTemplate; + resourceTemplate: ResourceTemplate; metadata?: ResourceMetadata; - listCallback?: ListResourcesCallback; readCallback: ReadResourceTemplateCallback; }; @@ -558,11 +557,11 @@ export class Server< const templateResources: Resource[] = []; for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.listCallback) { + if (!template.resourceTemplate.listCallback) { continue; } - const result = await template.listCallback(); + const result = await template.resourceTemplate.listCallback(); for (const resource of result.resources) { templateResources.push({ ...resource, @@ -579,7 +578,7 @@ export class Server< this._registeredResourceTemplates, ).map(([name, template]) => ({ name, - uriTemplate: template.uriTemplate.toString(), + uriTemplate: template.resourceTemplate.uriTemplate.toString(), ...template.metadata, })); @@ -597,7 +596,9 @@ export class Server< // Then check templates for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.uriTemplate.match(uri.toString()); + const variables = template.resourceTemplate.uriTemplate.match( + uri.toString(), + ); if (variables) { return template.readCallback(uri, variables); } @@ -623,48 +624,27 @@ export class Server< ): void; /** - * Registers a resource `name` with a URI template pattern, which will use the given callback to respond to read requests. + * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. */ resource( name: string, - uriTemplate: UriTemplate, + template: ResourceTemplate, readCallback: ReadResourceTemplateCallback, ): void; /** - * Registers a resource `name` with a URI template pattern and metadata, which will use the given callback to respond to read requests. + * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. */ resource( name: string, - uriTemplate: UriTemplate, + template: ResourceTemplate, metadata: ResourceMetadata, readCallback: ReadResourceTemplateCallback, ): void; - /** - * Registers a resource `name` with a URI template pattern, which will use the list callback to enumerate matching resources and read callback to respond to read requests. - */ - resource( - name: string, - uriTemplate: UriTemplate, - listCallback: ListResourcesCallback, - readCallback: ReadResourceTemplateCallback, - ): void; - - /** - * Registers a resource `name` with a URI template pattern and metadata, which will use the list callback to enumerate matching resources and read callback to respond to read requests. - */ resource( name: string, - uriTemplate: UriTemplate, - metadata: ResourceMetadata, - listCallback: ListResourcesCallback, - readCallback: ReadResourceTemplateCallback, - ): void; - - resource( - name: string, - uriOrTemplate: string | UriTemplate, + uriOrTemplate: string | ResourceTemplate, ...rest: unknown[] ): void { let metadata: ResourceMetadata | undefined; @@ -672,27 +652,23 @@ export class Server< metadata = rest.shift() as ResourceMetadata; } - let listCallback: ListResourcesCallback | undefined; - if (rest.length > 1) { - listCallback = rest.shift() as ListResourcesCallback; - } + const readCallback = rest[0] as + | ReadResourceCallback + | ReadResourceTemplateCallback; if (typeof uriOrTemplate === "string") { - const readCallback = rest[0] as ReadResourceCallback; this.registerResource({ name, uri: uriOrTemplate, metadata, - readCallback, + readCallback: readCallback as ReadResourceCallback, }); } else { - const readCallback = rest[0] as ReadResourceTemplateCallback; this.registerResourceTemplate({ name, - uriTemplate: uriOrTemplate, + resourceTemplate: uriOrTemplate, metadata, - listCallback, - readCallback, + readCallback: readCallback as ReadResourceTemplateCallback, }); } } @@ -723,15 +699,13 @@ export class Server< private registerResourceTemplate({ name, - uriTemplate, + resourceTemplate, metadata, - listCallback, readCallback, }: { name: string; - uriTemplate: UriTemplate; + resourceTemplate: ResourceTemplate; metadata?: ResourceMetadata; - listCallback?: ListResourcesCallback; readCallback: ReadResourceTemplateCallback; }): void { if (this._registeredResourceTemplates[name]) { @@ -739,12 +713,43 @@ export class Server< } this._registeredResourceTemplates[name] = { - uriTemplate, + resourceTemplate, metadata, - listCallback, readCallback, }; this.setResourceRequestHandlers(); } } + +/** + * A resource template combines a URI pattern with optional functionality to enumerate + * all resources matching that pattern. + */ +export class ResourceTemplate { + private _uriTemplate: UriTemplate; + + constructor( + uriTemplate: string | UriTemplate, + private _listCallback: ListResourcesCallback | undefined, + ) { + this._uriTemplate = + typeof uriTemplate === "string" + ? new UriTemplate(uriTemplate) + : uriTemplate; + } + + /** + * Gets the URI template pattern. + */ + get uriTemplate(): UriTemplate { + return this._uriTemplate; + } + + /** + * Gets the list callback, if one was provided. + */ + get listCallback(): ListResourcesCallback | undefined { + return this._listCallback; + } +} From f090c931274ed0383cf950b6004535f941fce945 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 13:47:38 +0000 Subject: [PATCH 19/31] Move `Server` helper definitions to bottom of file --- src/server/index.ts | 112 ++++++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 50 deletions(-) diff --git a/src/server/index.ts b/src/server/index.ts index 9f51f91..d926b21 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -53,56 +53,6 @@ export type ServerOptions = ProtocolOptions & { capabilities?: ServerCapabilities; }; -/** - * Callback for a tool handler registered with Server.tool(). - * - * Parameters will include tool arguments, if applicable, as well as other request handler context. - */ -export type ToolCallback = - Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; - -type RegisteredTool = { - description?: string; - inputSchema?: AnyZodObject; - callback: ToolCallback; -}; - -const EMPTY_OBJECT_JSON_SCHEMA = { - type: "object" as const, -}; - -export type ResourceMetadata = Omit; - -export type ListResourcesCallback = () => - | ListResourcesResult - | Promise; - -export type ReadResourceCallback = ( - uri: URL, -) => ReadResourceResult | Promise; - -type RegisteredResource = { - name: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; -}; - -export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, -) => ReadResourceResult | Promise; - -type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; -}; - /** * An MCP server on top of a pluggable transport. * @@ -753,3 +703,65 @@ export class ResourceTemplate { return this._listCallback; } } + +/** + * Callback for a tool handler registered with Server.tool(). + * + * Parameters will include tool arguments, if applicable, as well as other request handler context. + */ +export type ToolCallback = + Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; + +type RegisteredTool = { + description?: string; + inputSchema?: AnyZodObject; + callback: ToolCallback; +}; + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: "object" as const, +}; + +/** + * Additional, optional information for annotating a resource. + */ +export type ResourceMetadata = Omit; + +/** + * Callback to list all resources matching a given template. + */ +export type ListResourcesCallback = () => + | ListResourcesResult + | Promise; + +/** + * Callback to read a resource at a given URI. + */ +export type ReadResourceCallback = ( + uri: URL, +) => ReadResourceResult | Promise; + +type RegisteredResource = { + name: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +}; + +/** + * Callback to read a resource at a given URI, following a filled-in URI template. + */ +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, +) => ReadResourceResult | Promise; + +type RegisteredResourceTemplate = { + resourceTemplate: ResourceTemplate; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; +}; From 5ddfa0a0d93861566da197bab739632e7fa85a93 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 14:01:28 +0000 Subject: [PATCH 20/31] Add `extra` to resource request handlers --- src/server/index.test.ts | 5 +- src/server/index.ts | 101 ++++++++++++++++++++++----------------- 2 files changed, 62 insertions(+), 44 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 4ec0f30..b862600 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -576,7 +576,10 @@ describe("ResourceTemplate", () => { const template = new ResourceTemplate("test://{id}", listCallback); expect(template.listCallback).toBe(listCallback); - const result = await template.listCallback?.(); + const abortController = new AbortController(); + const result = await template.listCallback?.({ + signal: abortController.signal, + }); expect(result?.resources).toHaveLength(1); expect(listCallback).toHaveBeenCalled(); }); diff --git a/src/server/index.ts b/src/server/index.ts index d926b21..bc5ea61 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -496,32 +496,37 @@ export class Server< resources: {}, }); - this.setRequestHandler(ListResourcesRequestSchema, async () => { - const resources = Object.entries(this._registeredResources).map( - ([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata, - }), - ); + this.setRequestHandler( + ListResourcesRequestSchema, + async (request, extra) => { + const resources = Object.entries(this._registeredResources).map( + ([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata, + }), + ); - const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { - continue; - } + const templateResources: Resource[] = []; + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + if (!template.resourceTemplate.listCallback) { + continue; + } - const result = await template.resourceTemplate.listCallback(); - for (const resource of result.resources) { - templateResources.push({ - ...resource, - ...template.metadata, - }); + const result = await template.resourceTemplate.listCallback(extra); + for (const resource of result.resources) { + templateResources.push({ + ...resource, + ...template.metadata, + }); + } } - } - return { resources: [...resources, ...templateResources] }; - }); + return { resources: [...resources, ...templateResources] }; + }, + ); this.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { const resourceTemplates = Object.entries( @@ -535,27 +540,35 @@ export class Server< return { resourceTemplates }; }); - this.setRequestHandler(ReadResourceRequestSchema, async (request) => { - const uri = new URL(request.params.uri); - - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - return resource.readCallback(uri); - } + this.setRequestHandler( + ReadResourceRequestSchema, + async (request, extra) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + return resource.readCallback(uri, extra); + } - // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match( - uri.toString(), - ); - if (variables) { - return template.readCallback(uri, variables); + // Then check templates + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + const variables = template.resourceTemplate.uriTemplate.match( + uri.toString(), + ); + if (variables) { + return template.readCallback(uri, variables, extra); + } } - } - throw new McpError(ErrorCode.InvalidParams, `Resource ${uri} not found`); - }); + throw new McpError( + ErrorCode.InvalidParams, + `Resource ${uri} not found`, + ); + }, + ); } /** @@ -735,15 +748,16 @@ export type ResourceMetadata = Omit; /** * Callback to list all resources matching a given template. */ -export type ListResourcesCallback = () => - | ListResourcesResult - | Promise; +export type ListResourcesCallback = ( + extra: RequestHandlerExtra, +) => ListResourcesResult | Promise; /** * Callback to read a resource at a given URI. */ export type ReadResourceCallback = ( uri: URL, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResource = { @@ -758,6 +772,7 @@ type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, + extra: RequestHandlerExtra, ) => ReadResourceResult | Promise; type RegisteredResourceTemplate = { From 9b95fa8203da0f3c60240faff1ff24dfca88f290 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 14:20:22 +0000 Subject: [PATCH 21/31] Server.prompt() convenience API --- src/server/index.test.ts | 286 +++++++++++++++++++++++++++++++++++++++ src/server/index.ts | 149 +++++++++++++++++++- 2 files changed, 433 insertions(+), 2 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index b862600..89bb578 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -20,6 +20,8 @@ import { ListResourcesResultSchema, ListResourceTemplatesResultSchema, ReadResourceResultSchema, + ListPromptsResultSchema, + GetPromptResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; @@ -1334,3 +1336,287 @@ describe("Server.resource", () => { ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); }); }); + +describe("Server.prompt", () => { + test("should register zero-argument prompt", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + + test("should register prompt with args schema", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.prompt( + "test", + { + name: z.string(), + value: z.string(), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toEqual([ + { name: "name", required: true }, + { name: "value", required: true }, + ]); + }); + + test("should register prompt with description", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + server.prompt("test", "Test description", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].description).toBe("Test description"); + }); + + test("should validate prompt args", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + server.prompt( + "test", + { + name: z.string(), + value: z.string().min(3), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "test", + arguments: { + name: "test", + value: "ab", // Too short + }, + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate prompt registration", () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + server.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + expect(() => { + server.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response 2", + }, + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should throw McpError for invalid prompt name", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + server.prompt("test-prompt", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "nonexistent-prompt", + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Prompt nonexistent-prompt not found/); + }); +}); diff --git a/src/server/index.ts b/src/server/index.ts index bc5ea61..c1efdd8 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,11 @@ -import z, { AnyZodObject, ZodRawShape, ZodTypeAny } from "zod"; +import z, { + AnyZodObject, + ZodObject, + ZodOptional, + ZodRawShape, + ZodString, + ZodTypeAny, +} from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { mergeCapabilities, @@ -15,12 +22,16 @@ import { CreateMessageResultSchema, EmptyResultSchema, ErrorCode, + GetPromptRequestSchema, + GetPromptResult, Implementation, InitializedNotificationSchema, InitializeRequest, InitializeRequestSchema, InitializeResult, LATEST_PROTOCOL_VERSION, + ListPromptsRequestSchema, + ListPromptsResult, ListResourcesRequestSchema, ListResourcesResult, ListResourceTemplatesRequestSchema, @@ -31,6 +42,8 @@ import { LoggingMessageNotification, McpError, Notification, + Prompt, + PromptArgument, ReadResourceRequestSchema, ReadResourceResult, Request, @@ -90,11 +103,12 @@ export class Server< private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; - private _registeredTools: { [name: string]: RegisteredTool } = {}; private _registeredResources: { [uri: string]: RegisteredResource } = {}; private _registeredResourceTemplates: { [name: string]: RegisteredResourceTemplate; } = {}; + private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -683,6 +697,106 @@ export class Server< this.setResourceRequestHandlers(); } + + private setPromptRequestHandlers() { + this.assertCanSetRequestHandler( + ListPromptsRequestSchema.shape.method.value, + ); + this.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); + + this.registerCapabilities({ + prompts: {}, + }); + + this.setRequestHandler( + ListPromptsRequestSchema, + (): ListPromptsResult => ({ + prompts: Object.entries(this._registeredPrompts).map( + ([name, prompt]): Prompt => { + return { + name, + description: prompt.description, + arguments: prompt.argsSchema + ? promptArgumentsFromSchema(prompt.argsSchema) + : undefined, + }; + }, + ), + }), + ); + + this.setRequestHandler( + GetPromptRequestSchema, + async (request, extra): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.name} not found`, + ); + } + + if (prompt.argsSchema) { + const parseResult = await prompt.argsSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(extra)); + } + }, + ); + } + + prompt(name: string, cb: PromptCallback): void; + prompt(name: string, description: string, cb: PromptCallback): void; + prompt( + name: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + prompt( + name: string, + description: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + prompt(name: string, ...rest: unknown[]): void { + if (this._registeredPrompts[name]) { + throw new Error(`Prompt ${name} is already registered`); + } + + let description: string | undefined; + if (typeof rest[0] === "string") { + description = rest.shift() as string; + } + + let argsSchema: PromptArgsRawShape | undefined; + if (rest.length > 1) { + argsSchema = rest.shift() as PromptArgsRawShape; + } + + const cb = rest[0] as PromptCallback; + this._registeredPrompts[name] = { + description, + argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), + callback: cb, + }; + + this.setPromptRequestHandlers(); + } } /** @@ -780,3 +894,34 @@ type RegisteredResourceTemplate = { metadata?: ResourceMetadata; readCallback: ReadResourceTemplateCallback; }; + +type PromptArgsRawShape = { + [k: string]: ZodString | ZodOptional; +}; + +export type PromptCallback< + Args extends undefined | PromptArgsRawShape = undefined, +> = Args extends PromptArgsRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => GetPromptResult | Promise + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + +type RegisteredPrompt = { + description?: string; + argsSchema?: ZodObject; + callback: PromptCallback; +}; + +function promptArgumentsFromSchema( + schema: ZodObject, +): PromptArgument[] { + return Object.entries(schema.shape).map( + ([name, field]): PromptArgument => ({ + name, + description: field.description, + required: !field.isOptional(), + }), + ); +} From 02fa78728cc5ae97c6a849cde944cb74fb281ded Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 15:00:33 +0000 Subject: [PATCH 22/31] Add `completable` wrapper for Zod schemas --- src/server/completable.test.ts | 46 ++++++++++++++++ src/server/completable.ts | 95 ++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 src/server/completable.test.ts create mode 100644 src/server/completable.ts diff --git a/src/server/completable.test.ts b/src/server/completable.test.ts new file mode 100644 index 0000000..6040ff3 --- /dev/null +++ b/src/server/completable.test.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import { completable } from "./completable.js"; + +describe("completable", () => { + it("preserves types and values of underlying schema", () => { + const baseSchema = z.string(); + const schema = completable(baseSchema, () => []); + + expect(schema.parse("test")).toBe("test"); + expect(() => schema.parse(123)).toThrow(); + }); + + it("provides access to completion function", async () => { + const completions = ["foo", "bar", "baz"]; + const schema = completable(z.string(), () => completions); + + expect(await schema._def.complete("")).toEqual(completions); + }); + + it("allows async completion functions", async () => { + const completions = ["foo", "bar", "baz"]; + const schema = completable(z.string(), async () => completions); + + expect(await schema._def.complete("")).toEqual(completions); + }); + + it("passes current value to completion function", async () => { + const schema = completable(z.string(), (value) => [value + "!"]); + + expect(await schema._def.complete("test")).toEqual(["test!"]); + }); + + it("works with number schemas", async () => { + const schema = completable(z.number(), () => [1, 2, 3]); + + expect(schema.parse(1)).toBe(1); + expect(await schema._def.complete(0)).toEqual([1, 2, 3]); + }); + + it("preserves schema description", () => { + const desc = "test description"; + const schema = completable(z.string().describe(desc), () => []); + + expect(schema.description).toBe(desc); + }); +}); diff --git a/src/server/completable.ts b/src/server/completable.ts new file mode 100644 index 0000000..3b5bc16 --- /dev/null +++ b/src/server/completable.ts @@ -0,0 +1,95 @@ +import { + ZodTypeAny, + ZodTypeDef, + ZodType, + ParseInput, + ParseReturnType, + RawCreateParams, + ZodErrorMap, + ProcessedCreateParams, +} from "zod"; + +export enum McpZodTypeKind { + Completable = "McpCompletable", +} + +export type CompleteCallback = ( + value: T["_input"], +) => T["_input"][] | Promise; + +export interface CompletableDef + extends ZodTypeDef { + type: T; + complete: CompleteCallback; + typeName: McpZodTypeKind.Completable; +} + +export class Completable extends ZodType< + T["_output"], + CompletableDef, + T["_input"] +> { + _parse(input: ParseInput): ParseReturnType { + const { ctx } = this._processInputParams(input); + const data = ctx.data; + return this._def.type._parse({ + data, + path: ctx.path, + parent: ctx, + }); + } + + unwrap() { + return this._def.type; + } + + static create = ( + type: T, + params: RawCreateParams & { + complete: CompleteCallback; + }, + ): Completable => { + return new Completable({ + type, + typeName: McpZodTypeKind.Completable, + complete: params.complete, + ...processCreateParams(params), + }); + }; +} + +/** + * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. + */ +export function completable( + schema: T, + complete: CompleteCallback, +): Completable { + return Completable.create(schema, { ...schema._def, complete }); +} + +// Not sure why this isn't exported from Zod: +// https://github.com/colinhacks/zod/blob/f7ad26147ba291cb3fb257545972a8e00e767470/src/types.ts#L130 +function processCreateParams(params: RawCreateParams): ProcessedCreateParams { + if (!params) return {}; + const { errorMap, invalid_type_error, required_error, description } = params; + if (errorMap && (invalid_type_error || required_error)) { + throw new Error( + `Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`, + ); + } + if (errorMap) return { errorMap: errorMap, description }; + const customMap: ZodErrorMap = (iss, ctx) => { + const { message } = params; + + if (iss.code === "invalid_enum_value") { + return { message: message ?? ctx.defaultError }; + } + if (typeof ctx.data === "undefined") { + return { message: message ?? required_error ?? ctx.defaultError }; + } + if (iss.code !== "invalid_type") return { message: ctx.defaultError }; + return { message: message ?? invalid_type_error ?? ctx.defaultError }; + }; + return { errorMap: customMap, description }; +} From 8635c6384f65f4229cb9c5ea8deb1c158d7d3beb Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 15:19:16 +0000 Subject: [PATCH 23/31] Autocomplete on prompt arguments --- src/server/index.test.ts | 133 +++++++++++++++++++++++++++++++++++++++ src/server/index.ts | 61 +++++++++++++++++- 2 files changed, 192 insertions(+), 2 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 89bb578..cb686d6 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -22,12 +22,14 @@ import { ReadResourceResultSchema, ListPromptsResultSchema, GetPromptResultSchema, + CompleteResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; import { UriTemplate } from "../shared/uriTemplate.js"; import { ResourceTemplate } from "./index.js"; +import { completable } from "./completable.js"; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -1619,4 +1621,135 @@ describe("Server.prompt", () => { ), ).rejects.toThrow(/Prompt nonexistent-prompt not found/); }); + test("should support completion of prompt arguments", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + server.prompt( + "test-prompt", + { + name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of prompt arguments", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + server.prompt( + "test-prompt", + { + name: completable(z.string(), (test) => + ["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)), + ), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "A", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice"]); + expect(result.completion.total).toBe(1); + }); }); diff --git a/src/server/index.ts b/src/server/index.ts index c1efdd8..2816ee8 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -4,7 +4,9 @@ import z, { ZodOptional, ZodRawShape, ZodString, + ZodType, ZodTypeAny, + ZodTypeDef, } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { @@ -18,6 +20,8 @@ import { CallToolRequestSchema, CallToolResult, ClientCapabilities, + CompleteRequestSchema, + CompleteResult, CreateMessageRequest, CreateMessageResultSchema, EmptyResultSchema, @@ -58,6 +62,7 @@ import { Tool, } from "../types.js"; import { UriTemplate, Variables } from "../shared/uriTemplate.js"; +import { Completable, CompletableDef } from "./completable.js"; export type ServerOptions = ProtocolOptions & { /** @@ -703,6 +708,7 @@ export class Server< ListPromptsRequestSchema.shape.method.value, ); this.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); + this.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); this.registerCapabilities({ prompts: {}, @@ -756,6 +762,56 @@ export class Server< } }, ); + + this.setRequestHandler( + CompleteRequestSchema, + async (request): Promise => { + if (request.params.ref.type !== "ref/prompt") { + throw new McpError( + ErrorCode.InvalidParams, + "Only prompt completions are supported", + ); + } + + const prompt = this._registeredPrompts[request.params.ref.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.ref.name} not found`, + ); + } + + if (!prompt.argsSchema) { + return { + completion: { + values: [], + hasMore: false, + }, + }; + } + + const field = prompt.argsSchema.shape[request.params.argument.name]; + if (!(field instanceof Completable)) { + return { + completion: { + values: [], + hasMore: false, + }, + }; + } + + const def: CompletableDef = field._def; + const completer = def.complete; + const suggestions = await completer(request.params.argument.value); + return { + completion: { + values: suggestions.slice(0, 100), + total: suggestions.length, + hasMore: suggestions.length > 100, + }, + }; + }, + ); } prompt(name: string, cb: PromptCallback): void; @@ -765,7 +821,6 @@ export class Server< argsSchema: Args, cb: PromptCallback, ): void; - prompt( name: string, description: string, @@ -896,7 +951,9 @@ type RegisteredResourceTemplate = { }; type PromptArgsRawShape = { - [k: string]: ZodString | ZodOptional; + [k: string]: + | ZodType + | ZodOptional>; }; export type PromptCallback< From 189cc84113ec5ce9cb132a8d24f808ec4be3fa92 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 15:46:43 +0000 Subject: [PATCH 24/31] Autocomplete on resource template variables --- src/server/index.test.ts | 183 +++++++++++++++++++++++++++++++++----- src/server/index.ts | 185 ++++++++++++++++++++++++++++----------- 2 files changed, 294 insertions(+), 74 deletions(-) diff --git a/src/server/index.test.ts b/src/server/index.test.ts index cb686d6..2957b87 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -560,32 +560,34 @@ test("should handle request timeout", async () => { describe("ResourceTemplate", () => { test("should create ResourceTemplate with string pattern", () => { - const template = new ResourceTemplate("test://{category}/{id}", undefined); + const template = new ResourceTemplate("test://{category}/{id}", { + list: undefined, + }); expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); expect(template.listCallback).toBeUndefined(); }); test("should create ResourceTemplate with UriTemplate", () => { const uriTemplate = new UriTemplate("test://{category}/{id}"); - const template = new ResourceTemplate(uriTemplate, undefined); + const template = new ResourceTemplate(uriTemplate, { list: undefined }); expect(template.uriTemplate).toBe(uriTemplate); expect(template.listCallback).toBeUndefined(); }); test("should create ResourceTemplate with list callback", async () => { - const listCallback = jest.fn().mockResolvedValue({ + const list = jest.fn().mockResolvedValue({ resources: [{ name: "Test", uri: "test://example" }], }); - const template = new ResourceTemplate("test://{id}", listCallback); - expect(template.listCallback).toBe(listCallback); + const template = new ResourceTemplate("test://{id}", { list }); + expect(template.listCallback).toBe(list); const abortController = new AbortController(); const result = await template.listCallback?.({ signal: abortController.signal, }); expect(result?.resources).toHaveLength(1); - expect(listCallback).toHaveBeenCalled(); + expect(list).toHaveBeenCalled(); }); }); @@ -1068,7 +1070,7 @@ describe("Server.resource", () => { server.resource( "test", - new ResourceTemplate("test://resource/{id}", undefined), + new ResourceTemplate("test://resource/{id}", { list: undefined }), async () => ({ contents: [ { @@ -1113,18 +1115,20 @@ describe("Server.resource", () => { server.resource( "test", - new ResourceTemplate("test://resource/{id}", async () => ({ - resources: [ - { - name: "Resource 1", - uri: "test://resource/1", - }, - { - name: "Resource 2", - uri: "test://resource/2", - }, - ], - })), + new ResourceTemplate("test://resource/{id}", { + list: async () => ({ + resources: [ + { + name: "Resource 1", + uri: "test://resource/1", + }, + { + name: "Resource 2", + uri: "test://resource/2", + }, + ], + }), + }), async (uri) => ({ contents: [ { @@ -1169,7 +1173,9 @@ describe("Server.resource", () => { server.resource( "test", - new ResourceTemplate("test://resource/{category}/{id}", undefined), + new ResourceTemplate("test://resource/{category}/{id}", { + list: undefined, + }), async (uri, { category, id }) => ({ contents: [ { @@ -1236,7 +1242,7 @@ describe("Server.resource", () => { server.resource( "test", - new ResourceTemplate("test://resource/{id}", undefined), + new ResourceTemplate("test://resource/{id}", { list: undefined }), async () => ({ contents: [ { @@ -1250,7 +1256,7 @@ describe("Server.resource", () => { expect(() => { server.resource( "test", - new ResourceTemplate("test://resource/{id}", undefined), + new ResourceTemplate("test://resource/{id}", { list: undefined }), async () => ({ contents: [ { @@ -1337,6 +1343,139 @@ describe("Server.resource", () => { ), ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); }); + + test("should support completion of resource template parameters", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + server.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: () => ["books", "movies", "music"], + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["books", "movies", "music"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of resource template parameters", async () => { + const server = new Server({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + server.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: (test) => + ["books", "movies", "music"].filter((value) => + value.startsWith(test), + ), + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "m", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["movies", "music"]); + expect(result.completion.total).toBe(2); + }); }); describe("Server.prompt", () => { diff --git a/src/server/index.ts b/src/server/index.ts index 2816ee8..1810883 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -20,6 +20,7 @@ import { CallToolRequestSchema, CallToolResult, ClientCapabilities, + CompleteRequest, CompleteRequestSchema, CompleteResult, CreateMessageRequest, @@ -48,10 +49,12 @@ import { Notification, Prompt, PromptArgument, + PromptReference, ReadResourceRequestSchema, ReadResourceResult, Request, Resource, + ResourceReference, ResourceUpdatedNotification, Result, ServerCapabilities, @@ -500,6 +503,86 @@ export class Server< this.setToolRequestHandlers(); } + private setCompletionRequestHandler() { + this.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); + + this.setRequestHandler( + CompleteRequestSchema, + async (request): Promise => { + switch (request.params.ref.type) { + case "ref/prompt": + return this.handlePromptCompletion(request, request.params.ref); + + case "ref/resource": + return this.handleResourceCompletion(request, request.params.ref); + + default: + throw new McpError( + ErrorCode.InvalidParams, + `Invalid completion reference: ${request.params.ref}`, + ); + } + }, + ); + } + + private async handlePromptCompletion( + request: CompleteRequest, + ref: PromptReference, + ): Promise { + const prompt = this._registeredPrompts[ref.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.ref.name} not found`, + ); + } + + if (!prompt.argsSchema) { + return EMPTY_COMPLETION_RESULT; + } + + const field = prompt.argsSchema.shape[request.params.argument.name]; + if (!(field instanceof Completable)) { + return EMPTY_COMPLETION_RESULT; + } + + const def: CompletableDef = field._def; + const suggestions = await def.complete(request.params.argument.value); + return createCompletionResult(suggestions); + } + + private async handleResourceCompletion( + request: CompleteRequest, + ref: ResourceReference, + ): Promise { + const template = Object.values(this._registeredResourceTemplates).find( + (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, + ); + + if (!template) { + if (this._registeredResources[ref.uri]) { + // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + throw new McpError( + ErrorCode.InvalidParams, + `Resource template ${request.params.ref.uri} not found`, + ); + } + + const completer = template.resourceTemplate.completeCallback( + request.params.argument.name, + ); + if (!completer) { + return EMPTY_COMPLETION_RESULT; + } + + const suggestions = await completer(request.params.argument.value); + return createCompletionResult(suggestions); + } + private setResourceRequestHandlers() { this.assertCanSetRequestHandler( ListResourcesRequestSchema.shape.method.value, @@ -588,6 +671,8 @@ export class Server< ); }, ); + + this.setCompletionRequestHandler(); } /** @@ -708,7 +793,6 @@ export class Server< ListPromptsRequestSchema.shape.method.value, ); this.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); - this.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); this.registerCapabilities({ prompts: {}, @@ -763,55 +847,7 @@ export class Server< }, ); - this.setRequestHandler( - CompleteRequestSchema, - async (request): Promise => { - if (request.params.ref.type !== "ref/prompt") { - throw new McpError( - ErrorCode.InvalidParams, - "Only prompt completions are supported", - ); - } - - const prompt = this._registeredPrompts[request.params.ref.name]; - if (!prompt) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${request.params.ref.name} not found`, - ); - } - - if (!prompt.argsSchema) { - return { - completion: { - values: [], - hasMore: false, - }, - }; - } - - const field = prompt.argsSchema.shape[request.params.argument.name]; - if (!(field instanceof Completable)) { - return { - completion: { - values: [], - hasMore: false, - }, - }; - } - - const def: CompletableDef = field._def; - const completer = def.complete; - const suggestions = await completer(request.params.argument.value); - return { - completion: { - values: suggestions.slice(0, 100), - total: suggestions.length, - hasMore: suggestions.length > 100, - }, - }; - }, - ); + this.setCompletionRequestHandler(); } prompt(name: string, cb: PromptCallback): void; @@ -854,6 +890,13 @@ export class Server< } } +/** + * A callback to complete one variable within a resource template's URI template. + */ +export type CompleteResourceTemplateCallback = ( + value: string, +) => string[] | Promise; + /** * A resource template combines a URI pattern with optional functionality to enumerate * all resources matching that pattern. @@ -863,7 +906,19 @@ export class ResourceTemplate { constructor( uriTemplate: string | UriTemplate, - private _listCallback: ListResourcesCallback | undefined, + private _callbacks: { + /** + * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. + */ + list: ListResourcesCallback | undefined; + + /** + * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. + */ + complete?: { + [variable: string]: CompleteResourceTemplateCallback; + }; + }, ) { this._uriTemplate = typeof uriTemplate === "string" @@ -882,7 +937,16 @@ export class ResourceTemplate { * Gets the list callback, if one was provided. */ get listCallback(): ListResourcesCallback | undefined { - return this._listCallback; + return this._callbacks.list; + } + + /** + * Gets the callback for completing a specific URI template variable, if one was provided. + */ + completeCallback( + variable: string, + ): CompleteResourceTemplateCallback | undefined { + return this._callbacks.complete?.[variable]; } } @@ -982,3 +1046,20 @@ function promptArgumentsFromSchema( }), ); } + +function createCompletionResult(suggestions: string[]): CompleteResult { + return { + completion: { + values: suggestions.slice(0, 100), + total: suggestions.length, + hasMore: suggestions.length > 100, + }, + }; +} + +const EMPTY_COMPLETION_RESULT: CompleteResult = { + completion: { + values: [], + hasMore: false, + }, +}; From 2289dbc0d0e48314eeeed5dfc38ae606fd3b901d Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 21:59:38 +0000 Subject: [PATCH 25/31] Factor out convenience APIs into another class --- src/server/index.test.ts | 1346 ------------------------------------ src/server/index.ts | 741 -------------------- src/server/mcp.test.ts | 1395 ++++++++++++++++++++++++++++++++++++++ src/server/mcp.ts | 767 +++++++++++++++++++++ src/shared/protocol.ts | 10 +- 5 files changed, 2169 insertions(+), 2090 deletions(-) create mode 100644 src/server/mcp.test.ts create mode 100644 src/server/mcp.ts diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 2957b87..2454339 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -15,21 +15,10 @@ import { ListToolsRequestSchema, SetLevelRequestSchema, ErrorCode, - ListToolsResultSchema, - CallToolResultSchema, - ListResourcesResultSchema, - ListResourceTemplatesResultSchema, - ReadResourceResultSchema, - ListPromptsResultSchema, - GetPromptResultSchema, - CompleteResultSchema, } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; -import { UriTemplate } from "../shared/uriTemplate.js"; -import { ResourceTemplate } from "./index.js"; -import { completable } from "./completable.js"; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -557,1338 +546,3 @@ test("should handle request timeout", async () => { code: ErrorCode.RequestTimeout, }); }); - -describe("ResourceTemplate", () => { - test("should create ResourceTemplate with string pattern", () => { - const template = new ResourceTemplate("test://{category}/{id}", { - list: undefined, - }); - expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); - expect(template.listCallback).toBeUndefined(); - }); - - test("should create ResourceTemplate with UriTemplate", () => { - const uriTemplate = new UriTemplate("test://{category}/{id}"); - const template = new ResourceTemplate(uriTemplate, { list: undefined }); - expect(template.uriTemplate).toBe(uriTemplate); - expect(template.listCallback).toBeUndefined(); - }); - - test("should create ResourceTemplate with list callback", async () => { - const list = jest.fn().mockResolvedValue({ - resources: [{ name: "Test", uri: "test://example" }], - }); - - const template = new ResourceTemplate("test://{id}", { list }); - expect(template.listCallback).toBe(list); - - const abortController = new AbortController(); - const result = await template.listCallback?.({ - signal: abortController.signal, - }); - expect(result?.resources).toHaveLength(1); - expect(list).toHaveBeenCalled(); - }); -}); - -describe("Server.tool", () => { - test("should register zero-argument tool", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].inputSchema).toEqual({ - type: "object", - }); - }); - - test("should register tool with args schema", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.tool( - "test", - { - name: z.string(), - value: z.number(), - }, - async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `${name}: ${value}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].inputSchema).toMatchObject({ - type: "object", - properties: { - name: { type: "string" }, - value: { type: "number" }, - }, - }); - }); - - test("should register tool with description", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.tool("test", "Test description", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/list", - }, - ListToolsResultSchema, - ); - - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("test"); - expect(result.tools[0].description).toBe("Test description"); - }); - - test("should validate tool args", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); - - server.tool( - "test", - { - name: z.string(), - value: z.number(), - }, - async ({ name, value }) => ({ - content: [ - { - type: "text", - text: `${name}: ${value}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - name: "test", - value: "not a number", - }, - }, - }, - CallToolResultSchema, - ), - ).rejects.toThrow(/Invalid arguments/); - }); - - test("should prevent duplicate tool registration", () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - server.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - expect(() => { - server.tool("test", async () => ({ - content: [ - { - type: "text", - text: "Test response 2", - }, - ], - })); - }).toThrow(/already registered/); - }); - - test("should allow client to call server tools", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); - - server.tool( - "test", - "Test tool", - { - input: z.string(), - }, - async ({ input }) => ({ - content: [ - { - type: "text", - text: `Processed: ${input}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/call", - params: { - name: "test", - arguments: { - input: "hello", - }, - }, - }, - CallToolResultSchema, - ); - - expect(result.content).toEqual([ - { - type: "text", - text: "Processed: hello", - }, - ]); - }); - - test("should handle server tool errors gracefully", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); - - server.tool("error-test", async () => { - throw new Error("Tool execution failed"); - }); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "tools/call", - params: { - name: "error-test", - }, - }, - CallToolResultSchema, - ); - - expect(result.isError).toBe(true); - expect(result.content).toEqual([ - { - type: "text", - text: "Tool execution failed", - }, - ]); - }); - - test("should throw McpError for invalid tool name", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); - - server.tool("test-tool", async () => ({ - content: [ - { - type: "text", - text: "Test response", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "tools/call", - params: { - name: "nonexistent-tool", - }, - }, - CallToolResultSchema, - ), - ).rejects.toThrow(/Tool nonexistent-tool not found/); - }); -}); - -describe("Server.resource", () => { - test("should register resource with uri and readCallback", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - expect(result.resources[0].name).toBe("test"); - expect(result.resources[0].uri).toBe("test://resource"); - }); - - test("should register resource with metadata", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource( - "test", - "test://resource", - { - description: "Test resource", - mimeType: "text/plain", - }, - async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(1); - expect(result.resources[0].description).toBe("Test resource"); - expect(result.resources[0].mimeType).toBe("text/plain"); - }); - - test("should register resource template", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ - { - uri: "test://resource/123", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/templates/list", - }, - ListResourceTemplatesResultSchema, - ); - - expect(result.resourceTemplates).toHaveLength(1); - expect(result.resourceTemplates[0].name).toBe("test"); - expect(result.resourceTemplates[0].uriTemplate).toBe( - "test://resource/{id}", - ); - }); - - test("should register resource template with listCallback", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource( - "test", - new ResourceTemplate("test://resource/{id}", { - list: async () => ({ - resources: [ - { - name: "Resource 1", - uri: "test://resource/1", - }, - { - name: "Resource 2", - uri: "test://resource/2", - }, - ], - }), - }), - async (uri) => ({ - contents: [ - { - uri: uri.href, - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/list", - }, - ListResourcesResultSchema, - ); - - expect(result.resources).toHaveLength(2); - expect(result.resources[0].name).toBe("Resource 1"); - expect(result.resources[0].uri).toBe("test://resource/1"); - expect(result.resources[1].name).toBe("Resource 2"); - expect(result.resources[1].uri).toBe("test://resource/2"); - }); - - test("should pass template variables to readCallback", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource( - "test", - new ResourceTemplate("test://resource/{category}/{id}", { - list: undefined, - }), - async (uri, { category, id }) => ({ - contents: [ - { - uri: uri.href, - text: `Category: ${category}, ID: ${id}`, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "resources/read", - params: { - uri: "test://resource/books/123", - }, - }, - ReadResourceResultSchema, - ); - - expect(result.contents[0].text).toBe("Category: books, ID: 123"); - }); - - test("should prevent duplicate resource registration", () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - server.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - expect(() => { - server.resource("test2", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content 2", - }, - ], - })); - }).toThrow(/already registered/); - }); - - test("should prevent duplicate resource template registration", () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - server.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ - { - uri: "test://resource/123", - text: "Test content", - }, - ], - }), - ); - - expect(() => { - server.resource( - "test", - new ResourceTemplate("test://resource/{id}", { list: undefined }), - async () => ({ - contents: [ - { - uri: "test://resource/123", - text: "Test content 2", - }, - ], - }), - ); - }).toThrow(/already registered/); - }); - - test("should handle resource read errors gracefully", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource("error-test", "test://error", async () => { - throw new Error("Resource read failed"); - }); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "resources/read", - params: { - uri: "test://error", - }, - }, - ReadResourceResultSchema, - ), - ).rejects.toThrow(/Resource read failed/); - }); - - test("should throw McpError for invalid resource URI", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.resource("test", "test://resource", async () => ({ - contents: [ - { - uri: "test://resource", - text: "Test content", - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "resources/read", - params: { - uri: "test://nonexistent", - }, - }, - ReadResourceResultSchema, - ), - ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); - }); - - test("should support completion of resource template parameters", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); - - server.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: () => ["books", "movies", "music"], - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "test://resource/{category}", - }, - argument: { - name: "category", - value: "", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["books", "movies", "music"]); - expect(result.completion.total).toBe(3); - }); - - test("should support filtered completion of resource template parameters", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); - - server.resource( - "test", - new ResourceTemplate("test://resource/{category}", { - list: undefined, - complete: { - category: (test) => - ["books", "movies", "music"].filter((value) => - value.startsWith(test), - ), - }, - }), - async () => ({ - contents: [ - { - uri: "test://resource/test", - text: "Test content", - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/resource", - uri: "test://resource/{category}", - }, - argument: { - name: "category", - value: "m", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["movies", "music"]); - expect(result.completion.total).toBe(2); - }); -}); - -describe("Server.prompt", () => { - test("should register zero-argument prompt", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].arguments).toBeUndefined(); - }); - - test("should register prompt with args schema", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.prompt( - "test", - { - name: z.string(), - value: z.string(), - }, - async ({ name, value }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `${name}: ${value}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].arguments).toEqual([ - { name: "name", required: true }, - { name: "value", required: true }, - ]); - }); - - test("should register prompt with description", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - server.prompt("test", "Test description", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "prompts/list", - }, - ListPromptsResultSchema, - ); - - expect(result.prompts).toHaveLength(1); - expect(result.prompts[0].name).toBe("test"); - expect(result.prompts[0].description).toBe("Test description"); - }); - - test("should validate prompt args", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); - - server.prompt( - "test", - { - name: z.string(), - value: z.string().min(3), - }, - async ({ name, value }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `${name}: ${value}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "prompts/get", - params: { - name: "test", - arguments: { - name: "test", - value: "ab", // Too short - }, - }, - }, - GetPromptResultSchema, - ), - ).rejects.toThrow(/Invalid arguments/); - }); - - test("should prevent duplicate prompt registration", () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - server.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - expect(() => { - server.prompt("test", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response 2", - }, - }, - ], - })); - }).toThrow(/already registered/); - }); - - test("should throw McpError for invalid prompt name", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); - - server.prompt("test-prompt", async () => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: "Test response", - }, - }, - ], - })); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - await expect( - client.request( - { - method: "prompts/get", - params: { - name: "nonexistent-prompt", - }, - }, - GetPromptResultSchema, - ), - ).rejects.toThrow(/Prompt nonexistent-prompt not found/); - }); - test("should support completion of prompt arguments", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); - - server.prompt( - "test-prompt", - { - name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]); - expect(result.completion.total).toBe(3); - }); - - test("should support filtered completion of prompt arguments", async () => { - const server = new Server({ - name: "test server", - version: "1.0", - }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); - - server.prompt( - "test-prompt", - { - name: completable(z.string(), (test) => - ["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)), - ), - }, - async ({ name }) => ({ - messages: [ - { - role: "assistant", - content: { - type: "text", - text: `Hello ${name}`, - }, - }, - ], - }), - ); - - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - - await Promise.all([ - client.connect(clientTransport), - server.connect(serverTransport), - ]); - - const result = await client.request( - { - method: "completion/complete", - params: { - ref: { - type: "ref/prompt", - name: "test-prompt", - }, - argument: { - name: "name", - value: "A", - }, - }, - }, - CompleteResultSchema, - ); - - expect(result.completion.values).toEqual(["Alice"]); - expect(result.completion.total).toBe(1); - }); -}); diff --git a/src/server/index.ts b/src/server/index.ts index 1810883..bf8cd6b 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,60 +1,25 @@ -import z, { - AnyZodObject, - ZodObject, - ZodOptional, - ZodRawShape, - ZodString, - ZodType, - ZodTypeAny, - ZodTypeDef, -} from "zod"; -import { zodToJsonSchema } from "zod-to-json-schema"; import { mergeCapabilities, Protocol, ProtocolOptions, - RequestHandlerExtra, RequestOptions, } from "../shared/protocol.js"; import { - CallToolRequestSchema, - CallToolResult, ClientCapabilities, - CompleteRequest, - CompleteRequestSchema, - CompleteResult, CreateMessageRequest, CreateMessageResultSchema, EmptyResultSchema, - ErrorCode, - GetPromptRequestSchema, - GetPromptResult, Implementation, InitializedNotificationSchema, InitializeRequest, InitializeRequestSchema, InitializeResult, LATEST_PROTOCOL_VERSION, - ListPromptsRequestSchema, - ListPromptsResult, - ListResourcesRequestSchema, - ListResourcesResult, - ListResourceTemplatesRequestSchema, ListRootsRequest, ListRootsResultSchema, - ListToolsRequestSchema, - ListToolsResult, LoggingMessageNotification, - McpError, Notification, - Prompt, - PromptArgument, - PromptReference, - ReadResourceRequestSchema, - ReadResourceResult, Request, - Resource, - ResourceReference, ResourceUpdatedNotification, Result, ServerCapabilities, @@ -62,10 +27,7 @@ import { ServerRequest, ServerResult, SUPPORTED_PROTOCOL_VERSIONS, - Tool, } from "../types.js"; -import { UriTemplate, Variables } from "../shared/uriTemplate.js"; -import { Completable, CompletableDef } from "./completable.js"; export type ServerOptions = ProtocolOptions & { /** @@ -111,12 +73,6 @@ export class Server< private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; private _capabilities: ServerCapabilities; - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { - [name: string]: RegisteredResourceTemplate; - } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -365,701 +321,4 @@ export class Server< async sendPromptListChanged() { return this.notification({ method: "notifications/prompts/list_changed" }); } - - private setToolRequestHandlers() { - this.assertCanSetRequestHandler(ListToolsRequestSchema.shape.method.value); - this.assertCanSetRequestHandler(CallToolRequestSchema.shape.method.value); - - this.registerCapabilities({ - tools: {}, - }); - - this.setRequestHandler( - ListToolsRequestSchema, - (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools).map( - ([name, tool]): Tool => { - return { - name, - description: tool.description, - inputSchema: tool.inputSchema - ? (zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"]) - : EMPTY_OBJECT_JSON_SCHEMA, - }; - }, - ), - }), - ); - - this.setRequestHandler( - CallToolRequestSchema, - async (request, extra): Promise => { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new McpError( - ErrorCode.InvalidParams, - `Tool ${request.params.name} not found`, - ); - } - - if (tool.inputSchema) { - const parseResult = await tool.inputSchema.safeParseAsync( - request.params.arguments, - ); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, - ); - } - - const args = parseResult.data; - const cb = tool.callback as ToolCallback; - try { - return await Promise.resolve(cb(args, extra)); - } catch (error) { - return { - content: [ - { - type: "text", - text: error instanceof Error ? error.message : String(error), - }, - ], - isError: true, - }; - } - } else { - const cb = tool.callback as ToolCallback; - try { - return await Promise.resolve(cb(extra)); - } catch (error) { - return { - content: [ - { - type: "text", - text: error instanceof Error ? error.message : String(error), - }, - ], - isError: true, - }; - } - } - }, - ); - } - - /** - * Registers a zero-argument tool `name`, which will run the given function when the client calls it. - */ - tool(name: string, cb: ToolCallback): void; - - /** - * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. - */ - tool(name: string, description: string, cb: ToolCallback): void; - - /** - * Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. - */ - tool( - name: string, - paramsSchema: Args, - cb: ToolCallback, - ): void; - - /** - * Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. - */ - tool( - name: string, - description: string, - paramsSchema: Args, - cb: ToolCallback, - ): void; - - tool(name: string, ...rest: unknown[]): void { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } - - let description: string | undefined; - if (typeof rest[0] === "string") { - description = rest.shift() as string; - } - - let paramsSchema: ZodRawShape | undefined; - if (rest.length > 1) { - paramsSchema = rest.shift() as ZodRawShape; - } - - const cb = rest[0] as ToolCallback; - this._registeredTools[name] = { - description, - inputSchema: - paramsSchema === undefined ? undefined : z.object(paramsSchema), - callback: cb, - }; - - this.setToolRequestHandlers(); - } - - private setCompletionRequestHandler() { - this.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); - - this.setRequestHandler( - CompleteRequestSchema, - async (request): Promise => { - switch (request.params.ref.type) { - case "ref/prompt": - return this.handlePromptCompletion(request, request.params.ref); - - case "ref/resource": - return this.handleResourceCompletion(request, request.params.ref); - - default: - throw new McpError( - ErrorCode.InvalidParams, - `Invalid completion reference: ${request.params.ref}`, - ); - } - }, - ); - } - - private async handlePromptCompletion( - request: CompleteRequest, - ref: PromptReference, - ): Promise { - const prompt = this._registeredPrompts[ref.name]; - if (!prompt) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${request.params.ref.name} not found`, - ); - } - - if (!prompt.argsSchema) { - return EMPTY_COMPLETION_RESULT; - } - - const field = prompt.argsSchema.shape[request.params.argument.name]; - if (!(field instanceof Completable)) { - return EMPTY_COMPLETION_RESULT; - } - - const def: CompletableDef = field._def; - const suggestions = await def.complete(request.params.argument.value); - return createCompletionResult(suggestions); - } - - private async handleResourceCompletion( - request: CompleteRequest, - ref: ResourceReference, - ): Promise { - const template = Object.values(this._registeredResourceTemplates).find( - (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, - ); - - if (!template) { - if (this._registeredResources[ref.uri]) { - // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). - return EMPTY_COMPLETION_RESULT; - } - - throw new McpError( - ErrorCode.InvalidParams, - `Resource template ${request.params.ref.uri} not found`, - ); - } - - const completer = template.resourceTemplate.completeCallback( - request.params.argument.name, - ); - if (!completer) { - return EMPTY_COMPLETION_RESULT; - } - - const suggestions = await completer(request.params.argument.value); - return createCompletionResult(suggestions); - } - - private setResourceRequestHandlers() { - this.assertCanSetRequestHandler( - ListResourcesRequestSchema.shape.method.value, - ); - this.assertCanSetRequestHandler( - ListResourceTemplatesRequestSchema.shape.method.value, - ); - this.assertCanSetRequestHandler( - ReadResourceRequestSchema.shape.method.value, - ); - - this.registerCapabilities({ - resources: {}, - }); - - this.setRequestHandler( - ListResourcesRequestSchema, - async (request, extra) => { - const resources = Object.entries(this._registeredResources).map( - ([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata, - }), - ); - - const templateResources: Resource[] = []; - for (const template of Object.values( - this._registeredResourceTemplates, - )) { - if (!template.resourceTemplate.listCallback) { - continue; - } - - const result = await template.resourceTemplate.listCallback(extra); - for (const resource of result.resources) { - templateResources.push({ - ...resource, - ...template.metadata, - }); - } - } - - return { resources: [...resources, ...templateResources] }; - }, - ); - - this.setRequestHandler(ListResourceTemplatesRequestSchema, async () => { - const resourceTemplates = Object.entries( - this._registeredResourceTemplates, - ).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata, - })); - - return { resourceTemplates }; - }); - - this.setRequestHandler( - ReadResourceRequestSchema, - async (request, extra) => { - const uri = new URL(request.params.uri); - - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - return resource.readCallback(uri, extra); - } - - // Then check templates - for (const template of Object.values( - this._registeredResourceTemplates, - )) { - const variables = template.resourceTemplate.uriTemplate.match( - uri.toString(), - ); - if (variables) { - return template.readCallback(uri, variables, extra); - } - } - - throw new McpError( - ErrorCode.InvalidParams, - `Resource ${uri} not found`, - ); - }, - ); - - this.setCompletionRequestHandler(); - } - - /** - * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. - */ - resource(name: string, uri: string, readCallback: ReadResourceCallback): void; - - /** - * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. - */ - resource( - name: string, - uri: string, - metadata: ResourceMetadata, - readCallback: ReadResourceCallback, - ): void; - - /** - * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. - */ - resource( - name: string, - template: ResourceTemplate, - readCallback: ReadResourceTemplateCallback, - ): void; - - /** - * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. - */ - resource( - name: string, - template: ResourceTemplate, - metadata: ResourceMetadata, - readCallback: ReadResourceTemplateCallback, - ): void; - - resource( - name: string, - uriOrTemplate: string | ResourceTemplate, - ...rest: unknown[] - ): void { - let metadata: ResourceMetadata | undefined; - if (typeof rest[0] === "object") { - metadata = rest.shift() as ResourceMetadata; - } - - const readCallback = rest[0] as - | ReadResourceCallback - | ReadResourceTemplateCallback; - - if (typeof uriOrTemplate === "string") { - this.registerResource({ - name, - uri: uriOrTemplate, - metadata, - readCallback: readCallback as ReadResourceCallback, - }); - } else { - this.registerResourceTemplate({ - name, - resourceTemplate: uriOrTemplate, - metadata, - readCallback: readCallback as ReadResourceTemplateCallback, - }); - } - } - - private registerResource({ - name, - uri, - metadata, - readCallback, - }: { - name: string; - uri: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; - }): void { - if (this._registeredResources[uri]) { - throw new Error(`Resource ${uri} is already registered`); - } - - this._registeredResources[uri] = { - name, - metadata, - readCallback, - }; - - this.setResourceRequestHandlers(); - } - - private registerResourceTemplate({ - name, - resourceTemplate, - metadata, - readCallback, - }: { - name: string; - resourceTemplate: ResourceTemplate; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; - }): void { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - this._registeredResourceTemplates[name] = { - resourceTemplate, - metadata, - readCallback, - }; - - this.setResourceRequestHandlers(); - } - - private setPromptRequestHandlers() { - this.assertCanSetRequestHandler( - ListPromptsRequestSchema.shape.method.value, - ); - this.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); - - this.registerCapabilities({ - prompts: {}, - }); - - this.setRequestHandler( - ListPromptsRequestSchema, - (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts).map( - ([name, prompt]): Prompt => { - return { - name, - description: prompt.description, - arguments: prompt.argsSchema - ? promptArgumentsFromSchema(prompt.argsSchema) - : undefined, - }; - }, - ), - }), - ); - - this.setRequestHandler( - GetPromptRequestSchema, - async (request, extra): Promise => { - const prompt = this._registeredPrompts[request.params.name]; - if (!prompt) { - throw new McpError( - ErrorCode.InvalidParams, - `Prompt ${request.params.name} not found`, - ); - } - - if (prompt.argsSchema) { - const parseResult = await prompt.argsSchema.safeParseAsync( - request.params.arguments, - ); - if (!parseResult.success) { - throw new McpError( - ErrorCode.InvalidParams, - `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`, - ); - } - - const args = parseResult.data; - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = prompt.callback as PromptCallback; - return await Promise.resolve(cb(extra)); - } - }, - ); - - this.setCompletionRequestHandler(); - } - - prompt(name: string, cb: PromptCallback): void; - prompt(name: string, description: string, cb: PromptCallback): void; - prompt( - name: string, - argsSchema: Args, - cb: PromptCallback, - ): void; - prompt( - name: string, - description: string, - argsSchema: Args, - cb: PromptCallback, - ): void; - - prompt(name: string, ...rest: unknown[]): void { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } - - let description: string | undefined; - if (typeof rest[0] === "string") { - description = rest.shift() as string; - } - - let argsSchema: PromptArgsRawShape | undefined; - if (rest.length > 1) { - argsSchema = rest.shift() as PromptArgsRawShape; - } - - const cb = rest[0] as PromptCallback; - this._registeredPrompts[name] = { - description, - argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), - callback: cb, - }; - - this.setPromptRequestHandlers(); - } -} - -/** - * A callback to complete one variable within a resource template's URI template. - */ -export type CompleteResourceTemplateCallback = ( - value: string, -) => string[] | Promise; - -/** - * A resource template combines a URI pattern with optional functionality to enumerate - * all resources matching that pattern. - */ -export class ResourceTemplate { - private _uriTemplate: UriTemplate; - - constructor( - uriTemplate: string | UriTemplate, - private _callbacks: { - /** - * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. - */ - list: ListResourcesCallback | undefined; - - /** - * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. - */ - complete?: { - [variable: string]: CompleteResourceTemplateCallback; - }; - }, - ) { - this._uriTemplate = - typeof uriTemplate === "string" - ? new UriTemplate(uriTemplate) - : uriTemplate; - } - - /** - * Gets the URI template pattern. - */ - get uriTemplate(): UriTemplate { - return this._uriTemplate; - } - - /** - * Gets the list callback, if one was provided. - */ - get listCallback(): ListResourcesCallback | undefined { - return this._callbacks.list; - } - - /** - * Gets the callback for completing a specific URI template variable, if one was provided. - */ - completeCallback( - variable: string, - ): CompleteResourceTemplateCallback | undefined { - return this._callbacks.complete?.[variable]; - } } - -/** - * Callback for a tool handler registered with Server.tool(). - * - * Parameters will include tool arguments, if applicable, as well as other request handler context. - */ -export type ToolCallback = - Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; - -type RegisteredTool = { - description?: string; - inputSchema?: AnyZodObject; - callback: ToolCallback; -}; - -const EMPTY_OBJECT_JSON_SCHEMA = { - type: "object" as const, -}; - -/** - * Additional, optional information for annotating a resource. - */ -export type ResourceMetadata = Omit; - -/** - * Callback to list all resources matching a given template. - */ -export type ListResourcesCallback = ( - extra: RequestHandlerExtra, -) => ListResourcesResult | Promise; - -/** - * Callback to read a resource at a given URI. - */ -export type ReadResourceCallback = ( - uri: URL, - extra: RequestHandlerExtra, -) => ReadResourceResult | Promise; - -type RegisteredResource = { - name: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; -}; - -/** - * Callback to read a resource at a given URI, following a filled-in URI template. - */ -export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, - extra: RequestHandlerExtra, -) => ReadResourceResult | Promise; - -type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; -}; - -type PromptArgsRawShape = { - [k: string]: - | ZodType - | ZodOptional>; -}; - -export type PromptCallback< - Args extends undefined | PromptArgsRawShape = undefined, -> = Args extends PromptArgsRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; - -type RegisteredPrompt = { - description?: string; - argsSchema?: ZodObject; - callback: PromptCallback; -}; - -function promptArgumentsFromSchema( - schema: ZodObject, -): PromptArgument[] { - return Object.entries(schema.shape).map( - ([name, field]): PromptArgument => ({ - name, - description: field.description, - required: !field.isOptional(), - }), - ); -} - -function createCompletionResult(suggestions: string[]): CompleteResult { - return { - completion: { - values: suggestions.slice(0, 100), - total: suggestions.length, - hasMore: suggestions.length > 100, - }, - }; -} - -const EMPTY_COMPLETION_RESULT: CompleteResult = { - completion: { - values: [], - hasMore: false, - }, -}; diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts new file mode 100644 index 0000000..33fdeb9 --- /dev/null +++ b/src/server/mcp.test.ts @@ -0,0 +1,1395 @@ +import { McpServer } from "./mcp.js"; +import { Client } from "../client/index.js"; +import { InMemoryTransport } from "../inMemory.js"; +import { z } from "zod"; +import { + ListToolsResultSchema, + CallToolResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ListPromptsResultSchema, + GetPromptResultSchema, + CompleteResultSchema, +} from "../types.js"; +import { ResourceTemplate } from "./mcp.js"; +import { completable } from "./completable.js"; +import { UriTemplate } from "../shared/uriTemplate.js"; + +describe("McpServer", () => { + test("should expose underlying Server instance", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + expect(mcpServer.server).toBeDefined(); + }); + + test("should allow sending notifications via Server", async () => { + const mcpServer = new McpServer( + { + name: "test server", + version: "1.0", + }, + { capabilities: { logging: {} } }, + ); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + // This should work because we're using the underlying server + await expect( + mcpServer.server.sendLoggingMessage({ + level: "info", + data: "Test log message", + }), + ).resolves.not.toThrow(); + }); +}); + +describe("ResourceTemplate", () => { + test("should create ResourceTemplate with string pattern", () => { + const template = new ResourceTemplate("test://{category}/{id}", { + list: undefined, + }); + expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with UriTemplate", () => { + const uriTemplate = new UriTemplate("test://{category}/{id}"); + const template = new ResourceTemplate(uriTemplate, { list: undefined }); + expect(template.uriTemplate).toBe(uriTemplate); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with list callback", async () => { + const list = jest.fn().mockResolvedValue({ + resources: [{ name: "Test", uri: "test://example" }], + }); + + const template = new ResourceTemplate("test://{id}", { list }); + expect(template.listCallback).toBe(list); + + const abortController = new AbortController(); + const result = await template.listCallback?.({ + signal: abortController.signal, + }); + expect(result?.resources).toHaveLength(1); + expect(list).toHaveBeenCalled(); + }); +}); + +describe("tool()", () => { + test("should register zero-argument tool", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toEqual({ + type: "object", + }); + }); + + test("should register tool with args schema", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async ({ name, value }) => ({ + content: [ + { + type: "text", + text: `${name}: ${value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toMatchObject({ + type: "object", + properties: { + name: { type: "string" }, + value: { type: "number" }, + }, + }); + }); + + test("should register tool with description", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool("test", "Test description", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].description).toBe("Test description"); + }); + + test("should validate tool args", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async ({ name, value }) => ({ + content: [ + { + type: "text", + text: `${name}: ${value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + name: "test", + value: "not a number", + }, + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate tool registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + expect(() => { + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should allow client to call server tools", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool( + "test", + "Test tool", + { + input: z.string(), + }, + async ({ input }) => ({ + content: [ + { + type: "text", + text: `Processed: ${input}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + input: "hello", + }, + }, + }, + CallToolResultSchema, + ); + + expect(result.content).toEqual([ + { + type: "text", + text: "Processed: hello", + }, + ]); + }); + + test("should handle server tool errors gracefully", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool("error-test", async () => { + throw new Error("Tool execution failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "error-test", + }, + }, + CallToolResultSchema, + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual([ + { + type: "text", + text: "Tool execution failed", + }, + ]); + }); + + test("should throw McpError for invalid tool name", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool("test-tool", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "nonexistent-tool", + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow(/Tool nonexistent-tool not found/); + }); +}); + +describe("resource()", () => { + test("should register resource with uri and readCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe("test"); + expect(result.resources[0].uri).toBe("test://resource"); + }); + + test("should register resource with metadata", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + "test://resource", + { + description: "Test resource", + mimeType: "text/plain", + }, + async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].description).toBe("Test resource"); + expect(result.resources[0].mimeType).toBe("text/plain"); + }); + + test("should register resource template", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/templates/list", + }, + ListResourceTemplatesResultSchema, + ); + + expect(result.resourceTemplates).toHaveLength(1); + expect(result.resourceTemplates[0].name).toBe("test"); + expect(result.resourceTemplates[0].uriTemplate).toBe( + "test://resource/{id}", + ); + }); + + test("should register resource template with listCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { + list: async () => ({ + resources: [ + { + name: "Resource 1", + uri: "test://resource/1", + }, + { + name: "Resource 2", + uri: "test://resource/2", + }, + ], + }), + }), + async (uri) => ({ + contents: [ + { + uri: uri.href, + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(2); + expect(result.resources[0].name).toBe("Resource 1"); + expect(result.resources[0].uri).toBe("test://resource/1"); + expect(result.resources[1].name).toBe("Resource 2"); + expect(result.resources[1].uri).toBe("test://resource/2"); + }); + + test("should pass template variables to readCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}/{id}", { + list: undefined, + }), + async (uri, { category, id }) => ({ + contents: [ + { + uri: uri.href, + text: `Category: ${category}, ID: ${id}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/read", + params: { + uri: "test://resource/books/123", + }, + }, + ReadResourceResultSchema, + ); + + expect(result.contents[0].text).toBe("Category: books, ID: 123"); + }); + + test("should prevent duplicate resource registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + expect(() => { + mcpServer.resource("test2", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should prevent duplicate resource template registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + expect(() => { + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content 2", + }, + ], + }), + ); + }).toThrow(/already registered/); + }); + + test("should handle resource read errors gracefully", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("error-test", "test://error", async () => { + throw new Error("Resource read failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://error", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource read failed/); + }); + + test("should throw McpError for invalid resource URI", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://nonexistent", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); + }); + + test("should support completion of resource template parameters", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: () => ["books", "movies", "music"], + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["books", "movies", "music"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of resource template parameters", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: (test: string) => + ["books", "movies", "music"].filter((value) => + value.startsWith(test), + ), + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "m", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["movies", "music"]); + expect(result.completion.total).toBe(2); + }); +}); + +describe("prompt()", () => { + test("should register zero-argument prompt", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + + test("should register prompt with args schema", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt( + "test", + { + name: z.string(), + value: z.string(), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toEqual([ + { name: "name", required: true }, + { name: "value", required: true }, + ]); + }); + + test("should register prompt with description", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt("test", "Test description", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].description).toBe("Test description"); + }); + + test("should validate prompt args", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test", + { + name: z.string(), + value: z.string().min(3), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "test", + arguments: { + name: "test", + value: "ab", // Too short + }, + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate prompt registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + expect(() => { + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response 2", + }, + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should throw McpError for invalid prompt name", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt("test-prompt", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "nonexistent-prompt", + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Prompt nonexistent-prompt not found/); + }); + + test("should support completion of prompt arguments", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test-prompt", + { + name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of prompt arguments", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test-prompt", + { + name: completable(z.string(), (test) => + ["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)), + ), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "A", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice"]); + expect(result.completion.total).toBe(1); + }); +}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts new file mode 100644 index 0000000..db27099 --- /dev/null +++ b/src/server/mcp.ts @@ -0,0 +1,767 @@ +import { Server, ServerOptions } from "./index.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { + z, + ZodRawShape, + ZodObject, + ZodString, + AnyZodObject, + ZodTypeAny, + ZodType, + ZodTypeDef, + ZodOptional, +} from "zod"; +import { + Implementation, + Tool, + ListToolsResult, + CallToolResult, + McpError, + ErrorCode, + CompleteRequest, + CompleteResult, + PromptReference, + ResourceReference, + Resource, + ListResourcesResult, + ListResourceTemplatesRequestSchema, + ReadResourceRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + ListResourcesRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema, + CompleteRequestSchema, + ListPromptsResult, + Prompt, + PromptArgument, + GetPromptResult, + ReadResourceResult, + CallToolResultSchema, +} from "../types.js"; +import { Completable, CompletableDef } from "./completable.js"; +import { UriTemplate, Variables } from "../shared/uriTemplate.js"; +import { RequestHandlerExtra } from "../shared/protocol.js"; +import { Transport } from "../shared/transport.js"; + +/** + * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. + * For advanced usage (like sending notifications or setting custom request handlers), use the underlying + * Server instance available via the `server` property. + */ +export class McpServer { + /** + * The underlying Server instance, useful for advanced operations like sending notifications. + */ + public readonly server: Server; + + private _registeredResources: { [uri: string]: RegisteredResource } = {}; + private _registeredResourceTemplates: { + [name: string]: RegisteredResourceTemplate; + } = {}; + private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + + constructor(serverInfo: Implementation, options?: ServerOptions) { + this.server = new Server(serverInfo, options); + } + + /** + * Attaches to the given transport, starts it, and starts listening for messages. + * + * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + */ + async connect(transport: Transport): Promise { + return await this.server.connect(transport); + } + + /** + * Closes the connection. + */ + async close(): Promise { + await this.server.close(); + } + + private setToolRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListToolsRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + CallToolRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + tools: {}, + }); + + this.server.setRequestHandler( + ListToolsRequestSchema, + (): ListToolsResult => ({ + tools: Object.entries(this._registeredTools).map( + ([name, tool]): Tool => { + return { + name, + description: tool.description, + inputSchema: tool.inputSchema + ? (zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"]) + : EMPTY_OBJECT_JSON_SCHEMA, + }; + }, + ), + }), + ); + + this.server.setRequestHandler( + CallToolRequestSchema, + async (request, extra): Promise => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} not found`, + ); + } + + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback; + try { + return await Promise.resolve(cb(args, extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } + } else { + const cb = tool.callback as ToolCallback; + try { + return await Promise.resolve(cb(extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } + } + }, + ); + } + + private setCompletionRequestHandler() { + this.server.assertCanSetRequestHandler( + CompleteRequestSchema.shape.method.value, + ); + + this.server.setRequestHandler( + CompleteRequestSchema, + async (request): Promise => { + switch (request.params.ref.type) { + case "ref/prompt": + return this.handlePromptCompletion(request, request.params.ref); + + case "ref/resource": + return this.handleResourceCompletion(request, request.params.ref); + + default: + throw new McpError( + ErrorCode.InvalidParams, + `Invalid completion reference: ${request.params.ref}`, + ); + } + }, + ); + } + + private async handlePromptCompletion( + request: CompleteRequest, + ref: PromptReference, + ): Promise { + const prompt = this._registeredPrompts[ref.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.ref.name} not found`, + ); + } + + if (!prompt.argsSchema) { + return EMPTY_COMPLETION_RESULT; + } + + const field = prompt.argsSchema.shape[request.params.argument.name]; + if (!(field instanceof Completable)) { + return EMPTY_COMPLETION_RESULT; + } + + const def: CompletableDef = field._def; + const suggestions = await def.complete(request.params.argument.value); + return createCompletionResult(suggestions); + } + + private async handleResourceCompletion( + request: CompleteRequest, + ref: ResourceReference, + ): Promise { + const template = Object.values(this._registeredResourceTemplates).find( + (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, + ); + + if (!template) { + if (this._registeredResources[ref.uri]) { + // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + throw new McpError( + ErrorCode.InvalidParams, + `Resource template ${request.params.ref.uri} not found`, + ); + } + + const completer = template.resourceTemplate.completeCallback( + request.params.argument.name, + ); + if (!completer) { + return EMPTY_COMPLETION_RESULT; + } + + const suggestions = await completer(request.params.argument.value); + return createCompletionResult(suggestions); + } + + private setResourceRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListResourcesRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + ListResourceTemplatesRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + ReadResourceRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + resources: {}, + }); + + this.server.setRequestHandler( + ListResourcesRequestSchema, + async (request, extra) => { + const resources = Object.entries(this._registeredResources).map( + ([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata, + }), + ); + + const templateResources: Resource[] = []; + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + if (!template.resourceTemplate.listCallback) { + continue; + } + + const result = await template.resourceTemplate.listCallback(extra); + for (const resource of result.resources) { + templateResources.push({ + ...resource, + ...template.metadata, + }); + } + } + + return { resources: [...resources, ...templateResources] }; + }, + ); + + this.server.setRequestHandler( + ListResourceTemplatesRequestSchema, + async () => { + const resourceTemplates = Object.entries( + this._registeredResourceTemplates, + ).map(([name, template]) => ({ + name, + uriTemplate: template.resourceTemplate.uriTemplate.toString(), + ...template.metadata, + })); + + return { resourceTemplates }; + }, + ); + + this.server.setRequestHandler( + ReadResourceRequestSchema, + async (request, extra) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + return resource.readCallback(uri, extra); + } + + // Then check templates + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + const variables = template.resourceTemplate.uriTemplate.match( + uri.toString(), + ); + if (variables) { + return template.readCallback(uri, variables, extra); + } + } + + throw new McpError( + ErrorCode.InvalidParams, + `Resource ${uri} not found`, + ); + }, + ); + + this.setCompletionRequestHandler(); + } + + private setPromptRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListPromptsRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + GetPromptRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + prompts: {}, + }); + + this.server.setRequestHandler( + ListPromptsRequestSchema, + (): ListPromptsResult => ({ + prompts: Object.entries(this._registeredPrompts).map( + ([name, prompt]): Prompt => { + return { + name, + description: prompt.description, + arguments: prompt.argsSchema + ? promptArgumentsFromSchema(prompt.argsSchema) + : undefined, + }; + }, + ), + }), + ); + + this.server.setRequestHandler( + GetPromptRequestSchema, + async (request, extra): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.name} not found`, + ); + } + + if (prompt.argsSchema) { + const parseResult = await prompt.argsSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(extra)); + } + }, + ); + + this.setCompletionRequestHandler(); + } + + /** + * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. + */ + resource(name: string, uri: string, readCallback: ReadResourceCallback): void; + + /** + * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. + */ + resource( + name: string, + uri: string, + metadata: ResourceMetadata, + readCallback: ReadResourceCallback, + ): void; + + /** + * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. + */ + resource( + name: string, + template: ResourceTemplate, + readCallback: ReadResourceTemplateCallback, + ): void; + + /** + * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. + */ + resource( + name: string, + template: ResourceTemplate, + metadata: ResourceMetadata, + readCallback: ReadResourceTemplateCallback, + ): void; + + resource( + name: string, + uriOrTemplate: string | ResourceTemplate, + ...rest: unknown[] + ): void { + let metadata: ResourceMetadata | undefined; + if (typeof rest[0] === "object") { + metadata = rest.shift() as ResourceMetadata; + } + + const readCallback = rest[0] as + | ReadResourceCallback + | ReadResourceTemplateCallback; + + if (typeof uriOrTemplate === "string") { + if (this._registeredResources[uriOrTemplate]) { + throw new Error(`Resource ${uriOrTemplate} is already registered`); + } + + this._registeredResources[uriOrTemplate] = { + name, + metadata, + readCallback: readCallback as ReadResourceCallback, + }; + } else { + if (this._registeredResourceTemplates[name]) { + throw new Error(`Resource template ${name} is already registered`); + } + + this._registeredResourceTemplates[name] = { + resourceTemplate: uriOrTemplate, + metadata, + readCallback: readCallback as ReadResourceTemplateCallback, + }; + } + + this.setResourceRequestHandlers(); + } + + /** + * Registers a zero-argument tool `name`, which will run the given function when the client calls it. + */ + tool(name: string, cb: ToolCallback): void; + + /** + * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. + */ + tool(name: string, description: string, cb: ToolCallback): void; + + /** + * Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool( + name: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; + + /** + * Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool( + name: string, + description: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; + + tool(name: string, ...rest: unknown[]): void { + if (this._registeredTools[name]) { + throw new Error(`Tool ${name} is already registered`); + } + + let description: string | undefined; + if (typeof rest[0] === "string") { + description = rest.shift() as string; + } + + let paramsSchema: ZodRawShape | undefined; + if (rest.length > 1) { + paramsSchema = rest.shift() as ZodRawShape; + } + + const cb = rest[0] as ToolCallback; + this._registeredTools[name] = { + description, + inputSchema: + paramsSchema === undefined ? undefined : z.object(paramsSchema), + callback: cb, + }; + + this.setToolRequestHandlers(); + } + + /** + * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. + */ + prompt(name: string, cb: PromptCallback): void; + + /** + * Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it. + */ + prompt(name: string, description: string, cb: PromptCallback): void; + + /** + * Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt( + name: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + /** + * Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt( + name: string, + description: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + prompt(name: string, ...rest: unknown[]): void { + if (this._registeredPrompts[name]) { + throw new Error(`Prompt ${name} is already registered`); + } + + let description: string | undefined; + if (typeof rest[0] === "string") { + description = rest.shift() as string; + } + + let argsSchema: PromptArgsRawShape | undefined; + if (rest.length > 1) { + argsSchema = rest.shift() as PromptArgsRawShape; + } + + const cb = rest[0] as PromptCallback; + this._registeredPrompts[name] = { + description, + argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), + callback: cb, + }; + + this.setPromptRequestHandlers(); + } +} + +/** + * A callback to complete one variable within a resource template's URI template. + */ +export type CompleteResourceTemplateCallback = ( + value: string, +) => string[] | Promise; + +/** + * A resource template combines a URI pattern with optional functionality to enumerate + * all resources matching that pattern. + */ +export class ResourceTemplate { + private _uriTemplate: UriTemplate; + + constructor( + uriTemplate: string | UriTemplate, + private _callbacks: { + /** + * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. + */ + list: ListResourcesCallback | undefined; + + /** + * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. + */ + complete?: { + [variable: string]: CompleteResourceTemplateCallback; + }; + }, + ) { + this._uriTemplate = + typeof uriTemplate === "string" + ? new UriTemplate(uriTemplate) + : uriTemplate; + } + + /** + * Gets the URI template pattern. + */ + get uriTemplate(): UriTemplate { + return this._uriTemplate; + } + + /** + * Gets the list callback, if one was provided. + */ + get listCallback(): ListResourcesCallback | undefined { + return this._callbacks.list; + } + + /** + * Gets the callback for completing a specific URI template variable, if one was provided. + */ + completeCallback( + variable: string, + ): CompleteResourceTemplateCallback | undefined { + return this._callbacks.complete?.[variable]; + } +} + +/** + * Callback for a tool handler registered with Server.tool(). + * + * Parameters will include tool arguments, if applicable, as well as other request handler context. + */ +export type ToolCallback = + Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; + +type RegisteredTool = { + description?: string; + inputSchema?: AnyZodObject; + callback: ToolCallback; +}; + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: "object" as const, +}; + +/** + * Additional, optional information for annotating a resource. + */ +export type ResourceMetadata = Omit; + +/** + * Callback to list all resources matching a given template. + */ +export type ListResourcesCallback = ( + extra: RequestHandlerExtra, +) => ListResourcesResult | Promise; + +/** + * Callback to read a resource at a given URI. + */ +export type ReadResourceCallback = ( + uri: URL, + extra: RequestHandlerExtra, +) => ReadResourceResult | Promise; + +type RegisteredResource = { + name: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +}; + +/** + * Callback to read a resource at a given URI, following a filled-in URI template. + */ +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, + extra: RequestHandlerExtra, +) => ReadResourceResult | Promise; + +type RegisteredResourceTemplate = { + resourceTemplate: ResourceTemplate; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; +}; + +type PromptArgsRawShape = { + [k: string]: + | ZodType + | ZodOptional>; +}; + +export type PromptCallback< + Args extends undefined | PromptArgsRawShape = undefined, +> = Args extends PromptArgsRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => GetPromptResult | Promise + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + +type RegisteredPrompt = { + description?: string; + argsSchema?: ZodObject; + callback: PromptCallback; +}; + +function promptArgumentsFromSchema( + schema: ZodObject, +): PromptArgument[] { + return Object.entries(schema.shape).map( + ([name, field]): PromptArgument => ({ + name, + description: field.description, + required: !field.isOptional(), + }), + ); +} + +function createCompletionResult(suggestions: string[]): CompleteResult { + return { + completion: { + values: suggestions.slice(0, 100), + total: suggestions.length, + hasMore: suggestions.length > 100, + }, + }; +} + +const EMPTY_COMPLETION_RESULT: CompleteResult = { + completion: { + values: [], + hasMore: false, + }, +}; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index cf33cd2..a4f211c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -516,9 +516,11 @@ export abstract class Protocol< /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ - protected assertCanSetRequestHandler(method: string): void { + assertCanSetRequestHandler(method: string): void { if (this._requestHandlers.has(method)) { - throw new Error(`A request handler for ${method} already exists, which would be overridden`); + throw new Error( + `A request handler for ${method} already exists, which would be overridden`, + ); } } @@ -550,7 +552,9 @@ export abstract class Protocol< } } -export function mergeCapabilities(base: T, additional: T): T { +export function mergeCapabilities< + T extends ServerCapabilities | ClientCapabilities, +>(base: T, additional: T): T { return Object.entries(additional).reduce( (acc, [key, value]) => { if (value && typeof value === "object") { From 7629c70ef988b7c8256796060564f8b13a5d0154 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Wed, 8 Jan 2025 22:02:00 +0000 Subject: [PATCH 26/31] Remove unused import --- src/server/mcp.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/src/server/mcp.ts b/src/server/mcp.ts index db27099..4b48446 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -37,7 +37,6 @@ import { PromptArgument, GetPromptResult, ReadResourceResult, - CallToolResultSchema, } from "../types.js"; import { Completable, CompletableDef } from "./completable.js"; import { UriTemplate, Variables } from "../shared/uriTemplate.js"; From e60fb5e2b8ccbc143d8dd20796b076b3a48ce50d Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 13 Jan 2025 20:41:01 +0000 Subject: [PATCH 27/31] Update README.md --- README.md | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/README.md b/README.md index 6843fa7..7ad7636 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,67 @@ const resourceContent = await client.request( ### Creating a Server +The SDK provides two ways to create a server: using the low-level `Server` class or the simplified `McpServer` class with an Express-style API. + +#### Using McpServer (Recommended) + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/index.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { z } from "zod"; + +const server = new McpServer({ + name: "example-server", + version: "1.0.0", +}); + +// Define a simple tool with no parameters +server.tool("save", async () => { + return { + content: [{ type: "text", text: "Saved successfully." }] + }; +}); + +// Define a tool with parameters +server.tool("add", { a: z.number(), b: z.number() }, async ({ a, b }) => { + return { + content: [{ type: "text", text: String(a + b) }] + }; +}); + +// Define a static resource +server.resource( + "welcome-message", + "file:///welcome.txt", + async (uri) => ({ + contents: [{ + uri: uri.href, + text: "Welcome to the server!" + }] + }) +); + +// Define a prompt with parameters +server.prompt( + "greeting", + { name: z.string(), language: z.string().optional() }, + ({ name, language }) => ({ + messages: [{ + role: "assistant", + content: { + type: "text", + text: `${language === "es" ? "¡Hola" : "Hello"} ${name}!` + } + }] + }) +); + +const transport = new StdioServerTransport(); +await server.connect(transport); +``` + +#### Using Server (Low-level API) + ```typescript import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; From 0790890bf3a9d452be5290952cddb13911bfc8bb Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 13 Jan 2025 21:08:47 +0000 Subject: [PATCH 28/31] Model README on the Python SDK's README --- README.md | 425 ++++++++++++++++++++++++++++++++++++---------- package-lock.json | 4 +- 2 files changed, 334 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index 7ad7636..8417272 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,26 @@ -# MCP TypeScript SDK ![NPM Version](https://img.shields.io/npm/v/%40modelcontextprotocol%2Fsdk) - -TypeScript implementation of the [Model Context Protocol](https://modelcontextprotocol.io) (MCP), providing both client and server capabilities for integrating with LLM surfaces. +# MCP TypeScript SDK ![NPM Version](https://img.shields.io/npm/v/%40modelcontextprotocol%2Fsdk) ![MIT licensed](https://img.shields.io/npm/l/%40modelcontextprotocol%2Fsdk) + +## Table of Contents +- [Overview](#overview) +- [Installation](#installation) +- [Quickstart](#quickstart) +- [What is MCP?](#what-is-mcp) +- [Core Concepts](#core-concepts) + - [Server](#server) + - [Resources](#resources) + - [Tools](#tools) + - [Prompts](#prompts) +- [Running Your Server](#running-your-server) + - [Development Mode](#development-mode) + - [Claude Desktop Integration](#claude-desktop-integration) + - [Direct Execution](#direct-execution) +- [Examples](#examples) + - [Echo Server](#echo-server) + - [SQLite Explorer](#sqlite-explorer) +- [Advanced Usage](#advanced-usage) + - [Low-Level Server](#low-level-server) + - [Writing MCP Clients](#writing-mcp-clients) + - [Server Capabilities](#server-capabilities) ## Overview @@ -19,158 +39,377 @@ npm install @modelcontextprotocol/sdk ## Quick Start -### Creating a Client +Let's create a simple MCP server that exposes a calculator tool and some data: ```typescript -import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; -import { - ListResourcesRequestSchema, - ReadResourceRequestSchema, -} from "@modelcontextprotocol/sdk/types.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { z } from "zod"; -const transport = new StdioClientTransport({ - command: "path/to/server", +// Create an MCP server +const server = new McpServer({ + name: "Demo", + version: "1.0.0" }); -const client = new Client({ - name: "example-client", - version: "1.0.0", -}, { - capabilities: {} +// Add an addition tool +server.tool("add", + { a: z.number(), b: z.number() }, + async ({ a, b }) => ({ + content: [{ type: "text", text: String(a + b) }] + }) +); + +// Add a dynamic greeting resource +server.resource( + "greeting", + "greeting://{name}", + async (uri, { name }) => ({ + contents: [{ + uri: uri.href, + text: `Hello, ${name}!` + }] + }) +); +``` + +## What is MCP? + +The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can: + +- Expose data through **Resources** (think of these sort of like GET endpoints; they are used to load information into the LLM's context) +- Provide functionality through **Tools** (sort of like POST endpoints; they are used to execute code or otherwise produce a side effect) +- Define interaction patterns through **Prompts** (reusable templates for LLM interactions) +- And more! + +## Core Concepts + +### Server + +The McpServer is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: + +```typescript +const server = new McpServer({ + name: "My App", + version: "1.0.0" }); +``` -await client.connect(transport); +### Resources + +Resources are how you expose data to LLMs. They're similar to GET endpoints in a REST API - they provide data but shouldn't perform significant computation or have side effects: + +```typescript +// Static resource +server.resource( + "config", + "config://app", + async (uri) => ({ + contents: [{ + uri: uri.href, + text: "App configuration here" + }] + }) +); -// List available resources -const resources = await client.request( - { method: "resources/list" }, - ListResourcesResultSchema +// Dynamic resource with parameters +server.resource( + "user-profile", + "users://{userId}/profile", + async (uri, { userId }) => ({ + contents: [{ + uri: uri.href, + text: `Profile data for user ${userId}` + }] + }) ); +``` + +### Tools -// Read a specific resource -const resourceContent = await client.request( +Tools let LLMs take actions through your server. Unlike resources, tools are expected to perform computation and have side effects: + +```typescript +// Simple tool with parameters +server.tool( + "calculate-bmi", { - method: "resources/read", - params: { - uri: "file:///example.txt" - } + weightKg: z.number(), + heightM: z.number() }, - ReadResourceResultSchema + async ({ weightKg, heightM }) => ({ + content: [{ + type: "text", + text: String(weightKg / (heightM * heightM)) + }] + }) +); + +// Async tool with external API call +server.tool( + "fetch-weather", + { city: z.string() }, + async ({ city }) => { + const response = await fetch(`https://api.weather.com/${city}`); + const data = await response.text(); + return { + content: [{ type: "text", text: data }] + }; + } ); ``` -### Creating a Server +### Prompts -The SDK provides two ways to create a server: using the low-level `Server` class or the simplified `McpServer` class with an Express-style API. +Prompts are reusable templates that help LLMs interact with your server effectively: -#### Using McpServer (Recommended) +```typescript +server.prompt( + "review-code", + { code: z.string() }, + ({ code }) => ({ + messages: [{ + role: "user", + content: { + type: "text", + text: `Please review this code:\n\n${code}` + } + }] + }) +); +``` + +## Examples + +### Echo Server + +A simple server demonstrating resources, tools, and prompts: ```typescript -import { McpServer } from "@modelcontextprotocol/sdk/server/index.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { z } from "zod"; const server = new McpServer({ - name: "example-server", - version: "1.0.0", -}); - -// Define a simple tool with no parameters -server.tool("save", async () => { - return { - content: [{ type: "text", text: "Saved successfully." }] - }; + name: "Echo", + version: "1.0.0" }); -// Define a tool with parameters -server.tool("add", { a: z.number(), b: z.number() }, async ({ a, b }) => { - return { - content: [{ type: "text", text: String(a + b) }] - }; -}); - -// Define a static resource server.resource( - "welcome-message", - "file:///welcome.txt", - async (uri) => ({ + "echo", + "echo://{message}", + async (uri, { message }) => ({ contents: [{ uri: uri.href, - text: "Welcome to the server!" + text: `Resource echo: ${message}` }] }) ); -// Define a prompt with parameters +server.tool( + "echo", + { message: z.string() }, + async ({ message }) => ({ + content: [{ type: "text", text: `Tool echo: ${message}` }] + }) +); + server.prompt( - "greeting", - { name: z.string(), language: z.string().optional() }, - ({ name, language }) => ({ + "echo", + { message: z.string() }, + ({ message }) => ({ messages: [{ - role: "assistant", + role: "user", content: { type: "text", - text: `${language === "es" ? "¡Hola" : "Hello"} ${name}!` + text: `Please process this message: ${message}` } }] }) ); +``` -const transport = new StdioServerTransport(); -await server.connect(transport); +### SQLite Explorer + +A more complex example showing database integration: + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import sqlite3 from "sqlite3"; +import { promisify } from "util"; +import { z } from "zod"; + +const server = new McpServer({ + name: "SQLite Explorer", + version: "1.0.0" +}); + +// Helper to create DB connection +const getDb = () => { + const db = new sqlite3.Database("database.db"); + return { + all: promisify(db.all.bind(db)), + close: promisify(db.close.bind(db)) + }; +}; + +server.resource( + "schema", + "schema://main", + async (uri) => { + const db = getDb(); + try { + const tables = await db.all( + "SELECT sql FROM sqlite_master WHERE type='table'" + ); + return { + contents: [{ + uri: uri.href, + text: tables.map((t: {sql: string}) => t.sql).join("\n") + }] + }; + } finally { + await db.close(); + } + } +); + +server.tool( + "query", + { sql: z.string() }, + async ({ sql }) => { + const db = getDb(); + try { + const results = await db.all(sql); + return { + content: [{ + type: "text", + text: JSON.stringify(results, null, 2) + }] + }; + } catch (err: unknown) { + const error = err as Error; + return { + content: [{ + type: "text", + text: `Error: ${error.message}` + }], + isError: true + }; + } finally { + await db.close(); + } + } +); ``` -#### Using Server (Low-level API) +## Advanced Usage + +### Low-Level Server + +For more control, you can use the low-level Server class directly: ```typescript import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { - ListResourcesRequestSchema, - ReadResourceRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema } from "@modelcontextprotocol/sdk/types.js"; -const server = new Server({ - name: "example-server", - version: "1.0.0", -}, { - capabilities: { - resources: {} +const server = new Server( + { + name: "example-server", + version: "1.0.0" + }, + { + capabilities: { + prompts: {} + } } -}); +); -server.setRequestHandler(ListResourcesRequestSchema, async () => { +server.setRequestHandler(ListPromptsRequestSchema, async () => { return { - resources: [ - { - uri: "file:///example.txt", - name: "Example Resource", - }, - ], + prompts: [{ + name: "example-prompt", + description: "An example prompt template", + arguments: [{ + name: "arg1", + description: "Example argument", + required: true + }] + }] }; }); -server.setRequestHandler(ReadResourceRequestSchema, async (request) => { - if (request.params.uri === "file:///example.txt") { - return { - contents: [ - { - uri: "file:///example.txt", - mimeType: "text/plain", - text: "This is the content of the example resource.", - }, - ], - }; - } else { - throw new Error("Resource not found"); +server.setRequestHandler(GetPromptRequestSchema, async (request) => { + if (request.params.name !== "example-prompt") { + throw new Error("Unknown prompt"); } + return { + description: "Example prompt", + messages: [{ + role: "user", + content: { + type: "text", + text: "Example prompt text" + } + }] + }; }); const transport = new StdioServerTransport(); await server.connect(transport); ``` +### Writing MCP Clients + +The SDK provides a high-level client interface: + +```typescript +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; + +const transport = new StdioClientTransport({ + command: "node", + args: ["server.js"] +}); + +const client = new Client( + { + name: "example-client", + version: "1.0.0" + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } +); + +await client.connect(transport); + +// List prompts +const prompts = await client.listPrompts(); + +// Get a prompt +const prompt = await client.getPrompt("example-prompt", { + arg1: "value" +}); + +// List resources +const resources = await client.listResources(); + +// Read a resource +const resource = await client.readResource("file:///example.txt"); + +// Call a tool +const result = await client.callTool("example-tool", { + arg1: "value" +}); +``` + ## Documentation - [Model Context Protocol documentation](https://modelcontextprotocol.io) diff --git a/package-lock.json b/package-lock.json index 44a876a..627ba05 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.1.0", + "version": "1.1.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.1.0", + "version": "1.1.1", "license": "MIT", "dependencies": { "content-type": "^1.0.5", From 5a6823a0db253fa35bcae8d696bdb3a0eb9da637 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Thu, 16 Jan 2025 11:24:39 +0000 Subject: [PATCH 29/31] Stick with %20, don't replace with + --- src/shared/uriTemplate.test.ts | 2 +- src/shared/uriTemplate.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts index 0e75131..7941d44 100644 --- a/src/shared/uriTemplate.test.ts +++ b/src/shared/uriTemplate.test.ts @@ -34,7 +34,7 @@ describe("UriTemplate", () => { it("should encode reserved characters", () => { const template = new UriTemplate("{var}"); expect(template.expand({ var: "value with spaces" })).toBe( - "value+with+spaces", + "value%20with%20spaces", ); }); }); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts index cd3f46f..bb17732 100644 --- a/src/shared/uriTemplate.ts +++ b/src/shared/uriTemplate.ts @@ -125,7 +125,7 @@ export class UriTemplate { if (operator === "+" || operator === "#") { return encodeURI(value); } - return encodeURIComponent(value).replace(/%20/g, "+"); + return encodeURIComponent(value); } private expandPart( From d47a76f60835ee6f46e91b36d66d459afaaa54c7 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 20 Jan 2025 15:27:58 +0000 Subject: [PATCH 30/31] Pre-emptively bump package version --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index 8a4d11e..bc7fb42 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.2.0", + "version": "1.3.0", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", From e8a5ffc64f8289b6625471fefebd4a1d7ca9a683 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 20 Jan 2025 15:28:01 +0000 Subject: [PATCH 31/31] `npm audit fix` --- package-lock.json | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/package-lock.json b/package-lock.json index 627ba05..6f8490b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.1.1", + "version": "1.3.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.1.1", + "version": "1.3.0", "license": "MIT", "dependencies": { "content-type": "^1.0.5", @@ -3123,10 +3123,11 @@ } }, "node_modules/express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, + "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", @@ -3147,7 +3148,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -3162,6 +3163,10 @@ }, "engines": { "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" } }, "node_modules/express/node_modules/debug": { @@ -5008,10 +5013,11 @@ "dev": true }, "node_modules/path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", - "dev": true + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "dev": true, + "license": "MIT" }, "node_modules/picocolors": { "version": "1.1.1",