diff --git a/eslint-rules/enforce-zod-v4.js b/eslint-rules/enforce-zod-v4.js index c631412f..86f33d75 100644 --- a/eslint-rules/enforce-zod-v4.js +++ b/eslint-rules/enforce-zod-v4.js @@ -3,6 +3,7 @@ import path from "path"; // The file that is allowed to import from zod/v4 const configFilePath = path.resolve(import.meta.dirname, "../src/common/config.ts"); +const schemasFilePath = path.resolve(import.meta.dirname, "../src/common/schemas.ts"); // Ref: https://eslint.org/docs/latest/extend/custom-rules export default { diff --git a/scripts/generateArguments.ts b/scripts/generateArguments.ts index 414d06ba..ee3695c4 100644 --- a/scripts/generateArguments.ts +++ b/scripts/generateArguments.ts @@ -11,9 +11,10 @@ import { readFileSync, writeFileSync } from "fs"; import { join, dirname } from "path"; import { fileURLToPath } from "url"; -import { OPTIONS, UserConfigSchema, defaultUserConfig, configRegistry } from "../src/common/config.js"; +import { UserConfigSchema, configRegistry } from "../src/common/config.js"; import assert from "assert"; import { execSync } from "child_process"; +import { OPTIONS } from "../src/common/argsParserOptions.js"; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); @@ -68,7 +69,8 @@ function extractZodDescriptions(): Record { let description = schema.description || `Configuration option: ${key}`; if ("innerType" in schema.def) { - if (schema.def.innerType.def.type === "array") { + // "pipe" is used for our comma-separated arrays + if (schema.def.innerType.def.type === "pipe") { assert( description.startsWith("An array of"), `Field description for field "${key}" with array type does not start with 'An array of'` @@ -255,9 +257,7 @@ function generateReadmeConfigTable(argumentInfos: ArgumentInfo[]): string { const cliOption = `\`${argumentInfo.configKey}\``; const envVarName = `\`${argumentInfo.name}\``; - // Get default value from Zod schema or fallback to defaultUserConfig - const config = defaultUserConfig as unknown as Record; - const defaultValue = argumentInfo.defaultValue ?? config[argumentInfo.configKey]; + const defaultValue = argumentInfo.defaultValue; let defaultValueString = argumentInfo.defaultValueDescription ?? "``"; if (!argumentInfo.defaultValueDescription && defaultValue !== undefined && defaultValue !== null) { diff --git a/src/common/argsParserOptions.ts b/src/common/argsParserOptions.ts new file mode 100644 index 00000000..8decc318 --- /dev/null +++ b/src/common/argsParserOptions.ts @@ -0,0 +1,109 @@ +type ArgsParserOptions = { + string: string[]; + number: string[]; + boolean: string[]; + array: string[]; + alias: Record; + configuration: Record; +}; + +// TODO: Export this from arg-parser or find a better way to do this +// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts +export const OPTIONS = { + number: ["maxDocumentsPerQuery", "maxBytesPerQuery"], + string: [ + "apiBaseUrl", + "apiClientId", + "apiClientSecret", + "connectionString", + "httpHost", + "httpPort", + "idleTimeoutMs", + "logPath", + "notificationTimeoutMs", + "telemetry", + "transport", + "apiVersion", + "authenticationDatabase", + "authenticationMechanism", + "browser", + "db", + "gssapiHostName", + "gssapiServiceName", + "host", + "oidcFlows", + "oidcRedirectUri", + "password", + "port", + "sslCAFile", + "sslCRLFile", + "sslCertificateSelector", + "sslDisabledProtocols", + "sslPEMKeyFile", + "sslPEMKeyPassword", + "sspiHostnameCanonicalization", + "sspiRealmOverride", + "tlsCAFile", + "tlsCRLFile", + "tlsCertificateKeyFile", + "tlsCertificateKeyFilePassword", + "tlsCertificateSelector", + "tlsDisabledProtocols", + "username", + "atlasTemporaryDatabaseUserLifetimeMs", + "exportsPath", + "exportTimeoutMs", + "exportCleanupIntervalMs", + "voyageApiKey", + ], + boolean: [ + "apiDeprecationErrors", + "apiStrict", + "disableEmbeddingsValidation", + "help", + "indexCheck", + "ipv6", + "nodb", + "oidcIdTokenAsAccessToken", + "oidcNoNonce", + "oidcTrustedEndpoint", + "readOnly", + "retryWrites", + "ssl", + "sslAllowInvalidCertificates", + "sslAllowInvalidHostnames", + "sslFIPSMode", + "tls", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsFIPSMode", + "version", + ], + array: ["disabledTools", "loggers", "confirmationRequiredTools", "previewFeatures"], + alias: { + h: "help", + p: "password", + u: "username", + "build-info": "buildInfo", + browser: "browser", + oidcDumpTokens: "oidcDumpTokens", + oidcRedirectUrl: "oidcRedirectUri", + oidcIDTokenAsAccessToken: "oidcIdTokenAsAccessToken", + }, + configuration: { + "camel-case-expansion": false, + "unknown-options-as-args": true, + "parse-positional-numbers": false, + "parse-numbers": false, + "greedy-arrays": true, + "short-option-groups": false, + }, +} as Readonly; + +export const ALL_CONFIG_KEYS = new Set( + (OPTIONS.string as readonly string[]) + .concat(OPTIONS.number) + .concat(OPTIONS.array) + .concat(OPTIONS.boolean) + .concat(Object.keys(OPTIONS.alias)) +); diff --git a/src/common/config.ts b/src/common/config.ts index 902fbcc2..e7ece022 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -1,184 +1,20 @@ -import path from "path"; -import os from "os"; import argv from "yargs-parser"; import type { CliOptions, ConnectionInfo } from "@mongosh/arg-parser"; import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser"; import { Keychain } from "./keychain.js"; import type { Secret } from "./keychain.js"; -import * as levenshteinModule from "ts-levenshtein"; -import type { Similarity } from "./search/vectorSearchEmbeddingsManager.js"; import { z as z4 } from "zod/v4"; -const levenshtein = levenshteinModule.default; - -// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts -export const OPTIONS = { - number: ["maxDocumentsPerQuery", "maxBytesPerQuery"], - string: [ - "apiBaseUrl", - "apiClientId", - "apiClientSecret", - "connectionString", - "httpHost", - "httpPort", - "idleTimeoutMs", - "logPath", - "notificationTimeoutMs", - "telemetry", - "transport", - "apiVersion", - "authenticationDatabase", - "authenticationMechanism", - "browser", - "db", - "gssapiHostName", - "gssapiServiceName", - "host", - "oidcFlows", - "oidcRedirectUri", - "password", - "port", - "sslCAFile", - "sslCRLFile", - "sslCertificateSelector", - "sslDisabledProtocols", - "sslPEMKeyFile", - "sslPEMKeyPassword", - "sspiHostnameCanonicalization", - "sspiRealmOverride", - "tlsCAFile", - "tlsCRLFile", - "tlsCertificateKeyFile", - "tlsCertificateKeyFilePassword", - "tlsCertificateSelector", - "tlsDisabledProtocols", - "username", - "atlasTemporaryDatabaseUserLifetimeMs", - "exportsPath", - "exportTimeoutMs", - "exportCleanupIntervalMs", - "voyageApiKey", - ], - boolean: [ - "apiDeprecationErrors", - "apiStrict", - "disableEmbeddingsValidation", - "help", - "indexCheck", - "ipv6", - "nodb", - "oidcIdTokenAsAccessToken", - "oidcNoNonce", - "oidcTrustedEndpoint", - "readOnly", - "retryWrites", - "ssl", - "sslAllowInvalidCertificates", - "sslAllowInvalidHostnames", - "sslFIPSMode", - "tls", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsFIPSMode", - "version", - ], - array: ["disabledTools", "loggers", "confirmationRequiredTools", "previewFeatures"], - alias: { - h: "help", - p: "password", - u: "username", - "build-info": "buildInfo", - browser: "browser", - oidcDumpTokens: "oidcDumpTokens", - oidcRedirectUrl: "oidcRedirectUri", - oidcIDTokenAsAccessToken: "oidcIdTokenAsAccessToken", - }, - configuration: { - "camel-case-expansion": false, - "unknown-options-as-args": true, - "parse-positional-numbers": false, - "parse-numbers": false, - "greedy-arrays": true, - "short-option-groups": false, - }, -} as Readonly; - -interface Options { - string: string[]; - number: string[]; - boolean: string[]; - array: string[]; - alias: Record; - configuration: Record; -} - -export const ALL_CONFIG_KEYS = new Set( - (OPTIONS.string as readonly string[]) - .concat(OPTIONS.number) - .concat(OPTIONS.array) - .concat(OPTIONS.boolean) - .concat(Object.keys(OPTIONS.alias)) -); - -function validateConfigKey(key: string): { valid: boolean; suggestion?: string } { - if (ALL_CONFIG_KEYS.has(key)) { - return { valid: true }; - } - - let minLev = Number.MAX_VALUE; - let suggestion = ""; - - // find the closest match for a suggestion - for (const validKey of ALL_CONFIG_KEYS) { - // check if there is an exact case-insensitive match - if (validKey.toLowerCase() === key.toLowerCase()) { - return { valid: false, suggestion: validKey }; - } +import { + commaSeparatedToArray, + type ConfigFieldMeta, + getExportsPath, + getLogPath, + isConnectionSpecifier, + validateConfigKey, +} from "./configUtils.js"; +import { OPTIONS } from "./argsParserOptions.js"; +import { similarityValues, previewFeatureValues } from "./schemas.js"; - // else, infer something using levenshtein so we suggest a valid key - const lev = levenshtein.get(key, validKey); - if (lev < minLev) { - minLev = lev; - suggestion = validKey; - } - } - - if (minLev <= 2) { - // accept up to 2 typos - return { valid: false, suggestion }; - } - - return { valid: false }; -} - -function isConnectionSpecifier(arg: string | undefined): boolean { - return ( - arg !== undefined && - (arg.startsWith("mongodb://") || - arg.startsWith("mongodb+srv://") || - !(arg.endsWith(".js") || arg.endsWith(".mongodb"))) - ); -} - -/** - * Metadata for config schema fields. - */ -interface ConfigFieldMeta { - /** - * Custom description for the default value, used when generating documentation. - */ - defaultValueDescription?: string; - /** - * Marks the field as containing sensitive/secret information, used for MCP Registry. - * Secret fields will be marked as secret in environment variable definitions. - */ - isSecret?: boolean; - - [key: string]: unknown; -} - -/** - * Custom registry for storing metadata specific to config schema fields. - */ export const configRegistry = z4.registry(); export const UserConfigSchema = z4.object({ @@ -201,7 +37,16 @@ export const UserConfigSchema = z4.object({ ) .register(configRegistry, { isSecret: true }), loggers: z4 - .array(z4.enum(["stderr", "disk", "mcp"])) + .preprocess( + (val: string | string[] | undefined) => commaSeparatedToArray(val), + z4.array(z4.enum(["stderr", "disk", "mcp"])) + ) + .check( + z4.minLength(1, "Cannot be an empty array"), + z4.refine((val) => new Set(val).size === val.length, { + message: "Duplicate loggers found in config", + }) + ) .default(["disk", "mcp"]) .describe("An array of logger types.") .register(configRegistry, { @@ -209,14 +54,15 @@ export const UserConfigSchema = z4.object({ }), logPath: z4 .string() + .default(getLogPath()) .describe("Folder to store logs.") .register(configRegistry, { defaultValueDescription: "see below*" }), disabledTools: z4 - .array(z4.string()) + .preprocess((val: string | string[] | undefined) => commaSeparatedToArray(val), z4.array(z4.string())) .default([]) .describe("An array of tool names, operation types, and/or categories of tools that will be disabled."), confirmationRequiredTools: z4 - .array(z4.string()) + .preprocess((val: string | string[] | undefined) => commaSeparatedToArray(val), z4.array(z4.string())) .default([ "atlas-create-access-list", "atlas-create-db-user", @@ -245,8 +91,11 @@ export const UserConfigSchema = z4.object({ .default("enabled") .describe("When set to disabled, disables telemetry collection."), transport: z4.enum(["stdio", "http"]).default("stdio").describe("Either 'stdio' or 'http'."), - httpPort: z4 + httpPort: z4.coerce .number() + .int() + .min(1, "Invalid httpPort: must be at least 1") + .max(65535, "Invalid httpPort: must be at most 65535") .default(3000) .describe("Port number for the HTTP server (only used when transport is 'http')."), httpHost: z4 @@ -260,21 +109,21 @@ export const UserConfigSchema = z4.object({ .describe( "Header that the HTTP server will validate when making requests (only used when transport is 'http')." ), - idleTimeoutMs: z4 + idleTimeoutMs: z4.coerce .number() .default(600_000) .describe("Idle timeout for a client to disconnect (only applies to http transport)."), - notificationTimeoutMs: z4 + notificationTimeoutMs: z4.coerce .number() .default(540_000) .describe("Notification timeout for a client to be aware of disconnect (only applies to http transport)."), - maxBytesPerQuery: z4 + maxBytesPerQuery: z4.coerce .number() .default(16_777_216) .describe( "The maximum size in bytes for results from a find or aggregate tool call. This serves as an upper bound for the responseBytesLimit parameter in those tools." ), - maxDocumentsPerQuery: z4 + maxDocumentsPerQuery: z4.coerce .number() .default(100) .describe( @@ -282,17 +131,18 @@ export const UserConfigSchema = z4.object({ ), exportsPath: z4 .string() + .default(getExportsPath()) .describe("Folder to store exported data files.") .register(configRegistry, { defaultValueDescription: "see below*" }), - exportTimeoutMs: z4 + exportTimeoutMs: z4.coerce .number() .default(300_000) .describe("Time in milliseconds after which an export is considered expired and eligible for cleanup."), - exportCleanupIntervalMs: z4 + exportCleanupIntervalMs: z4.coerce .number() .default(120_000) .describe("Time in milliseconds between export cleanup cycles that remove expired export files."), - atlasTemporaryDatabaseUserLifetimeMs: z4 + atlasTemporaryDatabaseUserLifetimeMs: z4.coerce .number() .default(14_400_000) .describe( @@ -307,73 +157,32 @@ export const UserConfigSchema = z4.object({ .register(configRegistry, { isSecret: true }), disableEmbeddingsValidation: z4 .boolean() - .optional() + .default(false) .describe("When set to true, disables validation of embeddings dimensions."), - vectorSearchDimensions: z4 + vectorSearchDimensions: z4.coerce .number() .default(1024) .describe("Default number of dimensions for vector search embeddings."), vectorSearchSimilarityFunction: z4 - .custom() - .optional() + .enum(similarityValues) .default("euclidean") .describe("Default similarity function for vector search: 'euclidean', 'cosine', or 'dotProduct'."), previewFeatures: z4 - .array(z4.enum(["vectorSearch"])) + .preprocess( + (val: string | string[] | undefined) => commaSeparatedToArray(val), + z4.array(z4.enum(previewFeatureValues)) + ) .default([]) .describe("An array of preview features that are enabled."), }); -export type PreviewFeature = z4.infer["previewFeatures"][number]; export type UserConfig = z4.infer & CliOptions; -export const defaultUserConfig: UserConfig = { - apiBaseUrl: "https://cloud.mongodb.com/", - logPath: getLogPath(), - exportsPath: getExportsPath(), - exportTimeoutMs: 5 * 60 * 1000, // 5 minutes - exportCleanupIntervalMs: 2 * 60 * 1000, // 2 minutes - disabledTools: [], - telemetry: "enabled", - readOnly: false, - indexCheck: false, - confirmationRequiredTools: [ - "atlas-create-access-list", - "atlas-create-db-user", - "drop-database", - "drop-collection", - "delete-many", - "drop-index", - ], - transport: "stdio", - httpPort: 3000, - httpHost: "127.0.0.1", - loggers: ["disk", "mcp"], - idleTimeoutMs: 10 * 60 * 1000, // 10 minutes - notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes - httpHeaders: {}, - maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation - maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation - atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours - voyageApiKey: "", - disableEmbeddingsValidation: false, - vectorSearchDimensions: 1024, - vectorSearchSimilarityFunction: "euclidean", - previewFeatures: [], -}; - export const config = setupUserConfig({ - defaults: defaultUserConfig, cli: process.argv, env: process.env, }); -function getLocalDataPath(): string { - return process.platform === "win32" - ? path.join(process.env.LOCALAPPDATA || process.env.APPDATA || os.homedir(), "mongodb") - : path.join(os.homedir(), ".mongodb"); -} - export type DriverOptions = ConnectionInfo["driverOptions"]; export const defaultDriverOptions: DriverOptions = { readConcern: { @@ -388,22 +197,17 @@ export const defaultDriverOptions: DriverOptions = { applyProxyToOIDC: true, }; -function getLogPath(): string { - const logPath = path.join(getLocalDataPath(), "mongodb-mcp", ".app-logs"); - return logPath; -} - -function getExportsPath(): string { - return path.join(getLocalDataPath(), "mongodb-mcp", "exports"); -} - // Gets the config supplied by the user as environment variables. The variable names // are prefixed with `MDB_MCP_` and the keys match the UserConfig keys, but are converted // to SNAKE_UPPER_CASE. function parseEnvConfig(env: Record): Partial { const CONFIG_WITH_URLS: Set = new Set<(typeof OPTIONS)["string"][number]>(["connectionString"]); - function setValue(obj: Record, path: string[], value: string): void { + function setValue( + obj: Record | undefined>, + path: string[], + value: string + ): void { const currentField = path.shift(); if (!currentField) { return; @@ -440,10 +244,10 @@ function parseEnvConfig(env: Record): Partial { obj[currentField] = {}; } - setValue(obj[currentField] as Record, path, value); + setValue(obj[currentField] as Record, path, value); } - const result: Record = {}; + const result: Record = {}; const mcpVariables = Object.entries(env).filter( ([key, value]) => value !== undefined && key.startsWith("MDB_MCP_") ) as [string, string][]; @@ -468,12 +272,14 @@ function SNAKE_CASE_toCamelCase(str: string): string { // We will consolidate them in a way where the mongosh format takes precedence. // We will warn users that previous configuration is deprecated in favour of // whatever is in mongosh. -function parseCliConfig(args: string[]): CliOptions { +function parseCliConfig(args: string[]): Partial> { const programArgs = args.slice(2); - const parsed = argv(programArgs, OPTIONS as unknown as argv.Options) as unknown as CliOptions & - UserConfig & { - _?: string[]; - }; + const parsed = argv(programArgs, OPTIONS as unknown as argv.Options) as unknown as Record< + keyof CliOptions, + string | number | undefined + > & { + _?: string[]; + }; const positionalArguments = parsed._ ?? []; @@ -543,29 +349,6 @@ export function warnAboutDeprecatedOrUnknownCliArgs( } } -function commaSeparatedToArray(str: string | string[] | undefined): T { - if (str === undefined) { - return [] as unknown as T; - } - - if (!Array.isArray(str)) { - return [str] as T; - } - - if (str.length === 0) { - return str as T; - } - - if (str.length === 1) { - return str[0] - ?.split(",") - .map((e) => e.trim()) - .filter((e) => e.length > 0) as T; - } - - return str as T; -} - export function registerKnownSecretsInRootKeychain(userConfig: Partial): void { const keychain = Keychain.root; @@ -589,59 +372,25 @@ export function registerKnownSecretsInRootKeychain(userConfig: Partial; - defaults: UserConfig; -}): UserConfig { - const userConfig = { - ...defaults, +export function setupUserConfig({ cli, env }: { cli: string[]; env: Record }): UserConfig { + const rawConfig = { ...parseEnvConfig(env), ...parseCliConfig(cli), - } satisfies UserConfig; - - userConfig.disabledTools = commaSeparatedToArray(userConfig.disabledTools); - userConfig.loggers = commaSeparatedToArray(userConfig.loggers); - userConfig.confirmationRequiredTools = commaSeparatedToArray(userConfig.confirmationRequiredTools); - - if (userConfig.connectionString && userConfig.connectionSpecifier) { - const connectionInfo = generateConnectionInfoFromCliArgs(userConfig); - userConfig.connectionString = connectionInfo.connectionString; - } - - const transport = userConfig.transport as string; - if (transport !== "http" && transport !== "stdio") { - throw new Error(`Invalid transport: ${transport}`); - } - - const telemetry = userConfig.telemetry as string; - if (telemetry !== "enabled" && telemetry !== "disabled") { - throw new Error(`Invalid telemetry: ${telemetry}`); - } - - const httpPort = +userConfig.httpPort; - if (httpPort < 1 || httpPort > 65535 || isNaN(httpPort)) { - throw new Error(`Invalid httpPort: ${userConfig.httpPort}`); - } - - if (userConfig.loggers.length === 0) { - throw new Error("No loggers found in config"); - } + }; - const loggerTypes = new Set(userConfig.loggers); - if (loggerTypes.size !== userConfig.loggers.length) { - throw new Error("Duplicate loggers found in config"); + if (rawConfig.connectionString && rawConfig.connectionSpecifier) { + const connectionInfo = generateConnectionInfoFromCliArgs(rawConfig as UserConfig); + rawConfig.connectionString = connectionInfo.connectionString; } - for (const loggerType of userConfig.loggers as string[]) { - if (loggerType !== "mcp" && loggerType !== "disk" && loggerType !== "stderr") { - throw new Error(`Invalid logger: ${loggerType}`); - } + const parseResult = UserConfigSchema.safeParse(rawConfig); + if (parseResult.error) { + throw new Error( + `Invalid configuration for the following fields:\n${parseResult.error.issues.map((issue) => `${issue.path.join(".")} - ${issue.message}`).join("\n")}` + ); } + // We don't have as schema defined for all args-parser arguments so we need to merge the raw config with the parsed config. + const userConfig = { ...rawConfig, ...parseResult.data } as UserConfig; registerKnownSecretsInRootKeychain(userConfig); return userConfig; diff --git a/src/common/configUtils.ts b/src/common/configUtils.ts new file mode 100644 index 00000000..47285152 --- /dev/null +++ b/src/common/configUtils.ts @@ -0,0 +1,96 @@ +import path from "path"; +import os from "os"; +import { ALL_CONFIG_KEYS } from "./argsParserOptions.js"; +import * as levenshteinModule from "ts-levenshtein"; +const levenshtein = levenshteinModule.default; + +export function validateConfigKey(key: string): { valid: boolean; suggestion?: string } { + if (ALL_CONFIG_KEYS.has(key)) { + return { valid: true }; + } + + let minLev = Number.MAX_VALUE; + let suggestion = ""; + + // find the closest match for a suggestion + for (const validKey of ALL_CONFIG_KEYS) { + // check if there is an exact case-insensitive match + if (validKey.toLowerCase() === key.toLowerCase()) { + return { valid: false, suggestion: validKey }; + } + + // else, infer something using levenshtein so we suggest a valid key + const lev = levenshtein.get(key, validKey); + if (lev < minLev) { + minLev = lev; + suggestion = validKey; + } + } + + if (minLev <= 2) { + // accept up to 2 typos + return { valid: false, suggestion }; + } + + return { valid: false }; +} + +export function isConnectionSpecifier(arg: string | undefined): boolean { + return ( + arg !== undefined && + (arg.startsWith("mongodb://") || + arg.startsWith("mongodb+srv://") || + !(arg.endsWith(".js") || arg.endsWith(".mongodb"))) + ); +} + +/** + * Metadata for config schema fields. + */ +export type ConfigFieldMeta = { + /** + * Custom description for the default value, used when generating documentation. + */ + defaultValueDescription?: string; + /** + * Marks the field as containing sensitive/secret information, used for MCP Registry. + * Secret fields will be marked as secret in environment variable definitions. + */ + isSecret?: boolean; + + [key: string]: unknown; +}; + +export function getLocalDataPath(): string { + return process.platform === "win32" + ? path.join(process.env.LOCALAPPDATA || process.env.APPDATA || os.homedir(), "mongodb") + : path.join(os.homedir(), ".mongodb"); +} + +export function getLogPath(): string { + const logPath = path.join(getLocalDataPath(), "mongodb-mcp", ".app-logs"); + return logPath; +} + +export function getExportsPath(): string { + return path.join(getLocalDataPath(), "mongodb-mcp", "exports"); +} + +export function commaSeparatedToArray(str: string | string[] | undefined): T | undefined { + if (str === undefined) { + return undefined; + } + + if (!Array.isArray(str)) { + return [str] as T; + } + + if (str.length === 1) { + return str[0] + ?.split(",") + .map((e) => e.trim()) + .filter((e) => e.length > 0) as T; + } + + return str as T; +} diff --git a/src/common/schemas.ts b/src/common/schemas.ts new file mode 100644 index 00000000..6375d25c --- /dev/null +++ b/src/common/schemas.ts @@ -0,0 +1,6 @@ +export const previewFeatureValues = ["vectorSearch"] as const; +export type PreviewFeature = (typeof previewFeatureValues)[number]; + +export const similarityValues = ["cosine", "euclidean", "dotProduct"] as const; + +export type Similarity = (typeof similarityValues)[number]; diff --git a/src/common/search/vectorSearchEmbeddingsManager.ts b/src/common/search/vectorSearchEmbeddingsManager.ts index 1af3a8a6..e570f064 100644 --- a/src/common/search/vectorSearchEmbeddingsManager.ts +++ b/src/common/search/vectorSearchEmbeddingsManager.ts @@ -7,9 +7,7 @@ import { ErrorCodes, MongoDBError } from "../errors.js"; import { getEmbeddingsProvider } from "./embeddingsProvider.js"; import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js"; import { formatUntrustedData } from "../../tools/tool.js"; - -export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]); -export type Similarity = z.infer; +import type { Similarity } from "../schemas.js"; export const quantizationEnum = z.enum(["none", "scalar", "binary"]); export type Quantization = z.infer; diff --git a/src/lib.ts b/src/lib.ts index c81472e0..a9d5cfed 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,6 +1,6 @@ export { Server, type ServerOptions } from "./server.js"; export { Session, type SessionOptions } from "./common/session.js"; -export { defaultUserConfig, type UserConfig, ALL_CONFIG_KEYS as configurableProperties } from "./common/config.js"; +export { type UserConfig } from "./common/config.js"; export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js"; export { StreamableHttpRunner } from "./transports/streamableHttp.js"; export { diff --git a/src/tools/mongodb/create/createIndex.ts b/src/tools/mongodb/create/createIndex.ts index 252d8c4c..68ad4d91 100644 --- a/src/tools/mongodb/create/createIndex.ts +++ b/src/tools/mongodb/create/createIndex.ts @@ -3,7 +3,8 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { type ToolArgs, type OperationType } from "../../tool.js"; import type { IndexDirection } from "mongodb"; -import { quantizationEnum, similarityEnum } from "../../../common/search/vectorSearchEmbeddingsManager.js"; +import { quantizationEnum } from "../../../common/search/vectorSearchEmbeddingsManager.js"; +import { similarityValues } from "../../../common/schemas.js"; export class CreateIndexTool extends MongoDBToolBase { private vectorSearchIndexDefinition = z.object({ @@ -38,7 +39,8 @@ export class CreateIndexTool extends MongoDBToolBase { .describe( "Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time" ), - similarity: similarityEnum + similarity: z + .enum(similarityValues) .default(this.config.vectorSearchSimilarityFunction) .describe( "Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields." diff --git a/src/tools/tool.ts b/src/tools/tool.ts index ec9f01a6..f8d594f7 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -6,9 +6,10 @@ import type { Session } from "../common/session.js"; import { LogId } from "../common/logger.js"; import type { Telemetry } from "../telemetry/telemetry.js"; import { type ToolEvent } from "../telemetry/types.js"; -import type { PreviewFeature, UserConfig } from "../common/config.js"; +import type { UserConfig } from "../common/config.js"; import type { Server } from "../server.js"; import type { Elicitation } from "../elicitation.js"; +import type { PreviewFeature } from "../common/schemas.js"; export type ToolArgs = z.objectOutputType; export type ToolCallbackArgs = Parameters>; diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index 7453cb3d..064af001 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -49,7 +49,6 @@ describe("Build Test", () => { "Session", "StreamableHttpRunner", "Telemetry", - "defaultUserConfig", ]) ); }); diff --git a/tests/unit/common/config.test.ts b/tests/unit/common/config.test.ts index 78a0382e..5c671ca7 100644 --- a/tests/unit/common/config.test.ts +++ b/tests/unit/common/config.test.ts @@ -2,15 +2,56 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; import type { UserConfig } from "../../../src/common/config.js"; import { setupUserConfig, - defaultUserConfig, registerKnownSecretsInRootKeychain, warnAboutDeprecatedOrUnknownCliArgs, + UserConfigSchema, } from "../../../src/common/config.js"; +import { getLogPath, getExportsPath } from "../../../src/common/configUtils.js"; import type { CliOptions } from "@mongosh/arg-parser"; import { Keychain } from "../../../src/common/keychain.js"; import type { Secret } from "../../../src/common/keychain.js"; describe("config", () => { + it("should generate defaults from UserConfigSchema that match expected values", () => { + // Expected hardcoded values (what we had before) + const expectedDefaults = { + apiBaseUrl: "https://cloud.mongodb.com/", + logPath: getLogPath(), + exportsPath: getExportsPath(), + exportTimeoutMs: 5 * 60 * 1000, // 5 minutes + exportCleanupIntervalMs: 2 * 60 * 1000, // 2 minutes + disabledTools: [], + telemetry: "enabled", + readOnly: false, + indexCheck: false, + confirmationRequiredTools: [ + "atlas-create-access-list", + "atlas-create-db-user", + "drop-database", + "drop-collection", + "delete-many", + "drop-index", + ], + transport: "stdio", + httpPort: 3000, + httpHost: "127.0.0.1", + loggers: ["disk", "mcp"], + idleTimeoutMs: 10 * 60 * 1000, // 10 minutes + notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes + httpHeaders: {}, + maxDocumentsPerQuery: 100, + maxBytesPerQuery: 16 * 1024 * 1024, // ~16 mb + atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours + voyageApiKey: "", + vectorSearchDimensions: 1024, + vectorSearchSimilarityFunction: "euclidean", + disableEmbeddingsValidation: false, + previewFeatures: [], + }; + + expect(UserConfigSchema.parse({})).toStrictEqual(expectedDefaults); + }); + describe("env var parsing", () => { describe("mongodb urls", () => { it("should not try to parse a multiple-host urls", () => { @@ -19,7 +60,6 @@ describe("config", () => { MDB_MCP_CONNECTION_STRING: "mongodb://user:password@host1,host2,host3/", }, cli: [], - defaults: defaultUserConfig, }); expect(actual.connectionString).toEqual("mongodb://user:password@host1,host2,host3/"); @@ -55,7 +95,6 @@ describe("config", () => { env: { [envVar]: String(value), }, - defaults: defaultUserConfig, }); expect(actual[property]).toBe(value); @@ -76,7 +115,6 @@ describe("config", () => { env: { [envVar]: "disk,mcp", }, - defaults: defaultUserConfig, }); expect(actual[config]).toEqual(["disk", "mcp"]); @@ -90,7 +128,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--connectionString", "mongodb://user:password@host1,host2,host3/"], env: {}, - defaults: defaultUserConfig, }); expect(actual.connectionString).toEqual("mongodb://user:password@host1,host2,host3/"); @@ -120,11 +157,11 @@ describe("config", () => { }, { cli: ["--httpPort", "8080"], - expected: { httpPort: "8080" }, + expected: { httpPort: 8080 }, }, { cli: ["--idleTimeoutMs", "42"], - expected: { idleTimeoutMs: "42" }, + expected: { idleTimeoutMs: 42 }, }, { cli: ["--logPath", "/var/"], @@ -132,11 +169,11 @@ describe("config", () => { }, { cli: ["--notificationTimeoutMs", "42"], - expected: { notificationTimeoutMs: "42" }, + expected: { notificationTimeoutMs: 42 }, }, { cli: ["--atlasTemporaryDatabaseUserLifetimeMs", "12345"], - expected: { atlasTemporaryDatabaseUserLifetimeMs: "12345" }, + expected: { atlasTemporaryDatabaseUserLifetimeMs: 12345 }, }, { cli: ["--telemetry", "enabled"], @@ -184,11 +221,19 @@ describe("config", () => { }, { cli: ["--oidcRedirectUri", "https://oidc"], - expected: { oidcRedirectUri: "https://oidc" }, + expected: { oidcRedirectUri: "https://oidc", oidcRedirectUrl: "https://oidc" }, + }, + { + cli: ["--oidcRedirectUrl", "https://oidc"], + expected: { oidcRedirectUrl: "https://oidc", oidcRedirectUri: "https://oidc" }, }, { cli: ["--password", "123456"], - expected: { password: "123456" }, + expected: { password: "123456", p: "123456" }, + }, + { + cli: ["-p", "123456"], + expected: { password: "123456", p: "123456" }, }, { cli: ["--port", "27017"], @@ -252,7 +297,11 @@ describe("config", () => { }, { cli: ["--username", "admin"], - expected: { username: "admin" }, + expected: { username: "admin", u: "admin" }, + }, + { + cli: ["-u", "admin"], + expected: { username: "admin", u: "admin" }, }, ] as { cli: string[]; expected: Partial }[]; @@ -261,12 +310,12 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", ...cli], env: {}, - defaults: defaultUserConfig, }); - for (const [key, value] of Object.entries(expected)) { - expect(actual[key as keyof UserConfig]).toBe(value); - } + expect(actual).toStrictEqual({ + ...UserConfigSchema.parse({}), + ...expected, + }); }); } }); @@ -360,7 +409,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", ...cli], env: {}, - defaults: defaultUserConfig, }); for (const [key, value] of Object.entries(expected)) { @@ -387,7 +435,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", ...cli], env: {}, - defaults: defaultUserConfig, }); for (const [key, value] of Object.entries(expected)) { @@ -403,7 +450,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--connectionString", "mongodb://localhost"], env: { MDB_MCP_CONNECTION_STRING: "mongodb://crazyhost" }, - defaults: defaultUserConfig, }); expect(actual.connectionString).toBe("mongodb://localhost"); @@ -413,10 +459,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--connectionString", "mongodb://localhost"], env: {}, - defaults: { - ...defaultUserConfig, - connectionString: "mongodb://crazyhost", - }, }); expect(actual.connectionString).toBe("mongodb://localhost"); @@ -426,10 +468,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: [], env: { MDB_MCP_CONNECTION_STRING: "mongodb://localhost" }, - defaults: { - ...defaultUserConfig, - connectionString: "mongodb://crazyhost", - }, }); expect(actual.connectionString).toBe("mongodb://localhost"); @@ -441,7 +479,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "mongodb://localhost", "--connectionString", "toRemove"], env: {}, - defaults: defaultUserConfig, }); // the shell specifies directConnection=true and serverSelectionTimeoutMS=2000 by default @@ -458,7 +495,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--transport", "http"], env: {}, - defaults: defaultUserConfig, }); expect(actual.transport).toEqual("http"); @@ -468,7 +504,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--transport", "stdio"], env: {}, - defaults: defaultUserConfig, }); expect(actual.transport).toEqual("stdio"); @@ -479,9 +514,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--transport", "sse"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid transport: sse"); + ).toThrowError( + 'Invalid configuration for the following fields:\ntransport - Invalid option: expected one of "stdio"|"http"' + ); }); it("should not support arbitrary values", () => { @@ -491,9 +527,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--transport", value], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError(`Invalid transport: ${value}`); + ).toThrowError( + `Invalid configuration for the following fields:\ntransport - Invalid option: expected one of "stdio"|"http"` + ); }); }); @@ -502,7 +539,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--telemetry", "enabled"], env: {}, - defaults: defaultUserConfig, }); expect(actual.telemetry).toEqual("enabled"); @@ -512,7 +548,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--telemetry", "disabled"], env: {}, - defaults: defaultUserConfig, }); expect(actual.telemetry).toEqual("disabled"); @@ -523,9 +558,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--telemetry", "true"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid telemetry: true"); + ).toThrowError( + 'Invalid configuration for the following fields:\ntelemetry - Invalid option: expected one of "enabled"|"disabled"' + ); }); it("should not support the boolean false value", () => { @@ -533,9 +569,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--telemetry", "false"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid telemetry: false"); + ).toThrowError( + 'Invalid configuration for the following fields:\ntelemetry - Invalid option: expected one of "enabled"|"disabled"' + ); }); it("should not support arbitrary values", () => { @@ -545,9 +582,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--telemetry", value], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError(`Invalid telemetry: ${value}`); + ).toThrowError( + `Invalid configuration for the following fields:\ntelemetry - Invalid option: expected one of "enabled"|"disabled"` + ); }); }); @@ -557,9 +595,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--httpPort", "0"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid httpPort: 0"); + ).toThrowError( + "Invalid configuration for the following fields:\nhttpPort - Invalid httpPort: must be at least 1" + ); }); it("must be below 65535 (OS limit)", () => { @@ -567,9 +606,10 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--httpPort", "89527345"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid httpPort: 89527345"); + ).toThrowError( + "Invalid configuration for the following fields:\nhttpPort - Invalid httpPort: must be at most 65535" + ); }); it("should not support non numeric values", () => { @@ -577,19 +617,19 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--httpPort", "portAventura"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Invalid httpPort: portAventura"); + ).toThrowError( + "Invalid configuration for the following fields:\nhttpPort - Invalid input: expected number, received NaN" + ); }); it("should support numeric values", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--httpPort", "8888"], env: {}, - defaults: defaultUserConfig, }); - expect(actual.httpPort).toEqual("8888"); + expect(actual.httpPort).toEqual(8888); }); }); @@ -599,9 +639,8 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--loggers", ""], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("No loggers found in config"); + ).toThrowError("Invalid configuration for the following fields:\nloggers - Cannot be an empty array"); }); it("must not allow duplicates", () => { @@ -609,16 +648,16 @@ describe("config", () => { setupUserConfig({ cli: ["myself", "--", "--loggers", "disk,disk,disk"], env: {}, - defaults: defaultUserConfig, }) - ).toThrowError("Duplicate loggers found in config"); + ).toThrowError( + "Invalid configuration for the following fields:\nloggers - Duplicate loggers found in config" + ); }); it("allows mcp logger", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--loggers", "mcp"], env: {}, - defaults: defaultUserConfig, }); expect(actual.loggers).toEqual(["mcp"]); @@ -628,7 +667,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--loggers", "disk"], env: {}, - defaults: defaultUserConfig, }); expect(actual.loggers).toEqual(["disk"]); @@ -638,7 +676,6 @@ describe("config", () => { const actual = setupUserConfig({ cli: ["myself", "--", "--loggers", "stderr"], env: {}, - defaults: defaultUserConfig, }); expect(actual.loggers).toEqual(["stderr"]); diff --git a/tests/unit/common/roles.test.ts b/tests/unit/common/roles.test.ts index 058e605a..e9eac0f2 100644 --- a/tests/unit/common/roles.test.ts +++ b/tests/unit/common/roles.test.ts @@ -1,11 +1,9 @@ import { describe, it, expect } from "vitest"; import { getDefaultRoleFromConfig } from "../../../src/common/atlas/roles.js"; -import { defaultUserConfig, type UserConfig } from "../../../src/common/config.js"; +import { UserConfigSchema, type UserConfig } from "../../../src/common/config.js"; describe("getDefaultRoleFromConfig", () => { - const defaultConfig: UserConfig = { - ...defaultUserConfig, - }; + const defaultConfig: UserConfig = UserConfigSchema.parse({}); const readOnlyConfig: UserConfig = { ...defaultConfig, diff --git a/tests/unit/toolBase.test.ts b/tests/unit/toolBase.test.ts index 984aa5bf..9f45eb55 100644 --- a/tests/unit/toolBase.test.ts +++ b/tests/unit/toolBase.test.ts @@ -3,11 +3,12 @@ import { z } from "zod"; import { ToolBase, type OperationType, type ToolCategory, type ToolConstructorParams } from "../../src/tools/tool.js"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import type { Session } from "../../src/common/session.js"; -import type { PreviewFeature, UserConfig } from "../../src/common/config.js"; +import type { UserConfig } from "../../src/common/config.js"; import type { Telemetry } from "../../src/telemetry/telemetry.js"; import type { Elicitation } from "../../src/elicitation.js"; import type { CompositeLogger } from "../../src/common/logger.js"; import type { TelemetryToolMetadata, ToolCallbackArgs } from "../../src/tools/tool.js"; +import type { PreviewFeature } from "../../src/common/schemas.js"; describe("ToolBase", () => { let mockSession: Session;