diff --git a/src/server/rdb_extensions.h b/src/server/rdb_extensions.h index fee5feb2e74f..b79ad78a299c 100644 --- a/src/server/rdb_extensions.h +++ b/src/server/rdb_extensions.h @@ -58,3 +58,6 @@ constexpr uint8_t RDB_OPCODE_VECTOR_INDEX = 222; // Opcode to store ShardDocIndex key-to-DocId mapping for search indices // Format: [shard_id, index_name, mapping_count, then for each mapping: key_string, doc_id] constexpr uint8_t RDB_OPCODE_SHARD_DOC_INDEX = 223; + +// Used to tag a chunk of serialized data with its stream id +constexpr uint8_t RDB_OPCODE_TAGGED_CHUNK = 224; diff --git a/src/server/rdb_load.cc b/src/server/rdb_load.cc index 824d3fc138bf..123ac4c3f0e7 100644 --- a/src/server/rdb_load.cc +++ b/src/server/rdb_load.cc @@ -287,13 +287,17 @@ void RdbLoaderBase::OpaqueObjLoader::operator()(const unique_ptr& ptr } void RdbLoaderBase::OpaqueObjLoader::operator()(const RdbSBF& src) { - SBF* sbf = - CompactObj::AllocateMR(src.grow_factor, src.fp_prob, src.max_capacity, src.prev_size, + SBF* sbf = config_.append ? pv_->GetSBF() + : CompactObj::AllocateMR( + src.grow_factor, src.fp_prob, src.max_capacity, src.prev_size, src.current_size, CompactObj::memory_resource()); for (unsigned i = 0; i < src.filters.size(); ++i) { sbf->AddFilter(src.filters[i].blob, src.filters[i].hash_cnt); } - pv_->SetSBF(sbf); + + // new obj + if (!config_.append) + pv_->SetSBF(sbf); } void RdbLoaderBase::OpaqueObjLoader::operator()(const RdbTOPK& src) { @@ -1000,7 +1004,7 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { DVLOG(3) << "Copying " << to_copy << " bytes"; ::memcpy(next, mem_buf_->InputBuffer().data(), to_copy); - mem_buf_->ConsumeInput(to_copy); + RETURN_ON_ERR(ConsumeInput(to_copy)); size -= to_copy; if (size == 0) return kOk; @@ -1016,6 +1020,7 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { if (size > 512) { // Worth reading directly into next. io::MutableBytes mb{next, size}; + RETURN_ON_ERR(ConsumeChunkBudget(size)); SET_OR_RETURN(src_->Read(mb), bytes_read); if (bytes_read < size) return RdbError(errc::rdb_file_corrupted); @@ -1045,7 +1050,7 @@ std::error_code RdbLoaderBase::FetchBuf(size_t size, void* dest) { mem_buf_->CommitWrite(bytes_read); ::memcpy(next, mem_buf_->InputBuffer().data(), size); - mem_buf_->ConsumeInput(size); + RETURN_ON_ERR(ConsumeInput(size)); return kOk; } @@ -1141,8 +1146,11 @@ auto RdbLoaderBase::FetchLzfStringObject() -> io::Result { // FetchBuf consumes the input but if we have not went through that path // we need to consume now. - if (zerocopy_decompress) - mem_buf_->ConsumeInput(clen); + if (zerocopy_decompress) { + if (auto ec = ConsumeInput(clen); ec) { + return make_unexpected(ec); + } + } return res; } @@ -1167,11 +1175,12 @@ io::Result RdbLoaderBase::FetchBinaryDouble() { } u; static_assert(sizeof(u) == sizeof(uint64_t)); - auto ec = EnsureRead(8); - if (ec) + if (auto ec = EnsureRead(8); ec) return make_unexpected(ec); uint8_t buf[8]; + if (auto ec = ConsumeChunkBudget(8); ec) + return make_unexpected(ec); mem_buf_->ReadAndConsume(8, buf); u.val = base::LE::LoadT(buf); return u.d; @@ -1385,17 +1394,20 @@ auto RdbLoaderBase::ReadSet(int rdbtype) -> io::Result { unique_ptr load_trace(new LoadTrace); size_t n = std::min(len, kMaxBlobLen); load_trace->arr.resize(n); - for (size_t i = 0; i < n; i++) { + size_t i = 0; + for (; i < n && !ChunkBudgetExhausted(); i++) { error_code ec = ReadStringObj(&load_trace->arr[i].rdb_var); if (ec) { return make_unexpected(ec); } } + // cut off extra elements we allocated but stopped short due to budget + load_trace->arr.resize(i); // If there are still unread elements, cache the number of remaining // elements, or clear if the full object has been read. - if (len > n) { - pending_read_.remaining = len - n; + if (len > i) { + pending_read_.remaining = len - i; } else if (pending_read_.remaining > 0) { pending_read_.remaining = 0; } @@ -1466,16 +1478,18 @@ auto RdbLoaderBase::ReadHMap(int rdbtype) -> io::Result { unique_ptr load_trace(new LoadTrace); size_t n = std::min(len, kMaxBlobLen); load_trace->arr.resize(n); - for (size_t i = 0; i < n; ++i) { + size_t i = 0; + for (; i < n && !ChunkBudgetExhausted(); ++i) { error_code ec = ReadStringObj(&load_trace->arr[i].rdb_var); if (ec) return make_unexpected(ec); } + load_trace->arr.resize(i); // If there are still unread elements, cache the number of remaining // elements, or clear if the full object has been read. - if (len > n) { - pending_read_.remaining = len - n; + if (len > i) { + pending_read_.remaining = len - i; } else if (pending_read_.remaining > 0) { pending_read_.remaining = 0; } @@ -1501,7 +1515,8 @@ auto RdbLoaderBase::ReadZSet(int rdbtype) -> io::Result { unique_ptr load_trace(new LoadTrace); size_t n = std::min(zsetlen, kMaxBlobLen); load_trace->arr.resize(n); - for (size_t i = 0; i < n; ++i) { + size_t i = 0; + for (; i < n && !ChunkBudgetExhausted(); ++i) { error_code ec = ReadStringObj(&load_trace->arr[i].rdb_var); if (ec) return make_unexpected(ec); @@ -1516,11 +1531,12 @@ auto RdbLoaderBase::ReadZSet(int rdbtype) -> io::Result { } load_trace->arr[i].score = score; } + load_trace->arr.resize(i); // If there are still unread elements, cache the number of remaining // elements, or clear if the full object has been read. - if (zsetlen > n) { - pending_read_.remaining = zsetlen - n; + if (zsetlen > i) { + pending_read_.remaining = zsetlen - i; } else if (pending_read_.remaining > 0) { pending_read_.remaining = 0; } @@ -1545,7 +1561,8 @@ auto RdbLoaderBase::ReadListQuicklist(int rdbtype) -> io::Result { // therefore using a smaller segment length than kMaxBlobLen. size_t n = std::min(len, 512); load_trace->arr.resize(n); - for (size_t i = 0; i < n; ++i) { + size_t i = 0; + for (; i < n && !ChunkBudgetExhausted(); ++i) { uint64_t container = QUICKLIST_NODE_CONTAINER_PACKED; if (rdbtype == RDB_TYPE_LIST_QUICKLIST_2) { SET_OR_UNEXPECT(LoadLen(nullptr), container); @@ -1568,11 +1585,12 @@ auto RdbLoaderBase::ReadListQuicklist(int rdbtype) -> io::Result { load_trace->arr[i].rdb_var = std::move(var); load_trace->arr[i].encoding = container; } + load_trace->arr.resize(i); // If there are still unread elements, cache the number of remaining // elements, or clear if the full object has been read. - if (len > n) { - pending_read_.remaining = len - n; + if (len > i) { + pending_read_.remaining = len - i; } else if (pending_read_.remaining > 0) { pending_read_.remaining = 0; } @@ -1595,7 +1613,10 @@ auto RdbLoaderBase::ReadStreams(int rdbtype) -> io::Result { load_trace->arr.resize(n * 2); error_code ec; - for (size_t i = 0; i < n; ++i) { + size_t i = 0; + // The sender always sends stream id and blob together, there is no flush between. So the budget + // check is not midway between the two entries. + for (; i < n && !ChunkBudgetExhausted(); ++i) { /* Get the master ID, the one we'll use as key of the radix tree * node: the entries inside the listpack itself are delta-encoded * relatively to this ID. */ @@ -1620,14 +1641,15 @@ auto RdbLoaderBase::ReadStreams(int rdbtype) -> io::Result { load_trace->arr[2 * i].rdb_var = std::move(stream_id); load_trace->arr[2 * i + 1].rdb_var = std::move(blob); } + load_trace->arr.resize(2 * i); // If there are still unread elements, cache the number of remaining // elements, or clear if the full object has been read. // // We only load the stream metadata and consumer groups in the final read, // so if there are still unread elements return the partial stream. - if (listpacks > n) { - pending_read_.remaining = listpacks - n; + if (listpacks > i) { + pending_read_.remaining = listpacks - i; return OpaqueObj{std::move(load_trace), rdbtype}; } @@ -1805,24 +1827,75 @@ auto RdbLoaderBase::ReadRedisJson() -> io::Result { auto RdbLoaderBase::ReadSBFImpl(bool chunking) -> io::Result { RdbSBF res; - uint64_t options; - SET_OR_UNEXPECT(LoadLen(nullptr), options); - if (options != 0) - return Unexpected(errc::rdb_file_corrupted); - SET_OR_UNEXPECT(FetchBinaryDouble(), res.grow_factor); - SET_OR_UNEXPECT(FetchBinaryDouble(), res.fp_prob); - if (res.fp_prob <= 0 || res.fp_prob > 0.5) { - return Unexpected(errc::rdb_file_corrupted); - } - SET_OR_UNEXPECT(LoadLen(nullptr), res.prev_size); - SET_OR_UNEXPECT(LoadLen(nullptr), res.current_size); - SET_OR_UNEXPECT(LoadLen(nullptr), res.max_capacity); + auto is_power2 = [](size_t n) { return (n & (n - 1)) == 0; }; unsigned num_filters = 0; - SET_OR_UNEXPECT(LoadLen(nullptr), num_filters); - auto is_power2 = [](size_t n) { return (n & (n - 1)) == 0; }; - for (unsigned i = 0; i < num_filters; ++i) { + // Only read SBF metadata if not continuing. + if (!pending_read_.sbf_filter.has_value() && pending_read_.remaining == 0) { + uint64_t options; + SET_OR_UNEXPECT(LoadLen(nullptr), options); + if (options != 0) + return Unexpected(errc::rdb_file_corrupted); + SET_OR_UNEXPECT(FetchBinaryDouble(), res.grow_factor); + SET_OR_UNEXPECT(FetchBinaryDouble(), res.fp_prob); + if (res.fp_prob <= 0 || res.fp_prob > 0.5) { + return Unexpected(errc::rdb_file_corrupted); + } + SET_OR_UNEXPECT(LoadLen(nullptr), res.prev_size); + SET_OR_UNEXPECT(LoadLen(nullptr), res.current_size); + SET_OR_UNEXPECT(LoadLen(nullptr), res.max_capacity); + + SET_OR_UNEXPECT(LoadLen(nullptr), num_filters); + } else { + num_filters = pending_read_.remaining; + pending_read_.remaining = 0; + } + + // Read single filter as chunks into data starting at start_offset, will stop early if chunk + // budget exhausted + auto read_filter_chunks = [&](string& data, size_t start_offset) -> io::Result { + const size_t total = data.size(); + size_t curr_offset = start_offset; + while (curr_offset < total && !ChunkBudgetExhausted()) { + auto chunk_res = LoadLen(nullptr); + if (!chunk_res) + return make_unexpected(chunk_res.error()); + const size_t chunk_size = *chunk_res; + if (chunk_size == 0 || chunk_size > total - curr_offset) + return Unexpected(errc::rdb_file_corrupted); + if (auto ec = FetchBuf(chunk_size, data.data() + curr_offset)) + return make_unexpected(ec); + curr_offset += chunk_size; + } + return curr_offset; + }; + + auto append_filter = [&](unsigned hash_cnt, string filter_data) -> error_code { + if (const size_t bit_len = filter_data.size() * 8; !is_power2(bit_len)) + return RdbError(errc::rdb_file_corrupted); + res.filters.emplace_back(hash_cnt, std::move(filter_data)); + return {}; + }; + + // First, complete a partially read filter from the previous state if there is one + if (pending_read_.sbf_filter) { + auto& sf = *pending_read_.sbf_filter; + + SET_OR_UNEXPECT(read_filter_chunks(sf.filter_data, sf.offset), sf.offset); + + if (sf.offset < sf.filter_data.size()) { + return OpaqueObj{std::move(res), RDB_TYPE_SBF}; + } + + if (auto ec = append_filter(sf.hash_cnt, std::move(sf.filter_data))) + return make_unexpected(ec); + pending_read_.sbf_filter.reset(); + num_filters--; + } + + unsigned i = 0; + for (; i < num_filters && !ChunkBudgetExhausted(); ++i) { unsigned hash_cnt; string filter_data; SET_OR_UNEXPECT(LoadLen(nullptr), hash_cnt); @@ -1830,35 +1903,27 @@ auto RdbLoaderBase::ReadSBFImpl(bool chunking) -> io::Result { if (chunking) { size_t total_size = 0; SET_OR_UNEXPECT(LoadLen(nullptr), total_size); - if (total_size == 0) { + if (total_size == 0) return Unexpected(errc::rdb_file_corrupted); - } filter_data.resize(total_size); size_t offset = 0; - while (offset < total_size) { - size_t chunk_size = 0; - SET_OR_UNEXPECT(LoadLen(nullptr), chunk_size); - if (chunk_size == 0 || chunk_size > total_size - offset) { - return Unexpected(errc::rdb_file_corrupted); - } - error_code ec = FetchBuf(chunk_size, filter_data.data() + offset); - if (ec) { - return make_unexpected(ec); - } + SET_OR_UNEXPECT(read_filter_chunks(filter_data, 0), offset); - offset += chunk_size; + if (offset < total_size) { + pending_read_.sbf_filter = {std::move(filter_data), offset, hash_cnt}; + pending_read_.remaining = num_filters - i; + return OpaqueObj{std::move(res), RDB_TYPE_SBF}; } } else { SET_OR_UNEXPECT(FetchGenericString(), filter_data); } - size_t bit_len = filter_data.size() * 8; - if (!is_power2(bit_len)) { // must be power of two - return Unexpected(errc::rdb_file_corrupted); - } - res.filters.emplace_back(hash_cnt, std::move(filter_data)); + if (auto ec = append_filter(hash_cnt, std::move(filter_data))) + return make_unexpected(ec); } + + pending_read_.remaining = num_filters > i ? num_filters - i : 0; return OpaqueObj{std::move(res), RDB_TYPE_SBF}; } @@ -1974,11 +2039,12 @@ io::Result RdbLoaderBase::ReadCMS() { } template io::Result RdbLoaderBase::FetchInt() { - auto ec = EnsureRead(sizeof(T)); - if (ec) + if (auto ec = EnsureRead(sizeof(T)); ec) return make_unexpected(ec); char buf[16]; + if (auto ec = ConsumeChunkBudget(sizeof(T)); ec) + return make_unexpected(ec); mem_buf_->ReadAndConsume(sizeof(T), buf); return base::LE::LoadT>(buf); @@ -2020,6 +2086,18 @@ struct RdbLoader::ObjSettings { ObjSettings() = default; }; +// A key and value can be sent in chunks, which overlap with other keys and values or other journal +// entries. To maintain the state between calls we store enough information so that when a partial +// value is seen it can be mapped to a prime value which is partially built, append to the value, +// and finally save it. +struct RdbLoader::StreamState { + DbIndex db_index; + int type; + PendingRead pending_read; + ObjSettings settings; + std::string key; +}; + RdbLoader::RdbLoader(Service* service, RdbLoadContext* load_context, std::string snapshot_id) : service_{service}, load_context_(load_context), @@ -2162,7 +2240,7 @@ error_code RdbLoader::Load(io::Source* src) { if (type == RDB_OPCODE_FULLSYNC_END) { VLOG(1) << "Read RDB_OPCODE_FULLSYNC_END"; RETURN_ON_ERR(EnsureRead(8)); - mem_buf_->ConsumeInput(8); // ignore 8 bytes + RETURN_ON_ERR(ConsumeInput(8)); // ignore 8 bytes if (full_sync_cut_cb) { FlushAllShards(); // Flush as the handler awakes post load handlers @@ -2341,6 +2419,24 @@ error_code RdbLoader::Load(io::Source* src) { continue; } + if (type == RDB_OPCODE_TAGGED_CHUNK) { + ActiveTaggedChunk state; + SET_OR_RETURN(FetchInt(), state.stream_id); + SET_OR_RETURN(FetchInt(), state.remaining_payload_bytes); + current_chunk_state_ = state; + + if (stream_states_.contains(current_chunk_state_->stream_id)) { + RETURN_ON_ERR(LoadValueChunk()); + // TODO return error instead of CHECK_EQ + CHECK_EQ(current_chunk_state_->remaining_payload_bytes, 0u) + << "chunk fully consumed but payload bytes remain " + << current_chunk_state_->remaining_payload_bytes; + // This chunk is finished. So reset the state, next chunk encountered will read from map + current_chunk_state_.reset(); + } + continue; + } + if (!rdbIsObjectTypeDF(type)) { LOG(ERROR) << "Unrecognized rdb object type: " << type; LOG(ERROR) << "Last iteration: "; @@ -2354,6 +2450,16 @@ error_code RdbLoader::Load(io::Source* src) { ++keys_loaded; RETURN_ON_ERR(LoadKeyValPair(type, &settings)); settings.Reset(); + + // TODO return error instead of CHECK_EQ + if (current_chunk_state_) + CHECK_EQ(current_chunk_state_->remaining_payload_bytes, 0u) + << "chunk fully consumed but payload bytes remain " + << current_chunk_state_->remaining_payload_bytes; + + // If we just read the first chunk of a key, then reset state here because LoadKeyValPair will + // only return when the chunk finishes + current_chunk_state_.reset(); } // main load loop DVLOG(1) << "RdbLoad loop finished"; @@ -2433,14 +2539,30 @@ error_code RdbLoaderBase::EnsureReadInternal(size_t min_to_read) { return kOk; } +std::error_code RdbLoaderBase::ConsumeInput(size_t n) { + RETURN_ON_ERR(ConsumeChunkBudget(n)); + mem_buf_->ConsumeInput(n); + return kOk; +} + +std::error_code RdbLoaderBase::ConsumeChunkBudget(size_t n) { + if (!current_chunk_state_) + return kOk; + + if (n > current_chunk_state_->remaining_payload_bytes) + return RdbError(errc::rdb_file_corrupted); + + current_chunk_state_->remaining_payload_bytes -= n; + return kOk; +} + io::Result RdbLoaderBase::LoadLen(bool* is_encoded) { if (is_encoded) *is_encoded = false; // Every RDB file with rdbver >= 5 has 8-bytes checksum at the end, // so we can ensure we have 9 bytes to read up until that point. - error_code ec = EnsureRead(9); - if (ec) + if (error_code ec = EnsureRead(9)) return make_unexpected(ec); // Read integer meta info. @@ -2455,7 +2577,9 @@ io::Result RdbLoaderBase::LoadLen(bool* is_encoded) { if (meta.Type() == RDB_ENCVAL && is_encoded) *is_encoded = true; - mem_buf_->ConsumeInput(1 + meta.ByteSize()); + if (auto ec = ConsumeInput(1 + meta.ByteSize()); ec) { + return make_unexpected(ec); + } return res; } @@ -2527,6 +2651,14 @@ error_code RdbLoaderBase::HandleCompressedBlob(int op_type) { string res; SET_OR_RETURN(FetchGenericString(), res); + // Stop counting payload bytes on decompressed data. At this point the entire payload size must be + // consumed as it was the compressed blob. We switch to another buffer and must be able to read + // everything from it without any checks + if (current_chunk_state_ && current_chunk_state_->remaining_payload_bytes > 0) { + return RdbError(errc::rdb_file_corrupted); + } + current_chunk_state_.reset(); + // Decompress blob and switch membuf pointer // Last type in the compressed blob is RDB_OPCODE_COMPRESSED_BLOB_END // in which we will switch back to the origin membuf (HandleCompressedBlobFinish) @@ -2750,9 +2882,10 @@ void RdbLoader::CreateObjectOnShard(const DbContext& db_cntx, const Item* item, }; LoadConfig config_copy = item->load_config; + ChunkedKey chunked_key{db_ind, item->key}; if (item->load_config.chunked && item->load_config.append) { std::unique_lock lk{now_chunked_mu_}; - if (auto it = now_chunked_.find(item->key); it != now_chunked_.end()) { + if (auto it = now_chunked_.find(chunked_key); it != now_chunked_.end()) { pv_ptr = it->second.get(); } else { // Sets and hashes are deleted when all their entries are expired. @@ -2793,13 +2926,13 @@ void RdbLoader::CreateObjectOnShard(const DbContext& db_cntx, const Item* item, if (item->load_config.chunked) { std::unique_lock lk{now_chunked_mu_}; - if (!now_chunked_.contains(item->key)) - now_chunked_.emplace(item->key, make_unique(std::move(pv))); + if (!now_chunked_.contains(chunked_key)) + now_chunked_.emplace(chunked_key, make_unique(std::move(pv))); if (!item->load_config.finalize) return; - pv = std::move(*now_chunked_.extract(item->key).mapped()); + pv = std::move(*now_chunked_.extract(chunked_key).mapped()); } // We need this extra check because we don't return empty_key @@ -2879,79 +3012,134 @@ error_code RdbLoader::LoadKeyValPair(int type, ObjSettings* settings) { SET_OR_RETURN(ReadKey(), key); last_key_loaded_ = key; - bool dry_run = absl::GetFlag(FLAGS_rdb_load_dry_run); - bool streamed = false; + bool finalized = false; + auto remaining_payload_bytes = [&] { + return current_chunk_state_ ? current_chunk_state_->remaining_payload_bytes : UINT32_MAX; + }; + do { - // If there is a cached Item in the free pool, take it, otherwise allocate - // a new Item (LoadItemsBuffer returns free items). - Item* item = item_queue_.Pop(); - if (item == nullptr) { - item = new Item; - } - // Delete the item if we fail to load the key/val pair. - auto cleanup = absl::Cleanup([item] { delete item; }); + RETURN_ON_ERR(ReadAndDispatchObject(type, key, *settings, cur_db_index_, &finalized)); + } while (!finalized && remaining_payload_bytes() > 0 && !stop_early_.load(memory_order_relaxed)); - item->load_config.append = pending_read_.remaining > 0; + // If the object is not complete and we're in tagged chunk mode, save state for continuation. + // This is only done on the first chunk. Any next chunks will go through the LoadValueChunk method + if (!finalized && current_chunk_state_ && current_chunk_state_->stream_id != 0) { + stream_states_[current_chunk_state_->stream_id] = { + .db_index = cur_db_index_, + .type = type, + .pending_read = pending_read_, + .settings = *settings, + .key = std::move(key), + }; + } - error_code ec = ReadObj(type, &item->val); - if (ec) { - VLOG(2) << "ReadObj error " << ec << " for key " << key; - return ec; - } + int delta_ms = (absl::GetCurrentTimeNanos() - start) / 1000'000; + LOG_IF(INFO, delta_ms > 1000) << "Took " << delta_ms << " ms to load rdb_type " << type; - // If the key can be discarded, we must still continue to read the - // object from the RDB so we can read the next key. - if (ShouldDiscardKey(key, *settings)) { - pending_read_.reserve = 0; - continue; - } + pending_read_ = {}; + return kOk; +} - if (dry_run) - continue; +std::error_code RdbLoader::LoadValueChunk() { + // At this point we must have a filled in chunk state and the same stream id we loaded + CHECK(current_chunk_state_.has_value()) << "chunk load attempt without expected state"; + const auto it = stream_states_.find(current_chunk_state_->stream_id); + CHECK(it != stream_states_.end()) << "missing stream id " << current_chunk_state_->stream_id; - item->load_config.finalize = pending_read_.remaining == 0; - if (!item->load_config.finalize) { - item->key = key; - streamed = true; - } else { - // Avoid copying the key if this is the last read of the object. - item->key = std::move(key); - } + StreamState& state = it->second; + + // Restore before reading one chunk + pending_read_ = state.pending_read; - item->load_config.chunked = streamed; - item->load_config.reserve = pending_read_.reserve; - // Clear 'reserve' as we must only set when the object is first - // initialized. + bool finalized = false; + do { + RETURN_ON_ERR( + ReadAndDispatchObject(state.type, state.key, state.settings, state.db_index, &finalized)); + } while (!finalized && current_chunk_state_->remaining_payload_bytes > 0 && + !stop_early_.load(memory_order_relaxed)); + + if (finalized) { + // done reading this object + stream_states_.erase(current_chunk_state_->stream_id); + } else { + // only pending read changes from chunk -> chunk + state.pending_read = pending_read_; + } + + pending_read_ = {}; + return kOk; +} + +std::error_code RdbLoader::ReadAndDispatchObject(int object_type, std::string& key, + const ObjSettings& obj_settings, DbIndex db_index, + bool* finalized) { + Item* item = item_queue_.Pop(); + if (item == nullptr) { + item = new Item; + } + + auto cleanup = absl::Cleanup([item] { delete item; }); + + // Pending read is restored before loading a value chunk. + auto is_done = [&] { return pending_read_.remaining == 0; }; + + // For first chunk, remaining is 0. append is False. This causes CreateObjectOnShard->FromOpaque + // to build a new data structure. For all next chunks, append is always True. This causes + // FromOpaque to append. + item->load_config.append = !is_done(); + + // Read a part of the object. Updates remaining items + RETURN_ON_ERR(ReadObj(object_type, &item->val)); + + if (ShouldDiscardKey(key, obj_settings)) { pending_read_.reserve = 0; + *finalized = is_done(); + return kOk; + } - item->is_sticky = settings->is_sticky; - item->has_mc_flags = settings->has_mc_flags; - item->mc_flags = settings->mc_flags; - item->expire_ms = settings->expiretime; + if (GetFlag(FLAGS_rdb_load_dry_run)) { + *finalized = is_done(); + return kOk; + } - std::move(cleanup).Cancel(); - ShardId sid = Shard(item->key, shard_set->size()); - EngineShard* es = EngineShard::tlocal(); + item->load_config.finalize = is_done(); - if (es && es->shard_id() == sid) { - DbContext db_cntx{&namespaces->GetDefaultNamespace(), cur_db_index_, GetCurrentTimeMs()}; - CreateObjectOnShard(db_cntx, item, &db_cntx.GetDbSlice(sid)); - item_queue_.Push(item); - } else { - auto& out_buf = shard_buf_[sid]; + // Append will be false on first chunk. This field should still be true if elements are remaining. + // It is used to place items in the now_chunked_ map and support incremental parsing + item->load_config.chunked = item->load_config.append || !is_done(); - out_buf.emplace_back(item); + item->load_config.reserve = pending_read_.reserve; + pending_read_.reserve = 0; - constexpr size_t kBufSize = 64; - if (out_buf.size() >= kBufSize) { - // Despite being async, this function can block if the shard queue is full. - FlushShardAsync(sid); - } - } - } while (pending_read_.remaining > 0 && !stop_early_.load(memory_order_relaxed)); + if (item->load_config.finalize) { + item->key = std::move(key); + } else { + item->key = key; + } - int delta_ms = (absl::GetCurrentTimeNanos() - start) / 1000'000; - LOG_IF(INFO, delta_ms > 1000) << "Took " << delta_ms << " ms to load rdb_type " << type; + item->is_sticky = obj_settings.is_sticky; + item->has_mc_flags = obj_settings.has_mc_flags; + item->mc_flags = obj_settings.mc_flags; + item->expire_ms = obj_settings.expiretime; + + *finalized = item->load_config.finalize; + + std::move(cleanup).Cancel(); + + const ShardId sid = Shard(item->key, shard_set->size()); + + if (const EngineShard* es = EngineShard::tlocal(); es && es->shard_id() == sid) { + const DbContext db_cntx{&namespaces->GetDefaultNamespace(), db_index, GetCurrentTimeMs()}; + CreateObjectOnShard(db_cntx, item, &db_cntx.GetDbSlice(sid)); + item_queue_.Push(item); + } else { + auto& out_buf = shard_buf_[sid]; + out_buf.emplace_back(item); + constexpr size_t kBufSize = 64; + if (out_buf.size() >= kBufSize) { + FlushShardAsync(sid); + } + } return kOk; } diff --git a/src/server/rdb_load.h b/src/server/rdb_load.h index cc3d4f270e4a..e4c25a97da49 100644 --- a/src/server/rdb_load.h +++ b/src/server/rdb_load.h @@ -48,9 +48,9 @@ class RdbLoaderBase { }; struct RdbSBF { - double grow_factor, fp_prob; - size_t prev_size, current_size; - size_t max_capacity; + double grow_factor = 0, fp_prob = 0; + size_t prev_size = 0, current_size = 0; + size_t max_capacity = 0; struct Filter { unsigned hash_cnt; @@ -144,6 +144,19 @@ class RdbLoaderBase { // Number of elements remaining in the object. size_t remaining = 0; + + // partial state for single filter in an SBF + // when chunk size runs out mid-filter, saves the partially filled buffer and resumes on the + // next chunk. + struct SbfFilterState { + // Pre-allocated to total_size, partially filled + std::string filter_data; + // Bytes read so far, the point to which we will write next + size_t offset = 0; + // Only read on first chunk of a filter + unsigned hash_cnt = 0; + }; + std::optional sbf_filter; }; struct LoadConfig { @@ -206,6 +219,18 @@ class RdbLoaderBase { std::error_code EnsureReadInternal(size_t min_to_read); + // Wrapper to consume n bytes from mem buf, and also decrement remaining_payload_bytes if a chunk + // read is in progress + std::error_code ConsumeInput(size_t n); + + // If reading a chunk, deducts n bytes from size with error checking. No op if chunk is not being + // read such as journal data etc + std::error_code ConsumeChunkBudget(size_t n); + + bool ChunkBudgetExhausted() const { + return current_chunk_state_ && current_chunk_state_->remaining_payload_bytes == 0; + } + static void CopyStreamId(const StreamID& src, struct streamID* dest); base::IoBuf* mem_buf_ = nullptr; @@ -220,6 +245,19 @@ class RdbLoaderBase { std::optional journal_offset_ = std::nullopt; RdbVersion rdb_version_ = RDB_VERSION; PendingRead pending_read_; + + // Tracks the id and size of a chunked read, if one is in progress + struct ActiveTaggedChunk { + // Mapped to a db and key + uint32_t stream_id; + // How many bytes remaining in the current chunk. Required to know when to stop reading the + // chunk + uint32_t remaining_payload_bytes; + }; + + // Is set to current chunk being parsed. nullopt means the data being parsed is not chunk, but a + // full entry + std::optional current_chunk_state_ = std::nullopt; }; class RdbLoader : protected RdbLoaderBase { @@ -324,7 +362,18 @@ class RdbLoader : protected RdbLoaderBase { struct ObjSettings; + struct StreamState; + std::error_code LoadKeyValPair(int type, ObjSettings* settings); + + // Loads a partially chunked value. the key and maybe part of the value has already been loaded by + // LoadKeyValPair. The state is restored from stream_states_ map. + std::error_code LoadValueChunk(); + + std::error_code ReadAndDispatchObject(int object_type, std::string& key, + const ObjSettings& obj_settings, DbIndex db_index, + bool* finalized); + // Returns whether to discard the read key pair. bool ShouldDiscardKey(std::string_view key, const ObjSettings& settings) const; @@ -396,11 +445,17 @@ class RdbLoader : protected RdbLoaderBase { // A free pool of allocated unused items. base::MPSCIntrusiveQueue item_queue_; - // Map of currently chunked big values - std::unordered_map> now_chunked_; + // Map of currently chunked big values, keyed by (db index, key) to avoid + // collisions when the same key name exists in different databases, and we + // receive chunked data from >1 db with the same key name + using ChunkedKey = std::pair; + std::unordered_map, absl::Hash> now_chunked_; base::SpinLock now_chunked_mu_; // guards now_chunked_ std::string last_key_loaded_; + + // Maps stream id to partially streamed (chunked) key, value + absl::flat_hash_map stream_states_; }; } // namespace dfly diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index fd99e1f844dc..2bc2b767fb35 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -216,7 +216,6 @@ uint8_t RdbObjectType(const CompactObj& pv) { RdbSerializer::RdbSerializer(CompressionMode compression_mode, ConsumeFun consume_fun, size_t flush_threshold) : compression_mode_(compression_mode), - mem_buf_{4_KB}, tmp_buf_(nullptr), consume_fun_(std::move(consume_fun)), flush_threshold_(flush_threshold) { @@ -275,6 +274,9 @@ io::Result RdbSerializer::SaveEntry(const PrimeKey& pk, const PrimeValu return 0; } + mem_buf_controller_.StartEntry(); + absl::Cleanup cleanup = [&] { mem_buf_controller_.FinishEntry(); }; + DVLOG(3) << "Selecting " << dbid << " previous: " << last_entry_db_index_; auto ec = SelectDb(dbid); if (ec) { @@ -923,7 +925,7 @@ std::error_code RdbSerializer::WriteOpcode(uint8_t opcode) { } size_t RdbSerializer::GetBufferCapacity() const { - return mem_buf_.Capacity(); + return mem_buf_controller_.CurrentBuffer()->Capacity(); } size_t RdbSerializer::GetTempBufferSize() const { @@ -931,27 +933,31 @@ size_t RdbSerializer::GetTempBufferSize() const { } error_code RdbSerializer::WriteRaw(const io::Bytes& buf) { - mem_buf_.Reserve(mem_buf_.InputLen() + buf.size()); - IoBuf::Bytes dest = mem_buf_.AppendBuffer(); + auto mem_buf = mem_buf_controller_.CurrentBuffer(); + mem_buf->Reserve(mem_buf->InputLen() + buf.size()); + IoBuf::Bytes dest = mem_buf_controller_.CurrentBuffer()->AppendBuffer(); memcpy(dest.data(), buf.data(), buf.size()); - mem_buf_.CommitWrite(buf.size()); + mem_buf->CommitWrite(buf.size()); return error_code{}; } string RdbSerializer::Flush(FlushState flush_state) { - auto bytes = PrepareFlush(flush_state); - if (bytes.empty()) + const auto bytes = PrepareFlush(flush_state); + auto result = mem_buf_controller_.BuildBlob(bytes); + if (result.empty()) return {}; - if (bytes.size() > serialization_peak_bytes_) { - serialization_peak_bytes_ = bytes.size(); + // did not compress during PrepareFlush, try now + if (!allow_prepare_flush_compression_) { + if (auto res = CompressBlob(result); res) + result = std::move(*res); } - DVLOG(2) << "FlushToSink " << bytes.size() << " bytes"; - - string result(io::View(bytes)); + if (result.size() > serialization_peak_bytes_) { + serialization_peak_bytes_ = result.size(); + } - mem_buf_.ConsumeInput(bytes.size()); + DVLOG(2) << "FlushToSink " << result.size() << " bytes"; // After every flush we should write the DB index again because the blobs in the channel are // interleaved and multiple savers can correspond to a single writer (in case of single file rdb @@ -1026,27 +1032,28 @@ string RdbSerializer::DumpValue(const PrimeValue& obj, bool ignore_crc) { return DumpValue(&serializer, obj, ignore_crc); } -size_t RdbSerializer::SerializedLen() const { - return mem_buf_.InputLen(); -} - -io::Bytes RdbSerializer::PrepareFlush(FlushState flush_state) { - size_t sz = mem_buf_.InputLen(); +Bytes RdbSerializer::PrepareFlush(FlushState flush_state) { + auto mem_buf = mem_buf_controller_.CurrentBuffer(); + size_t sz = mem_buf->InputLen(); if (sz == 0) return {}; - bool is_last_chunk = flush_state == FlushState::kFlushEndEntry; + const bool is_last_chunk = flush_state == FlushState::kFlushEndEntry; + const bool should_compress = + is_last_chunk && number_of_chunks_ == 0 && allow_prepare_flush_compression_; VLOG(2) << "PrepareFlush:" << is_last_chunk << " " << number_of_chunks_; - if (is_last_chunk && number_of_chunks_ == 0) { + if (should_compress) { if (compression_mode_ == CompressionMode::MULTI_ENTRY_ZSTD || compression_mode_ == CompressionMode::MULTI_ENTRY_LZ4) { CompressBlob(); } } - number_of_chunks_ = is_last_chunk ? 0 : (number_of_chunks_ + 1); + if (allow_prepare_flush_compression_) { + number_of_chunks_ = is_last_chunk ? 0 : number_of_chunks_ + 1; + } - return mem_buf_.InputBuffer(); + return mem_buf->InputBuffer(); } error_code RdbSerializer::WriteJournalEntry(std::string_view serialized_entry) { @@ -1817,71 +1824,168 @@ void RdbSerializer::AllocateCompressorOnce() { } } -void RdbSerializer::CompressBlob() { - if (!compression_stats_) { +std::optional RdbSerializer::CompressBlob(std::string_view input) { + if (compression_mode_ != CompressionMode::MULTI_ENTRY_ZSTD && + compression_mode_ != CompressionMode::MULTI_ENTRY_LZ4) + return std::nullopt; + + if (!compression_stats_) compression_stats_.emplace(CompressionStats{}); - } - Bytes blob_to_compress = mem_buf_.InputBuffer(); - VLOG(2) << "CompressBlob size " << blob_to_compress.size(); - size_t blob_size = blob_to_compress.size(); + + VLOG(2) << "CompressBlob size " << input.size(); + size_t blob_size = input.size(); if (blob_size < kMinStrSizeToCompress || blob_size > kMaxStrSizeToCompress) { ++compression_stats_->size_skip_count; - return; + return std::nullopt; } AllocateCompressorOnce(); - // Compress the data. We copy compressed data once into the internal buffer of compressor_impl_ - // and then we copy it again into the mem_buf_. - // - // TODO: it is possible to avoid double copying here by changing the compressor interface, - // so that the compressor will accept the output buffer and return the final size. This requires - // exposing the additional compress bound interface as well. - io::Result res = compressor_impl_->Compress(blob_to_compress); + io::Result res = compressor_impl_->Compress( + Bytes{reinterpret_cast(input.data()), input.size()}); if (!res) { ++compression_stats_->compression_failed; - return; + return std::nullopt; } Bytes compressed_blob = *res; if (compressed_blob.length() > blob_size * kMinCompressionReductionPrecentage) { ++compression_stats_->compression_no_effective; - return; + return std::nullopt; } - // Clear membuf and write the compressed blob to it - mem_buf_.ConsumeInput(blob_size); - mem_buf_.Reserve(compressed_blob.length() + 1 + 9); // reserve space for blob + opcode + len + const uint8_t opcode = compression_mode_ == CompressionMode::MULTI_ENTRY_ZSTD + ? RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START + : RDB_OPCODE_COMPRESSED_LZ4_BLOB_START; + const size_t clen = compressed_blob.size(); + uint8_t len_buf[16]; + const unsigned encoded_size = WritePackedUInt(compressed_blob.size(), len_buf); - // First write opcode for compressed string - auto dest = mem_buf_.AppendBuffer(); - uint8_t opcode = compression_mode_ == CompressionMode::MULTI_ENTRY_ZSTD - ? RDB_OPCODE_COMPRESSED_ZSTD_BLOB_START - : RDB_OPCODE_COMPRESSED_LZ4_BLOB_START; - dest[0] = opcode; - mem_buf_.CommitWrite(1); + std::string out; + out.reserve(1 + encoded_size + clen); - // Write encoded compressed blob len - dest = mem_buf_.AppendBuffer(); - unsigned enclen = WritePackedUInt(compressed_blob.length(), dest); - mem_buf_.CommitWrite(enclen); + out.push_back(static_cast(opcode)); + out.append(reinterpret_cast(len_buf), encoded_size); + out.append(reinterpret_cast(compressed_blob.data()), compressed_blob.size()); - // Write compressed blob - dest = mem_buf_.AppendBuffer(); - memcpy(dest.data(), compressed_blob.data(), compressed_blob.length()); - mem_buf_.CommitWrite(compressed_blob.length()); ++compression_stats_->compressed_blobs; - auto& stats = ServerState::tlocal()->stats; - ++stats.compressed_blobs; + ++ServerState::tlocal()->stats.compressed_blobs; + return out; +} + +void RdbSerializer::CompressBlob() { + auto mem_buf = mem_buf_controller_.CurrentBuffer(); + Bytes blob_to_compress = mem_buf->InputBuffer(); + std::string_view input{reinterpret_cast(blob_to_compress.data()), + blob_to_compress.size()}; + auto compressed = CompressBlob(input); + if (!compressed) + return; + + mem_buf->ConsumeInput(blob_to_compress.size()); + mem_buf->Reserve(compressed->size()); + auto destination = mem_buf->AppendBuffer(); + memcpy(destination.data(), compressed->data(), compressed->size()); + mem_buf->CommitWrite(compressed->size()); } void RdbSerializer::PushToConsumerIfNeeded(FlushState flush_state) { - if (consume_fun_ && SerializedLen() > flush_threshold_) { - string blob = Flush(flush_state); - DCHECK(!blob.empty()); // SerializedLen() > 0. - consume_fun_(std::move(blob)); + if (!consume_fun_ || mem_buf_controller_.FlushableSize() <= flush_threshold_) + return; + + if (flush_state == FlushState::kFlushMidEntry) + mem_buf_controller_.MarkMidFlush(); + + string blob = Flush(flush_state); + DCHECK(!blob.empty()); // SerializedLen() > 0. + + const auto state = mem_buf_controller_.SaveStateBeforeConsume(); + consume_fun_(std::move(blob)); + mem_buf_controller_.RestoreStateAfterConsume(state); +} + +void MemBufController::StartEntry() { + active_id_ = next_id_++; + entries_.emplace(active_id_, EntryState{std::make_unique(4096), false}); + current_buffer_ = entries_[active_id_].buffer.get(); +} + +void MemBufController::FinishEntry() { + if (const auto it = entries_.find(active_id_); + it != entries_.end() && current_buffer_ == it->second.buffer.get()) + TagAndDrainToDefaultBuffer(); + + entries_.erase(active_id_); + current_buffer_ = &default_buffer_; + active_id_ = 0; +} + +void MemBufController::TagAndDrainToDefaultBuffer() { + if (current_buffer_->InputLen() == 0) + return; + + const auto bytes = current_buffer_->InputBuffer(); + const auto& entry = entries_.at(active_id_); + if (entry.was_split && send_tagged_entries_) { + const auto header = MakeTagHeader(current_buffer_->InputLen()); + default_buffer_.WriteAndCommit(header.data(), header.size()); + } + + default_buffer_.WriteAndCommit(bytes.data(), bytes.size()); + current_buffer_->ConsumeInput(bytes.size()); +} + +std::array MemBufController::MakeTagHeader(size_t size) const { + DCHECK_NE(active_id_, 0u) << "tagging when active entry is invalid"; + std::array header; + header[0] = RDB_OPCODE_TAGGED_CHUNK; + absl::little_endian::Store32(header.data() + 1, active_id_); + absl::little_endian::Store32(header.data() + 5, size); + return header; +} + +size_t MemBufController::FlushableSize() const { + auto size = current_buffer_->InputLen(); + if (current_buffer_ != &default_buffer_) + size += default_buffer_.InputLen(); + return size; +} + +MemBufController::SaveEntryState MemBufController::SaveStateBeforeConsume() { + const SaveEntryState state{active_id_, current_buffer_}; + current_buffer_ = &default_buffer_; + active_id_ = 0; + return state; +} + +void MemBufController::RestoreStateAfterConsume(const SaveEntryState state) { + active_id_ = state.id; + current_buffer_ = state.ptr; +} + +std::string MemBufController::BuildBlob(Bytes current_bytes) { + const bool has_prefix = current_buffer_ != &default_buffer_ && default_buffer_.InputLen() > 0; + const auto prefix = has_prefix ? default_buffer_.InputBuffer() : Bytes{}; + const bool should_tag = send_tagged_entries_ && active_id_ != 0 && + entries_.at(active_id_).was_split && !current_bytes.empty(); + + std::string out; + out.reserve(prefix.size() + (should_tag ? kHeaderSize : 0) + current_bytes.size()); + + if (has_prefix) { + out.append(io::View(prefix)); + default_buffer_.ConsumeInput(prefix.size()); } + + if (should_tag) { + const auto header = MakeTagHeader(current_bytes.size()); + out.append(reinterpret_cast(header.data()), header.size()); + } + + out.append(io::View(current_bytes)); + current_buffer_->ConsumeInput(current_bytes.size()); + return out; } } // namespace dfly diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index 1704e5f2af2a..99ef4a9c4dbf 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -79,6 +79,59 @@ CompressionMode GetDefaultCompressionMode(); using StringVec = std::vector; +class MemBufController { + friend class MemBufControllerTest; + struct EntryState { + std::unique_ptr buffer; + bool was_split = false; + }; + using EntryId = uint32_t; + + public: + static constexpr auto kHeaderSize = 9; + + void StartEntry(); + void FinishEntry(); + + void TagAndDrainToDefaultBuffer(); + io::IoBuf* CurrentBuffer() const { + return current_buffer_; + } + + std::array MakeTagHeader(size_t size) const; + + void MarkMidFlush() { + entries_.at(active_id_).was_split = true; + } + + size_t FlushableSize() const; + + struct SaveEntryState { + uint32_t id; + io::IoBuf* ptr; + bool operator<=>(const SaveEntryState&) const = default; + }; + + SaveEntryState SaveStateBeforeConsume(); + void RestoreStateAfterConsume(SaveEntryState state); + + std::string BuildBlob(io::Bytes current_bytes); + + void SetTagEntries(bool tag_entries) { + send_tagged_entries_ = tag_entries; + } + + private: + bool send_tagged_entries_ = false; + + EntryId next_id_ = 1; + EntryId active_id_ = 0; + + io::IoBuf default_buffer_{4096}; + io::IoBuf* current_buffer_ = &default_buffer_; + absl::flat_hash_map entries_; +}; + class RdbSaver { public: // Global data which doesn't belong to shards and is serialized in header @@ -182,7 +235,9 @@ class RdbSerializer { bool ignore_crc = false); // Internal buffer size. Might shrink after flush due to compression. - size_t SerializedLen() const; + size_t SerializedLen() const { + return mem_buf_controller_.FlushableSize(); + } // Flush internal buffer and return serialized blob. std::string Flush(FlushState flush_state); @@ -237,6 +292,11 @@ class RdbSerializer { std::error_code SendEofAndChecksum(); + void SetTagEntries(bool tag_entries) { + mem_buf_controller_.SetTagEntries(tag_entries); + allow_prepare_flush_compression_ = !tag_entries; + } + private: // Prepare internal buffer for flush. Compress it. io::Bytes PrepareFlush(FlushState flush_state); @@ -244,6 +304,7 @@ class RdbSerializer { // If membuf data is compressable use compression impl to compress the data and write it to membuf void CompressBlob(); void AllocateCompressorOnce(); + std::optional CompressBlob(std::string_view input); std::error_code SaveLzfBlob(const ::io::Bytes& src, size_t uncompressed_len); @@ -279,18 +340,24 @@ class RdbSerializer { }; CompressionMode compression_mode_; - io::IoBuf mem_buf_; std::unique_ptr compressor_impl_; std::optional compression_stats_; base::PODArray tmp_buf_; std::unique_ptr lzf_; size_t number_of_chunks_ = 0; + + // If tagged chunks are set, compression is not done during PrepareFlush on small chunks. Instead + // we compress before flush, after appending the memory buffer to stash + bool allow_prepare_flush_compression_ = true; + uint64_t serialization_peak_bytes_ = 0; std::string tmp_str_; DbIndex last_entry_db_index_ = kInvalidDbId; ConsumeFun consume_fun_; size_t flush_threshold_ = 0; + + MemBufController mem_buf_controller_; }; } // namespace dfly diff --git a/src/server/rdb_test.cc b/src/server/rdb_test.cc index 4c419e4ceafd..7b97307913fb 100644 --- a/src/server/rdb_test.cc +++ b/src/server/rdb_test.cc @@ -3,6 +3,9 @@ // #include +#include "rdb_extensions.h" +#include "serializer_commons.h" + extern "C" { #include "redis/crc64.h" #include "redis/listpack.h" @@ -40,6 +43,7 @@ ABSL_DECLARE_FLAG(uint32_t, num_shards); ABSL_DECLARE_FLAG(bool, rdb_sbf_chunked); ABSL_DECLARE_FLAG(bool, serialize_hnsw_index); ABSL_DECLARE_FLAG(bool, deserialize_hnsw_index); +ABSL_DECLARE_FLAG(bool, serialization_tagged_chunks); namespace dfly { @@ -65,8 +69,9 @@ class RdbTest : public BaseFamilyTest { void RdbTest::SetUp() { // Setting max_memory_limit must be before calling InitWithDbFilename max_memory_limit = 40000000; - absl::SetFlag(&FLAGS_serialize_hnsw_index, true); - absl::SetFlag(&FLAGS_deserialize_hnsw_index, true); + SetFlag(&FLAGS_serialize_hnsw_index, true); + SetFlag(&FLAGS_deserialize_hnsw_index, true); + SetFlag(&FLAGS_serialization_tagged_chunks, true); InitWithDbFilename(); CHECK_EQ(zmalloc_used_memory_tl, 0); } @@ -1216,4 +1221,474 @@ TEST_F(RdbTest, TopkSerializationDecayParameter) { EXPECT_THAT(resp2, RespArray(ElementsAre("item3", "item4"))); } +class MemBufControllerTest : public Test { + protected: + MemBufController controller_; + uint32_t ActiveId() const { + return controller_.active_id_; + } + + const io::IoBuf* DefaultBuffer() const { + return &controller_.default_buffer_; + } + + auto& Entries() const { + return controller_.entries_; + } + + std::string Flush() { + auto current = controller_.CurrentBuffer()->InputBuffer(); + const auto blob = controller_.BuildBlob(current); + EXPECT_EQ(controller_.FlushableSize(), 0); + return blob; + } + + void Write(std::string_view s) { + controller_.CurrentBuffer()->WriteAndCommit(s.data(), s.size()); + } + + void AssertDefaultState() const { + EXPECT_EQ(ActiveId(), 0); + EXPECT_EQ(controller_.CurrentBuffer(), DefaultBuffer()); + } + + void MarkMidFlush() { + controller_.MarkMidFlush(); + EXPECT_TRUE(Entries().at(controller_.active_id_).was_split); + } +}; + +TEST_F(MemBufControllerTest, StartAndEndEntry) { + controller_.StartEntry(); + EXPECT_EQ(ActiveId(), 1); + EXPECT_NE(controller_.CurrentBuffer(), DefaultBuffer()); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + AssertDefaultState(); + EXPECT_EQ(controller_.FlushableSize(), 0); +} + +void AssertTaggedData(std::string_view blob, std::string_view expected, uint32_t expected_id = 1) { + using namespace absl::little_endian; + + EXPECT_EQ(blob.size(), MemBufController::kHeaderSize + expected.size()); + EXPECT_EQ(static_cast(blob[0]), RDB_OPCODE_TAGGED_CHUNK); + + auto id = Load32(reinterpret_cast(blob.data()) + 1); + auto len = Load32(reinterpret_cast(blob.data()) + 5); + + EXPECT_EQ(id, expected_id); + EXPECT_EQ(len, expected.size()); + EXPECT_EQ(blob.substr(9), expected); +} + +TEST_F(MemBufControllerTest, TaggedData) { + controller_.SetTagEntries(true); + + controller_.StartEntry(); + + // write some data to entry buffer + const std::string_view data = "a_a_a_"; + Write(data); + EXPECT_EQ(controller_.FlushableSize(), data.size()); + + // entry will be tagged for id=1 + MarkMidFlush(); + + AssertTaggedData(Flush(), data); + + // switch to default buffer + const auto entry_buf = controller_.CurrentBuffer(); + const auto state = controller_.SaveStateBeforeConsume(); + + EXPECT_EQ(state, MemBufController::SaveEntryState(1, entry_buf)); + AssertDefaultState(); + EXPECT_FALSE(Entries().empty()); + + EXPECT_EQ(controller_.FlushableSize(), 0); + + // write to default buffer + Write("a"); + + // restore entry buffer + controller_.RestoreStateAfterConsume(state); + + EXPECT_EQ(controller_.CurrentBuffer(), entry_buf); + EXPECT_EQ(ActiveId(), state.id); + + // flushable adds default buffer + EXPECT_EQ(controller_.FlushableSize(), 1); + + Write("b"); + EXPECT_EQ(controller_.FlushableSize(), 2); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + + const std::string blob = Flush(); + + // "a" is attached as prefix + EXPECT_EQ(blob.size(), MemBufController::kHeaderSize + 2); + EXPECT_EQ(blob[0], 'a'); + AssertTaggedData(blob.substr(1), "b"); +} + +TEST_F(MemBufControllerTest, Interleaving) { + controller_.SetTagEntries(true); + // A starts and produces first tagged chunk + controller_.StartEntry(); + Write("aaa"); + MarkMidFlush(); + AssertTaggedData(Flush(), "aaa"); + + const auto state_a = controller_.SaveStateBeforeConsume(); + AssertDefaultState(); + + // B writes while A is suspended and finishes without yielding/split + controller_.StartEntry(); + EXPECT_EQ(ActiveId(), 2); + Write("bbb"); + controller_.FinishEntry(); + EXPECT_FALSE(Entries().empty()); + + AssertDefaultState(); + // After finishing B, it is drained into default buffer + EXPECT_EQ(controller_.FlushableSize(), 3); + + // Simulate a public API Flush on serializer (such as snapshot calls) + EXPECT_EQ(Flush(), "bbb"); + + // Restore A and write its tail + controller_.RestoreStateAfterConsume(state_a); + EXPECT_EQ(ActiveId(), 1); + Write("c"); + // Drains tail of A into default buffer (with tagging) + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + + // Now default has the tagged tail of A + AssertTaggedData(Flush(), "c"); +} + +TEST_F(MemBufControllerTest, NestedInterleaving) { + controller_.SetTagEntries(true); + + // A starts, splits, and emits first chunk. + controller_.StartEntry(); + Write("aaa"); + MarkMidFlush(); + AssertTaggedData(Flush(), "aaa"); + + const auto state_a = controller_.SaveStateBeforeConsume(); + EXPECT_EQ(state_a.id, 1u); + AssertDefaultState(); + + // B starts while A is suspended and yields mid chunk + controller_.StartEntry(); + Write("bbb"); + MarkMidFlush(); + AssertTaggedData(Flush(), "bbb", 2); + + const auto state_b = controller_.SaveStateBeforeConsume(); + EXPECT_EQ(state_b.id, 2u); + AssertDefaultState(); + + // C written fully without split/yield + controller_.StartEntry(); + Write("ccc"); + controller_.FinishEntry(); + AssertDefaultState(); + EXPECT_FALSE(Entries().empty()); + + EXPECT_EQ(controller_.FlushableSize(), 3); + + // no tagging + EXPECT_EQ(Flush(), "ccc"); + + // Restore B, write tail, finish B. + controller_.RestoreStateAfterConsume(state_b); + EXPECT_EQ(ActiveId(), 2); + EXPECT_EQ(controller_.CurrentBuffer(), state_b.ptr); + Write("x"); + controller_.FinishEntry(); + EXPECT_FALSE(Entries().empty()); + + AssertTaggedData(Flush(), "x", 2); + + // Restore A, write tail, finish A + controller_.RestoreStateAfterConsume(state_a); + EXPECT_EQ(ActiveId(), 1); + EXPECT_EQ(controller_.CurrentBuffer(), state_a.ptr); + Write("y"); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + + AssertTaggedData(Flush(), "y"); +} + +TEST_F(MemBufControllerTest, TagAndDrain) { + // This is a low level api test + controller_.SetTagEntries(true); + + { + // unsplit case + controller_.StartEntry(); + Write("abc"); + controller_.TagAndDrainToDefaultBuffer(); + EXPECT_EQ(Flush(), "abc"); + controller_.FinishEntry(); + AssertDefaultState(); + } + + controller_.StartEntry(); + Write("abc"); + MarkMidFlush(); + controller_.TagAndDrainToDefaultBuffer(); + AssertTaggedData(Flush(), "abc", 2); + controller_.FinishEntry(); + AssertDefaultState(); + EXPECT_EQ(controller_.FlushableSize(), 0); +} + +TEST_F(MemBufControllerTest, BuildBlobEdgeCases) { + controller_.SetTagEntries(true); + + // plain untagged data, no entries in flight + Write("p"); + EXPECT_EQ(Flush(), "p"); + + // some untagged data, then some tagged data + Write("p"); + controller_.StartEntry(); + Write("x"); + MarkMidFlush(); + { + const auto blob = Flush(); + ASSERT_EQ(blob[0], 'p'); + AssertTaggedData(blob.substr(1), "x"); + } + + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + AssertDefaultState(); + + // Empty current does not emit 0 len tag + Write("p"); + controller_.StartEntry(); + MarkMidFlush(); + EXPECT_EQ(Flush(), "p"); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + AssertDefaultState(); + + // default prefix + unsplit current, nothing tagged + Write("p"); + controller_.StartEntry(); + Write("x"); + EXPECT_EQ(Flush(), "px"); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + AssertDefaultState(); +} + +TEST_F(MemBufControllerTest, NonTaggedData) { + // No tagging data added anywhere when send_tagged_entries_=false + controller_.StartEntry(); + + const std::string_view data = "a_a_a_"; + Write(data); + MarkMidFlush(); + EXPECT_EQ(Flush(), data); + + const auto entry_buf = controller_.CurrentBuffer(); + const auto state = controller_.SaveStateBeforeConsume(); + + EXPECT_EQ(state, MemBufController::SaveEntryState(1, entry_buf)); + AssertDefaultState(); + EXPECT_FALSE(Entries().empty()); + + EXPECT_EQ(controller_.FlushableSize(), 0); + Write("a"); + + controller_.RestoreStateAfterConsume(state); + EXPECT_EQ(controller_.CurrentBuffer(), entry_buf); + EXPECT_EQ(ActiveId(), state.id); + + EXPECT_EQ(controller_.FlushableSize(), 1); + + Write("b"); + EXPECT_EQ(controller_.FlushableSize(), 2); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + EXPECT_EQ(Flush(), "ab"); +} + +TEST_F(MemBufControllerTest, NonTaggedInterleave) { + controller_.StartEntry(); + Write("aaa"); + MarkMidFlush(); + EXPECT_EQ(Flush(), "aaa"); + + const auto state_a = controller_.SaveStateBeforeConsume(); + AssertDefaultState(); + + controller_.StartEntry(); + EXPECT_EQ(ActiveId(), 2); + Write("bbb"); + controller_.FinishEntry(); + EXPECT_FALSE(Entries().empty()); + + AssertDefaultState(); + EXPECT_EQ(Flush(), "bbb"); + + controller_.RestoreStateAfterConsume(state_a); + EXPECT_EQ(ActiveId(), 1); + Write("c"); + controller_.FinishEntry(); + EXPECT_TRUE(Entries().empty()); + EXPECT_EQ(Flush(), "c"); +} + +namespace { + +std::string WrapInRdb(std::string_view body) { + std::string out = absl::StrFormat("REDIS%04d", RDB_SER_VERSION); + out.append(body); + out.push_back(static_cast(RDB_OPCODE_EOF)); + constexpr uint8_t checksum[8] = {0}; + out.append(reinterpret_cast(checksum), sizeof(checksum)); + return out; +} + +std::error_code LoadRdbBytes(Service* service, const std::string& rdb, + std::optional expected_journal_offset) { + io::BytesSource src{io::Buffer(rdb)}; + RdbLoadContext load_context; + RdbLoader loader(service, &load_context); + auto ec = loader.Load(&src); + EXPECT_EQ(loader.journal_offset(), expected_journal_offset); + return ec; +} + +struct InterleaveHarness { + struct Pending { + std::string key; + const PrimeValue* value; + }; + + std::vector queued; + size_t next = 0; + std::string body; + RdbSerializer* serializer = nullptr; + std::optional last_journal_offset; + + void AddKey(std::string_view key, DbContext& ctx) { + auto& db = ctx.GetDbSlice(0); + auto it = db.FindReadOnly(ctx, key, OBJ_HASH); + ASSERT_TRUE(it.ok()); + queued.push_back(Pending{std::string{key}, &it.value()->second}); + } + + // picks next item from queue, inserts it as SaveEntry between a SendJournalOffset and a + // WriteJournalEntry. The injected SaveEntry will also call this same method. + void operator()(std::string blob) { + body += blob; + // No op if queue is finished + if (next >= queued.size()) + return; + + uint64_t offset = last_journal_offset.value_or(0) + 100; + last_journal_offset = offset; + ASSERT_FALSE(serializer->SendJournalOffset(offset)); + + const auto& entry = queued[next++]; + ASSERT_TRUE(serializer->SaveEntry(PrimeKey{entry.key}, *entry.value, 0, 0, 0).has_value()); + + io::StringSink sink; + JournalWriter writer(&sink); + writer.Write(journal::Entry{journal::Op::PING, 0, std::nullopt}); + ASSERT_FALSE(serializer->WriteJournalEntry(std::move(sink).str())); + } +}; + +} // namespace + +TEST_F(RdbTest, TaggedInterleavedRoundTrip) { + absl::FlagSaver fs; + SetTestFlag("cache_mode", "false"); + SetTestFlag("num_shards", "1"); + SetTestFlag("serialization_tagged_chunks", "true"); + ResetService(); + + auto fill_hash = [&](std::string_view key, int count, char ch) { + for (int i = 0; i < count; ++i) { + EXPECT_THAT(Run({"HSET", std::string{key}, StrCat("field:", i), std::string(128, ch)}), + IntArg(1)); + } + }; + + auto get_key_size = [](std::string s) { + if ((s[0] - 'A') % 3 == 2) + return 4; + return 200; + }; + + constexpr auto from = 'A'; + constexpr auto to = 'F'; + for (auto ch = from; ch <= to; ++ch) { + std::string s{ch}; + fill_hash(s, get_key_size(s), ch); + } + + std::string body; + std::optional last_journal_offset; + + pp_->at(0)->Await([&] { + DbContext ctx{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}; + + InterleaveHarness harness; + // queue up B -> F, when A will yield during SaveEntry -> PushToConsumerIfNeeded + for (auto ch = 'B'; ch <= to; ++ch) { + std::string s{ch}; + harness.AddKey(s, ctx); + } + + RdbSerializer serializer( + CompressionMode::NONE, [&](std::string blob) { harness(std::move(blob)); }, 256); + + harness.serializer = &serializer; + serializer.SetTagEntries(true); + + auto& db = ctx.GetDbSlice(0); + auto it = db.FindReadOnly(ctx, "A", OBJ_HASH); + ASSERT_TRUE(it.ok()); + + // kick off A + ASSERT_TRUE(serializer.SaveEntry(PrimeKey{"A"}, it.value()->second, 0, 0, 0).has_value()); + + if (auto tail = serializer.Flush(RdbSerializer::FlushState::kFlushEndEntry); !tail.empty()) + harness.body += tail; + + body = std::move(harness.body); + last_journal_offset = harness.last_journal_offset; + }); + + EXPECT_EQ(Run({"FLUSHALL"}), "OK"); + + auto ec = pp_->at(0)->Await( + [&] { return LoadRdbBytes(service_.get(), WrapInRdb(body), last_journal_offset); }); + ASSERT_FALSE(ec) << ec.message(); + + auto verify_hash = [&](std::string_view key, int count, char ch) { + EXPECT_EQ(CheckedInt({"HLEN", std::string{key}}), count); + for (int i = 0; i < count; ++i) { + EXPECT_EQ(Run({"HGET", std::string{key}, StrCat("field:", i)}), std::string(128, ch)); + } + }; + + for (auto ch = from; ch <= to; ++ch) { + std::string s{ch}; + verify_hash(s, get_key_size(s), ch); + } +} + } // namespace dfly diff --git a/src/server/snapshot.cc b/src/server/snapshot.cc index ce2182ace971..48c0c3d414a2 100644 --- a/src/server/snapshot.cc +++ b/src/server/snapshot.cc @@ -28,6 +28,8 @@ ABSL_FLAG(bool, point_in_time_snapshot, true, "If true replication uses point in time snapshoting"); ABSL_FLAG(bool, background_snapshotting, false, "Whether to run snapshot as a background fiber"); ABSL_FLAG(bool, serialize_hnsw_index, false, "Serialize HNSW vector index graph structure"); +ABSL_FLAG(bool, serialization_tagged_chunks, false, + "Tag each chunk for replication with type of stream"); namespace dfly { @@ -103,6 +105,10 @@ void SliceSnapshot::Start(bool stream_journal, SnapshotFlush allow_flush) { serializer_ = std::make_unique(compression_mode_, consume_fun, flush_threshold); + if (allow_flush == SnapshotFlush::kAllow) { + serializer_->SetTagEntries(absl::GetFlag(FLAGS_serialization_tagged_chunks)); + } + VLOG(1) << "DbSaver::Start - saving entries with version less than " << snapshot_version_; fb2::Fiber::Opts opts{.priority = use_background_mode_ ? fb2::FiberPriority::BACKGROUND