Skip to content

Commit

Permalink
Fix & Add all unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Beihao-Zhou committed Jul 7, 2024
1 parent a98376f commit 7bbcd93
Show file tree
Hide file tree
Showing 5 changed files with 449 additions and 207 deletions.
232 changes: 140 additions & 92 deletions src/search/hnsw_indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,49 +96,13 @@ Status Node::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& sear
return Status::OK();
}

Status Node::UpdateNeighbours(std::vector<NodeKey>& neighbours, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
std::unordered_set<NodeKey>& deleted_neighbours) {
deleted_neighbours.clear();
auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search);
auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key);
std::unordered_set<NodeKey> to_be_added{neighbours.begin(), neighbours.end()};

util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search);
for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) {
if (!iter->key().starts_with(edge_prefix)) {
break;
}
auto neighbour_edge = iter->key();
neighbour_edge.remove_prefix(edge_prefix.size());
Slice neighbour;
GetSizedString(&neighbour_edge, &neighbour);
auto neighbour_key = neighbour.ToString();

if (to_be_added.count(neighbour_key) == 0) {
batch->Delete(cf_handle, iter->key());
deleted_neighbours.insert(neighbour_key);
} else {
to_be_added.erase(neighbour_key);
}
}

for (const auto& neighbour : to_be_added) {
auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour);
batch->Put(cf_handle, edge_index_key, Slice());
}

HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage));
node_metadata.num_neighbours = static_cast<uint16_t>(neighbours.size());
PutMetadata(&node_metadata, search_key, storage, batch);
return Status::OK();
}

VectorItem::VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}
VectorItem::VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata)
: key(key), vector(std::move(vector)), metadata(metadata) {}

bool VectorItem::operator==(const VectorItem& other) const { return key == other.key; }

bool VectorItem::operator<(const VectorItem& other) const { return key < other.key; }

StatusOr<double> ComputeSimilarity(const VectorItem& left, const VectorItem& right) {
Expand Down Expand Up @@ -190,7 +154,10 @@ HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vecto

uint16_t HnswIndex::RandomizeLayer() {
std::uniform_real_distribution<double> level_dist(0.0, 1.0);
return static_cast<uint16_t>(std::floor(-std::log(level_dist(generator_)) * m_level_normalization_factor_));
double r = level_dist(generator_);
double log_val = -std::log(r);
double layer_val = log_val * m_level_normalization_factor_;
return static_cast<uint16_t>(std::floor(layer_val));
}

StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level) {
Expand All @@ -210,40 +177,49 @@ StatusOr<HnswIndex::NodeKey> HnswIndex::DefaultEntryPoint(uint16_t level) {
return {Status::NotFound, fmt::format("No node found in layer {}", level)};
}

Status HnswIndex::Connect(uint16_t layer, const NodeKey& node_key1, const NodeKey& node_key2,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto node1 = Node(node_key1, layer);
GET_OR_RET(node1.AddNeighbour(node_key2, search_key_, storage_, batch));
StatusOr<std::vector<VectorItem>> HnswIndex::DecodeNodesToVectorItems(const std::vector<NodeKey>& node_keys,
uint16_t level, const SearchKey& search_key,
engine::Storage* storage,
const HnswVectorFieldMetadata* metadata) {
std::vector<VectorItem> vector_items;
vector_items.reserve(node_keys.size());

for (const auto& neighbour_key : node_keys) {
Node neighbour_node(neighbour_key, level);
auto neighbour_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage));
vector_items.emplace_back(VectorItem(neighbour_key, std::move(neighbour_metadata.vector), metadata));
}
return vector_items;
}

auto node2 = Node(node_key2, layer);
GET_OR_RET(node2.AddNeighbour(node_key1, search_key_, storage_, batch));
Status HnswIndex::AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key1 = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key1, Slice());
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())};
}

auto edge_index_key2 = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key2, Slice());
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())};
}
return Status::OK();
}

Status HnswIndex::PruneEdges(const VectorItem& vec, const std::vector<VectorItem>& new_neighbour_vectors,
uint16_t layer, ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto node = Node(vec.key, layer);
node.DecodeNeighbours(search_key_, storage_);
std::unordered_set original_neighbours{node.neighbours.begin(), node.neighbours.end()};

uint16_t neighbours_sz = static_cast<uint16_t>(new_neighbour_vectors.size());
std::vector<NodeKey> neighbours(neighbours_sz);
for (auto i = 0; i < neighbours_sz; i++) {
auto neighbour_key = new_neighbour_vectors[i].key;
if (original_neighbours.count(neighbour_key) == 0) {
return {Status::InvalidArgument,
fmt::format("Node \"{}\" is not a neighbour of \"{}\" and can't be pruned", neighbour_key, vec.key)};
}
neighbours[i] = new_neighbour_vectors[i].key;
Status HnswIndex::RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto edge_index_key = search_key_.ConstructHnswEdge(layer, node_key1, node_key2);
auto s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}

std::unordered_set<NodeKey> deleted_neighbours;
GET_OR_RET(node.UpdateNeighbours(neighbours, search_key_, storage_, batch, deleted_neighbours));

for (const auto& key : deleted_neighbours) {
auto neighbour_node = Node(key, layer);
GET_OR_RET(neighbour_node.RemoveNeighbour(vec.key, search_key_, storage_, batch));
edge_index_key = search_key_.ConstructHnswEdge(layer, node_key2, node_key1);
s = batch->Delete(storage_->GetCFHandle(ColumnFamilyID::Search), edge_index_key);
if (!s.ok()) {
return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())};
}
return Status::OK();
}
Expand Down Expand Up @@ -290,7 +266,7 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}

