diff --git a/src/core/compact_object.cc b/src/core/compact_object.cc index 7030b39e87b9..55f0c2fd7924 100644 --- a/src/core/compact_object.cc +++ b/src/core/compact_object.cc @@ -30,6 +30,7 @@ extern "C" { #include "core/cms.h" #include "core/detail/bitpacking.h" #include "core/huff_coder.h" +#include "core/oah_set.h" #include "core/page_usage/page_usage_stats.h" #include "core/qlist.h" #include "core/sorted_map.h" @@ -63,10 +64,12 @@ size_t UpdateSize(size_t size, int64_t update) { inline void FreeObjSet(unsigned encoding, void* ptr, MemoryResource* mr) { switch (encoding) { - case kEncodingStrMap2: { - CompactObj::DeleteMR(ptr); + case kEncodingStrMap2: + VisitSet(ptr, [](auto* ss) { + using T = std::remove_pointer_t; + CompactObj::DeleteMR(ss); + }); break; - } case kEncodingIntSet: zfree((void*)ptr); @@ -87,10 +90,10 @@ void FreeList(unsigned encoding, void* ptr, MemoryResource* mr) { size_t MallocUsedSet(unsigned encoding, void* ptr) { switch (encoding) { - case kEncodingStrMap2: { - StringSet* ss = (StringSet*)ptr; - return ss->ObjMallocUsed() + ss->SetMallocUsed() + zmalloc_usable_size(ptr); - } + case kEncodingStrMap2: + return VisitSet(ptr, [ptr](auto* ss) { + return ss->ObjMallocUsed() + ss->SetMallocUsed() + zmalloc_usable_size(ptr); + }); case kEncodingIntSet: return intsetBlobLen((intset*)ptr); } @@ -277,12 +280,10 @@ pair DefragSortedMap(detail::SortedMap* sm, PageUsage* page_usage) return {sm, reallocated}; } -pair DefragStrSet(StringSet* ss, PageUsage* page_usage) { +template pair DefragDenseSet(Set* ss, PageUsage* page_usage) { bool realloced = false; - for (auto it = ss->begin(); it != ss->end(); ++it) realloced |= it.ReallocIfNeeded(page_usage); - return {ss, realloced}; } @@ -313,9 +314,8 @@ pair DefragSet(unsigned encoding, void* ptr, PageUsage* page_usage) return DefragIntSet((intset*)ptr, page_usage); } - case kEncodingStrMap2: { - return DefragStrSet((StringSet*)ptr, page_usage); - } + case kEncodingStrMap2: + return VisitSet(ptr, [page_usage](auto* ss) { return DefragDenseSet(ss, page_usage); }); default: ABSL_UNREACHABLE(); @@ -460,10 +460,8 @@ size_t RobjWrapper::Size() const { intset* is = (intset*)inner_obj_; return intsetLen(is); } - case kEncodingStrMap2: { - StringSet* ss = (StringSet*)inner_obj_; - return ss->UpperBoundSize(); - } + case kEncodingStrMap2: + return VisitSet(inner_obj_, [](auto* ss) { return ss->UpperBoundSize(); }); default: LOG(FATAL) << "Unexpected encoding " << encoding_; }; diff --git a/src/core/oah_set.h b/src/core/oah_set.h index 2746449662fc..47d8109030e9 100644 --- a/src/core/oah_set.h +++ b/src/core/oah_set.h @@ -12,6 +12,7 @@ #include #include "core/detail/stateless_allocator.h" +#include "core/string_set.h" #include "oah_entry.h" namespace dfly { @@ -107,11 +108,19 @@ class OAHSet { // Open Addressing Hash Set void SetEntryIt() { if (!owner_) return; + // time_now_ == 0 disables expiry (callers set it to 0 around serialization). + const uint32_t now = owner_->time_now_; for (auto num_entries = owner_->entries_.size(); bucket_ < num_entries; ++bucket_) { auto& bucket = owner_->entries_[bucket_]; for (uint32_t bucket_size = bucket.ElementsNum(); pos_ < bucket_size; ++pos_) { - if (bucket[pos_]) - return; + auto& entry = bucket[pos_]; + if (!entry) + continue; + if (now != 0 && entry.HasExpiry() && entry.GetExpiry() <= now) { + entry.ExpireIfNeeded(now, &owner_->size_, &owner_->obj_alloc_used_); + continue; + } + return; } pos_ = 0; } @@ -468,11 +477,11 @@ class OAHSet { // Open Addressing Hash Set return time_now_; } - size_t ObjAllocUsed() const { + size_t ObjMallocUsed() const { return obj_alloc_used_; } - size_t SetAllocUsed() const { + size_t SetMallocUsed() const { return entries_.capacity() * sizeof(OAHEntry) + ptr_vectors_alloc_used_; } @@ -762,4 +771,28 @@ class OAHSet { // Open Addressing Hash Set Buckets entries_; }; +// Snapshot of --use_oah_set captured once at startup. +inline bool g_use_oah_set = false; + +// Dispatches a generic lambda over the runtime-selected dense-set type backing +// kEncodingStrMap2 SETs. Both StringSet and OAHSet expose the same surface +// (set_time, Empty, BucketCount, Reserve, ObjMallocUsed, ...) so the lambda +// can be written once and visit either concrete type. +template auto VisitSet(void* ptr, Fn&& fn) { + return g_use_oah_set ? fn(static_cast(ptr)) : fn(static_cast(ptr)); +} + +// Extracts the current member as a string_view from either a StringSet or an +// OAHSet iterator. Free functions so generic code (e.g. inside VisitSet +// lambdas) can write `Key(it)` without a member-method asymmetry between the +// two iterator types. +inline std::string_view Key(StringSet::iterator it) { + sds s = *it; + return {s, sdslen(s)}; +} + +inline std::string_view Key(OAHSet::iterator it) { + return it->Key(); +} + } // namespace dfly diff --git a/src/core/oah_set_test.cc b/src/core/oah_set_test.cc index 091a963f442c..96b2a6c49dfa 100644 --- a/src/core/oah_set_test.cc +++ b/src/core/oah_set_test.cc @@ -614,7 +614,7 @@ TEST_F(OAHSetTest, ReallocIfNeededForceReallocates) { for (size_t i = 0; i < 50; ++i) { EXPECT_TRUE(ss_->Add(absl::StrCat("key_", i, "_xxxxxxxx"), 100 + i)); } - size_t alloc_before = ss_->ObjAllocUsed(); + size_t alloc_before = ss_->ObjMallocUsed(); EXPECT_GT(alloc_before, 0u); PageUsage page_usage{CollectPageStats::NO, 0.9}; @@ -633,8 +633,8 @@ TEST_F(OAHSetTest, ReallocIfNeededForceReallocates) { ASSERT_NE(it, ss_->end()); EXPECT_EQ(it.ExpiryTime(), 100u + i); } - // ObjAllocUsed remains roughly consistent (mimalloc usable size for same logical size). - EXPECT_GT(ss_->ObjAllocUsed(), 0u); + // ObjMallocUsed remains roughly consistent (mimalloc usable size for same logical size). + EXPECT_GT(ss_->ObjMallocUsed(), 0u); } TEST_F(OAHSetTest, ReallocIfNeededVectorEntry) { @@ -743,7 +743,7 @@ TEST_F(OAHSetTest, ClearStepIncremental) { } EXPECT_EQ(cursor, total); EXPECT_EQ(ss_->UpperBoundSize(), 0u); - EXPECT_EQ(ss_->ObjAllocUsed(), 0u); + EXPECT_EQ(ss_->ObjMallocUsed(), 0u); } TEST_F(OAHSetTest, ClearStepFullBucketCount) { @@ -753,7 +753,7 @@ TEST_F(OAHSetTest, ClearStepFullBucketCount) { uint32_t end = ss_->ClearStep(0, ss_->Capacity()); EXPECT_EQ(end, ss_->Capacity()); EXPECT_EQ(ss_->UpperBoundSize(), 0u); - EXPECT_EQ(ss_->ObjAllocUsed(), 0u); + EXPECT_EQ(ss_->ObjMallocUsed(), 0u); } TEST_F(OAHSetTest, GetRandomMemberEmpty) { @@ -810,7 +810,7 @@ TEST_F(OAHSetTest, ClearStepResetsExpirationUsed) { << "ExpirationUsed must be false after ClearStep fully empties the set"; } -TEST_F(OAHSetTest, ReallocIfNeededObjAllocUsedConsistent) { +TEST_F(OAHSetTest, ReallocIfNeededObjMallocUsedConsistent) { // Sanity: after force-realloc, obj_alloc_used_ remains the sum of all entries' // current AllocSize. Guards against signed-delta arithmetic going wrong on the counter. for (size_t i = 0; i < 100; ++i) @@ -824,20 +824,20 @@ TEST_F(OAHSetTest, ReallocIfNeededObjAllocUsedConsistent) { size_t expected = 0; for (auto it = ss_->begin(); it != ss_->end(); ++it) expected += (*it).AllocSize(); - EXPECT_EQ(ss_->ObjAllocUsed(), expected); + EXPECT_EQ(ss_->ObjMallocUsed(), expected); } -TEST_F(OAHSetTest, ClearResetsObjAllocUsed) { +TEST_F(OAHSetTest, ClearResetsObjMallocUsed) { for (size_t i = 0; i < 100; ++i) { ss_->Add(random_string(generator_, 10)); } - EXPECT_GT(ss_->ObjAllocUsed(), 0u); + EXPECT_GT(ss_->ObjMallocUsed(), 0u); EXPECT_GT(ss_->UpperBoundSize(), 0u); ss_->Clear(); - EXPECT_EQ(ss_->ObjAllocUsed(), 0u); + EXPECT_EQ(ss_->ObjMallocUsed(), 0u); EXPECT_EQ(ss_->UpperBoundSize(), 0u); } @@ -849,7 +849,7 @@ TEST_F(OAHSetTest, IterateEmpty) { } static size_t MemUsed(OAHSet& obj) { - return obj.ObjAllocUsed() + obj.SetAllocUsed(); + return obj.ObjMallocUsed() + obj.SetMallocUsed(); } void BM_Clone(benchmark::State& state) { diff --git a/src/server/collection_family_fallback.cc b/src/server/collection_family_fallback.cc index 8932a915690f..803f93c2bced 100644 --- a/src/server/collection_family_fallback.cc +++ b/src/server/collection_family_fallback.cc @@ -30,7 +30,7 @@ StringMap* HSetFamily::ConvertToStrMap(uint8_t* lp) { return nullptr; } -StringSet* SetFamily::ConvertToStrSet(const intset* is, size_t expected_len) { +void* SetFamily::ConvertToStrSet(const intset* is, size_t expected_len) { Fail(); return nullptr; } diff --git a/src/server/container_utils.cc b/src/server/container_utils.cc index 3bfc54b6695c..45bdcdb2ffa8 100644 --- a/src/server/container_utils.cc +++ b/src/server/container_utils.cc @@ -6,6 +6,7 @@ #include "base/flags.h" #include "base/logging.h" #include "core/detail/listpack_wrap.h" +#include "core/oah_set.h" #include "core/qlist.h" #include "core/sorted_map.h" #include "core/string_map.h" @@ -205,12 +206,15 @@ bool IterateSet(const PrimeValue& pv, const IterateFunc& func) { success = func(ContainerEntry{ival}); } } else { - for (sds ptr : *static_cast(pv.RObjPtr())) { - if (!func(ContainerEntry{ptr, sdslen(ptr)})) { - success = false; - break; + VisitSet(pv.RObjPtr(), [&](auto* set) { + for (auto it = set->begin(); it != set->end(); ++it) { + std::string_view key = Key(it); + if (!func(ContainerEntry{key.data(), key.size()})) { + success = false; + break; + } } - } + }); } return success; diff --git a/src/server/db_slice.cc b/src/server/db_slice.cc index c8322c805889..aa9d69705a79 100644 --- a/src/server/db_slice.cc +++ b/src/server/db_slice.cc @@ -5,6 +5,7 @@ #include "server/db_slice.h" #include "core/dense_set.h" +#include "core/oah_set.h" #include "strings/human_readable.h" extern "C" { @@ -254,35 +255,48 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotBuckets& eb, PrimeTable class AsyncDeleter { public: - static void EnqueDeletion(uint32_t next, DenseSet* ds); + template static void EnqueDeletion(uint32_t next, Set* ds); static void Shutdown(); private: static constexpr uint32_t kClearStepSize = 1024; struct ClearNode { - DenseSet* ds; + void* ds; uint32_t cursor; + // Advances the node by one step. Deletes the underlying set and returns true + // when the set is fully cleared; otherwise updates cursor and returns false. + bool (*step)(ClearNode*); ClearNode* next; - - ClearNode(DenseSet* d, uint32_t c, ClearNode* n) : ds(d), cursor(c), next(n) { - } }; - // Asynchronously deletes entries during the cpu-idle time. static int32_t IdleCb(); - - // We add async deletion requests to a linked list and process them asynchronously - // in each thread. static __thread ClearNode* head_; }; __thread AsyncDeleter::ClearNode* AsyncDeleter::head_ = nullptr; -void AsyncDeleter::EnqueDeletion(uint32_t next, DenseSet* ds) { +// ClearStep returns the next cursor; the table is empty when it equals the +// underlying entries-vector size. DenseSet exposes that as BucketCount(); +// OAHSet exposes it as Capacity() (BucketCount() omits displacement slots). +template uint32_t ClearStepEnd(Set* s) { + if constexpr (std::is_same_v) + return s->Capacity(); + else + return s->BucketCount(); +} + +template void AsyncDeleter::EnqueDeletion(uint32_t next, Set* ds) { + auto step = +[](ClearNode* n) { + auto* s = static_cast(n->ds); + n->cursor = s->ClearStep(n->cursor, kClearStepSize); + if (n->cursor == ClearStepEnd(s)) { + CompactObj::DeleteMR(s); + return true; + } + return false; + }; bool launch_task = (head_ == nullptr); - - // register ds - head_ = new ClearNode{ds, next, head_}; + head_ = new ClearNode{ds, next, step, head_}; ProactorBase* pb = ProactorBase::me(); DCHECK(pb); DVLOG(2) << "Adding async deletion task, thread " << pb->GetPoolIndex() << " " << launch_task; @@ -306,15 +320,10 @@ int32_t AsyncDeleter::IdleCb() { return -1; // unregister itself. auto* current = head_; - DVLOG(2) << "IdleCb " << current->cursor; - uint32_t next = current->ds->ClearStep(current->cursor, kClearStepSize); - if (next == current->ds->BucketCount()) { // reached the end. - CompactObj::DeleteMR(current->ds); + if (current->step(current)) { head_ = current->next; delete current; - } else { - current->cursor = next; } return ProactorBase::kOnIdleMaxLevel; }; @@ -1839,16 +1848,21 @@ void DbSlice::PerformDeletionAtomic(const Iterator& del_it, DbTable* table, bool AccountObjectMemory(del_it.key(), pv.ObjType(), -value_heap_size, table); // Value if (async && MayDeleteAsynchronously(pv)) { - DenseSet* ds = (DenseSet*)pv.RObjPtr(); + auto schedule = [](auto* ds) { + using Ds = std::remove_pointer_t; + uint32_t next = ds->ClearStep(0, 512); + if (next < ClearStepEnd(ds)) + AsyncDeleter::EnqueDeletion(next, ds); + else + CompactObj::DeleteMR(ds); + }; + void* obj_ptr = pv.RObjPtr(); pv.SetRObjPtr(nullptr); - const size_t kClearStepSize = 512; - - uint32_t next = ds->ClearStep(0, kClearStepSize); - if (next < ds->BucketCount()) { - AsyncDeleter::EnqueDeletion(next, ds); - } else { - CompactObj::DeleteMR(ds); - } + // SET dispatches via VisitSet (StringSet/OAHSet); HASH is always StringMap (DenseSet-derived). + if (pv.ObjType() == OBJ_SET) + VisitSet(obj_ptr, schedule); + else + schedule(static_cast(obj_ptr)); } if (table->slots_stats) { diff --git a/src/server/dfly_main.cc b/src/server/dfly_main.cc index c85de72326e4..1e54adb75b84 100644 --- a/src/server/dfly_main.cc +++ b/src/server/dfly_main.cc @@ -84,6 +84,11 @@ ABSL_DECLARE_FLAG(std::string, admin_bind); ABSL_DECLARE_FLAG(strings::MemoryBytesFlag, maxmemory); ABSL_DECLARE_FLAG(uint32_t, proactor_threads); ABSL_DECLARE_FLAG(std::string, dbfilename); +ABSL_DECLARE_FLAG(bool, use_oah_set); + +namespace dfly { +extern bool g_use_oah_set; // defined in core/oah_set.h +} #ifdef USE_ABSL_LOG ABSL_FLAG(bool, alsologtostderr, false, "also log messages to stderr in addition to logfiles"); @@ -1139,6 +1144,7 @@ Usage: dragonfly [FLAGS] LOG(WARNING) << "SWAP is enabled. Consider disabling it when running Dragonfly."; dfly::max_memory_limit = absl::GetFlag(FLAGS_maxmemory); + dfly::g_use_oah_set = absl::GetFlag(FLAGS_use_oah_set); if (dfly::max_memory_limit == 0) { LOG(INFO) << "maxmemory has not been specified. Deciding myself...."; diff --git a/src/server/generic_family.cc b/src/server/generic_family.cc index bbc7f26b28a7..b30b004d7501 100644 --- a/src/server/generic_family.cc +++ b/src/server/generic_family.cc @@ -21,6 +21,7 @@ extern "C" { #include "base/flags.h" #include "base/logging.h" #include "core/glob_matcher.h" +#include "core/oah_set.h" #include "core/qlist.h" #include "core/string_set.h" #include "redis/rdb.h" @@ -1668,8 +1669,8 @@ OpResult, CompactObjType>> OpFetchContainerElements(const Op // IterateSet would skip expiry entirely and empty-set cleanup below would // depend on a prior command having set time_now_. if (obj_type == OBJ_SET && it->second.Encoding() == kEncodingStrMap2) { - static_cast(it->second.RObjPtr()) - ->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); + uint32_t t = MemberTimeSeconds(op_args.db_cntx.time_now_ms); + VisitSet(it->second.RObjPtr(), [t](auto* s) { s->set_time(t); }); } Iterate(it->second, [&elements](const ContainerEntry& entry) { diff --git a/src/server/journal/cmd_serializer.cc b/src/server/journal/cmd_serializer.cc index 382828f651c3..52955966e912 100644 --- a/src/server/journal/cmd_serializer.cc +++ b/src/server/journal/cmd_serializer.cc @@ -4,6 +4,7 @@ #include "server/journal/cmd_serializer.h" +#include "core/oah_set.h" #include "core/string_map.h" #include "core/string_set.h" #include "server/container_utils.h" @@ -162,12 +163,13 @@ void CmdSerializer::SerializeExpireIfNeeded(string_view key, uint64_t expire_ms) size_t CmdSerializer::SerializeSet(string_view key, const PrimeValue& pv) { // Disable lazy expiry during serialization (same as rdb_save.cc). // We are called under bucket lock so DeleteIfEmpty is not possible. - StringSet* ss = nullptr; uint32_t prev_time = 0; - if (pv.Encoding() == kEncodingStrMap2) { - ss = static_cast(pv.RObjPtr()); - prev_time = ss->time_now(); - ss->set_time(0); + bool dense = pv.Encoding() == kEncodingStrMap2; + if (dense) { + VisitSet(pv.RObjPtr(), [&](auto* s) { + prev_time = s->time_now(); + s->set_time(0); + }); } CommandAggregator aggregator( @@ -181,8 +183,8 @@ size_t CmdSerializer::SerializeSet(string_view key, const PrimeValue& pv) { }); // Restore previous time so subsequent operations can trigger lazy expiry. - if (ss) - ss->set_time(prev_time); + if (dense) + VisitSet(pv.RObjPtr(), [prev_time](auto* s) { s->set_time(prev_time); }); return commands; } diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index fe8b99b1be04..745144a68c5a 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -35,6 +35,7 @@ extern "C" { #include "core/cms.h" #include "core/detail/listpack_wrap.h" #include "core/json/json_object.h" +#include "core/oah_set.h" #include "core/qlist.h" #include "core/sorted_map.h" #include "core/string_map.h" @@ -372,6 +373,8 @@ void RdbLoaderBase::OpaqueObjLoader::CreateSet(const LoadTrace* ltrace) { if (inner_obj) { if (is_intset) { zfree(inner_obj); + } else if (g_use_oah_set) { + CompactObj::DeleteMR(inner_obj); } else { CompactObj::DeleteMR(inner_obj); } @@ -394,62 +397,67 @@ void RdbLoaderBase::OpaqueObjLoader::CreateSet(const LoadTrace* ltrace) { return true; }); } else { - StringSet* set; - if (config_.append) { - // Note we always use StringSet when the object is being chunked. - if (!EnsureObjEncoding(OBJ_SET, kEncodingStrMap2)) { - return; + auto load = [&]() { + Set* set; + if (config_.append) { + if (!EnsureObjEncoding(OBJ_SET, kEncodingStrMap2)) { + return; + } + set = static_cast(pv_->RObjPtr()); + } else { + set = CompactObj::AllocateMR(); + set->set_time(MemberTimeSeconds(GetCurrentTimeMs())); + inner_obj = set; + + // Expand the set up front to avoid rehashing. + set->Reserve((config_.reserve > len) ? config_.reserve : len); } - set = static_cast(pv_->RObjPtr()); - } else { - set = CompactObj::AllocateMR(); - set->set_time(MemberTimeSeconds(GetCurrentTimeMs())); - inner_obj = set; - // Expand the set up front to avoid rehashing. - set->Reserve((config_.reserve > len) ? config_.reserve : len); - } + size_t increment = 1; + if (rdb_type_ == RDB_TYPE_SET_WITH_EXPIRY) { + increment = 2; + } - size_t increment = 1; - if (rdb_type_ == RDB_TYPE_SET_WITH_EXPIRY) { - increment = 2; - } + bool values_expired = false; - bool values_expired = false; + for (size_t i = 0; i < ltrace->arr.size(); i += increment) { + string_view element = ToSV(ltrace->arr[i].rdb_var, &buf1_); - for (size_t i = 0; i < ltrace->arr.size(); i += increment) { - string_view element = ToSV(ltrace->arr[i].rdb_var, &buf1_); + uint32_t ttl_sec = UINT32_MAX; + if (increment == 2) { + int64_t ttl_time = -1; + string_view ttl_str = ToSV(ltrace->arr[i + 1].rdb_var, &buf2_); + if (!absl::SimpleAtoi(ttl_str, &ttl_time)) { + LOG(ERROR) << "Can't parse set TTL " << ttl_str; + ec_ = RdbError(errc::rdb_file_corrupted); + return; + } - uint32_t ttl_sec = UINT32_MAX; - if (increment == 2) { - int64_t ttl_time = -1; - string_view ttl_str = ToSV(ltrace->arr[i + 1].rdb_var, &buf2_); - if (!absl::SimpleAtoi(ttl_str, &ttl_time)) { - LOG(ERROR) << "Can't parse set TTL " << ttl_str; - ec_ = RdbError(errc::rdb_file_corrupted); - return; - } + if (ttl_time != -1) { + if (ttl_time <= set->time_now()) { + values_expired = true; + continue; + } - if (ttl_time != -1) { - if (ttl_time <= set->time_now()) { - values_expired = true; - continue; + ttl_sec = ttl_time - set->time_now(); } - - ttl_sec = ttl_time - set->time_now(); + } + if (!set->Add(element, ttl_sec)) { + LOG(ERROR) << "Duplicate set members detected " << absl::CHexEscape(element) + << " with TTL " << ttl_sec << " " << rdb_type_ << " " << set->ExpirationUsed() + << " " << config_.append; + ec_ = RdbError(errc::duplicate_key); + return; } } - if (!set->Add(element, ttl_sec)) { - LOG(ERROR) << "Duplicate set members detected " << absl::CHexEscape(element) << " with TTL " - << ttl_sec << " " << rdb_type_ << " " << set->ExpirationUsed() << " " - << config_.append; - ec_ = RdbError(errc::duplicate_key); - return; + if (set->Empty() && values_expired) { + ec_ = RdbError(errc::value_expired); } - } - if (set->Empty() && values_expired) { - ec_ = RdbError(errc::value_expired); - } + }; + if (g_use_oah_set) + load.template operator()(); + else + load.template operator()(); } if (ec_) diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index cf6149af341c..d81384733bc3 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -28,6 +28,7 @@ extern "C" { #include "core/bloom.h" #include "core/cms.h" #include "core/json/json_object.h" +#include "core/oah_set.h" #include "core/qlist.h" #include "core/search/hnsw_index.h" #include "core/size_tracking_channel.h" @@ -177,10 +178,8 @@ uint8_t RdbObjectType(const CompactObj& pv) { if (compact_enc == kEncodingIntSet) return RDB_TYPE_SET_INTSET; else if (compact_enc == kEncodingStrMap2) { - if (((StringSet*)pv.RObjPtr())->ExpirationUsed()) - return RDB_TYPE_SET_WITH_EXPIRY; - else - return RDB_TYPE_SET; + bool has_expiry = VisitSet(pv.RObjPtr(), [](auto* s) { return s->ExpirationUsed(); }); + return has_expiry ? RDB_TYPE_SET_WITH_EXPIRY : RDB_TYPE_SET; } break; case OBJ_ZSET: @@ -437,28 +436,29 @@ error_code RdbSerializer::SaveListObject(const PrimeValue& pv) { error_code RdbSerializer::SaveSetObject(const PrimeValue& obj) { if (obj.Encoding() == kEncodingStrMap2) { - StringSet* set = (StringSet*)obj.RObjPtr(); - - // We don't expire any data during serialization - set->set_time(0); - - // due to we avoid expiring we can use UpperBoundSize() instead of SlowSize() - RETURN_ON_ERR(SaveLen(set->UpperBoundSize())); - for (auto it = set->begin(); it != set->end();) { - RETURN_ON_ERR(SaveString(string_view{*it, sdslen(*it)})); - if (set->ExpirationUsed()) { - int64_t expiry = -1; - if (it.HasExpiry()) - expiry = it.ExpiryTime(); - RETURN_ON_ERR(SaveLongLongAsString(expiry)); + auto save_loop = [this](auto* set) -> error_code { + // set_time(0) disables lazy expiry during serialization. Restore on every + // exit path (including the early returns inside RETURN_ON_ERR) so a failed + // SAVE doesn't leave the set with expiry permanently disabled. + set->set_time(0); + absl::Cleanup restore_time = [set] { set->set_time(MemberTimeSeconds(GetCurrentTimeMs())); }; + + RETURN_ON_ERR(SaveLen(set->UpperBoundSize())); + for (auto it = set->begin(); it != set->end();) { + RETURN_ON_ERR(SaveString(Key(it))); + if (set->ExpirationUsed()) { + int64_t expiry = it.HasExpiry() ? int64_t{it.ExpiryTime()} : -1; + RETURN_ON_ERR(SaveLongLongAsString(expiry)); + } + ++it; + FlushState flush_state = + it == set->end() ? FlushState::kFlushEndEntry : FlushState::kFlushMidEntry; + PushToConsumerIfNeeded(flush_state); } - ++it; - FlushState flush_state = FlushState::kFlushMidEntry; - if (it == set->end()) - flush_state = FlushState::kFlushEndEntry; - PushToConsumerIfNeeded(flush_state); - } - set->set_time(MemberTimeSeconds(GetCurrentTimeMs())); + return error_code{}; + }; + + RETURN_ON_ERR(VisitSet(obj.RObjPtr(), save_loop)); } else { CHECK_EQ(obj.Encoding(), kEncodingIntSet); intset* is = (intset*)obj.RObjPtr(); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 3ecc4b1eb5be..bf10c2b7046a 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -41,6 +41,7 @@ extern "C" { #include "base/logging.h" #include "core/compact_object.h" #include "core/dense_set.h" +#include "core/oah_set.h" #include "facade/cmd_arg_parser.h" #include "facade/dragonfly_connection.h" #include "facade/dragonfly_listener.h" @@ -2829,62 +2830,86 @@ void ServerFamily::Shrink(CmdArgList args, CommandContext* cmd_cntx) { auto cb = [key](Transaction* t, EngineShard* shard) -> OpResult { auto& db_slice = t->GetDbSlice(shard->shard_id()); - // First, do a read-only check: validate type/encoding and decide whether - // shrink is needed. This avoids bumping the key version, firing WATCH - // invalidations, or running PostUpdate for no-op / WRONGTYPE paths. - { - auto it = db_slice.FindReadOnly(t->GetDbContext(), key); - if (!IsValid(it)) { - return OpStatus::KEY_NOTFOUND; - } + // SET with --use_oah_set uses OAHSet; HASH (StringMap) and SET-without-OAH + // (StringSet) both inherit DenseSet, so the original code path covers them. + auto shrink = [&]() -> OpResult { + // First, do a read-only check: validate type/encoding and decide whether + // shrink is needed. This avoids bumping the key version, firing WATCH + // invalidations, or running PostUpdate for no-op / WRONGTYPE paths. + { + auto it = db_slice.FindReadOnly(t->GetDbContext(), key); + if (!IsValid(it)) { + return OpStatus::KEY_NOTFOUND; + } - const PrimeValue& pv = it->second; - unsigned encoding = pv.Encoding(); - unsigned obj_type = pv.ObjType(); + const PrimeValue& pv = it->second; + unsigned encoding = pv.Encoding(); + unsigned obj_type = pv.ObjType(); - if (encoding != kEncodingStrMap2 || (obj_type != OBJ_SET && obj_type != OBJ_HASH)) { - return OpStatus::WRONG_TYPE; - } + if (encoding != kEncodingStrMap2 || (obj_type != OBJ_SET && obj_type != OBJ_HASH)) { + return OpStatus::WRONG_TYPE; + } - DenseSet* ds = static_cast(pv.RObjPtr()); - ds->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); - size_t current_size = ds->UpperBoundSize(); - size_t bucket_count = ds->BucketCount(); + Set* ds = static_cast(pv.RObjPtr()); + ds->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); + size_t current_size = ds->UpperBoundSize(); + size_t bucket_count = ds->BucketCount(); - if (current_size == 0 || bucket_count == 0) { - return 0; + if (current_size == 0 || bucket_count == 0) { + return 0; + } + + size_t optimal_size = std::max(size_t(8), absl::bit_ceil(current_size)); + if (optimal_size >= bucket_count) { + return 0; + } } - size_t optimal_size = std::max(size_t(8), absl::bit_ceil(current_size)); - if (optimal_size >= bucket_count) { - return 0; + // Shrink is needed — use FindMutable so the AutoUpdater tracks the + // MallocUsed() delta (bucket array resize, link changes, expired-entry + // deletions) and keeps obj_memory_usage in sync. + auto it_res = db_slice.FindMutable(t->GetDbContext(), key); + if (!IsValid(it_res.it)) { + return OpStatus::KEY_NOTFOUND; // raced away between the two lookups } - } - // Shrink is needed — use FindMutable so the AutoUpdater tracks the - // MallocUsed() delta (bucket array resize, link changes, expired-entry - // deletions) and keeps obj_memory_usage in sync. - auto it_res = db_slice.FindMutable(t->GetDbContext(), key); - if (!IsValid(it_res.it)) { - return OpStatus::KEY_NOTFOUND; // raced away between the two lookups - } + PrimeValue& pv = it_res.it->second; + Set* ds = static_cast(pv.RObjPtr()); + ds->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); - PrimeValue& pv = it_res.it->second; - DenseSet* ds = static_cast(pv.RObjPtr()); - ds->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); + // Bucket-array bytes only. We can't use SetMallocUsed() because it also + // counts collision-link / vector-bucket bytes which can grow when the + // table shrinks, and the SHRINK reply is meant to report the bucket-vector + // delta. DenseSet buckets are DensePtr (pointer-sized); OAHSet buckets are + // OAHEntry-sized and the entries vector includes displacement slots. + auto bucket_bytes = [](Set* s) -> size_t { + if constexpr (std::is_same_v) + return size_t{s->Capacity()} * sizeof(OAHEntry); + else + return s->BucketCount() * sizeof(void*); + }; + size_t bytes_before = bucket_bytes(ds); + size_t optimal_size = std::max(size_t(8), absl::bit_ceil(ds->UpperBoundSize())); + ds->Shrink(optimal_size); + size_t bytes_after = bucket_bytes(ds); + + // Shrink expires entries during bucket compaction. If all entries expired, + // delete the now-empty key to prevent zombie keys that crash SAVE. + if (ds->Empty()) { + db_slice.DelMutable(t->GetDbContext(), std::move(it_res)); + } - size_t bucket_bytes_before = ds->BucketCount() * sizeof(void*); - size_t optimal_size = std::max(size_t(8), absl::bit_ceil(ds->UpperBoundSize())); - ds->Shrink(optimal_size); - size_t bucket_bytes_after = ds->BucketCount() * sizeof(void*); + return bytes_before - bytes_after; + }; - // Shrink expires entries during bucket compaction. If all entries expired, - // delete the now-empty key to prevent zombie keys that crash SAVE. - if (ds->Empty()) { - db_slice.DelMutable(t->GetDbContext(), std::move(it_res)); + bool use_oah; + { + auto it = db_slice.FindReadOnly(t->GetDbContext(), key); + if (!IsValid(it)) + return OpStatus::KEY_NOTFOUND; + use_oah = it->second.ObjType() == OBJ_SET && g_use_oah_set; } - - return bucket_bytes_before - bucket_bytes_after; + return use_oah ? shrink.template operator()() : shrink.template operator()(); }; OpResult result = cmd_cntx->tx()->ScheduleSingleHopT(std::move(cb)); diff --git a/src/server/set_family.cc b/src/server/set_family.cc index e1e81642adb0..cf4a6bf86fa7 100644 --- a/src/server/set_family.cc +++ b/src/server/set_family.cc @@ -12,10 +12,13 @@ extern "C" { #include "redis/util.h" // for string2ll } +#include + #include "base/cycle_clock.h" #include "base/logging.h" #include "base/stl_util.h" #include "core/detail/listpack_wrap.h" +#include "core/oah_set.h" #include "core/string_set.h" #include "facade/cmd_arg_parser.h" #include "server/acl/acl_commands_def.h" @@ -28,6 +31,8 @@ extern "C" { #include "server/journal/journal.h" #include "server/transaction.h" +ABSL_FLAG(bool, use_oah_set, false, "If true, store SET values in OAHSet instead of StringSet."); + namespace rng = std::ranges; namespace dfly { @@ -56,6 +61,11 @@ bool IsDenseEncoding(const CompactObj& co) { return co.Encoding() == kEncodingStrMap2; } +inline void DenseSetTouchTime(void* robj_ptr, uint64_t now_ms) { + uint32_t t = MemberTimeSeconds(now_ms); + VisitSet(robj_ptr, [t](auto* s) { s->set_time(t); }); +} + intset* IntsetAddSafe(string_view val, intset* is, bool* success, bool* added) { long long llval; *added = false; @@ -89,7 +99,9 @@ struct StringSetWrapper { } static void Init(CompactObj* obj) { - obj->InitRobj(OBJ_SET, kEncodingStrMap2, CompactObj::AllocateMR()); + void* set = g_use_oah_set ? static_cast(CompactObj::AllocateMR()) + : static_cast(CompactObj::AllocateMR()); + obj->InitRobj(OBJ_SET, kEncodingStrMap2, set); } unsigned Add(const NewEntries& entries, uint32_t ttl_sec, bool keepttl) const { @@ -97,29 +109,33 @@ struct StringSetWrapper { string_view members[StringSet::kMaxBatchLen]; size_t entries_len = std::visit([](const auto& e) { return e.size(); }, entries); unsigned len = 0; - if (ss->BucketCount() < entries_len) { - ss->Reserve(entries_len); - } + VisitSet(obj_, [entries_len](auto* s) { + if (s->BucketCount() < entries_len) + s->Reserve(entries_len); + }); + auto add_many = [&](unsigned n) { + return VisitSet( + obj_, [&](auto* s) { return s->AddMany(absl::MakeSpan(members, n), ttl_sec, keepttl); }); + }; for (string_view member : EntriesRange(entries)) { members[len++] = member; if (len == StringSet::kMaxBatchLen) { - res += ss->AddMany(absl::MakeSpan(members, StringSet::kMaxBatchLen), ttl_sec, keepttl); + res += add_many(len); len = 0; } } - - if (len) { - res += ss->AddMany(absl::MakeSpan(members, len), ttl_sec, keepttl); - } - + if (len) + res += add_many(len); return res; } pair Remove(const facade::ArgRange& entries) const { - unsigned removed = 0; - for (string_view member : entries) - removed += ss->Erase(member); - return {removed, ss->Empty()}; + return VisitSet(obj_, [&](auto* s) { + unsigned removed = 0; + for (string_view member : entries) + removed += s->Erase(member); + return std::pair{removed, s->Empty()}; + }); } uint64_t Scan(uint64_t curs, const ScanOpts& scan_op, StringVec* res) const { @@ -130,36 +146,51 @@ struct StringSetWrapper { // Approximately 100usec const uint64_t timeout_cycles = base::CycleClock::Now() + base::CycleClock::Frequency() / 10000; + auto record = [&](string_view str) { + if (scan_op.Matches(str)) + res->emplace_back(str); + }; do { - auto scan_callback = [&](sds ptr) { - if (string_view str{ptr, sdslen(ptr)}; scan_op.Matches(str)) - res->emplace_back(str); - }; - curs = ss->Scan(curs, scan_callback); + curs = VisitSet(obj_, [&](auto* s) { + return s->Scan(static_cast(curs), [&](auto key) { + if constexpr (std::is_same_v) + record(string_view{key, sdslen(key)}); + else + record(key); + }); + }); } while (curs && maxiterations-- && res->size() < count && (base::CycleClock::Now() - start_cycles) < timeout_cycles); return curs; } - explicit operator StringSet*() const { - return ss; + template void ForEach(Cb&& cb) const { + VisitSet(obj_, [&](auto* s) { + for (auto it = s->begin(); it != s->end(); ++it) + cb(Key(it)); + }); } - StringSet* operator->() const { - return ss; + size_t UpperBoundSize() const { + return VisitSet(obj_, [](auto* s) { return s->UpperBoundSize(); }); + } + bool Empty() const { + return VisitSet(obj_, [](auto* s) { return s->Empty(); }); + } + bool Contains(string_view member) const { + return VisitSet(obj_, [member](auto* s) { return s->Contains(member); }); } - auto Range() const { - auto transform = [](sds ptr) { return string_view{ptr, sdslen(ptr)}; }; - return base::it::Transform(transform, base::it::Range(ss->begin(), ss->end())); + void* obj() const { + return obj_; } private: - StringSetWrapper(void* robj_ptr, uint64_t now_ms) : ss(static_cast(robj_ptr)) { - ss->set_time(MemberTimeSeconds(now_ms)); + StringSetWrapper(void* robj_ptr, uint64_t now_ms) : obj_(robj_ptr) { + DenseSetTouchTime(obj_, now_ms); } - StringSet* const ss; + void* const obj_; }; // returns (removed, isempty) @@ -210,7 +241,7 @@ uint32_t SetTypeLen(const DbContext& db_context, const SetType& set) { if (set.second == kEncodingIntSet) { return intsetLen((const intset*)set.first); } else { - return StringSetWrapper(set, db_context)->UpperBoundSize(); + return StringSetWrapper(set, db_context).UpperBoundSize(); } } @@ -222,7 +253,7 @@ bool IsInSet(const DbContext& db_context, const SetType& st, int64_t val) { char* next = absl::numbers_internal::FastIntToBuffer(val, buf); string_view str{buf, size_t(next - buf)}; - return StringSetWrapper(st, db_context)->Contains(str); + return StringSetWrapper(st, db_context).Contains(str); } bool IsInSet(const DbContext& db_context, const SetType& st, string_view member) { @@ -233,7 +264,7 @@ bool IsInSet(const DbContext& db_context, const SetType& st, string_view member) return intsetFind((intset*)st.first, llval); } else { - return StringSetWrapper(st, db_context)->Contains(member); + return StringSetWrapper(st, db_context).Contains(member); } } @@ -247,23 +278,23 @@ int32_t GetExpiry(const DbContext& db_context, const SetType& st, string_view me return -1; } else { StringSetWrapper ss{st, db_context}; - auto it = ss->Find(member); - if (it == ss->end()) - return -3; - - return it.HasExpiry() ? it.ExpiryTime() : -1; + return VisitSet(ss.obj(), [member](auto* s) -> int32_t { + auto it = s->Find(member); + if (it == s->end()) + return -3; + return it.HasExpiry() ? it.ExpiryTime() : -1; + }); } } // Removes arg from result. void DiffStrSet(const DbContext& db_context, const SetType& st, absl::flat_hash_set* result) { - for (string_view entry : StringSetWrapper{st, db_context}.Range()) - result->erase(entry); + StringSetWrapper{st, db_context}.ForEach([&](string_view entry) { result->erase(entry); }); } void InterStrSet(const DbContext& db_context, const vector& vec, StringVec* result) { - for (string_view str : StringSetWrapper{vec.front(), db_context}.Range()) { + StringSetWrapper{vec.front(), db_context}.ForEach([&](string_view str) { size_t j = 1; for (j = 1; j < vec.size(); ++j) { if (vec[j].first != vec.front().first && !IsInSet(db_context, vec[j], str)) { @@ -274,22 +305,23 @@ void InterStrSet(const DbContext& db_context, const vector& vec, String if (j == vec.size()) { result->emplace_back(str); } - } + }); } template > -StringVec RandMemberStrSetPicky(StringSet* strset, size_t count) { +StringVec RandMemberStrSetPicky(const StringSetWrapper& strset, size_t count) { C picks; picks.reserve(count); - size_t tries = 0; - while (picks.size() < count && tries++ < count * 2) { - auto it = strset->GetRandomMember(); - if (it == strset->end()) - break; - sds member = *it; - picks.insert(picks.end(), {member, sdslen(member)}); - } + VisitSet(strset.obj(), [&](auto* s) { + size_t tries = 0; + while (picks.size() < count && tries++ < count * 2) { + auto it = s->GetRandomMember(); + if (it == s->end()) + break; + picks.insert(picks.end(), string{Key(it)}); + } + }); if constexpr (is_same_v) return picks; @@ -301,13 +333,12 @@ StringVec RandMemberStrSet(const DbContext& db_context, const CompactObj& co, CHECK(IsDenseEncoding(co)); StringSetWrapper strset{co, db_context}; - // If the set is small, extract entries with StringSet::GetRandomMember - if (picks_count * 5 < strset->UpperBoundSize()) { - StringSet* ss(strset); + // If the set is small, extract entries with random sampling. + if (picks_count * 5 < strset.UpperBoundSize()) { if (bool unique = (dynamic_cast(&generator) != nullptr); unique) - return RandMemberStrSetPicky(ss, picks_count); + return RandMemberStrSetPicky(strset, picks_count); else - return RandMemberStrSetPicky(ss, picks_count); + return RandMemberStrSetPicky(strset, picks_count); } std::unordered_map times_index_is_picked; @@ -319,13 +350,13 @@ StringVec RandMemberStrSet(const DbContext& db_context, const CompactObj& co, result.reserve(picks_count); std::uint32_t ss_entry_index = 0; - for (string_view str : strset.Range()) { + strset.ForEach([&](string_view str) { auto it = times_index_is_picked.find(ss_entry_index++); if (it != times_index_is_picked.end()) { while (it->second--) result.emplace_back(str); } - } + }); /* Equal elements in the result are always successive. So, it is necessary to shuffle them */ absl::BitGen gen; std::shuffle(result.begin(), result.end(), gen); @@ -544,7 +575,7 @@ OpResult OpAdd(const OpArgs& op_args, std::string_view key, const NewE if (!success) { co.SetRObjPtr(is); - StringSet* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); + void* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); if (!ss) { return OpStatus::OUT_OF_MEMORY; } @@ -594,7 +625,7 @@ OpResult OpAddEx(const OpArgs& op_args, string_view key, uint32_t ttl_ // Update stats and trigger any handle the old value if needed. if (co.Encoding() == kEncodingIntSet) { intset* is = (intset*)co.RObjPtr(); - StringSet* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); + void* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); if (!ss) { return OpStatus::OUT_OF_MEMORY; } @@ -737,8 +768,7 @@ OpResult OpUnion(const OpArgs& op_args, ShardArgs::Iterator start, if (find_res) { const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { - StringSet* ss = (StringSet*)pv.RObjPtr(); - ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); + DenseSetTouchTime(pv.RObjPtr(), op_args.db_cntx.time_now_ms); } container_utils::IterateSet(pv, [&uniques](container_utils::ContainerEntry ce) { uniques.emplace(ce.ToString()); @@ -771,8 +801,7 @@ OpResult OpDiff(const OpArgs& op_args, ShardArgs::Iterator start, absl::flat_hash_set uniques; const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { - StringSet* ss = (StringSet*)pv.RObjPtr(); - ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); + DenseSetTouchTime(pv.RObjPtr(), op_args.db_cntx.time_now_ms); } container_utils::IterateSet(pv, [&uniques](container_utils::ContainerEntry ce) { @@ -834,8 +863,7 @@ OpResult OpInter(const Transaction* t, EngineShard* es, bool remove_f const PrimeValue& pv = find_res.value()->second; if (IsDenseEncoding(pv)) { - StringSet* ss = (StringSet*)pv.RObjPtr(); - ss->set_time(MemberTimeSeconds(t->GetDbContext().time_now_ms)); + DenseSetTouchTime(pv.RObjPtr(), t->GetDbContext().time_now_ms); } result.reserve(pv.Size()); @@ -962,8 +990,7 @@ OpResult OpPop(const OpArgs& op_args, string_view key, unsigned count * the number of elements inside the set: simply return the whole set. */ if (count >= size) { if (IsDenseEncoding(co)) { - StringSet* ss = (StringSet*)co.RObjPtr(); - ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); + DenseSetTouchTime(co.RObjPtr(), op_args.db_cntx.time_now_ms); } StringVec result; @@ -1606,7 +1633,7 @@ bool SetFamily::DeleteSetIfEmpty(DbSlice& db_slice, const DbContext& db_cntx, st if (!IsDenseEncoding(pv)) return false; - if (StringSet* ss = (StringSet*)pv.RObjPtr(); !ss->Empty()) + if (!VisitSet(pv.RObjPtr(), [](auto* s) { return s->Empty(); })) return false; if (auto res = db_slice.FindMutable(db_cntx, key, OBJ_SET); res) { @@ -1630,7 +1657,7 @@ auto SetFamily::LoadIntSetBlob(std::string_view blob, PrimeValue* pv) -> LoadBlo unsigned len = intsetLen(is); if (len > SetFamily::MaxIntsetEntries()) { - StringSet* set = SetFamily::ConvertToStrSet(is, len); + void* set = SetFamily::ConvertToStrSet(is, len); if (!set) { LOG(ERROR) << "OOM in ConvertToStrSet " << len; @@ -1646,46 +1673,60 @@ auto SetFamily::LoadIntSetBlob(std::string_view blob, PrimeValue* pv) -> LoadBlo return LoadBlobResult::kSuccess; } -auto SetFamily::LoadLPSetBlob(std::string_view blob, PrimeValue* pv) -> LoadBlobResult { - if (!lpValidateIntegrity((uint8_t*)blob.data(), blob.size(), 0, nullptr, nullptr)) { - LOG(ERROR) << "ListPack integrity check failed."; - return LoadBlobResult::kCorrupted; - } - - unsigned char* lp = (unsigned char*)blob.data(); - StringSet* set = CompactObj::AllocateMR(); +// Allocate a fresh dense set and populate it from the listpack. Returns nullptr +// (and deletes the partial set) if a duplicate member is detected. +template static Set* BuildDenseSetFromLP(unsigned char* lp) { + Set* set = CompactObj::AllocateMR(); for (unsigned char* cur = lpFirst(lp); cur != nullptr; cur = lpNext(lp, cur)) { unsigned char field_buf[LP_INTBUF_SIZE]; string_view elem = detail::ListpackWrap::GetView(cur, field_buf); if (!set->Add(elem)) { LOG(ERROR) << "Duplicate member " << elem; - CompactObj::DeleteMR(set); - return LoadBlobResult::kCorrupted; + CompactObj::DeleteMR(set); + return nullptr; } } - pv->InitRobj(OBJ_SET, kEncodingStrMap2, set); - return LoadBlobResult::kSuccess; + return set; } -StringSet* SetFamily::ConvertToStrSet(const intset* is, size_t expected_len) { - int64_t intele; - char buf[32]; - int ii = 0; +auto SetFamily::LoadLPSetBlob(std::string_view blob, PrimeValue* pv) -> LoadBlobResult { + if (!lpValidateIntegrity((uint8_t*)blob.data(), blob.size(), 0, nullptr, nullptr)) { + LOG(ERROR) << "ListPack integrity check failed."; + return LoadBlobResult::kCorrupted; + } - StringSet* ss = CompactObj::AllocateMR(); + unsigned char* lp = (unsigned char*)blob.data(); + void* set_ptr = g_use_oah_set ? static_cast(BuildDenseSetFromLP(lp)) + : static_cast(BuildDenseSetFromLP(lp)); + if (!set_ptr) + return LoadBlobResult::kCorrupted; + + pv->InitRobj(OBJ_SET, kEncodingStrMap2, set_ptr); + return LoadBlobResult::kSuccess; +} + +// Allocate a fresh dense set, reserve capacity, and copy each intset member as +// the decimal string form. +template static Set* BuildDenseSetFromIntSet(const intset* is, size_t expected_len) { + Set* ss = CompactObj::AllocateMR(); if (expected_len) { ss->Reserve(expected_len); } - + int64_t intele; + char buf[32]; + int ii = 0; while (intsetGet(const_cast(is), ii++, &intele)) { char* next = absl::numbers_internal::FastIntToBuffer(intele, buf); - string_view str{buf, size_t(next - buf)}; - CHECK(ss->Add(str)); + CHECK(ss->Add(string_view{buf, size_t(next - buf)})); } - return ss; } +void* SetFamily::ConvertToStrSet(const intset* is, size_t expected_len) { + return g_use_oah_set ? static_cast(BuildDenseSetFromIntSet(is, expected_len)) + : static_cast(BuildDenseSetFromIntSet(is, expected_len)); +} + using CI = CommandId; #define HFUNC(x) SetHandler(&Cmd##x) @@ -1734,7 +1775,7 @@ vector SetFamily::SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_ if (pv->Encoding() == kEncodingIntSet) { // a valid result can never be a intset, since it doesnt keep ttl intset* is = (intset*)pv->RObjPtr(); - StringSet* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); + void* ss = SetFamily::ConvertToStrSet(is, intsetLen(is)); if (!ss) { std::vector out(values.size(), -2); return out; @@ -1742,9 +1783,10 @@ vector SetFamily::SetFieldsExpireTime(const OpArgs& op_args, uint32_t ttl_ pv->InitRobj(OBJ_SET, kEncodingStrMap2, ss); } - auto ss = static_cast(pv->RObjPtr()); - ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); - return ExpireElements(ss, values, ttl_sec); + return VisitSet(pv->RObjPtr(), [&](auto* ss) { + ss->set_time(MemberTimeSeconds(op_args.db_cntx.time_now_ms)); + return ExpireElements(ss, values, ttl_sec); + }); } } // namespace dfly diff --git a/src/server/set_family.h b/src/server/set_family.h index cdf7109b9b34..c58af2d32fd1 100644 --- a/src/server/set_family.h +++ b/src/server/set_family.h @@ -25,8 +25,10 @@ class SetFamily { static uint32_t MaxIntsetEntries(); - // Returns nullptr on OOM. - static StringSet* ConvertToStrSet(const intset* is, size_t expected_len); + // Returns nullptr on OOM. The returned pointer is StringSet* if --use_oah_set is false, + // or OAHSet* if --use_oah_set is true. Callers store it as a void* in CompactObj and + // dispatch via dfly::g_use_oah_set. + static void* ConvertToStrSet(const intset* is, size_t expected_len); // returns expiry time in seconds since kMemberExpiryBase date. // returns -3 if field was not found, -1 if no ttl is associated with the item. diff --git a/src/server/test_utils.cc b/src/server/test_utils.cc index a5cb91fff888..22df31b65ede 100644 --- a/src/server/test_utils.cc +++ b/src/server/test_utils.cc @@ -20,6 +20,7 @@ extern "C" { #include "base/flags.h" #include "base/logging.h" #include "base/stl_util.h" +#include "core/oah_set.h" #include "facade/dragonfly_connection.h" #include "facade/reply_builder.h" #include "io/file_util.h" @@ -31,6 +32,7 @@ using namespace std; ABSL_DECLARE_FLAG(string, dbfilename); ABSL_DECLARE_FLAG(double, rss_oom_deny_ratio); ABSL_DECLARE_FLAG(uint32_t, num_shards); +ABSL_DECLARE_FLAG(bool, use_oah_set); ABSL_FLAG(bool, force_epoll, false, "If true, uses epoll api instead iouring to run tests"); ABSL_DECLARE_FLAG(uint32_t, acllog_max_len); ABSL_DECLARE_FLAG(bool, enable_heartbeat_rss_eviction); @@ -215,6 +217,7 @@ void BaseFamilyTest::SetUpTestSuite() { void BaseFamilyTest::SetUp() { max_memory_limit = INT_MAX; + g_use_oah_set = absl::GetFlag(FLAGS_use_oah_set); ResetService(); } diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index dba90bc7b757..3eb9fcdc44f1 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -15,6 +15,7 @@ extern "C" { #include "base/logging.h" #include "base/stl_util.h" +#include "core/oah_set.h" #include "core/sorted_map.h" #include "core/string_set.h" #include "facade/cmd_arg_parser.h" @@ -741,7 +742,8 @@ ScoredMap ScoreMapFromSet(const PrimeValue& pv, double weight, const DbContext& // Enable lazy member expiry before iterating dense sets so expired members // do not pollute the result (and so the caller can detect an emptied set). if (pv.Encoding() == kEncodingStrMap2) { - static_cast(pv.RObjPtr())->set_time(MemberTimeSeconds(db_cntx.time_now_ms)); + uint32_t t = MemberTimeSeconds(db_cntx.time_now_ms); + VisitSet(pv.RObjPtr(), [t](auto* s) { s->set_time(t); }); } ScoredMap result;