Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions src/server/db_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,18 @@ inline bool MayDeleteAsynchronously(const PrimeValue& pv) {
return (obj_type == OBJ_SET || obj_type == OBJ_HASH) && pv.Encoding() == kEncodingStrMap2;
}

// Implement ChangeConsumerInterface with a single callback for one-shot functions
template <typename F> struct CallbackConsumer : public DbSlice::ChangeConsumerInterface {
explicit CallbackConsumer(F f) : f_{std::move(f)} {
}

void OnChange(DbIndex db_index, const ChangeReq& req) {
f_(db_index, req);
}

F f_;
};

} // namespace

#define ADD(x) (x) += o.x
Expand Down Expand Up @@ -826,7 +838,7 @@ void DbSlice::DelMutable(Context cntx, ItAndUpdater it_updater) {
}

void DbSlice::FlushSlotsFb(const cluster::SlotSet& slot_ids, uint64_t next_version,
uint64_t cb_id) {
ChangeConsumerInterface* consumer) {
VLOG(1) << "Start FlushSlotsFb";
// Slot deletion can take time as it traverses all the database, hence it runs in fiber.
// We want to flush all the data of a slot that was added till the time the call to FlushSlots
Expand Down Expand Up @@ -870,7 +882,7 @@ void DbSlice::FlushSlotsFb(const cluster::SlotSet& slot_ids, uint64_t next_versi
} while (cursor && etl.gstate() != GlobalState::SHUTTING_DOWN);

VLOG(1) << "FlushSlotsFb del count is: " << del_count;
UnregisterOnChange(cb_id);
UnregisterOnChange(consumer);

if (absl::GetFlag(FLAGS_cluster_flush_decommit_memory)) {
int64_t start = absl::GetCurrentTimeNanos();
Expand Down Expand Up @@ -929,10 +941,11 @@ void DbSlice::FlushSlots(const cluster::SlotRanges& slot_ranges) {
}
};

uint64_t cb_id = RegisterOnChange(std::move(on_change));
auto consumer = std::make_unique<CallbackConsumer<decltype(on_change)>>(std::move(on_change));
RegisterOnChange(consumer.get());

fb2::Fiber("flush_slots", [this, shared_slots, next_version, cb_id]() {
FlushSlotsFb(*shared_slots, next_version, cb_id);
fb2::Fiber("flush_slots", [this, shared_slots, next_version, consumer = std::move(consumer)]() {
FlushSlotsFb(*shared_slots, next_version, consumer.get());
}).Detach();
}

Expand Down Expand Up @@ -1318,7 +1331,8 @@ PrimeIterator DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterator it) con
void DbSlice::ExpireAllIfNeeded() {
// We hold no locks to any of the keys so we should Wait() here such that
// we don't preempt in ExpireIfNeeded
serialization_latch_.Wait();
WaitForUnblockedJournalWrites();

// Disable flush journal changes to prevent preemtion in traverse.
journal::DisableFlushGuard journal_flush_guard(owner_->journal());

Expand All @@ -1340,9 +1354,27 @@ void DbSlice::ExpireAllIfNeeded() {
}
}

uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) {
void DbSlice::RegisterOnChange(ChangeConsumerInterface* consumer) {
DCHECK(!owner_->shard_lock()->IsFree());
return change_cb_.emplace_back(NextVersion(), std::move(cb)).first;

consumer->snapshot_version_ = NextVersion();
change_cb_.emplace_back(consumer);
}

void DbSlice::UnregisterOnChange(ChangeConsumerInterface* consumer) {
change_cb_latch_.Wait();
Comment thread
dranikpg marked this conversation as resolved.
auto it = std::find(change_cb_.begin(), change_cb_.end(), consumer);
CHECK(it != change_cb_.end());
change_cb_.erase(it);
}

bool DbSlice::WillBlockOnJournalWrite() const {
return ranges::any_of(change_cb_, &ChangeConsumerInterface::IsAnyBucketBlocked);
}

void DbSlice::WaitForUnblockedJournalWrites() const {
while (WillBlockOnJournalWrite())
ranges::for_each(change_cb_, &ChangeConsumerInterface::UnblockAllBuckets);
Comment thread
dranikpg marked this conversation as resolved.
}

// Ordering invariant (PIT mode):
Expand All @@ -1353,7 +1385,7 @@ uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) {
// snapshot could miss the bucket entirely — its traversal already passed it, and the version
// stamp from the current snapshot would cause the earlier snapshot's OnChangeBlocking to skip it.
void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) {
unique_lock<LocalLatch> lk(serialization_latch_);
unique_lock<LocalLatch> lk(change_cb_latch_);

uint64_t bucket_version = it.GetVersion();
// change_cb_ is ordered by version.
Expand All @@ -1363,7 +1395,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_
const size_t limit = change_cb_.size();
auto ccb = change_cb_.begin();
for (size_t i = 0; i < limit; ++i) {
uint64_t cb_version = ccb->first;
uint64_t cb_version = (*ccb)->snapshot_version_;
DCHECK_LE(cb_version, upper_bound);
if (cb_version == upper_bound) {
return;
Expand All @@ -1378,21 +1410,12 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_
// does not change during the serialization, therefore we allow at most one serializer
// reading the bucket at the same time.
if (bucket_version <= cb_version) {
ccb->second(db_ind, ChangeReq{it.GetInnerIt()});
(*ccb)->OnChange(db_ind, ChangeReq{it.GetInnerIt()});
}
++ccb;
}
}

//! Unregisters the callback.
void DbSlice::UnregisterOnChange(uint64_t id) {
serialization_latch_.Wait();
auto it = find_if(change_cb_.begin(), change_cb_.end(),
[id](const auto& cb) { return cb.first == id; });
CHECK(it != change_cb_.end());
change_cb_.erase(it);
}

auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteExpiredStats {
auto& db = *db_arr_[cntx.db_index];
DeleteExpiredStats result;
Expand Down Expand Up @@ -1898,7 +1921,7 @@ void DbSlice::OnCbFinishBlocking() {
}

// We must not change the bucket's internal order during serialization
serialization_latch_.Wait();
WaitForUnblockedJournalWrites();
PrimeBumpPolicy policy;
auto bump_it = db.prime.BumpUp(it, policy);
if (bump_it != it) { // the item was bumped
Expand All @@ -1916,13 +1939,12 @@ void DbSlice::CallChangeCallbacks(DbIndex id, const ChangeReq& cr) const {
return;

// does not preempt, just increments the counter.
unique_lock<LocalLatch> lk(serialization_latch_);
unique_lock<LocalLatch> lk(change_cb_latch_);

const size_t limit = change_cb_.size();
auto ccb = change_cb_.begin();
for (size_t i = 0; i < limit; ++i) {
CHECK(ccb->second);
ccb->second(id, cr);
(*ccb)->OnChange(id, cr);
++ccb;
}
}
Expand Down
46 changes: 29 additions & 17 deletions src/server/db_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ class DbSlice {
void operator=(const DbSlice&) = delete;

public:
// Consumer of bucket change events than can be registered inside the slice.
// It also includes additional methods for interfacing with snapshots and migrations.
struct ChangeConsumerInterface {
Copy link
Copy Markdown

@augmentcode augmentcode Bot May 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/server/db_slice.h:89: ChangeConsumerInterface is used as a polymorphic base, but it doesn’t declare a virtual destructor, which can cause UB if a consumer is ever deleted through a ChangeConsumerInterface* (easy to do with interfaces). Consider adding a virtual default destructor to make ownership patterns safer.

Severity: low

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

// Called before a specific bucket (or set of buckets) will be mutated
virtual void OnChange(DbIndex, const ChangeReq&) = 0;

// Should return true if any bucket is mid-serialization
virtual bool IsAnyBucketBlocked() const {
return false;
}

// Should wait for IsAnyBucketBlocked to return false
virtual void UnblockAllBuckets() const {
}

uint64_t snapshot_version_ = 0;
};

// Auto-laundering iterator wrapper. Laundering means re-finding keys if they moved between
// buckets.
template <typename T> class IteratorT {
Expand Down Expand Up @@ -393,7 +411,8 @@ class DbSlice {
//! Registers the callback to be called for each change.
//! Returns the registration id which is also the unique version of the dbslice
//! at a time of the call.
uint64_t RegisterOnChange(ChangeCallback cb);
void RegisterOnChange(ChangeConsumerInterface* consumer);
void UnregisterOnChange(ChangeConsumerInterface* consumer);

bool HasRegisteredCallbacks() const {
return !change_cb_.empty();
Expand All @@ -402,9 +421,6 @@ class DbSlice {
// Call registered callbacks with version less than upper_bound.
void FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound);

//! Unregisters the callback.
void UnregisterOnChange(uint64_t id);

struct DeleteExpiredStats {
uint32_t deleted = 0; // number of deleted items due to expiry.
uint32_t deleted_bytes = 0; // total bytes of deleted items.
Expand Down Expand Up @@ -483,13 +499,12 @@ class DbSlice {
// if it's not empty and not EX.
void SetNotifyKeyspaceEvents(std::string_view notify_keyspace_events);

bool WillBlockOnJournalWrite() const {
return serialization_latch_.IsBlocked();
}
// Returns true if any registered snapshot is blocked on bucket serialiazion (big value, delayed)
// and thus might reject the journal change
bool WillBlockOnJournalWrite() const;

LocalLatch* GetLatch() {
return &serialization_latch_;
}
// Block and wait for WillBlockOnJournalWrite to become false
void WaitForUnblockedJournalWrites() const;

void StartSampleTopK(DbIndex db_ind, uint32_t min_freq);

Expand Down Expand Up @@ -521,7 +536,8 @@ class DbSlice {
PrimeValue obj, uint64_t expire_at_ms,
bool force_update);

void FlushSlotsFb(const cluster::SlotSet& slot_ids, uint64_t next_version, uint64_t cb_id);
void FlushSlotsFb(const cluster::SlotSet& slot_ids, uint64_t next_version,
ChangeConsumerInterface* consumer);
util::fb2::Fiber FlushDbIndexes(const std::vector<DbIndex>& indexes);

// Invalidate all watched keys in database. Used on FLUSH.
Expand Down Expand Up @@ -565,11 +581,6 @@ class DbSlice {

void CallChangeCallbacks(DbIndex id, const ChangeReq& cr) const;

// We need this because registered callbacks might yield and when they do so we want
// to avoid Heartbeat or Flushing the db.
// This latch protects us against this case.
mutable LocalLatch serialization_latch_;

ShardId shard_id_;
uint8_t cache_mode_ : 1;

Expand Down Expand Up @@ -608,7 +619,8 @@ class DbSlice {
mutable absl::flat_hash_set<uint64_t, FpHasher> uniq_fps_;

// ordered from the smallest to largest version.
std::list<std::pair<uint64_t, ChangeCallback>> change_cb_;
std::list<ChangeConsumerInterface*> change_cb_;
mutable LocalLatch change_cb_latch_;

// Used in temporary computations in Find item and CbFinish
// This set is used to hold fingerprints of key accessed during the run of
Expand Down
20 changes: 16 additions & 4 deletions src/server/dragonfly_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -703,15 +703,27 @@ TEST_F(DflyEngineTest, Bug468) {
ASSERT_FALSE(IsLocked(0, "foo"));
}

struct CountingConsumer : public DbSlice::ChangeConsumerInterface {
explicit CountingConsumer(unsigned* cb_hits) : cb_hits_(cb_hits) {
}

void OnChange(DbIndex db_index, const ChangeReq&) {
(*cb_hits_)++;
}

unsigned* cb_hits_;
};

TEST_F(DflyEngineTest, Bug496) {
shard_set->RunBlockingInParallel([](EngineShard* shard) {
auto& db = namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id());

int cb_hits = 0;
unsigned cb_hits = 0;
CountingConsumer consumer{&cb_hits};

// RegisterOnChange requires the shard lock to be held (see #7153).
shard->shard_lock()->Acquire(IntentLock::EXCLUSIVE);
uint32_t cb_id =
db.RegisterOnChange([&cb_hits](DbIndex, const DbSlice::ChangeReq&) { cb_hits++; });
db.RegisterOnChange(&consumer);
shard->shard_lock()->Release(IntentLock::EXCLUSIVE);

{
Expand All @@ -732,7 +744,7 @@ TEST_F(DflyEngineTest, Bug496) {
EXPECT_EQ(cb_hits, 3);
}

db.UnregisterOnChange(cb_id);
db.UnregisterOnChange(&consumer);
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/server/generic_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ void OpScan(const OpArgs& op_args, const ScanOpts& scan_opts, uint64_t* cursor,
// ScanCb can preempt due to journaling expired entries and we need to make sure that
// we enter the callback in a timing when journaling will not cause preemption. Otherwise,
// the bucket might change as we Traverse and yield.
db_slice.GetLatch()->Wait();
db_slice.WaitUnblockJournalWrite();

// Disable flush journal changes to prevent preemtion in traverse.
journal::DisableFlushGuard journal_flush_guard(op_args.shard->journal());
Expand Down
2 changes: 1 addition & 1 deletion src/server/journal/streamer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ bool RestoreStreamer::Cancel() {
snapshot_version_ = 0; // to prevent double cancel in another fiber
cntx_->Cancel();
if (sver != 0) {
db_slice_->UnregisterOnChange(sver);
db_slice_->UnregisterOnChange(this);
}
bool res = JournalStreamer::Cancel();
LOG_IF(WARNING, res != (sver != 0)) << "Journal and DBSlice unregister state mismatch in "
Expand Down
29 changes: 20 additions & 9 deletions src/server/serializer_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "server/tiered_storage.h"
#include "util/fibers/fibers.h"
#include "util/fibers/stacktrace.h"
#include "util/fibers/synchronization.h"

namespace dfly {

Expand All @@ -36,6 +37,9 @@ void BucketDependencies::Decrement(BucketIdentity bucket) {
it->second->unlock();
if (!it->second->IsBlocked())
deps_.erase(it);

if (deps_.empty())
empty_q_.notify_all();
}

void BucketDependencies::Wait(BucketIdentity bucket) const {
Expand All @@ -47,6 +51,11 @@ void BucketDependencies::Wait(BucketIdentity bucket) const {
counter->Wait();
}

void BucketDependencies::WaitEmpty() const {
util::fb2::NoOpLock lock;
empty_q_.wait(lock, [&] { return deps_.empty(); });
}

bool BucketDependencies::DEBUG_IsBusy(BucketIdentity bucket) const {
return deps_.contains(bucket);
}
Expand Down Expand Up @@ -131,15 +140,12 @@ SerializerBase::~SerializerBase() {
// emitting large values.
void SerializerBase::RegisterChangeListener() {
db_array_ = db_slice_->databases(); // copy pointers to survive flush
auto cb = [this](DbIndex dbid, const ChangeReq& req) {
std::visit([&](auto it) { OnChangeBlocking(dbid, it); }, req);
};
snapshot_version_ = db_slice_->RegisterOnChange(cb);
db_slice_->RegisterOnChange(this);
}

void SerializerBase::UnregisterChangeListener() {
if (auto version = std::exchange(snapshot_version_, 0); version > 0)
db_slice_->UnregisterOnChange(version);
db_slice_->UnregisterOnChange(this);
}

bool SerializerBase::ProcessBucket(DbIndex db_index, PrimeTable::bucket_iterator it,
Expand All @@ -165,12 +171,9 @@ bool SerializerBase::ProcessBucket(DbIndex db_index, PrimeTable::bucket_iterator
// acquire serialization latch.
// We must make sure that earlier snapshots serialized this bucket before we update its
// version below.
std::optional<std::lock_guard<LocalLatch>> db_guard;
if (!on_update) {
if (!on_update)
db_slice_->FlushChangeToEarlierCallbacks(db_index, DbSlice::Iterator::FromPrime(it),
snapshot_version_);
db_guard.emplace(*db_slice_->GetLatch());
}

// The block above with updating earlier callbacks is not exlusive - check version again
if (it.GetVersion() >= snapshot_version_)
Expand All @@ -192,6 +195,14 @@ bool SerializerBase::ProcessBucket(DbIndex db_index, PrimeTable::bucket_iterator
return true;
}

void SerializerBase::UnblockAllBuckets() const {
BucketDependencies::WaitEmpty();
}

void SerializerBase::OnChange(DbIndex db_index, const ChangeReq& req) {
std::visit([&](auto it) { OnChangeBlocking(db_index, it); }, req);
}

void SerializerBase::OnChangeBlocking(DbIndex db_index, PrimeTable::bucket_iterator it) {
std::string_view active_name = util::fb2::detail::FiberActive()->name();
if (!absl::StartsWith(active_name, "shard_queue") && //
Expand Down
Loading
Loading