while (!explore_heap.empty()) {
auto& [dist, current_vector] = explore_heap.top();
auto [dist, current_vector] = explore_heap.top();
explore_heap.pop();
if (dist > result_heap.top().first) {
break;
Expand All @@ -317,23 +293,26 @@ StatusOr<std::vector<VectorItem>> HnswIndex::SearchLayer(uint16_t level, const V
}
}
}

while (!result_heap.empty()) {
candidates.push_back(result_heap.top().second);
result_heap.pop();
}

std::reverse(candidates.begin(), candidates.end());
return candidates;
}

Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
Status HnswIndex::InsertVectorEntryInternal(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
uint16_t target_level) {
auto cf_handle = storage_->GetCFHandle(ColumnFamilyID::Search);
auto inserted_vector_item = VectorItem(std::string(key), vector, metadata_);
auto target_level = RandomizeLayer();
std::vector<VectorItem> nearest_vec_items;

if (metadata_->num_levels != 0) {
auto level = metadata_->num_levels - 1;

auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level));
std::vector<NodeKey> entry_points{default_entry_node};

Expand All @@ -345,32 +324,95 @@ Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vec
for (; level >= 0; level--) {
nearest_vec_items =
GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata_->ef_construction, entry_points));
auto connect_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level));
auto node = Node(std::string(key), level);
auto m_max = level == 0 ? 2 * metadata_->m : metadata_->m;

std::unordered_set<NodeKey> connected_edges_set;
std::unordered_map<NodeKey, std::unordered_set<NodeKey>> deleted_edges_map;

// Check against if candidate node has room for more outgoing edges
auto has_room_for_more_edges = [&](int candidate_node_num_neighbours) {
return candidate_node_num_neighbours < m_max;
};

// Check against if candidate node has room after some other nodes' are pruned in current batch
auto has_room_after_deletions = [&](const Node& candidate_node, int candidate_node_num_neighbours) {
auto it = deleted_edges_map.find(candidate_node.key);
if (it != deleted_edges_map.end()) {
int num_deleted_edges = it->second.size();
return (candidate_node_num_neighbours - num_deleted_edges) < m_max;
}
return false;
};

for (const auto& candidate_vec : candidate_vec_items) {
auto candidate_node = Node(candidate_vec.key, level);
auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key_, storage_));
uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours;

if (has_room_for_more_edges(candidate_node_num_neighbours) ||
has_room_after_deletions(candidate_node, candidate_node_num_neighbours)) {
GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
connected_edges_set.insert(candidate_node.key);
continue;
}

for (const auto& connected_vec_item : connect_vec_items) {
GET_OR_RET(Connect(level, inserted_vector_item.key, connected_vec_item.key, batch));
// Re-evaluate the neighbours for the candidate node
candidate_node.DecodeNeighbours(search_key_, storage_);
auto candidate_node_neighbour_vec_items =
GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key_, storage_, metadata_));
candidate_node_neighbour_vec_items.push_back(inserted_vector_item);
auto sorted_neighbours = GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level));

bool inserted_node_is_selected = std::find(sorted_neighbours.begin(), sorted_neighbours.end(),
inserted_vector_item) != sorted_neighbours.end();
if (inserted_node_is_selected) {
// Add the edge between candidate and inserted node
GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch));
connected_edges_set.insert(candidate_node.key);

auto find_deleted_item = [&](const std::vector<VectorItem>& candidate_node_neighbour_vec_items,
const std::vector<VectorItem>& sorted_neighbours) -> VectorItem {
auto it = std::find_if(candidate_node_neighbour_vec_items.begin(), candidate_node_neighbour_vec_items.end(),
[&](const VectorItem& item) {
return std::find(sorted_neighbours.begin(), sorted_neighbours.end(), item) ==
sorted_neighbours.end();
});
return *it;
};

// Remove the edge for candidate and the pruned node
auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours);
GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch));
deleted_edges_map[candidate_node.key].insert(deleted_node.key);
deleted_edges_map[deleted_node.key].insert(candidate_node.key);
}
}

