diff --git a/.commitlintrc.yaml b/.commitlintrc.yaml new file mode 100644 index 00000000..4b583467 --- /dev/null +++ b/.commitlintrc.yaml @@ -0,0 +1,3 @@ +extends: + - "@commitlint/config-conventional" + diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 888eef4c..258ab9d8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,6 +1,11 @@ name: CI -on: [push] +on: + push: + branches: [ "**" ] + pull_request: + branches: [ "**" ] + workflow_dispatch: concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -10,12 +15,21 @@ env: LATEST_STABLE_RUST_VERSION: "TBD" jobs: + commitlint: + name: Commit Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: wagoid/commitlint-github-action@v6 + with: + configFile: .commitlintrc.yaml + format: name: Formatting runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -44,7 +58,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -80,7 +94,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -115,7 +129,7 @@ jobs: name: Cargo Check (Stable) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -149,7 +163,7 @@ jobs: name: Cargo Check (Nightly) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -185,7 +199,7 @@ jobs: name: Run Tests (Stable, no crate features enabled that require unstable Rust) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -218,7 +232,7 @@ jobs: name: Run Tests (Nightly, all features enabled) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index f65ed274..eeb14ae0 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -12,7 +12,7 @@ jobs: benchmarks: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -27,7 +27,7 @@ jobs: run: cargo codspeed build - name: Run the benchmarks - uses: CodSpeedHQ/action@v3 + uses: CodSpeedHQ/action@v4 with: run: cargo codspeed run token: ${{ secrets.CODSPEED_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index f85ab75a..aface7a5 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -1,6 +1,11 @@ name: Coverage -on: [push] +on: + push: + branches: [ "**" ] + pull_request: + branches: [ "**" ] + workflow_dispatch: env: CARGO_TERM_COLOR: always @@ -10,7 +15,7 @@ jobs: name: Test Coverage runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false diff --git a/.github/workflows/pre-release-string-updater.yml b/.github/workflows/pre-release-string-updater.yml index 6d3e7df0..bd3ad8f9 100644 --- a/.github/workflows/pre-release-string-updater.yml +++ b/.github/workflows/pre-release-string-updater.yml @@ -12,7 +12,7 @@ jobs: name: Pre Release String Updater runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false @@ -50,7 +50,7 @@ jobs: git commit -a -m "chore: update version references" - name: Push changes - uses: ad-m/github-push-action@v0.8.0 + uses: ad-m/github-push-action@v1.0.0 with: github_token: ${{ secrets.GITHUB_TOKEN }} branch: ${{ github.ref }} diff --git a/.github/workflows/pre-release.yml b/.github/workflows/pre-release.yml index a6526f4f..0159c36b 100644 --- a/.github/workflows/pre-release.yml +++ b/.github/workflows/pre-release.yml @@ -13,9 +13,10 @@ jobs: pull-requests: write contents: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false + fetch-depth: 0 - name: Get date run: | diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 506b496c..9b6389e0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ jobs: if: ${{ github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'release') }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: show-progress: false diff --git a/CHANGELOG.md b/CHANGELOG.md index a54e03a9..66d3088d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,66 @@ # Kiddo Changelog +## Unreleased + +I've kept back these changes for now as, whilst extremely welcome, the change to the return type of `size()` would be breaking. +I'm hoping to have v6 available soon, which will include both these changes as well as addressing quite a few other long-running feature requests. + +### ✨ Features + +- Improve the flexibility of T type, KdTree.size. Add generate nearest_one_point (https://github.com/sdd/kiddo/pull/214, @pjkundert) +- Fix broken test of custom struct as T. Add test for `()` as T (https://github.com/sdd/kiddo/pull/243 @pjkundert) + +## [5.2.4] - 2026-01-01 (Happy New Year! šŸŽ‰) + +### Deps + +- update cmov dep from 0.3 to 0.4 after 0.3 got yanked (see https://github.com/RustCrypto/utils/issues/1304). Thanks @yuby and @jqnatividad + +## [5.2.3] - 2025-12-08 + +### šŸ› Bug Fixes + +- Correct slice access in remainder processing and remove unsafe (@MarkusZoppelt) +- Use `try_from()` with error for `leaf_items.len()` (@MarkusZoppelt) + +### ā™»ļø Refactor + +- within_unsorted_iter no longer uses a generator (@KvA2KLvAST) +- Remove needless SubAssign trait bound from Content trait + +### Deps + +- Remove doc-comment dependency and use doc attribute that was added in Rust 1.54 instead (@jqnatividad) +- Use `doc` attribute instead of `doc_comment!` (@jqnatividad) +- Update actions/checkout action to v6 +- Update codspeedhq/action action to v4 +- Update ad-m/github-push-action action to v1 +- Update rust crate rstest to 0.26 +- Update rust crate codspeed-criterion-compat to v4 + +### Ci + +- Update CI workflow triggers to include PR and workflow_dispatch +- Permit coverage to run for PRs as well +- Fix release-plz and add commitlint + +### šŸ’„ Styling + +- Remove unnecessary parentheses +- Fix formatting +- Fix some lint issues + +### 🧪 Testing + +- Add regression test for remainder slice access bug (@MarkusZoppelt) + +## [5.2.2] - 2025-06-30 + +### ā™»ļø Refactor + +- refactor `within_unsorted_iter` to decouple the lifetime of the iterator from that of the query by + copying `query` once at the start of the call (@KvA2KLvAST) + ## [5.2.1] - 2025-06-29 ### šŸ“ Documentation diff --git a/Cargo.toml b/Cargo.toml index d90dbdc3..c3e68580 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kiddo" -version = "5.2.1" +version = "5.2.4" edition = "2021" authors = ["Scott Donnelly "] description = "A high-performance, flexible, ergonomic k-d tree library. Ideal for geo- and astro- nearest-neighbour and k-nearest-neighbor queries" @@ -27,16 +27,15 @@ opt-level = 3 aligned-vec = "0.6.1" array-init = "2.1.0" az = "1" -cmov = "0.3" +cmov = "0.4" divrem = "1" -doc-comment = "0.3" num-traits = "0.2" ordered-float = "5" sorted-vec = "0.8" [dev-dependencies] bincode = { version = "2", features = ["serde"] } -codspeed-criterion-compat = "2.10" +codspeed-criterion-compat = "4.0" criterion = "0.6" elapsed = "0.1.2" flate2 = { version = "1", features = ["zlib-ng-compat"], default-features = false } @@ -48,7 +47,7 @@ radians = "0.3" rand = "0.9" rand_distr = "0.5" rayon = "1" -rstest = "0.25" +rstest = "0.26" serde = { version = "1", features = ["derive", "rc"] } serde_json = "1" ubyte = "0.10" diff --git a/README.md b/README.md index 32b5dd41..c67f3ecd 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Kiddo provides: Add `kiddo` to `Cargo.toml` ```toml [dependencies] -kiddo = "5.2.0" +kiddo = "5.2.4" ``` Add points to k-d tree and query nearest n points with distance function diff --git a/examples/cities.rs b/examples/cities.rs index 8da29473..1c8a4b1c 100644 --- a/examples/cities.rs +++ b/examples/cities.rs @@ -190,7 +190,7 @@ fn main() -> Result<(), Box> { let dist = kilometres_to_unit_sphere_squared_euclidean(1000.0); let best_3_iter = kdtree.best_n_within::(&query, dist, 3); let best_3 = best_3_iter - .map(|neighbour| (&cities[neighbour.item].name)) + .map(|neighbour| &cities[neighbour.item].name) .collect::>(); println!("\nMost populous 3 cities within 1000km of 0N, 0W: {best_3:?}"); diff --git a/git-cliff.toml b/git-cliff.toml index 815f6214..706f4b0c 100644 --- a/git-cliff.toml +++ b/git-cliff.toml @@ -40,6 +40,7 @@ commit_parsers = [ { message = "^refactor", group = "ā™»ļø Refactor"}, { message = "^style", group = "šŸ’„ Styling"}, { message = "^test", group = "🧪 Testing"}, + { message = "^chore", group = "🧹 Chore"}, ] protect_breaking_commits = true tag_pattern = "v[0-9]\\.[0-9]\\.[0-9]" diff --git a/release-plz.toml b/release-plz.toml index f75cbce8..67776be9 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -1,10 +1,5 @@ [workspace] -allow_dirty = false changelog_config = "git-cliff.toml" # use a custom git-cliff configuration -changelog_update = true # disable changelog updates dependencies_update = true # update dependencies with `cargo update` -git_release_enable = true # disable GitHub/Gitea releases pr_labels = ["release"] # add the `release` label to the release Pull Request -publish = true -publish_allow_dirty = false -semver_check = true # disable API breaking changes checks + diff --git a/src/common/generate_best_n_within.rs b/src/common/generate_best_n_within.rs index 1fbdf16f..21c2e576 100644 --- a/src/common/generate_best_n_within.rs +++ b/src/common/generate_best_n_within.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_best_n_within { ($leafnode:ident, $comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn best_n_within( &self, @@ -77,7 +76,7 @@ macro_rules! generate_best_n_within { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -147,4 +146,5 @@ macro_rules! generate_best_n_within { } } } -}}} + }; +} diff --git a/src/common/generate_nearest_n.rs b/src/common/generate_nearest_n.rs index f518101b..f4efac07 100644 --- a/src/common/generate_nearest_n.rs +++ b/src/common/generate_nearest_n.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_nearest_n { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn nearest_n(&self, query: &[A; K], qty: usize) -> Vec> where @@ -66,7 +65,7 @@ macro_rules! generate_nearest_n { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if Self::dist_belongs_in_heap(rd, results) { off[split_dim] = new_off; @@ -113,8 +112,9 @@ macro_rules! generate_nearest_n { } } - #[inline] + #[inline(always)] fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } -}}} + }; +} diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index ccd741b8..09d56c4e 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -2,9 +2,7 @@ #[macro_export] macro_rules! generate_nearest_n_within_unsorted { ($comments:tt) => { - doc_comment! { - concat!$comments, - + #[doc = concat!$comments] #[inline] pub fn nearest_n_within(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool) -> Vec> where @@ -88,7 +86,7 @@ macro_rules! generate_nearest_n_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -118,7 +116,7 @@ macro_rules! generate_nearest_n_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); @@ -130,6 +128,5 @@ macro_rules! generate_nearest_n_within_unsorted { }); } } - } }; } diff --git a/src/common/generate_nearest_one.rs b/src/common/generate_nearest_one.rs index ebf950a6..c5838b20 100644 --- a/src/common/generate_nearest_one.rs +++ b/src/common/generate_nearest_one.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_nearest_one { ($leafnode:ident, $comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn nearest_one(&self, query: &[A; K]) -> NearestNeighbour where @@ -68,7 +67,7 @@ macro_rules! generate_nearest_one { nearest = nearest_neighbour; } - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim] = new_off; @@ -125,6 +124,5 @@ macro_rules! generate_nearest_one { } }); } - } }; } diff --git a/src/common/generate_within.rs b/src/common/generate_within.rs index bc0b9221..7952636f 100644 --- a/src/common/generate_within.rs +++ b/src/common/generate_within.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_within { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within(&self, query: &[A; K], dist: A) -> Vec> where @@ -13,6 +12,5 @@ macro_rules! generate_within { matching_items.sort(); matching_items } - } }; } diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index f28c2536..656ab033 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_within_unsorted { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within_unsorted(&self, query: &[A; K], dist: A) -> Vec> where @@ -69,7 +68,7 @@ macro_rules! generate_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -99,7 +98,7 @@ macro_rules! generate_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); @@ -111,6 +110,5 @@ macro_rules! generate_within_unsorted { }); } } - } }; } diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index a3808c7e..a44ad861 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -2,12 +2,11 @@ #[macro_export] macro_rules! generate_within_unsorted_iter { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within_unsorted_iter( &'a self, - query: &'a [A; K], + query: &'query [A; K], dist: A, ) -> WithinUnsortedIter<'a, A, T> where @@ -16,10 +15,12 @@ macro_rules! generate_within_unsorted_iter { let mut off = [A::zero(); K]; let root_index: IDX = *transform(&self.root_index); + let query = query.clone(); let gen = Gn::new_scoped(move |gen_scope| { + let query_ref = &query; unsafe { self.within_unsorted_iter_recurse::( - query, + query_ref, dist, root_index, 0, @@ -77,7 +78,7 @@ macro_rules! generate_within_unsorted_iter { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -107,7 +108,7 @@ macro_rules! generate_within_unsorted_iter { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); @@ -121,6 +122,5 @@ macro_rules! generate_within_unsorted_iter { gen_scope } - } }; } diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 5e901432..c2a6bb9d 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -55,6 +55,70 @@ impl DistanceMetric for Manhattan { b - a } } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } +} + +/// Returns the Chebyshev distance (L-infinity norm) between two points. +/// +/// This is the maximum of the absolute differences between coordinates of points. +/// +/// # Examples +/// +/// ```rust +/// use fixed::types::extra::U0; +/// use fixed::FixedU16; +/// use kiddo::traits::DistanceMetric; +/// use kiddo::fixed::distance::Chebyshev; +/// type Fxd = FixedU16; +/// +/// let ZERO = Fxd::from_num(0); +/// let ONE = Fxd::from_num(1); +/// let TWO = Fxd::from_num(2); +/// +/// assert_eq!(ZERO, Chebyshev::dist(&[ZERO, ZERO], &[ZERO, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ONE])); +/// assert_eq!(TWO, Chebyshev::dist(&[ZERO, ZERO], &[TWO, ONE])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| { + if a_val > b_val { + a_val - b_val + } else { + b_val - a_val + } + }) + .reduce(|a, b| if a > b { a } else { b }) + .unwrap_or(A::ZERO) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + if a > b { + a - b + } else { + b - a + } + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + if rd > delta { + rd + } else { + delta + } + } } /// Returns the squared euclidean distance between two points. @@ -99,4 +163,318 @@ impl DistanceMetric for SquaredEuclidean { let diff: A = a.dist(b); diff * diff } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use fixed::types::extra::U0; + use rstest::rstest; + + type FxdU16 = fixed::FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], ONE)] + #[case([ZERO, ZERO], [TWO, ONE], TWO)] + #[case([ZERO, ZERO], [ONE, TWO], TWO)] + fn test_chebyshev_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, TWO, THREE], THREE)] + #[case([FIVE, FIVE, FIVE], [ONE, TWO, THREE], FOUR)] + fn test_chebyshev_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] + #[case([TWO, THREE], [ONE, ONE], THREE)] + fn test_manhattan_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], FxdU16::lit("2"))] + #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] + #[case([ONE, TWO], [TWO, ONE], FxdU16::lit("2"))] + fn test_squared_euclidean_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, ZERO, ZERO], ONE)] + #[case([ONE, ONE, ONE], [TWO, TWO, TWO], THREE)] + fn test_squared_euclidean_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::diff(THREE, ONE, TWO)] + fn test_manhattan_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_chebyshev_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_squared_euclidean_dist1( + #[case] a: FxdU16, + #[case] b: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero_one(ZERO, ONE, ONE)] + #[case::one_zero(ONE, ZERO, ONE)] + #[case::first_larger(ONE, TWO, TWO)] + #[case::second_larger(TWO, ONE, TWO)] + fn test_chebyshev_accumulate( + #[case] rd: FxdU16, + #[case] delta: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::fixed::kdtree::KdTree; + use fixed::types::extra::U0; + use fixed::FixedU16; + use rstest::rstest; + + type FxdU16 = FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + enum DataScenario { + NoTies, + Ties, + } + + impl DataScenario { + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => { + vec![vec![ONE], vec![TWO], vec![THREE], vec![FOUR], vec![FIVE]] + } + (DataScenario::NoTies, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![TWO, ZERO], + vec![THREE, ZERO], + vec![FOUR, ZERO], + vec![FIVE, ZERO], + ], + (DataScenario::NoTies, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![TWO, ZERO, ZERO], + vec![THREE, ZERO, ZERO], + vec![FOUR, ZERO, ZERO], + vec![FIVE, ZERO, ZERO], + ], + (DataScenario::Ties, 1) => vec![ + vec![ZERO], + vec![ONE], + vec![ONE], + vec![TWO], + vec![THREE], + vec![THREE], + ], + (DataScenario::Ties, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![ZERO, ONE], + vec![TWO, ZERO], + vec![ZERO, TWO], + vec![TWO, TWO], + ], + (DataScenario::Ties, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![ZERO, ONE, ZERO], + vec![ZERO, ZERO, ONE], + vec![TWO, ZERO, ZERO], + vec![ZERO, TWO, ZERO], + ], + _ => panic!("Unsupported dimension"), + } + } + } + + fn run_test_helper>(dim: usize, scenario: DataScenario, n: usize) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[FxdU16; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [ZERO; 6]; + for (i, &val) in row.iter().enumerate() { + if i < 6 { + p[i] = val; + } + } + points.push(p); + } + + let mut query_arr = [ZERO; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + let expected: Vec<(usize, FxdU16)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + let mut tree: KdTree = KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u32); + } + + let results = tree.nearest_n::(&query_arr, n); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, ZERO, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u32, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + #[rstest] + fn test_nearest_n_chebyshev( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } } diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index 01848e88..094211af 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -163,7 +163,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted.rs b/src/fixed/query/within_unsorted.rs index 3075cc15..b5a1dcb8 100644 --- a/src/fixed/query/within_unsorted.rs +++ b/src/fixed/query/within_unsorted.rs @@ -167,7 +167,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted_iter.rs b/src/fixed/query/within_unsorted_iter.rs index 47b42f36..de22be63 100644 --- a/src/fixed/query/within_unsorted_iter.rs +++ b/src/fixed/query/within_unsorted_iter.rs @@ -11,8 +11,15 @@ use crate::within_unsorted_iter::WithinUnsortedIter; use crate::generate_within_unsorted_iter; -impl<'a, A: Axis, T: Content, const K: usize, const B: usize, IDX: Index + Send> - KdTree +impl< + 'a, + 'query, + A: Axis, + T: Content, + const K: usize, + const B: usize, + IDX: Index + Send, + > KdTree where usize: Cast, { @@ -171,7 +178,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/distance.rs b/src/float/distance.rs index 3784f946..37beb03d 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -40,6 +40,45 @@ impl DistanceMetric for Manhattan { } } +/// Returns the Chebyshev / L-infinity distance between two points. +/// +/// Chebyshev distance is the maximum absolute difference along any axis. +/// Also known as chessboard distance or L-infinity norm. +/// +/// re-exported as `kiddo::Chebyshev` for convenience +/// +/// # Examples +/// +/// ```rust +/// use kiddo::traits::DistanceMetric; +/// use kiddo::Chebyshev; +/// +/// assert_eq!(0f32, Chebyshev::dist(&[0f32, 0f32], &[0f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 1f32])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| (a_val - b_val).abs()) + .fold(A::zero(), |acc, val| acc.max(val)) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + (a - b).abs() + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.max(delta) + } +} + /// Returns the squared euclidean distance between two points. /// /// Faster than Euclidean distance due to not needing a square root, but still @@ -73,3 +112,1007 @@ impl DistanceMetric for SquaredEuclidean { (a - b) * (a - b) } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + mod common_metric_tests { + use super::*; + + #[rstest] + #[case::zeros_1d([0.0f32], [0.0f32])] + #[case::normal_1d([1.0f32], [2.0f32])] + #[case::neg_1d([-1.0f32], [1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32], [0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32], [3.0f32, 4.0f32])] + #[case::large_2d([1e30f32, 1e30f32], [-1e30f32, -1e30f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32], [0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32])] + #[case::zeros_4d([0.0f32; 4], [0.0f32; 4])] + #[case::normal_4d([1.0f32; 4], [2.0f32; 4])] + #[case::zeros_5d([0.0f32; 5], [0.0f32; 5])] + #[case::normal_5d([1.0f32; 5], [2.0f32; 5])] + fn test_metric_non_negativity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + let distance = D::dist(&a, &b); + assert!(distance >= A::zero()); + } + + #[rstest] + #[case::zeros_1d([0.0f32])] + #[case::normal_1d([1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32])] + #[case::zeros_4d([0.0f32; 4])] + #[case::zeros_5d([0.0f32; 5])] + fn test_metric_identity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + ) { + assert_eq!(D::dist(&a, &a), A::zero()); + } + + #[rstest] + #[case::normal_1d([1.0f64], [2.0f64])] + #[case::neg_1d([-1.0f64], [1.0f64])] + #[case::normal_2d([1.0f64, 2.0f64], [3.0f64, 4.0f64])] + #[case::normal_3d([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64])] + #[case::normal_4d([1.0f64; 4], [2.0f64; 4])] + #[case::normal_5d([1.0f64; 5], [2.0f64; 5])] + fn test_metric_symmetry>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + assert_eq!(D::dist(&a, &b), D::dist(&b, &a)); + } + } + + mod manhattan_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 4.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 4.0f32)] // fractional values + fn test_manhattan_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 6.0f64)] // 3D diagonal + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 9.0f64)] // 3D offset + fn test_manhattan_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_manhattan_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 16.0f32; // |5-1| + |6-2| + |7-3| + |8-4| = 4+4+4+4 = 16 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 25.0f64; // |5-0| + |6-1| + |7-2| + |8-3| + |9-4| = 5+5+5+5+5 = 25 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 6.0f32 + ); // fractional negative to positive + assert_eq!( + >::dist1(1000.0f32, -1000.0f32), + 2000.0f32 + ); // large values + } + } + + mod squared_euclidean_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal (1^2 + 1^2) + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 8.0f32)] // negative to positive (2^2 + 2^2) + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 8.0f32)] // fractional values (2^2 + 2^2) + #[case([0.0f32, 0.0f32], [3.0f32, 4.0f32], 25.0f32)] // 3-4-5 triangle + fn test_squared_euclidean_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 2.0f64], 9.0f64)] // 3D (1^2 + 2^2 + 2^2) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 27.0f64)] // 3D offset (3^2 + 3^2 + 3^2) + fn test_squared_euclidean_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 25.0f32)] // 1D positive (5^2) + #[case([5.0f32], [0.0f32], 25.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 100.0f32)] // 1D negative to positive (10^2) + fn test_squared_euclidean_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[test] + fn test_squared_euclidean_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 36.0f32 + ); // fractional negative to positive (6^2) + assert_eq!( + >::dist1(10.0f32, -10.0f32), + 400.0f32 + ); // large values (20^2) + } + + #[test] + fn test_squared_euclidean_triangle_inequality_property() { + // Test that squared Euclidean distance preserves ordering + let a = [0.0f32, 0.0f32]; + let b = [1.0f32, 0.0f32]; + let c = [1.0f32, 1.0f32]; + + let dist_ab = SquaredEuclidean::dist(&a, &b); + let dist_ac = SquaredEuclidean::dist(&a, &c); + let dist_bc = SquaredEuclidean::dist(&b, &c); + + // For these points: dist(a,b) = 1, dist(b,c) = 1, dist(a,c) = 2 + assert_eq!(dist_ab, 1.0f32); + assert_eq!(dist_bc, 1.0f32); + assert_eq!(dist_ac, 2.0f32); + } + } + + mod chebyshev_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 1.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 2.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 2.0f32)] // fractional values + #[case([0.0f32, 0.0f32], [2.0f32, 1.0f32], 2.0f32)] // max on first axis + #[case([0.0f32, 0.0f32], [1.0f32, 2.0f32], 2.0f32)] // max on second axis + fn test_chebyshev_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 3.0f64)] // 3D diagonal (max is 3) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 3.0f64)] // 3D offset (max is 3) + fn test_chebyshev_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_chebyshev_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 4.0f32; // max(|5-1|, |6-2|, |7-3|, |8-4|) = max(4, 4, 4, 4) = 4 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 5.0f64; // max(|5-0|, |6-1|, |7-2|, |8-3|, |9-4|) = max(5, 5, 5, 5, 5) = 5 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case(0.0f32, 0.0f32, 0.0f32)] // zero difference + #[case(1.0f32, 0.0f32, 1.0f32)] // positive difference + #[case(0.0f32, 1.0f32, 1.0f32)] // negative difference (reversed) + #[case(-2.5f32, 3.5f32, 6.0f32)] // fractional negative to positive + #[case(1000.0f32, -1000.0f32, 2000.0f32)] // large values + fn test_chebyshev_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) { + assert_eq!(>::dist1(a, b), expected); + } + + #[test] + fn test_chebyshev_max_property() { + // Test that Chebyshev correctly finds the maximum difference + let a = [0.0, 0.0]; + let b = [3.0, 1.0]; + + let result = Chebyshev::dist(&a, &b); + + // max(|0-3|, |0-1|) = max(3, 1) = 3 + assert_eq!(result, 3.0); + + // Verify it's not Manhattan (which would be 4) or Euclidean (sqrt(10)) + assert_ne!(result, 4.0); + assert_ne!(result, (10.0_f64).sqrt()); + } + } + + #[cfg(feature = "f16")] + mod f16_tests { + use super::*; + use half::f16; + + #[test] + fn test_manhattan_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = Manhattan::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + + #[test] + fn test_squared_euclidean_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = SquaredEuclidean::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + } + + mod integration_tests { + use super::*; + use crate::KdTree; + use rand::prelude::*; + use rand_distr::Normal; + use rstest::rstest; + + #[derive(Debug, Clone, Copy)] + enum DataScenario { + NoTies, + Ties, + Gaussian, + } + + #[derive(Debug, Clone, Copy)] + enum TreeType { + Mutable, + Immutable, + } + + impl DataScenario { + /// Get data scenario + /// + /// Predefined data has input dimension (`dim`) and either + /// with `DataScenario::NoTies` or `DataScenario::Ties`. + /// + /// # Parameters + /// - `dim`: The dimensionality of the data to retrieve. + /// Must be a value between 1 and 4 (inclusive). + /// + /// # Returns + /// - `Vec>`: A 2D vector where each inner vector represents a data point. + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => vec![ + vec![1.0], + vec![2.0], + vec![4.0], + vec![7.0], + vec![-9.0], + vec![16.0], + ], + (DataScenario::NoTies, 2) => vec![ + vec![0.0, 0.0], + vec![1.1, 0.1], + vec![2.3, 0.4], + vec![3.6, 0.9], + vec![5.0, 1.6], + vec![6.5, 2.5], + ], + (DataScenario::NoTies, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.1, 0.1, 0.01], + vec![2.3, 0.4, 0.08], + vec![-3.6, -0.9, -0.27], + vec![5.0, 1.6, 0.64], + vec![6.5, 2.5, 1.25], + ], + (DataScenario::NoTies, 4) => vec![ + vec![0.0, 0.0, 0.0, 1000.0], + vec![1.1, 0.1, 0.01, 1000.001], + vec![2.3, 0.4, 0.08, 1000.008], + vec![3.6, 0.9, 0.27, 1000.027], + vec![5.0, 1.6, 0.64, 1000.256], + vec![6.5, 2.5, 1.25, 1000.625], + ], + (DataScenario::Ties, 1) => vec![ + vec![0.0], + vec![1.0], + vec![1.0], + vec![2.0], + vec![2.0], + vec![3.0], + ], + (DataScenario::Ties, 2) => vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.0, 1.0], + vec![-1.0, 0.0], + vec![0.0, -1.0], + vec![1.0, 1.0], + ], + (DataScenario::Ties, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0], + vec![0.0, -1.0, 0.0], + ], + (DataScenario::Ties, 4) => vec![ + vec![0.0, 0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0, 0.0], + vec![0.0, 0.0, 1.0, 0.0], + vec![0.0, 0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0, 0.0], + ], + (DataScenario::Gaussian, d) => { + let mut rng = StdRng::seed_from_u64(8757); + let normal = Normal::new(1.0, 10.0).unwrap(); + let n_samples = 200; + let mut data = vec![vec![0.0; d]; n_samples]; + for i in 0..n_samples { + for j in 0..d { + data[i][j] = normal.sample(&mut rng); + } + } + data + } + _ => panic!("Unsupported dimension {} for scenario {:?}", dim, self), + } + } + } + + /// Helper function to test nearest_n queries for `D: DistanceMetric` + /// + /// Tests KD-tree Chebyshev distance queries across different tree types and + /// data scenarios. This simplifies testing across different combinations. + /// + /// # What this function does + /// 1. Get test data points based on a scenario (NoTies/Ties) and dimensionality + /// 2. Builds either MutableKdTree (incremental) or ImmutableKdTree (bulk construction) + /// 3. Performs nearest_n query with Chebyshev distance from point 0 + /// 4. Compares results against Brute-force distances, + /// calculated from `>::dist`. + /// + /// # Choices + /// - Fixed-size array `[f64; 6]`. For `dim<6` a subspace/padding is used for practicality + /// + /// # Assertions + /// - Point 0 is always the query point (distance 0, index 0 expected first result) + /// - NoTies scenario: checks distances and item IDs for points with unique distances + /// - Ties scenario: checks distances (order among ties is non-deterministic) + fn run_test_helper>( + dim: usize, + tree_type: TreeType, + scenario: DataScenario, + n: usize, + ) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [0.0; 6]; + for (i, &val) in row.iter().enumerate() { + p[i] = val; + } + points.push(p); + } + + let mut query_arr = [0.0; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + // Calculate ground truth with brute-force approach + let mut expected: Vec<(usize, f64)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + println!( + "Query: {:?}, TreeType: {:?}, Scenario: {:?}, dim={}, n={}", + query_point, tree_type, scenario, dim, n + ); + + // Query based on tree type + let results = match tree_type { + TreeType::Mutable => { + let mut tree: crate::float::kdtree::KdTree = + crate::float::kdtree::KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u64); + } + tree.nearest_n::(&query_arr, n) + } + TreeType::Immutable => { + let tree: crate::immutable::float::kdtree::ImmutableKdTree = + crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); + tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) + } + }; + + println!("Results (len: {}):", results.len()); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, 0.0, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u64, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + /// Chebyshev distance nearest-neighbor query tests. + /// + /// Test matrix covering all combinations of mutable/immutable trees, + /// data scenarios (with/out ties), dimensions, and neighbor query counts. + /// + /// Currently passing tests: + /// - All MutableKdTree tests pass + /// - ImmutableKdTree with NoTies: + /// - Pass for when just querying the root n=1 or dim=1 + /// - ImmutableKdTree with Ties: Several pass (one edge case failure for n=6, dim=2) + /// + /// Currently failing tests (16 of 96): + /// - ImmutableKdTree + NoTies: fails for dim>=2 AND n>=2 (15 failures) + /// - ImmutableKdTree + Ties: 1 failure (n=6, dim=2) + #[rstest] + fn test_nearest_n_chebyshev( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[test] + fn test_nearest_n_manhattan_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a simple pattern + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // distance 1 from query point + ([0.0f32, 1.0f32], 2), // distance 1 from query point + ([2.0f32, 0.0f32], 3), // distance 2 from query point + ([0.0f32, 2.0f32], 4), // distance 2 from query point + ([3.0f32, 3.0f32], 5), // distance 6 from query point + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 4); + + // Expected order: [0], [1], [2], [3], [4] + // Distances: 0, 1, 1, 2, 2 + // But we only ask for 4 nearest + assert_eq!(results.len(), 4); + + // First result should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Manhattan distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Fourth should be one of the points at distance 2 + assert!(results[3].item == 3 || results[3].item == 4); + assert_eq!(results[3].distance, 2.0); + } + + #[test] + fn test_nearest_n_squared_euclidean_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a pattern where Euclidean and Manhattan differ + let points = [ + ([0.0, 0.0], 0), // distance 0 from query point + ([1.0, 0.0], 1), // Euclidean: 1, Manhattan: 1 + ([0.0, 1.0], 2), // Euclidean: 1, Manhattan: 1 + ([1.0, 1.0], 3), // Euclidean: 2, Manhattan: 2 + ([2.0, 0.0], 4), // Euclidean: 4, Manhattan: 2 + ([0.0, 2.0], 5), // Euclidean: 4, Manhattan: 2 + ([3.0, 4.0], 6), // Euclidean: 25, Manhattan: 7 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + let results = kdtree.nearest_n::(&query_point, 5); + + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Euclidean distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Next two should be the points at Euclidean distance 2 + assert_eq!(results[3].item, 3); + assert_eq!(results[3].distance, 2.0); + assert_eq!(results[4].item, 4); + assert_eq!(results[4].distance, 4.0); + + // Verify that points at squared Euclidean distance 4 are indeed farther + // than points at squared Euclidean distance 2 + assert!(results[4].distance > results[3].distance); + } + + #[test] + fn test_nearest_n_different_metrics_produce_different_orderings() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points where Manhattan and Euclidean give different orderings + let points = [ + ([0.0, 0.0], 0), // origin + ([2.0, 1.0], 1), // Manhattan: 3, Euclidean^2: 5 + ([1.0, 2.0], 2), // Manhattan: 3, Euclidean^2: 5 + ([3.0, 0.0], 3), // Manhattan: 3, Euclidean^2: 9 + ([0.0, 3.0], 4), // Manhattan: 3, Euclidean^2: 9 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + + let manhattan_results = kdtree.nearest_n::(&query_point, 3); + let euclidean_results = kdtree.nearest_n::(&query_point, 3); + + // Both should include the origin as first result + assert_eq!(manhattan_results[0].item, 0); + assert_eq!(euclidean_results[0].item, 0); + + // For Manhattan: points 1, 2, 3, 4 are all at distance 3 + // The ordering among ties depends on tree structure, but they should all have same distance + assert_eq!(manhattan_results[1].distance, 3.0); + assert_eq!(manhattan_results[2].distance, 3.0); + + // For Euclidean: points 1 and 2 are at distance sqrt(5) ā‰ˆ 2.236 (squared: 5) + // Points 3 and 4 are at distance 3 (squared: 9) + assert_eq!(euclidean_results[1].distance, 5.0); + assert_eq!(euclidean_results[2].distance, 5.0); + + // Verify that Euclidean ordering puts points 1 and 2 before 3 and 4 + let euclidean_items: Vec = euclidean_results + .iter() + .skip(1) // skip origin + .take(2) // take next 2 + .map(|nn| nn.item) + .collect(); + + assert!(euclidean_items.contains(&1) || euclidean_items.contains(&2)); + + // Calculate actual distances to verify our understanding + let p1 = [2.0, 1.0]; + let p2 = [1.0, 2.0]; + let p3 = [3.0, 0.0]; + + let manhattan_p1 = Manhattan::dist(&query_point, &p1); + let manhattan_p2 = Manhattan::dist(&query_point, &p2); + let manhattan_p3 = Manhattan::dist(&query_point, &p3); + + let euclidean_p1 = SquaredEuclidean::dist(&query_point, &p1); + let euclidean_p2 = SquaredEuclidean::dist(&query_point, &p2); + let euclidean_p3 = SquaredEuclidean::dist(&query_point, &p3); + + assert_eq!(manhattan_p1, 3.0); + assert_eq!(manhattan_p2, 3.0); + assert_eq!(manhattan_p3, 3.0); + + assert_eq!(euclidean_p1, 5.0); + assert_eq!(euclidean_p2, 5.0); + assert_eq!(euclidean_p3, 9.0); + } + + #[test] + fn test_nearest_n_3d_different_metrics() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in 3D space + let points = [ + ([1.0, 1.0, 1.0], 0), // origin + ([2.0, 1.0, 1.0], 1), // 1 unit away on x-axis + ([1.0, 2.0, 1.0], 2), // 1 unit away on y-axis + ([1.0, 1.0, 2.0], 3), // 1 unit away on z-axis + ([3.0, 1.0, 1.0], 4), // 2 units away on x-axis + ([0.0, 0.0, 0.0], 5), // sqrt(3) ā‰ˆ 1.732 from origin + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [1.0, 1.0, 1.0]; + let results = kdtree.nearest_n::(&query_point, 4); + + assert_eq!(results.len(), 4); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next three should be the points at Manhattan distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(3) // take next 3 + .map(|nn| nn.item) + .collect(); + + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&3)); + + // All nearby points should have distance 1 + for result in results.iter().skip(1).take(3) { + assert_eq!(result.distance, 1.0); + } + + // Point 4 should be farther (distance 2) and not in top 4 + let all_items: Vec = results.iter().map(|nn| nn.item).collect(); + assert!(!all_items.contains(&4)); + + // Point 5 has Manhattan distance 3, so definitely not in top 4 + assert!(!all_items.contains(&5)); + } + + #[test] + fn test_nearest_n_large_scale() { + let mut kdtree: KdTree = KdTree::new(); + + // Create a grid of points + let mut index = 0; + for x in 0i32..10 { + for y in 0i32..10 { + let point = [x as f32, y as f32]; + kdtree.add(&point, index); + index += 1; + } + } + + // Query from center of grid + let query_point = [5.0f32, 5.0f32]; + let results = kdtree.nearest_n::(&query_point, 10); + + assert_eq!(results.len(), 10); + + // First result should be the center point itself (index 55) + assert_eq!(results[0].item, 55); + assert_eq!(results[0].distance, 0.0); + + // Results should be ordered by increasing distance + for i in 1..10 { + assert!(results[i].distance >= results[i - 1].distance); + } + + // Verify distances make sense for a grid + // The nearest points should be at squared distances: 0, 1, 1, 1, 1, 2, 2, 4, 4, 5... + let expected_distances = [0.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32]; + + for (i, &expected_dist) in expected_distances.iter().enumerate() { + if i < results.len() { + assert_eq!(results[i].distance, expected_dist); + } + } + } + + #[test] + fn test_nearest_n_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points that show Chebyshev behavior + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([0.0f32, 1.0f32], 2), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([2.0f32, 0.0f32], 3), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([0.0f32, 2.0f32], 4), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([1.0f32, 1.0f32], 5), // Chebyshev: 1, Manhattan: 2, Euclidean^2: 2 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 5); + + // With Chebyshev, points at (1,0), (0,1), and (1,1) all have distance 1 + // Points at (2,0) and (0,2) have distance 2 + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next should all be at Chebyshev distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(4) // take next 4 + .filter(|r| (r.distance - 1.0).abs() < 0.001) // check for distance 1 (with some float tolerance) + .map(|nn| nn.item) + .collect(); + + // All of these should be in the results: 1, 2, 5 + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&5)); + } + + #[test] + fn test_within_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points with varying Chebyshev distances + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 + ([0.5f32, 0.5f32], 1), // Chebyshev: 0.5 + ([1.0f32, 0.0f32], 2), // Chebyshev: 1.0 + ([0.8f32, 0.9f32], 3), // Chebyshev: 0.9 + ([2.0f32, 0.0f32], 4), // Chebyshev: 2.0 + ([0.0f32, 2.0f32], 5), // Chebyshev: 2.0 + ([1.5f32, 1.5f32], 6), // Chebyshev: 1.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let radius = 1.0; // radius 1 (not squared for Chebyshev) + let mut results = kdtree.within::(&query_point, radius); + + // Sort by distance for easier verification + results.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) + // For `<=5.2.4` found: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) + let found_indices: Vec = results.iter().map(|r| r.item).collect(); + + assert!(found_indices.contains(&0)); + assert!(found_indices.contains(&1)); + assert!(found_indices.contains(&2)); + assert!(found_indices.contains(&3)); + // Should NOT include points with Chebyshev distance > 1 + assert!(!found_indices.contains(&4)); + assert!(!found_indices.contains(&5)); + assert!(!found_indices.contains(&6)); + + // Verify distances + for result in results { + assert!(result.distance <= 1.0 || (result.distance - 1.0).abs() < 0.001); + } + } + + #[test] + fn test_chebyshev_vs_manhattan_ordering() { + let mut kdtree: KdTree = KdTree::new(); + + // Points where Chebyshev and Manhattan differ significantly + let points = [ + ([0.0f32, 0.0f32], 0), // origin + ([3.0f32, 1.0f32], 1), // Chebyshev: 3, Manhattan: 4 + ([1.0f32, 3.0f32], 2), // Chebyshev: 3, Manhattan: 4 + ([2.0f32, 2.0f32], 3), // Chebyshev: 2, Manhattan: 4 + ([4.0f32, 0.5f32], 4), // Chebyshev: 4, Manhattan: 4.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + + let chebyshev_results = kdtree.nearest_n::(&query_point, 4); + let manhattan_results = kdtree.nearest_n::(&query_point, 4); + + // Both should include the origin first + assert_eq!(chebyshev_results[0].item, 0); + assert_eq!(manhattan_results[0].item, 0); + + // With Chebyshev, nearest should be point 3 (distance 2) + // With Manhattan, nearest should be points 1 and 2 (distance 4) + assert_eq!(chebyshev_results[1].item, 3); + assert_eq!(chebyshev_results[1].distance, 2.0); + + // With Manhattan, points 1 and 2 should come before point 3 (which is distance 4) + let manhattan_items: Vec = manhattan_results + .iter() + .skip(1) + .take(3) + .map(|r| r.item) + .collect(); + assert!(manhattan_items.contains(&1) || manhattan_items.contains(&2)); + + // Verify the distance calculations are correct + assert_eq!(chebyshev_results[1].distance, 2.0); // Chebyshev: max(|2-0|, |2-0|) = 2 + assert_eq!(manhattan_results[1].distance, 4.0); // Manhattan: |3-0| + |1-0| = 4 + } + } +} diff --git a/src/float/kdtree.rs b/src/float/kdtree.rs index 4757cb3c..0ea74d6e 100644 --- a/src/float/kdtree.rs +++ b/src/float/kdtree.rs @@ -63,6 +63,7 @@ pub trait Axis: FloatCore + Default + Debug + Copy + Sync + Send + std::ops::Add fn saturating_dist(self, other: Self) -> Self; /// Used in query methods to update the rd value. A saturating add for Fixed and an add for Float + #[deprecated(since = "5.3.0", note = "Use D::accumulate instead")] // TODO: change version number if adding this change - or better so: fully get rid off rd_update fn rd_update(rd: Self, delta: Self) -> Self; } @@ -73,6 +74,7 @@ impl #[inline] fn rd_update(rd: Self, delta: Self) -> Self { + // DEPRECATED: Use D::accumulate instead rd + delta } } diff --git a/src/float/query/nearest_n_within.rs b/src/float/query/nearest_n_within.rs index 6a2c70c9..6fc78994 100644 --- a/src/float/query/nearest_n_within.rs +++ b/src/float/query/nearest_n_within.rs @@ -353,7 +353,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/query/within_unsorted.rs b/src/float/query/within_unsorted.rs index 5d0bab97..226270b9 100644 --- a/src/float/query/within_unsorted.rs +++ b/src/float/query/within_unsorted.rs @@ -209,7 +209,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/query/within_unsorted_iter.rs b/src/float/query/within_unsorted_iter.rs index 023e889c..f20df730 100644 --- a/src/float/query/within_unsorted_iter.rs +++ b/src/float/query/within_unsorted_iter.rs @@ -36,8 +36,15 @@ assert_eq!(within.len(), 2); }; } -impl<'a, A: Axis, T: Content, const K: usize, const B: usize, IDX: Index + Send> - KdTree +impl< + 'a, + 'query, + A: Axis, + T: Content, + const K: usize, + const B: usize, + IDX: Index + Send, + > KdTree where usize: Cast, { @@ -54,6 +61,7 @@ use crate::float::kdtree::ArchivedKdTree; #[cfg(feature = "rkyv")] impl< 'a, + 'query, A: Axis + rkyv::Archive, T: Content + rkyv::Archive, const K: usize, @@ -77,6 +85,7 @@ use crate::float::kdtree::ArchivedR8KdTree; #[cfg(feature = "rkyv_08")] impl< 'a, + 'query, A: Axis + Send + rkyv_08::Archive, T: Content + Send + rkyv_08::Archive, const K: usize, @@ -145,10 +154,25 @@ mod tests { let radius = 0.2; let expected = linear_search(&content_to_add, &query_point, radius); - let result: Vec<_> = tree - .within_unsorted_iter::(&query_point, radius) - .collect(); - assert_eq!(result, expected); + // Store some iterators in a way that the test will fail to compile + // if the lifetime of the iterator is tied to the query as well as to + // the lifetime of the tree + let mut iterators = Vec::new(); + for _ in 0..2 { + // take a copy of query_point to ensure that the lifetime of the + // iterator is tied to the lifetime of the tree and not the lifetime + // of the query + let temp_query = query_point; + + let iter = tree.within_unsorted_iter::(&temp_query, radius); + + iterators.push(iter); + } + + for iter in iterators { + let result: Vec<_> = iter.collect(); + assert_eq!(result, expected); + } let mut rng = rand::rng(); for _i in 0..1000 { diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index be453722..4da31f83 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -237,13 +237,13 @@ where for idx in 0..remainder_items.len() { let mut distance = A::zero(); (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - - if distance < radius { + if distance <= radius { results.add(NearestNeighbour { distance, - item: *unsafe { self.content_items.get_unchecked(idx) }, + item: remainder_items[idx], }); } } @@ -271,11 +271,12 @@ where for idx in 0..remainder_items.len() { let mut distance = A::zero(); (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance < radius { - let item = *unsafe { remainder_items.get_unchecked(idx) }; + if distance <= radius { + let item = remainder_items[idx]; if results.len() < max_qty { results.push(BestNeighbour { distance, item }); } else { @@ -364,16 +365,13 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f64; C]; (0..K).step_by(1).for_each(|dim| { let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); }); - acc } } @@ -450,23 +448,20 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f32; C]; (0..K).step_by(1).for_each(|dim| { let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); }); - acc } } #[cfg(test)] mod test { - use crate::float_leaf_slice::leaf_slice::{LeafFixedSlice, LeafSliceFloat}; + use crate::float_leaf_slice::leaf_slice::{LeafFixedSlice, LeafSlice, LeafSliceFloat}; use crate::{BestNeighbour, NearestNeighbour, SquaredEuclidean}; use std::collections::BinaryHeap; @@ -624,4 +619,35 @@ mod test { ] ); } + + // Test for remainder path processing with non-chunk-aligned sizes (CHUNK_SIZE=32) + // Verifies the fix for using remainder_items[idx] instead of self.content_items[idx] + #[test] + fn test_remainder_processing_finds_correct_item() { + // Size 33 = 1 chunk (32) + 1 remainder + // Item 32 is in the remainder region - if the bug existed, it would return + // self.content_items[0] instead of remainder_items[0] + let mut dim0 = Vec::with_capacity(33); + let mut dim1 = Vec::with_capacity(33); + let mut items = Vec::with_capacity(33); + for i in 0..33 { + dim0.push(i as f64); + dim1.push(0.0f64); + items.push(i as u32); + } + + let slice = LeafSlice { + content_points: [&dim0[..], &dim1[..]], + content_items: &items[..], + }; + + let mut results: BinaryHeap> = BinaryHeap::with_capacity(10); + slice.nearest_n_within::(&[32.0f64, 0.0f64], 4.0f64, &mut results); + + let items_found: Vec<_> = results.iter().map(|n| n.item).collect(); + assert!( + items_found.contains(&32u32), + "Should find item 32 in remainder region" + ); + } } diff --git a/src/hybrid/query/nearest_n.rs b/src/hybrid/query/nearest_n.rs index f6e24fda..9d28118a 100644 --- a/src/hybrid/query/nearest_n.rs +++ b/src/hybrid/query/nearest_n.rs @@ -134,6 +134,7 @@ where } } + #[inline(always)] fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index be5b956b..76505046 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -254,7 +254,7 @@ mod tests { for &(p, item) in content { let dist = manhattan(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index 683c58fe..fb93473f 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -256,7 +256,7 @@ mod tests { for &(p, item) in content { let dist = squared_euclidean(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/immutable/common/generate_immutable_approx_nearest_one.rs b/src/immutable/common/generate_immutable_approx_nearest_one.rs index 56ded4c8..85d34659 100644 --- a/src/immutable/common/generate_immutable_approx_nearest_one.rs +++ b/src/immutable/common/generate_immutable_approx_nearest_one.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_approx_nearest_one { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn approx_nearest_one(&self, query: &[A; K]) -> NearestNeighbour where @@ -85,6 +84,5 @@ macro_rules! generate_immutable_approx_nearest_one { item: best_item, } } - } }; } diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index b37dfd78..e5b58371 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_best_n_within { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn best_n_within( &self, @@ -110,7 +109,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -191,7 +190,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -232,6 +231,5 @@ macro_rules! generate_immutable_best_n_within { results, ); } - } }; } diff --git a/src/immutable/common/generate_immutable_nearest_n.rs b/src/immutable/common/generate_immutable_nearest_n.rs index b0b8f3c7..27b0bd40 100644 --- a/src/immutable/common/generate_immutable_nearest_n.rs +++ b/src/immutable/common/generate_immutable_nearest_n.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_nearest_n { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn nearest_n(&self, query: &[A; K], max_qty: NonZero) -> Vec> where @@ -13,6 +12,5 @@ macro_rules! generate_immutable_nearest_n { { self.nearest_n_within::(query, A::infinity(), max_qty, true) } - } }; } diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 1237fa8f..3d3d8435 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_nearest_n_within { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn nearest_n_within(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool) -> Vec> where @@ -113,7 +112,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; @@ -190,7 +189,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; @@ -229,6 +228,5 @@ macro_rules! generate_immutable_nearest_n_within { results, ); } - } }; } diff --git a/src/immutable/common/generate_immutable_nearest_one.rs b/src/immutable/common/generate_immutable_nearest_one.rs index 16dca00a..e72007ed 100644 --- a/src/immutable/common/generate_immutable_nearest_one.rs +++ b/src/immutable/common/generate_immutable_nearest_one.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_nearest_one { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn nearest_one(&self, query: &[A; K]) -> NearestNeighbour where @@ -110,7 +109,7 @@ macro_rules! generate_immutable_nearest_one { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; @@ -178,7 +177,7 @@ macro_rules! generate_immutable_nearest_one { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; @@ -211,6 +210,5 @@ macro_rules! generate_immutable_nearest_one { &mut nearest.item ); } - } }; } diff --git a/src/immutable/common/generate_immutable_within.rs b/src/immutable/common/generate_immutable_within.rs index b4595ec5..83cc4725 100644 --- a/src/immutable/common/generate_immutable_within.rs +++ b/src/immutable/common/generate_immutable_within.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_within { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within(&self, query: &[A; K], dist: A) -> Vec> where @@ -12,6 +11,5 @@ macro_rules! generate_immutable_within { usize: Cast, { self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true) } - } }; } diff --git a/src/immutable/common/generate_immutable_within_unsorted.rs b/src/immutable/common/generate_immutable_within_unsorted.rs index e602c453..c8910474 100644 --- a/src/immutable/common/generate_immutable_within_unsorted.rs +++ b/src/immutable/common/generate_immutable_within_unsorted.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_within_unsorted { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within_unsorted(&self, query: &[A; K], dist: A) -> Vec> where @@ -12,6 +11,5 @@ macro_rules! generate_immutable_within_unsorted { usize: Cast, { self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false) } - } }; } diff --git a/src/immutable/common/generate_immutable_within_unsorted_iter.rs b/src/immutable/common/generate_immutable_within_unsorted_iter.rs index 5d656cd6..94f673fd 100644 --- a/src/immutable/common/generate_immutable_within_unsorted_iter.rs +++ b/src/immutable/common/generate_immutable_within_unsorted_iter.rs @@ -2,8 +2,7 @@ #[macro_export] macro_rules! generate_immutable_within_unsorted_iter { ($comments:tt) => { - doc_comment! { - concat!$comments, + #[doc = concat!$comments] #[inline] pub fn within_unsorted_iter( &'a self, @@ -84,7 +83,7 @@ macro_rules! generate_immutable_within_unsorted_iter { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -125,6 +124,5 @@ macro_rules! generate_immutable_within_unsorted_iter { gen_scope } - } }; } diff --git a/src/immutable/float/kdtree.rs b/src/immutable/float/kdtree.rs index 3f468e7a..0328d2ed 100644 --- a/src/immutable/float/kdtree.rs +++ b/src/immutable/float/kdtree.rs @@ -492,11 +492,13 @@ where let chunk_length = sort_index.len(); if level > max_stem_level { + let start = + u32::try_from(leaf_items.len()).expect("Too many points: index exceeds u32::MAX"); + let end = u32::try_from(leaf_items.len() + chunk_length) + .expect("Too many points: index exceeds u32::MAX"); + // Write leaf and terminate recursion - leaf_extents.push(( - leaf_items.len() as u32, - (leaf_items.len() + chunk_length) as u32, - )); + leaf_extents.push((start, end)); (0..chunk_length).for_each(|i| { (0..K).for_each(|dim| leaf_points[dim].push(source[sort_index[i]][dim])); diff --git a/src/immutable/float/query/nearest_n_within.rs b/src/immutable/float/query/nearest_n_within.rs index 7b3cfece..d8b1603a 100644 --- a/src/immutable/float/query/nearest_n_within.rs +++ b/src/immutable/float/query/nearest_n_within.rs @@ -231,7 +231,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within.rs b/src/immutable/float/query/within.rs index a04c9f7a..535d56ee 100644 --- a/src/immutable/float/query/within.rs +++ b/src/immutable/float/query/within.rs @@ -206,7 +206,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = Manhattan::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted.rs b/src/immutable/float/query/within_unsorted.rs index 28310c5a..67b4d6c2 100644 --- a/src/immutable/float/query/within_unsorted.rs +++ b/src/immutable/float/query/within_unsorted.rs @@ -204,7 +204,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted_iter.rs b/src/immutable/float/query/within_unsorted_iter.rs index 3f3cceba..dc8aa5bc 100644 --- a/src/immutable/float/query/within_unsorted_iter.rs +++ b/src/immutable/float/query/within_unsorted_iter.rs @@ -201,7 +201,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/lib.rs b/src/lib.rs index c8be8a31..102bd044 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ #![warn(missing_docs)] #![warn(rustdoc::broken_intra_doc_links)] #![warn(rustdoc::private_intra_doc_links)] -#![doc(html_root_url = "https://docs.rs/kiddo/5.2.0")] +#![doc(html_root_url = "https://docs.rs/kiddo/5.2.4")] #![doc(issue_tracker_base_url = "https://github.com/sdd/kiddo/issues/")] //! # Kiddo @@ -35,7 +35,7 @@ //! Add `kiddo` to `Cargo.toml` //! ```toml //! [dependencies] -//! kiddo = "5.2.0" +//! kiddo = "5.2.4" //! ``` //! //! ## Usage @@ -83,8 +83,6 @@ //! //! **NOTE**: Support for rkyv 0.7 is now deprecated and will be removed in Kiddo v6. -#[macro_use] -extern crate doc_comment; extern crate core; #[doc(hidden)] @@ -137,6 +135,7 @@ pub type ImmutableKdTree = immutable::float::kdtree::ImmutableKdTree; pub use best_neighbour::BestNeighbour; +pub use float::distance::Chebyshev; pub use float::distance::Manhattan; pub use float::distance::SquaredEuclidean; pub use nearest_neighbour::NearestNeighbour; diff --git a/src/traits.rs b/src/traits.rs index 9561b804..81da5347 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -107,26 +107,65 @@ pub(crate) fn is_stem_index>(x: IDX) -> bool { x < ::leaf_offset() } -/// Trait that needs to be implemented by any potential distance -/// metric to be used within queries +/// Defines how distances are measured and compared for k-d tree queries. +/// +/// Implement this trait to use custom distance metrics with [`kiddo:KdTree`](crate::KdTree). +/// +/// # Distance Metrics in k-d Trees +/// +/// **Distance aggregation**: How to combine per-dimension distances into a total distance +/// - Sum-based: `dist(p,q) = Ī£ |p[i] - q[i]|` (Manhattan, SquaredEuclidean) +/// - Max-based: `dist(p,q) = max_i |p[i] - q[i]|` (Chebyshev/Lāˆž) +/// +/// # Required Methods +/// +/// - [`dist()`]: Compute total distance between two points +/// - [`dist1()`]: Compute per-dimension distance component +/// - [`accumulate()`]: Aggregate distance components (add or max) +/// pub trait DistanceMetric { - /// returns the distance between two K-d points, as measured - /// by a particular distance metric + /// Returns the distance between two K-d points, as measured by this metric. fn dist(a: &[A; K], b: &[A; K]) -> A; - /// returns the distance between two points along a single axis, - /// as measured by a particular distance metric. + /// Returns the distance between two points along a single dimension. /// - /// (needs to be implemented as it is used by the NN query implementations - /// to extend the minimum acceptable distance for a node when recursing - /// back up the tree) + /// Used internally by NN query implementations to extend the minimum + /// acceptable distance for a node when recursing back up the tree. fn dist1(a: A, b: A) -> A; + + /// Aggregates a distance contribution into a running total. + /// + /// This defines how per-dimension distances combine into a total distance. + /// Choose based on your distance metric: + /// + /// - **Sum-based (L1, L2)**: Use `rd + delta` or `rd.saturating_add(delta)` for fixed-point types + /// - **Max-based (Lāˆž/Chebyshev)**: Use `rd.max(delta)` + /// + /// The implementation should match the mathematical definition of your metric: + /// - Manhattan: `dist(p,q) = Ī£ |p[i] - q[i]|` -> accumulate by adding + /// - SquaredEuclidean: `dist(p,q) = Ī£ (p[i] - q[i])²` -> accumulate by adding + /// - Chebyshev: `dist(p,q) = max_i |p[i] - q[i]|` -> accumulate by taking max + /// - Generalised Minkowski (L_p): `dist(p,q) = (Ī£ |p[i] - q[i]|^p)^(1/p)`. + /// For k-d tree pruning, use the sum of powers: accumulate by adding. + /// Only the limit p → āˆž (Chebyshev) uses `max`. + /// + /// The default implementation uses regular addition (`rd + delta`), which works for + /// both integer and floating-point types. For fixed-point types where overflow is a + /// concern, override this with `rd.saturating_add(delta)`. + fn accumulate(rd: A, delta: A) -> A + where + A: std::ops::Add, + { + rd + delta + } } #[cfg(test)] mod tests { + use super::DistanceMetric; use crate::traits::Index; + use rstest::rstest; #[test] fn test_u16() { @@ -158,4 +197,45 @@ mod tests { (u32::MAX - u32::MAX.overflowing_shr(1).0).saturating_mul(bucket_size); assert_eq!(capacity_with_bucket_size, u32::MAX); } + + struct TestMetricU32; + struct TestMetricI64; + + impl DistanceMetric for TestMetricU32 { + fn dist(_a: &[u32; K], _b: &[u32; K]) -> u32 { + 0 + } + fn dist1(a: u32, _b: u32) -> u32 { + a + } + } + + impl DistanceMetric for TestMetricI64 { + fn dist(_a: &[i64; K], _b: &[i64; K]) -> i64 { + 0 + } + fn dist1(a: i64, _b: i64) -> i64 { + a + } + } + + #[rstest] + #[case(5u32, 3u32, 8u32)] + #[case(10u32, 20u32, 30u32)] + fn test_default_accumulate_u32(#[case] rd: u32, #[case] delta: u32, #[case] expected: u32) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } + + #[rstest] + #[case(10i64, 20i64, 30i64)] + #[case(100i64, 200i64, 300i64)] + fn test_default_accumulate_i64(#[case] rd: i64, #[case] delta: i64, #[case] expected: i64) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } }