Skip to content

Commit 71d1fa8

Browse files
committed
feat: add shim support for multiple dtypes
1 parent 7feb310 commit 71d1fa8

File tree

7 files changed

+504
-302
lines changed

7 files changed

+504
-302
lines changed

pysrc/fastquadtree/_base_quadtree.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class _BaseQuadTree(Generic[G, HitT, ItemType], ABC):
4848
"_bounds",
4949
"_capacity",
5050
"_count",
51+
"_dtype",
5152
"_max_depth",
5253
"_native",
5354
"_next_id",
@@ -62,7 +63,7 @@ def _new_native(self, bounds: Bounds, capacity: int, max_depth: int | None) -> A
6263
"""Create the native engine instance."""
6364

6465
@classmethod
65-
def _new_native_from_bytes(cls, data: bytes) -> Any:
66+
def _new_native_from_bytes(cls, data: bytes, dtype: str) -> Any:
6667
"""Create the native engine instance from serialized bytes."""
6768

6869
@staticmethod
@@ -79,10 +80,12 @@ def __init__(
7980
*,
8081
max_depth: int | None = None,
8182
track_objects: bool = False,
83+
dtype: str = "f32",
8284
):
8385
self._bounds = bounds
8486
self._max_depth = max_depth
8587
self._capacity = capacity
88+
self._dtype = dtype
8689
self._native = self._new_native(bounds, capacity, max_depth)
8790

8891
self._track_objects = bool(track_objects)
@@ -138,12 +141,13 @@ def to_bytes(self) -> bytes:
138141
return pickle.dumps(self.to_dict())
139142

140143
@classmethod
141-
def from_bytes(cls, data: bytes) -> Self:
144+
def from_bytes(cls, data: bytes, dtype: str = "f32") -> Self:
142145
"""
143-
Deserialize a quadtree from bytes.
146+
Deserialize a quadtree from bytes. Specifiy the dtype if the original tree that was serialized used a non-default dtype.
144147
145148
Args:
146149
data: Bytes representing the serialized quadtree from `to_bytes()`.
150+
dtype: The data type used in the native engine ('f32', 'f64', 'i32', 'i64') when saved to bytes.
147151
148152
Returns:
149153
A new quadtree instance with the same state as when serialized.
@@ -160,7 +164,15 @@ def from_bytes(cls, data: bytes) -> Self:
160164
store_dict = in_dict["store"]
161165

162166
qt = cls.__new__(cls) # type: ignore[call-arg]
163-
qt._native = cls._new_native_from_bytes(core_bytes)
167+
try:
168+
qt._native = cls._new_native_from_bytes(core_bytes, dtype=dtype)
169+
except ValueError as ve:
170+
raise ValueError(
171+
"Failed to deserialize quadtree native core. "
172+
"This may be due to a dtype mismatch. "
173+
"Ensure the dtype used in from_bytes() matches the original tree. "
174+
"Error details: " + str(ve)
175+
) from ve
164176

165177
if store_dict is not None:
166178
qt._store = ObjStore.from_dict(store_dict, qt._make_item)

pysrc/fastquadtree/_item.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from typing import Any, Tuple
55

6-
Bounds = Tuple[float, float, float, float]
6+
Bounds = Tuple[float | int, float | int, float | int, float | int]
77
"""Axis-aligned rectangle as (min_x, min_y, max_x, max_y)."""
88

9-
Point = Tuple[float, float]
9+
Point = Tuple[float | int, float | int]
1010
"""2D point as (x, y)."""
1111

1212

pysrc/fastquadtree/point_quadtree.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55

66
from ._base_quadtree import Bounds, _BaseQuadTree
77
from ._item import Point, PointItem
8-
from ._native import QuadTree as _RustQuadTree # native point tree
8+
from ._native import QuadTree as QuadTreeF32, QuadTreeF64, QuadTreeI32, QuadTreeI64
99

1010
_IdCoord = Tuple[int, float, float]
1111

12+
DTYPE_MAP = {
13+
"f32": QuadTreeF32,
14+
"f64": QuadTreeF64,
15+
"i32": QuadTreeI32,
16+
"i64": QuadTreeI64,
17+
}
18+
1219

1320
class QuadTree(_BaseQuadTree[Point, _IdCoord, PointItem]):
1421
"""
@@ -29,6 +36,7 @@ class QuadTree(_BaseQuadTree[Point, _IdCoord, PointItem]):
2936
capacity: Max number of points per node before splitting.
3037
max_depth: Optional max tree depth. If omitted, engine decides.
3138
track_objects: Enable id <-> object mapping inside Python.
39+
dtype: Data type for coordinates and ids in the native engine. Default is 'f32'. Options are 'f32', 'f64', 'i32', 'i64'.
3240
3341
Raises:
3442
ValueError: If parameters are invalid or inserts are out of bounds.
@@ -41,12 +49,14 @@ def __init__(
4149
*,
4250
max_depth: int | None = None,
4351
track_objects: bool = False,
52+
dtype: str = "f32",
4453
):
4554
super().__init__(
4655
bounds,
4756
capacity,
4857
max_depth=max_depth,
4958
track_objects=track_objects,
59+
dtype=dtype,
5060
)
5161

