diff --git a/SConstruct b/SConstruct index 4ec07a36a5159..dcd591c10d7f4 100644 --- a/SConstruct +++ b/SConstruct @@ -1260,9 +1260,22 @@ env_vars.Add( default=version_data['githash'], ) +VSINSTALLDIR = 'C:/Program Files/Microsoft Visual Studio/2022/Professional' +MSVC_TOOLSET_VERSION = "14.31.31103" + +env_vars.Add('MSVC_USE_SCRIPT', help='Sets the script used to setup Visual Studio.', + default=f'{VSINSTALLDIR}/VC/Auxiliary/Build/vcvarsall.bat') + +env_vars.Add('VSINSTALLDIR', help='The path where Visual Studio is installed.', + default=VSINSTALLDIR) + +env_vars.Add('MSVC_USE_SCRIPT_ARGS', help='Sets the script used to setup Visual Studio.', + default=f'x64 -vcvars_ver={MSVC_TOOLSET_VERSION}') + env_vars.Add( - 'MSVC_USE_SCRIPT', - help='Sets the script used to setup Visual Studio.', + 'MSVC_TOOLSET_VERSION', + help='Sets the full toolset version of Visual C++ to use.', + default=MSVC_TOOLSET_VERSION, ) env_vars.Add( @@ -1271,12 +1284,6 @@ env_vars.Add( default="14.3", ) -env_vars.Add( - 'MSVC_TOOLSET_VERSION', - help='Sets the full toolset version of Visual C++ to use.', - default="14.31.31103", -) - env_vars.Add( 'LINKFLAGS_COMPILER_EXEC_PREFIX', help='Specify the search path to be injected into the LINKFLAGS', @@ -1647,6 +1654,9 @@ envDict = dict( LIBDEPS_TAG_EXPANSIONS=[], MSVC_VERSION=variables_only_env.get("MSVC_VERSION"), MSVC_TOOLSET_VERSION=variables_only_env.get("MSVC_TOOLSET_VERSION"), + VSINSTALLDIR=variables_only_env.get("VSINSTALLDIR"), + MSVC_USE_SCRIPT=variables_only_env.get("MSVC_USE_SCRIPT"), + MSVC_USE_SCRIPT_ARGS=variables_only_env.get("MSVC_USE_SCRIPT_ARGS"), ) # By default, we will get the normal SCons tool search. But if the diff --git a/buildscripts/resmokelib/mongod_fuzzer_configs.py b/buildscripts/resmokelib/mongod_fuzzer_configs.py index 42a92a0bfedc7..a64cc96029c0f 100644 --- a/buildscripts/resmokelib/mongod_fuzzer_configs.py +++ b/buildscripts/resmokelib/mongod_fuzzer_configs.py @@ -129,6 +129,8 @@ def generate_independent_parameters(rng, mode): ret["minSnapshotHistoryWindowInSeconds"] = rng.choice([300, rng.randint(5, 600)]) # TODO (SERVER-75632): Uncomment this to enable passthrough testing. # ret["lockCodeSegmentsInMemory"] = rng.choice([True, False]) + ret["preAuthMaximumMessageSizeBytes"] = rng.randint(65536, 16777216) + ret["capMemoryConsumptionForPreAuthBuffers"] = rng.randint(80, 100) return ret diff --git a/buildscripts/resmokelib/multiversionconstants.py b/buildscripts/resmokelib/multiversionconstants.py index cd3b67f32a8ce..182f30b8af497 100644 --- a/buildscripts/resmokelib/multiversionconstants.py +++ b/buildscripts/resmokelib/multiversionconstants.py @@ -53,13 +53,21 @@ def generate_releases_file(): def in_git_root_dir(): """Return True if we are in the root of a git directory.""" - if call(["git", "branch"], stderr=STDOUT, stdout=DEVNULL) != 0: - # We are not in a git directory. + try: + if call(["git", "branch"], stderr=STDOUT, stdout=DEVNULL) != 0: + # We are not in a git directory. + return False + except FileNotFoundError: + # Git is not even installed. return False - git_root_dir = check_output("git rev-parse --show-toplevel", shell=True, text=True).strip() - # Always use forward slash for the cwd path to resolve inconsistent formatting with Windows. - curr_dir = os.getcwd().replace("\\", "/") + git_root_dir = os.path.realpath( + check_output("git rev-parse --show-toplevel", shell=True, text=True).strip()) + curr_dir = os.path.realpath(os.getcwd()) + # python on windows under version 3.8 can have + # weird behavior with identifying the real drive + curr_dir = curr_dir.replace("Z:\\", "C:\\") + git_root_dir = git_root_dir.replace("Z:\\", "C:\\") return git_root_dir == curr_dir diff --git a/buildscripts/sync_repo_with_copybara.py b/buildscripts/sync_repo_with_copybara.py index d33037877aec1..e617753ccb058 100644 --- a/buildscripts/sync_repo_with_copybara.py +++ b/buildscripts/sync_repo_with_copybara.py @@ -11,6 +11,9 @@ from buildscripts.util.read_config import read_config_file from evergreen.api import RetryingEvergreenApi +# Commit hash of Copybara to use (v20251110) +COPYBARA_COMMIT_HASH = "3f050c9e08b84aeda98875bf1b02a3288d351333" + def run_command(command): # noqa: D406 """ @@ -107,6 +110,9 @@ def main(): else: run_command("git clone https://github.com/10gen/copybara.git") + # Checkout the specific commit of Copybara we want to use + run_command(f"cd copybara && git checkout {COPYBARA_COMMIT_HASH}") + # Navigate to the Copybara directory and build the Copybara Docker image run_command("cd copybara && docker build --rm -t copybara .") diff --git a/etc/backports_required_for_multiversion_tests.yml b/etc/backports_required_for_multiversion_tests.yml index 40c051bceed54..457e1b8521b3c 100644 --- a/etc/backports_required_for_multiversion_tests.yml +++ b/etc/backports_required_for_multiversion_tests.yml @@ -629,10 +629,14 @@ last-continuous: ticket: SERVER-103960 - test_file: jstests/core/txns/multi_statement_transaction_abort.js ticket: SERVER-84081 + - test_file: jstests/aggregation/expressions/reduce_overflow.js + ticket: SERVER-102364 - test_file: jstests/change_streams/ddl_create_drop_index_events.js ticket: SERVER-93153 - test_file: jstests/change_streams/resume_expanded_events.js ticket: SERVER-93153 + - test_file: jstests/replsets/insert_documents_close_to_size_limit.js + ticket: SERVER-113532 suites: null last-lts: all: @@ -1314,8 +1318,12 @@ last-lts: ticket: SERVER-103960 - test_file: jstests/core/txns/multi_statement_transaction_abort.js ticket: SERVER-84081 + - test_file: jstests/aggregation/expressions/reduce_overflow.js + ticket: SERVER-102364 - test_file: jstests/change_streams/ddl_create_drop_index_events.js ticket: SERVER-93153 - test_file: jstests/change_streams/resume_expanded_events.js ticket: SERVER-93153 + - test_file: jstests/replsets/insert_documents_close_to_size_limit.js + ticket: SERVER-113532 suites: null diff --git a/evergreen/multiversion_setup.sh b/evergreen/multiversion_setup.sh index 438f0fb4ccefb..781dac59c989e 100644 --- a/evergreen/multiversion_setup.sh +++ b/evergreen/multiversion_setup.sh @@ -95,4 +95,4 @@ db-contrib-tool setup-repro-env \ --fallbackToMaster \ --resmokeCmd "python buildscripts/resmoke.py" \ --debug \ - $last_lts_arg 5.0 + $last_lts_arg 5.0 7.0.25 diff --git a/jstests/aggregation/expressions/reduce_overflow.js b/jstests/aggregation/expressions/reduce_overflow.js new file mode 100644 index 0000000000000..1c9ef094f1c7e --- /dev/null +++ b/jstests/aggregation/expressions/reduce_overflow.js @@ -0,0 +1,50 @@ +/** + * Verify the server does not crash when $reduce creates a deeply nested intermediate document. + */ +const coll = db[jsTestName()]; + +let seenSuccess = false; +let seenOverflow = false; +let recursiveObject = {"a": "$$value.array"}; +let depth = 0; +for (let recursiveObjectDepth = 10; recursiveObjectDepth < 150; recursiveObjectDepth *= 2) { + while (depth < recursiveObjectDepth) { + recursiveObject = {"a": recursiveObject}; + depth = depth + 1; + } + + const pipeline = [ + {"$group": {"_id": null, "entries": {"$push": "$value"}}}, + { + "$project": { + "filtered": { + "$reduce": { + "input": "$entries", + "initialValue": {"array": []}, + "in": {"array": [recursiveObject]} + } + } + } + } + ]; + + for (let numDocs = 10; numDocs < 500; numDocs *= 2) { + coll.drop(); + const bulk = coll.initializeUnorderedBulkOp(); + for (let i = 0; i < numDocs; i++) { + bulk.insert({"value": 0}); + } + assert.commandWorked(bulk.execute()); + try { + coll.aggregate(pipeline); + seenSuccess = true; + assert(!seenOverflow); + } catch (error) { + assert(seenSuccess); + assert(error.code === ErrorCodes.Overflow, error); + jsTest.log("Pipeline exceeded max BSON depth", numDocs, recursiveObjectDepth); + seenOverflow = true; + } + } +} +assert(seenOverflow, "expected test to trigger overflow case"); diff --git a/jstests/auth/lib/commands_lib.js b/jstests/auth/lib/commands_lib.js index d4388b2a9ad7b..18680ea10bd29 100644 --- a/jstests/auth/lib/commands_lib.js +++ b/jstests/auth/lib/commands_lib.js @@ -5805,6 +5805,22 @@ export const authCommandsLib = { } ] }, + { + testname: "profileFilter", + command: {profile: -1, filter: {$alwaysTrue: 1}}, + testcases: [ + { + runOnDb: firstDbName, + roles: roles_dbAdmin, + privileges: [{resource: {db: firstDbName, collection: ""}, actions: ["enableProfiler"]}], + }, + { + runOnDb: secondDbName, + roles: roles_dbAdminAny, + privileges: [{resource: {db: secondDbName, collection: ""}, actions: ["enableProfiler"]}], + }, + ], + }, { testname: "profile_mongos", command: {profile: 0, slowms: 10, sampleRate: 0.5}, diff --git a/jstests/core/compound_wildcard_index_validation.js b/jstests/core/compound_wildcard_index_validation.js index c64074342aad7..a4e4a79bfb865 100644 --- a/jstests/core/compound_wildcard_index_validation.js +++ b/jstests/core/compound_wildcard_index_validation.js @@ -32,11 +32,13 @@ assert.commandWorked( // Tests that _id can be excluded in an inclusion projection statement. assert.commandWorked( coll.createIndex({"$**": 1, "other": 1}, {"wildcardProjection": {"_id": 0, "a": 1}})); -// Tests that _id can be inccluded in an exclusion projection statement. +// Tests that _id can be included in an exclusion projection statement. assert.commandWorked(coll.createIndex({"$**": 1, "another": 1}, {"wildcardProjection": {"_id": 1, "a": 0, "another": 0}})); +assert.commandWorked( + coll.createIndex({"$**": 1, "yetAnother": 1}, {"wildcardProjection": {"_id": 1}})); -// Tests we wildcard projections allow nested objects. +// Tests wildcard projections allow nested objects. assert.commandWorked( coll.createIndex({"$**": 1, "d": 1}, {"wildcardProjection": {"a": {"b": 1, "c": 1}}})); @@ -67,10 +69,55 @@ assert.commandFailedWithCode(coll.createIndex({"a.$**": 1, b: 1}, {expireAfterSe // 'wildcardProjection' is not specified. assert.commandFailedWithCode(coll.createIndex({a: 1, "$**": 1}), 67); +// Tests that the wildcardProjection cannot be empty. +assert.commandFailedWithCode(coll.createIndex({a: 1, "$**": 1}, {wildcardProjection: {}}), + ErrorCodes.FailedToParse); + +// Tests that a wildcardProjection is not allowed for non-wildcard indexes. +assert.commandFailedWithCode(coll.createIndex({a: 1, b: 1}, {wildcardProjection: {c: 1}}), + ErrorCodes.BadValue); + +// Tests that wildcardProjection cannot include regular index fields. +assert.commandFailedWithCode(coll.createIndex({a: 1, "$**": 1}, {wildcardProjection: {a: 1}}), + 7246208); +assert.commandFailedWithCode(coll.createIndex({a: 1, "$**": 1}, {wildcardProjection: {a: 1, b: 1}}), + 7246208); +assert.commandFailedWithCode( + coll.createIndex({a: 1, "$**": 1, b: 1}, {wildcardProjection: {a: 1, b: 1}}), 7246208); +assert.commandFailedWithCode(coll.createIndex({_id: 1, "$**": 1}, {wildcardProjection: {_id: 1}}), + 7246208); + +// Tests that a wildcardProjection can only mix inclusion/exclusion projections with _id. +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, c: 1}, {wildcardProjection: {_id: 0, a: 0, b: 1}}), 7246211); +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, c: 1}, {wildcardProjection: {_id: 1, a: 0, b: 1}}), 7246211); +assert.commandWorked(coll.createIndex({"$**": 1, "c": 1}, {wildcardProjection: {"_id": 0, b: 1}})); +assert.commandWorked(coll.createIndex({"$**": 1, "e": 1}, {wildcardProjection: {"_id": 1, e: 0}})); +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, "_id": 1, "e": 1}, {wildcardProjection: {"_id": 1, e: 0}}), + 7246209, +); +assert.commandWorked( + coll.createIndex({"$**": 1, "_id": 1, "e": 1}, {wildcardProjection: {"_id": 0, a: 1}})); + // Tests that wildcard projections accept only numeric values. assert.commandFailedWithCode( coll.createIndex({"st": 1, "$**": 1}, {wildcardProjection: {"a": "something"}}), 51271); +// Tests that just excluding _id is not valid in the wildcardProjection, unless the regular part is +// _id. +assert.commandFailedWithCode(coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0}}), + 7246210); +assert.commandFailedWithCode(coll.createIndex({"$**": 1, "a": 1}, {wildcardProjection: {"_id": 0}}), + 7246210); +assert.commandFailedWithCode( + coll.createIndex({"b": 1, "$**": 1, "a": 1}, {wildcardProjection: {"_id": 0}}), 7246210); +assert.commandWorked(coll.createIndex({"_id": 1, "$**": 1}, {wildcardProjection: {"_id": 0}})); +assert.commandWorked(coll.createIndex({"$**": 1, "_id": 1}, {wildcardProjection: {"_id": 0}})); +assert.commandFailedWithCode( + coll.createIndex({"_id": 1, "$**": 1, "a": 1}, {wildcardProjection: {"_id": 0}}), 7246210); + // Tests that all compound wildcard indexes in the catalog can be validated by running validate() // command. @@ -79,6 +126,54 @@ assert.commandWorked(coll.createIndex({a: 1, "b.$**": 1, str: 1})); assert.commandWorked(coll.createIndex({"b.$**": 1, str: 1})); assert.commandWorked(coll.createIndex({a: 1, "b.$**": 1})); assert.commandWorked(coll.createIndex({"$**": 1})); +assert.commandWorked( + coll.createIndex({"b": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 1}})); +assert.commandWorked( + coll.createIndex({"_id": 1, "a": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 0}})); +assert.commandFailedWithCode( + coll.createIndex({"a": 1, "_id": 1, "$**": 1}, {wildcardProjection: {"a": 0}}), 7246209); +assert.commandFailedWithCode( + coll.createIndex({"a": 1, "b": 1, "$**": 1}, {wildcardProjection: {"b": 0}}), 7246209); +assert.commandWorked( + coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 0}})); +assert.commandWorked( + coll.createIndex({"e": 1, "$**": 1}, {wildcardProjection: {"_id": 1, "f": 1}})); +assert.commandFailedWithCode( + coll.createIndex({"b": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 0}}), 7246210); +assert.commandFailedWithCode( + coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 1}}), 7246208); +assert.commandFailedWithCode( + coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 1, "a": 1}}), 7246208); +assert.commandFailedWithCode( + coll.createIndex({"_id": 1, "$**": 1}, {wildcardProjection: {"_id": 1, "a": 1}}), 7246208); +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, "d": 1}, {wildcardProjection: {"_id": 1, e: 0}}), 7246209); + +// Dotted paths +assert.commandWorked( + coll.createIndex({"_id": 1, "a.b": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 0}})); +assert.commandFailedWithCode( + coll.createIndex({"_id": 1, "a": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a.b": 0}}), + 7246209, +); +assert.commandWorked( + coll.createIndex({"a.b": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 0}})); +assert.commandFailedWithCode( + coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a.b": 0}}), 7246209); +assert.commandWorked( + coll.createIndex({"$**": 1, "_id": 1, "e.b": 1}, {wildcardProjection: {"_id": 0, "a.b": 1}})); +assert.commandWorked( + coll.createIndex({"b.c": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a": 1}})); +assert.commandWorked( + coll.createIndex({"b.d": 1, "$**": 1}, {wildcardProjection: {"_id": 0, "a.b": 1}})); +assert.commandWorked( + coll.createIndex({"$**": 1, "e.b": 1}, {wildcardProjection: {"_id": 1, e: 0}})); +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, "e": 1}, {wildcardProjection: {"_id": 1, "e.b": 0}}), 7246209); +assert.commandFailedWithCode( + coll.createIndex({"$**": 1, "e.c": 1}, {wildcardProjection: {"_id": 1, "e.b": 0}}), + 7246210, +); // Insert documents to index. for (let i = 0; i < 10; i++) { diff --git a/jstests/multiVersion/genericBinVersion/invalid_wildcard_index_log_at_startup.js b/jstests/multiVersion/genericBinVersion/invalid_wildcard_index_log_at_startup.js new file mode 100644 index 0000000000000..497a12b708828 --- /dev/null +++ b/jstests/multiVersion/genericBinVersion/invalid_wildcard_index_log_at_startup.js @@ -0,0 +1,96 @@ +/** + * Tests that a server containing an invalid wildcard index will log a warning on startup. + * + * @tags: [ + * requires_persistence, + * requires_replication, + * ] + */ + +// This is a version that allows the bad index to be created. +const oldVersion = "7.0.25"; + +// Standalone mongod +{ + const testName = "invalid_wildcard_index_log_at_startup"; + const dbpath = MongoRunner.dataPath + testName; + const collName = "collectionWithInvalidWildcardIndex"; + + { + // Startup mongod version where we are allowed to create the invalid index. + const conn = MongoRunner.runMongod({dbpath: dbpath, binVersion: oldVersion}); + assert.neq(null, conn, "mongod was unable to start up"); + + const testDB = conn.getDB("test"); + assert.commandWorked(testDB[collName].insert({a: 1})); + + // Invalid index + assert.commandWorked( + testDB[collName].createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0}})); + + MongoRunner.stopMongod(conn); + } + { + const conn = MongoRunner.runMongod({dbpath: dbpath, noCleanData: true}); + assert.neq(null, conn, "mongod was unable to start up"); + const testDB = conn.getDB("test"); + + const cmdRes = assert.commandWorked(testDB.adminCommand({getLog: "startupWarnings"})); + assert( + /Found a compound wildcard index with an invalid wildcardProjection. Such indexes can no longer be created./ + .test( + cmdRes.log, + ), + ); + + // Be sure that inserting to the collection with the invalid index succeeds. + assert.commandWorked(testDB[collName].insert({a: 2})); + + // Inserting to another collection should succeed. + assert.commandWorked(testDB.someOtherCollection.insert({a: 1})); + assert.eq(testDB.someOtherCollection.find().itcount(), 1); + + MongoRunner.stopMongod(conn); + } +} + +// Replica set +{ + let nodes = { + n1: {binVersion: oldVersion}, + n2: {binVersion: oldVersion}, + }; + + const rst = new ReplSetTest({nodes: nodes}); + rst.startSet(); + rst.initiate(); + + let primary = rst.getPrimary(); + const db = primary.getDB("test"); + const coll = db.t; + assert.commandWorked(coll.insert({a: 1})); + + assert.commandWorked(coll.createIndex({"a": 1, "$**": 1}, {wildcardProjection: {"_id": 0}})); + + // Force checkpoint in storage engine to ensure index is part of the catalog in + // in finished state at startup. + rst.awaitReplication(); + let secondary = rst.getSecondary(); + assert.commandWorked(secondary.adminCommand({fsync: 1})); + + // Check that initial sync works, this node would not allow the index to be created + // (since it is on a version with the new validation logic) but should not fail on startup. + const initialSyncNode = rst.add({rsConfig: {priority: 0}}); + rst.reInitiate(); + rst.awaitSecondaryNodes(null, [initialSyncNode]); + + // Restart the new node and check for the startup warning in the logs. + rst.restart(initialSyncNode); + rst.awaitSecondaryNodes(null, [initialSyncNode]); + + checkLog.containsJson(initialSyncNode, 11389700, { + ns: coll.getFullName(), + }); + + rst.stopSet(); +} diff --git a/jstests/noPassthrough/cap_memory_consumption_preauth_buffers.js b/jstests/noPassthrough/cap_memory_consumption_preauth_buffers.js new file mode 100644 index 0000000000000..d0d1c5fef164a --- /dev/null +++ b/jstests/noPassthrough/cap_memory_consumption_preauth_buffers.js @@ -0,0 +1,60 @@ +/** + * Tests that a server started with a low limit for `capMemoryConsumptionForPreAuthBuffers` will + * adjust maxConnections and log a warning message on startup, noting the memory cap forces a + * reduction in the maximum number of connections. + * + * @tags: [ + * multiversion_incompatible, + * ] + */ + +load("jstests/libs/log.js"); // For findMatchingLogLine +load("jstests/libs/os_helpers.js"); // For isLinux + +if (!isLinux()) { + jsTest.log("Skipping test since it requires Linux-specific features."); + quit(); +} + +(function testCapMemoryConsumptionPreAuthBuffersWarning() { + // Start mongod with capMemoryConsumptionForPreAuthBuffers set to 1% (very low) + // This should trigger the warning because the calculated connection limit will be + // much lower than the default maxConns, regardless of available system memory. + const conn = MongoRunner.runMongod({ + setParameter: { + capMemoryConsumptionForPreAuthBuffers: 1, + // TODO set buffer size to 16 * 1024 for pre-auth buffers + }, + }); + + assert.neq(null, conn, "mongod was unable to start up"); + + const db = conn.getDB("admin"); + + // TODO This can only run on small linux machines with properly set ulimits, so some of burn-in + // tests will fail since they run on very large machines. This needs to be adjusted. + const hostInfo = assert.commandWorked(db.hostInfo()); + const memLimitBytes = hostInfo.system.memLimitMB * 1024 * 1024; + const maxOpenFiles = hostInfo.extra.maxOpenFiles; + + const maxConnsLimit = (maxOpenFiles * 0.8) / 2; + const connCapByMemoryLimit = Math.floor((memLimitBytes * 0.01) / (16 * 1024)); + if (maxConnsLimit <= connCapByMemoryLimit) { + jsTest.log("Skipping test since since the connection cap won't be triggered."); + jsTest.log(" memLimitBytes: " + memLimitBytes); + jsTest.log(" maxOpenFiles: " + maxOpenFiles); + jsTest.log(" maxConnsLimit: " + maxConnsLimit); + jsTest.log(" connCapByMemoryLimit: " + connCapByMemoryLimit); + MongoRunner.stopMongod(conn); + return; + } + + // Get all log messages (including warnings) + const logResults = + assert.commandWorked(db.adminCommand({getLog: "global"}), "Failed to get global log"); + + assert(!!findMatchingLogLine(logResults.log, {id: 11621101}), + "Failed to find the expected warning message. Log contents: " + tojson(logResults.log)); + + MongoRunner.stopMongod(conn); +})(); diff --git a/jstests/noPassthrough/cap_memory_consumption_preauth_buffers_validation.js b/jstests/noPassthrough/cap_memory_consumption_preauth_buffers_validation.js new file mode 100644 index 0000000000000..1381b4f820a98 --- /dev/null +++ b/jstests/noPassthrough/cap_memory_consumption_preauth_buffers_validation.js @@ -0,0 +1,49 @@ +/** + * Tests that `capMemoryConsumptionForPreAuthBuffers` parameter validation fails at startup + * when set to invalid values (0 and 101), rather than at runtime. + * + * @tags: [ + * multiversion_incompatible, + * ] + */ + +(function testCapMemoryConsumptionPreAuthBuffersValidation() { + jsTest.log("Testing parameter validation for capMemoryConsumptionForPreAuthBuffers"); + + // Test 1: Value 0 should fail (must be > 0) + jsTest.log("Testing that value 0 fails at startup"); + assert.throws( + () => MongoRunner.runMongod({ + setParameter: { + capMemoryConsumptionForPreAuthBuffers: 0, + }, + }), + [], + "Expected mongod to fail to startup with capMemoryConsumptionForPreAuthBuffers=0"); + + // Test 2: Value 101 should fail (must be <= 100) + jsTest.log("Testing that value 101 fails at startup"); + assert.throws( + () => MongoRunner.runMongod({ + setParameter: { + capMemoryConsumptionForPreAuthBuffers: 101, + }, + }), + [], + "Expected mongod to fail to startup with capMemoryConsumptionForPreAuthBuffers=101"); + + // Test 3: Valid value 1 should succeed + jsTest.log("Testing that valid value 1 succeeds at startup"); + let conn = MongoRunner.runMongod({ + setParameter: { + capMemoryConsumptionForPreAuthBuffers: 1, + }, + }); + assert.neq( + null, + conn, + "Expected mongod to start successfully with capMemoryConsumptionForPreAuthBuffers=1"); + MongoRunner.stopMongod(conn); + + jsTest.log("All parameter validation tests passed"); +})(); diff --git a/jstests/noPassthrough/message_size_check.js b/jstests/noPassthrough/message_size_check.js new file mode 100644 index 0000000000000..743fbd17a616a --- /dev/null +++ b/jstests/noPassthrough/message_size_check.js @@ -0,0 +1,85 @@ +/** + * Verifies the message size limits for pre-auth and post-auth commands are correctly enforced. + */ + +function sendLargeHello(conn, shouldFail) { + const largeHello = {hello: 1, bigField: "x".repeat(20 * 1024)}; + + let ssNetworkMetricsBefore = connAdmin.runCommand({serverStatus: 1}).metrics.network; + if (shouldFail) { + let closedConnErr; + try { + conn.adminCommand(largeHello); + } catch (e) { + closedConnErr = e; + } + assert(closedConnErr, "Expected server to close the connection on oversized hello"); + assert( + isNetworkError(closedConnErr), + () => "Expected a network error from closed connection, got: " + tojson(closedConnErr), + ); + } else { + assert.commandWorked(conn.adminCommand(largeHello)); + } + + const expectedTotalMessageSizeErrorPreAuth = shouldFail + ? Number(ssNetworkMetricsBefore.totalMessageSizeErrorPreAuth) + 1 + : Number(ssNetworkMetricsBefore.totalMessageSizeErrorPreAuth); + let ssNetworkMetricsAfter = connAdmin.runCommand({serverStatus: 1}).metrics.network; + assert.eq(expectedTotalMessageSizeErrorPreAuth, + Number(ssNetworkMetricsAfter.totalMessageSizeErrorPreAuth)); + assert.eq(Number(ssNetworkMetricsAfter.totalMessageSizeErrorPostAuth), 0); +} + +const rsName = jsTestName(); + +const rs = new ReplSetTest({name: rsName, nodes: 1, keyFile: "jstests/libs/key1"}); + +rs.startSet({ + setParameter: { + messageSizeErrorRateSec: 5, + }, +}); +rs.initiate(); + +const primary = rs.getPrimary(); +const adminDB = primary.getDB("admin"); + +const user = jsTestName() + "_admin"; +const pwd = "pwd"; +assert.commandWorked( + adminDB.runCommand({ + createUser: user, + pwd: pwd, + roles: ["root"], + writeConcern: {w: "majority"}, + }), +); + +const conn = new Mongo(primary.host); +const connAdmin = conn.getDB("admin"); +assert.eq(1, connAdmin.auth(user, pwd), "Authentication failed"); + +{ + const newConn = new Mongo(primary.host); + sendLargeHello(newConn, true); +} + +sendLargeHello(conn, false); + +assert.commandWorked( + connAdmin.runCommand({ + setParameter: 1, + preAuthMaximumMessageSizeBytes: 1024 * 1024, + }), + "Failed to set preAuthMaximumMessageSizeBytes to 1MB", +); + +{ + const newConn = new Mongo(primary.host); + sendLargeHello(newConn, false); + newConn.close(); +} + +conn.close(); +rs.stopSet(); diff --git a/jstests/noPassthrough/sharded_find_with_collation.js b/jstests/noPassthrough/sharded_find_with_collation.js new file mode 100644 index 0000000000000..a9eb6dbc47263 --- /dev/null +++ b/jstests/noPassthrough/sharded_find_with_collation.js @@ -0,0 +1,161 @@ +/** + * Tests for the absence/presence of a shard filtering stage for queries which may/may not + * be single shard targeted. + * TODO SERVER-94611: Extend testing once shard key may have non-simple collation. + */ +load('jstests/libs/analyze_plan.js'); +load('jstests/libs/fail_point_util.js'); + +const st = new ShardingTest({shards: 2}); + +const db = st.getDB("test"); +const coll = db[jsTestName()]; + +const setupCollection = ({shardKey, splits}) => { + coll.drop(); + + assert.commandWorked(db.createCollection(coll.getName())); + + // Shard the collection on with the provided spec, implicitly creating an index with simple + // collation. + assert.commandWorked(db.adminCommand({shardCollection: coll.getFullName(), key: shardKey})); + + // Split the collection. e.g., + // shard0: [chunk 1] { : { "$minKey" : 1 } } -->> { : 0 } + // shard0: [chunk 2] { : 0 } -->> { : "b"} + // shard1: [chunk 3] { : "b" } -->> { : { "$maxKey" : 1 }} + // Chunk 2 will be moved between the shards. + for (let mid of splits) { + assert.commandWorked(db.adminCommand({split: coll.getFullName(), middle: mid})); + } + + let doc = {_id: "a", a: "a"}; + + assert.commandWorked(db.adminCommand( + {moveChunk: coll.getFullName(), find: {_id: MinKey, a: MinKey}, to: st.shard0.shardName})); + assert.commandWorked( + db.adminCommand({moveChunk: coll.getFullName(), find: doc, to: st.shard0.shardName})); + assert.commandWorked(db.adminCommand( + {moveChunk: coll.getFullName(), find: {_id: MaxKey, a: MaxKey}, to: st.shard1.shardName})); + + // Put data on shard0, that will go into chunk 2. + assert.commandWorked(coll.insert(doc)); + + // Perform a chunk migration of chunk 2 from shard0 to shard1, but do not clean + // up orphans on shard0 (see suspendRangeDeletionShard0 failpoint) + db.adminCommand({moveChunk: coll.getFullName(), find: doc, to: st.shard1.shardName}); +}; + +/** + * Verify that if a SHARD_MERGE is required, each shard includes a SHARD_FILTER + * in their plan. + * (Note that !SHARD_MERGE && SHARD_FILTER is a pessimisation but not a functional issue). + * + * Accepts a DBQuery object - e.g., + * assertShardMergeImpliesShardFilter(db.coll.find({foo:"bar"})) + */ +function assertShardMergeImpliesShardFilter(queryObj, testInfo) { + const explain = queryObj.explain(); + const winningPlan = getWinningPlan(explain.queryPlanner); + + const shardMerge = planHasStage(db, winningPlan, "SHARD_MERGE"); + + if (!shardMerge) { + return; + } + + const express = winningPlan.shards.every((shard) => { + return planHasStage(db, getWinningPlan(shard), "EXPRESS_IXSCAN"); + }); + + if (express) { + // Lookup by a specific _id. + // TODO: SERVER-98300 - explain for express path doesn't indicate if shard + // filtering will be applied, so it can't be asserted here. + // The document count assertions will still be applied, and would fail if + // duplicate results were returned due to missing shard filtering + orphans. + return; + } + + const shardFiltered = winningPlan.shards.every((shard) => { + return planHasStage(db, getWinningPlan(shard), "SHARDING_FILTER"); + }); + + assert(shardFiltered, + {msg: "SHARD_MERGE but missing SHARDING_FILTER", explain: explain, testInfo: testInfo}); +} + +const caseInsensitive = { + locale: 'en_US', + strength: 2 +}; + +const doQueries = (shardKey, collation) => { + let evalFn = (query, expectedCount = 1) => { + // Information about the current combination of parameters under test, + // to log in the event of a failure. + const testInfo = {query: query, collation: collation, shardKey: shardKey}; + let queryObj = coll.find(query); + if (collation !== undefined) { + queryObj = queryObj.collation(collation); + } + assert.eq(expectedCount, queryObj.toArray().length, testInfo); + assertShardMergeImpliesShardFilter(queryObj, testInfo); + }; + for (let fieldNames of [["a"], ["_id"], ["a", "_id"]]) { + // Helper to compose an object like: + // {a: }, {_id:}, {a:, _id:} + // using the above field names. + let query = value => Object.fromEntries(fieldNames.map(name => [name, value])); + // Equality. + evalFn(query("a")); + // A document was not inserted with {a:1} or {_id:1} (or both), no docs should match. + evalFn(query(1), /*expectedCount */ 0); + // Ranges within a single chunk. + evalFn(query({"$lte": "a"})); + evalFn(query({"$gte": "a"})); + evalFn(query({"$gte": "a", "$lte": "a"})); + } +}; + +let suspendRangeDeletionShard0; + +function getSplitPoints(shardKey) { + // We wish to isolate a document in a particular chunk, by splitting at 0, and 'b'. + // However, we also have wider testing of hashed and compound shard keys. + let results = []; + for (const splitPoint of [0, 'b']) { + let split = {}; + for (const [key, value] of Object.entries(shardKey)) { + split[key] = value == "hashed" ? convertShardKeyToHashed(splitPoint) : splitPoint; + } + results.push(split); + } + return results; +} +const shardKeys = [ + {a: 1}, + {_id: 1}, + {a: 1, _id: 1}, + {a: "hashed"}, + {_id: "hashed"}, + {a: "hashed", _id: 1}, + {a: 1, _id: "hashed"} +]; +for (let shardKey of shardKeys) { + let splitPoints = getSplitPoints(shardKey); + suspendRangeDeletionShard0 = configureFailPoint(st.shard0, 'suspendRangeDeletion'); + setupCollection({shardKey: shardKey, splits: splitPoints}); + + // Queries without collation. + doQueries(shardKey); + + // Queries WITH collation. + // Since the query has non-simple collation we will have to broadcast to all shards (since + // the shard key is on a simple collation index), and should have a SHARD_FILTER stage. + doQueries(shardKey, caseInsensitive); + + suspendRangeDeletionShard0.off(); + coll.drop(); +} +st.stop(); diff --git a/jstests/noPassthrough/traffic_recording.js b/jstests/noPassthrough/traffic_recording.js index e3cfc33d819e2..bcc8a6024b4ec 100644 --- a/jstests/noPassthrough/traffic_recording.js +++ b/jstests/noPassthrough/traffic_recording.js @@ -7,6 +7,11 @@ function getDB(client) { return db; } +function addPreAuth(params) { + params["preAuthMaximumMessageSizeBytes"] = 16 * 1024 * 1024; + return params; +} + function runTest(client, restartCommand) { let db = getDB(client); @@ -19,26 +24,26 @@ function runTest(client, restartCommand) { if (!jsTest.isMongos(client)) { TestData.enableTestCommands = false; - client = restartCommand({ + client = restartCommand(addPreAuth({ trafficRecordingDirectory: path, AlwaysRecordTraffic: "notARealPath", enableTestCommands: 0, - }); + })); TestData.enableTestCommands = true; assert.eq(null, client, "AlwaysRecordTraffic and not enableTestCommands should fail"); } - client = restartCommand({ + client = restartCommand(addPreAuth({ trafficRecordingDirectory: path, AlwaysRecordTraffic: "notARealPath", enableTestCommands: 1 - }); + })); assert.neq(null, client, "AlwaysRecordTraffic and with enableTestCommands should suceed"); db = getDB(client); assert(db.runCommand({"serverStatus": 1}).trafficRecording.running); - client = restartCommand({trafficRecordingDirectory: path}); + client = restartCommand(addPreAuth({trafficRecordingDirectory: path})); db = getDB(client); res = db.runCommand({'startRecordingTraffic': 1, 'filename': 'notARealPath'}); diff --git a/jstests/replsets/insert_documents_close_to_size_limit.js b/jstests/replsets/insert_documents_close_to_size_limit.js new file mode 100644 index 0000000000000..6536c061fed52 --- /dev/null +++ b/jstests/replsets/insert_documents_close_to_size_limit.js @@ -0,0 +1,131 @@ +/** + * Tests that document inserts with documents that are close to the size limit and that have larger + * _id values can be successfully inserted into all nodes in a replica set. This implicitly tests + * the message size limits that are applied on messages received by the secondaries' oplog fetcher. + * + * This test can be slow to execute. Exclude from any build variants that are known to be slow. + * + * @tags: [ + * incompatible_aubsan, + * resource_intensive, + * tsan_incompatible, + * ] + */ +load("jstests/libs/collection_drop_recreate.js"); + +const kDbName = 'insert_docs'; +const kCollName = 'coll'; + +function runTest(db) { + const coll = db[kCollName]; + + // Generate _id values with different characters so we avoid duplicate key errors. + let nextChar; + const resetNextChar = () => { + nextChar = 'a'.charCodeAt(0); + }; + const getNextChar = () => String.fromCharCode(nextChar++); + + const runInsert = (idLength, payloadLength) => { + const doc = {_id: getNextChar().repeat(idLength), payload: 'a'.repeat(payloadLength)}; + + jsTestLog(`Inserting document with id length ${idLength}, payload length ${payloadLength}`); + + // The insert needs to be acknowledged by all nodes in the replica set, not just the + // primary. + return coll.insert(doc, {writeConcern: {w: 3}}); + }; + + const kMaxUserSize = 16 << 20; + const maxPayloadLengthForIdLength = (idLength) => { + // Overhead: + // - 4 bytes for object size + // - 1 byte for string type tag + // - 4 bytes for "_id" field name plus trailing \0 byte + // - 4 bytes for string field length + // - 1 byte for trailing \0 byte for string value + // - 1 byte for string type tag + // - 8 bytes for "payload" field name plus trailing \0 byte + // - 4 bytes for string field length + // - 1 byte for trailing \0 byte for string value + // - 1 byte for trailing \0 byte for object + // ======================================== + // = 29 bytes total overhead + // + // This formula is accurate for all id lengths <= 10000. + // It does not work accurately with arbitrary large _id values, because the _id value + // is duplicated in the oplog entry, and the total size of the oplog entry may then + // exceed the size limit when using too large _ids. + return kMaxUserSize - 29 - idLength; + }; + + [1, 10, 100, 1000, 10000].forEach((idLength) => { + coll.remove({}); + + resetNextChar(); + + // Test that inserting large documents into all nodes of the replica set works. + // The documents created below have total sizes that are below the BSON object size limit, + // but as the _id value gets repeated in oplog entries, the examples below are already at + // the fringe of what can be inserted. + const maxPayloadLength = maxPayloadLengthForIdLength(idLength); + assert.commandWorked(runInsert(idLength, maxPayloadLength)); + assert.commandWorked(runInsert(idLength + 1, maxPayloadLength - 1)); + assert.commandWorked(runInsert(idLength - 1, maxPayloadLength + 1)); + + assert.soon(() => { + const count = coll.find({}).readConcern('majority').itcount(); + return count == 3; + }); + + assert.commandFailedWithCode(runInsert(idLength, maxPayloadLength + 1), + ErrorCodes.BSONObjectTooLarge); + assert.commandFailedWithCode(runInsert(idLength + 1, maxPayloadLength), + ErrorCodes.BSONObjectTooLarge); + assert.commandFailedWithCode(runInsert(idLength + 1, maxPayloadLength + 1), + ErrorCodes.BSONObjectTooLarge); + }); + + // Now test documents with larger _id values. Large _id values can be a problem because the _id + // is repeated in the oplog entry. The size of an oplog entry is affected the size of the + // actually inserted document, the size of the _id value and a few other factors (e.g. the size + // of the namespace string). + + // Here, we produce documents with large _id values and a large payload. The documents + // themselves are exactly 16MiB large, and thus permitted. However, as the _id value is + // duplicated in the oplog entry, and because there are other fields that are stored in an oplog + // entry, the maximum _id length usable together with a large document payload is limited to + // slightly less than 16KiB. The value of 135 has been determined for this particular test, + // using the namespace string used by this test. It can be different for other namespace values. + const kMaxIdLengthUsableInOplogEntry = (16 << 10) - 135; + + const runInsertForIdLength = (idLength) => { + return runInsert(idLength, maxPayloadLengthForIdLength(idLength)); + }; + + // Still fits into oplog. + assert.commandWorked(runInsertForIdLength(kMaxIdLengthUsableInOplogEntry)); + + // Slighty too large for oplog. + assert.commandFailedWithCode(runInsertForIdLength(kMaxIdLengthUsableInOplogEntry + 1), + ErrorCodes.BSONObjectTooLarge); + + // Vastly too large for oplog. + assert.commandFailedWithCode(runInsertForIdLength(kMaxIdLengthUsableInOplogEntry * 2), + ErrorCodes.BSONObjectTooLarge); + + // Final check: insert a small document to verify that replication still works. + coll.remove({}); + assert.commandWorked(coll.insert({_id: "test"}, {writeConcern: {w: 3}})); + + assert.soon(() => { + const count = coll.find({}).readConcern('majority').itcount(); + return count == 1; + }); +} + +const replTest = new ReplSetTest({nodes: 3}); +replTest.startSet(); +replTest.initiate(); +runTest(replTest.getPrimary().getDB(kDbName)); +replTest.stopSet(); diff --git a/site_scons/site_tools/vcredist.py b/site_scons/site_tools/vcredist.py index 0f86629b2819e..1061eefb9302c 100644 --- a/site_scons/site_tools/vcredist.py +++ b/site_scons/site_tools/vcredist.py @@ -87,6 +87,14 @@ def generate(env): if not exists(env): return + if 'TARGET_ARCH' not in env or env['TARGET_ARCH'] is None: + env['TARGET_ARCH'] = "x86_64" + + if env['TARGET_ARCH'] not in target_arch_expansion_map: + raise Exception( + f"TARGET_ARCH={env['TARGET_ARCH']}, TARGET_ARCH must be in {target_arch_expansion_map.keys()} on windows." + ) + env.Tool("msvc") env.AddMethod(_get_merge_module_name_for_feature, "GetMergeModuleNameForFeature") @@ -128,6 +136,9 @@ def generate(env): if os.path.isdir(mergemodulepath): env["MSVS"]["VCREDISTMERGEMODULEPATH"] = mergemodulepath + if 'VSINSTALLDIR' in env: + env["MSVS"]["VSINSTALLDIR"] = env["VSINSTALLDIR"] + if not "VSINSTALLDIR" in env["MSVS"]: # Compute a VS version based on the VC version. VC 14.0 is VS 2015, VC diff --git a/src/mongo/bson/bsonobj.cpp b/src/mongo/bson/bsonobj.cpp index 54736ffbeeab3..64170b5e2c2d7 100644 --- a/src/mongo/bson/bsonobj.cpp +++ b/src/mongo/bson/bsonobj.cpp @@ -85,6 +85,20 @@ int compareObjects(const BSONObj& firstObj, MONGO_UNREACHABLE; } +/** + * Generic implementation to remove field 'name' from BSONObj 'obj', by producing a new BSONObj in + * 'builder' with the leftover fields. + */ +void rebuildObjectWithoutField(BSONObjBuilder& builder, const BSONObj& obj, StringData name) { + BSONObjIterator i(obj); + while (i.more()) { + BSONElement e = i.next(); + StringData fname = e.fieldNameStringData(); + if (name != fname) + builder.append(e); + } +} + } // namespace /* BSONObj ------------------------------------------------------------*/ @@ -713,15 +727,15 @@ BSONObj BSONObj::addFields(const BSONObj& from, } BSONObj BSONObj::removeField(StringData name) const { - BSONObjBuilder b; - BSONObjIterator i(*this); - while (i.more()) { - BSONElement e = i.next(); - const char* fname = e.fieldName(); - if (name != fname) - b.append(e); - } - return b.obj(); + BSONObjBuilder builder; + rebuildObjectWithoutField(builder, *this, name); + return builder.obj(); +} + +BSONObj BSONObj::removeField(StringData name, WireMessageSizeTrait) const { + BSONObjBuilder builder; + rebuildObjectWithoutField(builder, *this, name); + return builder.obj(); } BSONObj BSONObj::removeFields(const std::set& fields) const { diff --git a/src/mongo/bson/bsonobj.h b/src/mongo/bson/bsonobj.h index e228461b6209a..3a6fb3eba06ee 100644 --- a/src/mongo/bson/bsonobj.h +++ b/src/mongo/bson/bsonobj.h @@ -107,6 +107,14 @@ class BSONObj { constexpr static int MaxSize = BufferMaxSize; }; + /** + * Special size trait for validating wire messages, which may have a small size overhead for + * metadata compared to other BSON objects. + */ + struct WireMessageSizeTrait { + constexpr static int MaxSize = BSONObjMaxWireMessageSize; + }; + // Declared in bsonobj_comparator_interface.h. class ComparatorInterface; @@ -345,10 +353,18 @@ class BSONObj { /** * Remove specified field and return a new object with the remaining fields. - * slowish as builds a full new object + * Slowish, as it builds a full new object. + * This uses the DefaultSizeTrait to validate the size of the resulting BSONObj. */ BSONObj removeField(StringData name) const; + /** + * Remove specified field and return a new object with the remaining fields. + * Same as above, but using the WireMessageSizeTrait to validate the size of the resulting + * BSONObj. + */ + BSONObj removeField(StringData name, WireMessageSizeTrait) const; + /** * Remove specified fields and return a new object with the remaining fields. */ diff --git a/src/mongo/bson/util/builder.h b/src/mongo/bson/util/builder.h index e78787601fe16..525c77c7d57d3 100644 --- a/src/mongo/bson/util/builder.h +++ b/src/mongo/bson/util/builder.h @@ -76,6 +76,14 @@ const int BSONObjMaxUserSize = 16 * 1024 * 1024; */ const int BSONObjMaxInternalSize = BSONObjMaxUserSize + (16 * 1024); +/* + * Maximum internal BSONObj size used for wire messages. Using a slightly larger value here than the + * BSONObjMaxInternalSize allows for command metadata to be safely added to a command response, and + * also allows for safe deserialization thereof. This value should only be used for command + * serialization and deserialization, but not for building BSONObjs in any other situation. + */ +const int BSONObjMaxWireMessageSize = BSONObjMaxUserSize + (32 * 1024); + /** * Maximum size of a builder buffer and for BSONObj with BsonLargeSizeTrait. Limiting it to 27 bits * because SharedBuffer::Holder might bit pack information. Setting it to 125 MB to have some diff --git a/src/mongo/db/catalog/index_catalog_impl.cpp b/src/mongo/db/catalog/index_catalog_impl.cpp index a5666ec83b582..52b6dbd97ac26 100644 --- a/src/mongo/db/catalog/index_catalog_impl.cpp +++ b/src/mongo/db/catalog/index_catalog_impl.cpp @@ -51,6 +51,7 @@ #include "mongo/db/index/index_descriptor.h" #include "mongo/db/index/s2_access_method.h" #include "mongo/db/index/s2_bucket_access_method.h" +#include "mongo/db/index/wildcard_validation.h" #include "mongo/db/index_names.h" #include "mongo/db/jsobj.h" #include "mongo/db/keypattern.h" @@ -213,6 +214,23 @@ void IndexCatalogImpl::init(OperationContext* opCtx, "spec"_attr = spec); } + // Look for an invalid compound wildcard index. + if (IndexNames::findPluginName(keyPattern) == IndexNames::WILDCARD && + keyPattern.nFields() > 1 && spec.hasField("wildcardProjection")) { + auto validationStatus = + validateWildcardProjection(keyPattern, spec.getObjectField("wildcardProjection")); + if (!validationStatus.isOK()) { + LOGV2_OPTIONS(11389700, + {logv2::LogTag::kStartupWarnings}, + "Found a compound wildcard index with an invalid wildcardProjection. " + "Such indexes can no longer be created.", + "ns"_attr = collection->ns(), + "uuid"_attr = collection->uuid(), + "index"_attr = indexName, + "spec"_attr = spec); + } + } + auto descriptor = IndexDescriptor(_getAccessMethodName(keyPattern), spec); if (spec.hasField(IndexDescriptor::kExpireAfterSecondsFieldName)) { diff --git a/src/mongo/db/commands/getmore_cmd.cpp b/src/mongo/db/commands/getmore_cmd.cpp index 3b4bb9bd4cd96..f69d4443899ce 100644 --- a/src/mongo/db/commands/getmore_cmd.cpp +++ b/src/mongo/db/commands/getmore_cmd.cpp @@ -802,7 +802,7 @@ class GetMoreCmd final : public Command { auto ret = reply->getBodyBuilder().asTempObj(); CursorGetMoreReply::parse( IDLParserContext{"CursorGetMoreReply", false /* apiStrict */, tenantId}, - ret.removeField("ok")); + ret.removeField("ok", BSONObj::WireMessageSizeTrait{})); } const GetMoreCommandRequest _cmd; diff --git a/src/mongo/db/commands/profile.idl b/src/mongo/db/commands/profile.idl index 497e95624fe3f..48b20c2c05e29 100644 --- a/src/mongo/db/commands/profile.idl +++ b/src/mongo/db/commands/profile.idl @@ -66,6 +66,8 @@ commands: an alternative to slowms and sampleRate. The special value 'unset' removes the filter." optional: true + # WARNING: If adding a new parameter to this command, be sure to update the 'isReadOnly' + # special-case in the authorization logic. setProfilingFilterGlobally: description: "Parser for the 'setProfilingFilterGlobally' command." diff --git a/src/mongo/db/commands/profile_common.cpp b/src/mongo/db/commands/profile_common.cpp index 6a33e66a561f8..039ab04fb3484 100644 --- a/src/mongo/db/commands/profile_common.cpp +++ b/src/mongo/db/commands/profile_common.cpp @@ -44,6 +44,13 @@ namespace mongo { +namespace { + +bool isReadOnly(const ProfileCmdRequest& request) { + return !request.getSlowms() && !request.getSampleRate() && !request.getFilter(); +} +} // namespace + Status ProfileCmdBase::checkAuthForOperation(OperationContext* opCtx, const DatabaseName& dbName, const BSONObj& cmdObj) const { @@ -52,9 +59,9 @@ Status ProfileCmdBase::checkAuthForOperation(OperationContext* opCtx, auto request = ProfileCmdRequest::parse(IDLParserContext("profile"), cmdObj); const auto profilingLevel = request.getCommandParameter(); - if (profilingLevel < 0 && !request.getSlowms() && !request.getSampleRate()) { - // If the user just wants to view the current values of 'slowms' and 'sampleRate', they - // only need read rights on system.profile, even if they can't change the profiling level. + if (profilingLevel < 0 && isReadOnly(request)) { + // If the user just wants to view the current profiling settings, they only need read rights + // on system.profile, even if they can't change the profiling level. if (authzSession->isAuthorizedForActionsOnResource( ResourcePattern::forExactNamespace( NamespaceStringUtil::parseNamespaceFromRequest(dbName, "system.profile")), diff --git a/src/mongo/db/concurrency/lock_manager.cpp b/src/mongo/db/concurrency/lock_manager.cpp index e4c99cf0c8dac..333612a78039a 100644 --- a/src/mongo/db/concurrency/lock_manager.cpp +++ b/src/mongo/db/concurrency/lock_manager.cpp @@ -904,7 +904,7 @@ void LockManager::_buildLocksArray(const std::map& lockToClie auto o = BSONObjBuilder(locks->subobjStart()); if (forLogging) o.append("lockAddr", formatPtr(lock)); - o.append("resourceId", lock->resourceId.toString()); + o.append("resourceId", toStringForLogging(lock->resourceId)); struct { StringData key; LockRequest* iter; @@ -972,19 +972,45 @@ LockHead* LockManager::LockBucket::findOrInsert(ResourceId resId) { // // ResourceId // -std::string ResourceId::toString() const { +std::string toStringForLogging(const ResourceId& rId) { StringBuilder ss; - ss << "{" << _fullHash << ": " << resourceTypeName(getType()) << ", " << getHashId(); - if (getType() == RESOURCE_MUTEX) { - ss << ", " << Lock::ResourceMutex::getName(*this); - } - - if (getType() == RESOURCE_DATABASE || getType() == RESOURCE_COLLECTION) { - if (auto resourceName = ResourceCatalog::get(getGlobalServiceContext()).name(*this)) { + const auto type = rId.getType(); + ss << "{" << rId._fullHash << ": " << resourceTypeName(type) << ", " << rId.getHashId(); + if (type == RESOURCE_DATABASE || type == RESOURCE_COLLECTION) { + if (auto resourceName = ResourceCatalog::get(getGlobalServiceContext()).name(rId)) { ss << ", " << *resourceName; } } + if (type == RESOURCE_MUTEX) { + ss << ", " << Lock::ResourceMutex::getName(rId); + } + ss << "}"; + + return ss.str(); +} +std::string ResourceId::toStringForErrorMessage() const { + StringBuilder ss; + const auto type = getType(); + ss << "{" << resourceTypeName(type); + switch (type) { + case RESOURCE_GLOBAL: + ss << " : " << getHashId(); + break; + case RESOURCE_DATABASE: + case RESOURCE_COLLECTION: + if (auto resourceName = ResourceCatalog::get(getGlobalServiceContext()).name(*this)) { + ss << " : " << *resourceName; + } + break; + case RESOURCE_MUTEX: + ss << " : " << Lock::ResourceMutex::getName(*this); + break; + case ResourceTypesCount: + case RESOURCE_INVALID: + case RESOURCE_METADATA: + break; + } ss << "}"; return ss.str(); diff --git a/src/mongo/db/concurrency/lock_manager_defs.h b/src/mongo/db/concurrency/lock_manager_defs.h index 6d91108dd7365..513cbc1d3f2e7 100644 --- a/src/mongo/db/concurrency/lock_manager_defs.h +++ b/src/mongo/db/concurrency/lock_manager_defs.h @@ -42,6 +42,7 @@ #include "mongo/base/string_data.h" #include "mongo/config.h" #include "mongo/db/namespace_string.h" +#include "mongo/platform/random.h" namespace mongo { @@ -225,6 +226,14 @@ static const char* resourceGlobalIdName(ResourceGlobalId id) { return ResourceGlobalIdNames[static_cast(id)]; } +inline static uint64_t hashStringDataForResourceId(StringData str, uint64_t salt) { + // We salt the hash with a given random value to generate randomness in ResourceId selection on + // every restart. This aids in testing for detecting lock ordering issues. + char hash[16]; + MurmurHash3_x64_128(str.rawData(), str.size(), salt, hash); + return static_cast(ConstDataView(hash).read>()); +} + /** * Uniquely identifies a lockable resource. */ @@ -236,14 +245,19 @@ class ResourceId { public: ResourceId() : _fullHash(0) {} ResourceId(ResourceType type, const NamespaceString& nss) - : _fullHash(fullHash(type, hashStringData(nss.toStringWithTenantId()))) { + : _fullHash(fullHash( + type, + hashStringDataForResourceId(nss.toStringWithTenantId(), kHashingSaltForResourceId))) { verifyNoResourceMutex(type); } ResourceId(ResourceType type, const DatabaseName& dbName) - : _fullHash(fullHash(type, hashStringData(dbName.toStringWithTenantId()))) { + : _fullHash(fullHash(type, + hashStringDataForResourceId(dbName.toStringWithTenantId(), + kHashingSaltForResourceId))) { verifyNoResourceMutex(type); } - ResourceId(ResourceType type, StringData str) : _fullHash(fullHash(type, hashStringData(str))) { + ResourceId(ResourceType type, StringData str) + : _fullHash(fullHash(type, hashStringDataForResourceId(str, kHashingSaltForResourceId))) { // Resources of type database, collection, or tenant must never be passed as a raw string. invariant(type != RESOURCE_DATABASE && type != RESOURCE_COLLECTION); verifyNoResourceMutex(type); @@ -273,7 +287,9 @@ class ResourceId { return _fullHash & (std::numeric_limits::max() >> resourceTypeBits); } - std::string toString() const; + // String representation of the resource type that omits the parts not intended to be read by + // humans. Intended to be used for error messages that are returned to the client. + std::string toStringForErrorMessage() const; template friend H AbslHashValue(H h, const ResourceId& resource) { @@ -281,6 +297,8 @@ class ResourceId { } private: + friend std::string toStringForLogging(const ResourceId&); + ResourceId(uint64_t fullHash) : _fullHash(fullHash) {} // Used to allow Lock::ResourceMutex to create ResourceIds with RESOURCE_MUTEX type @@ -307,13 +325,17 @@ class ResourceId { (hashId & (std::numeric_limits::max() >> resourceTypeBits)); } - static uint64_t hashStringData(StringData str) { - char hash[16]; - MurmurHash3_x64_128(str.rawData(), str.size(), 0, hash); - return static_cast(ConstDataView(hash).read>()); - } + static inline const uint64_t kHashingSaltForResourceId = [] { + SecureUrbg entropy; + const auto result = entropy(); + static_assert(std::is_same_v, uint64_t>, + "salting hash entropy must be a uint64"); + return result; + }(); }; +std::string toStringForLogging(const ResourceId&); + #ifndef MONGO_CONFIG_DEBUG_BUILD // Treat the resource ids as 64-bit integers in release mode in order to ensure we do // not spend too much time doing comparisons for hashing. diff --git a/src/mongo/db/concurrency/lock_manager_test.cpp b/src/mongo/db/concurrency/lock_manager_test.cpp index a1bd0e1403c81..1d161aed98d86 100644 --- a/src/mongo/db/concurrency/lock_manager_test.cpp +++ b/src/mongo/db/concurrency/lock_manager_test.cpp @@ -75,6 +75,18 @@ TEST(ResourceId, Masking) { } } +TEST(ResourceId, SaltingWorks) { + const std::string collAName = "db1.collA"; + + const uint64_t salt1 = 0; + const uint64_t salt2 = 1; + + const auto id1Salt1 = hashStringDataForResourceId(collAName, salt1); + const auto id1Salt2 = hashStringDataForResourceId(collAName, salt2); + + ASSERT_NE(id1Salt1, id1Salt2); +} + class ResourceIdTest : public unittest::Test {}; DEATH_TEST_F(ResourceIdTest, StringConstructorMustNotBeCollection, "invariant") { diff --git a/src/mongo/db/concurrency/lock_state.cpp b/src/mongo/db/concurrency/lock_state.cpp index bc8f377467600..079dcbf8207e3 100644 --- a/src/mongo/db/concurrency/lock_state.cpp +++ b/src/mongo/db/concurrency/lock_state.cpp @@ -209,7 +209,7 @@ void LockerImpl::dump() const { BSONObj toBSON() const { BSONObjBuilder b; - b.append("key", key.toString()); + b.append("key", toStringForLogging(key)); b.append("status", lockRequestStatusName(status)); b.append("recursiveCount", static_cast(recursiveCount)); b.append("unlockPending", static_cast(unlockPending)); @@ -376,7 +376,7 @@ void LockerImpl::reacquireTicket(OperationContext* opCtx) { fmt::format("Unable to acquire ticket with mode '{}' due to detected lock " "conflict for resource {}", _modeForTicket, - it.key().toString()), + it.key().toStringForErrorMessage()), !getGlobalLockManager()->hasConflictingRequests(it.key(), it.objAddr())); } } while (!_acquireTicket(opCtx, _modeForTicket, Date_t::now() + Milliseconds{100})); @@ -866,7 +866,7 @@ LockResult LockerImpl::_lockBegin(OperationContext* opCtx, ResourceId resId, Loc invariant(!opCtx->recoveryUnit()->isTimestamped(), str::stream() << "Operation holding open an oplog hole tried to acquire locks. ResourceId: " - << resId << ", mode: " << modeName(mode)); + << toStringForLogging(resId) << ", mode: " << modeName(mode)); } LockRequest* request; @@ -956,7 +956,7 @@ void LockerImpl::_lockComplete(OperationContext* opCtx, invariant(!opCtx->recoveryUnit()->isTimestamped(), str::stream() << "Operation holding open an oplog hole tried to acquire locks. ResourceId: " - << resId << ", mode: " << modeName(mode)); + << toStringForLogging(resId) << ", mode: " << modeName(mode)); } // Clean up the state on any failed lock attempts. @@ -973,7 +973,8 @@ void LockerImpl::_lockComplete(OperationContext* opCtx, if (!_uninterruptibleLocksRequested && isUserOperation && MONGO_unlikely(failNonIntentLocksIfWaitNeeded.shouldFail())) { uassert(ErrorCodes::LockTimeout, - str::stream() << "Cannot immediately acquire lock '" << resId.toString() + str::stream() << "Cannot immediately acquire lock '" + << resId.toStringForErrorMessage() << "'. Timing out due to failpoint.", (mode == MODE_IS || mode == MODE_IX)); } @@ -1037,8 +1038,8 @@ void LockerImpl::_lockComplete(OperationContext* opCtx, onTimeout(); } std::string timeoutMessage = str::stream() - << "Unable to acquire " << modeName(mode) << " lock on '" << resId.toString() - << "' within " << timeout << "."; + << "Unable to acquire " << modeName(mode) << " lock on '" + << resId.toStringForErrorMessage() << "' within " << timeout << "."; if (opCtx && opCtx->getClient()) { timeoutMessage = str::stream() << timeoutMessage << " opId: " << opCtx->getOpID() diff --git a/src/mongo/db/concurrency/resource_catalog_test.cpp b/src/mongo/db/concurrency/resource_catalog_test.cpp index 026e5cf587adf..09de8da80a8c0 100644 --- a/src/mongo/db/concurrency/resource_catalog_test.cpp +++ b/src/mongo/db/concurrency/resource_catalog_test.cpp @@ -47,7 +47,7 @@ class ResourceCatalogTest : public unittest::Test { NamespaceString secondCollection = NamespaceString::createNamespaceString_forTest(boost::none, "1626936312"); - ResourceId secondResourceId{RESOURCE_COLLECTION, secondCollection}; + ResourceId secondResourceId = firstResourceId; NamespaceString thirdCollection = NamespaceString::createNamespaceString_forTest(boost::none, "2930102946"); diff --git a/src/mongo/db/exec/document_value/document_value_test.cpp b/src/mongo/db/exec/document_value/document_value_test.cpp index 9ca60e7200df0..9022c0e908176 100644 --- a/src/mongo/db/exec/document_value/document_value_test.cpp +++ b/src/mongo/db/exec/document_value/document_value_test.cpp @@ -168,6 +168,77 @@ TEST(DocumentSerialization, CannotSerializeDocumentThatExceedsDepthLimit) { throwaway.abandon(); } +TEST(DocumentDepthCalculations, Sanity) { + { + // A scalar has depth 0. + ASSERT_EQ(0, Value(1).depth(BSONDepth::getMaxAllowableDepth())); + } + { + // Nesting documents increments depth. + int32_t initialDepth = 1; + MutableDocument md; + md.addField("a", Value(1)); + Document doc(md.freeze()); + Value val(doc); + int32_t iters = 16; + ASSERT_EQ(initialDepth, val.depth(BSONDepth::getMaxAllowableDepth())); + for (int32_t idx = 0; idx < iters; ++idx) { + MutableDocument md; + md.addField("a", Value(doc)); + doc = md.freeze(); + Value val(doc); + ASSERT_EQ(idx + initialDepth + 1, val.depth(BSONDepth::getMaxAllowableDepth())); + } + } + { + // Simple document with no nested paths has depth 1. + Value val(BSON("a" << 1)); + ASSERT_EQ(1, val.depth(BSONDepth::getMaxAllowableDepth())); + } + { + // Depth is max of children. + BSONObj bson = BSON("a" << 1 << "b" << BSON("c" << 1)); + Document document = fromBson(bson); + Value val(document); + ASSERT_EQ(2, val.depth(BSONDepth::getMaxAllowableDepth())); + } + { + // Arrays increment depth. + BSONObj bson = BSON("a" << BSON_ARRAY(1 << 1)); + Value val(fromBson(bson)); + ASSERT_EQ(2, val.depth(BSONDepth::getMaxAllowableDepth())); + } + { + // Array length does not affect depth. + BSONObj bson = BSON("a" << BSON_ARRAY(1 << 1)); + BSONObj bson2 = BSON("a" << BSON_ARRAY(1 << 1 << 1)); + Value val(fromBson(bson)); + Value val2(fromBson(bson2)); + ASSERT_EQ(val.depth(BSONDepth::getMaxAllowableDepth()), + val2.depth(BSONDepth::getMaxAllowableDepth())); + } + { + // Nested arrays increment depth. + BSONObj bson = BSON("a" << BSON_ARRAY(1 << BSON_ARRAY(1 << 1))); + Value val(fromBson(bson)); + ASSERT_EQ(3, val.depth(BSONDepth::getMaxAllowableDepth())); + } + { + // If maxDepth at least document depth, this function returns -1. + BSONObj bson = BSON("a" << 1 << "b" << BSON("c" << 1)); + Document document = fromBson(bson); + Value val(document); + int32_t depth = 2; + for (int32_t maxDepth = 0; maxDepth < 2 * depth; maxDepth++) { + if (maxDepth <= depth) { + ASSERT_EQ(-1, val.depth(maxDepth)); + } else { + ASSERT_EQ(depth, val.depth(maxDepth)); + } + } + } +} + TEST(DocumentGetFieldNonCaching, UncachedTopLevelFields) { BSONObj bson = BSON("scalar" << 1 << "array" << BSON_ARRAY(1 << 2 << 3) << "scalar2" << true); Document document = fromBson(bson); diff --git a/src/mongo/db/exec/document_value/value.cpp b/src/mongo/db/exec/document_value/value.cpp index 74e10c4deb36f..bb78d59af98ec 100644 --- a/src/mongo/db/exec/document_value/value.cpp +++ b/src/mongo/db/exec/document_value/value.cpp @@ -1207,6 +1207,42 @@ size_t Value::getApproximateSize() const { MONGO_verify(false); } +int32_t Value::depth(int32_t maxDepth, int32_t curDepth /*=0*/) const { + if (curDepth >= maxDepth) { + return -1; + } + int32_t maxChildDepth = -1; + switch (getType()) { + case BSONType::Object: { + FieldIterator f(getDocument()); + while (f.more()) { + auto fp = f.next(); + int32_t childDepth = fp.second.depth(maxDepth, curDepth + 1); + if (childDepth == -1) { + return -1; + } + maxChildDepth = std::max(maxChildDepth, childDepth); + } + break; + } + case BSONType::Array: { + for (const auto& val : getArray()) { + int32_t childDepth = val.depth(maxDepth, curDepth + 1); + if (childDepth == -1) { + return -1; + } + maxChildDepth = std::max(maxChildDepth, childDepth); + } + break; + } + default: + // No op + break; + } + // Increment depth to account for this level. + return maxChildDepth + 1; +} + string Value::toString() const { // TODO use StringBuilder when operator << is ready stringstream out; diff --git a/src/mongo/db/exec/document_value/value.h b/src/mongo/db/exec/document_value/value.h index 12c08c421911e..e5b9297618bc0 100644 --- a/src/mongo/db/exec/document_value/value.h +++ b/src/mongo/db/exec/document_value/value.h @@ -366,6 +366,12 @@ class Value { /// Get the approximate memory size of the value, in bytes. Includes sizeof(Value) size_t getApproximateSize() const; + /** + * Returns object/array depth of this value. Returns -1 if the depth is at least 'maxDepth'. + * Returns 0 on a scalar value. + */ + int32_t depth(int32_t maxDepth, int32_t curDepth = 0) const; + /** * Calculate a hash value. * diff --git a/src/mongo/db/exec/geo_near.cpp b/src/mongo/db/exec/geo_near.cpp index ca7c1859c0b06..4554a448c8308 100644 --- a/src/mongo/db/exec/geo_near.cpp +++ b/src/mongo/db/exec/geo_near.cpp @@ -72,7 +72,7 @@ static double computeGeoNearDistance(const GeoNearParams& nearParams, WorkingSet // // Must have an object in order to get geometry out of it. - invariant(member->hasObj()); + tassert(9911912, "", member->hasObj()); CRS queryCRS = nearParams.nearQuery->centroid->crs; @@ -118,7 +118,7 @@ static double computeGeoNearDistance(const GeoNearParams& nearParams, WorkingSet if (nearParams.nearQuery->unitsAreRadians) { // Hack for nearSphere // TODO: Remove nearSphere? - invariant(SPHERE == queryCRS); + tassert(9911927, "", SPHERE == queryCRS); member->metadata().setGeoNearDistance(minDistance / kRadiusOfEarthInMeters); } else { member->metadata().setGeoNearDistance(minDistance); @@ -139,7 +139,7 @@ static R2Annulus geoNearDistanceBounds(const GeoNearExpression& query) { return R2Annulus(query.centroid->oldPoint, query.minDistance, query.maxDistance); } - invariant(SPHERE == queryCRS); + tassert(9911913, "", SPHERE == queryCRS); // TODO: Tighten this up a bit by making a CRS for "sphere with radians" double minDistance = query.minDistance; @@ -188,8 +188,8 @@ static R2Annulus twoDDistanceBounds(const GeoNearParams& nearParams, } else { // Spherical queries have upper bounds set by the earth - no-op // TODO: Wrapping errors would creep in here if nearSphere wasn't defined to not wrap - invariant(SPHERE == queryCRS); - invariant(!nearParams.nearQuery->isWrappingQuery); + tassert(9911914, "", SPHERE == queryCRS); + tassert(9911915, "", !nearParams.nearQuery->isWrappingQuery); } return fullBounds; @@ -247,13 +247,13 @@ void GeoNear2DStage::DensityEstimator::buildIndexScan(ExpressionContext* expCtx, builder.obj(), BoundInclusion::kIncludeBothStartAndEndKeys)); } - invariant(oil.isValidFor(1)); + tassert(9911916, "", oil.isValidFor(1)); // Intersect the $near bounds we just generated into the bounds we have for anything else // in the scan (i.e. $within) IndexBoundsBuilder::intersectize(oil, &scanParams.bounds.fields[twoDFieldPosition]); - invariant(!_indexScan); + tassert(9911917, "", !_indexScan); _indexScan = new IndexScan(expCtx, _collection, scanParams, workingSet, nullptr); _children->emplace_back(_indexScan); } @@ -312,7 +312,7 @@ PlanStage::StageState GeoNear2DStage::DensityEstimator::work(ExpressionContext* // Advance to the next level and search again. _currentLevel--; // Reset index scan for the next level. - invariant(_children->back().get() == _indexScan); + tassert(9911918, "", _children->back().get() == _indexScan); _indexScan = nullptr; _children->pop_back(); return PlanStage::NEED_TIME; @@ -369,7 +369,7 @@ PlanStage::StageState GeoNear2DStage::initialize(OperationContext* opCtx, _boundsIncrement = 3 * estimatedDistance; } - invariant(_boundsIncrement > 0.0); + tassert(9911919, "", _boundsIncrement > 0.0); // Clean up _densityEstimator.reset(nullptr); @@ -432,7 +432,7 @@ static double min2DBoundsIncrement(const GeoNearExpression& query, if (FLAT == queryCRS) return minBoundsIncrement; - invariant(SPHERE == queryCRS); + tassert(9911920, "", SPHERE == queryCRS); // If this is a spherical query, units are in meters - this is just a heuristic return minBoundsIncrement * kMetersPerDegreeAtEquator; @@ -528,7 +528,7 @@ std::unique_ptr GeoNear2DStage::nextInterval( max(0.0, nextBounds.getInner() - epsilon), nextBounds.getOuter() + epsilon)); } else { - invariant(SPHERE == queryCRS); + tassert(9911921, "", SPHERE == queryCRS); // TODO: As above, make this consistent with $within : $centerSphere // Our intervals aren't in the same CRS as our index, so we need to adjust them @@ -743,13 +743,13 @@ void GeoNear2DSphereStage::DensityEstimator::buildIndexScan(ExpressionContext* e // The search area expands 4X each time. // Return the neighbors of closest vertex to this cell at the given level. - invariant(_currentLevel < centerId.level()); + tassert(9911922, "", _currentLevel < centerId.level()); centerId.AppendVertexNeighbors(_currentLevel, &neighbors); ExpressionMapping::S2CellIdsToIntervals(neighbors, _indexParams.indexVersion, coveredIntervals); // Index scan - invariant(!_indexScan); + tassert(9911923, "", !_indexScan); _indexScan = new IndexScan(expCtx, _collection, scanParams, workingSet, nullptr); _children->emplace_back(_indexScan); } @@ -808,7 +808,7 @@ PlanStage::StageState GeoNear2DSphereStage::DensityEstimator::work(ExpressionCon // Advance to the next level and search again. _currentLevel--; // Reset index scan for the next level. - invariant(_children->back().get() == _indexScan); + tassert(9911924, "", _children->back().get() == _indexScan); _indexScan = nullptr; _children->pop_back(); return PlanStage::NEED_TIME; @@ -853,7 +853,7 @@ PlanStage::StageState GeoNear2DSphereStage::initialize(OperationContext* opCtx, // // At the coarsest level, the search area is the whole earth. _boundsIncrement = 3 * estimatedDistance; - invariant(_boundsIncrement > 0.0); + tassert(9911925, "", _boundsIncrement > 0.0); // Clean up _densityEstimator.reset(nullptr); @@ -883,7 +883,7 @@ std::unique_ptr GeoNear2DSphereStage::nextInterval( _boundsIncrement /= 2; } - invariant(_boundsIncrement > 0.0); + tassert(9911926, "", _boundsIncrement > 0.0); R2Annulus nextBounds(_currBounds.center(), _currBounds.getOuter(), @@ -913,7 +913,7 @@ std::unique_ptr GeoNear2DSphereStage::nextInterval( // Generate a covering that does not intersect with any previous coverings S2CellUnion coverUnion; coverUnion.InitSwap(&cover); - invariant(cover.empty()); + tassert(9911910, "", cover.empty()); S2CellUnion diffUnion; diffUnion.GetDifference(&coverUnion, &_scannedCells); for (const auto& cellId : diffUnion.cell_ids()) { diff --git a/src/mongo/db/geo/big_polygon.cpp b/src/mongo/db/geo/big_polygon.cpp index c4cffc35cea4a..a516eef1d3beb 100644 --- a/src/mongo/db/geo/big_polygon.cpp +++ b/src/mongo/db/geo/big_polygon.cpp @@ -217,14 +217,14 @@ bool BigSimplePolygon::VirtualContainsPoint(const S2Point& p) const { } void BigSimplePolygon::Encode(Encoder* const encoder) const { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911951); } bool BigSimplePolygon::Decode(Decoder* const decoder) { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911952); } bool BigSimplePolygon::DecodeWithinScope(Decoder* const decoder) { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911953); } } // namespace mongo diff --git a/src/mongo/db/geo/geometry_container.cpp b/src/mongo/db/geo/geometry_container.cpp index 1c05147d1cc92..bb490e9414a23 100644 --- a/src/mongo/db/geo/geometry_container.cpp +++ b/src/mongo/db/geo/geometry_container.cpp @@ -50,7 +50,7 @@ bool GeometryContainer::isPoint() const { } PointWithCRS GeometryContainer::getPoint() const { - invariant(isPoint()); + tassert(9911939, "", isPoint()); return *_point; } @@ -85,7 +85,7 @@ const S2Region& GeometryContainer::getS2Region() const { } else if (nullptr != _multiPolygon) { return *_s2Region; } else { - invariant(nullptr != _geometryCollection); + tassert(9911928, "", nullptr != _geometryCollection); return *_s2Region; } } @@ -982,8 +982,7 @@ Status GeometryContainer::parseFromGeoJSON(bool skipValidation) { } } } else { - // Should not reach here. - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911954); } // Check parsing result. @@ -1115,7 +1114,7 @@ string GeometryContainer::getDebugType() const { } else if (nullptr != _geometryCollection) { return "gc"; } else { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911955); return ""; } } @@ -1142,7 +1141,7 @@ CRS GeometryContainer::getNativeCRS() const { } else if (nullptr != _geometryCollection) { return SPHERE; } else { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911956); return FLAT; } } @@ -1167,7 +1166,7 @@ bool GeometryContainer::supportsProject(CRS otherCRS) const { } else if (nullptr != _multiPolygon) { return _multiPolygon->crs == otherCRS; } else { - invariant(nullptr != _geometryCollection); + tassert(9911929, "", nullptr != _geometryCollection); return SPHERE == otherCRS; } } @@ -1181,7 +1180,7 @@ void GeometryContainer::projectInto(CRS otherCRS) { return; } - invariant(nullptr != _point); + tassert(9911930, "", nullptr != _point); ShapeProjection::projectInto(_point.get(), otherCRS); } @@ -1229,7 +1228,7 @@ static double s2MinDistanceRad(const S2Point& s2Point, for (vector::const_iterator it = geometryCollection.points.begin(); it != geometryCollection.points.end(); ++it) { - invariant(SPHERE == it->crs); + tassert(9911931, "", SPHERE == it->crs); double nextDistance = S2Distance::distanceRad(s2Point, it->point); if (minDistance < 0 || nextDistance < minDistance) { minDistance = nextDistance; @@ -1237,7 +1236,7 @@ static double s2MinDistanceRad(const S2Point& s2Point, } for (const auto& line : geometryCollection.lines) { - invariant(SPHERE == line->crs); + tassert(9911932, "", SPHERE == line->crs); double nextDistance = S2Distance::minDistanceRad(s2Point, line->line); if (minDistance < 0 || nextDistance < minDistance) { minDistance = nextDistance; @@ -1245,9 +1244,9 @@ static double s2MinDistanceRad(const S2Point& s2Point, } for (const auto& polygon : geometryCollection.polygons) { - invariant(SPHERE == polygon->crs); + tassert(9911933, "", SPHERE == polygon->crs); // We don't support distances for big polygons yet. - invariant(polygon->s2Polygon); + tassert(9911934, "", polygon->s2Polygon); double nextDistance = S2Distance::minDistanceRad(s2Point, *(polygon->s2Polygon)); if (minDistance < 0 || nextDistance < minDistance) { minDistance = nextDistance; @@ -1282,7 +1281,7 @@ double GeometryContainer::minDistance(const PointWithCRS& otherPoint) const { const CRS crs = getNativeCRS(); if (FLAT == crs) { - invariant(nullptr != _point); + tassert(9911935, "", nullptr != _point); if (FLAT == otherPoint.crs) { return distance(_point->oldPoint, otherPoint.oldPoint); @@ -1292,7 +1291,7 @@ double GeometryContainer::minDistance(const PointWithCRS& otherPoint) const { Point(latLng.lng().degrees(), latLng.lat().degrees())); } } else { - invariant(SPHERE == crs); + tassert(9911936, "", SPHERE == crs); double minDistance = -1; @@ -1310,7 +1309,7 @@ double GeometryContainer::minDistance(const PointWithCRS& otherPoint) const { minDistance = S2Distance::minDistanceRad(otherPoint.point, _line->line); } else if (nullptr != _polygon) { // We don't support distances for big polygons yet. - invariant(nullptr != _polygon->s2Polygon); + tassert(9911937, "", nullptr != _polygon->s2Polygon); minDistance = S2Distance::minDistanceRad(otherPoint.point, *_polygon->s2Polygon); } else if (nullptr != _cap) { minDistance = S2Distance::minDistanceRad(otherPoint.point, _cap->cap); @@ -1324,7 +1323,7 @@ double GeometryContainer::minDistance(const PointWithCRS& otherPoint) const { minDistance = s2MinDistanceRad(otherPoint.point, *_geometryCollection); } - invariant(minDistance != -1); + tassert(9911938, "", minDistance != -1); return minDistance * kRadiusOfEarthInMeters; } } diff --git a/src/mongo/db/geo/geoparser.cpp b/src/mongo/db/geo/geoparser.cpp index 53b163a60aca5..7d4b643d6b0fa 100644 --- a/src/mongo/db/geo/geoparser.cpp +++ b/src/mongo/db/geo/geoparser.cpp @@ -754,8 +754,7 @@ Status GeoParser::parseGeometryCollection(const BSONObj& obj, out->multiPolygons.push_back(std::make_unique()); status = parseMultiPolygon(geoObj, skipValidation, out->multiPolygons.back().get()); } else { - // Should not reach here. - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911957); } // Check parsing result. diff --git a/src/mongo/db/geo/hash.cpp b/src/mongo/db/geo/hash.cpp index 7983895dc7785..574d9d7ff17ca 100644 --- a/src/mongo/db/geo/hash.cpp +++ b/src/mongo/db/geo/hash.cpp @@ -606,7 +606,7 @@ GeoHash GeoHash::parent() const { void GeoHash::appendVertexNeighbors(unsigned level, vector* output) const { - invariant(level >= 0 && level < _bits); + tassert(9911940, "", level >= 0 && level < _bits); // Parent at the given level. GeoHash parentHash = parent(level); @@ -906,8 +906,8 @@ double GeoHashConverter::sizeOfDiag(const GeoHash& a) const { // Relative error = epsilon_(max-min). ldexp() is just a direct translation to // floating point exponent, and should be exact. double GeoHashConverter::sizeEdge(unsigned level) const { - invariant(level >= 0); - invariant((int)level <= _params.bits); + tassert(9911941, "", level >= 0); + tassert(9911942, "", (int)level <= _params.bits); #pragma warning(push) // C4146: unary minus operator applied to unsigned type, result still unsigned #pragma warning(disable : 4146) diff --git a/src/mongo/db/geo/r2_region_coverer.cpp b/src/mongo/db/geo/r2_region_coverer.cpp index 8f88d087d1a9d..cefe31eb4bbc0 100644 --- a/src/mongo/db/geo/r2_region_coverer.cpp +++ b/src/mongo/db/geo/r2_region_coverer.cpp @@ -206,7 +206,7 @@ void R2RegionCoverer::addCandidate(Candidate* candidate) { // Dones't take ownership of "candidate" int R2RegionCoverer::expandChildren(Candidate* candidate) { GeoHash childCells[4]; - invariant(candidate->cell.subdivide(childCells)); + tassert(9911943, "", candidate->cell.subdivide(childCells)); int numTerminals = 0; for (int i = 0; i < 4; ++i) { diff --git a/src/mongo/db/geo/shapes.cpp b/src/mongo/db/geo/shapes.cpp index 56414e5c1c14d..dfd97e164adee 100644 --- a/src/mongo/db/geo/shapes.cpp +++ b/src/mongo/db/geo/shapes.cpp @@ -512,7 +512,7 @@ std::unique_ptr MultiLineWithCRS::clone() const { cloned->crs = crs; for (const auto& line : lines) { - invariant(line); + tassert(9911944, "", line); cloned->lines.emplace_back(line->Clone()); } @@ -524,7 +524,7 @@ std::unique_ptr MultiPolygonWithCRS::clone() const { cloned->crs = crs; for (const auto& polygon : polygons) { - invariant(polygon); + tassert(9911945, "", polygon); cloned->polygons.emplace_back(polygon->Clone()); } @@ -835,7 +835,7 @@ bool ShapeProjection::supportsProject(const PointWithCRS& point, const CRS crs) if (point.crs == crs || point.crs == SPHERE) return true; - invariant(point.crs == FLAT); + tassert(9911946, "", point.crs == FLAT); // If crs is FLAT, we might be able to upgrade the point to SPHERE if it's a valid SPHERE // point (lng/lat in bounds). In this case, we can use FLAT data with SPHERE predicates. return isValidLngLat(point.oldPoint.x, point.oldPoint.y); @@ -853,7 +853,7 @@ void ShapeProjection::projectInto(PointWithCRS* point, CRS crs) { if (FLAT == point->crs) { // Prohibit projection to STRICT_SPHERE CRS - invariant(SPHERE == crs); + tassert(9911947, "", SPHERE == crs); // Note that it's (lat, lng) for S2 but (lng, lat) for MongoDB. S2LatLng latLng = S2LatLng::FromDegrees(point->oldPoint.y, point->oldPoint.x).Normalized(); @@ -865,7 +865,7 @@ void ShapeProjection::projectInto(PointWithCRS* point, CRS crs) { } // Prohibit projection to STRICT_SPHERE CRS - invariant(SPHERE == point->crs && FLAT == crs); + tassert(9911948, "", SPHERE == point->crs && FLAT == crs); // Just remove the additional spherical information point->point = S2Point(); point->cell = S2Cell(); @@ -877,7 +877,7 @@ void ShapeProjection::projectInto(PolygonWithCRS* polygon, CRS crs) { return; // Only project from STRICT_SPHERE to SPHERE - invariant(STRICT_SPHERE == polygon->crs && SPHERE == crs); + tassert(9911949, "", STRICT_SPHERE == polygon->crs && SPHERE == crs); polygon->crs = SPHERE; } diff --git a/src/mongo/db/index/wildcard_validation.cpp b/src/mongo/db/index/wildcard_validation.cpp index 57c17ef7680ba..5301c25aa3f6d 100644 --- a/src/mongo/db/index/wildcard_validation.cpp +++ b/src/mongo/db/index/wildcard_validation.cpp @@ -36,7 +36,7 @@ namespace mongo { namespace { static const StringData idFieldName = "_id"; /* - * Validate that wildcatdProject fields have no overlapping. It takes a sorted list of the + * Validate that wildcardProjection fields do not overlap. It takes a sorted list of the * projection fields. */ Status validateOverlappingFieldsInWildcardProjectionOnly( @@ -110,7 +110,7 @@ Status validateWildcardIndex(const BSONObj& keyPattern) { Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& pathProjection) { if (pathProjection.isEmpty()) { - return {ErrorCodes::Error{7246205}, "WildcardProjection must be non-empty if specified."}; + return {ErrorCodes::Error{7246205}, "WildcardProjection must be non-empty if specified"}; } // Prepare data for validation. @@ -152,8 +152,41 @@ Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& path return status; } - // test overlappings between index keys and wildcard projection - { + // The wildcardProjection cannot combine inclusion and exclusion statements, with the exception + // that _id may be excluded for inclusion projections and included for exclusion projections. + if (!projectionIncludedFields.empty() && !projectionExcludedFields.empty()) { + const FieldRef idFieldRef{idFieldName}; + const bool idOnlyExclusion = + projectionExcludedFields.size() == 1 && projectionExcludedFields.front() == idFieldRef; + const bool idOnlyInclusion = + projectionIncludedFields.size() == 1 && projectionIncludedFields.front() == idFieldRef; + + // In order for the projection to be valid when there are both inclusions and exclusions, + // _id has to be the sole field whose inclusion/exclusion value does not match the others. + if (idOnlyExclusion && idOnlyInclusion) { + return {ErrorCodes::Error{11368500}, + "The wildcard projection both excludes and includes _id"}; + } else if (!idOnlyExclusion && !idOnlyInclusion) { + return {ErrorCodes::Error{7246211}, + "The wildcardProjection cannot combine inclusion and exclusion statements, " + "with the exception that _id may be excluded for inclusion projections and " + "included for exclusion projections"}; + } + + // If _id is the only excluded field, ignore the exclusion in the checks below. For example, + // we can treat {_id: 0, a: 1} as just {a: 1}. In wildcard indexes (unlike regular + // projections) _id is excluded by default. + if (idOnlyExclusion) { + projectionExcludedFields.clear(); + } else { + // Here idOnlyInclusion is implied from the checks above. Similarly, ignore an _id-only + // inclusion. + projectionIncludedFields.clear(); + } + } + + // There cannot be overlap between the index keys and the wildcard projection's inclusions. + if (!projectionIncludedFields.empty()) { auto indexPos = indexFields.begin(); auto projectionPos = projectionIncludedFields.begin(); while (indexPos != indexFields.end() && projectionPos != projectionIncludedFields.end()) { @@ -162,8 +195,8 @@ Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& path str::stream() << "Index Key and Wildcard Projection cannot contain " "overlapping fields, however '" - << indexPos->dottedField() << "' index field is ovverlapping with '" - << projectionPos->dottedField() << "' wildcardProjection path."}; + << indexPos->dottedField() << "' index field is overlapping with '" + << projectionPos->dottedField() << "' wildcardProjection path"}; } int cmp = projectionPos->compare(*indexPos); @@ -173,14 +206,12 @@ Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& path ++indexPos; } } - } - - const FieldRef idFieldRef{idFieldName}; - const bool idOnlyExclusion = - projectionExcludedFields.size() == 1 && projectionExcludedFields.front() == idFieldRef; + } else { + tassert(11368501, + "Expected projectionExcludedFields to be populated", + !projectionExcludedFields.empty()); - // test test wildcard projects exclude all regular index fields - if (!projectionExcludedFields.empty() && !idOnlyExclusion) { + // If the wildcardProjection is an exclusion, it must exclude all regular index fields. auto indexPos = indexFields.begin(); auto projectionPos = projectionExcludedFields.begin(); while (indexPos != indexFields.end() && projectionPos != projectionExcludedFields.end()) { @@ -195,7 +226,7 @@ Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& path return {ErrorCodes::Error{7246209}, str::stream() << "wildcardProjection paths must exclude all regular " "index fields, however '" - << indexPos->dottedField() << "'is not excluded."}; + << indexPos->dottedField() << "'is not excluded"}; } } } @@ -204,22 +235,7 @@ Status validateWildcardProjection(const BSONObj& keyPattern, const BSONObj& path return {ErrorCodes::Error{7246210}, str::stream() << "wildcardProjection paths must exclude all regular " "index fields, however '" - << indexPos->dottedField() << "'is not excluded."}; - } - } - - // With the exception of explicitly including _id field, you cannot combine inclusion and - // exclusion statements in the wildcardProjection document. - if (!projectionIncludedFields.empty() && !projectionExcludedFields.empty()) { - const bool idOnlyInclusion = - projectionIncludedFields.size() == 1 && projectionIncludedFields.front() == idFieldRef; - const bool idIsSingleField = idOnlyExclusion || idOnlyInclusion; - if (!idIsSingleField) { - return { - ErrorCodes::Error{7246211}, - str::stream() - << "Inclusion and exclusion statements cannot combine in the " - "wildcardProjection with an exception of explicitly including _id field"}; + << indexPos->dottedField() << "'is not excluded"}; } } diff --git a/src/mongo/db/index/wildcard_validation_test.cpp b/src/mongo/db/index/wildcard_validation_test.cpp index bac71b0309e9d..f7a4032d9af86 100644 --- a/src/mongo/db/index/wildcard_validation_test.cpp +++ b/src/mongo/db/index/wildcard_validation_test.cpp @@ -110,4 +110,14 @@ TEST(WildcardProjectionValidation, IdField) { ASSERT_OK(validateWildcardProjection(BSON("$**" << 1 << "other" << 1), BSON("_id" << 1 << "a" << 0 << "other" << 0))); } + +TEST(WildcardProjectionValidation, IdFieldExcludedWithCompoundIndex) { + ASSERT_NOT_OK(validateWildcardProjection(BSON("a" << 1 << "$**" << 1), BSON("_id" << 0))); + ASSERT_NOT_OK(validateWildcardProjection(BSON("$**" << 1 << "a" << 1), BSON("_id" << 0))); + ASSERT_NOT_OK( + validateWildcardProjection(BSON("b" << 1 << "$**" << 1 << "a" << 1), BSON("_id" << 0))); + + ASSERT_OK(validateWildcardProjection(BSON("$**" << 1 << "_id" << 1), BSON("_id" << 0))); + ASSERT_OK(validateWildcardProjection(BSON("_id" << 1 << "$**" << 1), BSON("_id" << 0))); +} } // namespace mongo diff --git a/src/mongo/db/matcher/expression_geo.h b/src/mongo/db/matcher/expression_geo.h index 42966c88b3d42..b249eb0d92edd 100644 --- a/src/mongo/db/matcher/expression_geo.h +++ b/src/mongo/db/matcher/expression_geo.h @@ -281,20 +281,20 @@ class TwoDPtInAnnulusExpression : public LeafMatchExpression { void appendSerializedRightHandSide(BSONObjBuilder* bob, const SerializationOptions& opts = {}, bool includePath = true) const final { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911958); } void debugString(StringBuilder& debug, int level = 0) const final { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911959); } bool equivalent(const MatchExpression* other) const final { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911960); return false; } std::unique_ptr clone() const final { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911961); return nullptr; } diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp index 8b9814777279d..8725a72fe3fe3 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near.cpp @@ -119,7 +119,7 @@ Pipeline::SourceContainer::iterator DocumentSourceGeoNear::doOptimizeAt( Pipeline::SourceContainer::iterator DocumentSourceGeoNear::splitForTimeseries( Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) { - invariant(*itr == this); + tassert(9911904, "", *itr == this); // Only do this rewrite if we are immediately following an $_internalUnpackBucket stage. if (container->begin() == itr || diff --git a/src/mongo/db/pipeline/document_source_geo_near.h b/src/mongo/db/pipeline/document_source_geo_near.h index 26c0977fa15c5..dbd47c8c66d46 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.h +++ b/src/mongo/db/pipeline/document_source_geo_near.h @@ -76,7 +76,7 @@ class DocumentSourceGeoNear : public DocumentSource { * executing a pipeline, so this method should never be called. */ GetNextResult doGetNext() final { - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911962); } Value serialize(const SerializationOptions& opts = SerializationOptions{}) const final override; diff --git a/src/mongo/db/pipeline/document_source_geo_near_cursor.cpp b/src/mongo/db/pipeline/document_source_geo_near_cursor.cpp index 8de9f300ab927..5dd5c86031ec5 100644 --- a/src/mongo/db/pipeline/document_source_geo_near_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near_cursor.cpp @@ -77,7 +77,7 @@ DocumentSourceGeoNearCursor::DocumentSourceGeoNearCursor( _distanceField(std::move(distanceField)), _locationField(std::move(locationField)), _distanceMultiplier(distanceMultiplier) { - invariant(_distanceMultiplier >= 0); + tassert(9911901, "", _distanceMultiplier >= 0); } const char* DocumentSourceGeoNearCursor::getSourceName() const { @@ -88,19 +88,20 @@ Document DocumentSourceGeoNearCursor::transformDoc(Document&& objInput) const { MutableDocument output(std::move(objInput)); // Scale the distance by the requested factor. - invariant(output.peek().metadata().hasGeoNearDistance(), - str::stream() - << "Query returned a document that is unexpectedly missing the geoNear distance: " - << output.peek().toString()); + tassert(9911902, + str::stream() + << "Query returned a document that is unexpectedly missing the geoNear distance: " + << output.peek().toString(), + output.peek().metadata().hasGeoNearDistance()); const auto distance = output.peek().metadata().getGeoNearDistance() * _distanceMultiplier; output.setNestedField(_distanceField, Value(distance)); if (_locationField) { - invariant( - output.peek().metadata().hasGeoNearPoint(), - str::stream() - << "Query returned a document that is unexpectedly missing the geoNear point: " - << output.peek().toString()); + tassert(9911903, + str::stream() + << "Query returned a document that is unexpectedly missing the geoNear point: " + << output.peek().toString(), + output.peek().metadata().hasGeoNearPoint()); output.setNestedField(*_locationField, output.peek().metadata().getGeoNearPoint()); } diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index b20f01cbdcb22..fa432bad33f0c 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -4655,13 +4655,34 @@ Value ExpressionReduce::evaluate(const Document& root, Variables* variables) con << inputVal.toString(), inputVal.isArray()); - Value accumulatedValue = _children[_kInitial]->evaluate(root, variables); + size_t itr = 0; + int32_t prevDepth = -1; + size_t interval = getAccumulatedValueDepthCheckInterval(); + Value accumulatedValue = _children[_kInitial]->evaluate(root, variables); for (auto&& elem : inputVal.getArray()) { variables->setValue(_thisVar, elem); variables->setValue(_valueVar, accumulatedValue); accumulatedValue = _children[_kIn]->evaluate(root, variables); + if ((interval > 0) && (itr % interval) == 0 && + (accumulatedValue.isObject() || accumulatedValue.isArray())) { + int32_t depth = + accumulatedValue.depth(2 * BSONDepth::getMaxAllowableDepth() /*maxDepth*/); + if (MONGO_unlikely(depth == -1)) { + uasserted(ErrorCodes::Overflow, + "$reduce accumulated value exceeded max allowable BSON depth"); + } + // Exponential backoff if depth has not increased. + if (depth == prevDepth) { + tassert(10236400, + "unexpected control flow in $reduce object/array depth verification", + prevDepth != -1); + interval *= 2; + } + prevDepth = depth; + } + itr++; } return accumulatedValue; diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 228ede1cd0756..c2103a679a3ff 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -2861,6 +2861,10 @@ class ExpressionReduce final : public Expression { return visitor->visit(this); } + size_t getAccumulatedValueDepthCheckInterval() const { + return _accumulatedValueDepthCheckInterval; + } + private: static constexpr size_t _kInput = 0; static constexpr size_t _kInitial = 1; @@ -2868,6 +2872,9 @@ class ExpressionReduce final : public Expression { Variables::Id _thisVar; Variables::Id _valueVar; + + const size_t _accumulatedValueDepthCheckInterval = + gInternalReduceAccumulatedValueDepthCheckInterval.load(); }; diff --git a/src/mongo/db/pipeline/expression_bm_fixture.cpp b/src/mongo/db/pipeline/expression_bm_fixture.cpp index 00b656f8733e0..6a20bd26033f9 100644 --- a/src/mongo/db/pipeline/expression_bm_fixture.cpp +++ b/src/mongo/db/pipeline/expression_bm_fixture.cpp @@ -1611,4 +1611,86 @@ void ExpressionBenchmarkFixture::benchmarkPercentile(benchmark::State& state, } +/** + * Tests performance of $reduce that sums an array. + */ +void ExpressionBenchmarkFixture::benchmarkReduceSum(benchmark::State& state) { + const size_t numEntries = 16 * 1000; + BSONArray entries = rangeBSONArray(numEntries); + + BSONObj constantDepthExpr = BSON("$sum" << BSON_ARRAY("$$value" << 1)); + + BSONObj reduceExpression = + BSON("$reduce" << BSON("input" + << "$entries" + << "initialValue" << 0 << "in" << constantDepthExpr)); + + benchmarkExpression( + reduceExpression, state, std::vector(1, {{"entries"_sd, entries}})); +} + +/** + * Tests performance of $reduce that's equivalent to the identity function (uses $concatArrays + * to build the same array again). + */ +void ExpressionBenchmarkFixture::benchmarkReduceConcatArrays(benchmark::State& state) { + const size_t numEntries = 1000; + BSONArray entries = rangeBSONArray(numEntries); + + BSONArray emptyArray = rangeBSONArray(0); + + BSONObj reduceExpression = + BSON("$reduce" << BSON("input" + << "$entries" + << "initialValue" << emptyArray << "in" + << BSON("$concatArrays" + << BSON_ARRAY("$$value" << BSON_ARRAY("$$this"))))); + + benchmarkExpression( + reduceExpression, state, std::vector(1, {{"entries"_sd, entries}})); +} + + +/** + * Tests performance of $reduce that transforms an array into a deeply nested document. + * + * "$reduce": { + * "input": "$entries", + * "initialValue": [], + * "in": {"a": {"a" : {"a": ... {"a": {}}}}} + * } + * + * The nestedness of the "in" expression is configured by 'perIterationNestingDepth'. + */ +void ExpressionBenchmarkFixture::benchmarkReduceCreatingNestedObject( + benchmark::State& state, size_t perIterationNestingDepth) { + const size_t numEntries = 16 * 8 / perIterationNestingDepth; + BSONArray entries = rangeBSONArray(numEntries); + BSONObj recursiveObject = BSON("a" + << "$$value"); + for (size_t depth = 1; depth < perIterationNestingDepth; depth++) { + recursiveObject = BSON("a" << recursiveObject); + } + BSONArray emptyArray = rangeBSONArray(0); + BSONObj reduceExpression = + BSON("$reduce" << BSON("input" + << "$entries" + << "initialValue" << emptyArray << "in" << recursiveObject)); + benchmarkExpression( + reduceExpression, state, std::vector(1, {{"entries"_sd, entries}})); +} + + +void ExpressionBenchmarkFixture::benchmarkReduceCreatingNestedObject1(benchmark::State& state) { + benchmarkReduceCreatingNestedObject(state, 1); +} +void ExpressionBenchmarkFixture::benchmarkReduceCreatingNestedObject2(benchmark::State& state) { + benchmarkReduceCreatingNestedObject(state, 2); +} +void ExpressionBenchmarkFixture::benchmarkReduceCreatingNestedObject4(benchmark::State& state) { + benchmarkReduceCreatingNestedObject(state, 4); +} +void ExpressionBenchmarkFixture::benchmarkReduceCreatingNestedObject8(benchmark::State& state) { + benchmarkReduceCreatingNestedObject(state, 8); +} } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_bm_fixture.h b/src/mongo/db/pipeline/expression_bm_fixture.h index 9b0a182aae368..cca69f1019fbb 100644 --- a/src/mongo/db/pipeline/expression_bm_fixture.h +++ b/src/mongo/db/pipeline/expression_bm_fixture.h @@ -229,6 +229,15 @@ class ExpressionBenchmarkFixture : public benchmark::Fixture { void benchmarkPercentile(benchmark::State& state, int arraySize, const std::vector& ps); + void benchmarkReduceSum(benchmark::State& state); + void benchmarkReduceConcatArrays(benchmark::State& state); + void benchmarkReduceCreatingNestedObject(benchmark::State& state, + size_t perIterationNestingDepth); + void benchmarkReduceCreatingNestedObject1(benchmark::State& state); + void benchmarkReduceCreatingNestedObject2(benchmark::State& state); + void benchmarkReduceCreatingNestedObject4(benchmark::State& state); + void benchmarkReduceCreatingNestedObject8(benchmark::State& state); + private: void testDateDiffExpression(long long startDate, long long endDate, @@ -291,6 +300,30 @@ class ExpressionBenchmarkFixture : public benchmark::Fixture { benchmarkArrayFilter10(state); \ } \ \ + BENCHMARK_F(Fixture, ReduceSum)(benchmark::State & state) { \ + benchmarkReduceSum(state); \ + } \ + \ + BENCHMARK_F(Fixture, ReduceConcatArrays)(benchmark::State & state) { \ + benchmarkReduceConcatArrays(state); \ + } \ + \ + BENCHMARK_F(Fixture, ReduceCreatingNestedObject1)(benchmark::State & state) { \ + benchmarkReduceCreatingNestedObject1(state); \ + } \ + \ + BENCHMARK_F(Fixture, ReduceCreatingNestedObject2)(benchmark::State & state) { \ + benchmarkReduceCreatingNestedObject2(state); \ + } \ + \ + BENCHMARK_F(Fixture, ReduceCreatingNestedObject4)(benchmark::State & state) { \ + benchmarkReduceCreatingNestedObject4(state); \ + } \ + \ + BENCHMARK_F(Fixture, ReduceCreatingNestedObject8)(benchmark::State & state) { \ + benchmarkReduceCreatingNestedObject8(state); \ + } \ + \ BENCHMARK_F(Fixture, ConditionalCond)(benchmark::State & state) { \ benchmarkConditionalCond(state); \ } \ diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index d21c181b173b1..2aabd4f63281b 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -363,7 +363,7 @@ StatusWith> attemptToGetExe */ StringData extractGeoNearFieldFromIndexes(OperationContext* opCtx, const CollectionPtr& collection) { - invariant(collection); + tassert(9911911, "", collection); std::vector idxs; collection->getIndexCatalog()->findIndexByType(opCtx, IndexNames::GEO_2D, idxs); @@ -1578,7 +1578,7 @@ PipelineD::buildInnerQueryExecutorGeoNear(const MultipleCollectionAccessor& coll Pipeline::SourceContainer& sources = pipeline->_sources; auto expCtx = pipeline->getContext(); const auto geoNearStage = dynamic_cast(sources.front().get()); - invariant(geoNearStage); + tassert(9911900, "", geoNearStage); // If the user specified a "key" field, use that field to satisfy the "near" query. Otherwise, // look for a geo-indexed field in 'collection' that can. diff --git a/src/mongo/db/query/expression_index.cpp b/src/mongo/db/query/expression_index.cpp index a7cc20322766b..aca6b5caa33e6 100644 --- a/src/mongo/db/query/expression_index.cpp +++ b/src/mongo/db/query/expression_index.cpp @@ -152,7 +152,7 @@ void S2CellIdsToIntervalsUnsorted(const std::vector& intervalSet, long long end = static_cast(interval.range_max().id()); b.append("start", start); b.append("end", end); - invariant(start <= end); + tassert(9911950, "", start <= end); oilOut->intervals.push_back(IndexBoundsBuilder::makeRangeInterval( b.obj(), BoundInclusion::kIncludeBothStartAndEndKeys)); } else { @@ -183,7 +183,7 @@ void ExpressionMapping::S2CellIdsToIntervals(const std::vector& interv LOGV2(6029801, "invalid OrderedIntervalList", "orderedIntervalList"_attr = redact(oilOut->toString(false))); - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911963); } } @@ -232,7 +232,7 @@ void ExpressionMapping::S2CellIdsToIntervalsWithParents(const std::vectortoString(false))); - MONGO_UNREACHABLE; + MONGO_UNREACHABLE_TASSERT(9911964); } } diff --git a/src/mongo/db/query/get_executor.cpp b/src/mongo/db/query/get_executor.cpp index fef4f2e206e10..938c10ca4833a 100644 --- a/src/mongo/db/query/get_executor.cpp +++ b/src/mongo/db/query/get_executor.cpp @@ -109,6 +109,7 @@ #include "mongo/db/timeseries/timeseries_options.h" #include "mongo/logv2/log.h" #include "mongo/s/shard_key_pattern_query_util.h" +#include "mongo/s/shard_targeting_helpers.h" #include "mongo/scripting/engine.h" #include "mongo/util/duration.h" #include "mongo/util/processinfo.h" @@ -375,6 +376,36 @@ CollectionStats fillOutCollectionStats(OperationContext* opCtx, const Collection return stats; } +bool requiresShardFiltering(const CanonicalQuery& canonicalQuery, + const CollectionPtr& collection, + QueryPlannerParams* plannerParams) { + if (!(plannerParams->options & QueryPlannerParams::INCLUDE_SHARD_FILTER)) { + // Shard filter was not requested; cmd may not be from a router. + return false; + } + // If the caller wants a shard filter, make sure we're actually sharded. + if (!collection.isSharded()) { + // Not actually sharded. + return false; + } + + const auto& shardKeyPattern = collection.getShardKeyPattern(); + // Shards cannot own orphans for the key ranges they own, so there is no need + // to include a shard filtering stage. By omitting the shard filter, it may be + // possible to get a more efficient plan (for example, a COUNT_SCAN may be used if + // the query is eligible). + const BSONObj extractedKey = extractShardKeyFromQuery(shardKeyPattern, canonicalQuery); + + if (extractedKey.isEmpty()) { + // Couldn't extract all the fields of the shard key from the query, + // no way to target a single shard. + return true; + } + + return !isEqualityOnShardKeyTargetable( + extractedKey, shardKeyPattern, !canonicalQuery.getCollator()); +} + void fillOutPlannerParams(OperationContext* opCtx, const CollectionPtr& collection, const CanonicalQuery* canonicalQuery, @@ -412,27 +443,14 @@ void fillOutPlannerParams(OperationContext* opCtx, } // If the caller wants a shard filter, make sure we're actually sharded. - if (plannerParams->options & QueryPlannerParams::INCLUDE_SHARD_FILTER) { - if (collection.isSharded()) { - const auto& shardKeyPattern = collection.getShardKeyPattern(); - - // If the shard key is specified exactly, the query is guaranteed to only target one - // shard. Shards cannot own orphans for the key ranges they own, so there is no need - // to include a shard filtering stage. By omitting the shard filter, it may be possible - // to get a more efficient plan (for example, a COUNT_SCAN may be used if the query is - // eligible). - const BSONObj extractedKey = extractShardKeyFromQuery(shardKeyPattern, *canonicalQuery); - - if (extractedKey.isEmpty()) { - plannerParams->shardKey = shardKeyPattern.toBSON(); - } else { - plannerParams->options &= ~QueryPlannerParams::INCLUDE_SHARD_FILTER; - } - } else { - // If there's no metadata don't bother w/the shard filter since we won't know what - // the key pattern is anyway... - plannerParams->options &= ~QueryPlannerParams::INCLUDE_SHARD_FILTER; - } + if (requiresShardFiltering(*canonicalQuery, collection, plannerParams)) { + // This query may have been issued to multiple shards. + // Knowing the shardKey may avoid fetching the document to apply shard filtering + // e.g., if an ixscan will provide all required fields. + plannerParams->shardKey = collection.getShardKeyPattern().toBSON(); + } else { + // A shard filter was not requested, or is not required - clear the flag. + plannerParams->options &= ~QueryPlannerParams::INCLUDE_SHARD_FILTER; } if (internalQueryPlannerEnableIndexIntersection.load()) { diff --git a/src/mongo/db/query/plan_enumerator.cpp b/src/mongo/db/query/plan_enumerator.cpp index 4491c5639a613..15d3da87deba4 100644 --- a/src/mongo/db/query/plan_enumerator.cpp +++ b/src/mongo/db/query/plan_enumerator.cpp @@ -1677,6 +1677,10 @@ void PlanEnumerator::tagMemo(size_t id) { for (size_t j = 0; j < assign.preds.size(); ++j) { MatchExpression* pred = assign.preds[j]; if (pred->getTag()) { + tassert(11390000, + "Expected the predicate's tag to be of type OrPushdownTag", + pred->getTag()->getType() == + MatchExpression::TagData::Type::OrPushdownTag); OrPushdownTag* orPushdownTag = static_cast(pred->getTag()); orPushdownTag->setIndexTag( new IndexTag(assign.index, assign.positions[j], assign.canCombineBounds)); diff --git a/src/mongo/db/query/planner_wildcard_helpers.cpp b/src/mongo/db/query/planner_wildcard_helpers.cpp index 62839d8b0e7a7..63d3322a56f7d 100644 --- a/src/mongo/db/query/planner_wildcard_helpers.cpp +++ b/src/mongo/db/query/planner_wildcard_helpers.cpp @@ -401,6 +401,11 @@ std::pair expandWildcardIndexKeyPattern(const BSONObj& wildcard builder.appendAs(field, expandFieldName); wildcardFieldPos = fieldPos; } else { + tassert(11390001, + str::stream() << "Expansion of wildcard index " << wildcardKeyPattern + << " would result in duplicate field: " << expandFieldName, + fieldName != expandFieldName); + builder.append(field); } ++fieldPos; diff --git a/src/mongo/db/query/planner_wildcard_helpers_test.cpp b/src/mongo/db/query/planner_wildcard_helpers_test.cpp index 5b11bdc7c01c9..1585c50033f98 100644 --- a/src/mongo/db/query/planner_wildcard_helpers_test.cpp +++ b/src/mongo/db/query/planner_wildcard_helpers_test.cpp @@ -32,6 +32,7 @@ #include "mongo/db/query/planner_wildcard_helpers.h" #include "mongo/db/query/query_solution.h" #include "mongo/idl/server_parameter_test_util.h" +#include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" namespace mongo::wildcard_planning { @@ -259,4 +260,19 @@ TEST(PlannerWildcardHelpersTest, Expand_CompoundWildcardIndex_NumericComponents) ASSERT_FALSE(expandedIndexes.front().multikey); ASSERT_EQ(expectedMks, expandedIndexes.front().multikeyPaths); } + +DEATH_TEST(PlannerWildcardHelpersTest, InvalidIndexExpansion, "11390001") { + IndexEntryMock wildcardIndex{BSON("a" << 1 << "$**" << 1), BSON("_id" << 0), {}}; + stdx::unordered_set fields{"a"}; + std::vector expandedIndexes{}; + expandWildcardIndexEntry(*wildcardIndex.indexEntry, fields, &expandedIndexes); +} + +DEATH_TEST(PlannerWildcardHelpersTest, AnotherInvalidIndexExpansion, "11390001") { + IndexEntryMock wildcardIndex{BSON("$**" << 1 << "a" << 1), BSON("_id" << 0), {}}; + stdx::unordered_set fields{"a"}; + std::vector expandedIndexes{}; + expandWildcardIndexEntry(*wildcardIndex.indexEntry, fields, &expandedIndexes); +} + } // namespace mongo::wildcard_planning diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl index 2b0dc086d43c4..d97e76a63bac1 100644 --- a/src/mongo/db/query/query_knobs.idl +++ b/src/mongo/db/query/query_knobs.idl @@ -1257,6 +1257,19 @@ server_parameters: default: false redact: false + internalReduceAccumulatedValueDepthCheckInterval: + description: >- + Configures how frequently $reduce checks if its accumulated value has exceeded the maximum + allowable nestedness. Arrays and subdocuments both count. If set to 0, no check is performed. + set_at: [startup, runtime] + cpp_vartype: "AtomicWord" + cpp_varname: gInternalReduceAccumulatedValueDepthCheckInterval + default: 16 + validator: + gte: 0 + lte: 1048576 # 1024 ** 2 + redact: false + # Note for adding additional query knobs: # # When adding a new query knob, you should consider whether or not you need to add an 'on_update' diff --git a/src/mongo/db/s/query_analysis_coordinator.cpp b/src/mongo/db/s/query_analysis_coordinator.cpp index 13afb01f0dd48..8c0fafab14048 100644 --- a/src/mongo/db/s/query_analysis_coordinator.cpp +++ b/src/mongo/db/s/query_analysis_coordinator.cpp @@ -160,8 +160,7 @@ void QueryAnalysisCoordinator::onSamplerDelete(const MongosType& doc) { invariant(serverGlobalParams.clusterRole.has(ClusterRole::ConfigServer)); stdx::lock_guard lk(_mutex); - auto erased = _samplers.erase(doc.getName()); - invariant(erased); + _samplers.erase(doc.getName()); } void QueryAnalysisCoordinator::onStartup(OperationContext* opCtx) { diff --git a/src/mongo/db/s/query_analysis_coordinator_test.cpp b/src/mongo/db/s/query_analysis_coordinator_test.cpp index 0639a6e67eca1..644be2402a5d7 100644 --- a/src/mongo/db/s/query_analysis_coordinator_test.cpp +++ b/src/mongo/db/s/query_analysis_coordinator_test.cpp @@ -534,6 +534,24 @@ TEST_F(QueryAnalysisCoordinatorTest, RemoveSamplersOnDelete) { ASSERT(samplers.empty()); } +TEST_F(QueryAnalysisCoordinatorTest, RemoveUntrackedSamplerOnDeleteDoesNotCrash) { + auto coordinator = QueryAnalysisCoordinator::get(operationContext()); + + // There are no samplers initially. + auto samplers = coordinator->getSamplersForTest(); + ASSERT(samplers.empty()); + + // Directly call onSamplerDelete for a sampler that was never inserted into the coordinator's + // in-memory map. This simulates the scenario where a delete on config.mongos targets a sampler + // that was not loaded during onStartup (e.g., because its ping time was too old). This must not + // crash the server (SERVER-121686). + auto mongosDoc = makeConfigMongosDocument(mongosName0); + coordinator->onSamplerDelete(mongosDoc); + + samplers = coordinator->getSamplersForTest(); + ASSERT(samplers.empty()); +} + TEST_F(QueryAnalysisCoordinatorTest, CreateSamplersOnStartup) { auto coordinator = QueryAnalysisCoordinator::get(operationContext()); diff --git a/src/mongo/db/s/query_analysis_op_observer.cpp b/src/mongo/db/s/query_analysis_op_observer.cpp index 71b4c4d101b3b..f8803401f2b2f 100644 --- a/src/mongo/db/s/query_analysis_op_observer.cpp +++ b/src/mongo/db/s/query_analysis_op_observer.cpp @@ -69,8 +69,16 @@ void QueryAnalysisOpObserver::onInserts(OperationContext* opCtx, const auto parsedDoc = uassertStatusOK(MongosType::fromBSON(it->doc)); opCtx->recoveryUnit()->onCommit( [parsedDoc](OperationContext* opCtx, boost::optional) { - analyze_shard_key::QueryAnalysisCoordinator::get(opCtx)->onSamplerInsert( - parsedDoc); + try { + analyze_shard_key::QueryAnalysisCoordinator::get(opCtx) + ->onSamplerInsert(parsedDoc); + } catch (const DBException& ex) { + LOGV2_WARNING(10690305, + "Failed to handle sampler insert in " + "QueryAnalysisCoordinator", + "sampler"_attr = parsedDoc, + "error"_attr = ex.toString()); + } }); } } @@ -93,7 +101,16 @@ void QueryAnalysisOpObserver::onUpdate(OperationContext* opCtx, const OplogUpdat uassertStatusOK(MongosType::fromBSON(args.updateArgs->updatedDoc)); opCtx->recoveryUnit()->onCommit([parsedDoc](OperationContext* opCtx, boost::optional) { - analyze_shard_key::QueryAnalysisCoordinator::get(opCtx)->onSamplerUpdate(parsedDoc); + try { + analyze_shard_key::QueryAnalysisCoordinator::get(opCtx)->onSamplerUpdate( + parsedDoc); + } catch (const DBException& ex) { + LOGV2_WARNING(10690306, + "Failed to handle sampler update in " + "QueryAnalysisCoordinator", + "sampler"_attr = parsedDoc, + "error"_attr = ex.toString()); + } }); } } @@ -143,7 +160,16 @@ void QueryAnalysisOpObserver::onDelete(OperationContext* opCtx, const auto parsedDoc = uassertStatusOK(MongosType::fromBSON(doc)); opCtx->recoveryUnit()->onCommit([parsedDoc](OperationContext* opCtx, boost::optional) { - analyze_shard_key::QueryAnalysisCoordinator::get(opCtx)->onSamplerDelete(parsedDoc); + try { + analyze_shard_key::QueryAnalysisCoordinator::get(opCtx)->onSamplerDelete( + parsedDoc); + } catch (const DBException& ex) { + LOGV2_WARNING(10690307, + "Failed to handle sampler delete in " + "QueryAnalysisCoordinator", + "sampler"_attr = parsedDoc, + "error"_attr = ex.toString()); + } }); } } diff --git a/src/mongo/db/server_options.h b/src/mongo/db/server_options.h index 96b9f294d2303..21f95c70cd100 100644 --- a/src/mongo/db/server_options.h +++ b/src/mongo/db/server_options.h @@ -84,8 +84,10 @@ struct ServerGlobalParams { int defaultLocalThresholdMillis = 15; // --localThreshold in ms to consider a node local bool moveParanoia = false; // for move chunk paranoia - bool noUnixSocket = false; // --nounixsocket - bool doFork = false; // --fork + bool noUnixSocket = false; // --nounixsocket + bool doFork = false; // --fork + bool isMongoBridge = false; + std::string socket = "/tmp"; // UNIX domain socket directory std::string transportLayer; // --transportLayer (must be either "asio" or "legacy") diff --git a/src/mongo/db/service_entry_point_common.cpp b/src/mongo/db/service_entry_point_common.cpp index 49fa07e3faa38..2879deae3de35 100644 --- a/src/mongo/db/service_entry_point_common.cpp +++ b/src/mongo/db/service_entry_point_common.cpp @@ -1497,6 +1497,10 @@ void ExecCommandDatabase::_initiateCommand() { apiVersionMetrics.update(appName, apiParams); } + // Start authz contract tracking before we evaluate failpoints. + auto authzSession = AuthorizationSession::get(client); + authzSession->startContractTracking(); + rpc::TrackingMetadata::get(opCtx).initWithOperName(command->getName()); auto const replCoord = repl::ReplicationCoordinator::get(opCtx); @@ -1508,8 +1512,6 @@ void ExecCommandDatabase::_initiateCommand() { replCoord->getReplicationMode() == repl::ReplicationCoordinator::modeReplSet); - // Start authz contract tracking before we evaluate failpoints - auto authzSession = AuthorizationSession::get(client); authzSession->startContractTracking(); CommandHelpers::evaluateFailCommandFailPoint(opCtx, _invocation.get()); diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h b/src/mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h index 260dec3aa6e6b..141f299358353 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h @@ -122,7 +122,8 @@ int wiredTigerPrepareConflictRetry(OperationContext* opCtx, F&& f) { // lock for completeness). if (type == RESOURCE_GLOBAL || type == RESOURCE_DATABASE || type == RESOURCE_COLLECTION) invariant(lock.mode != MODE_S && lock.mode != MODE_X, - str::stream() << lock.resourceId.toString() << " in " << modeName(lock.mode)); + str::stream() + << toStringForLogging(lock.resourceId) << " in " << modeName(lock.mode)); } if (MONGO_unlikely(WTSkipPrepareConflictRetries.shouldFail())) { diff --git a/src/mongo/rpc/SConscript b/src/mongo/rpc/SConscript index 1052d1cacc9a7..809f8ceea61a5 100644 --- a/src/mongo/rpc/SConscript +++ b/src/mongo/rpc/SConscript @@ -176,6 +176,7 @@ if wiredtiger: env.CppUnitTest( target='rpc_test', source=[ + 'data_type_wire_message_payload_test.cpp', 'get_status_from_command_result_test.cpp', 'metadata/client_metadata_test.cpp', 'metadata/egress_metadata_hook_list_test.cpp', diff --git a/src/mongo/rpc/data_type_wire_message_payload_test.cpp b/src/mongo/rpc/data_type_wire_message_payload_test.cpp new file mode 100644 index 0000000000000..3b1ca70f51d22 --- /dev/null +++ b/src/mongo/rpc/data_type_wire_message_payload_test.cpp @@ -0,0 +1,117 @@ +/** + * Copyright (C) 2018-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include + +#include "mongo/base/data_range_cursor.h" +#include "mongo/bson/bsonobj.h" +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/rpc/wire_message_payload.h" +#include "mongo/unittest/assert.h" + +namespace mongo { + +namespace { + +TEST(DataTypeWireMessagePayload, EmptyBSONObject) { + const BSONObj empty; + ConstDataRangeCursor cdrc(empty.objdata(), empty.objsize()); + + WireMessagePayload wmp; + + ASSERT_OK(cdrc.readAndAdvanceNoThrow(&wmp)); + ASSERT_BSONOBJ_EQ(empty, wmp.obj); +} + +TEST(DataTypeWireMessagePayload, SmallBSONObject) { + const BSONObj obj = BSON("foo" << BSON_ARRAY(1 << 2 << 3 << 4 << 5) << "bar" + << "baz"); + ConstDataRangeCursor cdrc(obj.objdata(), obj.objsize()); + + WireMessagePayload wmp; + + ASSERT_OK(cdrc.readAndAdvanceNoThrow(&wmp)); + ASSERT_BSONOBJ_EQ(obj, wmp.obj); +} + +TEST(DataTypeWireMessagePayload, BSONObjectWith1MPayload) { + const BSONObj obj = BSON("payload" << std::string(1024 * 1024, 'x')); + ConstDataRangeCursor cdrc(obj.objdata(), obj.objsize()); + + WireMessagePayload wmp; + + ASSERT_OK(cdrc.readAndAdvanceNoThrow(&wmp)); + ASSERT_BSONOBJ_EQ(obj, wmp.obj); +} + +TEST(DataTypeWireMessagePayload, BSONObjectAtSizeLimits) { + // Overhead: + // - 4 bytes for object size + // - 1 byte for string type tag + // - 8 bytes for "payload" field name plus trailing \0 byte + // - 4 bytes for string field length + // - 1 byte for trailing \0 byte for string value + // - 1 byte for trailing \0 byte for object + // ======================================== + // = 19 bytes total overhead + auto buildBSONObjAtSizeLimit = [](int sizeLimit) { + constexpr int kOverheadInBytes = 19; + + BSONObjBuilder bob; + bob.append("payload", std::string(sizeLimit - kOverheadInBytes, 'x')); + return bob.obj(); + }; + + // Build objects that are exactly the size as the different internal size limits. + for (int sizeLimit : {BSONObjMaxUserSize, BSONObjMaxInternalSize, BSONObjMaxWireMessageSize}) { + BSONObj obj = buildBSONObjAtSizeLimit(sizeLimit); + + ConstDataRangeCursor cdrc(obj.objdata(), obj.objsize()); + + WireMessagePayload wmp; + + ASSERT_OK(cdrc.readAndAdvanceNoThrow(&wmp)); + ASSERT_BSONOBJ_EQ(obj, wmp.obj); + ASSERT_EQ(sizeLimit, wmp.obj.objsize()); + } + + // Build an object that is larger than the size limit for WireMessagePayloads. + { + BSONObj obj = buildBSONObjAtSizeLimit(BSONObjMaxWireMessageSize + 1); + ConstDataRangeCursor cdrc(obj.objdata(), obj.objsize()); + + WireMessagePayload wmp; + + ASSERT_EQUALS(ErrorCodes::BSONObjectTooLarge, cdrc.readAndAdvanceNoThrow(&wmp).code()); + } +} + +} // namespace + +} // namespace mongo diff --git a/src/mongo/rpc/object_check.h b/src/mongo/rpc/object_check.h index ac9ddb899d8be..374ab424342fe 100644 --- a/src/mongo/rpc/object_check.h +++ b/src/mongo/rpc/object_check.h @@ -36,6 +36,9 @@ #include "mongo/bson/bsontypes.h" #include "mongo/db/server_options.h" #include "mongo/logv2/redaction.h" +#include "mongo/platform/compiler.h" +#include "mongo/rpc/wire_message_payload.h" +#include "mongo/util/assert_util.h" #include "mongo/util/hex.h" // We do not use the rpc namespace here so we can specialize Validator. @@ -49,7 +52,6 @@ class Status; */ template <> struct Validator { - inline static Status validateLoad(const char* ptr, size_t length) { if (!serverGlobalParams.objcheck) { return Status::OK(); @@ -73,4 +75,27 @@ struct Validator { static Status validateStore(const BSONObj& toStore); }; + +/** + * A validator for WireMessagePayload objects. The implementation will validate the input object + * in the same way as a regular BSONObj. This type is only needed because it is required to be + * present by DataType::Handler. + */ +template <> +struct Validator { + inline static Status validateLoad(const char* ptr, size_t length) { + // Reuse BSONObj validation logic. + return Validator::validateLoad(ptr, length); + } + + static Status validateStore(const WireMessagePayload& toStore) { + invariant( + false, + "Validator must only be used to read incoming op_msg requests, and " + "never as a serialization sink."); + + MONGO_COMPILER_UNREACHABLE; + } +}; + } // namespace mongo diff --git a/src/mongo/rpc/op_msg.cpp b/src/mongo/rpc/op_msg.cpp index a5bb2c558edc6..a41ea387d2628 100644 --- a/src/mongo/rpc/op_msg.cpp +++ b/src/mongo/rpc/op_msg.cpp @@ -42,7 +42,12 @@ #include "mongo/db/multitenancy_gen.h" #include "mongo/db/server_feature_flags_gen.h" #include "mongo/logv2/log.h" -#include "mongo/rpc/object_check.h" +#include "mongo/logv2/log_attr.h" +#include "mongo/logv2/log_component.h" +#include "mongo/logv2/redaction.h" +#include "mongo/rpc/object_check.h" // IWYU pragma: keep +#include "mongo/rpc/op_msg.h" +#include "mongo/rpc/wire_message_payload.h" #include "mongo/util/bufreader.h" #include "mongo/util/database_name_util.h" #include "mongo/util/hex.h" @@ -168,7 +173,13 @@ OpMsg OpMsg::parse(const Message& message, Client* client) try { case Section::kBody: { uassert(40430, "Multiple body sections in message", !haveBody); haveBody = true; - msg.body = sectionsBuf.read>(); + + // Parse and validate the payload using a temporary WireMessagePayload struct. + // This allows for a slightly higher maximum object size than a regular BSONObj. + // This is neede in order to parse messages containing large BSONObj bodies with + // extra metadata, as in the oplog replication. + WireMessagePayload payload = sectionsBuf.read>(); + msg.body = std::move(payload.obj); uassert(ErrorCodes::InvalidOptions, "Multitenancy not enabled, cannot set $tenant in command body", @@ -443,11 +454,12 @@ AtomicWord OpMsgBuilder::disableDupeFieldCheck_forTest{false}; Message OpMsgBuilder::finish() { const auto size = _buf.len(); + constexpr auto maxSize = BSONObjMaxWireMessageSize; uassert(ErrorCodes::BSONObjectTooLarge, str::stream() << "BSON size limit hit while building Message. Size: " << size << " (0x" - << unsignedHex(size) << "); maxSize: " << BSONObjMaxInternalSize << "(" - << (BSONObjMaxInternalSize / (1024 * 1024)) << "MB)", - size <= BSONObjMaxInternalSize); + << unsignedHex(size) << "); maxSize: " << maxSize << "(" + << (maxSize / (1024 * 1024)) << "MB)", + size <= maxSize); return finishWithoutSizeChecking(); } diff --git a/src/mongo/rpc/wire_message_payload.h b/src/mongo/rpc/wire_message_payload.h new file mode 100644 index 0000000000000..5295a2f995218 --- /dev/null +++ b/src/mongo/rpc/wire_message_payload.h @@ -0,0 +1,94 @@ +/** + * Copyright (C) 2025-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include + +#include "mongo/base/data_type.h" +#include "mongo/base/data_type_validated.h" +#include "mongo/base/error_codes.h" +#include "mongo/base/status.h" +#include "mongo/bson/bsonobj.h" +#include "mongo/platform/compiler.h" +#include "mongo/util/assert_util.h" + +namespace mongo { + +/** + * A small wrapper around BSONObj, used exclusively for data type validation for BSON wire + * messages. The specialized DataType::Handler and Validator implementations for this type allow + * for a slightly larger BSONObj size (using the BSONObj::WireMessageSizeTrait) than regular + * BSONObjs. This slightly increased maximum size allows additional metadata to be sent along with + * the actual payload in server-to-server messages, e.g. oplog replication. + */ +struct WireMessagePayload { + BSONObj obj; +}; + +// Provides a DataType::Handler specialization for WireMessagePayload. +template <> +struct DataType::Handler { + static Status load(WireMessagePayload* bson, + const char* ptr, + size_t length, + size_t* advanced, + std::ptrdiff_t debugOffset) try { + auto temp = BSONObj(ptr, BSONObj::WireMessageSizeTrait{}); + auto len = temp.objsize(); + if (bson) { + *bson = WireMessagePayload{std::move(temp)}; + } + if (advanced) { + *advanced = len; + } + return Status::OK(); + } catch (const DBException& e) { + return e.toStatus(); + } + + static Status store(const WireMessagePayload& bson, + char* ptr, + size_t length, + size_t* advanced, + std::ptrdiff_t debugOffset) { + invariant( + false, + "Handler must only be used to read incoming op_msg requests, and " + "never as a serialization sink."); + + MONGO_COMPILER_UNREACHABLE; + } + + static WireMessagePayload defaultConstruct() { + return WireMessagePayload(); + } +}; + +} // namespace mongo diff --git a/src/mongo/s/SConscript b/src/mongo/s/SConscript index b4f067a6e575c..2fd47474795e6 100644 --- a/src/mongo/s/SConscript +++ b/src/mongo/s/SConscript @@ -74,6 +74,23 @@ env.Library( ], ) +env.Library( + target='sharding_helpers', + source=[ + 'shard_targeting_helpers.cpp', + 'shard_key_pattern.cpp', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/base', + '$BUILD_DIR/mongo/db/common', + '$BUILD_DIR/mongo/db/matcher/path', + '$BUILD_DIR/mongo/db/mongohasher', + '$BUILD_DIR/mongo/db/query/collation/collator_interface', + '$BUILD_DIR/mongo/db/server_base', + '$BUILD_DIR/mongo/db/storage/key_string', + ], +) + env.Library( target='sharding_router_api', source=[ @@ -99,6 +116,7 @@ env.Library( 'async_requests_sender', 'grid', 'query_analysis_sampler', + 'sharding_helpers', ], LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/catalog/collection_uuid_mismatch_info', @@ -270,7 +288,6 @@ env.Library( 'resharding/type_collection_fields.idl', 'shard_cannot_refresh_due_to_locks_held_exception.cpp', 'shard_invalidated_for_targeting_exception.cpp', - 'shard_key_pattern.cpp', 'shard_version.cpp', 'shard_version.idl', 'shard_version_factory.cpp', @@ -291,6 +308,7 @@ env.Library( '$BUILD_DIR/mongo/rpc/message', '$BUILD_DIR/mongo/util/caching', 'analyze_shard_key_common', + 'sharding_helpers', ], LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/client/read_preference', diff --git a/src/mongo/s/chunk_manager.cpp b/src/mongo/s/chunk_manager.cpp index 2c473097b12b0..ef9e0b081464c 100644 --- a/src/mongo/s/chunk_manager.cpp +++ b/src/mongo/s/chunk_manager.cpp @@ -37,6 +37,7 @@ #include "mongo/logv2/log.h" #include "mongo/s/mongod_and_mongos_server_parameters_gen.h" #include "mongo/s/shard_invalidated_for_targeting_exception.h" +#include "mongo/s/shard_targeting_helpers.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kSharding @@ -663,26 +664,13 @@ Chunk ChunkManager::findIntersectingChunk(const BSONObj& shardKey, bool bypassIsFieldHashedCheck) const { const bool hasSimpleCollation = (collation.isEmpty() && !_rt->optRt->getDefaultCollator()) || SimpleBSONObjComparator::kInstance.evaluate(collation == CollationSpec::kSimpleSpec); - if (!hasSimpleCollation) { - for (BSONElement elt : shardKey) { - // We must assume that if the field is specified as "hashed" in the shard key pattern, - // then the hash value could have come from a collatable type. - const bool isFieldHashed = - (_rt->optRt->getShardKeyPattern().isHashedPattern() && - _rt->optRt->getShardKeyPattern().getHashedField().fieldNameStringData() == - elt.fieldNameStringData()); - - // If we want to skip the check in the special case where the _id field is hashed and - // used as the shard key, set bypassIsFieldHashedCheck. This assumes that a request with - // a query that contains an _id field can target a specific shard. - uassert(ErrorCodes::ShardKeyNotFound, - str::stream() << "Cannot target single shard due to collation of key " - << elt.fieldNameStringData() << " for namespace " - << _rt->optRt->nss(), - !CollationIndexKey::isCollatableType(elt.type()) && - (!isFieldHashed || bypassIsFieldHashedCheck)); - } - } + const auto& elt = getFirstFieldWithIncompatibleCollation( + shardKey, _rt->optRt->getShardKeyPattern(), hasSimpleCollation, bypassIsFieldHashedCheck); + + uassert(ErrorCodes::ShardKeyNotFound, + str::stream() << "Cannot target single shard due to collation of key " + << elt.fieldNameStringData() << " for namespace " << _rt->optRt->nss(), + elt.eoo()); auto chunkInfo = _rt->optRt->findIntersectingChunk(shardKey); diff --git a/src/mongo/s/shard_targeting_helpers.cpp b/src/mongo/s/shard_targeting_helpers.cpp new file mode 100644 index 0000000000000..457d19765b196 --- /dev/null +++ b/src/mongo/s/shard_targeting_helpers.cpp @@ -0,0 +1,69 @@ +/** + * Copyright (C) 2024-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/s/shard_targeting_helpers.h" + +#include "mongo/db/query/collation/collation_index_key.h" + + +namespace mongo { +BSONElement getFirstFieldWithIncompatibleCollation(const BSONObj& shardKey, + const ShardKeyPattern& shardKeyPattern, + bool queryHasSimpleCollation, + bool permitHashedFields) { + if (queryHasSimpleCollation) { + return {}; + } + + for (BSONElement elt : shardKey) { + // We must assume that if the field is specified as "hashed" in the shard key pattern, + // then the hash value could have come from a collatable type. + const bool isFieldHashed = + (shardKeyPattern.isHashedPattern() && + shardKeyPattern.getHashedField().fieldNameStringData() == elt.fieldNameStringData()); + + // If we want to skip the check in the special case where the _id field is hashed and + // used as the shard key, set permitHashedFields. This assumes that a request with + // a query that contains an _id field can target a specific shard. + if (CollationIndexKey::isCollatableType(elt.type()) || + (isFieldHashed && !permitHashedFields)) { + return elt; + } + } + return {}; +} + +bool isEqualityOnShardKeyTargetable(const BSONObj& shardKey, + const ShardKeyPattern& shardKeyPattern, + bool queryHasSimpleCollation) { + return getFirstFieldWithIncompatibleCollation( + shardKey, shardKeyPattern, queryHasSimpleCollation, /* permitHashedFields */ false) + .eoo(); +} +} // namespace mongo diff --git a/src/mongo/s/shard_targeting_helpers.h b/src/mongo/s/shard_targeting_helpers.h new file mode 100644 index 0000000000000..fd8338baa644b --- /dev/null +++ b/src/mongo/s/shard_targeting_helpers.h @@ -0,0 +1,106 @@ +/** + * Copyright (C) 2024-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/bson/bsonobj.h" +#include "mongo/s/shard_key_pattern.h" + +namespace mongo { + +/** + * Given a shardKeyPattern, and the corresponding values extracted from a query specifying + * equalities (see extractShardKeyFromQuery), extract the first element which is incompatible with + * single shard targeting, because: + * + * * The value is affected by collation, and the query collation does not match the sharding + * collation + * * The field is hashed. + * + * For specific use cases (see cluster_find_and_modify_cmd targetSingleShard), hashed fields may not + * prevent single shard targeting; such callers can set permitHashedFields. + * + * e.g., + * + * shardKeyPattern : + * shardKey : + * queryHasSimpleCollation : true + * permitHashedFields : + * -> eoo() : collation is simple (matches sharding collation), so all elements are compatible. + * + * shardKeyPattern : {a: 1, b: 1} + * shardKey : {a: 123, b: null} + * queryHasSimpleCollation : false + * permitHashedFields : + * -> eoo() : no fields are affected by collation. + * + * shardKeyPattern : {a: 1, b: 1} + * shardKey : {a: 123, b: "foobar"} + * queryHasSimpleCollation : false + * permitHashedFields : + * -> {b: "foobar"} : strings are affected by collation, and collation differs from sharding + * collation. + * + * shardKeyPattern : {a: "hashed", b: 1} + * shardKey : {a: 123, b: "foobar"} + * queryHasSimpleCollation : false + * permitHashedFields : false + * -> {a: 123} : field a is hashed in shardKeyPattern; the originating data type _may_ have been + * affected by collation. + * + * shardKeyPattern : {a: "hashed", b: 1} + * shardKey : {a: 123, b: "foobar"} + * queryHasSimpleCollation : false + * permitHashedFields : true + * -> {b: "foobar"} : strings are affected by collation, and collation differs from sharding + * collation, and hashed fields are allowed. + * + * Precondition: shardKey must contain values for all fields of shardKeyPattern. + * + * @return BSONElement identified element, or eoo() if none. + */ +BSONElement getFirstFieldWithIncompatibleCollation(const BSONObj& shardKey, + const ShardKeyPattern& shardKeyPattern, + bool queryHasSimpleCollation, + bool permitHashedFields); + + +/** + * Given a shardKeyPattern, and the corresponding values extracted from a query specifying + * equalities (see extractShardKeyFromQuery), check if the query can be targeted to a single shard. + * + * See getFirstFieldWithIncompatibleCollation() for examples. There must be no incompatible fields + * for a query to be single shard targeted. + * + * Precondition: shardKey must contain values for all fields of shardKeyPattern. + */ +bool isEqualityOnShardKeyTargetable(const BSONObj& shardKey, + const ShardKeyPattern& shardKeyPattern, + bool queryHasSimpleCollation); +} // namespace mongo diff --git a/src/mongo/tools/mongobridge_tool/bridge.cpp b/src/mongo/tools/mongobridge_tool/bridge.cpp index 4e8fee23646ed..db9b5350af7e0 100644 --- a/src/mongo/tools/mongobridge_tool/bridge.cpp +++ b/src/mongo/tools/mongobridge_tool/bridge.cpp @@ -481,6 +481,7 @@ Future ServiceEntryPointBridge::handleRequest(OperationContext* opCt int bridgeMain(int argc, char** argv) { + serverGlobalParams.isMongoBridge = true; registerShutdownTask([&] { // NOTE: This function may be called at any time. It must not // depend on the prior execution of mongo initializers or the diff --git a/src/mongo/transport/asio/asio_session_impl.cpp b/src/mongo/transport/asio/asio_session_impl.cpp index 14127d3d31cde..7521b55aa4eca 100644 --- a/src/mongo/transport/asio/asio_session_impl.cpp +++ b/src/mongo/transport/asio/asio_session_impl.cpp @@ -150,6 +150,17 @@ int getExceptionLogSeverityLevel() { return logSeverity().toInt(); } + +int getMessageSizeErrorLogSeverityLevel() { + static logv2::SeveritySuppressor logSeverity{Seconds{gMessageSizeErrorRateSec}, + logv2::LogSeverity::Info(), + logv2::LogSeverity::Debug(1)}; + + return logSeverity().toInt(); +} + +CounterMetric totalMessageSizeErrorsPreAuth("network.totalMessageSizeErrorPreAuth"); +CounterMetric totalMessageSizeErrorsPostAuth("network.totalMessageSizeErrorPostAuth"); } // namespace @@ -565,17 +576,26 @@ Future CommonAsioSession::sourceMessageImpl(const BatonHandle& baton) { } const auto msgLen = size_t(MSGHEADER::View(headerBuffer.get()).getMessageLength()); - if (msgLen < kHeaderSize || msgLen > MaxMessageSizeBytes) { + + const size_t maxMessageSize = _restrictedMode + ? static_cast(gPreAuthMaximumMessageSizeBytes.loadRelaxed()) + : MaxMessageSizeBytes; + if (msgLen < kHeaderSize || msgLen > maxMessageSize) { StringBuilder sb; sb << "recv(): message msgLen " << msgLen << " is invalid. " - << "Min " << kHeaderSize << " Max: " << MaxMessageSizeBytes; + << "Min " << kHeaderSize << " Max: " << maxMessageSize; const auto str = sb.str(); - LOGV2(4615638, - "recv(): message msgLen {msgLen} is invalid. Min: {min} Max: {max}", - "recv(): message mstLen is invalid.", - "msgLen"_attr = msgLen, - "min"_attr = kHeaderSize, - "max"_attr = MaxMessageSizeBytes); + LOGV2_DEBUG(4615638, + getMessageSizeErrorLogSeverityLevel(), + "recv(): message msgLen is invalid.", + "msgLen"_attr = msgLen, + "min"_attr = kHeaderSize, + "max"_attr = maxMessageSize); + if (_restrictedMode) { + totalMessageSizeErrorsPreAuth.increment(); + } else { + totalMessageSizeErrorsPostAuth.increment(); + } return Future::makeReady(Status(ErrorCodes::ProtocolError, str)); } diff --git a/src/mongo/transport/asio/asio_transport_layer_test.cpp b/src/mongo/transport/asio/asio_transport_layer_test.cpp index 1b38206a25481..dc0ca1d6d50b4 100644 --- a/src/mongo/transport/asio/asio_transport_layer_test.cpp +++ b/src/mongo/transport/asio/asio_transport_layer_test.cpp @@ -41,6 +41,7 @@ #include "mongo/client/dbclient_connection.h" #include "mongo/config.h" +#include "mongo/db/commands/server_status_metric.h" #include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h" #include "mongo/db/server_options.h" #include "mongo/db/service_context_test_fixture.h" @@ -114,6 +115,18 @@ void ping(SyncClient& client) { ASSERT_EQ(client.write(msg.buf(), msg.size()), std::error_code{}); } +/** + * Returns the current value of a network metric by name. + * The metricName should be the field name under metrics.network (e.g., + * "totalMessageSizeErrorPreAuth"). + */ +long long getNetworkMetric(StringData metricName) { + BSONObjBuilder bob; + globalMetricTree()->appendTo(bob); + auto obj = bob.obj(); + return obj["metrics"]["network"][metricName].Long(); +} + transport::AsioTransportLayer::Options defaultTLAOptions() { ServerGlobalParams params; params.noUnixSocket = true; @@ -437,6 +450,153 @@ TEST(AsioTransportLayer, SourceSyncTimeoutSucceeds) { ASSERT_OK(received.get().getStatus()); } +/** + * Test that when the session is in restricted mode (pre-auth), messages with a size + * larger than preAuthMaximumMessageSizeBytes are rejected with a ProtocolError. + */ +TEST(AsioTransportLayer, UnauthenticatedConnectionRejectsOversizedMessage) { + // Set pre-auth max message size to a small value for testing (1024 bytes). + // We also disable the post-header timeout to isolate message size validation. + RAIIServerParameterControllerForTest maxSizeController{"preAuthMaximumMessageSizeBytes", 1024}; + + TestFixture tf; + Notification> received; + tf.sep().setOnStartSession([&](transport::test::SessionThread& st) { + st.schedule([&](auto& session) { + // Put the session in restricted mode (pre-auth). + session.setRestrictedMode(true); + received.set(session.sourceMessage()); + }); + }); + + SyncClient conn(tf.tla().listenerPort()); + + const auto errorsBefore = getNetworkMetric("totalMessageSizeErrorPreAuth"); + + static constexpr size_t kHeaderSize = sizeof(MSGHEADER::Value); + // Claim a message size larger than our configured preAuthMaximumMessageSizeBytes (1024). + static constexpr int32_t kOversizedMessageSize = 2048; + + char headerBuffer[kHeaderSize]; + MSGHEADER::View header(headerBuffer); + header.setMessageLength(kOversizedMessageSize); + header.setRequestMsgId(0); + header.setResponseToMsgId(0); + header.setOpCode(dbMsg); + + // Send the header with an oversized message length. The server should reject this + // immediately after reading the header due to the pre-auth message size limit. + auto ec = conn.write(headerBuffer, kHeaderSize); + ASSERT_FALSE(ec) << errorMessage(ec); + + auto result = received.get(); + ASSERT_EQ(result.getStatus().code(), ErrorCodes::ProtocolError) + << "Expected ProtocolError when message size exceeds preAuthMaximumMessageSizeBytes, got: " + << result.getStatus(); + + // Verify the pre-auth message size error metric was incremented. + ASSERT_EQ(getNetworkMetric("totalMessageSizeErrorPreAuth"), errorsBefore + 1) + << "totalMessageSizeErrorPreAuth metric should increment on oversized pre-auth message"; +} + +/** + * Test that when the session is NOT in restricted mode (post-auth), messages larger than + * preAuthMaximumMessageSizeBytes are allowed (up to MaxMessageSizeBytes). + */ +TEST(AsioTransportLayer, AuthenticatedConnectionAllowsLargerMessages) { + // Set pre-auth max message size to a small value for testing (1024 bytes). + const auto preAuthMaxMsgSize = 1024; + RAIIServerParameterControllerForTest maxSizeController{"preAuthMaximumMessageSizeBytes", + preAuthMaxMsgSize}; + + TestFixture tf; + Notification mockSessionCreated; + tf.sep().setOnStartSession( + [&](transport::test::SessionThread& st) { mockSessionCreated.set(&st); }); + + SyncClient conn(tf.tla().listenerPort()); + auto& st = *mockSessionCreated.get(); + + const auto preAuthErrorsBefore = getNetworkMetric("totalMessageSizeErrorPreAuth"); + const auto postAuthErrorsBefore = getNetworkMetric("totalMessageSizeErrorPostAuth"); + + Notification> done; + st.schedule([&](auto& session) { + // NOT in restricted mode - simulates an authenticated connection. + session.setRestrictedMode(false); + done.set(session.sourceMessage()); + }); + + // Build and send a message that is larger than preAuthMaximumMessageSizeBytes (1024) + // but valid for an authenticated connection. + OpMsgBuilder builder; + builder.setBody(BSON("ping" << 1 << "padding" << std::string(1024, 'x'))); + Message msg = builder.finish(); + msg.header().setResponseToMsgId(0); + msg.header().setId(0); + OpMsg::appendChecksum(&msg); + + // Verify the message size is larger than our configured pre-auth limit. + ASSERT_GT(msg.size(), preAuthMaxMsgSize) + << "Test message should be larger than preAuthMaximumMessageSizeBytes"; + + auto ec = conn.write(msg.buf(), msg.size()); + ASSERT_FALSE(ec) << errorMessage(ec); + + // Should succeed since we're not in restricted mode. + ASSERT_OK(done.get().getStatus()) << "Authenticated connections should accept messages larger " + "than preAuthMaximumMessageSizeBytes"; + + // Verify neither message size error metric was incremented. + ASSERT_EQ(getNetworkMetric("totalMessageSizeErrorPreAuth"), preAuthErrorsBefore) + << "totalMessageSizeErrorPreAuth should not increment for authenticated connections"; + ASSERT_EQ(getNetworkMetric("totalMessageSizeErrorPostAuth"), postAuthErrorsBefore) + << "totalMessageSizeErrorPostAuth should not increment for valid messages"; +} + +/** + * Test that pre-auth message size validation correctly accepts messages within the limit. + */ +TEST(AsioTransportLayer, UnauthenticatedConnectionAcceptsValidSizedMessage) { + // Set pre-auth max message size large enough to accept our test message. + RAIIServerParameterControllerForTest maxSizeController{"preAuthMaximumMessageSizeBytes", 16384}; + + TestFixture tf; + Notification> received; + tf.sep().setOnStartSession([&](transport::test::SessionThread& st) { + st.schedule([&](auto& session) { + session.setRestrictedMode(true); + received.set(session.sourceMessage()); + }); + }); + + SyncClient conn(tf.tla().listenerPort()); + + const auto errorsBefore = getNetworkMetric("totalMessageSizeErrorPreAuth"); + + // Build and send a valid message smaller than preAuthMaximumMessageSizeBytes. + OpMsgBuilder builder; + builder.setBody(BSON("ping" << 1)); + Message msg = builder.finish(); + msg.header().setResponseToMsgId(0); + msg.header().setId(0); + OpMsg::appendChecksum(&msg); + + // Verify the message size is within our configured pre-auth limit. + ASSERT_LT(msg.size(), 16384) << "Test message should be within preAuthMaximumMessageSizeBytes"; + + auto ec = conn.write(msg.buf(), msg.size()); + ASSERT_FALSE(ec) << errorMessage(ec); + + // Should succeed since message is within the pre-auth limit. + ASSERT_OK(received.get().getStatus()) + << "Pre-auth connections should accept messages within preAuthMaximumMessageSizeBytes"; + + // Verify no error metrics were incremented. + ASSERT_EQ(getNetworkMetric("totalMessageSizeErrorPreAuth"), errorsBefore) + << "totalMessageSizeErrorPreAuth should not increment for valid-sized pre-auth messages"; +} + /** Switching from timeouts to no timeouts must reset the timeout to unlimited. */ TEST(AsioTransportLayer, SwitchTimeoutModes) { TestFixture tf; diff --git a/src/mongo/transport/message_compressor_base.h b/src/mongo/transport/message_compressor_base.h index 1ac7543bac904..bea001b1eabe2 100644 --- a/src/mongo/transport/message_compressor_base.h +++ b/src/mongo/transport/message_compressor_base.h @@ -36,6 +36,8 @@ #include +#include + namespace mongo { enum class MessageCompressor : uint8_t { kNoop = 0, @@ -89,6 +91,12 @@ class MessageCompressorBase { */ virtual StatusWith decompressData(ConstDataRange input, DataRange output) = 0; + /* + * Returns the max uncompressed length of the data in the input ConstDataRange as given by the + * header, if available. + */ + virtual boost::optional getMaxDecompressedSize(ConstDataRange input) = 0; + /* * This returns the number of bytes passed in the input for compressData */ diff --git a/src/mongo/transport/message_compressor_manager.cpp b/src/mongo/transport/message_compressor_manager.cpp index 6cc3a928ec50b..a127bad762f13 100644 --- a/src/mongo/transport/message_compressor_manager.cpp +++ b/src/mongo/transport/message_compressor_manager.cpp @@ -47,33 +47,6 @@ namespace mongo { namespace { -// TODO(JBR): This should be changed so it 's closer to the MSGHEADER View/ConstView classes -// than this little struct. -struct CompressionHeader { - int32_t originalOpCode; - int32_t uncompressedSize; - uint8_t compressorId; - - void serialize(DataRangeCursor* cursor) { - cursor->writeAndAdvance>(originalOpCode); - cursor->writeAndAdvance>(uncompressedSize); - cursor->writeAndAdvance>(compressorId); - } - - CompressionHeader(int32_t _opcode, int32_t _size, uint8_t _id) - : originalOpCode{_opcode}, uncompressedSize{_size}, compressorId{_id} {} - - CompressionHeader(ConstDataRangeCursor* cursor) { - originalOpCode = cursor->readAndAdvance>(); - uncompressedSize = cursor->readAndAdvance>(); - compressorId = cursor->readAndAdvance>(); - } - - static size_t size() { - return sizeof(originalOpCode) + sizeof(uncompressedSize) + sizeof(compressorId); - } -}; - const transport::Session::Decoration getForSession = transport::Session::declareDecoration(); } // namespace @@ -145,7 +118,8 @@ StatusWith MessageCompressorManager::compressMessage( } StatusWith MessageCompressorManager::decompressMessage(const Message& msg, - MessageCompressorId* compressorId) { + MessageCompressorId* compressorId, + size_t maxMessageSize) { auto inputHeader = msg.header(); ConstDataRangeCursor input(inputHeader.data(), inputHeader.data() + inputHeader.dataLen()); if (input.length() < CompressionHeader::size()) { @@ -177,11 +151,17 @@ StatusWith MessageCompressorManager::decompressMessage(const Message& m // avoid potential overflow. size_t bufferSize = static_cast(compressionHeader.uncompressedSize) + MsgData::MsgDataHeaderSize; - if (bufferSize > MaxMessageSizeBytes) { + if (bufferSize > maxMessageSize) { return {ErrorCodes::BadValue, "Decompressed message would be larger than maximum message size"}; } + auto maxDecompressedSize = compressor->getMaxDecompressedSize(input); + if (maxDecompressedSize && + *maxDecompressedSize < static_cast(compressionHeader.uncompressedSize)) { + return {ErrorCodes::BadValue, "Uncompressed message size does not match expected size"}; + } + auto outputMessageBuffer = SharedBuffer::allocate(bufferSize); MsgData::View outMessage(outputMessageBuffer.get()); outMessage.setId(inputHeader.getId()); diff --git a/src/mongo/transport/message_compressor_manager.h b/src/mongo/transport/message_compressor_manager.h index 3e49084fd0fca..64a3ec0153d02 100644 --- a/src/mongo/transport/message_compressor_manager.h +++ b/src/mongo/transport/message_compressor_manager.h @@ -47,6 +47,33 @@ class MessageCompressorManager { MessageCompressorManager& operator=(const MessageCompressorManager&) = delete; public: + // TODO(JBR): This should be changed so it 's closer to the MSGHEADER View/ConstView classes + // than this little struct. + struct CompressionHeader { + int32_t originalOpCode; + int32_t uncompressedSize; + uint8_t compressorId; + + void serialize(DataRangeCursor* cursor) { + cursor->writeAndAdvance>(originalOpCode); + cursor->writeAndAdvance>(uncompressedSize); + cursor->writeAndAdvance>(compressorId); + } + + CompressionHeader(int32_t _opcode, int32_t _size, uint8_t _id) + : originalOpCode{_opcode}, uncompressedSize{_size}, compressorId{_id} {} + + CompressionHeader(ConstDataRangeCursor* cursor) { + originalOpCode = cursor->readAndAdvance>(); + uncompressedSize = cursor->readAndAdvance>(); + compressorId = cursor->readAndAdvance>(); + } + + static size_t size() { + return sizeof(originalOpCode) + sizeof(uncompressedSize) + sizeof(compressorId); + } + }; + /* * Default constructor. Uses the global MessageCompressorRegistry. */ @@ -119,7 +146,8 @@ class MessageCompressorManager { * compressMessage, ensuring that the same compressor is used on both sides of a conversation. */ StatusWith decompressMessage(const Message& msg, - MessageCompressorId* compressorId = nullptr); + MessageCompressorId* compressorId = nullptr, + size_t maxMessageSize = MaxMessageSizeBytes); const std::vector& getNegotiatedCompressors() const; diff --git a/src/mongo/transport/message_compressor_manager_test.cpp b/src/mongo/transport/message_compressor_manager_test.cpp index cbb9a7378a4e0..1c37e64e3f12d 100644 --- a/src/mongo/transport/message_compressor_manager_test.cpp +++ b/src/mongo/transport/message_compressor_manager_test.cpp @@ -36,6 +36,7 @@ #include "mongo/bson/bsonobjbuilder.h" #include "mongo/rpc/message.h" +#include "mongo/rpc/op_msg.h" #include "mongo/transport/message_compressor_manager.h" #include "mongo/transport/message_compressor_noop.h" #include "mongo/transport/message_compressor_registry.h" @@ -170,6 +171,26 @@ void checkOverflow(std::unique_ptr compressor) { compressor->decompressData(tooSmallRange, DataRange(scratch.data(), scratch.size()))); } +void checkUndersize(const Message& compressedMsg, + std::unique_ptr compressor) { + MessageCompressorRegistry registry; + const auto compressorName = compressor->getName(); + + std::vector compressorList = {compressorName}; + registry.setSupportedCompressors(std::move(compressorList)); + registry.registerImplementation(std::move(compressor)); + registry.finalizeSupportedCompressors().transitional_ignore(); + + MessageCompressorManager mgr(®istry); + BSONObjBuilder negotiatorOut; + std::vector negotiator({compressorName}); + mgr.serverNegotiate(negotiator, &negotiatorOut); + checkNegotiationResult(negotiatorOut.done(), {compressorName}); + + auto swm = mgr.decompressMessage(compressedMsg); + ASSERT_EQ(ErrorCodes::BadValue, swm.getStatus()); +} + Message buildMessage() { const auto data = std::string{"Hello, world!"}; const auto bufferSize = MsgData::MsgDataHeaderSize + data.size(); @@ -275,6 +296,52 @@ TEST(ZstdMessageCompressor, Overflow) { checkOverflow(std::make_unique()); } +TEST(ZlibMessageCompressor, Mismatch) { + checkOverflow(std::make_unique()); +} + +TEST(SnappyMessageCompressor, Undersize) { + std::vector payload = { + 0x41, 0x0, 0x0, 0x0, 0xad, 0xde, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xdc, + 0x7, 0x0, 0x0, 0xdd, 0x7, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x1, 0x27, + 0x0, 0x0, 0x1, 0x1, 0x84, 0xfb, 0x1f, 0x0, 0x0, 0x5, 0x5f, 0x69, 0x64, + 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x48, 0x45, 0x41, 0x50, 0x4c, 0x45, 0x41, + 0x4b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + + + auto buffer = SharedBuffer::allocate(payload.size()); + std::copy(payload.begin(), payload.end(), buffer.get()); + + checkUndersize(Message(buffer), std::make_unique()); +} + +TEST(ZlibMessageCompressor, Undersize) { + std::vector payload = { + 0x3c, 0x00, 0x00, 0x00, 0xad, 0xde, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xdc, 0x07, 0x00, + 0x00, 0xdd, 0x07, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x02, 0x78, 0xda, 0x63, 0x60, 0x00, + 0x82, 0xdf, 0xf2, 0x0c, 0x0c, 0xac, 0xf1, 0x99, 0x29, 0x0c, 0x0c, 0x02, 0x40, 0x9e, 0x87, + 0xab, 0x63, 0x80, 0x8f, 0xab, 0xa3, 0x37, 0x03, 0x12, 0x00, 0x00, 0x6d, 0x26, 0x04, 0x97}; + + auto buffer = SharedBuffer::allocate(payload.size()); + std::copy(payload.begin(), payload.end(), buffer.get()); + + checkUndersize(Message(buffer), std::make_unique()); +} + +TEST(ZstdMessageCompressor, Undersize) { + std::vector payload = { + 0x44, 0x0, 0x0, 0x0, 0xad, 0xde, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xdc, 0x7, + 0x0, 0x0, 0xdd, 0x7, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x3, 0x28, 0xb5, 0x2f, + 0xfd, 0x20, 0x27, 0x15, 0x1, 0x0, 0xe0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xfb, 0x1f, + 0x0, 0x0, 0x5, 0x5f, 0x69, 0x64, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x48, 0x45, + 0x41, 0x50, 0x4c, 0x45, 0x41, 0x4b, 0x0, 0x1, 0x0, 0x18, 0xc0, 0x9}; + + auto buffer = SharedBuffer::allocate(payload.size()); + std::copy(payload.begin(), payload.end(), buffer.get()); + + checkUndersize(Message(buffer), std::make_unique()); +} + TEST(MessageCompressorManager, SERVER_28008) { // Create a client and server that will negotiate the same compressors, @@ -416,7 +483,7 @@ TEST(MessageCompressorManager, RuntMessage) { badMessage.setOperation(dbCompressed); badMessage.setLen(MsgData::MsgDataHeaderSize + 8); - // This is a totally bogus compression header of just the orginal opcode + 0 byte uncompressed + // This is a totally bogus compression header of just the original opcode + 0 byte uncompressed // size DataRangeCursor cursor(badMessage.data(), badMessage.data() + badMessage.dataLen()); cursor.writeAndAdvance>(dbQuery); @@ -426,5 +493,80 @@ TEST(MessageCompressorManager, RuntMessage) { ASSERT_NOT_OK(status); } +void checkWrongUncompressedSize(std::unique_ptr compressor, + std::string expectedError) { + MessageCompressorRegistry registry; + const auto compressorId = compressor->getId(); + + std::vector compressorList = {compressor->getName()}; + registry.setSupportedCompressors(std::move(compressorList)); + registry.registerImplementation(std::move(compressor)); + ASSERT_OK(registry.finalizeSupportedCompressors()); + MessageCompressorManager manager(®istry); + + OpMsgBuilder msgBuilder; + msgBuilder.setBody(BSON("ping" << 1)); + Message originalMsg = msgBuilder.finishWithoutSizeChecking(); + const auto originalView = originalMsg.singleData(); + const size_t originalDataSize = originalView.dataLen(); + + auto swCompressed = manager.compressMessage(originalMsg, &compressorId); + ASSERT_OK(swCompressed); + Message properlyCompressedMsg = std::move(swCompressed.getValue()); + + const auto compressedView = properlyCompressedMsg.singleData(); + ConstDataRangeCursor input(compressedView.data(), + compressedView.data() + compressedView.dataLen()); + + MessageCompressorManager::CompressionHeader originalHeader(&input); + + const size_t compressedDataSize = input.length(); + const char* compressedDataPtr = input.data(); + + ASSERT_EQ(originalHeader.compressorId, compressorId); + ASSERT_EQ(originalHeader.originalOpCode, dbMsg); + + const int32_t wrongUncompressedSize = static_cast(originalDataSize * 2); + + const size_t malformedMessageSize = MsgData::MsgDataHeaderSize + + MessageCompressorManager::CompressionHeader::size() + compressedDataSize; + auto malformedBuffer = SharedBuffer::allocate(malformedMessageSize); + + MsgData::View malformedView(malformedBuffer.get()); + malformedView.setId(originalView.getId()); + malformedView.setResponseToMsgId(originalView.getResponseToMsgId()); + malformedView.setOperation(dbCompressed); + malformedView.setLen(malformedMessageSize); + + DataRangeCursor output(malformedView.data(), malformedView.data() + malformedView.dataLen()); + output.writeAndAdvance>(originalHeader.originalOpCode); + output.writeAndAdvance>(wrongUncompressedSize); + output.writeAndAdvance>(originalHeader.compressorId); + + std::memcpy(output.data(), compressedDataPtr, compressedDataSize); + + Message malformedMessage(malformedBuffer); + + auto swDecompressed = manager.decompressMessage(malformedMessage); + ASSERT_NOT_OK(swDecompressed.getStatus()); + ASSERT_EQ(swDecompressed.getStatus().code(), ErrorCodes::BadValue); + ASSERT_STRING_CONTAINS(swDecompressed.getStatus().reason(), expectedError); +} + +TEST(SnappyMessageCompressor, WrongUncompressedSize) { + checkWrongUncompressedSize(std::make_unique(), + "Uncompressed message size does not match expected size"); +} + +TEST(ZstdMessageCompressor, WrongUncompressedSize) { + checkWrongUncompressedSize(std::make_unique(), + "Uncompressed message size does not match expected size"); +} + +TEST(ZlibMessageCompressor, WrongUncompressedSize) { + checkWrongUncompressedSize(std::make_unique(), + "Decompressing message returned less data than expected"); +} + } // namespace } // namespace mongo diff --git a/src/mongo/transport/message_compressor_noop.h b/src/mongo/transport/message_compressor_noop.h index e8f07560d64e4..ed378ce5e1601 100644 --- a/src/mongo/transport/message_compressor_noop.h +++ b/src/mongo/transport/message_compressor_noop.h @@ -39,6 +39,10 @@ class NoopMessageCompressor final : public MessageCompressorBase { return inputSize; } + boost::optional getMaxDecompressedSize(ConstDataRange input) override { + return boost::none; + } + StatusWith compressData(ConstDataRange input, DataRange output) override try { output.write(input); counterHitCompress(input.length(), input.length()); diff --git a/src/mongo/transport/message_compressor_snappy.cpp b/src/mongo/transport/message_compressor_snappy.cpp index 1d926174cf5d4..a137fc7b7fae3 100644 --- a/src/mongo/transport/message_compressor_snappy.cpp +++ b/src/mongo/transport/message_compressor_snappy.cpp @@ -47,6 +47,14 @@ std::size_t SnappyMessageCompressor::getMaxCompressedSize(size_t inputSize) { return snappy::MaxCompressedLength(inputSize); } +boost::optional SnappyMessageCompressor::getMaxDecompressedSize(ConstDataRange input) { + size_t length = 0; + if (snappy::GetUncompressedLength(input.data(), input.length(), &length)) { + return length; + } + return boost::none; +} + StatusWith SnappyMessageCompressor::compressData(ConstDataRange input, DataRange output) { size_t outLength = output.length(); diff --git a/src/mongo/transport/message_compressor_snappy.h b/src/mongo/transport/message_compressor_snappy.h index 135aa49fad3e2..1cdb602096d98 100644 --- a/src/mongo/transport/message_compressor_snappy.h +++ b/src/mongo/transport/message_compressor_snappy.h @@ -29,6 +29,10 @@ #include "mongo/transport/message_compressor_base.h" +#include + +#include + namespace mongo { class SnappyMessageCompressor final : public MessageCompressorBase { public: @@ -39,6 +43,8 @@ class SnappyMessageCompressor final : public MessageCompressorBase { StatusWith compressData(ConstDataRange input, DataRange output) override; StatusWith decompressData(ConstDataRange input, DataRange output) override; + + boost::optional getMaxDecompressedSize(ConstDataRange input) override; }; diff --git a/src/mongo/transport/message_compressor_zlib.cpp b/src/mongo/transport/message_compressor_zlib.cpp index 878431b28244c..06d1693fbbe94 100644 --- a/src/mongo/transport/message_compressor_zlib.cpp +++ b/src/mongo/transport/message_compressor_zlib.cpp @@ -35,6 +35,8 @@ #include "mongo/transport/message_compressor_registry.h" #include "mongo/transport/message_compressor_zlib.h" +#include +#include #include namespace mongo { @@ -45,6 +47,10 @@ std::size_t ZlibMessageCompressor::getMaxCompressedSize(size_t inputSize) { return ::compressBound(inputSize); } +boost::optional ZlibMessageCompressor::getMaxDecompressedSize(ConstDataRange input) { + return boost::none; +} + StatusWith ZlibMessageCompressor::compressData(ConstDataRange input, DataRange output) { size_t outLength = output.length(); @@ -74,7 +80,7 @@ StatusWith ZlibMessageCompressor::decompressData(ConstDataRange inp } counterHitDecompress(input.length(), output.length()); - return {output.length()}; + return {length}; } diff --git a/src/mongo/transport/message_compressor_zlib.h b/src/mongo/transport/message_compressor_zlib.h index 2c25cac339c22..1770868426870 100644 --- a/src/mongo/transport/message_compressor_zlib.h +++ b/src/mongo/transport/message_compressor_zlib.h @@ -29,6 +29,10 @@ #include "mongo/transport/message_compressor_base.h" +#include + +#include + namespace mongo { class ZlibMessageCompressor final : public MessageCompressorBase { public: @@ -36,6 +40,8 @@ class ZlibMessageCompressor final : public MessageCompressorBase { std::size_t getMaxCompressedSize(size_t inputSize) override; + boost::optional getMaxDecompressedSize(ConstDataRange input) override; + StatusWith compressData(ConstDataRange input, DataRange output) override; StatusWith decompressData(ConstDataRange input, DataRange output) override; diff --git a/src/mongo/transport/message_compressor_zstd.cpp b/src/mongo/transport/message_compressor_zstd.cpp index ddfe58c12f326..ba1076aa7d59e 100644 --- a/src/mongo/transport/message_compressor_zstd.cpp +++ b/src/mongo/transport/message_compressor_zstd.cpp @@ -31,6 +31,8 @@ #include +#include +#include #include #include "mongo/base/init.h" @@ -75,8 +77,12 @@ StatusWith ZstdMessageCompressor::decompressData(ConstDataRange inp return {ret}; } -std::size_t ZstdMessageCompressor::getMaxDecompressedSize(const void* src, size_t srcSize) { - auto maxDecompressedSize = ZSTD_getFrameContentSize(src, srcSize); +boost::optional ZstdMessageCompressor::getMaxDecompressedSize(ConstDataRange input) { + auto maxDecompressedSize = ZSTD_getFrameContentSize(input.data(), input.length()); + if (maxDecompressedSize == ZSTD_CONTENTSIZE_UNKNOWN || + maxDecompressedSize == ZSTD_CONTENTSIZE_ERROR) { + return boost::none; + } return static_cast(maxDecompressedSize); } diff --git a/src/mongo/transport/message_compressor_zstd.h b/src/mongo/transport/message_compressor_zstd.h index c393f63b2f376..03c7487350bf5 100644 --- a/src/mongo/transport/message_compressor_zstd.h +++ b/src/mongo/transport/message_compressor_zstd.h @@ -40,7 +40,7 @@ class ZstdMessageCompressor final : public MessageCompressorBase { StatusWith decompressData(ConstDataRange input, DataRange output) override; - std::size_t getMaxDecompressedSize(const void* src, size_t srcSize); + boost::optional getMaxDecompressedSize(ConstDataRange input) override; }; diff --git a/src/mongo/transport/service_entry_point_impl.cpp b/src/mongo/transport/service_entry_point_impl.cpp index 534bd40364f19..f486592d21a9a 100644 --- a/src/mongo/transport/service_entry_point_impl.cpp +++ b/src/mongo/transport/service_entry_point_impl.cpp @@ -62,9 +62,11 @@ #include "mongo/transport/session_establishment_rate_limiter.h" #include "mongo/transport/session_establishment_rate_limiter_utils.h" #include "mongo/transport/session_workflow.h" +#include "mongo/transport/transport_options_gen.h" #include "mongo/util/duration.h" #include "mongo/util/hierarchical_acquisition.h" #include "mongo/util/net/cidr.h" +#include "mongo/util/processinfo.h" #if !defined(_WIN32) #include @@ -86,6 +88,21 @@ bool quiet() { return serverGlobalParams.quiet.load(); } +boost::optional calculateSafeConnectionLimit() { + invariant(transport::gMemoryCapPercentageForPreAuthBuffers > 0 && + transport::gMemoryCapPercentageForPreAuthBuffers <= 100); + if (transport::gMemoryCapPercentageForPreAuthBuffers == 100) { + return {}; + } + + const auto maxPreAuthBufferSizeBytes = transport::gPreAuthMaximumMessageSizeBytes.load(); + const auto totalMemoryBytes = ProcessInfo::getMemSizeMB() * 1024 * 1024; + const auto maxMemoryForPreAuthBuffersBytes = + (totalMemoryBytes * transport::gMemoryCapPercentageForPreAuthBuffers) / 100; + + return maxMemoryForPreAuthBuffersBytes / maxPreAuthBufferSizeBytes; +} + /** Some diagnostic data that we will want to log about a Client after its death. */ struct ClientSummary { explicit ClientSummary(const Client* c) @@ -147,6 +164,15 @@ size_t getSupportedMax() { "limit"_attr = supportedMax); } + const auto safeConnectionLimit = calculateSafeConnectionLimit(); + if (safeConnectionLimit && *safeConnectionLimit < supportedMax) { + LOGV2_WARNING( + 11621101, + "Overriding max connections to honor `capMemoryConsumptionForPreAuthBuffers` settings", + "limit"_attr = *safeConnectionLimit); + return *safeConnectionLimit; + } + return supportedMax; } diff --git a/src/mongo/transport/session.h b/src/mongo/transport/session.h index 7dbabe324157d..3dc0be28100a5 100644 --- a/src/mongo/transport/session.h +++ b/src/mongo/transport/session.h @@ -106,6 +106,14 @@ class Session : public std::enable_shared_from_this, public Decorable, public DecorableshouldIgnoreAuthChecks()) { + // If authentication is disabled, never consider sessions pre-auth. + return false; + } + + return !authorizationSession->isAuthenticated(); +}; } // namespace class SessionWorkflow::Impl { @@ -487,6 +505,9 @@ class SessionWorkflow::Impl { // latency. _yieldPointReached(); _iterationFrame->metrics.yieldedBeforeReceive(); + ON_BLOCK_EXIT( + [&, old = session()->getRestrictedMode()] { session()->setRestrictedMode(old); }); + session()->setRestrictedMode(isPreAuth(client())); return _receiveRequest(); } auto&& [p, f] = makePromiseFuture(); @@ -575,7 +596,10 @@ class SessionWorkflow::Impl::WorkItem { if (_in.operation() != dbCompressed) return; MessageCompressorId cid; - _in = uassertStatusOK(compressorMgr().decompressMessage(_in, &cid)); + _in = isPreAuth(_swf->client()) + ? uassertStatusOK(compressorMgr().decompressMessage( + _in, &cid, gPreAuthMaximumMessageSizeBytes.loadRelaxed())) + : uassertStatusOK(compressorMgr().decompressMessage(_in, &cid)); _compressorId = cid; } diff --git a/src/mongo/transport/session_workflow_test.cpp b/src/mongo/transport/session_workflow_test.cpp index 4078cf051d532..f71cf37527240 100644 --- a/src/mongo/transport/session_workflow_test.cpp +++ b/src/mongo/transport/session_workflow_test.cpp @@ -41,9 +41,13 @@ #include #include "mongo/base/checked_cast.h" +#include "mongo/base/data_range_cursor.h" +#include "mongo/base/data_type_endian.h" +#include "mongo/base/error_codes.h" #include "mongo/base/status.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/auth/authorization_manager.h" #include "mongo/db/client.h" #include "mongo/db/client_strand.h" #include "mongo/db/concurrency/locker_noop_service_context_test_fixture.h" @@ -53,8 +57,12 @@ #include "mongo/logv2/log.h" #include "mongo/platform/compiler.h" #include "mongo/platform/mutex.h" +#include "mongo/rpc/message.h" #include "mongo/rpc/op_msg.h" #include "mongo/stdx/mutex.h" +#include "mongo/transport/message_compressor_manager.h" +#include "mongo/transport/message_compressor_registry.h" +#include "mongo/transport/message_compressor_snappy.h" #include "mongo/transport/mock_session.h" #include "mongo/transport/service_entry_point.h" #include "mongo/transport/service_entry_point_impl.h" @@ -67,6 +75,7 @@ #include "mongo/unittest/log_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/concurrency/thread_pool.h" +#include "mongo/util/shared_buffer.h" #include "mongo/util/synchronized_value.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest @@ -946,5 +955,42 @@ TEST_F(SessionWorkflowWithBorrowedThreadsTest, MoreToComeLoop) { runSteps(convertStepsToBorrowed(moreToComeLoop())); } +TEST_F(SessionWorkflowTest, OversizedDecompressedMessage) { + auto* authzManager = AuthorizationManager::get(getServiceContext()); + authzManager->setAuthEnabled(true); + ScopeGuard resetAuth([authzManager] { authzManager->setAuthEnabled(false); }); + + auto& registry = MessageCompressorRegistry::get(); + const auto& compressorNames = registry.getCompressorNames(); + if (std::ranges::find(compressorNames, "snappy") == compressorNames.end()) { + registry.setSupportedCompressors({"snappy"}); + registry.registerImplementation(std::make_unique()); + uassertStatusOK(registry.finalizeSupportedCompressors()); + } + + RAIIServerParameterControllerForTest maxSizeController{"preAuthMaximumMessageSizeBytes", 1024}; + + startSession(); + + const size_t bufferSize = + MsgData::MsgDataHeaderSize + MessageCompressorManager::CompressionHeader::size(); + auto buffer = SharedBuffer::allocate(bufferSize); + MsgData::View msgView(buffer.get()); + msgView.setId(1); + msgView.setResponseToMsgId(0); + msgView.setOperation(dbCompressed); + msgView.setLen(bufferSize); + + DataRangeCursor cursor(msgView.data(), msgView.data() + msgView.dataLen()); + const auto snappyId = static_cast(MessageCompressor::kSnappy); + MessageCompressorManager::CompressionHeader header(dbMsg, 2048, snappyId); + header.serialize(&cursor); + + expect(StatusWith{Message(buffer)}); + + expect(); + joinSessions(); +} + } // namespace } // namespace mongo::transport diff --git a/src/mongo/transport/transport_options.idl b/src/mongo/transport/transport_options.idl index a5db361ae2a04..04ea785055b4b 100644 --- a/src/mongo/transport/transport_options.idl +++ b/src/mongo/transport/transport_options.idl @@ -165,7 +165,39 @@ server_parameters: set_at: [startup, runtime] cpp_varname: gProxyProtocolMaximumPendingConnections cpp_vartype: AtomicWord - default: {expr: "static_cast(DEFAULT_MAX_CONN)"} + default: 16000 validator: {gte: 0} redact: false + preAuthMaximumMessageSizeBytes: + description: >- + The maximum size of a message that can be sent before the session is authenticated. + set_at: [startup, runtime] + cpp_varname: gPreAuthMaximumMessageSizeBytes + cpp_vartype: AtomicWord + default: 16384 # 16KB + validator: {gte: 1024} # 1KB + redact: false + + messageSizeErrorRateSec: + description: >- + The rate, in seconds, at which message size errors are logged at the Info level. + set_at: [startup] + cpp_varname: gMessageSizeErrorRateSec + cpp_vartype: int32_t + default: 5 + validator: {gte: 0} + redact: false + + capMemoryConsumptionForPreAuthBuffers: + description: >- + Adjusts the maximum number of connections accepted by the server, if needed, to ensure + the total memory allocated to pre-auth buffers does not exceed X% of total memory + available to the server, where X is set to 20 by default. Accepts values in percentages + and setting it to 100 disables this limitation. + set_at: [startup] + cpp_varname: gMemoryCapPercentageForPreAuthBuffers + cpp_vartype: int + default: 20 + validator: {gt: 0, lte: 100} + redact: false