Skip to content

Commit dcfd709

Browse files
committed
feat: require numpy arrays in insert many to use the same dtype as the quadtree
1 parent 49b7e9a commit dcfd709

File tree

8 files changed

+67
-13
lines changed

8 files changed

+67
-13
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fastquadtree"
3-
version = "1.3.0"
3+
version = "1.3.1"
44
edition = "2021"
55

66
[lib]

pysrc/fastquadtree/_base_quadtree.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929
HitT = TypeVar("HitT") # raw native tuple, e.g. (id,x,y) or (id,x0,y0,x1,y1)
3030
ItemType = TypeVar("ItemType", bound=Item) # e.g. PointItem or RectItem
3131

32+
# Quadtree dtype to numpy dtype mapping
33+
QUADTREE_DTYPE_TO_NP_DTYPE = {
34+
"f32": "float32",
35+
"f64": "float64",
36+
"i32": "int32",
37+
"i64": "int64",
38+
}
39+
3240

3341
def _is_np_array(x: Any) -> bool:
3442
mod = getattr(x.__class__, "__module__", "")
@@ -250,7 +258,7 @@ def insert_many(
250258
) -> int:
251259
"""
252260
Bulk insert with auto-assigned contiguous ids. Faster than inserting one-by-one.<br>
253-
Can accept either a Python sequence of geometries or a NumPy array of shape (N,2) or (N,4) with dtype float32.
261+
Can accept either a Python sequence of geometries or a NumPy array of shape (N,2) or (N,4) with a dtype that matches the quadtree's dtype.
254262
255263
If tracking is enabled, the objects will be bulk stored internally.
256264
If no objects are provided, the items will have obj=None (if tracking).
@@ -289,8 +297,12 @@ def insert_many(
289297
if geoms.size == 0:
290298
return 0
291299

292-
if geoms.dtype != _np.float32:
293-
raise TypeError("Numpy array must use dtype float32")
300+
# Check if dtype matches quadtree dtype
301+
expected_np_dtype = QUADTREE_DTYPE_TO_NP_DTYPE.get(self._dtype)
302+
if geoms.dtype != expected_np_dtype:
303+
raise TypeError(
304+
f"Numpy array dtype {geoms.dtype} does not match quadtree dtype {self._dtype}"
305+
)
294306

295307
if self._store is None:
296308
# Simple contiguous path with native bulk insert

pysrc/fastquadtree/point_quadtree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,15 @@ def _new_native(self, bounds: Bounds, capacity: int, max_depth: int | None) -> A
161161
"""Create the native engine instance."""
162162
rust_cls = DTYPE_MAP.get(self._dtype)
163163
if rust_cls is None:
164-
raise ValueError(f"Unsupported dtype: {self._dtype}")
164+
raise TypeError(f"Unsupported dtype: {self._dtype}")
165165
return rust_cls(bounds, capacity, max_depth)
166166

167167
@classmethod
168168
def _new_native_from_bytes(cls, data: bytes, dtype: str = "f32") -> Any:
169169
"""Create a new native engine instance from serialized bytes."""
170170
rust_cls = DTYPE_MAP.get(dtype)
171171
if rust_cls is None:
172-
raise ValueError(f"Unsupported dtype: {dtype}")
172+
raise TypeError(f"Unsupported dtype: {dtype}")
173173
return rust_cls.from_bytes(data)
174174

175175
@staticmethod

pysrc/fastquadtree/rect_quadtree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ def _new_native(self, bounds: Bounds, capacity: int, max_depth: int | None) -> A
101101
"""Create the native engine instance."""
102102
rust_cls = DTYPE_MAP.get(self._dtype)
103103
if rust_cls is None:
104-
raise ValueError(f"Unsupported dtype: {self._dtype}")
104+
raise TypeError(f"Unsupported dtype: {self._dtype}")
105105
return rust_cls(bounds, capacity, max_depth)
106106

107107
@classmethod
108108
def _new_native_from_bytes(cls, data: bytes, dtype: str = "f32") -> Any:
109109
"""Create a new native engine instance from serialized bytes."""
110110
rust_cls = DTYPE_MAP.get(dtype)
111111
if rust_cls is None:
112-
raise ValueError(f"Unsupported dtype: {dtype}")
112+
raise TypeError(f"Unsupported dtype: {dtype}")
113113
return rust_cls.from_bytes(data)
114114

115115
@staticmethod

tests/test_insert_many_numpy.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,48 @@ def test_type_error_on_wrong_dtype():
5252
assert len(qt) == 0
5353

5454

55+
def test_non_default_dtype_insert_many():
56+
qt = QuadTree(BOUNDS, capacity=8, track_objects=True, dtype="f64")
57+
points = np.array([[10, 10], [20, 20], [30, 30]], dtype=np.float64)
58+
n = qt.insert_many(points)
59+
assert n == 3
60+
assert len(qt) == 3
61+
62+
raw = qt.query((0, 0, 40, 40), as_items=False)
63+
64+
assert len(raw) == 3
65+
# ids and positions match
66+
m_raw = {t[0]: (t[1], t[2]) for t in raw}
67+
for t in raw:
68+
assert (t[1], t[2]) == m_raw[t[0]]
69+
70+
71+
def test_non_default_quadtree_dtype_with_default_numpy_dtype_raises():
72+
qt = QuadTree(BOUNDS, capacity=8, track_objects=True, dtype="f64")
73+
points = np.array([[10, 10], [20, 20], [30, 30]], dtype=np.float32) # Wrong dtype
74+
with pytest.raises(TypeError):
75+
qt.insert_many(points)
76+
assert len(qt) == 0
77+
78+
79+
def test_unspported_quadtree_dtype_insert_many_raises():
80+
qt = QuadTree(BOUNDS, capacity=8, track_objects=True, dtype="i32")
81+
points = np.array([[10, 10], [20, 20], [30, 30]], dtype=np.float32) # Wrong dtype
82+
with pytest.raises(TypeError):
83+
qt.insert_many(points)
84+
assert len(qt) == 0
85+
86+
points = np.array(
87+
[[10, 10], [20, 20], [30, 30]], dtype=np.uint32
88+
) # unsupported dtype
89+
with pytest.raises(TypeError):
90+
qt.insert_many(points)
91+
92+
# QT is also unsupported
93+
with pytest.raises(TypeError):
94+
qt = QuadTree(BOUNDS, capacity=8, track_objects=True, dtype="u32")
95+
96+
5597
def test_insert_empty_numpy_array():
5698
qt = QuadTree(BOUNDS, capacity=8, track_objects=True)
5799
points = np.empty((0, 2), dtype=np.float32)

tests/test_point_quadtree_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55

66
def test_unsupported_dtype():
77
"""Test that providing an unsupported dtype raises ValueError."""
8-
with pytest.raises(ValueError):
8+
with pytest.raises(TypeError):
99
QuadTree((0, 0, 100, 100), capacity=4, track_objects=True, dtype="f128") # type: ignore
1010

1111
# From bytes
1212
qt = QuadTree((0, 0, 100, 100), capacity=4, track_objects=True, dtype="f32")
1313
data = qt.to_bytes()
14-
with pytest.raises(ValueError):
14+
with pytest.raises(TypeError):
1515
QuadTree.from_bytes(data, dtype="f128") # type: ignore
1616

1717

tests/test_rect_quadtree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def test_accurate_obj_output_with_tracking():
195195

196196
def test_unsupported_dtype():
197197
"""Test that providing an unsupported dtype raises ValueError."""
198-
with pytest.raises(ValueError):
198+
with pytest.raises(TypeError):
199199
rq.RectQuadTree(
200200
b_to_float(0, 0, 100, 100), capacity=4, track_objects=True, dtype="f128"
201201
) # type: ignore
@@ -205,7 +205,7 @@ def test_unsupported_dtype():
205205
b_to_float(0, 0, 100, 100), capacity=4, track_objects=True, dtype="f32"
206206
)
207207
data = qt.to_bytes()
208-
with pytest.raises(ValueError):
208+
with pytest.raises(TypeError):
209209
rq.RectQuadTree.from_bytes(data, dtype="f128") # type: ignore
210210

211211

0 commit comments

Comments
 (0)