Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions src/core/search/hnsw_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,21 @@ struct HnswlibAdapter {
DCHECK_EQ(world_.cur_element_count.load(), 0u)
<< "RestoreFromNodes should only be called on an empty index during deserialization";

// hnswlib pairs enterpoint_node_ with maxlevel_; node levels are immutable after
// creation, so the entry point's level in the serialized set equals the live
// maxlevel at metadata capture. max(node.level) would risk OOB reads when a
// concurrent Add raised maxlevel between capture and node serialization.
size_t max_internal_id = 0;
int entrypoint_level = -1;
for (const auto& node : nodes) {
max_internal_id = std::max<size_t>(max_internal_id, node.internal_id);
if (node.internal_id == metadata.enterpoint_node)
entrypoint_level = node.level;
}
if (entrypoint_level < 0) {
// Wire-ordering invariant: GetNodesRange writes nodes by ascending internal_id
// 0..count-1 under the saver's read lock, and the loader reads them sequentially
// (LoadVectorIndexNodes), so nodes[i].internal_id == i and nodes.size() is the
// capacity we need. Validate the entry-point in O(1) and read its level
// directly — by the hnswlib invariant it equals world_.maxlevel_ at save time.
if (metadata.enterpoint_node >= nodes.size() ||
nodes[metadata.enterpoint_node].internal_id != metadata.enterpoint_node) {
LOG(ERROR) << "HNSW restore: entry point internal_id=" << metadata.enterpoint_node
<< " not present in serialized node set (" << nodes.size()
<< " out of range or wire ordering violated (" << nodes.size()
<< " nodes); skipping restore — index will be rebuilt from the keyspace";
return false;
}
if (world_.max_elements_ < max_internal_id + 1) {
world_.resizeIndex(max_internal_id + 1);
int entrypoint_level = nodes[metadata.enterpoint_node].level;
if (world_.max_elements_ < nodes.size()) {
Comment thread
BorysTheDev marked this conversation as resolved.
world_.resizeIndex(nodes.size());
}
Comment thread
BorysTheDev marked this conversation as resolved.

// Restore each node - directly set up memory and fields
Expand Down
8 changes: 4 additions & 4 deletions src/core/search/hnsw_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

namespace dfly::search {

// Wire format for HNSW index AUX. Only the entry point is persisted: capacity is
// derived from max(internal_id)+1 in the node set and maxlevel from the entry-point
// node's level (hnswlib pairs enterpoint_node_ with maxlevel_, and node levels are
// immutable after creation).
// HNSW graph state needed at restore time. Capacity is derived from nodes.size()
// (internal_ids are contiguous 0..N-1 because hnswlib uses tombstones for deletes
// and GetNodesRange writes them in order); maxlevel is the entry-point node's
// level by hnswlib invariant, looked up in O(1) at restore.
struct HnswIndexMetadata {
size_t enterpoint_node = 0;
};
Expand Down
51 changes: 14 additions & 37 deletions src/server/rdb_load.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2702,8 +2702,6 @@ error_code RdbLoader::HandleAux() {
/* Just ignored. */
} else if (auxkey == "search-index") {
LoadSearchIndexDefFromAux(std::move(auxval));
Comment thread
BorysTheDev marked this conversation as resolved.
} else if (auxkey == "hnsw-index-metadata") {
LoadHnswIndexMetadataFromAux(std::move(auxval));
} else if (auxkey == "search-synonyms") {
LoadSearchSynonymsFromAux(std::move(auxval));
} else if (auxkey == "shard-count") {
Expand Down Expand Up @@ -3055,45 +3053,29 @@ void RdbLoader::LoadSearchIndexDefFromAux(string&& def) {
LoadSearchCommandFromAux(service_, std::move(def), "FT.CREATE", "index definition", true);
}

void RdbLoader::LoadHnswIndexMetadataFromAux(string&& def) {
try {
auto json_opt = JsonFromString(def);
if (!json_opt) {
LOG(ERROR) << "Invalid HNSW index metadata JSON: " << def;
return;
}
const auto& json = *json_opt;

PendingHnswMetadata phm;
phm.index_name = json["index_name"].as<string>();
phm.field_name = json["field_name"].as<string>();
phm.metadata.enterpoint_node = json["enterpoint_node"].as<size_t>();

LOG(INFO) << "Loaded HNSW metadata for index=" << phm.index_name << " field=" << phm.field_name
<< " enterpoint=" << phm.metadata.enterpoint_node;

load_context_->AddPendingHnswMetadata(std::move(phm));
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to parse HNSW index metadata JSON: " << e.what() << " def: " << def;
}
}

error_code RdbLoader::HandleVectorIndex() {
// HNSW vector index graph data.
// Binary format: [index_key, elements_number,
// then for each node (little-endian):
// Binary format: [index_key, enterpoint_node, elements_number,
Comment thread
BorysTheDev marked this conversation as resolved.
// then for each node (little-endian, ascending internal_id 0..count-1):
// internal_id (4 bytes), global_id (8 bytes), level (4 bytes),
// for each level (0 to level): links_num (4 bytes) + links (4 bytes each)]
string index_key;
SET_OR_RETURN(FetchGenericString(), index_key);

search::HnswIndexMetadata metadata;
SET_OR_RETURN(LoadLen(nullptr), metadata.enterpoint_node);
Comment thread
BorysTheDev marked this conversation as resolved.

uint64_t elements_number;
Comment thread
BorysTheDev marked this conversation as resolved.
SET_OR_RETURN(LoadLen(nullptr), elements_number);

if (!deserialize_hnsw_index_) {
return SkipVectorIndex(index_key, elements_number);
}

if (elements_number > 0) {
load_context_->MarkHnswIndexRestored();
}
Comment thread
BorysTheDev marked this conversation as resolved.
Outdated

DCHECK_GT(shard_count_, 0u);
// Parse "index_name:field_name" from the composite key.
size_t colon_pos = index_key.rfind(':');
Expand All @@ -3104,12 +3086,12 @@ error_code RdbLoader::HandleVectorIndex() {

if (shard_count_ == shard_set->size()) {
// Same shard count: restore directly.
return RestoreVectorIndex(index_key, index_name, field_name, elements_number);
return RestoreVectorIndex(index_key, index_name, field_name, elements_number, metadata);
}

// Different shard count: load nodes and defer restoration.
// Global_ids will be remapped in PerformPostLoad after all key mappings are collected.
PendingHnswNodes pending{std::string(index_name), std::string(field_name), {}};
PendingHnswNodes pending{std::string(index_name), std::string(field_name), metadata, {}};
RETURN_ON_ERR(LoadVectorIndexNodes(elements_number, &pending.nodes));
LOG(INFO) << "Deferred HNSW index restore for " << index_key << " with " << pending.nodes.size()
<< " nodes (shard count mismatch: " << shard_count_ << " vs " << shard_set->size()
Expand Down Expand Up @@ -3179,7 +3161,8 @@ error_code RdbLoader::LoadVectorIndexNodes(uint64_t elements_number,
}

error_code RdbLoader::RestoreVectorIndex(string_view index_key, string_view index_name,
string_view field_name, uint64_t elements_number) {
string_view field_name, uint64_t elements_number,
const search::HnswIndexMetadata& metadata) {
#ifdef WITH_SEARCH
// Look up the HNSW index in the global registry. It should exist from FT.CREATE in aux.
auto hnsw_index = GlobalHnswIndexRegistry::Instance().Get(index_name, field_name);
Expand All @@ -3194,13 +3177,7 @@ error_code RdbLoader::RestoreVectorIndex(string_view index_key, string_view inde
if (nodes.empty())
return {};

auto metadata = load_context_->FindHnswMetadata(index_name, field_name);
if (!metadata) {
LOG(ERROR) << "HNSW metadata missing for " << index_key
<< "; skipping graph restore — index will be rebuilt from keyspace";
return {};
}
if (!hnsw_index->RestoreFromNodes(nodes, *metadata)) {
if (!hnsw_index->RestoreFromNodes(nodes, metadata)) {
LOG(WARNING) << "HNSW graph restore rejected for " << index_key
<< "; index will be rebuilt from keyspace";
return {};
Expand Down
6 changes: 2 additions & 4 deletions src/server/rdb_load.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,15 +395,13 @@ class RdbLoader : protected RdbLoaderBase {
// issues an FT.CREATE call, but does not start indexing
void LoadSearchIndexDefFromAux(std::string&& value);

// Load HNSW index metadata from JSON, sets metadata on the GlobalHnswIndexRegistry
void LoadHnswIndexMetadataFromAux(std::string&& value);

// Load synonyms from RESP string and issue FT.SYNUPDATE call
void LoadSearchSynonymsFromAux(std::string&& value);

// Restore HNSW vector index graph from serialized node data.
std::error_code RestoreVectorIndex(std::string_view index_key, std::string_view index_name,
std::string_view field_name, uint64_t elements_number);
std::string_view field_name, uint64_t elements_number,
const search::HnswIndexMetadata& metadata);

// Load HNSW vector index nodes into a vector for deferred restoration.
std::error_code LoadVectorIndexNodes(uint64_t elements_number,
Expand Down
51 changes: 8 additions & 43 deletions src/server/rdb_load_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ HnswRemapTable BuildRemapTable(
// Remaps global_ids in deferred HNSW nodes and restores the graphs.
// Returns the set of index names that failed restoration (to be excluded from key mappings).
absl::flat_hash_set<std::string> RemapAndRestoreHnswGraphs(
std::vector<PendingHnswNodes>& pending_nodes,
const std::vector<PendingHnswMetadata>& hnsw_metadata, const HnswRemapTable& remap_table) {
std::vector<PendingHnswNodes>& pending_nodes, const HnswRemapTable& remap_table) {
absl::flat_hash_set<std::string> failed_indices;
#ifdef WITH_SEARCH
for (auto& pn : pending_nodes) {
Expand Down Expand Up @@ -120,21 +119,7 @@ absl::flat_hash_set<std::string> RemapAndRestoreHnswGraphs(
continue;
}

const PendingHnswMetadata* phm_ptr = nullptr;
for (const auto& phm : hnsw_metadata) {
if (phm.index_name == pn.index_name && phm.field_name == pn.field_name) {
phm_ptr = &phm;
break;
}
}
if (!phm_ptr) {
LOG(ERROR) << "HNSW metadata missing for " << pn.index_name << ":" << pn.field_name
<< ". Will rebuild from scratch.";
failed_indices.insert(pn.index_name);
continue;
}

if (!hnsw_index->RestoreFromNodes(pn.nodes, phm_ptr->metadata)) {
if (!hnsw_index->RestoreFromNodes(pn.nodes, pn.metadata)) {
Comment thread
BorysTheDev marked this conversation as resolved.
LOG(WARNING) << "HNSW graph restore rejected for " << pn.index_name << ":" << pn.field_name
<< ". Will rebuild from scratch.";
failed_indices.insert(pn.index_name);
Expand Down Expand Up @@ -258,11 +243,6 @@ void RdbLoadContext::AddPendingIndexMapping(uint32_t shard_id, PendingIndexMappi
pending_index_mappings_[shard_id].emplace_back(std::move(mapping));
}

void RdbLoadContext::AddPendingHnswMetadata(PendingHnswMetadata metadata) {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
pending_hnsw_metadata_.emplace_back(std::move(metadata));
}

void RdbLoadContext::AddPendingHnswNodes(PendingHnswNodes nodes) {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
pending_hnsw_nodes_.emplace_back(std::move(nodes));
Expand All @@ -272,15 +252,8 @@ void RdbLoadContext::SetMasterShardCount(uint32_t count) {
master_shard_count_ = count;
}

std::optional<search::HnswIndexMetadata> RdbLoadContext::FindHnswMetadata(
std::string_view index_name, std::string_view field_name) const {
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
for (const auto& phm : pending_hnsw_metadata_) {
if (phm.index_name == index_name && phm.field_name == field_name) {
return phm.metadata;
}
}
return std::nullopt;
void RdbLoadContext::MarkHnswIndexRestored() {
hnsw_index_restored_.store(true, std::memory_order_relaxed);
}

std::vector<std::string> RdbLoadContext::TakePendingSynonymCommands() {
Expand All @@ -305,16 +278,15 @@ std::vector<PendingHnswNodes> RdbLoadContext::TakePendingHnswNodes() {

RdbLoadContext::PerShardMappings RdbLoadContext::RemapHnswForDifferentShardCount(
const absl::flat_hash_map<uint32_t, std::vector<PendingIndexMapping>>& index_mappings,
std::vector<PendingHnswNodes>& pending_nodes,
const std::vector<PendingHnswMetadata>& hnsw_metadata) {
std::vector<PendingHnswNodes>& pending_nodes) {
const ShardId new_shard_count = shard_set->size();

// Build remap table: index_name -> master_shard_id -> new_global_ids indexed by old doc_id.
// Freed when this function returns.
HnswRemapTable remap_table = BuildRemapTable(index_mappings, new_shard_count);

// Remap global_ids, restore HNSW graphs; failed indices are excluded from key mappings.
auto failed = RemapAndRestoreHnswGraphs(pending_nodes, hnsw_metadata, remap_table);
auto failed = RemapAndRestoreHnswGraphs(pending_nodes, remap_table);
for (const auto& name : failed) {
remap_table.erase(name);
}
Expand All @@ -333,15 +305,9 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
auto index_mappings = TakePendingIndexMappings();
auto pending_nodes = TakePendingHnswNodes();

// Extract remaining shared state under lock. After this, no member access is needed.
std::vector<PendingHnswMetadata> hnsw_metadata;
{
util::fb2::LockGuard<util::fb2::Mutex> lk(mu_);
hnsw_metadata.swap(pending_hnsw_metadata_);
}
uint32_t master_shards = master_shard_count_;

bool has_hnsw_restore = !hnsw_metadata.empty();
bool has_hnsw_restore = hnsw_index_restored_.load(std::memory_order_relaxed);

if (is_error)
return;
Expand All @@ -352,8 +318,7 @@ void RdbLoadContext::PerformPostLoad(Service* service, bool is_error) {
if (shard_count_differs && !index_mappings.empty()) {
// Remaps HNSW global_ids, restores HNSW graphs, and pre-distributes key mappings by target
// shard. The internal remap table is local to the function and freed when it returns.
auto per_shard_mappings =
RemapHnswForDifferentShardCount(index_mappings, pending_nodes, hnsw_metadata);
auto per_shard_mappings = RemapHnswForDifferentShardCount(index_mappings, pending_nodes);

// Each shard reads only its own pre-built slice — no per-shard filtering of all N keys.
shard_set->AwaitRunningOnShardQueue([&per_shard_mappings](EngineShard* es) {
Expand Down
21 changes: 8 additions & 13 deletions src/server/rdb_load_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,12 @@ struct PendingIndexMapping {
std::vector<std::pair<std::string, search::DocId>> mappings;
};

// HNSW metadata loaded from "hnsw-index-metadata" AUX fields.
struct PendingHnswMetadata {
std::string index_name;
std::string field_name;
search::HnswIndexMetadata metadata;
};

// Deferred HNSW graph nodes for restoration when shard counts differ.
// The entry-point travels with the nodes inside RDB_OPCODE_VECTOR_INDEX.
struct PendingHnswNodes {
std::string index_name;
std::string field_name;
search::HnswIndexMetadata metadata;
std::vector<search::HnswNodeData> nodes;
};

Expand All @@ -54,12 +49,13 @@ class RdbLoadContext {

void AddPendingSynonymCommand(std::string cmd);
void AddPendingIndexMapping(uint32_t shard_id, PendingIndexMapping mapping);
void AddPendingHnswMetadata(PendingHnswMetadata metadata);
void AddPendingHnswNodes(PendingHnswNodes nodes);
void SetMasterShardCount(uint32_t count);

std::optional<search::HnswIndexMetadata> FindHnswMetadata(std::string_view index_name,
std::string_view field_name) const;
// Marks that an HNSW index with non-empty graph data was received in this load session.
// Used by PerformPostLoad to tell RebuildAllIndices to populate vectors into the
// already-restored graph instead of rebuilding from scratch.
void MarkHnswIndexRestored();

// Performs post load procedures while still remaining in global LOADING state.
// Called once immediately after loading the snapshot / full sync succeeded from the coordinator.
Expand All @@ -79,15 +75,14 @@ class RdbLoadContext {
// Failed indices are excluded from the returned mappings so they fall back to a full rebuild.
PerShardMappings RemapHnswForDifferentShardCount(
const absl::flat_hash_map<uint32_t, std::vector<PendingIndexMapping>>& index_mappings,
std::vector<PendingHnswNodes>& pending_nodes,
const std::vector<PendingHnswMetadata>& hnsw_metadata);
std::vector<PendingHnswNodes>& pending_nodes);

mutable util::fb2::Mutex mu_;
std::vector<std::string> pending_synonym_cmds_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<uint32_t, std::vector<PendingIndexMapping>> pending_index_mappings_
ABSL_GUARDED_BY(mu_);
std::vector<PendingHnswMetadata> pending_hnsw_metadata_ ABSL_GUARDED_BY(mu_);
std::vector<PendingHnswNodes> pending_hnsw_nodes_ ABSL_GUARDED_BY(mu_);
std::atomic<bool> hnsw_index_restored_{false};
uint32_t master_shard_count_ = 0; // Set identically by all loaders from AUX field.
Comment thread
BorysTheDev marked this conversation as resolved.
};

Expand Down
Loading
Loading