Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions src/core/search/hnsw_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,35 +274,51 @@ struct HnswlibAdapter {
DCHECK_EQ(world_.cur_element_count.load(), 0u)
<< "RestoreFromNodes should only be called on an empty index during deserialization";

// hnswlib pairs enterpoint_node_ with maxlevel_; node levels are immutable after
// creation, so the entry point's level in the serialized set equals the live
// maxlevel at metadata capture. max(node.level) would risk OOB reads when a
// concurrent Add raised maxlevel between capture and node serialization.
size_t max_internal_id = 0;
int entrypoint_level = -1;
for (const auto& node : nodes) {
max_internal_id = std::max<size_t>(max_internal_id, node.internal_id);
if (node.internal_id == metadata.enterpoint_node)
entrypoint_level = node.level;
}
if (entrypoint_level < 0) {
// Wire-ordering invariant: GetNodesRange writes nodes by ascending internal_id
// 0..count-1 under the saver's read lock, and the loader reads them sequentially
// (LoadVectorIndexNodes), so nodes[i].internal_id == i and nodes.size() is the
// capacity we need. Verify the entry-point in O(1) and read its level directly —
// by the hnswlib invariant it equals world_.maxlevel_ at save time.
if (metadata.enterpoint_node >= nodes.size()) {
LOG(ERROR) << "HNSW restore: entry point internal_id=" << metadata.enterpoint_node
<< " not present in serialized node set (" << nodes.size()
<< " out of range (" << nodes.size()
<< " nodes); skipping restore — index will be rebuilt from the keyspace";
return false;
}
if (world_.max_elements_ < max_internal_id + 1) {
world_.resizeIndex(max_internal_id + 1);
int entrypoint_level = nodes[metadata.enterpoint_node].level;
if (world_.max_elements_ < nodes.size()) {
Comment thread
BorysTheDev marked this conversation as resolved.
world_.resizeIndex(nodes.size());
}
Comment thread
BorysTheDev marked this conversation as resolved.

// Restore each node - directly set up memory and fields
// Restore each node - directly set up memory and fields. We also enforce the
// wire-ordering invariant (nodes[i].internal_id == i) inline: if a corrupted or
// future-format wire violates it we bail out cleanly so the index is rebuilt from
// the keyspace instead of writing past the resized memory. On failure we must roll
// back the partial mutations from iterations 0..i-1 so RebuildAllIndices doesn't
// see a non-empty graph and steer into the (corrupt) restore path.
size_t restored_count = 0;

for (const auto& node : nodes) {
size_t internal_id = node.internal_id;
auto rollback_partial_state = [&]() {
for (size_t k = 0; k < restored_count; ++k) {
if (world_.linkLists_[k]) {
mi_free(world_.linkLists_[k]);
world_.linkLists_[k] = nullptr;
}
}
world_.label_lookup_.clear();
world_.cur_element_count.store(0);
Comment thread
BorysTheDev marked this conversation as resolved.
Outdated
};

// Validate internal_id is within bounds - invalid internal_id indicates corrupted data
CHECK(internal_id < world_.max_elements_);
for (size_t i = 0; i < nodes.size(); ++i) {
const auto& node = nodes[i];
if (node.internal_id != i) {
LOG(ERROR) << "HNSW restore: wire ordering invariant violated at index " << i
<< " (got internal_id=" << node.internal_id << "); index will be rebuilt "
<< "from the keyspace";
rollback_partial_state();
return false;
Comment thread
BorysTheDev marked this conversation as resolved.
}
size_t internal_id = i;

// Register label in lookup table
world_.label_lookup_[node.global_id] = internal_id;
Expand Down
8 changes: 4 additions & 4 deletions src/core/search/hnsw_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

namespace dfly::search {

// Wire format for HNSW index AUX. Only the entry point is persisted: capacity is
// derived from max(internal_id)+1 in the node set and maxlevel from the entry-point
// node's level (hnswlib pairs enterpoint_node_ with maxlevel_, and node levels are
// immutable after creation).
// HNSW graph state needed at restore time. Capacity is derived from nodes.size()
// (internal_ids are contiguous 0..N-1 because hnswlib uses tombstones for deletes
// and GetNodesRange writes them in order); maxlevel is the entry-point node's
// level by hnswlib invariant, looked up in O(1) at restore.
struct HnswIndexMetadata {
size_t enterpoint_node = 0;
};
Expand Down
9 changes: 5 additions & 4 deletions src/server/rdb_extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ constexpr uint8_t RDB_OPCODE_DF_MASK = 220; /* Mask for key properties */
constexpr uint32_t DF_MASK_FLAG_STICKY = (1 << 0);
constexpr uint32_t DF_MASK_FLAG_MC_FLAGS = (1 << 1);

// Opcode to store HNSW vector index node data for global indices
// Format: [index_name, elements_number, internal_id, global_id, level, zero_level_links_num,
// zero_level_links,
// higher_level_links_num (only if level > 0), higher_level_links (only if level > 0)]
// Opcode to store HNSW vector index node data for global indices.
// Format: [index_name, enterpoint_node, elements_number,
// then for each node in ascending internal_id 0..elements_number-1:
// internal_id, global_id, level, zero_level_links_num, zero_level_links,
// higher_level_links_num (only if level > 0), higher_level_links (only if level > 0)]
constexpr uint8_t RDB_OPCODE_VECTOR_INDEX = 222;

// Opcode to store ShardDocIndex key-to-DocId mapping for search indices
Expand Down
47 changes: 10 additions & 37 deletions src/server/rdb_load.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2702,8 +2702,6 @@ error_code RdbLoader::HandleAux() {
/* Just ignored. */
} else if (auxkey == "search-index") {
LoadSearchIndexDefFromAux(std::move(auxval));
Comment thread
BorysTheDev marked this conversation as resolved.
} else if (auxkey == "hnsw-index-metadata") {
LoadHnswIndexMetadataFromAux(std::move(auxval));
} else if (auxkey == "search-synonyms") {
LoadSearchSynonymsFromAux(std::move(auxval));
} else if (auxkey == "shard-count") {
Expand Down Expand Up @@ -3055,38 +3053,18 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
LoadSearchCommandFromAux(service_, std::move(def), "FT.CREATE", "index definition", true);
}

void RdbLoader::LoadHnswIndexMetadataFromAux(string&& def) {
try {
auto json_opt = JsonFromString(def);
if (!json_opt) {
LOG(ERROR) << "Invalid HNSW index metadata JSON: " << def;
return;
}
const auto& json = *json_opt;

PendingHnswMetadata phm;
phm.index_name = json["index_name"].as<string>();
phm.field_name = json["field_name"].as<string>();
phm.metadata.enterpoint_node = json["enterpoint_node"].as<size_t>();

LOG(INFO) << "Loaded HNSW metadata for index=" << phm.index_name << " field=" << phm.field_name
<< " enterpoint=" << phm.metadata.enterpoint_node;

load_context_->AddPendingHnswMetadata(std::move(phm));
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to parse HNSW index metadata JSON: " << e.what() << " def: " << def;
}
}

error_code RdbLoader::HandleVectorIndex() {
// HNSW vector index graph data.
// Binary format: [index_key, elements_number,
// then for each node (little-endian):
// Binary format: [index_key, enterpoint_node, elements_number,
Comment thread
BorysTheDev marked this conversation as resolved.
// then for each node (little-endian, ascending internal_id 0..count-1):
// internal_id (4 bytes), global_id (8 bytes), level (4 bytes),
// for each level (0 to level): links_num (4 bytes) + links (4 bytes each)]
string index_key;
SET_OR_RETURN(FetchGenericString(), index_key);

search::HnswIndexMetadata metadata;
SET_OR_RETURN(LoadLen(nullptr), metadata.enterpoint_node);
Comment thread
BorysTheDev marked this conversation as resolved.

uint64_t elements_number;
Comment thread
BorysTheDev marked this conversation as resolved.
SET_OR_RETURN(LoadLen(nullptr), elements_number);

Expand All @@ -3104,12 +3082,12 @@ error_code RdbLoader::HandleVectorIndex() {

if (shard_count_ == shard_set->size()) {
// Same shard count: restore directly.
return RestoreVectorIndex(index_key, index_name, field_name, elements_number);
return RestoreVectorIndex(index_key, index_name, field_name, elements_number, metadata);
}

// Different shard count: load nodes and defer restoration.
// Global_ids will be remapped in PerformPostLoad after all key mappings are collected.
PendingHnswNodes pending{std::string(index_name), std::string(field_name), {}};
PendingHnswNodes pending{std::string(index_name), std::string(field_name), metadata, {}};
RETURN_ON_ERR(LoadVectorIndexNodes(elements_number, &pending.nodes));
LOG(INFO) << "Deferred HNSW index restore for " << index_key << " with " << pending.nodes.size()
<< " nodes (shard count mismatch: " << shard_count_ << " vs " << shard_set->size()
Expand Down Expand Up @@ -3179,7 +3157,8 @@ error_code RdbLoader::LoadVectorIndexNodes(uint64_t elements_number,
}

error_code RdbLoader::RestoreVectorIndex(string_view index_key, string_view index_name,
string_view field_name, uint64_t elements_number) {
string_view field_name, uint64_t elements_number,
const search::HnswIndexMetadata& metadata) {
#ifdef WITH_SEARCH
// Look up the HNSW index in the global registry. It should exist from FT.CREATE in aux.
auto hnsw_index = GlobalHnswIndexRegistry::Instance().Get(index_name, field_name);
Expand All @@ -3194,13 +3173,7 @@ error_code RdbLoader::RestoreVectorIndex(string_view index_key, string_view inde
if (nodes.empty())
return {};

auto metadata = load_context_->FindHnswMetadata(index_name, field_name);
if (!metadata) {
LOG(ERROR) << "HNSW metadata missing for " << index_key
<< "; skipping graph restore — index will be rebuilt from keyspace";
return {};
}
if (!hnsw_index->RestoreFromNodes(nodes, *metadata)) {
if (!hnsw_index->RestoreFromNodes(nodes, metadata)) {
LOG(WARNING) << "HNSW graph restore rejected for " << index_key
<< "; index will be rebuilt from keyspace";
return {};
Expand Down
6 changes: 2 additions & 4 deletions src/server/rdb_load.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,15 +395,13 @@ class RdbLoader : protected RdbLoaderBase {
// issues an FT.CREATE call, but does not start indexing
void LoadSearchIndexDefFromAux(std::string&& value);

// Load HNSW index metadata from JSON, sets metadata on the GlobalHnswIndexRegistry
void LoadHnswIndexMetadataFromAux(std::string&& value);

// Load synonyms from RESP string and issue FT.SYNUPDATE call
void LoadSearchSynonymsFromAux(std::string&& value);

// Restore HNSW vector index graph from serialized node data.
std::error_code RestoreVectorIndex(std::string_view index_key, std::string_view index_name,
std::string_view field_name, uint64_t elements_number);
std::string_view field_name, uint64_t elements_number,
const search::HnswIndexMetadata& metadata);

// Load HNSW vector index nodes into a vector for deferred restoration.
std::error_code LoadVectorIndexNodes(uint64_t elements_number,
Expand Down
90 changes: 28 additions & 62 deletions src/server/rdb_load_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@ HnswRemapTable BuildRemapTable(
// Remaps global_ids in deferred HNSW nodes and restores the graphs.
// Returns the set of index names that failed restoration (to be excluded from key mappings).
absl::flat_hash_set<std::string> RemapAndRestoreHnswGraphs(
std::vector<PendingHnswNodes>& pending_nodes,
const std::vector<PendingHnswMetadata>& hnsw_metadata, const HnswRemapTable& remap_table) {
std::vector<PendingHnswNodes>& pending_nodes, const HnswRemapTable& remap_table) {
absl::flat_hash_set<std::string> failed_indices;
#ifdef WITH_SEARCH
for (auto& pn : pending_nodes) {
// Empty graph is a valid state, not a failure — skip restore (the index already
// matches an empty graph) and don't mark it failed.
if (pn.nodes.empty()) {
continue;
}

auto remap_it = remap_table.find(pn.index_name);

auto hnsw_index = GlobalHnswIndexRegistry::Instance().Get(pn.index_name, pn.field_name);
Expand Down Expand Up @@ -120,21 +125,7 @@ absl::flat_hash_set<std::string> RemapAndRestoreHnswGraphs(
continue;
}

const PendingHnswMetadata* phm_ptr = nullptr;
for (const auto& phm : hnsw_metadata) {
if (phm.index_name == pn.index_name && phm.field_name == pn.field_name) {
phm_ptr = &phm;
break;
}
}
if (!phm_ptr) {
LOG(ERROR) << "HNSW metadata missing for " << pn.index_name << ":" << pn.field_name
<< ". Will rebuild from scratch.";
failed_indices.insert(pn.index_name);
continue;
}

if (!hnsw_index->RestoreFromNodes(pn.nodes, phm_ptr->metadata)) {
if (!hnsw_index->RestoreFromNodes(pn.nodes, pn.metadata)) {
Comment thread
BorysTheDev marked this conversation as resolved.
LOG(WARNING) << "HNSW graph restore rejected for " << pn.index_name << ":" << pn.field_name
<< ". Will rebuild from scratch.";
failed_indices.insert(pn.index_name);
Expand Down Expand Up @@ -258,11 +249,6 @@ void RdbLoadContext::AddPendingIndexMapping(uint32_t shard_id, PendingIndexMappi
pending_index_mappings_[shard_id].emplace_back(std::move(mapping));
}

void RdbLoadContext::AddPendingHnswMetadata(PendingHnswMetadata metadata) {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
pending_hnsw_metadata_.emplace_back(std::move(metadata));
}

void RdbLoadContext::AddPendingHnswNodes(PendingHnswNodes nodes) {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
pending_hnsw_nodes_.emplace_back(std::move(nodes));
Expand All @@ -272,17 +258,6 @@ void RdbLoadContext::SetMasterShardCount(uint32_t count) {
master_shard_count_ = count;
}

std::optional<search::HnswIndexMetadata> RdbLoadContext::FindHnswMetadata(
std::string_view index_name, std::string_view field_name) const {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
for (const auto& phm : pending_hnsw_metadata_) {
if (phm.index_name == index_name && phm.field_name == field_name) {
return phm.metadata;
}
}
return std::nullopt;
}

std::vector<std::string> RdbLoadContext::TakePendingSynonymCommands() {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
std::vector<std::string> result;
Expand All @@ -305,16 +280,15 @@ std::vector<PendingHnswNodes> RdbLoadContext::TakePendingHnswNodes() {

RdbLoadContext::PerShardMappings RdbLoadContext::RemapHnswForDifferentShardCount(
const absl::flat_hash_map<uint32_t, std::vector<PendingIndexMapping>>& index_mappings,
std::vector<PendingHnswNodes>& pending_nodes,
const std::vector<PendingHnswMetadata>& hnsw_metadata) {
std::vector<PendingHnswNodes>& pending_nodes) {
const ShardId new_shard_count = shard_set->size();

// Build remap table: index_name -> master_shard_id -> new_global_ids indexed by old doc_id.
// Freed when this function returns.
HnswRemapTable remap_table = BuildRemapTable(index_mappings, new_shard_count);

// Remap global_ids, restore HNSW graphs; failed indices are excluded from key mappings.
auto failed = RemapAndRestoreHnswGraphs(pending_nodes, hnsw_metadata, remap_table);
auto failed = RemapAndRestoreHnswGraphs(pending_nodes, remap_table);
for (const auto& name : failed) {
remap_table.erase(name);
}
Expand All @@ -333,16 +307,8 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
auto index_mappings = TakePendingIndexMappings();
auto pending_nodes = TakePendingHnswNodes();

// Extract remaining shared state under lock. After this, no member access is needed.
std::vector<PendingHnswMetadata> hnsw_metadata;
{
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
hnsw_metadata.swap(pending_hnsw_metadata_);
}
uint32_t master_shards = master_shard_count_;

bool has_hnsw_restore = !hnsw_metadata.empty();

if (is_error)
return;

Expand All @@ -352,8 +318,7 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
if (shard_count_differs && !index_mappings.empty()) {
// Remaps HNSW global_ids, restores HNSW graphs, and pre-distributes key mappings by target
// shard. The internal remap table is local to the function and freed when it returns.
auto per_shard_mappings =
RemapHnswForDifferentShardCount(index_mappings, pending_nodes, hnsw_metadata);
auto per_shard_mappings = RemapHnswForDifferentShardCount(index_mappings, pending_nodes);

// Each shard reads only its own pre-built slice — no per-shard filtering of all N keys.
shard_set->AwaitRunningOnShardQueue([&per_shard_mappings](EngineShard* es) {
Expand Down Expand Up @@ -390,13 +355,14 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
// RestoreKeyIndex (above) and RebuildAllIndices (below) run in separate sequential
// AwaitRunningOnShardQueue calls, so there is no parallel index build that could interfere
// with the doc_ids assigned during key mapping restoration.
LOG(INFO) << "PostLoad: rebuilding search indices across shards has_hnsw_restore="
<< has_hnsw_restore << " rss="
// RebuildAllIndices decides per-index whether to use the restore path or rebuild from
// scratch, based on the index's actual graph + key_index state.
LOG(INFO) << "PostLoad: rebuilding search indices across shards rss="
<< strings::HumanReadableNumBytes(rss_mem_current.load(std::memory_order_relaxed));
shard_set->AwaitRunningOnShardQueue([has_hnsw_restore](EngineShard* es) {
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
OpArgs op_args{es, nullptr,
DbContext{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}};
es->search_indices()->RebuildAllIndices(op_args, has_hnsw_restore);
es->search_indices()->RebuildAllIndices(op_args);
});

// Now execute all pending synonym commands after indices are rebuilt
Expand All @@ -411,19 +377,19 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
<< strings::HumanReadableNumBytes(rss_mem_current.load(std::memory_order_relaxed));
});

// All shards completed restoration — drain pending ops.
// DrainPendingVectorUpdates sets kBuilding which allows Add calls.
if (has_hnsw_restore) {
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
OpArgs op_args{es, nullptr,
DbContext{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}};
for (const auto& name : es->search_indices()->GetIndexNames()) {
if (auto* idx = es->search_indices()->GetIndex(name)) {
idx->DrainPendingVectorUpdates(op_args);
}
// Transition every search index out of kRestoring/kSerializing into kBuilding and
// drain any journal-buffered vector updates accumulated during a restoring window.
// For indices already in kBuilding the state assignment is idempotent and the empty
// pending set returns early, so this is cheap when nothing was deferred.
shard_set->AwaitRunningOnShardQueue([](EngineShard* es) {
OpArgs op_args{es, nullptr,
DbContext{&namespaces->GetDefaultNamespace(), 0, GetCurrentTimeMs()}};
for (const auto& name : es->search_indices()->GetIndexNames()) {
if (auto* idx = es->search_indices()->GetIndex(name)) {
idx->DrainPendingVectorUpdates(op_args);
}
});
}
}
});
#endif
}

Expand Down
Loading
Loading