Skip to content

Commit 799284c

Browse files
committed
Add: Jensen-Shannon Divergence-based distance
1 parent d28dc1a commit 799284c

File tree

12 files changed

+64
-3
lines changed

12 files changed

+64
-3
lines changed

Diff for: c/lib.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ metric_kind_t metric_kind_to_cpp(usearch_metric_kind_t kind) {
2222
case usearch_metric_l2sq_k: return metric_kind_t::l2sq_k;
2323
case usearch_metric_cos_k: return metric_kind_t::cos_k;
2424
case usearch_metric_haversine_k: return metric_kind_t::haversine_k;
25+
case usearch_metric_divergence_k: return metric_kind_t::divergence_k;
2526
case usearch_metric_pearson_k: return metric_kind_t::pearson_k;
2627
case usearch_metric_jaccard_k: return metric_kind_t::jaccard_k;
2728
case usearch_metric_hamming_k: return metric_kind_t::hamming_k;
@@ -37,6 +38,7 @@ usearch_metric_kind_t metric_kind_to_c(metric_kind_t kind) {
3738
case metric_kind_t::l2sq_k: return usearch_metric_l2sq_k;
3839
case metric_kind_t::cos_k: return usearch_metric_cos_k;
3940
case metric_kind_t::haversine_k: return usearch_metric_haversine_k;
41+
case metric_kind_t::divergence_k: return usearch_metric_divergence_k;
4042
case metric_kind_t::pearson_k: return usearch_metric_pearson_k;
4143
case metric_kind_t::jaccard_k: return usearch_metric_jaccard_k;
4244
case metric_kind_t::hamming_k: return usearch_metric_hamming_k;

Diff for: c/usearch.h

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ USEARCH_EXPORT typedef enum usearch_metric_kind_t {
4343
usearch_metric_ip_k,
4444
usearch_metric_l2sq_k,
4545
usearch_metric_haversine_k,
46+
usearch_metric_divergence_k,
4647
usearch_metric_pearson_k,
4748
usearch_metric_jaccard_k,
4849
usearch_metric_hamming_k,

Diff for: cpp/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ The following distances are pre-packaged:
7676
- `metric_sorensen_gt<scalar_t>` for "Dice-Sorensen" coefficient for bit-strings.
7777
- `metric_pearson_gt<scalar_t>` for "Pearson" correlation between probability distributions.
7878
- `metric_haversine_gt<scalar_t>` for "Haversine" or "Great Circle" distance between coordinates used in GIS applications.
79+
- `metric_divergence_gt<scalar_t>` for the "Jensen Shannon" similarity between probability distributions.
7980

8081
## Multi-Threading
8182

Diff for: cpp/bench.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ struct args_t {
447447
bool metric_l2 = false;
448448
bool metric_cos = false;
449449
bool metric_haversine = false;
450+
bool metric_divergence = false;
450451
bool metric_hamming = false;
451452
bool metric_tanimoto = false;
452453
bool metric_sorensen = false;
@@ -458,6 +459,8 @@ struct args_t {
458459
return metric_kind_t::cos_k;
459460
if (metric_haversine)
460461
return metric_kind_t::haversine_k;
462+
if (metric_divergence)
463+
return metric_kind_t::divergence_k;
461464
if (metric_hamming)
462465
return metric_kind_t::hamming_k;
463466
if (metric_tanimoto)

Diff for: golang/lib.go

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ const (
2121
Cosine
2222
L2sq
2323
Haversine
24+
Divergence
2425
Pearson
2526
Hamming
2627
Tanimoto
@@ -37,6 +38,8 @@ func (m Metric) String() string {
3738
return "cos"
3839
case Haversine:
3940
return "haversine"
41+
case Divergence:
42+
return "divergence"
4043
case Pearson:
4144
return "pearson"
4245
case Hamming:

Diff for: include/usearch/index_plugins.hpp

+46-2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ enum class metric_kind_t : std::uint8_t {
112112
// Custom:
113113
pearson_k = 'p',
114114
haversine_k = 'h',
115+
divergence_k = 'd',
115116

116117
// Sets:
117118
jaccard_k = 'j',
@@ -239,6 +240,7 @@ inline char const* metric_kind_name(metric_kind_t metric) noexcept {
239240
case metric_kind_t::l2sq_k: return "l2sq";
240241
case metric_kind_t::pearson_k: return "pearson";
241242
case metric_kind_t::haversine_k: return "haversine";
243+
case metric_kind_t::divergence_k: return "divergence";
242244
case metric_kind_t::jaccard_k: return "jaccard";
243245
case metric_kind_t::hamming_k: return "hamming";
244246
case metric_kind_t::tanimoto_k: return "tanimoto";
@@ -275,6 +277,8 @@ inline expected_gt<metric_kind_t> metric_from_name(char const* name, std::size_t
275277
parsed.result = metric_kind_t::cos_k;
276278
} else if (str_equals(name, len, "haversine")) {
277279
parsed.result = metric_kind_t::haversine_k;
280+
} else if (str_equals(name, len, "divergence")) {
281+
parsed.result = metric_kind_t::divergence_k;
278282
} else if (str_equals(name, len, "pearson")) {
279283
parsed.result = metric_kind_t::pearson_k;
280284
} else if (str_equals(name, len, "hamming")) {
@@ -284,8 +288,8 @@ inline expected_gt<metric_kind_t> metric_from_name(char const* name, std::size_t
284288
} else if (str_equals(name, len, "sorensen")) {
285289
parsed.result = metric_kind_t::sorensen_k;
286290
} else
287-
parsed.failed(
288-
"Unknown distance, choose: l2sq, ip, cos, haversine, jaccard, pearson, hamming, tanimoto, sorensen");
291+
parsed.failed("Unknown distance, choose: l2sq, ip, cos, haversine, divergence, jaccard, pearson, hamming, "
292+
"tanimoto, sorensen");
289293
return parsed;
290294
}
291295

@@ -1180,6 +1184,35 @@ template <typename scalar_at = float, typename result_at = float> struct metric_
11801184
}
11811185
};
11821186

1187+
#include <cmath>
1188+
#include <vector>
1189+
1190+
/**
1191+
* @brief Measures Jensen-Shannon Divergence between two probability distributions.
1192+
*/
1193+
template <typename scalar_t = float, typename result_t = float> struct metric_divergence_gt {
1194+
using scalar_t = scalar_t;
1195+
using result_t = result_t;
1196+
1197+
inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept {
1198+
result_t kld_pm{}, kld_qm{};
1199+
scalar_t epsilon = std::numeric_limits<scalar_t>::epsilon();
1200+
#if USEARCH_USE_OPENMP
1201+
#pragma omp simd reduction(+ : kld_pm, kld_qm)
1202+
#elif defined(USEARCH_DEFINED_CLANG)
1203+
#pragma clang loop vectorize(enable)
1204+
#elif defined(USEARCH_DEFINED_GCC)
1205+
#pragma GCC ivdep
1206+
#endif
1207+
for (std::size_t i = 0; i != dim; ++i) {
1208+
scalar_t mi = (p[i] + q[i]) / 2 + epsilon;
1209+
kld_pm += p[i] * std::log((p[i] + epsilon) / mi);
1210+
kld_qm += q[i] * std::log((q[i] + epsilon) / mi);
1211+
}
1212+
return (kld_pm + kld_qm) / 2;
1213+
}
1214+
};
1215+
11831216
struct cos_i8_t {
11841217
using scalar_t = i8_t;
11851218
using result_t = f32_t;
@@ -1452,6 +1485,17 @@ class metric_punned_t {
14521485
}
14531486
break;
14541487
}
1488+
case metric_kind_t::divergence_k: {
1489+
switch (scalar_kind_) {
1490+
case scalar_kind_t::f16_k:
1491+
raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f16_t, f32_t>>;
1492+
break;
1493+
case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f32_t>>; break;
1494+
case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_divergence_gt<f64_t>>; break;
1495+
default: raw_ptr_ = nullptr; break;
1496+
}
1497+
break;
1498+
}
14551499
case metric_kind_t::jaccard_k: // Equivalent to Tanimoto
14561500
case metric_kind_t::tanimoto_k:
14571501
raw_ptr_ = (punned_ptr_t)&equidimensional_<metric_tanimoto_gt<b1x8_t>>,

Diff for: javascript/usearch.js

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const MetricKind = {
1111
IP: 'ip',
1212
L2sq: 'l2sq',
1313
Haversine: 'haversine',
14+
Divergence: 'divergence',
1415
Pearson: 'pearson',
1516
Jaccard: 'jaccard',
1617
Hamming: 'hamming',

Diff for: objc/USearchObjective.mm

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ metric_kind_t to_native_metric(USearchMetric m) {
3333
case USearchMetricHaversine:
3434
return metric_kind_t::haversine_k;
3535

36+
case USearchMetricDivergence:
37+
return metric_kind_t::divergence_k;
38+
3639
case USearchMetricJaccard:
3740
return metric_kind_t::jaccard_k;
3841

Diff for: python/lib.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ PYBIND11_MODULE(compiled, m) {
864864
.value("L2sq", metric_kind_t::l2sq_k)
865865

866866
.value("Haversine", metric_kind_t::haversine_k)
867+
.value("Divergence", metric_kind_t::divergence_k)
867868
.value("Pearson", metric_kind_t::pearson_k)
868869
.value("Jaccard", metric_kind_t::jaccard_k)
869870
.value("Hamming", metric_kind_t::hamming_k)

Diff for: python/usearch/index.py

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _normalize_metric(metric) -> MetricKind:
151151
"l2sq": MetricKind.L2sq,
152152
"l2_sq": MetricKind.L2sq,
153153
"haversine": MetricKind.Haversine,
154+
"divergence": MetricKind.Divergence,
154155
"pearson": MetricKind.Pearson,
155156
"hamming": MetricKind.Hamming,
156157
"tanimoto": MetricKind.Tanimoto,

Diff for: rust/lib.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ metric_kind_t rust_to_cpp_metric(MetricKind value) {
107107
case MetricKind::Cos: return metric_kind_t::cos_k;
108108
case MetricKind::Pearson: return metric_kind_t::pearson_k;
109109
case MetricKind::Haversine: return metric_kind_t::haversine_k;
110+
case MetricKind::Divergence: return metric_kind_t::divergence_k;
110111
case MetricKind::Hamming: return metric_kind_t::hamming_k;
111112
case MetricKind::Tanimoto: return metric_kind_t::tanimoto_k;
112113
case MetricKind::Sorensen: return metric_kind_t::sorensen_k;

0 commit comments

Comments
 (0)