diff --git a/.github/workflows/bullmq-tests.yml b/.github/workflows/bullmq-tests.yml index 047af4e6ee72..351e792a68f7 100644 --- a/.github/workflows/bullmq-tests.yml +++ b/.github/workflows/bullmq-tests.yml @@ -54,12 +54,15 @@ jobs: - name: Start Dragonfly run: | + mkdir -p /tmp/df-logs ${GITHUB_WORKSPACE}/build/dragonfly \ --alsologtostderr \ + --log_dir=/tmp/df-logs \ --cluster_mode=emulated \ --lock_on_hashtags \ --dbfilename= \ - --port 6379 & + --port 6379 \ + >/tmp/df-logs/stdout.log 2>/tmp/df-logs/stderr.log & timeout 15s bash -c 'until redis-cli -p 6379 PING 2>/dev/null | grep -q PONG; do sleep 0.1; done' - name: Build BullMQ @@ -84,8 +87,8 @@ jobs: if: failure() uses: actions/upload-artifact@v7 with: - name: unit_logs - path: /tmp/dragonfly.* + name: dragonfly-logs + path: /tmp/df-logs/ - name: Send notification on failure if: failure() && github.ref == 'refs/heads/main' diff --git a/contrib/charts/dragonfly/go.mod b/contrib/charts/dragonfly/go.mod index 22806292eef6..c59e2410fa9c 100644 --- a/contrib/charts/dragonfly/go.mod +++ b/contrib/charts/dragonfly/go.mod @@ -1,13 +1,11 @@ module dragonfly -go 1.24.0 - -toolchain go1.24.7 +go 1.25.0 require github.com/gruntwork-io/terratest v0.51.0 require ( - filippo.io/edwards25519 v1.1.0 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/BurntSushi/toml v1.5.0 // indirect github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect @@ -83,7 +81,7 @@ require ( github.com/homeport/dyff v1.10.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.6 // indirect + github.com/jackc/pgx/v5 v5.9.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect @@ -93,7 +91,7 @@ require ( github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect github.com/mitchellh/hashstructure v1.1.0 // indirect - github.com/moby/spdystream v0.5.0 // indirect + github.com/moby/spdystream v0.5.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/contrib/charts/dragonfly/go.sum b/contrib/charts/dragonfly/go.sum index f113dc8a9f0c..e14ff49cef56 100644 --- a/contrib/charts/dragonfly/go.sum +++ b/contrib/charts/dragonfly/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw= +filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= @@ -169,8 +169,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw= +github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -198,8 +198,8 @@ github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= github.com/mitchellh/hashstructure v1.1.0 h1:P6P1hdjqAAknpY/M1CGipelZgp+4y9ja9kmUZPXP+H0= github.com/mitchellh/hashstructure v1.1.0/go.mod h1:xUDAozZz0Wmdiufv0uyhnHkUTN6/6d8ulp4AwfLKrmA= -github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU= -github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= +github.com/moby/spdystream v0.5.1 h1:9sNYeYZUcci9R6/w7KDaFWEWeV4LStVG78Mpyq/Zm/Y= +github.com/moby/spdystream v0.5.1/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/docs/pub-sub.md b/docs/pub-sub.md index ab300ce8989f..a78ca14f2b24 100644 --- a/docs/pub-sub.md +++ b/docs/pub-sub.md @@ -2,9 +2,9 @@ This document describes how Dragonfly implements the Publish-Subscribe (Pub/Sub) messaging paradigm within its shared-nothing, multi-threaded architecture. It covers the global -subscription registry, the Read-Copy-Update (RCU) mechanism used to prevent lock contention -on the publish path, the asynchronous message delivery pipeline, and the backpressure system -that protects the server from slow-subscriber OOM. +subscription registry backed by a `ShardedHashMap`, the per-shard two-lock RCU mechanism +used to minimize lock contention on the publish path, the asynchronous message delivery +pipeline, and the backpressure system that protects the server from slow-subscriber OOM. ## Overview @@ -13,19 +13,25 @@ unique challenge: subscriptions must be globally addressable across all threads, global lock on every `PUBLISH` would create a severe bottleneck. A single popular channel with thousands of subscribers could serialize all publish operations onto one shard thread. -Dragonfly solves this by using a **centralized `ChannelStore` updated via RCU -(Read-Copy-Update)**: +Dragonfly solves this with a **single global `ChannelStore`** backed by a +`ShardedHashMap` — a custom hash map split into 16 independent +shards, each protected by two fiber-aware locks: -- **Reads (`PUBLISH` / `SPUBLISH`)** are lock-free and use a thread-local pointer to the - most recent `ChannelStore` snapshot. -- **Writes (`SUBSCRIBE` / `UNSUBSCRIBE` / `PSUBSCRIBE` / `PUNSUBSCRIBE`)** are serialized - by a single mutex, performed by copying the necessary routing maps, applying the mutation, - and atomically swapping the global pointer. +- **`write_mu_`** (exclusive) — serializes writers within a shard. Readers never acquire it. +- **`read_mu_`** (shared/exclusive) — taken shared by readers; taken exclusively only for + structural map changes (inserting/erasing channel entries) and for safe deletion of old + `SubscribeMap` pointers (draining in-flight readers). + +Within each shard, subscriber updates use an **RCU-style pointer swap** via +`UpdatablePointer`: the writer copies the old `SubscribeMap`, modifies the copy, and +atomically stores the new pointer — all while holding only `write_mu_`, so readers on the +same shard proceed concurrently. Structural changes (new channel, channel deletion) briefly +acquire `read_mu_` exclusively to block readers. This design avoids contention on a single shard thread for heavy throughput on a single -channel and seamlessly scales across multiple threads even with a small number of channels. -Publish latency is lower than a shard-routed design because no inter-thread hop is required -to look up subscribers — the caller reads its local copy directly. +channel and scales across threads even with a small number of channels. Publish latency is +low because no inter-thread hop is required to look up subscribers — the caller reads its +shard's `read_mu_` in shared mode directly. Dragonfly supports three flavors of Pub/Sub: @@ -39,11 +45,11 @@ Dragonfly supports three flavors of Pub/Sub: | Type | Location | Role | |------|----------|------| -| `ChannelStore` | `src/server/channel_store.h` | Centralized registry mapping channels/patterns to subscribers. Updated via RCU. | -| `ChannelStoreUpdater` | `src/server/channel_store.h` | Orchestrates RCU mutations (add/remove) to the `ChannelStore`. | +| `ChannelStore` | `src/server/channel_store.h` | Centralized registry mapping channels/patterns to subscribers. Single global instance (`extern ChannelStore* channel_store`). | +| `ChannelStoreUpdater` | `src/server/channel_store.h` | Batches subscribe/unsubscribe operations by shard and applies them in one `Mutate` call per shard. | | `ChannelStore::Subscriber` | `src/server/channel_store.h` | Represents a subscribed client. Wraps `facade::ConnectionRef` plus a pattern string. | -| `ChannelStore::ControlBlock` | `src/server/channel_store.h` | Holds the `most_recent` atomic pointer and `update_mu` mutex. Prevents overlapping structural updates. | -| `ChannelStore::ChannelMap` | `src/server/channel_store.h` | `flat_hash_map` — maps channel/pattern names to subscriber lists. | +| `ChannelStore::ChannelMap` | `src/server/channel_store.h` | `ShardedHashMap` — sharded map of channel/pattern names to subscriber lists. | +| `ShardedHashMap` | `src/core/sharded_hash_map.h` | Generic thread-safe sharded hash map. 16 shards, each with `write_mu_` and `read_mu_` fiber-aware locks over an `absl::flat_hash_map`. | | `ChannelStore::SubscribeMap` | `src/server/channel_store.h` | `flat_hash_map` — maps subscriber contexts to their owning thread. | | `ChannelStore::UpdatablePointer` | `src/server/channel_store.h` | Atomic wrapper around `SubscribeMap*`. Supports lock-free reads (`acquire`) and RCU-style swaps (`release`). | | `ConnectionState::SubscribeInfo` | `src/server/conn_context.h` | Per-connection set of subscribed channels and patterns. Created lazily on first subscription. | @@ -58,89 +64,86 @@ Dragonfly supports three flavors of Pub/Sub: Pub/Sub Data Flow -## Subscription Management (RCU) +## Subscription Management (Sharded RCU) ### Data Structure Layout -Each `ChannelStore` instance holds two `ChannelMap` pointers: +The single global `ChannelStore` holds two `ChannelMap` instances (each a +`ShardedHashMap`):
Data Structure Layout
+Each of the 16 shards contains an `absl::flat_hash_map` guarded +by two fiber-aware locks: `write_mu_` (serializes writers) and `read_mu_` (shared for +readers, exclusive for structural changes). + `UpdatablePointer` wraps a `std::atomic` with `memory_order_acquire` on read and `memory_order_release` on write. This ensures that when a thread reads the pointer, it also sees the fully constructed `SubscribeMap` that the writer published. -### Two Levels of RCU +### Per-Shard Two-Lock RCU -The `ChannelStoreUpdater` implements two granularities of copy-on-write: +The `ChannelStoreUpdater` groups pending subscribe/unsubscribe operations by shard index +(via `Record()` → `ShardOf(channel)`) and processes each shard in a single `Mutate()` call. -1. **ChannelMap-level copy** — triggered when a channel slot must be added (first subscriber) - or removed (last subscriber leaves). The entire `ChannelMap` is shallow-copied, the slot is - added/removed on the copy, a new `ChannelStore` is allocated pointing to the new map, and - the global `control_block.most_recent` is swapped. +Within each shard's `Mutate()` callback, the updater handles two cases: -2. **SubscribeMap-level RCU** — triggered when adding/removing a subscriber to an existing - channel (the map slot already exists). Only the `SubscribeMap` for that channel is copied, - the mutation is applied, and the `UpdatablePointer` is atomically swapped. No new - `ChannelStore` or `ChannelMap` is needed. +**Case 1: Existing channel (add/remove subscriber, channel slot stays)** +1. Acquire `write_mu_` exclusively (done by `Mutate`) — serializes writers on this shard. +2. Copy the `SubscribeMap`, apply the mutation, atomically swap via `UpdatablePointer::Set`. + Readers are NOT blocked — they may still read the old pointer. +3. Push the old `SubscribeMap*` onto a per-shard `freelist_`. +4. Release `write_mu_` (Mutate returns). +5. Acquire `read_mu_` exclusively via `WithReadExclusiveLock` — this drains any reader that + loaded the old `SubscribeMap` pointer, then deletes all entries in the freelist. -This two-level scheme is implemented in `ChannelStoreUpdater::GetTargetMap()`: +**Case 2: New channel (first subscriber) or channel deletion (last subscriber leaves)** +1. Inside the `Mutate` callback, call `LockReaders()` to acquire `read_mu_` exclusively. + This blocks all readers in the shard while inserting or erasing the key. +2. For add: emplace a new `UpdatablePointer{new SubscribeMap{{cntx_, thread_id_}}}`. +3. For remove: delete the `SubscribeMap`, erase the map entry. +4. Writers on other shards are unaffected. -```cpp -pair ChannelStoreUpdater::GetTargetMap(ChannelStore* store) { - auto* target = pattern_ ? store->patterns_ : store->channels_; - - for (auto key : ops_) { - auto it = target->find(key); - DCHECK(it != target->end() || to_add_); - // We need to make a copy, if we are going to add or delete a new map slot. - if ((to_add_ && it == target->end()) || (!to_add_ && it->second->size() == 1)) - return {new ChannelStore::ChannelMap{*target}, true}; - } - - return {target, false}; -} -``` - -### Apply() Flow - -
- Apply Flow -
- -Step 8 uses `AwaitBrief` (non-preempting dispatch) to update each thread's local pointer. -The `seq_cst` load in the callback ensures the thread reads the latest pointer value _and_ -the memory published behind it. - -### Modify() — Per-Key Mutation - -For each key in the pending operations: +### Apply() — Batch Per-Shard Mutation ``` -Modify(target, key) - it = target->find(key) - - Case 1: Adding, key not in map (new channel) - → target->emplace(key, new SubscribeMap{{cntx_, thread_id_}}) - - Case 2: Removing, last subscriber (channel disappears) - → freelist_.push_back(it->second.Get()) // defer deletion - → target->erase(it) - - Case 3: Existing channel, add/remove subscriber (RCU on SubscribeMap) - → replacement = new SubscribeMap{*it->second} - → if to_add_: replacement->emplace(cntx_, thread_id_) - else: replacement->erase(cntx_) - → freelist_.push_back(it->second.Get()) // old map, defer deletion - → it->second.Set(replacement) // atomic release-store +ChannelStoreUpdater::Apply() + for each shard sid in 0..15: + if ops_[sid] empty: continue + + map.Mutate(ShardId{sid}, [&](const auto& m, auto LockReaders) { + // Phase 1: RCU updates for existing channels (only write_mu_ held) + for each key in ops_[sid]: + it = m.find(key) + if to_add_ and it exists: + → copy SubscribeMap, add {cntx_, thread_id_}, swap pointer + → push old pointer to freelist_[sid] + if !to_add_ and it exists and size > 1: + → copy SubscribeMap, erase cntx_, swap pointer + → push old pointer to freelist_[sid] + if needs structural change: + → mark needs_map_change[i] = true + + // Phase 2: structural changes (acquire read_mu_ exclusively) + if has_map_change: + auto locked = LockReaders() + for each key needing map change: + if to_add_: locked.map.emplace(key, new SubscribeMap{...}) + if !to_add_: delete ptr, locked.map.erase(it) + }) + + // Phase 3: drain readers, delete old SubscribeMaps + if freelist_[sid] not empty: + map.WithReadExclusiveLock(ShardId{sid}, [&] { + for each sm in freelist_[sid]: delete sm + }) ``` -Old `SubscribeMap` pointers are not immediately deleted because concurrent `PUBLISH` -operations on other threads may still be reading them. They are placed in a `freelist_` and -deleted only after `AwaitBrief` completes — at which point every thread has acknowledged the -new state and no reader can hold a reference to the old maps. +This batching minimizes lock acquisitions: all keys mapping to the same shard are processed +under a single `write_mu_` acquisition, and old `SubscribeMap` pointers are cleaned up in +one `read_mu_` exclusive pass. ### Connection-Level Subscription State @@ -176,8 +179,8 @@ When a client issues `PUBLISH channel message` (or `SPUBLISH`): ``` SendMessages(channel, messages, sharded) 1. subscribers = FetchSubscribers(channel) - → exact match: channels_->find(channel) - → pattern match: for each (pat, subs) in *patterns_: + → exact match: channels_.FindIf(channel, ...) + → pattern match: patterns_.ForEachShared(...) if GlobMatcher{pat}.Matches(channel): Fill(subs, pat, &result) → sort result by thread_id (enables efficient per-thread dispatch) @@ -221,10 +224,12 @@ string allocations. ``` FetchSubscribers(channel) - 1. Exact match: channels_->find(channel) + 1. Exact match: channels_.FindIf(channel, callback) + → acquires read_mu_ shared on the channel's shard → if found, Fill() creates Subscriber entries from the SubscribeMap - 2. Pattern match: iterate ALL patterns + 2. Pattern match: patterns_.ForEachShared(callback) + → iterates ALL patterns across all 16 shards (each shard locked independently) → for each (pat, subs): GlobMatcher{pat, case_sensitive=true}.Matches(channel) → matching subscribers are added with their pattern string @@ -232,6 +237,10 @@ FetchSubscribers(channel) → enables O(log n) per-thread lookup during dispatch ``` +**Note**: `FetchSubscribers` is not atomic — each shard is locked independently via shared +`read_mu_`, so the result may not reflect a fully consistent state. This trade-off is +acceptable for pub/sub use cases. + The `Fill` helper reads the `SubscribeMap` (via `UpdatablePointer::Get()` — acquire load) and creates `Subscriber` structs that hold a `ConnectionRef` (weak reference) obtained via `conn->Borrow()`. @@ -411,20 +420,31 @@ is called: ``` UnsubscribeAfterClusterSlotMigration(deleted_slots) - for each (channel, _) in *channels_: + // Phase 1: collect matching channels and their subscribers + channels_.ForEachShared([&](channel, up) { if deleted_slots.Contains(KeySlot(channel)): - csu.Record(channel) - csu.ApplyAndUnsubscribe() + Fill(*up, "", &subs) + owned_subs[channel] = sorted subs + }) + + if owned_subs empty: return + + // Phase 2: remove all subscribers from matched channels + for each (channel, _) in owned_subs: + RemoveAllSubscribers(false, channel) + + // Phase 3: notify connections on their owning threads + pool->AwaitFiberOnAll([&](idx, _) { + UnsubscribeConnectionsFromDeletedSlots(channel_subs_map, idx) + }) ``` -`ApplyAndUnsubscribe()` differs from `Apply()`: -1. It deep-copies the `ChannelMap` and removes the migrated channels. -2. It calls `FetchSubscribers` for each removed channel _before_ updating the store - (since `FetchSubscribers` reads from the current active store). -3. It uses `AwaitFiberOnAll` (fiber-based, may preempt) instead of `AwaitBrief` to dispatch - both the store update and unsubscription messages. -4. On each thread, `UnsubscribeConnectionsFromDeletedSlots` sends `PubMessage`s with - `force_unsubscribe=true`, which triggers `sunsubscribe` push messages to affected clients. +`RemoveAllSubscribers` uses `Mutate` to acquire `write_mu_`, then `LockReaders()` to block +readers while deleting the `SubscribeMap` and erasing the channel entry. + +`AwaitFiberOnAll` (fiber-based, may preempt) dispatches to each thread, where +`UnsubscribeConnectionsFromDeletedSlots` sends `PubMessage`s with `force_unsubscribe=true` +via `BuildSender`, triggering `sunsubscribe` push messages to affected clients. ## Keyspace Event Notifications @@ -439,8 +459,7 @@ When enabled: 3. At the end of `DeleteExpiredStep`, batched events are published: ```cpp -ChannelStore* store = ServerState::tlocal()->channel_store(); -store->SendMessages( +channel_store->SendMessages( absl::StrCat("__keyevent@", cntx.db_index, "__:expired"), events, false); events.clear(); @@ -476,11 +495,11 @@ Notable flags: | Purpose | File Path | |---------|-----------| | ChannelStore & ChannelStoreUpdater | `src/server/channel_store.h`, `src/server/channel_store.cc` | +| ShardedHashMap (underlying data structure) | `src/core/sharded_hash_map.h` | | Pub/Sub command handlers | `src/server/main_service.cc` (`Publish`, `Subscribe`, `Unsubscribe`, `PSubscribe`, `PUnsubscribe`, `Pubsub`) | | Connection-level subscription state | `src/server/conn_context.h`, `src/server/conn_context.cc` (`ChangeSubscriptions`, `UnsubscribeAll`, `PUnsubscribeAll`) | | PubMessage, AsyncFiber, backpressure | `src/facade/dragonfly_connection.h`, `src/facade/dragonfly_connection.cc` | | ConnectionRef (weak subscriber refs) | `src/facade/connection_ref.h` | -| ServerState channel_store_ pointer | `src/server/server_state.h`, `src/server/server_state.cc` | | Keyspace event integration | `src/server/db_slice.cc` (`DeleteExpiredStep`) | -| Cluster slot migration unsub | `src/server/channel_store.cc` (`UnsubscribeAfterClusterSlotMigration`, `ApplyAndUnsubscribe`) | +| Cluster slot migration unsub | `src/server/channel_store.cc` (`UnsubscribeAfterClusterSlotMigration`, `RemoveAllSubscribers`) | | GlobMatcher for pattern matching | `src/core/glob_matcher.h` | diff --git a/docs/transaction.md b/docs/transaction.md index 8eff4366c619..2317648a3247 100644 --- a/docs/transaction.md +++ b/docs/transaction.md @@ -8,7 +8,7 @@ This document describes how Dragonfly transactions provide atomicity and seriali Serializability is an isolation level for database transactions. Serializability describes multiple transactions, where a transaction is usually composed of multiple operations on multiple objects. -Database can executed transactions in parallel (and the operations in parallel). Serializability guarantees the result is the same with, as if the transactions were executed one by one. i.e. to behave like executed in a serial order. +Databases can execute transactions in parallel (and the operations in parallel). Serializability guarantees the result is the same with, as if the transactions were executed one by one. i.e. to behave like executed in a serial order. Serializability doesn’t guarantee the resulting serial order respects recency. I.e. the serial order can be different from the order in which transactions were actually executed. E.g. Tx1 begins earlier than Tx2, but the result behaves as if Tx2 executed before Tx1. That is also to say, to satisfy the same Serializability, there can be more than one possible execution schedulings. @@ -124,7 +124,7 @@ There are three modes called "multi modes" in which a multi transaction can be e __1. Global mode__ -The transaction is equivalent to a global transaction with multiple hops. It is scheduled globally and the commands are executed as a series of consequitive hops. This mode is required for global commands (like MOVE) and for accessing undeclared keys in Lua scripts. Otherwise, it should be avoided, because it prevents Dragonfly from running concurrently and thus greatly decreases throughput. +The transaction is equivalent to a global transaction with multiple hops. It is scheduled globally and the commands are executed as a series of consecutive hops. This mode is required for global commands (like MOVE) and for accessing undeclared keys in Lua scripts. Otherwise, it should be avoided, because it prevents Dragonfly from running concurrently and thus greatly decreases throughput. __2. Lock ahead mode__ @@ -144,11 +144,11 @@ Luckily we can make one important observation about command sequences. Given a s * each command needs to preserve its order only relative to other commands accessing the same shard * commands accessing different shards can run in parallel -The basic idea behind command squashing is identifying consecutive series of single-shard commands and separating them by shards, while maintaing their relative order withing each shard. Once the commands are separated, we can execute a single hop on all relevant shards. Within each shard the hop callback will execute one by one only those commands, that assigned to its respective shard. Because all commands are already placed on their relevant threads, no further hops are required and all command callbacks are executed inline. +The basic idea behind command squashing is identifying consecutive series of single-shard commands and separating them by shards, while maintaing their relative order whitin each shard. Once the commands are separated, we can execute a single hop on all relevant shards. Within each shard the hop callback will execute one by one only those commands, that assigned to its respective shard. Because all commands are already placed on their relevant threads, no further hops are required and all command callbacks are executed inline. Reviewing our initial problems, command squashing: * Allows executing many commands with only one hop -* Allows executing commands in pararllel +* Allows executing commands in parallel ## Optimizations Out of order transactions - TBD @@ -192,7 +192,7 @@ For the single-threaded Redis the order is determined by following the natural e However with blocking scenario for BLPOP, we do not have a built-in mechanism to determine which key was filled earlier - since, as stated, the concept of total order does not exist for multiple shards. -### Interesing examples to consider: +### Interesting examples to consider: **Ex1:** ``` diff --git a/go.work b/go.work index 7d37c78ac8d1..1f70aa0629a4 100644 --- a/go.work +++ b/go.work @@ -1,6 +1,4 @@ -go 1.24.0 - -toolchain go1.24.7 +go 1.25.0 use ( ./contrib/charts/dragonfly diff --git a/go.work.sum b/go.work.sum index 6e1391921a15..670dd5687d8f 100644 --- a/go.work.sum +++ b/go.work.sum @@ -41,10 +41,10 @@ github.com/bradleyfalzon/ghinstallation v1.1.1/go.mod h1:vyCmHTciHx/uuyN82Zc3rXN github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/containerd/stargz-snapshotter/estargz v0.14.3/go.mod h1:KY//uOCIkSuNAHhJogcZtrNHdKrA99/FCCRjE3HD36o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dimchansky/utfbom v1.1.1/go.mod h1:SxdoEBH5qIqFocHMyGOXVAybYJdr71b1Q/j0mACtrfE= @@ -87,6 +87,7 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jstemmer/go-junit-report v1.0.0/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/klauspost/compress v1.16.5/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= diff --git a/helio b/helio index e8a8a0a67d90..2ed01846f979 160000 --- a/helio +++ b/helio @@ -1 +1 @@ -Subproject commit e8a8a0a67d90814dfaee57bc8371dcee122376e8 +Subproject commit 2ed01846f97968b9e684c22e9e1bef840893f72b diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 111d5a023d77..2e00bb377ee4 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -61,6 +61,7 @@ helio_cxx_test(zstd_test dfly_core TRDP::zstd LABELS DFLY) helio_cxx_test(dict_builder_test dfly_core LABELS DFLY) helio_cxx_test(top_keys_test dfly_core LABELS DFLY) helio_cxx_test(topk_test dfly_core LABELS DFLY) +helio_cxx_test(sharded_hash_map_test dfly_core LABELS DFLY) helio_cxx_test(page_usage_stats_test dfly_core LABELS DFLY) helio_cxx_test(cms_test dfly_core LABELS DFLY) helio_cxx_test(memory_test TRDP::mimalloc2 LABELS DFLY) diff --git a/src/core/bloom.cc b/src/core/bloom.cc index 0942439949cd..bcb2b248e29b 100644 --- a/src/core/bloom.cc +++ b/src/core/bloom.cc @@ -196,11 +196,11 @@ SBF& SBF::operator=(SBF&& src) noexcept { return *this; } -void SBF::AddFilter(const std::string& blob, unsigned hash_cnt) { +uint8_t* SBF::AllocateFilter(size_t alloc_size, unsigned hash_cnt) { PMR_NS::memory_resource* mr = filters_.get_allocator().resource(); - uint8_t* ptr = (uint8_t*)mr->allocate(blob.size()); - memcpy(ptr, blob.data(), blob.size()); - filters_.emplace_back().Init(ptr, blob.size(), hash_cnt); + const auto ptr = static_cast(mr->allocate(alloc_size)); + filters_.emplace_back().Init(ptr, alloc_size, hash_cnt); + return ptr; } bool SBF::Add(std::string_view str) { @@ -387,51 +387,79 @@ void SBFDumpIterator::ResolveCursorToPos() { nonstd::expected LoadSBFHeader(std::string_view header_data, PMR_NS::memory_resource* mr) { + using enum SBFLoadResult; + using nonstd::make_unexpected; + if (header_data.size() < kDumpHeaderSize) - return nonstd::make_unexpected(SBFLoadResult::kTruncatedInput); + return make_unexpected(kTruncatedInput); + + if (header_data.size() > kDumpHeaderSize) + return make_unexpected(kBadInput); const char* ptr = header_data.data(); if (const uint32_t version = absl::little_endian::Load32(ptr); version != kSbfDumpVersion) - return nonstd::make_unexpected(SBFLoadResult::kBadVersion); + return make_unexpected(kBadVersion); const double grow_factor = std::bit_cast(absl::little_endian::Load64(ptr + 4)); + if (!std::isfinite(grow_factor) || grow_factor < 1.0) + return make_unexpected(kBadInput); + // Initialize everything to 0, later filters will overwrite these values return CompactObj::AllocateMR(grow_factor, 0.0, 0UL, 0UL, 0UL, mr); } SBFLoadResult AddNewFilterToSBF(std::string_view data, SBF* sbf) { + using enum SBFLoadResult; + if (data.size() < kDumpFilterMetaSize) - return SBFLoadResult::kTruncatedInput; + return kTruncatedInput; auto [hash_cnt, data_length, state] = SBFFilterMeta::Parse(data.data()); + if (hash_cnt == 0) + return kBadInput; + + if (hash_cnt > std::numeric_limits::max()) + return kBadInput; + + if (data_length == 0 || !absl::has_single_bit(data_length)) + return kBadInput; + + // probability should be 0 to 1 (probably less than 1) + if (!std::isfinite(state.fp_prob) || state.fp_prob <= 0.0 || state.fp_prob >= 1.0) + return kBadInput; + + if (state.max_capacity == 0 || state.current_size >= state.max_capacity) + return kBadInput; + const size_t payload = data.size() - kDumpFilterMetaSize; if (payload > data_length) - return SBFLoadResult::kOutOfRange; + return kOutOfRange; sbf->ApplyStateUpdate(state); const uint32_t new_index = sbf->num_filters(); - // TODO validate variables against bloom invariants (power of two etc) - sbf->AddFilter(std::string(data_length, '\0'), hash_cnt); + auto* ptr = sbf->AllocateFilter(data_length, hash_cnt); + memset(ptr, 0, data_length); if (payload > 0) memcpy(sbf->filter_data(new_index), data.data() + kDumpFilterMetaSize, payload); - return SBFLoadResult::kOk; + return kOk; } SBFLoadResult LoadSBFChunk(int64_t cursor, std::string_view data, SBF* sbf) { DCHECK_NE(sbf, nullptr) << "Input ptr must be valid SBF"; - // TODO on implementing LOADCHUNK there should be closer validation of the data fed into the SBF. - // This current implementation is mostly a test helper and proof that the SCANDUMP algorithm is - // actually loadable. - - size_t global_offset = cursor - static_cast(data.size()) - 1; + const int64_t write_pos = cursor - static_cast(data.size()); + if (write_pos < 1) + return SBFLoadResult::kOutOfRange; + size_t global_offset = write_pos - 1; for (uint32_t i = 0; i < sbf->num_filters(); ++i) { const size_t filter_span = kDumpFilterMetaSize + sbf->data(i).size(); if (global_offset < filter_span) { + // we should never have a write position inside the header. The header is always fully + // written. if (global_offset < kDumpFilterMetaSize) return SBFLoadResult::kOutOfRange; @@ -448,7 +476,24 @@ SBFLoadResult LoadSBFChunk(int64_t cursor, std::string_view data, SBF* sbf) { if (global_offset != 0) return SBFLoadResult::kOutOfRange; + // global offset is 0, ie ended exactly at the end of the filter. data goes into a new filter. return AddNewFilterToSBF(data, sbf); } +const char* ToString(SBFLoadResult res) { + switch (res) { + case SBFLoadResult::kOk: + return "ok"; + case SBFLoadResult::kBadInput: + return "bad_input"; + case SBFLoadResult::kOutOfRange: + return "out_of_range"; + case SBFLoadResult::kTruncatedInput: + return "truncated_input"; + case SBFLoadResult::kBadVersion: + return "bad_version"; + } + return "unknown"; +} + } // namespace dfly diff --git a/src/core/bloom.h b/src/core/bloom.h index e3bb9598edaa..dcb2f1102245 100644 --- a/src/core/bloom.h +++ b/src/core/bloom.h @@ -22,6 +22,8 @@ enum class SBFLoadResult : uint8_t { kOutOfRange, }; +const char* ToString(SBFLoadResult res); + /// Bloom filter based on the design of https://github.com/jvirkki/libbloom class Bloom { public: @@ -102,14 +104,14 @@ class SBF { SBF(const SBF&) = delete; // C'tor used for loading persisted filters into SBF. - // Should be followed by AddFilter. + // Should be followed by AllocateFilter. SBF(double grow_factor, double fp_prob, size_t max_capacity, size_t prev_size, size_t current_size, PMR_NS::memory_resource* mr); ~SBF(); SBF& operator=(SBF&& src) noexcept; - void AddFilter(const std::string& blob, unsigned hash_cnt); + uint8_t* AllocateFilter(size_t alloc_size, unsigned hash_cnt); bool Add(std::string_view str); bool Exists(std::string_view str) const; @@ -190,6 +192,34 @@ struct SBFChunk { // maximum of 16MiB in size. The first chunk sent back contains only the SBF metadata. Following // chunks contain filter data and a state of the SBF. The loader uses per filter data to update the // SBF as it encounters new filter items. + +/* +SCANDUMP wire output format (all fields little-endian) + + cursor=1 returns the SBF header (12 bytes): + +-------------------+--------------------+ + | version (4B) | grow_factor (8B) | + +-------------------+--------------------+ + + cursor>1 chunks carry filter data. Each filter begins with + 44 bytes of metadata, followed by the raw filter bytes. + A single filter may span multiple chunks. + + First chunk of a filter: + +-----------------+----------------+------------+---------------------+ + | hash_cnt 4B | data_length 8B | fp_prob 8B | max_capacity 8B | + +-----------------+----------------+------------+---------------------+ + | current_size 8B | prev_size 8B | filter bytes (up to 16MiB - 44B) | + +-----------------+----------------+------------ ... -----------------+ + + Continuation chunks (same filter, if >16MiB): + +------------------------ ... -------------------------+ + | filter bytes (up to 16MiB) | + +------------------------ ... -------------------------+ + + cursor=0 signals end of iteration (empty data). +*/ + class SBFDumpIterator { public: static constexpr uint64_t kMaxChunkSize = 16 * 1024 * 1024; @@ -203,10 +233,6 @@ class SBFDumpIterator { SBFChunk Next(); private: - // Sends the SBF wide header (little endian): - // +-------------------+-------------------+ - // | version (4 bytes) | grow_factor (8B) | - // +-------------------+-------------------+ std::string SerializeHeader() const; // Converts a cursor to the specific filter and the offset inside it diff --git a/src/core/dash.h b/src/core/dash.h index 7545544f17ee..a2710122d2b3 100644 --- a/src/core/dash.h +++ b/src/core/dash.h @@ -348,7 +348,7 @@ class DashTable : public detail::DashTableBase { // Unlike Traverse, TraverseBuckets calls cb once on bucket iterator and not on each entry in // bucket. TraverseBuckets is stable during table mutations. It guarantees traversing all buckets // that existed at the beginning of traversal. - template Cursor TraverseBuckets(Cursor curs, Cb&& cb); + template Cursor TraverseBuckets(Cursor curs, Cb&& cb, bool visit_empty = false); // Traverses over a single bucket in table and calls cb(iterator). The traverse order will be // segment by segment over physical backets. @@ -460,15 +460,16 @@ class DashTable<_Key, _Value, Policy>::Iterator { uint32_t seg_id_; detail::PhysicalBid bucket_id_; uint8_t slot_id_; + bool done_; friend class DashTable; Iterator(Owner* me, uint32_t seg_id, detail::PhysicalBid bid, uint8_t sid) - : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(sid) { + : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(sid), done_(false) { } Iterator(Owner* me, uint32_t seg_id, detail::PhysicalBid bid) - : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(0) { + : owner_(me), seg_id_(seg_id), bucket_id_(bid), slot_id_(0), done_(false) { Seek2Occupied(); } @@ -486,7 +487,8 @@ class DashTable<_Key, _Value, Policy>::Iterator { : owner_(other.owner_), seg_id_(other.seg_id_), bucket_id_(other.bucket_id_), - slot_id_(other.slot_id_) { + slot_id_(other.slot_id_), + done_(other.done_) { } // Copy constructor from iterator to bucket_iterator and vice versa. @@ -495,14 +497,15 @@ class DashTable<_Key, _Value, Policy>::Iterator { : owner_(other.owner_), seg_id_(other.seg_id_), bucket_id_(other.bucket_id_), - slot_id_(IsSingleBucket ? 0 : other.slot_id_) { + slot_id_(IsSingleBucket ? 0 : other.slot_id_), + done_(other.done_) { // if this - is a bucket_iterator - we reset slot_id to the first occupied space. if constexpr (IsSingleBucket) { Seek2Occupied(); } } - Iterator() : owner_(nullptr), seg_id_(0), bucket_id_(0), slot_id_(0) { + Iterator() : owner_(nullptr), seg_id_(0), bucket_id_(0), slot_id_(0), done_(true) { } Iterator(const Iterator& other) = default; @@ -539,7 +542,7 @@ class DashTable<_Key, _Value, Policy>::Iterator { // Make it self-contained. Does not need container::end(). bool is_done() const { - return owner_ == nullptr; + return done_; } bool IsOccupied() const { @@ -564,10 +567,11 @@ class DashTable<_Key, _Value, Policy>::Iterator { } friend bool operator==(const Iterator& lhs, const Iterator& rhs) { - if (lhs.owner_ == nullptr && rhs.owner_ == nullptr) + if (lhs.done_ && rhs.done_) return true; return lhs.owner_ == rhs.owner_ && lhs.seg_id_ == rhs.seg_id_ && - lhs.bucket_id_ == rhs.bucket_id_ && lhs.slot_id_ == rhs.slot_id_; + lhs.bucket_id_ == rhs.bucket_id_ && lhs.slot_id_ == rhs.slot_id_ && + lhs.done_ == rhs.done_; } friend bool operator!=(const Iterator& lhs, const Iterator& rhs) { @@ -649,7 +653,7 @@ struct DashTable<_Key, _Value, Policy>::BucketSet { template template void DashTable<_Key, _Value, Policy>::Iterator::Seek2Occupied() { - if (owner_ == nullptr) + if (done_) return; assert(seg_id_ < owner_->segment_.size()); @@ -673,7 +677,7 @@ void DashTable<_Key, _Value, Policy>::Iterator::Seek2Oc bucket_id_ = slot_id_ = 0; } } - owner_ = nullptr; + done_ = true; } template @@ -1164,7 +1168,8 @@ auto DashTable<_Key, _Value, Policy>::AdvanceCursorBucketOrder(Cursor cursor) -> template template -auto DashTable<_Key, _Value, Policy>::TraverseBuckets(Cursor cursor, Cb&& cb) -> Cursor { +auto DashTable<_Key, _Value, Policy>::TraverseBuckets(Cursor cursor, Cb&& cb, bool visit_empty) + -> Cursor { if (SegmentType::OutOfRange(cursor.bucket_id())) // sanity. return Cursor::end(); @@ -1178,7 +1183,7 @@ auto DashTable<_Key, _Value, Policy>::TraverseBuckets(Cursor cursor, Cb&& cb) -> assert(s); if (bid < s->num_buckets()) { const auto& bucket = s->GetBucket(bid); - if (bucket.GetBusy()) { // Invoke callback only if bucket has elements. + if (visit_empty || bucket.GetBusy()) { cb(BucketIt(sid, bid)); invoked = true; } diff --git a/src/core/dragonfly_core.cc b/src/core/dragonfly_core.cc index 1b046ddc3e86..43b3937227f0 100644 --- a/src/core/dragonfly_core.cc +++ b/src/core/dragonfly_core.cc @@ -4,11 +4,17 @@ #include +#include + #include "base/logging.h" #include "core/intent_lock.h" namespace dfly { +std::ostream& operator<<(std::ostream& o, const IntentLock& lock) { + return o << "{SHARED: " << lock.cnt_[0] << ", EXCLUSIVE: " << lock.cnt_[1] << "}"; +} + const char* IntentLock::ModeName(Mode m) { switch (m) { case IntentLock::SHARED: diff --git a/src/core/intent_lock.h b/src/core/intent_lock.h index 6a565e6bde1d..77e5a3643230 100644 --- a/src/core/intent_lock.h +++ b/src/core/intent_lock.h @@ -1,11 +1,11 @@ // Copyright 2022, DragonflyDB authors. All rights reserved. // See LICENSE for licensing terms. // +#pragma once + #include -#include - -#pragma once +#include namespace dfly { @@ -60,9 +60,7 @@ class IntentLock { void VerifyDebug(); - friend std::ostream& operator<<(std::ostream& o, const IntentLock& lock) { - return o << "{SHARED: " << lock.cnt_[0] << ", EXCLUSIVE: " << lock.cnt_[1] << "}"; - } + friend std::ostream& operator<<(std::ostream& o, const IntentLock& lock); private: unsigned cnt_[2] = {0, 0}; diff --git a/src/core/search/ast_expr.h b/src/core/search/ast_expr.h index da239ba839bb..6bf214dfef8b 100644 --- a/src/core/search/ast_expr.h +++ b/src/core/search/ast_expr.h @@ -5,9 +5,8 @@ #pragma once #include -#include +#include #include -#include #include #include diff --git a/src/core/search/hnsw_alg.h b/src/core/search/hnsw_alg.h index 4327ebb07417..a9ae39377617 100644 --- a/src/core/search/hnsw_alg.h +++ b/src/core/search/hnsw_alg.h @@ -78,6 +78,10 @@ template class HierarchicalNSW : public hnswlib::AlgorithmInte bool copy_vector_ = true; + // Cached in-memory footprint (bytes) — maintained by the constructor and resizeIndex; + // read lock-free by metrics. See memorySize() for what is / isn't counted. + std::atomic memory_size_{0}; + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions @@ -155,6 +159,7 @@ template class HierarchicalNSW : public hnswlib::AlgorithmInte size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); mult_ = 1 / log(1.0 * M_); revSize_ = 1.0 / mult_; + updateMemorySize(); } ~HierarchicalNSW() { @@ -689,6 +694,37 @@ template class HierarchicalNSW : public hnswlib::AlgorithmInte linkLists_ = linkLists_new; max_elements_ = new_max_elements; + updateMemorySize(); + } + + // Approximate in-memory footprint in bytes. Lock-free: reads the cached + // capacity-based total plus rough estimates of the two dynamic containers + // (label_lookup_, deleted_elements) from their already-atomic counters. + // Per-element upper-layer link lists remain uncounted (< 5% of the total). + size_t memorySize() const { + // Rough std::unordered_map entry: node (key+value+hash+ + // next-ptr) plus amortized bucket-slot overhead. + constexpr size_t kLabelLookupEntryBytes = sizeof(labeltype) + sizeof(tableint) + 32; + // std::unordered_set entry, populated only in allow_replace_deleted mode. + constexpr size_t kDeletedEntryBytes = sizeof(tableint) + 24; + size_t total = memory_size_.load(std::memory_order_relaxed); + total += cur_element_count.load(std::memory_order_relaxed) * kLabelLookupEntryBytes; + total += num_deleted_.load(std::memory_order_relaxed) * kDeletedEntryBytes; + return total; + } + + // Recomputes memory_size_ from the current allocation-defining fields. + // Must be called whenever max_elements_ changes. + void updateMemorySize() { + // Per-element costs: level-0 block, linkLists_ pointer slot, element_levels_ entry, + // link_list_locks_ mutex; plus the copied-vector block when enabled. + size_t per_element = size_data_per_element_ + sizeof(char*) + sizeof(int) + sizeof(std::mutex); + if (copy_vector_) { + per_element += data_size_; + } + // label_op_locks_ is a fixed-size shard of mutexes independent of max_elements_. + size_t fixed = MAX_LABEL_OPERATION_LOCKS * sizeof(std::mutex); + memory_size_.store(fixed + max_elements_ * per_element, std::memory_order_relaxed); } size_t indexFileSize() const { diff --git a/src/core/search/hnsw_index.cc b/src/core/search/hnsw_index.cc index 6cd660359aed..0e444e7812a3 100644 --- a/src/core/search/hnsw_index.cc +++ b/src/core/search/hnsw_index.cc @@ -132,35 +132,13 @@ struct HnswlibAdapter { HnswIndexMetadata GetMetadata() const { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); HnswIndexMetadata metadata; - metadata.max_elements = world_.max_elements_; - metadata.cur_element_count = world_.cur_element_count.load(); - metadata.maxlevel = world_.maxlevel_; metadata.enterpoint_node = world_.enterpoint_node_; return metadata; } - void SetMetadata(const HnswIndexMetadata& metadata) { - MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); - absl::WriterMutexLock resize_lock(&resize_mutex_); - - // SetMetadata is only called during deserialization before the index is used. - // Assert the index is empty to ensure no concurrent operations are possible. - DCHECK_EQ(world_.cur_element_count.load(), 0u) - << "SetMetadata should only be called on an empty index during deserialization"; - - // Runtime check for release builds to prevent silent corruption - if (world_.cur_element_count.load() != 0) { - LOG(ERROR) << "SetMetadata called on non-empty HNSW index with " - << world_.cur_element_count.load() << " elements, ignoring"; - return; - } - - // Pre-allocate capacity based on expected element count, but don't set cur_element_count. - // cur_element_count will be set by RestoreFromNodes when the actual nodes are restored. - if (world_.max_elements_ < metadata.cur_element_count) { - world_.resizeIndex(metadata.cur_element_count); - } - // Note: Don't set cur_element_count here - RestoreFromNodes will set it after restoring nodes. + int GetMaxLevel() const { + MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); + return world_.maxlevel_; } size_t GetNodeCount() const { @@ -280,13 +258,15 @@ struct HnswlibAdapter { } public: - // Restore HNSW graph structure from serialized nodes with metadata - void RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata) { + // Restore HNSW graph structure from serialized nodes with metadata. + // Returns false if the input is inconsistent (e.g. entry point not in node set) — + // caller should fall back to rebuilding the index from the keyspace. + bool RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata) { MRMWMutexLock lock(&mrmw_mutex_, MRMWMutex::LockMode::kWriteLock); absl::WriterMutexLock resize_lock(&resize_mutex_); if (nodes.empty()) { - return; + return true; } // RestoreFromNodes is only called during deserialization on a freshly created index. @@ -294,17 +274,25 @@ struct HnswlibAdapter { DCHECK_EQ(world_.cur_element_count.load(), 0u) << "RestoreFromNodes should only be called on an empty index during deserialization"; - // Ensure we have enough capacity. - // Metadata may have been captured before the snapshot read-lock, so - // cur_element_count can be smaller than actual node internal_ids when - // concurrent writes happen. Compute the real requirement from nodes. + // hnswlib pairs enterpoint_node_ with maxlevel_; node levels are immutable after + // creation, so the entry point's level in the serialized set equals the live + // maxlevel at metadata capture. max(node.level) would risk OOB reads when a + // concurrent Add raised maxlevel between capture and node serialization. size_t max_internal_id = 0; + int entrypoint_level = -1; for (const auto& node : nodes) { max_internal_id = std::max(max_internal_id, node.internal_id); + if (node.internal_id == metadata.enterpoint_node) + entrypoint_level = node.level; } - size_t required_capacity = std::max(metadata.cur_element_count, max_internal_id + 1); - if (world_.max_elements_ < required_capacity) { - world_.resizeIndex(required_capacity); + if (entrypoint_level < 0) { + LOG(ERROR) << "HNSW restore: entry point internal_id=" << metadata.enterpoint_node + << " not present in serialized node set (" << nodes.size() + << " nodes); skipping restore — index will be rebuilt from the keyspace"; + return false; + } + if (world_.max_elements_ < max_internal_id + 1) { + world_.resizeIndex(max_internal_id + 1); } // Restore each node - directly set up memory and fields @@ -378,12 +366,13 @@ struct HnswlibAdapter { } // Set the metadata for the graph - world_.maxlevel_ = metadata.maxlevel; + world_.maxlevel_ = entrypoint_level; world_.enterpoint_node_ = metadata.enterpoint_node; VLOG(1) << "Restored HNSW index with " << restored_count - << " nodes, maxlevel=" << metadata.maxlevel + << " nodes, maxlevel=" << entrypoint_level << ", enterpoint=" << metadata.enterpoint_node; + return true; } // Update vector data for an existing node (used after RestoreFromNodes). @@ -424,6 +413,10 @@ struct HnswlibAdapter { return MRMWMutexLock(&mrmw_mutex_, MRMWMutex::LockMode::kReadLock); } + size_t GetMemoryUsage() const { + return world_.memorySize(); + } + private: HnswSpace space_; HierarchicalNSW world_; @@ -498,8 +491,8 @@ HnswIndexMetadata HnswVectorIndex::GetMetadata() const { return adapter_->GetMetadata(); } -void HnswVectorIndex::SetMetadata(const HnswIndexMetadata& metadata) { - adapter_->SetMetadata(metadata); +int HnswVectorIndex::GetMaxLevel() const { + return adapter_->GetMaxLevel(); } size_t HnswVectorIndex::GetNodeCount() const { @@ -510,9 +503,9 @@ std::vector HnswVectorIndex::GetNodesRange(size_t start, size_t en return adapter_->GetNodesRange(start, end); } -void HnswVectorIndex::RestoreFromNodes(const std::vector& nodes, +bool HnswVectorIndex::RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata) { - adapter_->RestoreFromNodes(nodes, metadata); + return adapter_->RestoreFromNodes(nodes, metadata); } bool HnswVectorIndex::UpdateVectorData(GlobalDocId id, const DocumentAccessor& doc, @@ -542,4 +535,8 @@ MRMWMutexLock HnswVectorIndex::GetReadLock() const { return adapter_->GetReadLock(); } +size_t HnswVectorIndex::GetMemoryUsage() const { + return adapter_->GetMemoryUsage(); +} + } // namespace dfly::search diff --git a/src/core/search/hnsw_index.h b/src/core/search/hnsw_index.h index 55de0d1b5236..55f817351be9 100644 --- a/src/core/search/hnsw_index.h +++ b/src/core/search/hnsw_index.h @@ -11,16 +11,12 @@ namespace dfly::search { -// Metadata structure for HNSW index serialization -// Contains the key parameters needed to restore the index state +// Wire format for HNSW index AUX. Only the entry point is persisted: capacity is +// derived from max(internal_id)+1 in the node set and maxlevel from the entry-point +// node's level (hnswlib pairs enterpoint_node_ with maxlevel_, and node levels are +// immutable after creation). struct HnswIndexMetadata { - size_t max_elements = 0; // Maximum number of elements the index can hold - // Note: cur_element_count may be smaller than actual node count during concurrent writes, - // so we compute the real requirement from nodes during restoration. - // TODO: consider removing it from metadata and rely entirely on node data for restoration. - size_t cur_element_count = 0; // Current number of elements in the index - int maxlevel = -1; // Maximum level of the graph - size_t enterpoint_node = 0; // Entry point node for the graph + size_t enterpoint_node = 0; }; // Node data structure for HNSW serialization @@ -75,8 +71,9 @@ class HnswVectorIndex { // Get metadata for serialization HnswIndexMetadata GetMetadata() const; - // Set metadata (used during restoration) - void SetMetadata(const HnswIndexMetadata& metadata); + // Current graph maxlevel_. Exposed for introspection and tests that need to + // verify invariants preserved by RestoreFromNodes (entry point must sit at maxlevel). + int GetMaxLevel() const; // Get total number of nodes in the index size_t GetNodeCount() const; @@ -85,10 +82,12 @@ class HnswVectorIndex { // Returns vector of node data for serialization std::vector GetNodesRange(size_t start, size_t end) const; - // Restore graph structure from serialized nodes with metadata - // This restores the HNSW graph links but NOT the vector data - // Vector data must be populated separately via UpdateVectorData - void RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata); + // Restore graph structure from serialized nodes with metadata. + // Restores links only; vector data must be populated separately via UpdateVectorData. + // Returns false if the metadata is inconsistent with the node set (e.g. the entry + // point is missing from the serialized nodes) — caller should then leave the index + // empty and let the higher-level rebuild path repopulate it from the keyspace. + bool RestoreFromNodes(const std::vector& nodes, const HnswIndexMetadata& metadata); // Update vector data for an existing node (used after RestoreFromNodes) // This populates the vector data for a node that already has graph links @@ -98,6 +97,9 @@ class HnswVectorIndex { // Use this during serialization to block concurrent Add/Remove (write) operations. MRMWMutexLock GetReadLock() const; + // Approximate in-memory footprint of this HNSW graph, in bytes. + size_t GetMemoryUsage() const; + private: bool copy_vector_; size_t dim_; diff --git a/src/core/search/scoring.cc b/src/core/search/scoring.cc index 4f14294d30b7..f69f3e405fb7 100644 --- a/src/core/search/scoring.cc +++ b/src/core/search/scoring.cc @@ -6,15 +6,11 @@ namespace dfly::search { -double ScoreDocument(ScorerType scorer, const ScoringContext& ctx, +double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx, const std::vector& terms) { double score = 0.0; - switch (scorer) { - case ScorerType::BM25STD: - for (const auto& term : terms) - score += BM25Std(ctx, term); - break; - } + for (const auto& term : terms) + score += scorer(ctx, term); return score; } diff --git a/src/core/search/scoring.h b/src/core/search/scoring.h index e9e634a5dd07..ec9a2e0ada51 100644 --- a/src/core/search/scoring.h +++ b/src/core/search/scoring.h @@ -17,11 +17,6 @@ namespace dfly::search { class FieldIndices; struct TextIndex; -// Supported scorer types -enum class ScorerType : int { - BM25STD, // Standard Okapi BM25 (default) -}; - // Per-term information needed for scoring a single document struct ScoringTermInfo { uint32_t term_freq = 0; // How many times this term appears in the document @@ -35,6 +30,11 @@ struct ScoringContext { size_t num_docs = 0; // Total documents in index }; +// Scorer function signature: computes the score for a single (term, document) pair. +// Register new scorers by adding a function with this signature and exposing it via +// ParseScorer in the command layer. +using ScorerFn = double (*)(const ScoringContext&, const ScoringTermInfo&); + // Compute BM25STD score for a single term in a document. // // Formula: IDF * f * (k1 + 1) / (f + k1 * (1 - b + b * docLen / avgDocLen)) @@ -63,9 +63,33 @@ inline double BM25Std(const ScoringContext& ctx, const ScoringTermInfo& term) { return idf * tf; } -// Compute BM25STD score for a document matched against multiple terms. -// Returns sum of per-term BM25 scores. -double ScoreDocument(ScorerType scorer, const ScoringContext& ctx, +// Compute TFIDF score for a single term in a document. +// +// Formula: f * IDF +// where IDF = ln(N / n), clamped to be non-negative. +// +// Note: returns 0 when a term appears in every document (N == n, no discriminating power). +// This differs from BM25STD, which adds a "+1" inside the log to keep the score positive. +inline double TfIdf(const ScoringContext& ctx, const ScoringTermInfo& term) { + if (term.term_docs == 0) + return 0.0; + + // Clamp N >= n to avoid negative IDF during transient states + double N = std::max(ctx.num_docs, term.term_docs); + return std::log(N / term.term_docs) * term.term_freq; +} + +// Compute TFIDF with document length normalization for a single term. +// +// Formula: (f * IDF) / fieldDocLen +inline double TfIdfDocNorm(const ScoringContext& ctx, const ScoringTermInfo& term) { + auto d_len = term.field_doc_len == 0 ? 1 : term.field_doc_len; + return TfIdf(ctx, term) / d_len; +} + +// Compute score for a document matched against multiple terms. +// Returns sum of per-term scores produced by the given scorer function. +double ScoreDocument(ScorerFn scorer, const ScoringContext& ctx, const std::vector& terms); } // namespace dfly::search diff --git a/src/core/search/search.cc b/src/core/search/search.cc index aff562f3bf0b..aaf7d8a20694 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -115,8 +115,8 @@ struct ProfileBuilder { struct BasicSearch { using LogicOp = AstLogicalNode::LogicOp; - BasicSearch(const FieldIndices* indices, std::optional scorer = std::nullopt) - : indices_{indices}, scorer_type_{scorer} { + BasicSearch(const FieldIndices* indices, ScorerFn scorer = nullptr) + : indices_{indices}, scorer_{scorer} { } void EnableProfiling() { @@ -234,7 +234,7 @@ struct BasicSearch { // Track matched terms for scoring (prefix/suffix/infix expand to multiple terms). // Synonym shadow entries (freq=0) are resolved to their group_id for correct scoring. - if (scorer_type_) { + if (scorer_) { for (auto* index : indices) { auto term_cb = [this, index](string_view term, const auto*) { std::string resolved{term}; @@ -280,7 +280,7 @@ struct BasicSearch { if (!active_field.empty()) { if (auto* index = GetIndex(active_field); index) { - if (scorer_type_) + if (scorer_) AddMatchedTerm(index, term); return IndexResult{index->Matching(term, strip_whitespace)}; } @@ -290,7 +290,7 @@ struct BasicSearch { vector selected_indices = indices_->GetAllTextIndices(); // Track terms for scoring - if (scorer_type_) { + if (scorer_) { for (auto* index : selected_indices) AddMatchedTerm(index, term); } @@ -505,7 +505,7 @@ struct BasicSearch { optional profile = profile_builder_ ? make_optional(profile_builder_->Take()) : nullopt; - if (scorer_type_ && !matched_text_terms_.empty()) { + if (scorer_ && !matched_text_terms_.empty()) { // Score ALL matched docs and return top-K by score (not arbitrary cutoff). auto [out, total_size, text_scores] = TakeScoredTopK(std::move(result), cuttoff_limit); return SearchResult{ @@ -573,7 +573,7 @@ struct BasicSearch { term_infos[t].field_avg_doc_len = cursors[t].index->GetFieldAvgDocLen(); } } - scored.emplace_back(static_cast(ScoreDocument(*scorer_type_, ctx, term_infos)), doc); + scored.emplace_back(static_cast(ScoreDocument(scorer_, ctx, term_infos)), doc); } // Top-K by score (skip sort when no actual cutoff, e.g. FT.AGGREGATE) @@ -602,7 +602,7 @@ struct BasicSearch { } const FieldIndices* indices_; - std::optional scorer_type_; + ScorerFn scorer_ = nullptr; string error_; optional profile_builder_ = ProfileBuilder{}; @@ -866,7 +866,7 @@ bool SearchAlgorithm::Init(string_view query, const QueryParams* params, SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t cuttoff_limit) const { DCHECK(query_); - auto bs = BasicSearch{index, scorer_type_}; + auto bs = BasicSearch{index, scorer_}; if (profiling_enabled_) bs.EnableProfiling(); return bs.Search(*query_, cuttoff_limit); @@ -915,8 +915,8 @@ void SearchAlgorithm::EnableProfiling() { profiling_enabled_ = true; } -void SearchAlgorithm::SetScorer(ScorerType type) { - scorer_type_ = type; +void SearchAlgorithm::SetScorer(ScorerFn scorer) { + scorer_ = scorer; } const AstVectorRangeNode* SearchAlgorithm::GetVectorRangeNode() const { diff --git a/src/core/search/search.h b/src/core/search/search.h index 1d4f23038edd..0578ead865ef 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -15,6 +15,7 @@ #include "base/pmr/memory_resource.h" #include "core/search/base.h" #include "core/search/range_tree.h" +#include "core/search/scoring.h" #include "core/search/synonyms.h" namespace dfly::search { @@ -209,8 +210,6 @@ struct KnnScoreSortOption { size_t limit = std::numeric_limits::max(); }; -enum class ScorerType : int; - // SearchAlgorithm allows searching field indices with a query class SearchAlgorithm { public: @@ -237,11 +236,11 @@ class SearchAlgorithm { void EnableProfiling(); - void SetScorer(ScorerType type); + void SetScorer(ScorerFn scorer); private: bool profiling_enabled_ = false; - std::optional scorer_type_; + ScorerFn scorer_ = nullptr; std::unique_ptr query_; std::optional knn_hnsw_score_sort_option_; }; diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index aeacd41970ca..9c7338d2a44b 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -1072,6 +1072,30 @@ TEST_F(KnnTest, AutoResize) { EXPECT_EQ(indices.GetAllDocs().size(), 100); } +// Seeds the given HNSW index with `n` deterministic random vectors of dim `dim` using +// the given RNG seed. Returns the owning MockedDocuments so the caller can pass them +// back to UpdateVectorData after a restore. Used by the serialization/restore tests. +inline vector SeedHnswIndex(HnswVectorIndex& index, size_t n, size_t dim, + uint32_t rng_seed) { + vector docs(n); + std::mt19937 rng(rng_seed); + std::uniform_real_distribution dist(0.0f, 1.0f); + for (size_t i = 0; i < n; i++) { + vector coords(dim); + for (size_t d = 0; d < dim; d++) + coords[d] = dist(rng); + docs[i] = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}; + index.Add(i, docs[i], "vec"); + } + return docs; +} + +// Snapshots all nodes from the index under its read lock. +inline vector SnapshotHnswNodes(const HnswVectorIndex& index) { + auto lock = index.GetReadLock(); + return index.GetNodesRange(0, index.GetNodeCount()); +} + // Parameterized HNSW serialization round-trip test. // Parameters: {num_elements, dim, similarity} struct HnswSerParam { @@ -1108,27 +1132,12 @@ TEST_P(HnswSerializationTest, RoundTrip) { params.hnsw_ef_construction = 200; HnswVectorIndex original(params, /*copy_vector=*/true); + vector docs = SeedHnswIndex(original, num_elements, dim, /*rng_seed=*/42); - std::mt19937 rng(42); - std::uniform_real_distribution dist(0.0f, 1.0f); - vector docs(num_elements); - for (size_t i = 0; i < num_elements; i++) { - vector coords(dim); - for (size_t d = 0; d < dim; d++) - coords[d] = dist(rng); - docs[i] = MockedDocument::Map{{"vec", ToBytes(absl::MakeConstSpan(coords))}}; - original.Add(i, docs[i], "vec"); - } - - // Serialize auto metadata = original.GetMetadata(); - ASSERT_EQ(metadata.cur_element_count, num_elements); + ASSERT_EQ(original.GetNodeCount(), num_elements); - std::vector nodes; - { - auto lock = original.GetReadLock(); - nodes = original.GetNodesRange(0, metadata.cur_element_count); - } + std::vector nodes = SnapshotHnswNodes(original); ASSERT_EQ(nodes.size(), num_elements); // Verify node data integrity @@ -1139,8 +1148,7 @@ TEST_P(HnswSerializationTest, RoundTrip) { // Deserialize into a fresh index HnswVectorIndex restored(params, /*copy_vector=*/true); - restored.SetMetadata(metadata); - restored.RestoreFromNodes(nodes, metadata); + ASSERT_TRUE(restored.RestoreFromNodes(nodes, metadata)); // Before UpdateVectorData, all nodes must be marked deleted. // KNN should safely return empty results (no crash from nullptr dereference). @@ -1153,17 +1161,16 @@ TEST_P(HnswSerializationTest, RoundTrip) { for (size_t i = 0; i < num_elements; i++) restored.UpdateVectorData(i, docs[i], "vec"); - // Metadata must match auto rm = restored.GetMetadata(); - EXPECT_EQ(rm.cur_element_count, metadata.cur_element_count); - EXPECT_EQ(rm.maxlevel, metadata.maxlevel); + EXPECT_EQ(restored.GetNodeCount(), num_elements); EXPECT_EQ(rm.enterpoint_node, metadata.enterpoint_node); + EXPECT_EQ(restored.GetMaxLevel(), original.GetMaxLevel()); // Graph links must be identical std::vector restored_nodes; { auto lock = restored.GetReadLock(); - restored_nodes = restored.GetNodesRange(0, rm.cur_element_count); + restored_nodes = restored.GetNodesRange(0, restored.GetNodeCount()); } ASSERT_EQ(restored_nodes.size(), nodes.size()); for (size_t i = 0; i < nodes.size(); i++) { @@ -1209,6 +1216,76 @@ TEST_P(HnswSerializationTest, RoundTrip) { } } +// Regression for the save-side race where an Add raises maxlevel between metadata +// capture and node serialization (see RestoreFromNodes for the rationale). Simulated +// by forging metadata with a low-level entry point against a multi-level node set; +// expects maxlevel_ to clamp to the entry point's level rather than max(node.level). +TEST(HnswRestoreInvariant, MaxLevelClampedToEntryPointLevel) { + constexpr size_t kDim = 8; + constexpr size_t kN = 100; + + InitTLSearchMR(PMR_NS::get_default_resource()); + absl::Cleanup cleanup = [] { InitTLSearchMR(nullptr); }; + + SchemaField::VectorParams params; + params.use_hnsw = true; + params.dim = kDim; + params.sim = VectorSimilarity::L2; + params.capacity = kN; + params.hnsw_m = 16; + params.hnsw_ef_construction = 200; + + HnswVectorIndex original(params, /*copy_vector=*/true); + SeedHnswIndex(original, kN, kDim, /*rng_seed=*/42); + std::vector nodes = SnapshotHnswNodes(original); + + int global_max_level = -1; + std::optional low_level_internal_id; + for (const auto& n : nodes) { + global_max_level = std::max(global_max_level, n.level); + if (!low_level_internal_id && n.level == 0) + low_level_internal_id = n.internal_id; + } + ASSERT_GT(global_max_level, 0) << "test setup: need a multi-level graph"; + ASSERT_TRUE(low_level_internal_id.has_value()) << "test setup: need a level-0 node"; + + HnswIndexMetadata forged_metadata{.enterpoint_node = *low_level_internal_id}; + + HnswVectorIndex restored(params, /*copy_vector=*/true); + ASSERT_TRUE(restored.RestoreFromNodes(nodes, forged_metadata)); + + EXPECT_EQ(restored.GetMaxLevel(), 0) + << "maxlevel_ must equal entry-point level; got " << restored.GetMaxLevel() + << " while node set max level=" << global_max_level; +} + +// Malformed/mismatched metadata (entry point not in serialized node set) must +// fail restoration gracefully — returning false — instead of SIGABRT'ing via +// CHECK. Callers then rebuild the index from the keyspace. +TEST(HnswRestoreInvariant, MissingEntrypointFailsGracefully) { + constexpr size_t kDim = 4; + constexpr size_t kN = 10; + + InitTLSearchMR(PMR_NS::get_default_resource()); + absl::Cleanup cleanup = [] { InitTLSearchMR(nullptr); }; + + SchemaField::VectorParams params; + params.use_hnsw = true; + params.dim = kDim; + params.sim = VectorSimilarity::L2; + params.capacity = kN; + params.hnsw_m = 16; + params.hnsw_ef_construction = 200; + + HnswVectorIndex original(params, /*copy_vector=*/true); + SeedHnswIndex(original, kN, kDim, /*rng_seed=*/7); + std::vector nodes = SnapshotHnswNodes(original); + + HnswIndexMetadata bad_metadata{.enterpoint_node = 999999}; // well past any real id + HnswVectorIndex restored(params, /*copy_vector=*/true); + EXPECT_FALSE(restored.RestoreFromNodes(nodes, bad_metadata)); +} + // Regression: in borrowed mode (copy_vector=false), Remove marks the node deleted // but hnswlib still traverses it and dereferences its data pointer. If the external // data is freed (as happens after DEL), the pointer dangles. The fix in DoRemove @@ -2838,12 +2915,65 @@ TEST_F(ScoringTest, BM25StdMultiTerm) { ScoringTermInfo t2{ .term_freq = 1, .term_docs = 20, .field_doc_len = 10, .field_avg_doc_len = 10.0}; - double multi = ScoreDocument(ScorerType::BM25STD, ctx, {t1, t2}); + double multi = ScoreDocument(&BM25Std, ctx, {t1, t2}); double sum = BM25Std(ctx, t1) + BM25Std(ctx, t2); EXPECT_DOUBLE_EQ(multi, sum); } +TEST_F(ScoringTest, TfIdfFormula) { + // f=2, N=10, n=3 + // IDF = ln(10/3) ~ 1.2039 + // score = 2 * 1.2039 ~ 2.4079 + ScoringContext ctx{.num_docs = 10}; + ScoringTermInfo term{.term_freq = 2, .term_docs = 3}; + + EXPECT_NEAR(TfIdf(ctx, term), 2.4079, 0.01); +} + +TEST_F(ScoringTest, TfIdfZeroFreq) { + ScoringContext ctx{.num_docs = 10}; + ScoringTermInfo term{.term_freq = 0, .term_docs = 3}; + + EXPECT_EQ(TfIdf(ctx, term), 0.0); +} + +TEST_F(ScoringTest, TfIdfRareTermHigherScore) { + // Same TF, but rare term (small n) should score higher than common term (large n) + ScoringContext ctx{.num_docs = 100}; + ScoringTermInfo rare{.term_freq = 1, .term_docs = 2}; + ScoringTermInfo common{.term_freq = 1, .term_docs = 50}; + + EXPECT_GT(TfIdf(ctx, rare), TfIdf(ctx, common)); +} + +TEST_F(ScoringTest, TfIdfDocNormShorterDocScoresHigher) { + // Same TF/IDF, but shorter doc should score higher after length normalization + ScoringContext ctx{.num_docs = 10}; + ScoringTermInfo short_doc{.term_freq = 1, .term_docs = 3, .field_doc_len = 5}; + ScoringTermInfo long_doc{.term_freq = 1, .term_docs = 3, .field_doc_len = 50}; + + EXPECT_GT(TfIdfDocNorm(ctx, short_doc), TfIdfDocNorm(ctx, long_doc)); +} + +TEST_F(ScoringTest, TfIdfDocNormZeroDocLen) { + // field_doc_len = 0 should not cause division by zero — falls back to unnormalized score + ScoringContext ctx{.num_docs = 10}; + ScoringTermInfo term{.term_freq = 1, .term_docs = 3, .field_doc_len = 0}; + + EXPECT_EQ(TfIdfDocNorm(ctx, term), TfIdf(ctx, term)); +} + +TEST_F(ScoringTest, ScoreDocumentDispatchesByScorerType) { + ScoringContext ctx{.num_docs = 10}; + ScoringTermInfo term{ + .term_freq = 2, .term_docs = 3, .field_doc_len = 5, .field_avg_doc_len = 5.0}; + + EXPECT_DOUBLE_EQ(ScoreDocument(&BM25Std, ctx, {term}), BM25Std(ctx, term)); + EXPECT_DOUBLE_EQ(ScoreDocument(&TfIdf, ctx, {term}), TfIdf(ctx, term)); + EXPECT_DOUBLE_EQ(ScoreDocument(&TfIdfDocNorm, ctx, {term}), TfIdfDocNorm(ctx, term)); +} + TEST_F(ScoringTest, SearchWithScorer) { // Integration test: build index, search with scorer, verify scores are non-zero Schema schema = MakeSimpleSchema({{"field", SchemaField::TEXT}}); @@ -2861,7 +2991,7 @@ TEST_F(ScoringTest, SearchWithScorer) { QueryParams params; SearchAlgorithm algo; ASSERT_TRUE(algo.Init("hello", ¶ms)); - algo.SetScorer(ScorerType::BM25STD); + algo.SetScorer(&BM25Std); auto result = algo.Search(&index); @@ -2901,7 +3031,7 @@ TEST_F(ScoringTest, SearchPrefixWithScorer) { QueryParams params; SearchAlgorithm algo; ASSERT_TRUE(algo.Init("hel*", ¶ms)); - algo.SetScorer(ScorerType::BM25STD); + algo.SetScorer(&BM25Std); auto result = algo.Search(&index); @@ -3008,7 +3138,7 @@ TEST_F(ScoringTest, BM25StdAfterDocRemoval) { QueryParams params; SearchAlgorithm algo; ASSERT_TRUE(algo.Init("hello", ¶ms)); - algo.SetScorer(ScorerType::BM25STD); + algo.SetScorer(&BM25Std); auto result_before = algo.Search(&index); ASSERT_EQ(result_before.ids.size(), 3u); @@ -3020,7 +3150,7 @@ TEST_F(ScoringTest, BM25StdAfterDocRemoval) { // Re-search SearchAlgorithm algo2; ASSERT_TRUE(algo2.Init("hello", ¶ms)); - algo2.SetScorer(ScorerType::BM25STD); + algo2.SetScorer(&BM25Std); auto result_after = algo2.Search(&index); ASSERT_EQ(result_after.ids.size(), 2u); @@ -3055,7 +3185,7 @@ TEST_F(ScoringTest, ScorerTopKCutoff) { QueryParams params; SearchAlgorithm algo; ASSERT_TRUE(algo.Init("hello", ¶ms)); - algo.SetScorer(ScorerType::BM25STD); + algo.SetScorer(&BM25Std); // Request only top 3 - should return docs 9, 8, 7 (highest TF) auto result = algo.Search(&index, 3); diff --git a/src/core/sharded_hash_map.h b/src/core/sharded_hash_map.h new file mode 100644 index 000000000000..bebe13e93096 --- /dev/null +++ b/src/core/sharded_hash_map.h @@ -0,0 +1,218 @@ +// Copyright 2026, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "base/logging.h" +#include "util/fibers/synchronization.h" + +namespace dfly { + +// Thread-safe hash map sharded into NUM_SHARDS independent shards. +// +// Each shard contains an absl::flat_hash_map protected by two fiber-aware locks: +// - write_mu_ (Mutex): serializes writers. Only one writer can modify the shard at a time. +// - read_mu_ (SharedMutex): guards readers. Acquired in shared mode for reads (FindIf, +// ForEachShared, Size) and in exclusive mode when a writer needs to commit changes that +// must be visible atomically to readers. +// +// The two-lock design allows multiple concurrent readers on a shard while a single writer +// prepares its mutation (holding only write_mu_). The writer then briefly acquires read_mu_ +// exclusively to publish the change, minimizing the window during which readers are blocked. +// +// Shard selection is determined by hashing the key with Hash (default: absl::Hash) and +// taking modulo NUM_SHARDS. Both Hash and Eq are forwarded to the underlying +// absl::flat_hash_map, so a custom Hash can be supplied as the fourth template argument and +// a custom equality as the fifth. To enable heterogeneous lookup (e.g. finding a std::string +// key via std::string_view), both Hash and Eq must be transparent. absl::Hash is NOT +// transparent — its operator() only accepts const K&. Supply a custom hash that declares +// is_transparent and accepts all query types (e.g. std::string_view for string keys), paired +// with std::equal_to<> as Eq. Without both being transparent, heterogeneous lookups will +// not compile or will silently fall back to non-heterogeneous comparison. +// +// Thread safety guarantees: +// - Concurrent reads on the same shard are safe (shared read_mu_). +// - Concurrent writes to different shards are safe (independent locks). +// - A write and a read on the same shard are safe (write_mu_ + exclusive read_mu_). +// - Concurrent writes to the same shard are serialized by write_mu_. +// +// Re-entrancy: callbacks passed to FindIf, ForEachShared, ForEachExclusive, and +// WithReadExclusiveLock are invoked while one or more shard locks are held. Calling any +// ShardedHashMap method that would re-acquire the same lock on the same shard from within +// a callback will deadlock. +// +template , + typename Eq = std::equal_to> +class ShardedHashMap { + static_assert(NUM_SHARDS > 0, "NUM_SHARDS must be greater than 0"); + using InternalMap = absl::flat_hash_map; + + public: + static constexpr size_t kNumShards = NUM_SHARDS; + + // Tag type to disambiguate shard-index Mutate(ShardId{idx}, ...) from key-based Mutate(key, ...). + struct ShardId { + size_t value; + explicit ShardId(size_t v) : value(v) { + } + }; + + // Returned by the AcquireReaderLock callable passed to Mutate(). Holds an exclusive lock on + // read_mu_ for the duration of its lifetime and exposes a mutable reference to the shard + // map. Mutations must be performed through LockedMap::map to guarantee that no reader + // observes a partial update. + struct LockedMap { + std::unique_lock lock; + InternalMap& map; + }; + + // Looks up `key` under a shared read lock on its shard. If found, invokes f(const V&) + // with the mapped value while still holding the lock, then returns true. + // Returns false if the key is not present. The callback must not modify the value. + // + // The template parameter Q allows heterogeneous lookup — any type hashable via + // Hash and comparable against K can be used. + template bool FindIf(const Q& key, F&& f) const { + const Shard& shard = shards_[ShardOf(key)]; + std::shared_lock read_lock(shard.read_mu_); + auto it = shard.map_.find(key); + if (it == shard.map_.end()) { + return false; + } + std::forward(f)(it->second); + return true; + } + + // Iterates over all entries across every shard, invoking f(const K&, const V&) for each. + // Each shard's read_mu_ is acquired in shared mode independently — the iteration is NOT + // a global snapshot, so entries may be added or removed in other shards concurrently. + // Suitable for building approximate views or collecting statistics. + template void ForEachShared(F&& f) const { + for (const Shard& shard : shards_) { + std::shared_lock read_lock(shard.read_mu_); + for (const auto& [k, v] : shard.map_) { + f(k, v); + } + } + } + + // Iterates over all entries with full exclusive access, invoking f(const K&, V&) for each. + // Both write_mu_ and read_mu_ are held exclusively per shard, so no concurrent readers + // or writers can access the shard during iteration. This is the heaviest locking mode — + // use it only when entries must be mutated in-place or when a consistent per-shard view + // is required. Note: like ForEachShared, this is still not a global snapshot across shards. + template void ForEachExclusive(F&& f) { + for (Shard& shard : shards_) { + std::unique_lock write_lock{shard.write_mu_}; + std::unique_lock reader_lock{shard.read_mu_}; + for (auto& [k, v] : shard.map_) { + f(k, v); + } + } + } + + // Primary mutation interface. Acquires write_mu_ exclusively on the shard that owns `key`, + // then invokes f(const InternalMap& map, auto AcquireReaderLock). + // + // The callback receives: + // - map: a const reference to the shard's underlying absl::flat_hash_map. The caller + // may inspect data while only write_mu_ is held (readers still proceed). + // - AcquireReaderLock: a callable that returns LockedMap, which holds an exclusive lock + // on read_mu_ and a mutable InternalMap& reference. Mutations must go through LockedMap::map + // only — this ensures no reader observes a partial update. + // + // Do not hold multiple LockedMap instances simultaneously within the callback — read_mu_ is + // non-recursive, so acquiring it twice will deadlock. Calling lock_readers() more than once + // is safe only if the previous LockedMap has gone out of scope first. + // + // Typical usage pattern: + // map.Mutate(key, [&](const auto& m, auto lock_readers) { + // /* optionally inspect m (const) without blocking readers */ + // auto lm = lock_readers(); + // lm.map[key] = new_value; // now no reader sees a partial update + // }); + // + // The template parameter Q allows heterogeneous lookup — any type hashable via + // Hash and comparable against K can be used. + template void Mutate(const Q& key, F&& f) { + Shard& shard = shards_[ShardOf(key)]; + std::unique_lock write_lock{shard.write_mu_}; + std::forward(f)(static_cast(shard.map_), [&shard]() -> LockedMap { + return {std::unique_lock{shard.read_mu_}, shard.map_}; + }); + } + + // Shard-index overload of Mutate. Same semantics as Mutate(key, f) but addresses the + // shard directly by its index `sid` (0 <= sid < NUM_SHARDS). Useful when the caller has + // already computed the shard via ShardOf() or needs to batch multiple keys that map to + // the same shard under a single lock acquisition. The same lock_readers() re-entrancy + // restriction applies: do not hold two LockedMap instances at the same time. + template void Mutate(ShardId sid, F&& f) { + DCHECK_LT(sid.value, NUM_SHARDS); + Shard& shard = shards_[sid.value]; + std::unique_lock write_lock{shard.write_mu_}; + std::forward(f)(static_cast(shard.map_), [&shard]() -> LockedMap { + return {std::unique_lock{shard.read_mu_}, shard.map_}; + }); + } + + // Returns the shard index (0 .. NUM_SHARDS-1) that `key` maps to. Can be used to + // pre-compute the shard for later use with the shard-index overloads of Mutate() or + // WithReadExclusiveLock(), or to group operations on keys that share a shard. + template size_t ShardOf(const Q& key) const { + return Hash{}(key) % NUM_SHARDS; + } + + // Acquires read_mu_ exclusively on the shard that owns `key`, blocking all concurrent + // readers (FindIf, ForEachShared, Size) on that shard, then invokes f(). The write_mu_ + // is NOT acquired, so this does not serialize against other writers. Use this when you + // need to perform an external side-effect that must not race with readers of this shard + // but the map itself is not being modified. + // + // The template parameter Q allows heterogeneous lookup — any type hashable via + // Hash and comparable against K can be used. + template void WithReadExclusiveLock(const Q& key, F&& f) { + Shard& shard = shards_[ShardOf(key)]; + std::unique_lock l{shard.read_mu_}; + std::forward(f)(); + } + + // Shard-index overload of WithReadExclusiveLock. Same semantics but addresses the shard + // directly by its index `sid` (0 <= sid < NUM_SHARDS). + template void WithReadExclusiveLock(ShardId sid, F&& f) { + DCHECK_LT(sid.value, NUM_SHARDS); + std::unique_lock l{shards_[sid.value].read_mu_}; + std::forward(f)(); + } + + // Returns the approximate total number of entries across all shards. Each shard's + // read_mu_ is acquired in shared mode independently and its size accumulated. + size_t SizeApproximate() const { + size_t total = 0; + for (const Shard& shard : shards_) { + std::shared_lock read_lock{shard.read_mu_}; + total += shard.map_.size(); + } + return total; + } + + private: + // Aligned to cache line. + struct alignas(64) Shard { + util::fb2::Mutex write_mu_; + mutable util::fb2::SharedMutex read_mu_; + InternalMap map_; + }; + + std::array shards_; +}; + +} // namespace dfly diff --git a/src/core/sharded_hash_map_test.cc b/src/core/sharded_hash_map_test.cc new file mode 100644 index 000000000000..8b0de8d46cd5 --- /dev/null +++ b/src/core/sharded_hash_map_test.cc @@ -0,0 +1,292 @@ +// Copyright 2026, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/sharded_hash_map.h" + +#include + +#include +#include +#include + +#include "base/gtest.h" +#include "base/logging.h" +#include "util/fibers/fibers.h" +#include "util/fibers/synchronization.h" + +namespace dfly { + +using namespace std; + +// Transparent hash for string-like types. absl::Hash is not transparent (its +// operator() only accepts const string&), so heterogeneous lookup requires a custom hash +// that declares is_transparent and accepts string_view (which string and const char* both +// convert to). absl guarantees that hashing equal string contents produces the same value +// regardless of the concrete string type, so this is consistent with the stored keys. +struct TransparentStringHash { + using is_transparent = void; + size_t operator()(std::string_view sv) const { + return absl::Hash{}(sv); + } +}; + +class ShardedHashMapTest : public testing::Test { + protected: + ShardedHashMap map_; +}; + +TEST_F(ShardedHashMapTest, EmptyMap) { + EXPECT_EQ(map_.SizeApproximate(), 0u); + + bool found = map_.FindIf(string("missing"), [](const int&) {}); + EXPECT_FALSE(found); +} + +TEST_F(ShardedHashMapTest, MutateInsertAndFind) { + map_.Mutate(string("key1"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map["key1"] = 42; + }); + + EXPECT_EQ(map_.SizeApproximate(), 1u); + + bool found = map_.FindIf(string("key1"), [](const int& v) { EXPECT_EQ(v, 42); }); + EXPECT_TRUE(found); +} + +TEST_F(ShardedHashMapTest, MutateOverwrite) { + map_.Mutate(string("key1"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map["key1"] = 10; + }); + + map_.Mutate(string("key1"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map["key1"] = 20; + }); + + EXPECT_TRUE(map_.FindIf(string("key1"), [](const int& v) { EXPECT_EQ(v, 20); })); + EXPECT_EQ(map_.SizeApproximate(), 1u); +} + +TEST_F(ShardedHashMapTest, MutateErase) { + map_.Mutate(string("key1"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map["key1"] = 1; + }); + EXPECT_EQ(map_.SizeApproximate(), 1u); + + map_.Mutate(string("key1"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map.erase("key1"); + }); + EXPECT_EQ(map_.SizeApproximate(), 0u); + + EXPECT_FALSE(map_.FindIf(string("key1"), [](const int&) {})); +} + +TEST_F(ShardedHashMapTest, FindIfReturnsFalseForMissing) { + map_.Mutate(string("a"), [](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map["a"] = 1; + }); + + EXPECT_FALSE(map_.FindIf(string("b"), [](const int&) {})); +} + +TEST_F(ShardedHashMapTest, MultipleKeys) { + for (int i = 0; i < 100; ++i) { + string key = "key" + to_string(i); + map_.Mutate(key, [&key, i](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map[key] = i; + }); + } + + EXPECT_EQ(map_.SizeApproximate(), 100u); + + for (int i = 0; i < 100; ++i) { + string key = "key" + to_string(i); + bool found = map_.FindIf(key, [i](const int& v) { EXPECT_EQ(v, i); }); + EXPECT_TRUE(found); + } +} + +TEST_F(ShardedHashMapTest, HeterogeneousLookup) { + // Use transparent Eq so that string_view / C-string queries compile and match correctly. + ShardedHashMap> hmap; + + hmap.Mutate(string("hello"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["hello"] = 7; + }); + + string_view sv = "hello"; + bool found = hmap.FindIf(sv, [](const int& v) { EXPECT_EQ(v, 7); }); + EXPECT_TRUE(found); + + const char* cstr = "hello"; + found = hmap.FindIf(cstr, [](const int& v) { EXPECT_EQ(v, 7); }); + EXPECT_TRUE(found); + + EXPECT_FALSE(hmap.FindIf(string_view{"missing"}, [](const int&) {})); +} + +TEST_F(ShardedHashMapTest, ShardOf) { + // ShardOf should be deterministic and within range. + string key = "test_key"; + size_t shard = map_.ShardOf(key); + EXPECT_LT(shard, map_.kNumShards); + // Same key always maps to same shard. + EXPECT_EQ(shard, map_.ShardOf(key)); +} + +TEST_F(ShardedHashMapTest, MutateByShard) { + string key = "key1"; + size_t sid = map_.ShardOf(key); + + map_.Mutate(ShardedHashMap::ShardId{sid}, [&key](const auto& m, auto lock_readers) { + auto lm = lock_readers(); + lm.map[key] = 99; + }); + + bool found = map_.FindIf(key, [](const int& v) { EXPECT_EQ(v, 99); }); + EXPECT_TRUE(found); +} + +TEST_F(ShardedHashMapTest, ForEachShared) { + map_.Mutate(string("a"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["a"] = 1; + }); + map_.Mutate(string("b"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["b"] = 2; + }); + + int sum = 0; + map_.ForEachShared([&sum](const string&, const int& v) { sum += v; }); + EXPECT_EQ(sum, 3); +} + +TEST_F(ShardedHashMapTest, ForEachExclusive) { + map_.Mutate(string("x"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["x"] = 10; + }); + map_.Mutate(string("y"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["y"] = 20; + }); + + // Double all values via exclusive iteration. + map_.ForEachExclusive([](const string&, int& v) { v *= 2; }); + + EXPECT_TRUE(map_.FindIf(string("x"), [](const int& v) { EXPECT_EQ(v, 20); })); + EXPECT_TRUE(map_.FindIf(string("y"), [](const int& v) { EXPECT_EQ(v, 40); })); +} + +TEST_F(ShardedHashMapTest, WithReadExclusiveLockByKey) { + map_.Mutate(string("k"), [](const auto& m, auto lr) { + auto lm = lr(); + lm.map["k"] = 5; + }); + + bool executed = false; + map_.WithReadExclusiveLock(string("k"), [&executed]() { executed = true; }); + EXPECT_TRUE(executed); +} + +TEST_F(ShardedHashMapTest, WithReadExclusiveLockByShard) { + bool executed = false; + map_.WithReadExclusiveLock(ShardedHashMap::ShardId{0}, + [&executed]() { executed = true; }); + EXPECT_TRUE(executed); +} + +TEST_F(ShardedHashMapTest, ConcurrentReadersAndWriter) { + // Insert initial data. + for (int i = 0; i < 50; ++i) { + string key = "key" + to_string(i); + map_.Mutate(key, [&key, i](const auto& m, auto lr) { + auto lm = lr(); + lm.map[key] = i; + }); + } + + constexpr int kReaders = 4; + constexpr int kReadsPerFiber = 200; + + util::fb2::Barrier barrier(kReaders + 1); // +1 for writer fiber + vector fibers; + + // Launch reader fibers. + for (int r = 0; r < kReaders; ++r) { + fibers.emplace_back("reader", [&] { + barrier.Wait(); + for (int j = 0; j < kReadsPerFiber; ++j) { + string key = "key" + to_string(j % 50); + map_.FindIf(key, [](const int&) {}); + } + }); + } + + // Launch writer fiber. + fibers.emplace_back("writer", [&] { + barrier.Wait(); + for (int i = 50; i < 100; ++i) { + string key = "key" + to_string(i); + map_.Mutate(key, [&key, i](const auto& m, auto lr) { + auto lm = lr(); + lm.map[key] = i; + }); + } + }); + + for (auto& fb : fibers) { + fb.Join(); + } + + EXPECT_EQ(map_.SizeApproximate(), 100u); +} + +TEST_F(ShardedHashMapTest, ConcurrentWriters) { + constexpr int kWriters = 4; + constexpr int kKeysPerWriter = 50; + + vector fibers; + util::fb2::Barrier barrier(kWriters); + + for (int w = 0; w < kWriters; ++w) { + fibers.emplace_back("writer", [&, w] { + barrier.Wait(); + for (int i = 0; i < kKeysPerWriter; ++i) { + // Each writer writes to its own key space to avoid contention on values. + string key = "w" + to_string(w) + "_k" + to_string(i); + map_.Mutate(key, [&key, val = w * 1000 + i](const auto& m, auto lr) { + auto lm = lr(); + lm.map[key] = val; + }); + } + }); + } + + for (auto& fb : fibers) { + fb.Join(); + } + + EXPECT_EQ(map_.SizeApproximate(), kWriters * kKeysPerWriter); + + // Verify all values. + for (int w = 0; w < kWriters; ++w) { + for (int i = 0; i < kKeysPerWriter; ++i) { + string key = "w" + to_string(w) + "_k" + to_string(i); + int expected = w * 1000 + i; + bool found = map_.FindIf(key, [expected](const int& v) { EXPECT_EQ(v, expected); }); + EXPECT_TRUE(found) << "missing key: " << key; + } + } +} + +} // namespace dfly diff --git a/src/facade/cmd_arg_parser.h b/src/facade/cmd_arg_parser.h index e1ac88450a78..6073adb0a7db 100644 --- a/src/facade/cmd_arg_parser.h +++ b/src/facade/cmd_arg_parser.h @@ -7,15 +7,65 @@ #include #include +#include #include #include +#include +#include #include #include "facade/facade_types.h" namespace facade { -// Helper class for numerical range restriction during parsing +// CmdArgParser — utility for parsing command option lists. +// +// Reading individual args: +// CmdArgParser parser(args); +// auto key = parser.Next(); // read one arg by type +// auto [src, dst] = parser.Next(); // read several at once (tuple) +// auto db = parser.Next>(); // range-restricted int +// // (INVALID_INT if out of range) +// auto count = parser.NextOrDefault(10); // read optional with default +// +// Tag matching: +// parser.ExpectTag("LOAD"); // required literal keyword +// if (parser.Check("NX")) { ... } // consume tag only if matched +// auto mode = parser.MapNext("EX", Mode::EX, "PX", Mode::PX); // tag -> enum mapping +// auto maybe_mode = parser.TryMapNext("ASC", Dir::ASC, // like MapNext but returns +// "DESC", Dir::DESC); // nullopt (no error) on miss +// +// Bulk named options with Apply(): +// parser.Apply( +// Exist("WITHSCORES", &with_scores), // tag present -> sets bool true +// Tag("LIMIT", &offset, &limit), // tag -> reads following args +// Tag("COUNT", &optional_count), // std::optional* supported directly +// Tag("GET", [&](CmdArgParser* p) { // lambda: custom parsing on tag match +// patterns.push_back(p->Next()); +// }), +// Map(&dir, "ASC", Dir::ASC, "DESC", Dir::DESC), // tag -> fixed value mapping +// Tag("ATTR", Map(&mask, "v", Mask::Volatile, // nested: outer tag + inner Map +// "p", Mask::Permanent)), // (inner keyword required on match) +// OneOf(Exist("NX", &nx), Exist("XX", &xx)), // mutex — at most one may match +// If(!read_only, Tag("STORE", &store_key))); // runtime-gated option +// +// Strict vs lenient dispatch: +// parser.Apply(...) — stops at first unmatched arg; pair with Finalize() to error +// parser.ApplyOrSkip(...) — silently skips unknown tags one-by-one +// +// Navigating manually: +// if (parser.HasNext()) { ... } // is there another arg? +// if (parser.HasAtLeast(3)) { ... } // at least N args remain? +// auto peek = parser.Peek(); // look at next without consuming +// parser.Skip(n); // advance n args +// CmdArgList rest = parser.Tail(); // remaining args (e.g. k/v pairs) +// +// Error surfacing (at the end of parse): +// if (!parser.Finalize()) // also reports UNPROCESSED on +// return cmd_cntx->SendError(parser.TakeError().MakeReply()); // trailing args +// // or: if (parser.HasError()) ... + +// Numerical range restriction used with Next>(). template struct FInt { decltype(min) value = {}; operator decltype(min)() { @@ -31,7 +81,10 @@ template constexpr bool is_fint = false; template constexpr bool is_fint> = true; -// Utility class for easily parsing command options from argument lists. +template constexpr bool is_optional = false; + +template constexpr bool is_optional> = true; + struct CmdArgParser { enum ErrorType { NO_ERROR, @@ -59,15 +112,13 @@ struct CmdArgParser { CmdArgParser(ArgSlice args) : args_{args} { } - // Debug asserts sure error was consumed + // DCHECKs that any error was consumed. ~CmdArgParser(); - // Get next value without consuming it std::string_view Peek() { return SafeSV(cur_i_); } - // Consume next value template auto Next() { if (cur_i_ + sizeof...(Ts) >= args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); @@ -85,15 +136,13 @@ struct CmdArgParser { } } - // returns next value if exists or default value template auto NextOrDefault(T default_value = {}) { return HasNext() ? Next() : default_value; } - // check next value ignoring case and consume it + // Consumes the next arg; reports INVALID_NEXT if it doesn't match (case-insensitive). void ExpectTag(std::string_view tag); - // Consume next value template auto MapNext(Cases&&... cases) { if (cur_i_ >= args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); @@ -110,7 +159,7 @@ struct CmdArgParser { return *res; } - // Consume next value if can map it and return mapped result or return nullopt + // Same as MapNext, but returns nullopt (no error) if no case matches. template auto TryMapNext(Cases&&... cases) -> std::optional>> { @@ -123,7 +172,7 @@ struct CmdArgParser { return res; } - // Check if the next value is equal to a specific tag. If equal, its consumed. + // If the next arg matches `tag`, consume it and the following args-into-pointers; else no-op. template bool Check(std::string_view tag, Args*... args) { if (cur_i_ + sizeof...(Args) >= args_.size()) return false; @@ -139,7 +188,22 @@ struct CmdArgParser { return true; } - // Skip specified number of arguments + // Greedily matches remaining args against the options. See the file header for usage. + template void Apply(Opts... opts) { + while (HasNext() && (opts.TryApply(this) || ...)) { + } + } + + // Like Apply, but silently skips unmatched args (one at a time) instead of stopping. Use when + // unknown tags should be ignored rather than reported. Prefer Apply + Finalize when strictness + // is desired. + template void ApplyOrSkip(Opts... opts) { + while (HasNext()) { + if (!(opts.TryApply(this) || ...)) + Skip(1); + } + } + CmdArgParser& Skip(size_t n) { if (cur_i_ + n > args_.size()) { Report(OUT_OF_BOUNDS, cur_i_); @@ -149,7 +213,7 @@ struct CmdArgParser { return *this; } - // Expect no more arguments and return if no error has occured + // Requires no leftover args and no prior errors. Reports UNPROCESSED if args remain. bool Finalize() { if (HasNext()) { Report(UNPROCESSED, cur_i_); @@ -158,12 +222,10 @@ struct CmdArgParser { return !HasError(); } - // Return remaining arguments ArgSlice Tail() const { return args_.subspan(cur_i_); } - // Return true if arguments are left and no errors occured bool HasNext() { return cur_i_ < args_.size() && !error_; } @@ -182,11 +244,10 @@ struct CmdArgParser { return cur_i_; } - // Custom error_type should start from CUSTOM_ERROR + // Reports a custom error (error_type >= CUSTOM_ERROR) at the previously-consumed index + // (or 0 if called before any arg was consumed). void Report(int error_type) { - // we use previous index, because the check was done outside and it's done after element is - // processed - Report(error_type, cur_i_ - 1); + Report(error_type, cur_i_ > 0 ? cur_i_ - 1 : 0); } private: @@ -216,10 +277,12 @@ struct CmdArgParser { } template T Convert(size_t idx) { - static_assert( - std::is_arithmetic_v || std::is_constructible_v || is_fint, - "incorrect type"); - if constexpr (std::is_arithmetic_v) { + static_assert(std::is_arithmetic_v || std::is_constructible_v || + is_fint || is_optional, + "incorrect type"); + if constexpr (is_optional) { + return T{Convert(idx)}; + } else if constexpr (std::is_arithmetic_v) { return Num(idx); } else if constexpr (std::is_constructible_v) { return static_cast(SafeSV(idx)); @@ -280,4 +343,153 @@ struct CmdArgParser { ErrorInfo error_; }; +namespace detail { + +struct ExistOpt { + std::string_view tag; + bool* field; + + bool TryApply(CmdArgParser* parser) const { + if (parser->Check(tag)) { + *field = true; + return true; + } + return false; + } +}; + +template struct TagOpt { + std::string_view tag; + std::tuple args; + + bool TryApply(CmdArgParser* parser) const { + // Match the tag first, then read fields via Next<>() — so a missing value surfaces + // OUT_OF_BOUNDS instead of being swallowed by ApplyOrSkip as "no match". + if (!parser->Check(tag)) + return false; + std::apply( + [&](auto*... ptrs) { + (((*ptrs) = parser->template Next>()), ...); + }, + args); + return true; + } +}; + +template struct LambdaOpt { + std::string_view tag; + Func func; + + bool TryApply(CmdArgParser* parser) const { + if (parser->Check(tag)) { + func(parser); + return true; + } + return false; + } +}; + +template struct MapOpt { + static_assert(sizeof...(Cases) % 2 == 0, "Map expects alternating tag/value pairs"); + + T* field; + std::tuple cases; + + bool TryApply(CmdArgParser* parser) const { + return TryMatch<0>(parser); + } + + private: + template bool TryMatch(CmdArgParser* parser) const { + if constexpr (I >= sizeof...(Cases)) { + return false; + } else if (parser->Check(std::get(cases))) { + *field = std::get(cases); + return true; + } else { + return TryMatch(parser); + } + } +}; + +template struct IfOpt { + bool cond; + Inner inner; + + bool TryApply(CmdArgParser* parser) const { + return cond && inner.TryApply(parser); + } +}; + +template struct OneOfOpt { + std::tuple opts; + mutable bool matched = false; + + bool TryApply(CmdArgParser* parser) const { + bool any = std::apply([&](auto&... os) { return (os.TryApply(parser) || ...); }, opts); + if (!any) + return false; + if (matched) + parser->Report(CmdArgParser::INVALID_CASES); + matched = true; + return true; + } +}; + +// Nested: outer tag consumes one arg, then inner option runs against the next arg. If the inner +// doesn't match, reports INVALID_CASES (the inner keyword is required once the outer matched). +template struct TagNestedOpt { + std::string_view tag; + Inner inner; + + bool TryApply(CmdArgParser* parser) const { + if (!parser->Check(tag)) + return false; + if (!inner.TryApply(parser)) + parser->Report(CmdArgParser::INVALID_CASES); + return true; + } +}; + +// Concept matching any of the Apply options (has a TryApply(CmdArgParser*) method). +template +concept ParseOption = requires(const T& t, CmdArgParser* p) { + { t.TryApply(p) } -> std::same_as; +}; + +} // namespace detail + +inline detail::ExistOpt Exist(std::string_view tag, bool* field) { + return {tag, field}; +} + +template detail::TagOpt Tag(std::string_view tag, Args*... args) { + return detail::TagOpt{tag, std::make_tuple(args...)}; +} + +template +requires std::is_invocable_v detail::LambdaOpt Tag(std::string_view tag, + Func func) { + return {tag, std::move(func)}; +} + +// Nested option: outer tag + inner sub-option (e.g. Map). After outer matches, inner must match +// the following arg or INVALID_CASES is reported. +template +detail::TagNestedOpt Tag(std::string_view tag, Inner inner) { + return {tag, std::move(inner)}; +} + +template detail::MapOpt Map(T* field, Cases... cases) { + return {field, std::make_tuple(std::move(cases)...)}; +} + +template detail::IfOpt If(bool cond, Inner inner) { + return {cond, std::move(inner)}; +} + +template detail::OneOfOpt OneOf(Opts... opts) { + return {{std::move(opts)...}, false}; +} + } // namespace facade diff --git a/src/facade/cmd_arg_parser_test.cc b/src/facade/cmd_arg_parser_test.cc index 7906a36286ac..5bc3c15748c3 100644 --- a/src/facade/cmd_arg_parser_test.cc +++ b/src/facade/cmd_arg_parser_test.cc @@ -143,6 +143,358 @@ TEST_F(CmdArgParserTest, IgnoreCase) { EXPECT_EQ(absl::implicit_cast(parser.Next()), "world"sv); } +TEST_F(CmdArgParserTest, Apply) { + // All option shapes: Exist sets a bool, Tag-with-one-field, Tag-with-two-fields. + { + auto parser = Make({"FLAG", "COUNT", "5", "LIMIT", "10", "20"}); + + bool flag = false; + uint32_t count = 0; + uint32_t offset = 0; + uint32_t limit = 0; + + parser.Apply(Exist("FLAG", &flag), Tag("COUNT", &count), Tag("LIMIT", &offset, &limit)); + + EXPECT_TRUE(flag); + EXPECT_EQ(count, 5u); + EXPECT_EQ(offset, 10u); + EXPECT_EQ(limit, 20u); + EXPECT_FALSE(parser.HasError()); + } + + // Unknown option is left unconsumed (no error). The caller decides what to do next. + { + auto parser = Make({"COUNT", "5", "BOGUS"}); + + uint32_t count = 0; + parser.Apply(Tag("COUNT", &count)); + + EXPECT_EQ(count, 5u); + EXPECT_FALSE(parser.HasError()); + EXPECT_TRUE(parser.HasNext()); + EXPECT_EQ(parser.Peek(), "BOGUS"); + } + + // Case-insensitive matching (consistent with Check). + { + auto parser = Make({"count", "7"}); + + uint32_t count = 0; + parser.Apply(Tag("COUNT", &count)); + + EXPECT_EQ(count, 7u); + EXPECT_FALSE(parser.HasError()); + } + + // Invalid integer in a Tag arg propagates the error. + { + auto parser = Make({"COUNT", "NAN"}); + + uint32_t count = 0; + parser.Apply(Tag("COUNT", &count)); + + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_INT); + } +} + +TEST_F(CmdArgParserTest, ApplyOrSkip) { + // ApplyOrSkip silently skips any unknown arg (1 at a time) and keeps going. + { + auto parser = Make({"BOGUS", "COUNT", "5", "MORE_BOGUS", "STUFF"}); + + uint32_t count = 0; + parser.ApplyOrSkip(Tag("COUNT", &count)); + + EXPECT_EQ(count, 5u); + EXPECT_FALSE(parser.HasError()); + EXPECT_FALSE(parser.HasNext()); // everything consumed + } + // Empty input — no error, no work. + { + auto parser = Make({}); + uint32_t count = 0; + parser.ApplyOrSkip(Tag("COUNT", &count)); + EXPECT_FALSE(parser.HasError()); + EXPECT_FALSE(parser.HasNext()); + } + // Trailing unknown at end-of-args: the skip must not trip OUT_OF_BOUNDS. + { + auto parser = Make({"BOGUS"}); + uint32_t count = 0; + parser.ApplyOrSkip(Tag("COUNT", &count)); + EXPECT_FALSE(parser.HasError()); + EXPECT_FALSE(parser.HasNext()); + } +} + +TEST_F(CmdArgParserTest, ApplyTagMissingValue) { + // A matched tag with missing trailing value(s) must surface an error, not be silently skipped. + // This guards against a subtle interaction with ApplyOrSkip: if TagOpt treated "tag matches, + // values missing" as "no match", the skip path would swallow the malformed option. + { + auto parser = Make({"COUNT"}); // tag matches, value missing + uint32_t count = 0; + parser.Apply(Tag("COUNT", &count)); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::OUT_OF_BOUNDS); + } + { + auto parser = Make({"COUNT"}); + uint32_t count = 0; + parser.ApplyOrSkip(Tag("COUNT", &count)); + // Tag must have been consumed (not left for Skip to swallow silently). + EXPECT_FALSE(parser.HasNext()); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::OUT_OF_BOUNDS); + } + // Also guard the two-field case: LIMIT with only one trailing value. + { + auto parser = Make({"LIMIT", "10"}); // needs offset + limit + uint32_t offset = 0, limit = 0; + parser.Apply(Tag("LIMIT", &offset, &limit)); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::OUT_OF_BOUNDS); + } +} + +TEST_F(CmdArgParserTest, ReportBeforeAnyNext) { + // Report(code) at cur_i_ == 0 must clamp the error index to 0 rather than underflow to SIZE_MAX. + auto parser = Make({"x"}); + parser.Report(CmdArgParser::CUSTOM_ERROR); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.index, 0u); +} + +TEST_F(CmdArgParserTest, ApplyLambda) { + // Tag() with a lambda lets callers run custom parsing on match. Useful for side-effectful cases + // like push_back or toggling a bool to false. + auto parser = Make({"GET", "p1", "ASC", "GET", "p2"}); + + std::vector patterns; + bool reversed = true; + + parser.Apply( + Tag("ASC", [&](CmdArgParser*) { reversed = false; }), + Tag("GET", [&](CmdArgParser* p) { patterns.push_back(p->Next()); })); + + EXPECT_FALSE(reversed); + ASSERT_EQ(patterns.size(), 2u); + EXPECT_EQ(patterns[0], "p1"); + EXPECT_EQ(patterns[1], "p2"); + EXPECT_FALSE(parser.HasError()); +} + +TEST_F(CmdArgParserTest, ApplyMap) { + // Map(&field, tag, value, ...) — matches any tag and writes the corresponding value. + // Standalone Map allows repeated matches (last wins); wrap in OneOf to require at most one. + { + auto parser = Make({"DESC"}); + bool reversed = false; + parser.Apply(Map(&reversed, "DESC", true, "ASC", false)); + EXPECT_TRUE(reversed); + EXPECT_FALSE(parser.HasError()); + } + { + auto parser = Make({"ASC"}); + bool reversed = true; + parser.Apply(Map(&reversed, "DESC", true, "ASC", false)); + EXPECT_FALSE(reversed); + EXPECT_FALSE(parser.HasError()); + } + // Unrelated tag leaves field untouched and stops Apply. + { + auto parser = Make({"OTHER"}); + bool reversed = false; + parser.Apply(Map(&reversed, "DESC", true, "ASC", false)); + EXPECT_FALSE(reversed); + EXPECT_TRUE(parser.HasNext()); + } + // Standalone Map allows repeated matches — last wins, no error. This matches Redis SORT + // semantics where "ASC DESC" is equivalent to "DESC". + { + auto parser = Make({"DESC", "ASC"}); + bool reversed = true; + parser.Apply(Map(&reversed, "DESC", true, "ASC", false)); + EXPECT_FALSE(reversed); // ASC came last + EXPECT_FALSE(parser.HasError()); + } + { + auto parser = Make({"ASC", "DESC"}); + bool reversed = false; + parser.Apply(Map(&reversed, "DESC", true, "ASC", false)); + EXPECT_TRUE(reversed); // DESC came last + EXPECT_FALSE(parser.HasError()); + } + // OneOf + Map — DESC followed by ASC is a mutex violation. + { + auto parser = Make({"DESC", "ASC"}); + bool reversed = false; + parser.Apply(OneOf(Map(&reversed, "DESC", true, "ASC", false))); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_CASES); + } +} + +TEST_F(CmdArgParserTest, ApplyTagNested) { + // Tag(tag, inner_opt) — outer tag matches, then inner option runs against the next arg. + // If the inner doesn't match, INVALID_CASES is reported (the inner keyword is required). + enum class Mode { A, B, C }; + { + auto parser = Make({"MODE", "B"}); + Mode mode = Mode::A; + parser.Apply(Tag("MODE", Map(&mode, "A", Mode::A, "B", Mode::B, "C", Mode::C))); + EXPECT_EQ(mode, Mode::B); + EXPECT_FALSE(parser.HasError()); + } + // Unknown inner tag -> INVALID_CASES. + { + auto parser = Make({"MODE", "BOGUS"}); + Mode mode = Mode::A; + parser.Apply(Tag("MODE", Map(&mode, "A", Mode::A, "B", Mode::B))); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_CASES); + } + // Outer tag absent -> no effect, no error. + { + auto parser = Make({}); + Mode mode = Mode::A; + parser.Apply(Tag("MODE", Map(&mode, "A", Mode::A, "B", Mode::B))); + EXPECT_EQ(mode, Mode::A); + EXPECT_FALSE(parser.HasError()); + } +} + +TEST_F(CmdArgParserTest, ApplyTagIf) { + // If(cond, opt) behaves like `opt` when cond is true, and never matches when false. + // Use to gate an option on a runtime flag (e.g. is_read_only). + + // cond=true -> delegate to inner (matches and sets field). + { + auto parser = Make({"STORE", "dest"}); + std::string_view store; + parser.Apply(If(true, Tag("STORE", &store))); + EXPECT_EQ(store, "dest"); + EXPECT_FALSE(parser.HasError()); + } + + // cond=false -> inner is skipped. Apply stops at the (now unmatched) arg; Finalize reports + // UNPROCESSED so the caller can surface a syntax error. + { + auto parser = Make({"STORE", "dest"}); + std::string_view store; + parser.Apply(If(false, Tag("STORE", &store))); + EXPECT_EQ(store, ""); + EXPECT_FALSE(parser.HasError()); + EXPECT_TRUE(parser.HasNext()); + EXPECT_FALSE(parser.Finalize()); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::UNPROCESSED); + } + + // Composes: cond=false + Exist - does not toggle the bool even when the tag is present. + { + auto parser = Make({"FLAG"}); + bool flag = false; + parser.Apply(If(false, Exist("FLAG", &flag))); + EXPECT_FALSE(flag); + } +} + +TEST_F(CmdArgParserTest, ApplyOneOf) { + // OneOf groups mutually-exclusive options. Zero or one may match across the Apply loop. + // A second match reports an error instead of being quietly accepted. + + // Zero matches — fine. + { + auto parser = Make({}); + bool nx = false, xx = false; + parser.Apply(OneOf(Exist("NX", &nx), Exist("XX", &xx))); + EXPECT_FALSE(nx); + EXPECT_FALSE(xx); + EXPECT_FALSE(parser.HasError()); + } + + // Single match — fine. + { + auto parser = Make({"NX"}); + bool nx = false, xx = false; + parser.Apply(OneOf(Exist("NX", &nx), Exist("XX", &xx))); + EXPECT_TRUE(nx); + EXPECT_FALSE(xx); + EXPECT_FALSE(parser.HasError()); + } + + // Two different members of the group match -> error. + { + auto parser = Make({"NX", "XX"}); + bool nx = false, xx = false; + parser.Apply(OneOf(Exist("NX", &nx), Exist("XX", &xx))); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_CASES); + } + + // Same member twice also counts as a second match -> error. + { + auto parser = Make({"NX", "NX"}); + bool nx = false, xx = false; + parser.Apply(OneOf(Exist("NX", &nx), Exist("XX", &xx))); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_CASES); + } + + // OneOf composes with other Apply options. Unrelated tags are not affected. + { + auto parser = Make({"NX", "COUNT", "5"}); + bool nx = false, xx = false; + uint32_t count = 0; + parser.Apply(OneOf(Exist("NX", &nx), Exist("XX", &xx)), Tag("COUNT", &count)); + EXPECT_TRUE(nx); + EXPECT_EQ(count, 5u); + EXPECT_FALSE(parser.HasError()); + } +} + +TEST_F(CmdArgParserTest, ApplyOptional) { + // Tag present -> optional engaged. + { + auto parser = Make({"COUNT", "5"}); + std::optional count; + parser.Apply(Tag("COUNT", &count)); + ASSERT_TRUE(count.has_value()); + EXPECT_EQ(*count, 5u); + EXPECT_FALSE(parser.HasError()); + } + // Tag absent -> optional stays empty. + { + auto parser = Make({}); + std::optional count; + parser.Apply(Tag("COUNT", &count)); + EXPECT_FALSE(count.has_value()); + EXPECT_FALSE(parser.HasError()); + } + // Invalid value -> INVALID_INT reported. The optional's state on error is undefined; callers + // must check for the parse error first. + { + auto parser = Make({"COUNT", "NAN"}); + std::optional count; + parser.Apply(Tag("COUNT", &count)); + auto err = parser.TakeError(); + EXPECT_TRUE(err); + EXPECT_EQ(err.type, CmdArgParser::INVALID_INT); + } +} + TEST_F(CmdArgParserTest, FixedRangeInt) { { auto parser = Make({"10", "-10", "12"}); diff --git a/src/facade/dragonfly_connection.cc b/src/facade/dragonfly_connection.cc index e44fe2a5b30c..6aa6e67efe7a 100644 --- a/src/facade/dragonfly_connection.cc +++ b/src/facade/dragonfly_connection.cc @@ -157,16 +157,23 @@ bool MatchHttp11Line(string_view line) { absl::EndsWith(line, "HTTP/1.1"); } -void UpdateIoBufCapacity(const io::IoBuf& io_buf, ConnectionStats* stats, - absl::FunctionRef f) { - const size_t prev_capacity = io_buf.Capacity(); - f(); - const size_t capacity = io_buf.Capacity(); - if (prev_capacity != capacity) { - VLOG(2) << "Grown io_buf to " << capacity; - stats->read_buf_capacity += capacity - prev_capacity; +struct ReadBufTracker { + explicit ReadBufTracker(const io::IoBuf& io_buf) + : io_buf_(io_buf), last_capacity_(io_buf.Capacity()) { + } + + ~ReadBufTracker() { + size_t capacity = io_buf_.Capacity(); + if (last_capacity_ != capacity) { + VLOG(2) << "Grown io_buf to " << capacity; + tl_facade_stats->conn_stats.read_buf_capacity += capacity - last_capacity_; + } } -} + + private: + const io::IoBuf& io_buf_; + size_t last_capacity_; +}; size_t UsedMemoryInternal(const ParsedCommand& msg) { return msg.GetSize() + msg.HeapMemory(); @@ -177,6 +184,10 @@ struct TrafficLogger { // Also, makes sure that LogTraffic are executed atomically. fb2::Mutex mutex; unique_ptr log_file; + // Listener type that this thread's file is recording. Only connections with a + // matching `listener_type_` produce records; others are skipped on the hot path. + // Set once when the file is opened, cleared in ResetLocked(). + Connection::ListenerType listener_type = Connection::ListenerType::MAIN_RESP; void ResetLocked(); // Returns true if Write succeeded, false if it failed and the recording should be aborted. @@ -189,6 +200,7 @@ void TrafficLogger::ResetLocked() { std::ignore = log_file->Close(); log_file.reset(); } + listener_type = Connection::ListenerType::MAIN_RESP; } // Returns true if Write succeeded, false if it failed and the recording should be aborted. @@ -218,10 +230,16 @@ thread_local base::Histogram* io_req_size_hist = nullptr; thread_local const size_t reply_size_limit = absl::GetFlag(FLAGS_squashed_reply_size_limit); thread_local uint32 pipeline_wait_batch_usec = absl::GetFlag(FLAGS_pipeline_wait_batch_usec); -void OpenTrafficLogger(string_view base_path) { +// Opens the per-thread traffic log file. Distinguishes three outcomes so the caller +// can report an accurate error to the user (was the logger already running, or did +// we fail to open a file). `listener_type` is only committed after the file is +// successfully opened so the logger's state stays consistent on failure. +Connection::StartTrafficResult OpenTrafficLogger(string_view base_path, + Connection::ListenerType listener_type) { + using Res = Connection::StartTrafficResult; unique_lock lk{tl_traffic_logger.mutex}; if (tl_traffic_logger.log_file) - return; + return Res::kAlreadyLogging; #ifdef __linux__ // Open file with append mode, without it concurrent fiber writes seem to conflict @@ -230,21 +248,30 @@ void OpenTrafficLogger(string_view base_path) { auto file = util::fb2::OpenWrite(path, io::WriteFile::Options{/*.append = */ false}); if (!file) { LOG(ERROR) << "Error opening a file " << path << " for traffic logging: " << file.error(); - return; + return Res::kOpenFailed; } tl_traffic_logger.log_file = unique_ptr{file.value()}; + tl_traffic_logger.listener_type = listener_type; #else LOG(WARNING) << "Traffic logger is only supported on Linux"; + return Res::kOpenFailed; #endif - // Write version, incremental numbering :) - uint8_t version[1] = {2}; - std::ignore = tl_traffic_logger.log_file->Write(version); + // File header: version byte (v3), followed by a single byte carrying the listener + // type for the whole file. Every record in the file belongs to this listener. + uint8_t header[2] = {3, static_cast(listener_type)}; + std::ignore = tl_traffic_logger.log_file->Write(header); + return Res::kStarted; } -void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, - ServiceInterface::ContextInfo ci) { - string_view cmd = args.Front(); +// Writes a single record. `parts[0]` is the command name, following entries are its arguments. +// Callers must guarantee a non-empty span (both LogTraffic and LogMemcacheTraffic push +// the command name as the first element before invoking this function). +void LogTrafficParts(uint32_t id, bool has_more, uint32_t db_index, + absl::Span parts) { + DCHECK(!parts.empty()); + + string_view cmd = parts.front(); if (absl::EqualsIgnoreCase(cmd, "debug"sv)) return; @@ -253,26 +280,22 @@ void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, char stack_buf[1024]; char* next = stack_buf; - // We write id, timestamp, db_index, has_more, num_parts, part_len, part_len, part_len, ... - // And then all the part blobs concatenated together. + // Record header: id, timestamp, db_index, has_more, num_parts, followed by + // part_len, part_len, ... and finally the concatenated part blobs. + // The listener type is stored once in the file header; it is not repeated per record. auto write_u32 = [&next](uint32_t i) { absl::little_endian::Store32(next, i); next += 4; }; - // id write_u32(id); - // timestamp absl::little_endian::Store64(next, absl::GetCurrentTimeNanos()); next += 8; - // db_index - write_u32(ci.db_index); - - // has_more, num_parts - write_u32(has_more ? 1 : 0); - write_u32(uint32_t(args.size())); + write_u32(db_index); + write_u32(has_more ? 1u : 0u); + write_u32(uint32_t(parts.size())); // Grab the lock and check if the file is still open. lock_guard lk{tl_traffic_logger.mutex}; @@ -280,7 +303,7 @@ void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, return; // part_len, ... - for (auto part : args.view()) { + for (string_view part : parts) { if (size_t(next - stack_buf + 4) > sizeof(stack_buf)) { if (!tl_traffic_logger.Write(string_view{stack_buf, size_t(next - stack_buf)})) { return; @@ -297,7 +320,7 @@ void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, blobs[index++] = iovec{.iov_base = stack_buf, .iov_len = size_t(next - stack_buf)}; } - for (auto part : args.view()) { + for (string_view part : parts) { if (auto blob_len = part.size(); blob_len > 0) { blobs[index++] = iovec{.iov_base = const_cast(part.data()), .iov_len = blob_len}; @@ -315,6 +338,90 @@ void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, } } +void LogTraffic(uint32_t id, bool has_more, const cmn::BackedArguments& args, + ServiceInterface::ContextInfo ci) { + absl::InlinedVector parts; + parts.reserve(args.size()); + for (auto v : args.view()) + parts.push_back(v); + LogTrafficParts(id, has_more, ci.db_index, absl::MakeSpan(parts)); +} + +// Variant used by the Memcache protocol path. +// +// The memcache parser keeps fields that are NOT arguments in scalar Command members +// (flags, expire_ts, delta, cas_unique) rather than in `backed_args`. We serialize +// them into the record so that tools/replay has enough context to reproduce the +// command faithfully. Record layout per command type: +// +// SET/ADD/REPLACE/APPEND/PREPEND : [cmd, key, value, flags, expire_ts] +// CAS : [cas, key, value, flags, expire_ts, cas_unique] +// INCR/DECR : [cmd, key, delta] +// GAT/GATS : [cmd, expire_ts, key+] (expire BEFORE keys, matches wire) +// all others (GET/GETS/DELETE/ +// FLUSHALL/STATS/ +// QUIT/VERSION) : [cmd, *backed_args] +void LogMemcacheTraffic(uint32_t id, bool has_more, const MemcacheParser::Command& mc, + ServiceInterface::ContextInfo ci) { + using MP = MemcacheParser; + string_view cmd_name = MP::CmdName(mc.type); + if (cmd_name.empty()) + return; + + // owned backs stringified numeric fields. We use a fixed-size std::array + // rather than a resizable vector so that string_views inserted into `parts` + // remain stable even if more fields are appended in the future: std::array + // never reallocates. kMaxOwned must be >= the largest per-type push count + // (currently 3, for CAS: flags + expire_ts + cas_unique). + constexpr size_t kMaxOwned = 4; + std::array owned; + size_t owned_n = 0; + + absl::InlinedVector parts; + parts.reserve(mc.backed_args->size() + kMaxOwned + 1); + parts.push_back(cmd_name); + + auto push_num = [&](uint64_t n) { + DCHECK_LT(owned_n, kMaxOwned); + owned[owned_n] = absl::StrCat(n); + parts.push_back(owned[owned_n]); + ++owned_n; + }; + + // For GAT/GATS we want expire_ts to precede the key list because the parser can + // push multiple keys into backed_args; placing expire at the end would make the + // expire index depend on the number of keys. + if (mc.type == MP::GAT || mc.type == MP::GATS) + push_num(mc.raw_expire_ts); + + for (string_view a : mc.backed_args->view()) + parts.push_back(a); + + switch (mc.type) { + case MP::SET: + case MP::ADD: + case MP::REPLACE: + case MP::APPEND: + case MP::PREPEND: + push_num(mc.flags); + push_num(mc.raw_expire_ts); + break; + case MP::CAS: + push_num(mc.flags); + push_num(mc.raw_expire_ts); + push_num(mc.cas_unique); + break; + case MP::INCR: + case MP::DECR: + push_num(mc.delta); + break; + default: + break; + } + + LogTrafficParts(id, has_more, ci.db_index, absl::MakeSpan(parts)); +} + constexpr size_t kMinReadSize = 256; const char* kPhaseName[Connection::NUM_PHASES] = {"SETUP", "READ", "PROCESS", "SHUTTING_DOWN", @@ -722,6 +829,14 @@ void Connection::OnConnectionStart() { // is null in unit-tests. if (const Listener* lsnr = static_cast(listener()); lsnr) { is_main_ = lsnr->IsMainInterface(); + if (lsnr->IsPrivilegedInterface()) { + listener_type_ = ListenerType::ADMIN_RESP; + } else if (protocol_ == Protocol::MEMCACHE) { + listener_type_ = ListenerType::MEMCACHE; + } else { + // MAIN_RESP covers TCP main listener as well as unix-socket RESP listeners. + listener_type_ = ListenerType::MAIN_RESP; + } } if (GetFlag(FLAGS_tcp_nodelay) && !socket_->IsUDS()) { @@ -1049,7 +1164,6 @@ io::Result Connection::CheckForHttpProto() { size_t last_len = 0; auto* peer = socket_.get(); - auto& conn_stats = tl_facade_stats->conn_stats; do { auto buf = io_buf_.AppendBuffer(); DCHECK(!buf.empty()); @@ -1082,7 +1196,10 @@ io::Result Connection::CheckForHttpProto() { return MatchHttp11Line(ib); } last_len = io_buf_.InputLen(); - UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.EnsureCapacity(128); }); + { + ReadBufTracker tracker(io_buf_); + io_buf_.EnsureCapacity(128); + } } while (last_len < 1024); return false; @@ -1108,7 +1225,7 @@ void Connection::ConnectionFlow() { if (io_buf_.InputLen() > 0) { phase_ = PROCESS; if (redis_parser_ && !ioloop_v2_) { - parse_status = ParseRedis(10000); + parse_status = ParseRedis(io_buf_, 10000); } else { parse_status = ParseLoop(); } @@ -1118,7 +1235,10 @@ void Connection::ConnectionFlow() { // Main loop. if (parse_status != ERROR && !ec) { - UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { io_buf_.EnsureCapacity(64); }); + { + ReadBufTracker tracker(io_buf_); + io_buf_.EnsureCapacity(64); + } variant res; if (ioloop_v2_) { res = IoLoopV2(); @@ -1266,7 +1386,8 @@ void Connection::DispatchSingle(bool has_more, absl::FunctionRef invoke_ } } -Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles, bool enqueue_only) { +Connection::ParserStatus Connection::ParseRedis(base::IoBuf& io_buf, unsigned max_busy_cycles, + bool enqueue_only) { uint32_t consumed = 0; RespSrvParser::Result result = RespSrvParser::OK; @@ -1281,7 +1402,7 @@ Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles, bool e auto* cmd = std::exchange(parsed_cmd_, ptr.release()); EnqueueParsedCommand(cmd); }; - io::Bytes read_buffer = io_buf_.InputBuffer(); + io::Bytes read_buffer = io_buf.InputBuffer(); // Keep track of total bytes consumed/parsed. The do/while{} loop below preempts, // and InputBuffer() size might change between preemption points. There is a corner case, // that ConsumeInput() will strip a portion of the request which makes the test_publish_stuck @@ -1302,7 +1423,7 @@ Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles, bool e request_consumed_bytes_ = 0; bool has_more = consumed < read_buffer.size(); - if (tl_traffic_logger.log_file && IsMain() /* log only on the main interface */) { + if (tl_traffic_logger.log_file && tl_traffic_logger.listener_type == listener_type_) { LogTraffic(id_, has_more, *parsed_cmd_, service_->GetContextInfo(cc_.get())); } @@ -1336,7 +1457,7 @@ Connection::ParserStatus Connection::ParseRedis(unsigned max_busy_cycles, bool e } } while (RespSrvParser::OK == result && read_buffer.size() > 0 && !reply_builder_->GetError()); - io_buf_.ConsumeInput(total_consumed); + io_buf.ConsumeInput(total_consumed); parser_error_ = result; if (result == RespSrvParser::OK) @@ -1359,7 +1480,7 @@ auto Connection::ParseLoop() -> ParserStatus { bool commands_parsed = false; do { - commands_parsed = (this->*parse_func)(); + commands_parsed = (this->*parse_func)(io_buf_); if (!ExecuteBatch()) return ERROR; @@ -1476,10 +1597,10 @@ variant Connection::IoLoop() { } phase_ = PROCESS; - bool is_iobuf_full = io_buf_.AppendLen() == 0; + bool reached_capacity = io_buf_.AppendLen() == 0; if (redis_parser_) { - parse_status = ParseRedis(max_busy_read_cycles_cached); + parse_status = ParseRedis(io_buf_, max_busy_read_cycles_cached); } else { DCHECK(memcache_parser_); parse_status = ParseLoop(); @@ -1504,19 +1625,16 @@ variant Connection::IoLoop() { // (Note: The buffer object is only working in power-of-2 sizes, // so there's no danger of accidental O(n^2) behavior.) if (parser_hint > capacity) { - auto& conn_stats = GetLocalConnStats(); - UpdateIoBufCapacity(io_buf_, &conn_stats, - [&]() { io_buf_.Reserve(std::min(max_iobfuf_len, parser_hint)); }); + ReadBufTracker tracker(io_buf_); + io_buf_.Reserve(std::min(max_iobfuf_len, parser_hint)); } // If we got a partial request because iobuf was full, grow it up to // a reasonable limit to save on Recv() calls. - if (is_iobuf_full && capacity < max_iobfuf_len / 2) { - auto& conn_stats = GetLocalConnStats(); + if (reached_capacity && capacity < max_iobfuf_len / 2) { // Last io used most of the io_buf to the end. - UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { - io_buf_.Reserve(capacity * 2); // Valid growth range. - }); + ReadBufTracker tracker(io_buf_); + io_buf_.Reserve(capacity * 2); // Valid growth range. } if (io_buf_.AppendLen() == 0U) { @@ -2201,8 +2319,9 @@ void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force } } -void Connection::StartTrafficLogging(string_view path) { - OpenTrafficLogger(path); +Connection::StartTrafficResult Connection::StartTrafficLogging(string_view path, + ListenerType listener_type) { + return OpenTrafficLogger(path, listener_type); } void Connection::StopTrafficLogging() { @@ -2210,6 +2329,26 @@ void Connection::StopTrafficLogging() { tl_traffic_logger.ResetLocked(); } +void Connection::LogReplicaCommand(const cmn::BackedArguments& args, uint32_t db_index) { + // Contract: LSN/PING opcodes are filtered before ExecuteTx, and + // COMMAND/EXPIRED journal entries always carry at least a command name. + DCHECK(!args.empty()); + // Fast-path gate: cheap thread-local reads without the mutex. If the logger + // was swapped out concurrently, LogTrafficParts re-checks `log_file` inside + // the lock so at worst we do a bit of wasted work (building `parts`). + // id=0 is a synthetic client id — replication has no connection/client of its + // own, and callers on the same fiber serialise naturally. + if (!tl_traffic_logger.log_file || + tl_traffic_logger.listener_type != ListenerType::REPLICA_RESP) { + return; + } + absl::InlinedVector parts; + parts.reserve(args.size()); + for (auto v : args.view()) + parts.push_back(v); + LogTrafficParts(/*id=*/0, /*has_more=*/false, db_index, absl::MakeSpan(parts)); +} + bool Connection::IsHttp() const { return is_http_; } @@ -2285,7 +2424,7 @@ void Connection::BreakOnce(uint32_t ev_mask) { if (breaker_cb_) { DVLOG(1) << "[" << id_ << "] Connection::breaker_cb_ " << ev_mask; auto fun = std::move(breaker_cb_); - DCHECK(!breaker_cb_); + breaker_cb_ = nullptr; fun(ev_mask); } } @@ -2306,7 +2445,7 @@ bool Connection::IsReplySizeOverLimit() const { return over_limit; } -bool Connection::ParseRedisBatch() { +bool Connection::ParseRedisBatch(base::IoBuf& buf) { QueueBackpressure& qbp = GetQueueBackpressure(); // Only throttle parsing if this connection is actively contributing to the queue. @@ -2319,11 +2458,11 @@ bool Connection::ParseRedisBatch() { GetLocalConnStats().pipeline_throttle_count++; return false; } - return ParseRedis(max_busy_read_cycles_cached, true) == ParserStatus::OK; + return ParseRedis(buf, max_busy_read_cycles_cached, true) == ParserStatus::OK; } -bool Connection::ParseMCBatch() { - CHECK(io_buf_.InputLen() > 0); +bool Connection::ParseMCBatch(base::IoBuf& io_buf) { + CHECK(io_buf.InputLen() > 0); do { if (parsed_cmd_ == nullptr) { @@ -2333,15 +2472,22 @@ bool Connection::ParseMCBatch() { } uint32_t consumed = 0; memcache_parser_->set_last_unix_time(time(nullptr)); - MemcacheParser::Result result = memcache_parser_->Parse(io::View(io_buf_.InputBuffer()), + MemcacheParser::Result result = memcache_parser_->Parse(io::View(io_buf.InputBuffer()), &consumed, parsed_cmd_->mc_command()); - io_buf_.ConsumeInput(consumed); + io_buf.ConsumeInput(consumed); DVLOG(2) << "mc_result " << unsigned(result) << " consumed: " << consumed << " type " << unsigned(parsed_cmd_->mc_command()->type); if (result == MemcacheParser::INPUT_PENDING) return false; + if (result == MemcacheParser::OK && tl_traffic_logger.log_file && + tl_traffic_logger.listener_type == listener_type_) { + bool has_more = io_buf_.InputLen() > 0; + LogMemcacheTraffic(id_, has_more, *parsed_cmd_->mc_command(), + service_->GetContextInfo(cc_.get())); + } + // We push the command to the parsed queue even in case of parse errors, // so that we can reply in order. EnqueueParsedCommand(parsed_cmd_); @@ -2370,7 +2516,7 @@ bool Connection::ParseMCBatch() { break; } } - } while (parsed_cmd_q_len_ < 128 && io_buf_.InputLen() > 0); + } while (parsed_cmd_q_len_ < 128 && io_buf.InputLen() > 0); return true; } @@ -2453,12 +2599,12 @@ bool Connection::ExecuteBatch() { bool Connection::ReplyBatch() { reply_builder_->SetBatchMode(true); - while (HasDispatchedCommands() && parsed_head_->CanReply()) { + while (HasInFlightCommands() && parsed_head_->CanReply()) { current_wait_.reset(); // Clear the subscription before moving to the next command auto* cmd = parsed_head_; parsed_head_ = cmd->next; cmd->SendReply(); - ReleaseParsedCommand(cmd, HasDispatchedCommands() /* is_pipelined */); + ReleaseParsedCommand(cmd, HasInFlightCommands() /* is_pipelined */); if (reply_builder_->GetError()) return false; } @@ -2473,7 +2619,18 @@ bool Connection::ReplyBatch() { } reply_builder_->SetBatchMode(false); - reply_builder_->Flush(); + + // V1: handles its pipeline batching inside AsyncFiber, so it flushes unconditionally here. + // + // V2: operates as a single-fiber event loop where reading, parsing, and executing happen + // sequentially. Because ParseLoop processes pipelines in chunks, flushing here would trigger a + // sendmsg syscall for every single chunk. Instead, V2 delegates flushing to IoLoopV2, which + // safely flushes the coalesced buffer right before the fiber yields (await) or when memory limits + // are reached. + if (!ioloop_v2_) { + reply_builder_->Flush(); + } + return !reply_builder_->GetError(); } @@ -2663,8 +2820,10 @@ void Connection::NotifyOnRecv(const util::FiberSocketBase::RecvNotification& n) pending_input_ = true; } else if (std::holds_alternative(n.read_result)) { // provided buffer. io::MutableBytes buf = std::get(n.read_result); - UpdateIoBufCapacity(io_buf_, &tl_facade_stats->conn_stats, - [&]() { io_buf_.WriteAndCommit(buf.data(), buf.size()); }); + { + ReadBufTracker tracker(io_buf_); + io_buf_.WriteAndCommit(buf.data(), buf.size()); + } last_interaction_ = time(nullptr); } else { LOG(FATAL) << "Should not reach here"; @@ -2700,11 +2859,10 @@ void Connection::ReadPendingInput() { } } -void Connection::CheckIoBufCapacity(bool is_iobuf_full) { - auto& conn_stats = tl_facade_stats->conn_stats; +void Connection::CheckIoBufCapacity(bool reached_capacity, base::IoBuf* io_buf) { size_t max_io_buf_len = GetFlag(FLAGS_max_client_iobuf_len); - size_t capacity = io_buf_.Capacity(); + size_t capacity = io_buf->Capacity(); if (capacity < max_io_buf_len) { size_t parser_hint = 0; if (redis_parser_) @@ -2716,23 +2874,22 @@ void Connection::CheckIoBufCapacity(bool is_iobuf_full) { // (Note: The buffer object is only working in power-of-2 sizes, // so there's no danger of accidental O(n^2) behavior.) if (parser_hint > capacity) { - UpdateIoBufCapacity(io_buf_, &conn_stats, - [&]() { io_buf_.Reserve(std::min(max_io_buf_len, parser_hint)); }); + ReadBufTracker tracker(*io_buf); + io_buf->Reserve(std::min(max_io_buf_len, parser_hint)); } // If we got a partial request because iobuf was full, grow it up to // a reasonable limit to save on Recv() calls. - if (is_iobuf_full && capacity < max_io_buf_len / 2) { + if (reached_capacity && capacity < max_io_buf_len / 2) { // Last io used most of the io_buf to the end. - UpdateIoBufCapacity(io_buf_, &conn_stats, [&]() { - io_buf_.Reserve(capacity * 2); // Valid growth range. - }); + ReadBufTracker tracker(*io_buf); + io_buf->Reserve(capacity * 2); // Valid growth range. } - if (io_buf_.AppendLen() == 0U) { + if (io_buf->AppendLen() == 0U) { // it can happen with memcached but not for RedisParser, because RedisParser fully // consumes the passed buffer - LOG_EVERY_T(WARNING, 10) << "Maximum io_buf length reached " << io_buf_.Capacity() + LOG_EVERY_T(WARNING, 10) << "Maximum io_buf length reached " << io_buf->Capacity() << ", consider to increase max_client_iobuf_len flag"; } } @@ -2795,7 +2952,7 @@ variant Connection::IoLoopV2() { HandleMigrateRequest(); // Register completion for current head if its pending and we don't wait on current_wait_. - if (HasDispatchedCommands() && !current_wait_.has_value()) { + if (HasInFlightCommands() && !current_wait_.has_value()) { current_wait_.emplace(parsed_head_, &cmd_completion_waiter); } @@ -2807,26 +2964,33 @@ variant Connection::IoLoopV2() { if (io_buf_.InputLen() == 0) { phase_ = READ_SOCKET; + // Flush replies deferred by ReplyBatch before sleeping - ensures the client + // gets its response even when no more data arrives (single commands, end of pipeline). + reply_builder_->Flush(); + if (auto err = reply_builder_->GetError(); err) { + return err; + } + io_event_.await([this, &is_ready_to_migrate]() { // TODO: optimize CanReply with looking up waiter key // io_buf_.InputLen() > 0 is still needed for multishot flow. // We wake up if: // 1. New data arrived or is pending (io_buf_.InputLen() > 0 || pending_input_). - // 2. A parsed command is ready to execute (HeadReadyToDispatch()). + // 2. A parsed command is ready to execute (HasCommandToExecute()). // 3. An executed command is ready to send its reply (parsed_head_ && // parsed_head_->CanReply()). // 4. Control-plane messages arrived (!dispatch_q_.empty()). // 5. The socket encountered an error/closed (io_ec_). // 6. A migration to another thread was requested AND is actionable now (no subscriptions). - return io_buf_.InputLen() > 0 || pending_input_ || HeadReadyToDispatch() || + return io_buf_.InputLen() > 0 || pending_input_ || HasCommandToExecute() || (parsed_head_ && parsed_head_->CanReply()) || !dispatch_q_.empty() || io_ec_ || is_ready_to_migrate(); }); } phase_ = PROCESS; - bool is_iobuf_full = io_buf_.AppendLen() == 0; + bool reached_capacity = io_buf_.AppendLen() == 0; // Temporary: Handle dispatch queue items (Control Path) one by one blocking command execution if (!dispatch_q_.empty()) { @@ -2844,10 +3008,9 @@ variant Connection::IoLoopV2() { std::visit(AsyncOperations{reply_builder_.get(), this}, msg.handle); } - // TODO: Possibly don't flush unconditionally - optimize it - reply_builder_->Flush(); - if (auto ec = reply_builder_->GetError(); ec) - return ec; + // Note: No flush needed here: the `continue` below re-enters the loop, which either + // hits the data path (ParseLoop flushes via ReplyBatch) or the idle-await block + // (Flush 1), which always flushes before sleeping. // TODO: Properly handle backpressure GetQueueBackpressure().pubsub_ec.notifyAll(); @@ -2880,7 +3043,7 @@ variant Connection::IoLoopV2() { size_t mem_before = conn_stats.pipeline_queue_bytes; if (parsed_head_) { - if (HeadReadyToDispatch()) + if (HasCommandToExecute()) ExecuteBatch(); ReplyBatch(); } @@ -2916,6 +3079,12 @@ variant Connection::IoLoopV2() { // us "deaf" to future memory relief. auto sub_key = qbp.v2_pipeline_backpressure_ec.subscribe_persistent(&backpressure_waiter); + // Client needs replies to free its send buffer and relieve backpressure. + reply_builder_->Flush(); + if (auto err = reply_builder_->GetError(); err) { + return err; + } + io_event_.await([this, &is_ready_to_migrate]() { bool cmd_ready = parsed_head_ && parsed_head_->CanReply(); bool under_limit = !GetQueueBackpressure().IsPipelineBufferOverLimit( @@ -2939,6 +3108,10 @@ variant Connection::IoLoopV2() { // Check io_ec_ after parsing and flushing replies, so that half-closed // connections get their responses before we close. if (io_ec_) { + reply_builder_->Flush(); + if (auto err = reply_builder_->GetError(); err) { + return err; + } LOG_IF(WARNING, cntx()->replica_conn) << "async io error: " << io_ec_; return std::exchange(io_ec_, {}); } @@ -2949,12 +3122,18 @@ variant Connection::IoLoopV2() { // Migration requested and actionable: skip buffer bookkeeping, jump to HandleMigrateRequest(). if (is_ready_to_migrate()) { + // Flush before migrating: handing off unflushed thread-local buffers to a + // new thread will cause data corruption or a hard crash. + reply_builder_->Flush(); + if (auto err = reply_builder_->GetError(); err) { + return err; // Connection is dead, no point migrating it cross-thread. + } continue; } if (parse_status == NEED_MORE) { parse_status = OK; - CheckIoBufCapacity(is_iobuf_full); + CheckIoBufCapacity(reached_capacity, &io_buf_); } } while (peer->IsOpen()); diff --git a/src/facade/dragonfly_connection.h b/src/facade/dragonfly_connection.h index 3b6d1792f10a..7279030f1833 100644 --- a/src/facade/dragonfly_connection.h +++ b/src/facade/dragonfly_connection.h @@ -191,6 +191,24 @@ class Connection : public util::Connection { // This method returns true for customer facing listeners. bool IsMainOrMemcache() const; + // Classification of the traffic source for the DEBUG TRAFFIC recorder. + // Persisted as the second byte of the file header in the on-disk format; + // the numeric values are part of that format — do not change them. + // MAIN_RESP / ADMIN_RESP / REPLICA_RESP all carry RESP-format commands; + // MAIN vs ADMIN differ by the port they were accepted on, while REPLICA + // covers commands that arrived on a replica via the replication stream + // (not from a client-facing listener). + enum class ListenerType : uint8_t { + MAIN_RESP = 1, // main RESP listener (TCP and unix-socket) + MEMCACHE = 2, // memcached protocol listener + ADMIN_RESP = 3, // privileged / admin listener (RESP protocol on admin port) + REPLICA_RESP = 4, // commands arriving on a replica from its master + }; + + ListenerType GetListenerType() const { + return listener_type_; + } + void SetName(std::string name); void SetLibName(std::string name); @@ -218,14 +236,32 @@ class Connection : public util::Connection { // and only when the flag --migrate_connections is true. void RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force); + // Outcome of a StartTrafficLogging call on a single thread. + enum class StartTrafficResult : uint8_t { + kStarted, // new recording started successfully on this thread + kAlreadyLogging, // this thread already had an active recording (noop) + kOpenFailed, // failed to open the log file (or unsupported platform) + }; + // Starts traffic logging in the calling thread. Must be a proactor thread. - // Each thread creates its own log file combining requests from all the connections in - // that thread. A noop if the thread is already logging. - static void StartTrafficLogging(std::string_view base_path); + // Each thread creates its own log file containing requests from connections on + // that thread whose listener type equals `listener_type`. Exactly one listener + // kind per recording — mixing protocols in a single file is not supported. + static StartTrafficResult StartTrafficLogging(std::string_view base_path, + ListenerType listener_type); // Stops traffic logging in this thread. A noop if the thread is not logging. static void StopTrafficLogging(); + // Writes a single command to the per-thread traffic log if (and only if) the + // logger on this thread is currently recording the REPLICA_RESP source. + // Used by the replication read path on replicas to capture commands that + // arrived from the master — they do not travel through a Connection, so the + // regular per-connection hot path does not see them. + // `db_index` is the database the command should be applied to; it is stored + // in the record so replay tools can issue SELECT before dispatch. + static void LogReplicaCommand(const cmn::BackedArguments& args, uint32_t db_index); + // Get quick debug info for logs std::string DebugInfo() const; @@ -287,7 +323,7 @@ class Connection : public util::Connection { // Drains currently available bytes from socket into io_buf_ using non-blocking reads. void ReadPendingInput(); - void CheckIoBufCapacity(bool is_iobuf_full); + void CheckIoBufCapacity(bool reached_capacity, base::IoBuf* buf); // Main loop reading client messages and passing requests to dispatch queue. std::variant IoLoopV2(); @@ -319,7 +355,7 @@ class Connection : public util::Connection { // If add is true, stats are incremented, otherwise decremented. void UpdateDispatchStats(const MessageHandle& msg, bool add); - ParserStatus ParseRedis(unsigned max_busy_cycles, bool enqueue_only = false); + ParserStatus ParseRedis(base::IoBuf& buf, unsigned max_busy_cycles, bool enqueue_only = false); void OnBreakCb(int32_t mask); @@ -366,9 +402,13 @@ class Connection : public util::Connection { // Returns true if one or more commands were parsed from the read buffer, // and false if no complete commands could be parsed (for example, when // parsing is pending more input). - bool ParseMCBatch(); + bool ParseMCBatch(base::IoBuf& buf); - bool ParseRedisBatch(); + bool ParseRedisBatch(base::IoBuf& buf); + + // Call the appropriate ParseMCBatch or ParseRedisBatch based on the protocol. + // Only CPU-bound work; must not perform I/O or fiber suspension. + void ParseFromBuffer(base::IoBuf& buf); // Call appropriate ParseBatch function, proceed with Execute and Reply all why input is remaining ParserStatus ParseLoop(); @@ -453,13 +493,13 @@ class Connection : public util::Connection { size_t parsed_cmd_q_bytes_ = 0; // Returns true if there are dispatched commands that haven't been replied yet. - bool HasDispatchedCommands() const { + bool HasInFlightCommands() const { return parsed_head_ != parsed_to_execute_; } - // Returns true if the head command is ready to dispatch (nothing in-flight ahead of it). - bool HeadReadyToDispatch() const { - return parsed_head_ && !HasDispatchedCommands(); + // Returns true if the head command is ready to execute (nothing in-flight ahead of it). + bool HasCommandToExecute() const { + return parsed_head_ && !HasInFlightCommands(); } // Returns true if there are any commands pending in the parsed command queue or dispatch queue. @@ -532,6 +572,8 @@ class Connection : public util::Connection { }; }; + ListenerType listener_type_ = ListenerType::MAIN_RESP; + bool request_shutdown_ = false; }; diff --git a/src/facade/error.h b/src/facade/error.h index aae0a72fc99e..88262358db30 100644 --- a/src/facade/error.h +++ b/src/facade/error.h @@ -52,4 +52,6 @@ inline constexpr char kRestrictDenied[] = "restrict_denied"; inline constexpr char kNoGroupErrType[] = "no_group_error"; inline constexpr char kNoAuthErrType[] = "no_auth"; +inline constexpr char kBloomFilterLoadInProgress[] = "bloom filter load in progress"; + } // namespace facade diff --git a/src/facade/facade_test.h b/src/facade/facade_test.h index 76b6e090eea2..3394698b1f59 100644 --- a/src/facade/facade_test.h +++ b/src/facade/facade_test.h @@ -6,7 +6,7 @@ #include -#include +#include #include #include diff --git a/src/facade/memcache_parser.cc b/src/facade/memcache_parser.cc index 6750c8a0fc11..841e3ae2bf22 100644 --- a/src/facade/memcache_parser.cc +++ b/src/facade/memcache_parser.cc @@ -94,6 +94,7 @@ MP::Result ParseStore(ArgSlice tokens, int64_t now, MP::Command* res, uint32_t m return MP::PARSE_ERROR; } + res->raw_expire_ts = expire_ts; res->expire_ts = ToAbsolute(expire_ts, now); if (res->type == MP::CAS && !absl::SimpleAtoi(tokens[4], &res->cas_unique)) { @@ -126,6 +127,7 @@ MP::Result ParseValueless(ArgSlice tokens, int64_t now, MP::Command* res) { if (!absl::SimpleAtoi(tokens[0], &expire_ts)) { return MP::BAD_INT; } + res->raw_expire_ts = expire_ts; res->expire_ts = ToAbsolute(expire_ts, now); ++key_pos; } @@ -277,6 +279,7 @@ MP::Result ParseMeta(ArgSlice tokens, int64_t now, MP::Command* res, uint32_t ma case 'T': if (!absl::SimpleAtoi(token.substr(1), &expire_ts)) return MP::BAD_INT; + res->raw_expire_ts = expire_ts; res->expire_ts = ToAbsolute(expire_ts, now); if (res->type == MP::GET) res->type = MP::GAT; @@ -473,4 +476,60 @@ auto MP::ConsumeValue(std::string_view str, uint32_t* consumed, Command* dest) - return val_len_to_read_ > 0 ? MP::INPUT_PENDING : MP::OK; } +// Inverse of the token map in From(): enum -> wire token. Only used by the +// traffic logger, which is off most of the time, so a switch is plenty. +string_view MP::CmdName(CmdType type) { + switch (type) { + case MP::SET: + return "set"sv; + case MP::ADD: + return "add"sv; + case MP::REPLACE: + return "replace"sv; + case MP::APPEND: + return "append"sv; + case MP::PREPEND: + return "prepend"sv; + case MP::CAS: + return "cas"sv; + case MP::GET: + return "get"sv; + case MP::GETS: + return "gets"sv; + case MP::GAT: + return "gat"sv; + case MP::GATS: + return "gats"sv; + case MP::STATS: + return "stats"sv; + case MP::INCR: + return "incr"sv; + case MP::DECR: + return "decr"sv; + case MP::DELETE: + return "delete"sv; + case MP::FLUSHALL: + return "flush_all"sv; + case MP::QUIT: + return "quit"sv; + case MP::VERSION: + return "version"sv; + case MP::META_NOOP: + return "mn"sv; + case MP::META_SET: + return "ms"sv; + case MP::META_DEL: + return "md"sv; + case MP::META_ARITHM: + return "ma"sv; + case MP::META_GET: + return "mg"sv; + case MP::META_DEBUG: + return "me"sv; + case MP::INVALID: + return ""sv; + } + return ""sv; +} + } // namespace facade diff --git a/src/facade/memcache_parser.h b/src/facade/memcache_parser.h index 68127ef0828b..ad008ec12c5f 100644 --- a/src/facade/memcache_parser.h +++ b/src/facade/memcache_parser.h @@ -90,6 +90,12 @@ class MemcacheParser { int64_t expire_ts = 0; // unix time (expire_ts > month) in seconds + // Original, pre-ToAbsolute exptime token as sent by the client. Kept so that + // tools/replay can reproduce the exact wire command: relative exptimes stay + // relative on replay (re-resolved against the replayer's "now"), absolute + // exptimes stay absolute. `expire_ts` above is always the absolutised form. + uint32_t raw_expire_ts = 0; + // flags for STORE commands uint32_t flags = 0; @@ -99,7 +105,7 @@ class MemcacheParser { cmn::BackedArguments* backed_args = nullptr; }; - static_assert(sizeof(Command) == 40); + static_assert(sizeof(Command) == 48); enum Result : uint8_t { OK, @@ -114,6 +120,11 @@ class MemcacheParser { return type >= SET && type <= CAS; } + // Returns the wire-protocol token for `type` (e.g. "set", "mg"), or an empty + // string_view for INVALID / unrecognized values. Used by the traffic logger + // so that the memcache command name does not need to be duplicated in callers. + static std::string_view CmdName(CmdType type); + size_t UsedMemory() const { return tmp_buf_.capacity(); } diff --git a/src/facade/op_status.cc b/src/facade/op_status.cc index e429d8b42ce4..049796776137 100644 --- a/src/facade/op_status.cc +++ b/src/facade/op_status.cc @@ -1,9 +1,19 @@ #include "facade/op_status.h" +#include + #include "base/logging.h" #include "facade/error.h" #include "facade/resp_expr.h" +namespace std { + +std::ostream& operator<<(std::ostream& os, facade::OpStatus op) { + return os << static_cast(op); +} + +} // namespace std + namespace facade { std::string_view StatusToMsg(OpStatus status) { @@ -44,6 +54,8 @@ std::string_view StatusToMsg(OpStatus status) { return kNanOrInfDuringIncr; case OpStatus::IO_ERROR: return kTieredIoError; + case OpStatus::BLOOM_FILTER_LOAD_IN_PROGRESS: + return kBloomFilterLoadInProgress; default: LOG(ERROR) << "Unsupported status " << status; return "Internal error"; diff --git a/src/facade/op_status.h b/src/facade/op_status.h index 1fedca1cb4df..939f06a187ea 100644 --- a/src/facade/op_status.h +++ b/src/facade/op_status.h @@ -5,7 +5,9 @@ #pragma once #include -#include +#include +#include +#include namespace facade { @@ -35,6 +37,7 @@ enum class OpStatus : uint16_t { INVALID_JSON, IO_ERROR, NAN_OR_INF_DURING_INCR, + BLOOM_FILTER_LOAD_IN_PROGRESS, }; class OpResultBase { @@ -127,14 +130,10 @@ std::string_view StatusToMsg(OpStatus status); namespace std { -template std::ostream& operator<<(std::ostream& os, const facade::OpResult& res) { - os << res.status(); - return os; -} +std::ostream& operator<<(std::ostream& os, facade::OpStatus op); -inline std::ostream& operator<<(std::ostream& os, const facade::OpStatus op) { - os << int(op); - return os; +template std::ostream& operator<<(std::ostream& os, const facade::OpResult& res) { + return os << res.status(); } } // namespace std diff --git a/src/facade/reply_builder.cc b/src/facade/reply_builder.cc index 90b04d900757..8e217b1d6441 100644 --- a/src/facade/reply_builder.cc +++ b/src/facade/reply_builder.cc @@ -134,6 +134,10 @@ void SinkReplyBuilder::WriteRef(std::string_view str) { } void SinkReplyBuilder::Flush(size_t expected_buffer_cap) { + // Fast path: nothing buffered and no buffer resize requested. + if (vecs_.empty() && (expected_buffer_cap == 0)) + return; + if (!vecs_.empty()) Send(); diff --git a/src/facade/reply_payload.h b/src/facade/reply_payload.h index b40741eaa8ed..a122178bdf49 100644 --- a/src/facade/reply_payload.h +++ b/src/facade/reply_payload.h @@ -27,7 +27,7 @@ struct BulkString : public std::string {}; // SendBulkString using Payload = std::variant>; -#ifdef __linux__ +#if defined(__linux__) && !defined(_LIBCPP_VERSION) static_assert(sizeof(Payload) == 40); #endif diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 8e3b3cc5fd92..f7bbd89d80b5 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -165,6 +165,7 @@ helio_cxx_test(cluster/cluster_config_test dfly_test_lib LABELS DFLY) helio_cxx_test(cluster/cluster_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(acl/acl_family_test dfly_test_lib LABELS DFLY) helio_cxx_test(engine_shard_set_test dfly_test_lib LABELS DFLY) +helio_cxx_test(serializer_base_test dfly_test_lib LABELS DFLY) add_dependencies(check_dfly dragonfly_test json_family_test list_family_test generic_family_test memcache_parser_test rdb_test journal_test diff --git a/src/server/acl/acl_family.cc b/src/server/acl/acl_family.cc index 9d1a72d08d47..98f18320bf9f 100644 --- a/src/server/acl/acl_family.cc +++ b/src/server/acl/acl_family.cc @@ -147,6 +147,7 @@ void AclFamily::SetUser(CmdArgList args, CommandContext* cmd_cntx) { auto update_case = [username, ®, cmd_cntx, this, exists](User::UpdateRequest&& req) { auto& user = reg.registry[username]; + const User::MemoryUsage before = exists ? user.GetMemoryUsage() : User::MemoryUsage{}; if (!exists) { User::UpdateRequest default_req; default_req.updates = {User::UpdateRequest::CategoryValueType{User::Sign::MINUS, acl::ALL}}; @@ -155,6 +156,7 @@ void AclFamily::SetUser(CmdArgList args, CommandContext* cmd_cntx) { } const bool reset_channels = req.reset_channels; user.Update(std::move(req), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); + registry_->TrackUser(before, user.GetMemoryUsage(), !exists); // Send ok first because the connection might get evicted cmd_cntx->SendOk(); if (exists) { @@ -333,6 +335,7 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path, // Evict open connections for old users EvictOpenConnectionsOnAllProactorsWithRegistry(registry); registry.clear(); + registry_->ResetStats(); } for (size_t i = 0; i < usernames.size(); ++i) { @@ -343,12 +346,14 @@ GenericError AclFamily::LoadToRegistryFromFile(std::string_view full_path, CategoryToCommandsIndex()); user.Update(std::move(requests[i]), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); + registry_->TrackUser({}, user.GetMemoryUsage(), /*is_new=*/true); } if (!registry.contains("default")) { auto& user = registry["default"]; user.Update(registry_->DefaultUserUpdateRequest(), CategoryToIdx(), reverse_cat_table_, CategoryToCommandsIndex()); + registry_->TrackUser({}, user.GetMemoryUsage(), /*is_new=*/true); } return {}; diff --git a/src/server/acl/acl_family_test.cc b/src/server/acl/acl_family_test.cc index 870b1d750408..e58b3b3dc901 100644 --- a/src/server/acl/acl_family_test.cc +++ b/src/server/acl/acl_family_test.cc @@ -575,4 +575,49 @@ TEST_F(AclFamilyTest, TestAclLogUB) { EXPECT_THAT(resp, ErrArg("ERR index out of range")); } +TEST_F(AclFamilyTest, AclInfoMetrics) { + TestInitAclFam(); + + // After init only the default user exists. + auto stats = GetMetrics().acl_stats; + EXPECT_EQ(stats.num_users, 1); + EXPECT_EQ(stats.num_passwords, 0); // default user uses nopass + EXPECT_EQ(stats.num_key_globs, 0); + EXPECT_EQ(stats.num_pubsub_globs, 0); + + // Add a new user with two passwords. + Run("ACL SETUSER alice >pass1 >pass2"); + stats = GetMetrics().acl_stats; + EXPECT_EQ(stats.num_users, 2); + EXPECT_EQ(stats.num_passwords, 2); + + // Adding a key glob is tracked in num_key_globs and key_globs_bytes. + Run("ACL SETUSER alice ~mykey*"); + stats = GetMetrics().acl_stats; + EXPECT_EQ(stats.num_key_globs, 1); + EXPECT_EQ(stats.key_globs_bytes, std::string("mykey*").size()); + + // Adding a pubsub glob is tracked in num_pubsub_globs and pubsub_globs_bytes. + Run("ACL SETUSER alice &news.*"); + stats = GetMetrics().acl_stats; + EXPECT_EQ(stats.num_pubsub_globs, 1); + EXPECT_EQ(stats.pubsub_globs_bytes, std::string("news.*").size()); + + // Removing a password is reflected immediately. + Run("ACL SETUSER alice keys) { for (auto& key : keys) { if (key.all_keys) { diff --git a/src/server/acl/user.h b/src/server/acl/user.h index fb8daa418fc5..c107bf5f0717 100644 --- a/src/server/acl/user.h +++ b/src/server/acl/user.h @@ -134,6 +134,22 @@ class User final { const CommandChanges& CmdChanges() const; + // Per-user heap-allocated collection sizes, used by UserRegistry for aggregate stats. + struct MemoryUsage { + size_t num_passwords = 0; + size_t num_cat_changes = 0; + size_t num_cmd_changes = 0; + size_t num_key_globs = 0; + size_t key_globs_bytes = 0; // total byte length of key glob strings + size_t num_pubsub_globs = 0; + size_t pubsub_globs_bytes = 0; // total byte length of pubsub glob strings + + MemoryUsage& operator+=(const MemoryUsage& u); + MemoryUsage& operator-=(const MemoryUsage& u); + }; + + MemoryUsage GetMemoryUsage() const; + private: void SetAclCategoriesAndIncrSeq(uint32_t cat, const CategoryToIdxStore& cat_to_id, const ReverseCategoryIndexTable& reverse_cat, diff --git a/src/server/acl/user_registry.cc b/src/server/acl/user_registry.cc index efdc876ce1e8..915d1e533803 100644 --- a/src/server/acl/user_registry.cc +++ b/src/server/acl/user_registry.cc @@ -17,15 +17,75 @@ using namespace util; namespace dfly::acl { +// SHA256 produces 32-byte binary hashes. Each is stored as a std::string in the +// flat_hash_set, which exceeds SSO capacity and thus heap-allocates its content. +static constexpr size_t kSHA256Bytes = 32; + +size_t UserRegistry::AclStats::TotalBytes() const { + // Fixed per-user cost: the User object itself plus the always-allocated commands_ vector. + const size_t per_user_base = sizeof(User) + NumberOfFamilies() * sizeof(uint64_t); + + // Each password hash is a 32-byte binary string stored in an absl flat_hash_set. + // The std::string object lives inline in the set slot; the content (>SSO) is heap-allocated. + const size_t per_password = sizeof(std::string) + kSHA256Bytes; + + // Category-change map entry: uint32_t key + ChangeMetadata value. + const size_t per_cat_change = sizeof(User::CategoryChange) + sizeof(User::ChangeMetadata); + + // Command-change map entry: pair key + ChangeMetadata value. + const size_t per_cmd_change = sizeof(User::CommandChange) + sizeof(User::ChangeMetadata); + + // Key-glob vector entry: pair object + any string content exceeding SSO. + const size_t per_key_glob = sizeof(std::pair); + + // PubSub-glob vector entry: pair object + any string content exceeding SSO. + const size_t per_pubsub_glob = sizeof(std::pair); + + return num_users * per_user_base + // + num_passwords * per_password + // + num_cat_changes * per_cat_change + // + num_cmd_changes * per_cmd_change + // + num_key_globs * per_key_glob + key_globs_bytes + // + num_pubsub_globs * per_pubsub_glob + pubsub_globs_bytes; // +} + +UserRegistry::AclStats UserRegistry::GetAclStats() const { + std::shared_lock lock(mu_); + return stats_; +} + +void UserRegistry::TrackUser(const User::MemoryUsage& before, const User::MemoryUsage& after, + bool is_new) { + if (is_new) + ++stats_.num_users; + stats_ -= before; + stats_ += after; +} + +void UserRegistry::ResetStats() { + stats_ = AclStats{}; +} + void UserRegistry::MaybeAddAndUpdate(std::string_view username, User::UpdateRequest req) { std::unique_lock lock(mu_); + const bool is_new = !registry_.contains(username); auto& user = registry_[username]; + + const User::MemoryUsage before = is_new ? User::MemoryUsage{} : user.GetMemoryUsage(); user.Update(std::move(req), *cat_to_id_table_, *reverse_cat_table_, *cat_to_commands_table_); + TrackUser(before, user.GetMemoryUsage(), is_new); } bool UserRegistry::RemoveUser(std::string_view username) { std::unique_lock lock(mu_); - return registry_.erase(username); + auto it = registry_.find(username); + if (it == registry_.end()) { + return false; + } + TrackUser(it->second.GetMemoryUsage(), {}, false); + --stats_.num_users; + registry_.erase(it); + return true; } UserCredentials UserRegistry::GetCredentials(std::string_view username) const { diff --git a/src/server/acl/user_registry.h b/src/server/acl/user_registry.h index 529ef8654ef8..6e38b18410a4 100644 --- a/src/server/acl/user_registry.h +++ b/src/server/acl/user_registry.h @@ -75,9 +75,30 @@ class UserRegistry { User::UpdateRequest DefaultUserUpdateRequest() const; + // Aggregate memory-usage stats across all users. Updated incrementally on every + // user creation, deletion, or mutation so that INFO ACL can read them lock-free. + struct AclStats : User::MemoryUsage { + size_t num_users = 0; + + // Estimated total bytes consumed by all ACL users (base structs + heap collections). + size_t TotalBytes() const; + }; + + // Returns a snapshot of the aggregate stats under the registry read lock. + AclStats GetAclStats() const; + + // Updates aggregate stats after a user mutation performed via GetRegistryWithWriteLock. + // Must be called while the write lock is held. `before` is zeroed for new users. + void TrackUser(const User::MemoryUsage& before, const User::MemoryUsage& after, bool is_new); + + // Resets aggregate stats to zero. Must be called while the write lock is held, + // immediately after registry.clear(). + void ResetStats(); + private: RegistryType registry_; mutable util::fb2::SharedMutex mu_; + AclStats stats_; // maintained under mu_ // Helper class for accessing the registry with a ReadLock outside the scope of UserRegistry template