Skip to content

Commit 965a9b1

Browse files
authoredNov 30, 2024
handle overlapping points during removal (#62)
1 parent 3164525 commit 965a9b1

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed
 

‎src/kdtree.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,15 @@ impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::Pa
139139
let mut removed = 0;
140140
self.check_point(point.as_ref())?;
141141
if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
142-
while let Some(p_index) = points.iter().position(|x| x == point) {
143-
if &bucket[p_index] == data {
144-
points.remove(p_index);
145-
bucket.remove(p_index);
146-
removed += 1;
147-
self.size -= 1;
148-
}
142+
while let Some(p_index) = points
143+
.iter()
144+
.zip(bucket.iter())
145+
.position(|(p, d)| p == point && d == data)
146+
{
147+
points.remove(p_index);
148+
bucket.remove(p_index);
149+
removed += 1;
150+
self.size -= 1;
149151
}
150152
self.points = Some(points);
151153
self.bucket = Some(bucket);

‎tests/kdtree.rs

+18
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,21 @@ fn handles_remove_no_match() {
401401
vec![(16.0, &4), (36.0, &3)]
402402
);
403403
}
404+
405+
#[test]
406+
fn handles_remove_overlapping_points() {
407+
let a = ([0f64, 0f64], 0);
408+
let b = ([0f64, 0f64], 1);
409+
let mut kdtree = KdTree::new(2);
410+
411+
kdtree.add(a.0, a.1).unwrap();
412+
kdtree.add(b.0, b.1).unwrap();
413+
414+
let num_removed = kdtree.remove(&[0f64, 0f64], &1).unwrap();
415+
assert_eq!(kdtree.size(), 1);
416+
assert_eq!(num_removed, 1);
417+
assert_eq!(
418+
kdtree.nearest(&[0f64, 0f64], 1, &squared_euclidean).unwrap(),
419+
vec![(0.0, &0)]
420+
);
421+
}

0 commit comments

Comments
 (0)