diff --git a/README.md b/README.md index 9752db4..0864e5d 100644 --- a/README.md +++ b/README.md @@ -482,6 +482,7 @@ npm publish ```bash npm install # Install dependencies npm run dev # Development with hot reload +npm test # Run tests npm run build # Build for production npm start # Start production server ``` diff --git a/package-lock.json b/package-lock.json index d78d713..2b41613 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,7 +16,7 @@ "replicate": "^1.1.0" }, "bin": { - "imagegen-mcp": "dist/server.js" + "imagegen-mcp-server": "dist/server.js" }, "devDependencies": { "@types/node": "^24.3.1", @@ -464,6 +464,7 @@ "version": "1.17.5", "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.5.tgz", "integrity": "sha512-QakrKIGniGuRVfWBdMsDea/dx1PNE739QJ7gCM41s9q+qaCYTHCdsIBXQVVXry3mfWAiaM9kT22Hyz53Uw8mfg==", + "peer": true, "dependencies": { "ajv": "^6.12.6", "content-type": "^1.0.5", @@ -899,6 +900,7 @@ "version": "5.1.0", "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz", "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==", + "peer": true, "dependencies": { "accepts": "^2.0.0", "body-parser": "^2.2.0", @@ -1923,6 +1925,7 @@ "version": "3.25.76", "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } @@ -2132,6 +2135,7 @@ "version": "1.17.5", "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.5.tgz", "integrity": "sha512-QakrKIGniGuRVfWBdMsDea/dx1PNE739QJ7gCM41s9q+qaCYTHCdsIBXQVVXry3mfWAiaM9kT22Hyz53Uw8mfg==", + "peer": true, "requires": { "ajv": "^6.12.6", "content-type": "^1.0.5", @@ -2433,6 +2437,7 @@ "version": "5.1.0", "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz", "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==", + "peer": true, "requires": { "accepts": "^2.0.0", "body-parser": "^2.2.0", @@ -3128,7 +3133,8 @@ "zod": { "version": "3.25.76", "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", - "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==" + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "peer": true }, "zod-to-json-schema": { "version": "3.24.6", diff --git a/package.json b/package.json index 022a575..e08f6e3 100644 --- a/package.json +++ b/package.json @@ -33,7 +33,7 @@ "start": "node dist/server.js", "dev": "tsx watch src/server.ts", "prepublishOnly": "npm run build", - "test": "echo \"No tests yet\" && exit 0" + "test": "node --import tsx --test" }, "files": [ "dist/**/*", diff --git a/src/providers/gemini.ts b/src/providers/gemini.ts index c2d9644..4fe4fae 100644 --- a/src/providers/gemini.ts +++ b/src/providers/gemini.ts @@ -1,15 +1,28 @@ import { GoogleGenAI } from "@google/genai"; -import { GenerateImageArgs, GeneratedImage } from "../types.js"; +import { GenerateImageArgs, GeneratedImage, Provider } from "../types.js"; import { saveBase64Image } from "../utils/fs.js"; +import { z } from "zod"; const DEFAULT_MODEL = "gemini-2.5-flash-image-preview"; -export async function generateImageGemini(args: GenerateImageArgs): Promise { - const apiKey = process.env.GOOGLE_API_KEY; +type GeminiDeps = { + GoogleGenAI?: typeof GoogleGenAI; + saveBase64Image?: typeof saveBase64Image; + now?: () => number; + env?: NodeJS.ProcessEnv; +}; + +async function generateImageGemini( + args: GenerateImageArgs, + deps: GeminiDeps = {} +): Promise { + const env = deps.env ?? process.env; + const apiKey = env.GOOGLE_API_KEY; if (!apiKey) throw new Error("Missing GOOGLE_API_KEY environment variable"); - const model = args.model || process.env.GOOGLE_IMAGE_MODEL || DEFAULT_MODEL; - const ai = new GoogleGenAI({ apiKey }); + const model = args.model || env.GOOGLE_IMAGE_MODEL || DEFAULT_MODEL; + const GoogleGenAIClient = deps.GoogleGenAI ?? GoogleGenAI; + const ai = new GoogleGenAIClient({ apiKey }); // The new SDK accepts contents as a string for simple prompts const response = await ai.models.generateContent({ @@ -47,14 +60,16 @@ export async function generateImageGemini(args: GenerateImageArgs): Promise number; + env?: NodeJS.ProcessEnv; +}; /** * Placeholder Google image generation provider. @@ -17,9 +25,13 @@ import { saveBase64Image } from "../utils/fs.js"; * } * and returns JSON with a base64-encoded image under `image.base64` and optional mimeType. */ -export async function generateImageGoogle(args: GenerateImageArgs): Promise { - const apiKey = process.env.GOOGLE_API_KEY; - const endpoint = process.env.GOOGLE_IMAGEN_ENDPOINT; +async function generateImageGoogle( + args: GenerateImageArgs, + deps: GoogleDeps = {} +): Promise { + const env = deps.env ?? process.env; + const apiKey = env.GOOGLE_API_KEY; + const endpoint = env.GOOGLE_IMAGEN_ENDPOINT; if (!apiKey) throw new Error("Missing GOOGLE_API_KEY environment variable"); if (!endpoint) throw new Error( @@ -39,8 +51,8 @@ export async function generateImageGoogle(args: GenerateImageArgs): Promise { - const apiKey = process.env.OPENAI_API_KEY; +type OpenAIDeps = { + OpenAI?: typeof OpenAI; + saveBase64Image?: typeof saveBase64Image; + now?: () => number; + env?: NodeJS.ProcessEnv; +}; + +async function generateImageOpenAI( + args: GenerateImageArgs, + deps: OpenAIDeps = {} +): Promise { + const env = deps.env ?? process.env; + const apiKey = env.OPENAI_API_KEY; if (!apiKey) { throw new Error("Missing OPENAI_API_KEY environment variable"); } - const client = new OpenAI({ apiKey }); + const OpenAIClient = deps.OpenAI ?? OpenAI; + const client = new OpenAIClient({ apiKey }); const format = (args.format ?? "png").toLowerCase(); const ext = format === "jpeg" ? "jpg" : (format as string); @@ -18,17 +31,17 @@ export async function generateImageOpenAI(args: GenerateImageArgs): Promise width/height > default const size = args.size ? args.size : args.width && args.height - ? `${args.width}x${args.height}` - : "1024x1024"; + ? `${args.width}x${args.height}` + : "1024x1024"; - const model = args.model || process.env.OPENAI_IMAGE_MODEL || DEFAULT_MODEL; + const model = args.model || env.OPENAI_IMAGE_MODEL || DEFAULT_MODEL; // Build request parameters based on model capabilities const requestParams: any = { @@ -66,14 +79,16 @@ export async function generateImageOpenAI(args: GenerateImageArgs): Promise number; + env?: NodeJS.ProcessEnv; +}; + // Available Replicate models -export const REPLICATE_MODELS = { +const REPLICATE_MODELS = { "flux-1.1-pro": "black-forest-labs/flux-1.1-pro", "qwen-image": "qwen/qwen-image", "seedream-4": "bytedance/seedream-4", } as const; -export async function generateImageReplicate(args: GenerateImageArgs): Promise { - const apiKey = process.env.REPLICATE_API_TOKEN; +async function generateImageReplicate( + args: GenerateImageArgs, + deps: ReplicateDeps = {} +): Promise { + const env = deps.env ?? process.env; + const apiKey = env.REPLICATE_API_TOKEN; if (!apiKey) { throw new Error("Missing REPLICATE_API_TOKEN environment variable"); } - const replicate = new Replicate({ auth: apiKey }); + const ReplicateClient = deps.Replicate ?? Replicate; + const replicate = new ReplicateClient({ auth: apiKey }); const format = (args.format ?? "png").toLowerCase(); const ext = format === "jpeg" ? "jpg" : (format as string); @@ -25,8 +39,8 @@ export async function generateImageReplicate(args: GenerateImageArgs): Promise { @@ -42,7 +56,7 @@ export async function generateImageReplicate(args: GenerateImageArgs): Promise (model === DEFAULT_MODEL ? `${model} (default)` : model)) + .join(", ")}. Requires REPLICATE_API_TOKEN.`, + inputSchema: { + prompt: z.string(), + model: z.string().optional(), + width: z.number().int().positive().optional(), + height: z.number().int().positive().optional(), + size: z.string().optional(), + format: z.enum(["png", "jpeg", "jpg", "webp"]).optional(), + seed: z.number().int().optional(), + returnBase64: z.boolean().optional(), + filenameHint: z.string().optional(), + }, + }, + generateImage: generateImageReplicate, +}; \ No newline at end of file diff --git a/src/server.ts b/src/server.ts index 51a5afb..6452b9d 100644 --- a/src/server.ts +++ b/src/server.ts @@ -2,149 +2,68 @@ import "dotenv/config"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import { generateImageOpenAI } from "./providers/openai.js"; -import { generateImageGoogle } from "./providers/google.js"; -import { generateImageGemini } from "./providers/gemini.js"; -import { generateImageReplicate } from "./providers/replicate.js"; -import { GenerateImageArgs } from "./types.js"; -import { z } from "zod"; +import { pathToFileURL } from "node:url"; +import { openaiProvider } from "./providers/openai.js"; +import { googleProvider } from "./providers/google.js"; +import { geminiProvider } from "./providers/gemini.js"; +import { replicateProvider } from "./providers/replicate.js"; +import { GenerateImageArgs, Provider } from "./types.js"; -const server = new McpServer({ - name: "imagegen-mcp-server", - version: "0.1.0", -}); +export const providers: Provider[] = [ + openaiProvider, + googleProvider, + geminiProvider, + replicateProvider, +]; -server.registerTool( - "image.generate.openai", - { - description: - "Generate an image using OpenAI (default model gpt-image-1). Returns a saved file path and optional base64.", - inputSchema: { - prompt: z.string(), - negativePrompt: z.string().optional(), - width: z.number().int().positive().optional(), - height: z.number().int().positive().optional(), - size: z.string().optional(), - format: z.enum(["png", "jpeg", "jpg", "webp"]).optional(), - seed: z.number().int().optional(), - quality: z.enum(["standard", "hd"]).optional(), - style: z.string().optional(), - background: z.enum(["transparent", "solid"]).optional(), - model: z.string().optional(), - returnBase64: z.boolean().optional(), - filenameHint: z.string().optional(), - }, - }, - async (args) => { - const r = await generateImageOpenAI(args as unknown as GenerateImageArgs); - const parts = [] as any[]; - parts.push({ - type: "text", - text: `provider=openai model=${r.model ?? ""} saved=${r.path}`.trim(), - }); - if (r.base64) - parts.push({ type: "image", data: r.base64, mimeType: r.mimeType }); - return { content: parts } as any; - } -); +function formatTextResult(provider: Provider, result: { model?: string; path?: string }) { + const label = provider.responseProviderLabel ?? provider.id; + const bits: string[] = [`provider=${label}`]; + if (result.model) bits.push(`model=${result.model}`); + if (result.path) bits.push(`saved=${result.path}`); + return bits.join(" "); +} -server.registerTool( - "image.generate.google", - { - description: - "Generate an image using Google (e.g., Imagen 3). Requires GOOGLE_API_KEY and GOOGLE_IMAGEN_ENDPOINT. Returns a saved file path and optional base64.", - inputSchema: { - prompt: z.string(), - negativePrompt: z.string().optional(), - width: z.number().int().positive().optional(), - height: z.number().int().positive().optional(), - size: z.string().optional(), - format: z.enum(["png", "jpeg", "jpg", "webp"]).optional(), - seed: z.number().int().optional(), - quality: z.string().optional(), - style: z.string().optional(), - background: z.enum(["transparent", "solid"]).optional(), - model: z.string().optional(), - returnBase64: z.boolean().optional(), - filenameHint: z.string().optional(), - }, - }, - async (args) => { - const r = await generateImageGoogle(args as unknown as GenerateImageArgs); - const parts = [] as any[]; - parts.push({ - type: "text", - text: `provider=google model=${r.model ?? ""} saved=${r.path}`.trim(), - }); - if (r.base64) - parts.push({ type: "image", data: r.base64, mimeType: r.mimeType }); - return { content: parts } as any; +export function registerProviders( + server: { registerTool: (...args: any[]) => any }, + providersToRegister: Provider[] = providers +) { + for (const provider of providersToRegister) { + server.registerTool( + provider.tool.name, + { + description: provider.tool.description, + inputSchema: provider.tool.inputSchema, + }, + async (args: unknown) => { + const r = await provider.generateImage(args as GenerateImageArgs); + const parts = [] as any[]; + parts.push({ + type: "text", + text: formatTextResult(provider, { model: r.model, path: r.path }), + }); + if (r.base64) parts.push({ type: "image", data: r.base64, mimeType: r.mimeType }); + return { content: parts } as any; + } + ); } -); +} -server.registerTool( - "image.generate.gemini", - { - description: - "Generate an image using Google Gemini via @google/genai (default gemini-2.5-flash-image-preview). Requires GOOGLE_API_KEY.", - inputSchema: { - prompt: z.string(), - model: z.string().optional(), - returnBase64: z.boolean().optional(), - filenameHint: z.string().optional(), - }, - }, - async (args) => { - const r = await generateImageGemini(args as unknown as GenerateImageArgs); - const parts = [] as any[]; - parts.push({ - type: "text", - text: `provider=google(gemini) model=${r.model ?? ""} saved=${ - r.path - }`.trim(), - }); - if (r.base64) - parts.push({ type: "image", data: r.base64, mimeType: r.mimeType }); - return { content: parts } as any; - } -); +async function main() { + const server = new McpServer({ + name: "imagegen-mcp-server", + version: "0.1.0", + }); -server.registerTool( - "image.generate.replicate", - { - description: - "Generate an image using Replicate models: Flux 1.1 Pro (default), Qwen Image, or SeedDream-4. Requires REPLICATE_API_TOKEN.", - inputSchema: { - prompt: z.string(), - model: z.string().optional(), - width: z.number().int().positive().optional(), - height: z.number().int().positive().optional(), - size: z.string().optional(), - format: z.enum(["png", "jpeg", "jpg", "webp"]).optional(), - seed: z.number().int().optional(), - returnBase64: z.boolean().optional(), - filenameHint: z.string().optional(), - }, - }, - async (args) => { - const r = await generateImageReplicate(args as unknown as GenerateImageArgs); - const parts = [] as any[]; - parts.push({ - type: "text", - text: `provider=replicate model=${r.model ?? ""} saved=${r.path}`.trim(), - }); - if (r.base64) - parts.push({ type: "image", data: r.base64, mimeType: r.mimeType }); - return { content: parts } as any; - } -); + registerProviders(server); -async function main() { const transport = new StdioServerTransport(); await server.connect(transport); } -main().catch((err) => { - console.error("Fatal error starting server:", err); - process.exit(1); -}); +if (import.meta.url === pathToFileURL(process.argv[1] ?? "").href) { + main().catch((err) => { + console.error("Fatal error starting server:", err); + process.exit(1); + }); +} diff --git a/src/types.ts b/src/types.ts index 669ba5c..26791d7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,7 @@ export type ImageFormat = "png" | "jpeg" | "jpg" | "webp"; +import type { ZodTypeAny } from "zod"; + export interface GenerateImageArgs { prompt: string; negativePrompt?: string; @@ -27,3 +29,20 @@ export interface GeneratedImage { provider: "openai" | "google" | "replicate"; model?: string; } + +export type ProviderId = "openai" | "google" | "gemini" | "replicate"; + +export type ProviderToolDefinition = { + name: string; // MCP tool name (e.g. "image.generate.openai") + description: string; + inputSchema: Record; +}; + +export type Provider = { + id: ProviderId; + displayName: string; + tool: ProviderToolDefinition; + generateImage: (args: GenerateImageArgs) => Promise; + // Optional label used in the text response (e.g. "google(gemini)") + responseProviderLabel?: string; +}; diff --git a/test/providers/gemini.provider.test.ts b/test/providers/gemini.provider.test.ts new file mode 100644 index 0000000..3bdbf86 --- /dev/null +++ b/test/providers/gemini.provider.test.ts @@ -0,0 +1,57 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +const { geminiProvider } = await import("../../src/providers/gemini.ts"); + +describe("geminiProvider", () => { + describe("generateImage", () => { + it("returns image data without calling network", async () => { + const env = { + GOOGLE_API_KEY: "test-key", + OUTPUT_DIR: "outputs-test", + } as NodeJS.ProcessEnv; + + class FakeGoogleGenAI { + models = { + generateContent: async (_req: any) => ({ + candidates: [ + { + content: { + parts: [ + { + inlineData: { + data: "AAAA", + mimeType: "image/png", + }, + }, + ], + }, + }, + ], + }), + }; + constructor(_opts: any) { } + } + + const save = (_b64: string, outputDir: string, filename: string) => `${outputDir}/${filename}`; + const now = () => 123; + + const r = await (geminiProvider.generateImage as any)( + { prompt: "a mountain", returnBase64: true }, + { env, GoogleGenAI: FakeGoogleGenAI as any, saveBase64Image: save as any, now } + ); + assert.equal(r.provider, "google"); + assert.equal(r.mimeType, "image/png"); + assert.ok(r.path?.startsWith("outputs-test")); + assert.ok(r.base64); + }); + }); + + describe("tool descriptor", () => { + it("preserves response label", () => { + assert.equal(geminiProvider.id, "gemini"); + assert.equal(geminiProvider.responseProviderLabel, "google(gemini)"); + assert.ok(geminiProvider.tool.name); + }); + }); +}); diff --git a/test/providers/google.provider.test.ts b/test/providers/google.provider.test.ts new file mode 100644 index 0000000..4920022 --- /dev/null +++ b/test/providers/google.provider.test.ts @@ -0,0 +1,52 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +const { googleProvider } = await import("../../src/providers/google.ts"); + +describe("googleProvider", () => { + describe("generateImage", () => { + it("uses fetch endpoint and returns image", async () => { + const env = { + GOOGLE_API_KEY: "test-key", + GOOGLE_IMAGEN_ENDPOINT: "https://generativelanguage.googleapis.com/v1/images", + OUTPUT_DIR: "outputs-test", + } as NodeJS.ProcessEnv; + + let lastFetchArgs: any[] | undefined; + const fetchFn = async (...args: any[]) => { + lastFetchArgs = args; + return { + ok: true, + async json() { + return { image: { base64: "AAAA", mimeType: "image/png" } }; + }, + }; + }; + + const save = (_b64: string, outputDir: string, filename: string) => `${outputDir}/${filename}`; + const now = () => 123; + + const r = await (googleProvider.generateImage as any)( + { prompt: "dog" }, + { env, fetch: fetchFn as any, saveBase64Image: save as any, now } + ); + assert.equal(r.provider, "google"); + assert.equal(r.mimeType, "image/png"); + assert.ok(r.path?.startsWith("outputs-test")); + + assert.ok(lastFetchArgs); + const [_url, init] = lastFetchArgs!; + assert.equal(init.method, "POST"); + assert.equal(init.headers["x-goog-api-key"], "test-key"); + }); + }); + + describe("tool descriptor", () => { + it("exposes tool metadata", () => { + assert.equal(googleProvider.id, "google"); + assert.ok(googleProvider.tool.name); + assert.ok(googleProvider.tool.description); + assert.ok(googleProvider.tool.inputSchema.prompt); + }); + }); +}); diff --git a/test/providers/openai.provider.test.ts b/test/providers/openai.provider.test.ts new file mode 100644 index 0000000..7a1a8f3 --- /dev/null +++ b/test/providers/openai.provider.test.ts @@ -0,0 +1,52 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +const { openaiProvider } = await import("../../src/providers/openai.ts"); + +describe("openaiProvider", () => { + describe("generateImage", () => { + it("returns saved path and optional base64", async () => { + const env = { + OPENAI_API_KEY: "test-key", + OUTPUT_DIR: "outputs-test", + } as NodeJS.ProcessEnv; + + class FakeOpenAI { + images = { + generate: async (_params: any) => ({ + data: [{ b64_json: "aGVsbG8=" }], + }), + }; + constructor(_opts: any) { } + } + + const save = (_b64: string, outputDir: string, filename: string) => `${outputDir}/${filename}`; + const now = () => 123; + + const r1 = await (openaiProvider.generateImage as any)( + { prompt: "cat" }, + { OpenAI: FakeOpenAI as any, env, saveBase64Image: save as any, now } + ); + assert.equal(r1.provider, "openai"); + assert.equal(r1.mimeType, "image/png"); + assert.ok(r1.path?.startsWith("outputs-test")); + assert.equal(r1.base64, undefined); + + const r2 = await (openaiProvider.generateImage as any)( + { prompt: "cat", returnBase64: true, format: "png" }, + { OpenAI: FakeOpenAI as any, env, saveBase64Image: save as any, now } + ); + assert.equal(r2.provider, "openai"); + assert.ok(r2.base64); + }); + }); + + describe("tool descriptor", () => { + it("exposes tool metadata", () => { + assert.equal(openaiProvider.id, "openai"); + assert.ok(openaiProvider.tool.name); + assert.ok(openaiProvider.tool.description); + assert.ok(openaiProvider.tool.inputSchema.prompt); + }); + }); +}); diff --git a/test/providers/providers.test.ts b/test/providers/providers.test.ts new file mode 100644 index 0000000..6d85abf --- /dev/null +++ b/test/providers/providers.test.ts @@ -0,0 +1,28 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +import { providers } from "../../src/server.ts"; + +describe("providers", () => { + describe("tool metadata", () => { + it("has unique tool names", () => { + const names = providers.map((p) => p.tool.name); + const unique = new Set(names); + assert.equal(unique.size, names.length); + }); + + it("has required fields", () => { + for (const p of providers) { + assert.ok(p.id); + assert.ok(p.displayName); + + assert.ok(p.tool?.name); + assert.ok(p.tool?.description); + assert.ok(p.tool?.inputSchema); + + // All tools should at least accept prompt + assert.ok("prompt" in p.tool.inputSchema); + } + }); + }); +}); diff --git a/test/providers/replicate.provider.test.ts b/test/providers/replicate.provider.test.ts new file mode 100644 index 0000000..a70df27 --- /dev/null +++ b/test/providers/replicate.provider.test.ts @@ -0,0 +1,53 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +const { replicateProvider } = await import("../../src/providers/replicate.ts"); + +describe("replicateProvider", () => { + describe("generateImage", () => { + it("downloads image URL and returns saved path", async () => { + const env = { + REPLICATE_API_TOKEN: "test-token", + OUTPUT_DIR: "outputs-test", + } as NodeJS.ProcessEnv; + + class FakeReplicate { + constructor(_opts: any) { } + async run(_model: any, _opts: any) { + return "https://example.test/image.webp"; + } + } + + const fetchFn = async (url: string) => { + assert.equal(url, "https://example.test/image.webp"); + return { + ok: true, + statusText: "OK", + async arrayBuffer() { + return new Uint8Array([1, 2, 3, 4]).buffer; + }, + }; + }; + + const save = (_b64: string, outputDir: string, filename: string) => `${outputDir}/${filename}`; + const now = () => 123; + + const r = await (replicateProvider.generateImage as any)( + { prompt: "robot", format: "webp", returnBase64: true }, + { env, Replicate: FakeReplicate as any, fetch: fetchFn as any, saveBase64Image: save as any, now } + ); + assert.equal(r.provider, "replicate"); + assert.equal(r.mimeType, "image/webp"); + assert.ok(r.path?.startsWith("outputs-test")); + assert.ok(r.base64); + }); + }); + + describe("tool descriptor", () => { + it("exposes tool metadata", () => { + assert.equal(replicateProvider.id, "replicate"); + assert.ok(replicateProvider.tool.name); + assert.ok(replicateProvider.tool.inputSchema.prompt); + }); + }); +}); diff --git a/test/server-registration.test.ts b/test/server-registration.test.ts new file mode 100644 index 0000000..7034321 --- /dev/null +++ b/test/server-registration.test.ts @@ -0,0 +1,103 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; + +import { registerProviders } from "../src/server.ts"; +import type { Provider } from "../src/types.ts"; + +describe("registerProviders", () => { + describe("registration", () => { + it("registers one tool per provider", () => { + const calls: Array<{ name: string; meta: any; handler: Function }> = []; + + const fakeServer = { + registerTool(name: string, meta: any, handler: Function) { + calls.push({ name, meta, handler }); + }, + }; + + const p1: Provider = { + id: "openai", + displayName: "OpenAI", + tool: { + name: "image.generate.openai", + description: "desc", + inputSchema: { prompt: { _fake: true } as any }, + }, + generateImage: async () => ({ + mimeType: "image/png", + provider: "openai", + path: "outputs/x.png", + model: "m", + }), + }; + + const p2: Provider = { + id: "gemini", + displayName: "Gemini", + responseProviderLabel: "google(gemini)", + tool: { + name: "image.generate.gemini", + description: "desc", + inputSchema: { prompt: { _fake: true } as any }, + }, + generateImage: async () => ({ + mimeType: "image/png", + provider: "google", + path: "outputs/y.png", + model: "m2", + base64: "Zm9v", + }), + }; + + registerProviders(fakeServer as any, [p1, p2]); + + assert.equal(calls.length, 2); + assert.equal(calls[0].name, "image.generate.openai"); + assert.equal(calls[1].name, "image.generate.gemini"); + }); + }); + + describe("handler", () => { + it("formats text and optional image", async () => { + const calls: Array<{ name: string; meta: any; handler: (args: unknown) => Promise }> = []; + + const fakeServer = { + registerTool(name: string, meta: any, handler: any) { + calls.push({ name, meta, handler }); + }, + }; + + const provider: Provider = { + id: "replicate", + displayName: "Replicate", + tool: { + name: "image.generate.replicate", + description: "desc", + inputSchema: { prompt: { _fake: true } as any }, + }, + generateImage: async (args: any) => ({ + mimeType: "image/webp", + provider: "replicate", + model: "flux", + path: "outputs/z.webp", + base64: args.returnBase64 ? "AAAA" : undefined, + }), + }; + + registerProviders(fakeServer as any, [provider]); + assert.equal(calls.length, 1); + + const r1 = await calls[0].handler({ prompt: "hi" }); + assert.equal(r1.content[0].type, "text"); + assert.match(r1.content[0].text, /provider=replicate/); + assert.match(r1.content[0].text, /model=flux/); + assert.match(r1.content[0].text, /saved=outputs\/z\.webp/); + assert.equal(r1.content.length, 1); + + const r2 = await calls[0].handler({ prompt: "hi", returnBase64: true }); + assert.equal(r2.content[0].type, "text"); + assert.equal(r2.content[1].type, "image"); + assert.equal(r2.content[1].mimeType, "image/webp"); + }); + }); +});