5262
@overload
@@ -148,14 +158,19 @@ def nearest_neighbors(self, xy: Point, k: int, *, as_items: bool = False):
148158
return out
149159

150160
def _new_native(self, bounds: Bounds, capacity: int, max_depth: int | None) -> Any:
151-
if max_depth is None:
152-
return _RustQuadTree(bounds, capacity)
153-
return _RustQuadTree(bounds, capacity, max_depth=max_depth)
161+
"""Create the native engine instance."""
162+
rust_cls = DTYPE_MAP.get(self._dtype)
163+
if rust_cls is None:
164+
raise ValueError(f"Unsupported dtype: {self._dtype}")
165+
return rust_cls(bounds, capacity, max_depth)
154166

155167
@classmethod
156-
def _new_native_from_bytes(cls, data: bytes) -> Any:
168+
def _new_native_from_bytes(cls, data: bytes, dtype: str = "f32") -> Any:
157169
"""Create a new native engine instance from serialized bytes."""
158-
return _RustQuadTree.from_bytes(data)
170+
rust_cls = DTYPE_MAP.get(dtype)
171+
if rust_cls is None:
172+
raise ValueError(f"Unsupported dtype: {dtype}")
173+
return rust_cls.from_bytes(data)
159174

160175
@staticmethod
161176
def _make_item(id_: int, geom: Point, obj: Any | None) -> PointItem:

pysrc/fastquadtree/rect_quadtree.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,21 @@
55

66
from ._base_quadtree import Bounds, _BaseQuadTree
77
from ._item import RectItem
8-
from ._native import RectQuadTree as _RustRectQuadTree # native rect tree
8+
from ._native import (
9+
RectQuadTree as RectQuadTreeF32,
10+
RectQuadTreeF64,
11+
RectQuadTreeI32,
12+
RectQuadTreeI64,
13+
)
914

10-
_IdRect = Tuple[int, float, float, float, float]
11-
Point = Tuple[float, float] # only for type hints in docstrings
15+
_IdRect = Tuple[int, float | int, float | int, float | int, float | int]
16+
17+
DTYPE_MAP = {
18+
"f32": RectQuadTreeF32,
19+
"f64": RectQuadTreeF64,
20+
"i32": RectQuadTreeI32,
21+
"i64": RectQuadTreeI64,
22+
}
1223

1324

1425
class RectQuadTree(_BaseQuadTree[Bounds, _IdRect, RectItem]):
@@ -30,6 +41,7 @@ class RectQuadTree(_BaseQuadTree[Bounds, _IdRect, RectItem]):
3041
capacity: Max number of points per node before splitting.
3142
max_depth: Optional max tree depth. If omitted, engine decides.
3243
track_objects: Enable id <-> object mapping inside Python.
44+
dtype: Data type for coordinates and ids in the native engine. Default is 'f32'. Options are 'f32', 'f64', 'i32', 'i64'.
3345
3446
Raises:
3547
ValueError: If parameters are invalid or inserts are out of bounds.
@@ -42,12 +54,14 @@ def __init__(
4254
*,
4355
max_depth: int | None = None,
4456
track_objects: bool = False,
57+
dtype: str = "f32",
4558
):
4659
super().__init__(
4760
bounds,
4861
capacity,
4962
max_depth=max_depth,
5063
track_objects=track_objects,
64+
dtype=dtype,
5165
)
5266

5367
@overload
@@ -84,14 +98,19 @@ def query(
8498
return self._store.get_many_by_ids(self._native.query_ids(rect))
8599

86100
def _new_native(self, bounds: Bounds, capacity: int, max_depth: int | None) -> Any:
87-
if max_depth is None:
88-
return _RustRectQuadTree(bounds, capacity)
89-
return _RustRectQuadTree(bounds, capacity, max_depth=max_depth)
101+
"""Create the native engine instance."""
102+
rust_cls = DTYPE_MAP.get(self._dtype)
103+
if rust_cls is None:
104+
raise ValueError(f"Unsupported dtype: {self._dtype}")
105+
return rust_cls(bounds, capacity, max_depth)
90106

91107
@classmethod
92-
def _new_native_from_bytes(cls, data: bytes) -> Any:
108+
def _new_native_from_bytes(cls, data: bytes, dtype: str = "f32") -> Any:
93109
"""Create a new native engine instance from serialized bytes."""
94-
return _RustRectQuadTree.from_bytes(data)
110+
rust_cls = DTYPE_MAP.get(dtype)
111+
if rust_cls is None:
112+
raise ValueError(f"Unsupported dtype: {dtype}")
113+
return rust_cls.from_bytes(data)
95114

96115
@staticmethod
97116
def _make_item(id_: int, geom: Bounds, obj: Any | None) -> RectItem:

0 commit comments

Comments
 (0)