diff --git a/README.md b/README.md index 6843fa7..8417272 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,26 @@ -# MCP TypeScript SDK ![NPM Version](https://img.shields.io/npm/v/%40modelcontextprotocol%2Fsdk) +# MCP TypeScript SDK ![NPM Version](https://img.shields.io/npm/v/%40modelcontextprotocol%2Fsdk) ![MIT licensed](https://img.shields.io/npm/l/%40modelcontextprotocol%2Fsdk) -TypeScript implementation of the [Model Context Protocol](https://modelcontextprotocol.io) (MCP), providing both client and server capabilities for integrating with LLM surfaces. +## Table of Contents +- [Overview](#overview) +- [Installation](#installation) +- [Quickstart](#quickstart) +- [What is MCP?](#what-is-mcp) +- [Core Concepts](#core-concepts) + - [Server](#server) + - [Resources](#resources) + - [Tools](#tools) + - [Prompts](#prompts) +- [Running Your Server](#running-your-server) + - [Development Mode](#development-mode) + - [Claude Desktop Integration](#claude-desktop-integration) + - [Direct Execution](#direct-execution) +- [Examples](#examples) + - [Echo Server](#echo-server) + - [SQLite Explorer](#sqlite-explorer) +- [Advanced Usage](#advanced-usage) + - [Low-Level Server](#low-level-server) + - [Writing MCP Clients](#writing-mcp-clients) + - [Server Capabilities](#server-capabilities) ## Overview @@ -19,97 +39,377 @@ npm install @modelcontextprotocol/sdk ## Quick Start -### Creating a Client +Let's create a simple MCP server that exposes a calculator tool and some data: ```typescript -import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; -import { - ListResourcesRequestSchema, - ReadResourceRequestSchema, -} from "@modelcontextprotocol/sdk/types.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { z } from "zod"; -const transport = new StdioClientTransport({ - command: "path/to/server", +// Create an MCP server +const server = new McpServer({ + name: "Demo", + version: "1.0.0" }); -const client = new Client({ - name: "example-client", - version: "1.0.0", -}, { - capabilities: {} +// Add an addition tool +server.tool("add", + { a: z.number(), b: z.number() }, + async ({ a, b }) => ({ + content: [{ type: "text", text: String(a + b) }] + }) +); + +// Add a dynamic greeting resource +server.resource( + "greeting", + "greeting://{name}", + async (uri, { name }) => ({ + contents: [{ + uri: uri.href, + text: `Hello, ${name}!` + }] + }) +); +``` + +## What is MCP? + +The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can: + +- Expose data through **Resources** (think of these sort of like GET endpoints; they are used to load information into the LLM's context) +- Provide functionality through **Tools** (sort of like POST endpoints; they are used to execute code or otherwise produce a side effect) +- Define interaction patterns through **Prompts** (reusable templates for LLM interactions) +- And more! + +## Core Concepts + +### Server + +The McpServer is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing: + +```typescript +const server = new McpServer({ + name: "My App", + version: "1.0.0" }); +``` -await client.connect(transport); +### Resources + +Resources are how you expose data to LLMs. They're similar to GET endpoints in a REST API - they provide data but shouldn't perform significant computation or have side effects: + +```typescript +// Static resource +server.resource( + "config", + "config://app", + async (uri) => ({ + contents: [{ + uri: uri.href, + text: "App configuration here" + }] + }) +); -// List available resources -const resources = await client.request( - { method: "resources/list" }, - ListResourcesResultSchema +// Dynamic resource with parameters +server.resource( + "user-profile", + "users://{userId}/profile", + async (uri, { userId }) => ({ + contents: [{ + uri: uri.href, + text: `Profile data for user ${userId}` + }] + }) ); +``` + +### Tools -// Read a specific resource -const resourceContent = await client.request( +Tools let LLMs take actions through your server. Unlike resources, tools are expected to perform computation and have side effects: + +```typescript +// Simple tool with parameters +server.tool( + "calculate-bmi", { - method: "resources/read", - params: { - uri: "file:///example.txt" - } + weightKg: z.number(), + heightM: z.number() }, - ReadResourceResultSchema + async ({ weightKg, heightM }) => ({ + content: [{ + type: "text", + text: String(weightKg / (heightM * heightM)) + }] + }) +); + +// Async tool with external API call +server.tool( + "fetch-weather", + { city: z.string() }, + async ({ city }) => { + const response = await fetch(`https://api.weather.com/${city}`); + const data = await response.text(); + return { + content: [{ type: "text", text: data }] + }; + } +); +``` + +### Prompts + +Prompts are reusable templates that help LLMs interact with your server effectively: + +```typescript +server.prompt( + "review-code", + { code: z.string() }, + ({ code }) => ({ + messages: [{ + role: "user", + content: { + type: "text", + text: `Please review this code:\n\n${code}` + } + }] + }) ); ``` -### Creating a Server +## Examples + +### Echo Server + +A simple server demonstrating resources, tools, and prompts: + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { z } from "zod"; + +const server = new McpServer({ + name: "Echo", + version: "1.0.0" +}); + +server.resource( + "echo", + "echo://{message}", + async (uri, { message }) => ({ + contents: [{ + uri: uri.href, + text: `Resource echo: ${message}` + }] + }) +); + +server.tool( + "echo", + { message: z.string() }, + async ({ message }) => ({ + content: [{ type: "text", text: `Tool echo: ${message}` }] + }) +); + +server.prompt( + "echo", + { message: z.string() }, + ({ message }) => ({ + messages: [{ + role: "user", + content: { + type: "text", + text: `Please process this message: ${message}` + } + }] + }) +); +``` + +### SQLite Explorer + +A more complex example showing database integration: + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import sqlite3 from "sqlite3"; +import { promisify } from "util"; +import { z } from "zod"; + +const server = new McpServer({ + name: "SQLite Explorer", + version: "1.0.0" +}); + +// Helper to create DB connection +const getDb = () => { + const db = new sqlite3.Database("database.db"); + return { + all: promisify(db.all.bind(db)), + close: promisify(db.close.bind(db)) + }; +}; + +server.resource( + "schema", + "schema://main", + async (uri) => { + const db = getDb(); + try { + const tables = await db.all( + "SELECT sql FROM sqlite_master WHERE type='table'" + ); + return { + contents: [{ + uri: uri.href, + text: tables.map((t: {sql: string}) => t.sql).join("\n") + }] + }; + } finally { + await db.close(); + } + } +); + +server.tool( + "query", + { sql: z.string() }, + async ({ sql }) => { + const db = getDb(); + try { + const results = await db.all(sql); + return { + content: [{ + type: "text", + text: JSON.stringify(results, null, 2) + }] + }; + } catch (err: unknown) { + const error = err as Error; + return { + content: [{ + type: "text", + text: `Error: ${error.message}` + }], + isError: true + }; + } finally { + await db.close(); + } + } +); +``` + +## Advanced Usage + +### Low-Level Server + +For more control, you can use the low-level Server class directly: ```typescript import { Server } from "@modelcontextprotocol/sdk/server/index.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { - ListResourcesRequestSchema, - ReadResourceRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema } from "@modelcontextprotocol/sdk/types.js"; -const server = new Server({ - name: "example-server", - version: "1.0.0", -}, { - capabilities: { - resources: {} +const server = new Server( + { + name: "example-server", + version: "1.0.0" + }, + { + capabilities: { + prompts: {} + } } -}); +); -server.setRequestHandler(ListResourcesRequestSchema, async () => { +server.setRequestHandler(ListPromptsRequestSchema, async () => { return { - resources: [ - { - uri: "file:///example.txt", - name: "Example Resource", - }, - ], + prompts: [{ + name: "example-prompt", + description: "An example prompt template", + arguments: [{ + name: "arg1", + description: "Example argument", + required: true + }] + }] }; }); -server.setRequestHandler(ReadResourceRequestSchema, async (request) => { - if (request.params.uri === "file:///example.txt") { - return { - contents: [ - { - uri: "file:///example.txt", - mimeType: "text/plain", - text: "This is the content of the example resource.", - }, - ], - }; - } else { - throw new Error("Resource not found"); +server.setRequestHandler(GetPromptRequestSchema, async (request) => { + if (request.params.name !== "example-prompt") { + throw new Error("Unknown prompt"); } + return { + description: "Example prompt", + messages: [{ + role: "user", + content: { + type: "text", + text: "Example prompt text" + } + }] + }; }); const transport = new StdioServerTransport(); await server.connect(transport); ``` +### Writing MCP Clients + +The SDK provides a high-level client interface: + +```typescript +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; + +const transport = new StdioClientTransport({ + command: "node", + args: ["server.js"] +}); + +const client = new Client( + { + name: "example-client", + version: "1.0.0" + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {} + } + } +); + +await client.connect(transport); + +// List prompts +const prompts = await client.listPrompts(); + +// Get a prompt +const prompt = await client.getPrompt("example-prompt", { + arg1: "value" +}); + +// List resources +const resources = await client.listResources(); + +// Read a resource +const resource = await client.readResource("file:///example.txt"); + +// Call a tool +const result = await client.callTool("example-tool", { + arg1: "value" +}); +``` + ## Documentation - [Model Context Protocol documentation](https://modelcontextprotocol.io) diff --git a/package-lock.json b/package-lock.json index 5e68c8c..6f8490b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,17 +1,18 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.1.0", + "version": "1.3.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.1.0", + "version": "1.3.0", "license": "MIT", "dependencies": { "content-type": "^1.0.5", "raw-body": "^3.0.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@eslint/js": "^9.8.0", @@ -3122,10 +3123,11 @@ } }, "node_modules/express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, + "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", @@ -3146,7 +3148,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -3161,6 +3163,10 @@ }, "engines": { "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" } }, "node_modules/express/node_modules/debug": { @@ -5007,10 +5013,11 @@ "dev": true }, "node_modules/path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", - "dev": true + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "dev": true, + "license": "MIT" }, "node_modules/picocolors": { "version": "1.1.1", @@ -6146,12 +6153,22 @@ } }, "node_modules/zod": { - "version": "3.23.8", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz", - "integrity": "sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==", + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.1.tgz", + "integrity": "sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==", + "license": "MIT", "funding": { "url": "https://github.com/sponsors/colinhacks" } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.1.tgz", + "integrity": "sha512-3h08nf3Vw3Wl3PK+q3ow/lIil81IT2Oa7YpQyUUDsEWbXveMesdfK1xBd2RhCkynwZndAxixji/7SYJJowr62w==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } } } } diff --git a/package.json b/package.json index e674e74..bc7fb42 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.2.0", + "version": "1.3.0", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", @@ -48,7 +48,8 @@ "dependencies": { "content-type": "^1.0.5", "raw-body": "^3.0.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@eslint/js": "^9.8.0", diff --git a/src/client/index.ts b/src/client/index.ts index 1ad5161..da0d04c 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,4 +1,5 @@ import { + mergeCapabilities, Protocol, ProtocolOptions, RequestOptions, @@ -44,7 +45,7 @@ export type ClientOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this client. */ - capabilities: ClientCapabilities; + capabilities?: ClientCapabilities; }; /** @@ -90,10 +91,25 @@ export class Client< */ constructor( private _clientInfo: Implementation, - options: ClientOptions, + options?: ClientOptions, ) { super(options); - this._capabilities = options.capabilities; + this._capabilities = options?.capabilities ?? {}; + } + + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ClientCapabilities): void { + if (this.transport) { + throw new Error( + "Cannot register capabilities after connecting to transport", + ); + } + + this._capabilities = mergeCapabilities(this._capabilities, capabilities); } protected assertCapability( diff --git a/src/server/completable.test.ts b/src/server/completable.test.ts new file mode 100644 index 0000000..6040ff3 --- /dev/null +++ b/src/server/completable.test.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import { completable } from "./completable.js"; + +describe("completable", () => { + it("preserves types and values of underlying schema", () => { + const baseSchema = z.string(); + const schema = completable(baseSchema, () => []); + + expect(schema.parse("test")).toBe("test"); + expect(() => schema.parse(123)).toThrow(); + }); + + it("provides access to completion function", async () => { + const completions = ["foo", "bar", "baz"]; + const schema = completable(z.string(), () => completions); + + expect(await schema._def.complete("")).toEqual(completions); + }); + + it("allows async completion functions", async () => { + const completions = ["foo", "bar", "baz"]; + const schema = completable(z.string(), async () => completions); + + expect(await schema._def.complete("")).toEqual(completions); + }); + + it("passes current value to completion function", async () => { + const schema = completable(z.string(), (value) => [value + "!"]); + + expect(await schema._def.complete("test")).toEqual(["test!"]); + }); + + it("works with number schemas", async () => { + const schema = completable(z.number(), () => [1, 2, 3]); + + expect(schema.parse(1)).toBe(1); + expect(await schema._def.complete(0)).toEqual([1, 2, 3]); + }); + + it("preserves schema description", () => { + const desc = "test description"; + const schema = completable(z.string().describe(desc), () => []); + + expect(schema.description).toBe(desc); + }); +}); diff --git a/src/server/completable.ts b/src/server/completable.ts new file mode 100644 index 0000000..3b5bc16 --- /dev/null +++ b/src/server/completable.ts @@ -0,0 +1,95 @@ +import { + ZodTypeAny, + ZodTypeDef, + ZodType, + ParseInput, + ParseReturnType, + RawCreateParams, + ZodErrorMap, + ProcessedCreateParams, +} from "zod"; + +export enum McpZodTypeKind { + Completable = "McpCompletable", +} + +export type CompleteCallback = ( + value: T["_input"], +) => T["_input"][] | Promise; + +export interface CompletableDef + extends ZodTypeDef { + type: T; + complete: CompleteCallback; + typeName: McpZodTypeKind.Completable; +} + +export class Completable extends ZodType< + T["_output"], + CompletableDef, + T["_input"] +> { + _parse(input: ParseInput): ParseReturnType { + const { ctx } = this._processInputParams(input); + const data = ctx.data; + return this._def.type._parse({ + data, + path: ctx.path, + parent: ctx, + }); + } + + unwrap() { + return this._def.type; + } + + static create = ( + type: T, + params: RawCreateParams & { + complete: CompleteCallback; + }, + ): Completable => { + return new Completable({ + type, + typeName: McpZodTypeKind.Completable, + complete: params.complete, + ...processCreateParams(params), + }); + }; +} + +/** + * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. + */ +export function completable( + schema: T, + complete: CompleteCallback, +): Completable { + return Completable.create(schema, { ...schema._def, complete }); +} + +// Not sure why this isn't exported from Zod: +// https://github.com/colinhacks/zod/blob/f7ad26147ba291cb3fb257545972a8e00e767470/src/types.ts#L130 +function processCreateParams(params: RawCreateParams): ProcessedCreateParams { + if (!params) return {}; + const { errorMap, invalid_type_error, required_error, description } = params; + if (errorMap && (invalid_type_error || required_error)) { + throw new Error( + `Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`, + ); + } + if (errorMap) return { errorMap: errorMap, description }; + const customMap: ZodErrorMap = (iss, ctx) => { + const { message } = params; + + if (iss.code === "invalid_enum_value") { + return { message: message ?? ctx.defaultError }; + } + if (typeof ctx.data === "undefined") { + return { message: message ?? required_error ?? ctx.defaultError }; + } + if (iss.code !== "invalid_type") return { message: ctx.defaultError }; + return { message: message ?? invalid_type_error ?? ctx.defaultError }; + }; + return { errorMap: customMap, description }; +} diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 34c2503..7c0fbc5 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -478,6 +478,7 @@ test("should handle server cancelling a request", async () => { // Request should be rejected await expect(createMessagePromise).rejects.toBe("Cancelled by test"); }); + test("should handle request timeout", async () => { const server = new Server( { diff --git a/src/server/index.ts b/src/server/index.ts index 0333ad8..3901099 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,5 @@ import { + mergeCapabilities, Protocol, ProtocolOptions, RequestOptions, @@ -32,7 +33,7 @@ export type ServerOptions = ProtocolOptions & { /** * Capabilities to advertise as being supported by this server. */ - capabilities: ServerCapabilities; + capabilities?: ServerCapabilities; /** * Optional instructions describing how to use the server and its features. @@ -89,11 +90,11 @@ export class Server< */ constructor( private _serverInfo: Implementation, - options: ServerOptions, + options?: ServerOptions, ) { super(options); - this._capabilities = options.capabilities; - this._instructions = options.instructions; + this._capabilities = options?.capabilities ?? {}; + this._instructions = options?.instructions; this.setRequestHandler(InitializeRequestSchema, (request) => this._oninitialize(request), @@ -103,6 +104,21 @@ export class Server< ); } + /** + * Registers new capabilities. This can only be called before connecting to a transport. + * + * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). + */ + public registerCapabilities(capabilities: ServerCapabilities): void { + if (this.transport) { + throw new Error( + "Cannot register capabilities after connecting to transport", + ); + } + + this._capabilities = mergeCapabilities(this._capabilities, capabilities); + } + protected assertCapabilityForMethod(method: RequestT["method"]): void { switch (method as ServerRequest["method"]) { case "sampling/createMessage": diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts new file mode 100644 index 0000000..33fdeb9 --- /dev/null +++ b/src/server/mcp.test.ts @@ -0,0 +1,1395 @@ +import { McpServer } from "./mcp.js"; +import { Client } from "../client/index.js"; +import { InMemoryTransport } from "../inMemory.js"; +import { z } from "zod"; +import { + ListToolsResultSchema, + CallToolResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ReadResourceResultSchema, + ListPromptsResultSchema, + GetPromptResultSchema, + CompleteResultSchema, +} from "../types.js"; +import { ResourceTemplate } from "./mcp.js"; +import { completable } from "./completable.js"; +import { UriTemplate } from "../shared/uriTemplate.js"; + +describe("McpServer", () => { + test("should expose underlying Server instance", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + expect(mcpServer.server).toBeDefined(); + }); + + test("should allow sending notifications via Server", async () => { + const mcpServer = new McpServer( + { + name: "test server", + version: "1.0", + }, + { capabilities: { logging: {} } }, + ); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + // This should work because we're using the underlying server + await expect( + mcpServer.server.sendLoggingMessage({ + level: "info", + data: "Test log message", + }), + ).resolves.not.toThrow(); + }); +}); + +describe("ResourceTemplate", () => { + test("should create ResourceTemplate with string pattern", () => { + const template = new ResourceTemplate("test://{category}/{id}", { + list: undefined, + }); + expect(template.uriTemplate.toString()).toBe("test://{category}/{id}"); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with UriTemplate", () => { + const uriTemplate = new UriTemplate("test://{category}/{id}"); + const template = new ResourceTemplate(uriTemplate, { list: undefined }); + expect(template.uriTemplate).toBe(uriTemplate); + expect(template.listCallback).toBeUndefined(); + }); + + test("should create ResourceTemplate with list callback", async () => { + const list = jest.fn().mockResolvedValue({ + resources: [{ name: "Test", uri: "test://example" }], + }); + + const template = new ResourceTemplate("test://{id}", { list }); + expect(template.listCallback).toBe(list); + + const abortController = new AbortController(); + const result = await template.listCallback?.({ + signal: abortController.signal, + }); + expect(result?.resources).toHaveLength(1); + expect(list).toHaveBeenCalled(); + }); +}); + +describe("tool()", () => { + test("should register zero-argument tool", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toEqual({ + type: "object", + }); + }); + + test("should register tool with args schema", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async ({ name, value }) => ({ + content: [ + { + type: "text", + text: `${name}: ${value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].inputSchema).toMatchObject({ + type: "object", + properties: { + name: { type: "string" }, + value: { type: "number" }, + }, + }); + }); + + test("should register tool with description", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.tool("test", "Test description", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/list", + }, + ListToolsResultSchema, + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe("test"); + expect(result.tools[0].description).toBe("Test description"); + }); + + test("should validate tool args", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool( + "test", + { + name: z.string(), + value: z.number(), + }, + async ({ name, value }) => ({ + content: [ + { + type: "text", + text: `${name}: ${value}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + name: "test", + value: "not a number", + }, + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate tool registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + expect(() => { + mcpServer.tool("test", async () => ({ + content: [ + { + type: "text", + text: "Test response 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should allow client to call server tools", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool( + "test", + "Test tool", + { + input: z.string(), + }, + async ({ input }) => ({ + content: [ + { + type: "text", + text: `Processed: ${input}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "test", + arguments: { + input: "hello", + }, + }, + }, + CallToolResultSchema, + ); + + expect(result.content).toEqual([ + { + type: "text", + text: "Processed: hello", + }, + ]); + }); + + test("should handle server tool errors gracefully", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool("error-test", async () => { + throw new Error("Tool execution failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "tools/call", + params: { + name: "error-test", + }, + }, + CallToolResultSchema, + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual([ + { + type: "text", + text: "Tool execution failed", + }, + ]); + }); + + test("should throw McpError for invalid tool name", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + tools: {}, + }, + }, + ); + + mcpServer.tool("test-tool", async () => ({ + content: [ + { + type: "text", + text: "Test response", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "tools/call", + params: { + name: "nonexistent-tool", + }, + }, + CallToolResultSchema, + ), + ).rejects.toThrow(/Tool nonexistent-tool not found/); + }); +}); + +describe("resource()", () => { + test("should register resource with uri and readCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe("test"); + expect(result.resources[0].uri).toBe("test://resource"); + }); + + test("should register resource with metadata", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + "test://resource", + { + description: "Test resource", + mimeType: "text/plain", + }, + async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].description).toBe("Test resource"); + expect(result.resources[0].mimeType).toBe("text/plain"); + }); + + test("should register resource template", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/templates/list", + }, + ListResourceTemplatesResultSchema, + ); + + expect(result.resourceTemplates).toHaveLength(1); + expect(result.resourceTemplates[0].name).toBe("test"); + expect(result.resourceTemplates[0].uriTemplate).toBe( + "test://resource/{id}", + ); + }); + + test("should register resource template with listCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { + list: async () => ({ + resources: [ + { + name: "Resource 1", + uri: "test://resource/1", + }, + { + name: "Resource 2", + uri: "test://resource/2", + }, + ], + }), + }), + async (uri) => ({ + contents: [ + { + uri: uri.href, + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/list", + }, + ListResourcesResultSchema, + ); + + expect(result.resources).toHaveLength(2); + expect(result.resources[0].name).toBe("Resource 1"); + expect(result.resources[0].uri).toBe("test://resource/1"); + expect(result.resources[1].name).toBe("Resource 2"); + expect(result.resources[1].uri).toBe("test://resource/2"); + }); + + test("should pass template variables to readCallback", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}/{id}", { + list: undefined, + }), + async (uri, { category, id }) => ({ + contents: [ + { + uri: uri.href, + text: `Category: ${category}, ID: ${id}`, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "resources/read", + params: { + uri: "test://resource/books/123", + }, + }, + ReadResourceResultSchema, + ); + + expect(result.contents[0].text).toBe("Category: books, ID: 123"); + }); + + test("should prevent duplicate resource registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + expect(() => { + mcpServer.resource("test2", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content 2", + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should prevent duplicate resource template registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content", + }, + ], + }), + ); + + expect(() => { + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{id}", { list: undefined }), + async () => ({ + contents: [ + { + uri: "test://resource/123", + text: "Test content 2", + }, + ], + }), + ); + }).toThrow(/already registered/); + }); + + test("should handle resource read errors gracefully", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("error-test", "test://error", async () => { + throw new Error("Resource read failed"); + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://error", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource read failed/); + }); + + test("should throw McpError for invalid resource URI", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.resource("test", "test://resource", async () => ({ + contents: [ + { + uri: "test://resource", + text: "Test content", + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "resources/read", + params: { + uri: "test://nonexistent", + }, + }, + ReadResourceResultSchema, + ), + ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); + }); + + test("should support completion of resource template parameters", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: () => ["books", "movies", "music"], + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["books", "movies", "music"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of resource template parameters", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + resources: {}, + }, + }, + ); + + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: (test: string) => + ["books", "movies", "music"].filter((value) => + value.startsWith(test), + ), + }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/resource", + uri: "test://resource/{category}", + }, + argument: { + name: "category", + value: "m", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["movies", "music"]); + expect(result.completion.total).toBe(2); + }); +}); + +describe("prompt()", () => { + test("should register zero-argument prompt", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + + test("should register prompt with args schema", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt( + "test", + { + name: z.string(), + value: z.string(), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].arguments).toEqual([ + { name: "name", required: true }, + { name: "value", required: true }, + ]); + }); + + test("should register prompt with description", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + const client = new Client({ + name: "test client", + version: "1.0", + }); + + mcpServer.prompt("test", "Test description", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "prompts/list", + }, + ListPromptsResultSchema, + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe("test"); + expect(result.prompts[0].description).toBe("Test description"); + }); + + test("should validate prompt args", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test", + { + name: z.string(), + value: z.string().min(3), + }, + async ({ name, value }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `${name}: ${value}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "test", + arguments: { + name: "test", + value: "ab", // Too short + }, + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Invalid arguments/); + }); + + test("should prevent duplicate prompt registration", () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + expect(() => { + mcpServer.prompt("test", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response 2", + }, + }, + ], + })); + }).toThrow(/already registered/); + }); + + test("should throw McpError for invalid prompt name", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt("test-prompt", async () => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + }, + ], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + await expect( + client.request( + { + method: "prompts/get", + params: { + name: "nonexistent-prompt", + }, + }, + GetPromptResultSchema, + ), + ).rejects.toThrow(/Prompt nonexistent-prompt not found/); + }); + + test("should support completion of prompt arguments", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test-prompt", + { + name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]); + expect(result.completion.total).toBe(3); + }); + + test("should support filtered completion of prompt arguments", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + }, + }, + ); + + mcpServer.prompt( + "test-prompt", + { + name: completable(z.string(), (test) => + ["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)), + ), + }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + const result = await client.request( + { + method: "completion/complete", + params: { + ref: { + type: "ref/prompt", + name: "test-prompt", + }, + argument: { + name: "name", + value: "A", + }, + }, + }, + CompleteResultSchema, + ); + + expect(result.completion.values).toEqual(["Alice"]); + expect(result.completion.total).toBe(1); + }); +}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts new file mode 100644 index 0000000..4b48446 --- /dev/null +++ b/src/server/mcp.ts @@ -0,0 +1,766 @@ +import { Server, ServerOptions } from "./index.js"; +import { zodToJsonSchema } from "zod-to-json-schema"; +import { + z, + ZodRawShape, + ZodObject, + ZodString, + AnyZodObject, + ZodTypeAny, + ZodType, + ZodTypeDef, + ZodOptional, +} from "zod"; +import { + Implementation, + Tool, + ListToolsResult, + CallToolResult, + McpError, + ErrorCode, + CompleteRequest, + CompleteResult, + PromptReference, + ResourceReference, + Resource, + ListResourcesResult, + ListResourceTemplatesRequestSchema, + ReadResourceRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + ListResourcesRequestSchema, + ListPromptsRequestSchema, + GetPromptRequestSchema, + CompleteRequestSchema, + ListPromptsResult, + Prompt, + PromptArgument, + GetPromptResult, + ReadResourceResult, +} from "../types.js"; +import { Completable, CompletableDef } from "./completable.js"; +import { UriTemplate, Variables } from "../shared/uriTemplate.js"; +import { RequestHandlerExtra } from "../shared/protocol.js"; +import { Transport } from "../shared/transport.js"; + +/** + * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. + * For advanced usage (like sending notifications or setting custom request handlers), use the underlying + * Server instance available via the `server` property. + */ +export class McpServer { + /** + * The underlying Server instance, useful for advanced operations like sending notifications. + */ + public readonly server: Server; + + private _registeredResources: { [uri: string]: RegisteredResource } = {}; + private _registeredResourceTemplates: { + [name: string]: RegisteredResourceTemplate; + } = {}; + private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + + constructor(serverInfo: Implementation, options?: ServerOptions) { + this.server = new Server(serverInfo, options); + } + + /** + * Attaches to the given transport, starts it, and starts listening for messages. + * + * The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + */ + async connect(transport: Transport): Promise { + return await this.server.connect(transport); + } + + /** + * Closes the connection. + */ + async close(): Promise { + await this.server.close(); + } + + private setToolRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListToolsRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + CallToolRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + tools: {}, + }); + + this.server.setRequestHandler( + ListToolsRequestSchema, + (): ListToolsResult => ({ + tools: Object.entries(this._registeredTools).map( + ([name, tool]): Tool => { + return { + name, + description: tool.description, + inputSchema: tool.inputSchema + ? (zodToJsonSchema(tool.inputSchema) as Tool["inputSchema"]) + : EMPTY_OBJECT_JSON_SCHEMA, + }; + }, + ), + }), + ); + + this.server.setRequestHandler( + CallToolRequestSchema, + async (request, extra): Promise => { + const tool = this._registeredTools[request.params.name]; + if (!tool) { + throw new McpError( + ErrorCode.InvalidParams, + `Tool ${request.params.name} not found`, + ); + } + + if (tool.inputSchema) { + const parseResult = await tool.inputSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = tool.callback as ToolCallback; + try { + return await Promise.resolve(cb(args, extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } + } else { + const cb = tool.callback as ToolCallback; + try { + return await Promise.resolve(cb(extra)); + } catch (error) { + return { + content: [ + { + type: "text", + text: error instanceof Error ? error.message : String(error), + }, + ], + isError: true, + }; + } + } + }, + ); + } + + private setCompletionRequestHandler() { + this.server.assertCanSetRequestHandler( + CompleteRequestSchema.shape.method.value, + ); + + this.server.setRequestHandler( + CompleteRequestSchema, + async (request): Promise => { + switch (request.params.ref.type) { + case "ref/prompt": + return this.handlePromptCompletion(request, request.params.ref); + + case "ref/resource": + return this.handleResourceCompletion(request, request.params.ref); + + default: + throw new McpError( + ErrorCode.InvalidParams, + `Invalid completion reference: ${request.params.ref}`, + ); + } + }, + ); + } + + private async handlePromptCompletion( + request: CompleteRequest, + ref: PromptReference, + ): Promise { + const prompt = this._registeredPrompts[ref.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.ref.name} not found`, + ); + } + + if (!prompt.argsSchema) { + return EMPTY_COMPLETION_RESULT; + } + + const field = prompt.argsSchema.shape[request.params.argument.name]; + if (!(field instanceof Completable)) { + return EMPTY_COMPLETION_RESULT; + } + + const def: CompletableDef = field._def; + const suggestions = await def.complete(request.params.argument.value); + return createCompletionResult(suggestions); + } + + private async handleResourceCompletion( + request: CompleteRequest, + ref: ResourceReference, + ): Promise { + const template = Object.values(this._registeredResourceTemplates).find( + (t) => t.resourceTemplate.uriTemplate.toString() === ref.uri, + ); + + if (!template) { + if (this._registeredResources[ref.uri]) { + // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). + return EMPTY_COMPLETION_RESULT; + } + + throw new McpError( + ErrorCode.InvalidParams, + `Resource template ${request.params.ref.uri} not found`, + ); + } + + const completer = template.resourceTemplate.completeCallback( + request.params.argument.name, + ); + if (!completer) { + return EMPTY_COMPLETION_RESULT; + } + + const suggestions = await completer(request.params.argument.value); + return createCompletionResult(suggestions); + } + + private setResourceRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListResourcesRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + ListResourceTemplatesRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + ReadResourceRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + resources: {}, + }); + + this.server.setRequestHandler( + ListResourcesRequestSchema, + async (request, extra) => { + const resources = Object.entries(this._registeredResources).map( + ([uri, resource]) => ({ + uri, + name: resource.name, + ...resource.metadata, + }), + ); + + const templateResources: Resource[] = []; + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + if (!template.resourceTemplate.listCallback) { + continue; + } + + const result = await template.resourceTemplate.listCallback(extra); + for (const resource of result.resources) { + templateResources.push({ + ...resource, + ...template.metadata, + }); + } + } + + return { resources: [...resources, ...templateResources] }; + }, + ); + + this.server.setRequestHandler( + ListResourceTemplatesRequestSchema, + async () => { + const resourceTemplates = Object.entries( + this._registeredResourceTemplates, + ).map(([name, template]) => ({ + name, + uriTemplate: template.resourceTemplate.uriTemplate.toString(), + ...template.metadata, + })); + + return { resourceTemplates }; + }, + ); + + this.server.setRequestHandler( + ReadResourceRequestSchema, + async (request, extra) => { + const uri = new URL(request.params.uri); + + // First check for exact resource match + const resource = this._registeredResources[uri.toString()]; + if (resource) { + return resource.readCallback(uri, extra); + } + + // Then check templates + for (const template of Object.values( + this._registeredResourceTemplates, + )) { + const variables = template.resourceTemplate.uriTemplate.match( + uri.toString(), + ); + if (variables) { + return template.readCallback(uri, variables, extra); + } + } + + throw new McpError( + ErrorCode.InvalidParams, + `Resource ${uri} not found`, + ); + }, + ); + + this.setCompletionRequestHandler(); + } + + private setPromptRequestHandlers() { + this.server.assertCanSetRequestHandler( + ListPromptsRequestSchema.shape.method.value, + ); + this.server.assertCanSetRequestHandler( + GetPromptRequestSchema.shape.method.value, + ); + + this.server.registerCapabilities({ + prompts: {}, + }); + + this.server.setRequestHandler( + ListPromptsRequestSchema, + (): ListPromptsResult => ({ + prompts: Object.entries(this._registeredPrompts).map( + ([name, prompt]): Prompt => { + return { + name, + description: prompt.description, + arguments: prompt.argsSchema + ? promptArgumentsFromSchema(prompt.argsSchema) + : undefined, + }; + }, + ), + }), + ); + + this.server.setRequestHandler( + GetPromptRequestSchema, + async (request, extra): Promise => { + const prompt = this._registeredPrompts[request.params.name]; + if (!prompt) { + throw new McpError( + ErrorCode.InvalidParams, + `Prompt ${request.params.name} not found`, + ); + } + + if (prompt.argsSchema) { + const parseResult = await prompt.argsSchema.safeParseAsync( + request.params.arguments, + ); + if (!parseResult.success) { + throw new McpError( + ErrorCode.InvalidParams, + `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`, + ); + } + + const args = parseResult.data; + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(args, extra)); + } else { + const cb = prompt.callback as PromptCallback; + return await Promise.resolve(cb(extra)); + } + }, + ); + + this.setCompletionRequestHandler(); + } + + /** + * Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests. + */ + resource(name: string, uri: string, readCallback: ReadResourceCallback): void; + + /** + * Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests. + */ + resource( + name: string, + uri: string, + metadata: ResourceMetadata, + readCallback: ReadResourceCallback, + ): void; + + /** + * Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests. + */ + resource( + name: string, + template: ResourceTemplate, + readCallback: ReadResourceTemplateCallback, + ): void; + + /** + * Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests. + */ + resource( + name: string, + template: ResourceTemplate, + metadata: ResourceMetadata, + readCallback: ReadResourceTemplateCallback, + ): void; + + resource( + name: string, + uriOrTemplate: string | ResourceTemplate, + ...rest: unknown[] + ): void { + let metadata: ResourceMetadata | undefined; + if (typeof rest[0] === "object") { + metadata = rest.shift() as ResourceMetadata; + } + + const readCallback = rest[0] as + | ReadResourceCallback + | ReadResourceTemplateCallback; + + if (typeof uriOrTemplate === "string") { + if (this._registeredResources[uriOrTemplate]) { + throw new Error(`Resource ${uriOrTemplate} is already registered`); + } + + this._registeredResources[uriOrTemplate] = { + name, + metadata, + readCallback: readCallback as ReadResourceCallback, + }; + } else { + if (this._registeredResourceTemplates[name]) { + throw new Error(`Resource template ${name} is already registered`); + } + + this._registeredResourceTemplates[name] = { + resourceTemplate: uriOrTemplate, + metadata, + readCallback: readCallback as ReadResourceTemplateCallback, + }; + } + + this.setResourceRequestHandlers(); + } + + /** + * Registers a zero-argument tool `name`, which will run the given function when the client calls it. + */ + tool(name: string, cb: ToolCallback): void; + + /** + * Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it. + */ + tool(name: string, description: string, cb: ToolCallback): void; + + /** + * Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool( + name: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; + + /** + * Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + tool( + name: string, + description: string, + paramsSchema: Args, + cb: ToolCallback, + ): void; + + tool(name: string, ...rest: unknown[]): void { + if (this._registeredTools[name]) { + throw new Error(`Tool ${name} is already registered`); + } + + let description: string | undefined; + if (typeof rest[0] === "string") { + description = rest.shift() as string; + } + + let paramsSchema: ZodRawShape | undefined; + if (rest.length > 1) { + paramsSchema = rest.shift() as ZodRawShape; + } + + const cb = rest[0] as ToolCallback; + this._registeredTools[name] = { + description, + inputSchema: + paramsSchema === undefined ? undefined : z.object(paramsSchema), + callback: cb, + }; + + this.setToolRequestHandlers(); + } + + /** + * Registers a zero-argument prompt `name`, which will run the given function when the client calls it. + */ + prompt(name: string, cb: PromptCallback): void; + + /** + * Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it. + */ + prompt(name: string, description: string, cb: PromptCallback): void; + + /** + * Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt( + name: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + /** + * Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments. + */ + prompt( + name: string, + description: string, + argsSchema: Args, + cb: PromptCallback, + ): void; + + prompt(name: string, ...rest: unknown[]): void { + if (this._registeredPrompts[name]) { + throw new Error(`Prompt ${name} is already registered`); + } + + let description: string | undefined; + if (typeof rest[0] === "string") { + description = rest.shift() as string; + } + + let argsSchema: PromptArgsRawShape | undefined; + if (rest.length > 1) { + argsSchema = rest.shift() as PromptArgsRawShape; + } + + const cb = rest[0] as PromptCallback; + this._registeredPrompts[name] = { + description, + argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), + callback: cb, + }; + + this.setPromptRequestHandlers(); + } +} + +/** + * A callback to complete one variable within a resource template's URI template. + */ +export type CompleteResourceTemplateCallback = ( + value: string, +) => string[] | Promise; + +/** + * A resource template combines a URI pattern with optional functionality to enumerate + * all resources matching that pattern. + */ +export class ResourceTemplate { + private _uriTemplate: UriTemplate; + + constructor( + uriTemplate: string | UriTemplate, + private _callbacks: { + /** + * A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing. + */ + list: ListResourcesCallback | undefined; + + /** + * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. + */ + complete?: { + [variable: string]: CompleteResourceTemplateCallback; + }; + }, + ) { + this._uriTemplate = + typeof uriTemplate === "string" + ? new UriTemplate(uriTemplate) + : uriTemplate; + } + + /** + * Gets the URI template pattern. + */ + get uriTemplate(): UriTemplate { + return this._uriTemplate; + } + + /** + * Gets the list callback, if one was provided. + */ + get listCallback(): ListResourcesCallback | undefined { + return this._callbacks.list; + } + + /** + * Gets the callback for completing a specific URI template variable, if one was provided. + */ + completeCallback( + variable: string, + ): CompleteResourceTemplateCallback | undefined { + return this._callbacks.complete?.[variable]; + } +} + +/** + * Callback for a tool handler registered with Server.tool(). + * + * Parameters will include tool arguments, if applicable, as well as other request handler context. + */ +export type ToolCallback = + Args extends ZodRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => CallToolResult | Promise + : (extra: RequestHandlerExtra) => CallToolResult | Promise; + +type RegisteredTool = { + description?: string; + inputSchema?: AnyZodObject; + callback: ToolCallback; +}; + +const EMPTY_OBJECT_JSON_SCHEMA = { + type: "object" as const, +}; + +/** + * Additional, optional information for annotating a resource. + */ +export type ResourceMetadata = Omit; + +/** + * Callback to list all resources matching a given template. + */ +export type ListResourcesCallback = ( + extra: RequestHandlerExtra, +) => ListResourcesResult | Promise; + +/** + * Callback to read a resource at a given URI. + */ +export type ReadResourceCallback = ( + uri: URL, + extra: RequestHandlerExtra, +) => ReadResourceResult | Promise; + +type RegisteredResource = { + name: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; +}; + +/** + * Callback to read a resource at a given URI, following a filled-in URI template. + */ +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, + extra: RequestHandlerExtra, +) => ReadResourceResult | Promise; + +type RegisteredResourceTemplate = { + resourceTemplate: ResourceTemplate; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; +}; + +type PromptArgsRawShape = { + [k: string]: + | ZodType + | ZodOptional>; +}; + +export type PromptCallback< + Args extends undefined | PromptArgsRawShape = undefined, +> = Args extends PromptArgsRawShape + ? ( + args: z.objectOutputType, + extra: RequestHandlerExtra, + ) => GetPromptResult | Promise + : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + +type RegisteredPrompt = { + description?: string; + argsSchema?: ZodObject; + callback: PromptCallback; +}; + +function promptArgumentsFromSchema( + schema: ZodObject, +): PromptArgument[] { + return Object.entries(schema.shape).map( + ([name, field]): PromptArgument => ({ + name, + description: field.description, + required: !field.isOptional(), + }), + ); +} + +function createCompletionResult(suggestions: string[]): CompleteResult { + return { + completion: { + values: suggestions.slice(0, 100), + total: suggestions.length, + hasMore: suggestions.length > 100, + }, + }; +} + +const EMPTY_COMPLETION_RESULT: CompleteResult = { + completion: { + values: [], + hasMore: false, + }, +}; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index 9423a2a..3073d0a 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -1,13 +1,15 @@ -import { Protocol } from "./protocol.js"; -import { Transport } from "./transport.js"; +import { ZodType, z } from "zod"; import { - McpError, + ClientCapabilities, ErrorCode, + McpError, + Notification, Request, Result, - Notification, + ServerCapabilities, } from "../types.js"; -import { ZodType, z } from "zod"; +import { Protocol, mergeCapabilities } from "./protocol.js"; +import { Transport } from "./transport.js"; // Mock Transport class class MockTransport implements Transport { @@ -61,3 +63,89 @@ describe("protocol tests", () => { expect(oncloseMock).toHaveBeenCalled(); }); }); + +describe("mergeCapabilities", () => { + it("should merge client capabilities", () => { + const base: ClientCapabilities = { + sampling: {}, + roots: { + listChanged: true, + }, + }; + + const additional: ClientCapabilities = { + experimental: { + feature: true, + }, + roots: { + newProp: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + sampling: {}, + roots: { + listChanged: true, + newProp: true, + }, + experimental: { + feature: true, + }, + }); + }); + + it("should merge server capabilities", () => { + const base: ServerCapabilities = { + logging: {}, + prompts: { + listChanged: true, + }, + }; + + const additional: ServerCapabilities = { + resources: { + subscribe: true, + }, + prompts: { + newProp: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({ + logging: {}, + prompts: { + listChanged: true, + newProp: true, + }, + resources: { + subscribe: true, + }, + }); + }); + + it("should override existing values with additional values", () => { + const base: ServerCapabilities = { + prompts: { + listChanged: false, + }, + }; + + const additional: ServerCapabilities = { + prompts: { + listChanged: true, + }, + }; + + const merged = mergeCapabilities(base, additional); + expect(merged.prompts!.listChanged).toBe(true); + }); + + it("should handle empty objects", () => { + const base = {}; + const additional = {}; + const merged = mergeCapabilities(base, additional); + expect(merged).toEqual({}); + }); +}); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index f430b31..a4f211c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,6 +1,7 @@ import { ZodLiteral, ZodObject, ZodType, z } from "zod"; import { CancelledNotificationSchema, + ClientCapabilities, ErrorCode, JSONRPCError, JSONRPCNotification, @@ -15,6 +16,7 @@ import { Request, RequestId, Result, + ServerCapabilities, } from "../types.js"; import { Transport } from "./transport.js"; @@ -511,6 +513,17 @@ export abstract class Protocol< this._requestHandlers.delete(method); } + /** + * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. + */ + assertCanSetRequestHandler(method: string): void { + if (this._requestHandlers.has(method)) { + throw new Error( + `A request handler for ${method} already exists, which would be overridden`, + ); + } + } + /** * Registers a handler to invoke when this protocol object receives a notification with the given method. * @@ -538,3 +551,19 @@ export abstract class Protocol< this._notificationHandlers.delete(method); } } + +export function mergeCapabilities< + T extends ServerCapabilities | ClientCapabilities, +>(base: T, additional: T): T { + return Object.entries(additional).reduce( + (acc, [key, value]) => { + if (value && typeof value === "object") { + acc[key] = acc[key] ? { ...acc[key], ...value } : value; + } else { + acc[key] = value; + } + return acc; + }, + { ...base }, + ); +} diff --git a/src/shared/uriTemplate.test.ts b/src/shared/uriTemplate.test.ts new file mode 100644 index 0000000..7941d44 --- /dev/null +++ b/src/shared/uriTemplate.test.ts @@ -0,0 +1,259 @@ +import { UriTemplate } from "./uriTemplate.js"; + +describe("UriTemplate", () => { + describe("isTemplate", () => { + it("should return true for strings containing template expressions", () => { + expect(UriTemplate.isTemplate("{foo}")).toBe(true); + expect(UriTemplate.isTemplate("/users/{id}")).toBe(true); + expect(UriTemplate.isTemplate("http://example.com/{path}/{file}")).toBe(true); + expect(UriTemplate.isTemplate("/search{?q,limit}")).toBe(true); + }); + + it("should return false for strings without template expressions", () => { + expect(UriTemplate.isTemplate("")).toBe(false); + expect(UriTemplate.isTemplate("plain string")).toBe(false); + expect(UriTemplate.isTemplate("http://example.com/foo/bar")).toBe(false); + expect(UriTemplate.isTemplate("{}")).toBe(false); // Empty braces don't count + expect(UriTemplate.isTemplate("{ }")).toBe(false); // Just whitespace doesn't count + }); + }); + + describe("simple string expansion", () => { + it("should expand simple string variables", () => { + const template = new UriTemplate("http://example.com/users/{username}"); + expect(template.expand({ username: "fred" })).toBe( + "http://example.com/users/fred", + ); + }); + + it("should handle multiple variables", () => { + const template = new UriTemplate("{x,y}"); + expect(template.expand({ x: "1024", y: "768" })).toBe("1024,768"); + }); + + it("should encode reserved characters", () => { + const template = new UriTemplate("{var}"); + expect(template.expand({ var: "value with spaces" })).toBe( + "value%20with%20spaces", + ); + }); + }); + + describe("reserved expansion", () => { + it("should not encode reserved characters with + operator", () => { + const template = new UriTemplate("{+path}/here"); + expect(template.expand({ path: "/foo/bar" })).toBe("/foo/bar/here"); + }); + }); + + describe("fragment expansion", () => { + it("should add # prefix and not encode reserved chars", () => { + const template = new UriTemplate("X{#var}"); + expect(template.expand({ var: "/test" })).toBe("X#/test"); + }); + }); + + describe("label expansion", () => { + it("should add . prefix", () => { + const template = new UriTemplate("X{.var}"); + expect(template.expand({ var: "test" })).toBe("X.test"); + }); + }); + + describe("path expansion", () => { + it("should add / prefix", () => { + const template = new UriTemplate("X{/var}"); + expect(template.expand({ var: "test" })).toBe("X/test"); + }); + }); + + describe("query expansion", () => { + it("should add ? prefix and name=value format", () => { + const template = new UriTemplate("X{?var}"); + expect(template.expand({ var: "test" })).toBe("X?var=test"); + }); + }); + + describe("form continuation expansion", () => { + it("should add & prefix and name=value format", () => { + const template = new UriTemplate("X{&var}"); + expect(template.expand({ var: "test" })).toBe("X&var=test"); + }); + }); + + describe("matching", () => { + it("should match simple strings and extract variables", () => { + const template = new UriTemplate("http://example.com/users/{username}"); + const match = template.match("http://example.com/users/fred"); + expect(match).toEqual({ username: "fred" }); + }); + + it("should match multiple variables", () => { + const template = new UriTemplate("/users/{username}/posts/{postId}"); + const match = template.match("/users/fred/posts/123"); + expect(match).toEqual({ username: "fred", postId: "123" }); + }); + + it("should return null for non-matching URIs", () => { + const template = new UriTemplate("/users/{username}"); + const match = template.match("/posts/123"); + expect(match).toBeNull(); + }); + + it("should handle exploded arrays", () => { + const template = new UriTemplate("{/list*}"); + const match = template.match("/red,green,blue"); + expect(match).toEqual({ list: ["red", "green", "blue"] }); + }); + }); + + describe("edge cases", () => { + it("should handle empty variables", () => { + const template = new UriTemplate("{empty}"); + expect(template.expand({})).toBe(""); + expect(template.expand({ empty: "" })).toBe(""); + }); + + it("should handle undefined variables", () => { + const template = new UriTemplate("{a}{b}{c}"); + expect(template.expand({ b: "2" })).toBe("2"); + }); + + it("should handle special characters in variable names", () => { + const template = new UriTemplate("{$var_name}"); + expect(template.expand({ "$var_name": "value" })).toBe("value"); + }); + }); + + describe("complex patterns", () => { + it("should handle nested path segments", () => { + const template = new UriTemplate("/api/{version}/{resource}/{id}"); + expect(template.expand({ + version: "v1", + resource: "users", + id: "123" + })).toBe("/api/v1/users/123"); + }); + + it("should handle query parameters with arrays", () => { + const template = new UriTemplate("/search{?tags*}"); + expect(template.expand({ + tags: ["nodejs", "typescript", "testing"] + })).toBe("/search?tags=nodejs,typescript,testing"); + }); + + it("should handle multiple query parameters", () => { + const template = new UriTemplate("/search{?q,page,limit}"); + expect(template.expand({ + q: "test", + page: "1", + limit: "10" + })).toBe("/search?q=test&page=1&limit=10"); + }); + }); + + describe("matching complex patterns", () => { + it("should match nested path segments", () => { + const template = new UriTemplate("/api/{version}/{resource}/{id}"); + const match = template.match("/api/v1/users/123"); + expect(match).toEqual({ + version: "v1", + resource: "users", + id: "123" + }); + }); + + it("should match query parameters", () => { + const template = new UriTemplate("/search{?q}"); + const match = template.match("/search?q=test"); + expect(match).toEqual({ q: "test" }); + }); + + it("should match multiple query parameters", () => { + const template = new UriTemplate("/search{?q,page}"); + const match = template.match("/search?q=test&page=1"); + expect(match).toEqual({ q: "test", page: "1" }); + }); + + it("should handle partial matches correctly", () => { + const template = new UriTemplate("/users/{id}"); + expect(template.match("/users/123/extra")).toBeNull(); + expect(template.match("/users")).toBeNull(); + }); + }); + + describe("security and edge cases", () => { + it("should handle extremely long input strings", () => { + const longString = "x".repeat(100000); + const template = new UriTemplate(`/api/{param}`); + expect(template.expand({ param: longString })).toBe(`/api/${longString}`); + expect(template.match(`/api/${longString}`)).toEqual({ param: longString }); + }); + + it("should handle deeply nested template expressions", () => { + const template = new UriTemplate("{a}{b}{c}{d}{e}{f}{g}{h}{i}{j}".repeat(1000)); + expect(() => template.expand({ + a: "1", b: "2", c: "3", d: "4", e: "5", + f: "6", g: "7", h: "8", i: "9", j: "0" + })).not.toThrow(); + }); + + it("should handle malformed template expressions", () => { + expect(() => new UriTemplate("{unclosed")).toThrow(); + expect(() => new UriTemplate("{}")).not.toThrow(); + expect(() => new UriTemplate("{,}")).not.toThrow(); + expect(() => new UriTemplate("{a}{")).toThrow(); + }); + + it("should handle pathological regex patterns", () => { + const template = new UriTemplate("/api/{param}"); + // Create a string that could cause catastrophic backtracking + const input = "/api/" + "a".repeat(100000); + expect(() => template.match(input)).not.toThrow(); + }); + + it("should handle invalid UTF-8 sequences", () => { + const template = new UriTemplate("/api/{param}"); + const invalidUtf8 = "���"; + expect(() => template.expand({ param: invalidUtf8 })).not.toThrow(); + expect(() => template.match(`/api/${invalidUtf8}`)).not.toThrow(); + }); + + it("should handle template/URI length mismatches", () => { + const template = new UriTemplate("/api/{param}"); + expect(template.match("/api/")).toBeNull(); + expect(template.match("/api")).toBeNull(); + expect(template.match("/api/value/extra")).toBeNull(); + }); + + it("should handle repeated operators", () => { + const template = new UriTemplate("{?a}{?b}{?c}"); + expect(template.expand({ a: "1", b: "2", c: "3" })).toBe("?a=1&b=2&c=3"); + }); + + it("should handle overlapping variable names", () => { + const template = new UriTemplate("{var}{vara}"); + expect(template.expand({ var: "1", vara: "2" })).toBe("12"); + }); + + it("should handle empty segments", () => { + const template = new UriTemplate("///{a}////{b}////"); + expect(template.expand({ a: "1", b: "2" })).toBe("///1////2////"); + expect(template.match("///1////2////")).toEqual({ a: "1", b: "2" }); + }); + + it("should handle maximum template expression limit", () => { + // Create a template with many expressions + const expressions = Array(10000).fill("{param}").join(""); + expect(() => new UriTemplate(expressions)).not.toThrow(); + }); + + it("should handle maximum variable name length", () => { + const longName = "a".repeat(10000); + const template = new UriTemplate(`{${longName}}`); + const vars: Record = {}; + vars[longName] = "value"; + expect(() => template.expand(vars)).not.toThrow(); + }); + }); +}); diff --git a/src/shared/uriTemplate.ts b/src/shared/uriTemplate.ts new file mode 100644 index 0000000..bb17732 --- /dev/null +++ b/src/shared/uriTemplate.ts @@ -0,0 +1,312 @@ +// Claude-authored implementation of RFC 6570 URI Templates + +export type Variables = Record; + +const MAX_TEMPLATE_LENGTH = 1000000; // 1MB +const MAX_VARIABLE_LENGTH = 1000000; // 1MB +const MAX_TEMPLATE_EXPRESSIONS = 10000; +const MAX_REGEX_LENGTH = 1000000; // 1MB + +export class UriTemplate { + /** + * Returns true if the given string contains any URI template expressions. + * A template expression is a sequence of characters enclosed in curly braces, + * like {foo} or {?bar}. + */ + static isTemplate(str: string): boolean { + // Look for any sequence of characters between curly braces + // that isn't just whitespace + return /\{[^}\s]+\}/.test(str); + } + + private static validateLength( + str: string, + max: number, + context: string, + ): void { + if (str.length > max) { + throw new Error( + `${context} exceeds maximum length of ${max} characters (got ${str.length})`, + ); + } + } + private readonly template: string; + private readonly parts: Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + >; + + constructor(template: string) { + UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, "Template"); + this.template = template; + this.parts = this.parse(template); + } + + toString(): string { + return this.template; + } + + private parse( + template: string, + ): Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + > { + const parts: Array< + | string + | { name: string; operator: string; names: string[]; exploded: boolean } + > = []; + let currentText = ""; + let i = 0; + let expressionCount = 0; + + while (i < template.length) { + if (template[i] === "{") { + if (currentText) { + parts.push(currentText); + currentText = ""; + } + const end = template.indexOf("}", i); + if (end === -1) throw new Error("Unclosed template expression"); + + expressionCount++; + if (expressionCount > MAX_TEMPLATE_EXPRESSIONS) { + throw new Error( + `Template contains too many expressions (max ${MAX_TEMPLATE_EXPRESSIONS})`, + ); + } + + const expr = template.slice(i + 1, end); + const operator = this.getOperator(expr); + const exploded = expr.includes("*"); + const names = this.getNames(expr); + const name = names[0]; + + // Validate variable name length + for (const name of names) { + UriTemplate.validateLength( + name, + MAX_VARIABLE_LENGTH, + "Variable name", + ); + } + + parts.push({ name, operator, names, exploded }); + i = end + 1; + } else { + currentText += template[i]; + i++; + } + } + + if (currentText) { + parts.push(currentText); + } + + return parts; + } + + private getOperator(expr: string): string { + const operators = ["+", "#", ".", "/", "?", "&"]; + return operators.find((op) => expr.startsWith(op)) || ""; + } + + private getNames(expr: string): string[] { + const operator = this.getOperator(expr); + return expr + .slice(operator.length) + .split(",") + .map((name) => name.replace("*", "").trim()) + .filter((name) => name.length > 0); + } + + private encodeValue(value: string, operator: string): string { + UriTemplate.validateLength(value, MAX_VARIABLE_LENGTH, "Variable value"); + if (operator === "+" || operator === "#") { + return encodeURI(value); + } + return encodeURIComponent(value); + } + + private expandPart( + part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }, + variables: Variables, + ): string { + if (part.operator === "?" || part.operator === "&") { + const pairs = part.names + .map((name) => { + const value = variables[name]; + if (value === undefined) return ""; + const encoded = Array.isArray(value) + ? value.map((v) => this.encodeValue(v, part.operator)).join(",") + : this.encodeValue(value.toString(), part.operator); + return `${name}=${encoded}`; + }) + .filter((pair) => pair.length > 0); + + if (pairs.length === 0) return ""; + const separator = part.operator === "?" ? "?" : "&"; + return separator + pairs.join("&"); + } + + if (part.names.length > 1) { + const values = part.names + .map((name) => variables[name]) + .filter((v) => v !== undefined); + if (values.length === 0) return ""; + return values.map((v) => (Array.isArray(v) ? v[0] : v)).join(","); + } + + const value = variables[part.name]; + if (value === undefined) return ""; + + const values = Array.isArray(value) ? value : [value]; + const encoded = values.map((v) => this.encodeValue(v, part.operator)); + + switch (part.operator) { + case "": + return encoded.join(","); + case "+": + return encoded.join(","); + case "#": + return "#" + encoded.join(","); + case ".": + return "." + encoded.join("."); + case "/": + return "/" + encoded.join("/"); + default: + return encoded.join(","); + } + } + + expand(variables: Variables): string { + let result = ""; + let hasQueryParam = false; + + for (const part of this.parts) { + if (typeof part === "string") { + result += part; + continue; + } + + const expanded = this.expandPart(part, variables); + if (!expanded) continue; + + // Convert ? to & if we already have a query parameter + if ((part.operator === "?" || part.operator === "&") && hasQueryParam) { + result += expanded.replace("?", "&"); + } else { + result += expanded; + } + + if (part.operator === "?" || part.operator === "&") { + hasQueryParam = true; + } + } + + return result; + } + + private escapeRegExp(str: string): string { + return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + } + + private partToRegExp(part: { + name: string; + operator: string; + names: string[]; + exploded: boolean; + }): Array<{ pattern: string; name: string }> { + const patterns: Array<{ pattern: string; name: string }> = []; + + // Validate variable name length for matching + for (const name of part.names) { + UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, "Variable name"); + } + + if (part.operator === "?" || part.operator === "&") { + for (let i = 0; i < part.names.length; i++) { + const name = part.names[i]; + const prefix = i === 0 ? "\\" + part.operator : "&"; + patterns.push({ + pattern: prefix + this.escapeRegExp(name) + "=([^&]+)", + name, + }); + } + return patterns; + } + + let pattern: string; + const name = part.name; + + switch (part.operator) { + case "": + pattern = part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"; + break; + case "+": + case "#": + pattern = "(.+)"; + break; + case ".": + pattern = "\\.([^/,]+)"; + break; + case "/": + pattern = "/" + (part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)"); + break; + default: + pattern = "([^/]+)"; + } + + patterns.push({ pattern, name }); + return patterns; + } + + match(uri: string): Variables | null { + UriTemplate.validateLength(uri, MAX_TEMPLATE_LENGTH, "URI"); + let pattern = "^"; + const names: Array<{ name: string; exploded: boolean }> = []; + + for (const part of this.parts) { + if (typeof part === "string") { + pattern += this.escapeRegExp(part); + } else { + const patterns = this.partToRegExp(part); + for (const { pattern: partPattern, name } of patterns) { + pattern += partPattern; + names.push({ name, exploded: part.exploded }); + } + } + } + + pattern += "$"; + UriTemplate.validateLength( + pattern, + MAX_REGEX_LENGTH, + "Generated regex pattern", + ); + const regex = new RegExp(pattern); + const match = uri.match(regex); + + if (!match) return null; + + const result: Variables = {}; + for (let i = 0; i < names.length; i++) { + const { name, exploded } = names[i]; + const value = match[i + 1]; + const cleanName = name.replace("*", ""); + + if (exploded && value.includes(",")) { + result[cleanName] = value.split(","); + } else { + result[cleanName] = value; + } + } + + return result; + } +}