Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add: Add index_gt::merge() #572

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,108 @@ template <typename key_at, typename slot_at> void test_replacing_update() {
expect_eq(final_search[2].member.key, 44);
}

/**
* @brief Tests merging.
*/
void test_merge() {
using index_t = index_gt<>;
using distance_t = typename index_t::distance_t;
using key_t = typename index_t::key_t;
using compressed_slot_t = typename index_t::compressed_slot_t;
using member_ref_t = typename index_t::member_ref_t;
using member_cref_t = typename index_t::member_cref_t;
using member_citerator_t = typename index_t::member_citerator_t;
using add_result_t = typename index_t::add_result_t;

using value_t = float;

auto create_index = []() {
auto index_result = index_t::make();
expect(index_result);
return std::move(index_result.index);
};

struct metric_t {
std::unordered_map<compressed_slot_t, value_t> values;

metric_t() : values() {}
distance_t compute(value_t const& a, value_t const& b) {
if (b > a) {
return b - a;
} else {
return a - b;
}
}
distance_t operator()(value_t const& a, member_cref_t const& b) { return compute(a, values.at(get_slot(b))); }
distance_t operator()(value_t const& a, member_citerator_t const& b) {
return compute(a, values.at(get_slot(b)));
}
distance_t operator()(member_citerator_t const& a, member_citerator_t const& b) {
return compute(values.at(get_slot(a)), values.at(get_slot(b)));
}
};

auto add = [](index_t& index, key_t const key, value_t const value, metric_t& metric) {
auto on_success = [&](member_ref_t member) { metric.values[member.slot] = value; };
add_result_t result = index.add(key, value, metric, {}, on_success);
expect(result);
};

// Prepare index 1
auto index1 = create_index();
metric_t metric1;
expect(index1.reserve(3));
add(index1, 11, 1.1f, metric1);
add(index1, 12, 2.1f, metric1);
add(index1, 13, 3.1f, metric1);
expect_eq(index1.size(), 3);

// Prepare index 2
auto index2 = create_index();
metric_t metric2;
expect(index2.reserve(4));
add(index2, 21, -1.1f, metric2);
add(index2, 22, -2.1f, metric2);
add(index2, 23, -3.1f, metric2);
add(index2, 24, -4.1f, metric2);
expect_eq(index2.size(), 4);

// Merge indexes
char const* merge_file_path = "merge.usearch";
auto merged_index = create_index();
expect(merged_index.save(merge_file_path));
memory_mapped_file_t file{merge_file_path, true};
expect(merged_index.load(std::move(file)));
metric_t merged_metric;
auto merge_on_success = [&](member_ref_t member, value_t const& value) {
merged_metric.values[member.slot] = value;
};
auto get_value1 = [&](member_cref_t member) -> value_t& { return metric1.values[member.slot]; };
expect(merged_index.merge(index1, get_value1, merged_metric, {}, merge_on_success));
auto get_value2 = [&](member_cref_t member) -> value_t& { return metric2.values[member.slot]; };
expect(merged_index.merge(index2, get_value2, merged_metric, {}, merge_on_success));

// Assert
expect_eq(merged_index.size(), 7);
auto search = merged_index.search(0.75f, 3, merged_metric);
expect_eq(search.size(), 3);
expect_eq(static_cast<key_t>(search[0].member.key), 11);
expect_eq(static_cast<key_t>(search[1].member.key), 12);
expect_eq(static_cast<key_t>(search[2].member.key), 21);

// Re-load merged indexes
merged_index.reset();
merged_index.load(merge_file_path);

// Assert
expect_eq(merged_index.size(), 7);
search = merged_index.search(0.75f, 3, merged_metric);
expect_eq(search.size(), 3);
expect_eq(static_cast<key_t>(search[0].member.key), 11);
expect_eq(static_cast<key_t>(search[1].member.key), 12);
expect_eq(static_cast<key_t>(search[2].member.key), 21);
}

int main(int, char**) {
test_uint40();
test_cosine<float, std::int64_t, uint40_t>(10, 10);
Expand Down Expand Up @@ -1163,5 +1265,9 @@ int main(int, char**) {
test_sets<std::int64_t, slot32_t>(set_size, 20, 30);
test_strings<std::int64_t, slot32_t>();

// Test merge
std::printf("Testing merge\n");
test_merge();

return 0;
}
Loading