for (const auto& connected_vec_item : connect_vec_items) {
auto connected_node = Node(connected_vec_item.key, level);
auto connected_node_metadata = GET_OR_RET(connected_node.DecodeMetadata(search_key_, storage_));

uint16_t connected_node_num_neighbours = connected_node_metadata.num_neighbours;
auto m_max = level == 0 ? 2 * metadata_->m : metadata_->m;
if (connected_node_num_neighbours <= m_max) continue;

connected_node.DecodeNeighbours(search_key_, storage_);
std::vector<VectorItem> connected_node_neighbour_vec_items;
for (const auto& connected_node_neighbour_key : connected_node.neighbours) {
Node connected_node_neighbour = Node(connected_node_neighbour_key, level);
auto connected_node_neighbour_metadata =
GET_OR_RET(connected_node_neighbour.DecodeMetadata(search_key_, storage_));
auto neighbour_vector =
VectorItem(connected_node_neighbour_key, std::move(connected_node_neighbour_metadata.vector), metadata_);
connected_node_neighbour_vec_items.push_back(neighbour_vector);
// Update inserted node metadata
HnswNodeFieldMetadata node_metadata(static_cast<uint16_t>(connected_edges_set.size()), vector);
node.PutMetadata(&node_metadata, search_key_, storage_, batch);

// Update modified nodes metadata
for (const auto& node_edges : deleted_edges_map) {
auto& current_node_key = node_edges.first;
auto current_node = Node(current_node_key, level);
auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key_, storage_));
auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size();
if (connected_edges_set.count(current_node_key) != 0) {
new_num_neighbours++;
connected_edges_set.erase(current_node_key);
}
auto new_neighbors = GET_OR_RET(SelectNeighbors(connected_vec_item, connected_node_neighbour_vec_items, level));
GET_OR_RET(PruneEdges(connected_vec_item, new_neighbors, level, batch));
current_node_metadata.num_neighbours = new_num_neighbours;
current_node.PutMetadata(&current_node_metadata, search_key_, storage_, batch);
}

for (const auto& current_node_key : connected_edges_set) {
auto current_node = Node(current_node_key, level);
HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key_, storage_));
current_node_metadata.num_neighbours++;
current_node.PutMetadata(&current_node_metadata, search_key_, storage_, batch);
}

entry_points.clear();
Expand Down Expand Up @@ -400,4 +442,10 @@ Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vec
return Status::OK();
}

Status HnswIndex::InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch) {
auto target_level = RandomizeLayer();
return InsertVectorEntryInternal(key, vector, batch, target_level);
}

} // namespace redis
16 changes: 10 additions & 6 deletions src/search/hnsw_indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ struct Node {
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status UpdateNeighbours(std::vector<NodeKey>& neighbours, const SearchKey& search_key, engine::Storage* storage,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch,
std::unordered_set<NodeKey>& deleted_neighbours);

friend class HnswIndex;
};

Expand All @@ -66,6 +62,7 @@ struct VectorItem {
VectorItem(const NodeKey& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata);
VectorItem(const NodeKey& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata);

bool operator==(const VectorItem& other) const;
bool operator<(const VectorItem& other) const;
};

Expand All @@ -84,16 +81,23 @@ class HnswIndex {

HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage);

static StatusOr<std::vector<VectorItem>> DecodeNodesToVectorItems(const std::vector<NodeKey>& node_key,
uint16_t level, const SearchKey& search_key,
engine::Storage* storage,
const HnswVectorFieldMetadata* metadata);
uint16_t RandomizeLayer();
StatusOr<NodeKey> DefaultEntryPoint(uint16_t level);
Status Connect(uint16_t layer, const NodeKey& node_key1, const NodeKey& node_key2,
Status AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
Status PruneEdges(const VectorItem& vec, const std::vector<VectorItem>& new_neighbour_vectors, uint16_t layer,
Status RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);

StatusOr<std::vector<VectorItem>> SelectNeighbors(const VectorItem& vec, const std::vector<VectorItem>& vectors,
uint16_t layer);
StatusOr<std::vector<VectorItem>> SearchLayer(uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime,
const std::vector<NodeKey>& entry_points);
Status InsertVectorEntryInternal(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch, uint16_t layer);
Status InsertVectorEntry(std::string_view key, kqir::NumericArray vector,
ObserverOrUniquePtr<rocksdb::WriteBatchBase>& batch);
};
Expand Down
2 changes: 1 addition & 1 deletion src/search/search_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ struct HnswVectorFieldMetadata : IndexFieldMetadata {
uint32_t ef_construction = 200;
uint32_t ef_runtime = 10;
double epsilon = 0.01;
uint16_t num_levels = 10;
uint16_t num_levels = 0;

HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {}

Expand Down
Loading

0 comments on commit 7bbcd93

Please sign in to comment.