Skip to content

Commit ee18a86

Browse files
committed
feat: major optimizations against STRtree
1 parent c73ea7b commit ee18a86

File tree

13 files changed

+127
-63
lines changed

13 files changed

+127
-63
lines changed

README.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,27 @@ fastquadtree **outperforms** all other quadtree Python packages, including the R
3535
![Throughput](https://raw.githubusercontent.com/Elan456/fastquadtree/main/assets/quadtree_bench_throughput.png)
3636

3737
### Summary (largest dataset, PyQtree baseline)
38-
- Points: **500,000**, Queries: **500**
38+
- Points: **250,000**, Queries: **500**
3939
--------------------
40-
- Fastest total: **fastquadtree** at **1.591 s**
40+
- Fastest total: **fastquadtree** at **0.120 s**
4141

4242
| Library | Build (s) | Query (s) | Total (s) | Speed vs PyQtree |
4343
|---|---:|---:|---:|---:|
44-
| fastquadtree | 0.165 | 1.427 | 1.591 | 5.09× |
45-
| Rtree | 1.320 | 2.369 | 3.688 | 2.20× |
46-
| PyQtree | 2.687 | 5.415 | 8.102 | 1.00× |
47-
| nontree-QuadTree | 1.284 | 9.891 | 11.175 | 0.73× |
48-
| quads | 2.346 | 10.129 | 12.475 | 0.65× |
49-
| e-pyquadtree | 1.795 | 11.855 | 13.650 | 0.59× |
44+
| fastquadtree | 0.031 | 0.089 | 0.120 | 14.64× |
45+
| Shapely STRtree | 0.179 | 0.100 | 0.279 | 6.29× |
46+
| nontree-QuadTree | 0.595 | 0.605 | 1.200 | 1.46× |
47+
| Rtree | 0.961 | 0.300 | 1.261 | 1.39× |
48+
| e-pyquadtree | 1.005 | 0.660 | 1.665 | 1.05× |
49+
| PyQtree | 1.492 | 0.263 | 1.755 | 1.00× |
50+
| quads | 1.407 | 0.484 | 1.890 | 0.93× |
51+
52+
#### Benchmark Configuration
53+
| Parameter | Value |
54+
|---|---:|
55+
| Bounds | (0, 0, 1000, 1000) |
56+
| Max points per node | 128 |
57+
| Max depth | 16 |
58+
| Queries per experiment | 500 |
5059

5160
## Install
5261

5.77 KB
Loading

assets/quadtree_bench_time.png

34.1 KB
Loading

benchmarks/quadtree_bench/engines.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
from typing import Any, Callable, Dict, List, Optional, Tuple
99

10+
import numpy as np
1011
from pyqtree import Index as PyQTree # Pyqtree
1112

1213
# Built-in engines (always available in this repo)
1314
from pyquadtree.quadtree import QuadTree as EPyQuadTree # e-pyquadtree
15+
from shapely import box as shp_box, points # Shapely 2.x
16+
from shapely.strtree import STRtree
1417

1518
from fastquadtree import QuadTree as RustQuadTree # fastquadtree
1619

@@ -243,14 +246,9 @@ def _create_strtree_engine(
243246
) -> Optional[Engine]:
244247
"""Create engine adapter for Shapely STRtree (optional)."""
245248

246-
from shapely import box as shp_box, points # Shapely 2.x
247-
from shapely.strtree import STRtree
248-
249249
def build(points_list: List[Tuple[int, int]]):
250250
# Build geometries efficiently
251251

252-
import numpy as np
253-
254252
xs = np.fromiter(
255253
(x for x, _ in points_list), dtype="float32", count=len(points_list)
256254
)
@@ -296,15 +294,15 @@ def get_engines(
296294
# Always available engines
297295
engines = {
298296
"fastquadtree": _create_fastquadtree_engine(bounds, max_points, max_depth),
299-
# "e-pyquadtree": _create_e_pyquadtree_engine(bounds, max_points, max_depth),
297+
"e-pyquadtree": _create_e_pyquadtree_engine(bounds, max_points, max_depth),
300298
"PyQtree": _create_pyqtree_engine(bounds, max_points, max_depth),
301299
# "Brute force": _create_brute_force_engine(bounds, max_points, max_depth), # Brute force doesn't scale well on the graphs so omit it from the main set
302300
}
303301

304302
# Optional engines (only include if import succeeded)
305303
optional_engines = [
306-
# _create_quads_engine,
307-
# _create_nontree_engine,
304+
_create_quads_engine,
305+
_create_nontree_engine,
308306
_create_rtree_engine,
309307
_create_strtree_engine,
310308
]

benchmarks/quadtree_bench/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def main():
2525
parser.add_argument(
2626
"--max-points",
2727
type=int,
28-
default=20,
28+
default=128,
2929
help="Maximum points per node before splitting",
3030
)
31-
parser.add_argument("--max-depth", type=int, default=10, help="Maximum tree depth")
31+
parser.add_argument("--max-depth", type=int, default=16, help="Maximum tree depth")
3232
parser.add_argument(
3333
"--n-queries", type=int, default=500, help="Number of queries per experiment"
3434
)
@@ -41,7 +41,7 @@ def main():
4141
parser.add_argument(
4242
"--max-experiment-points",
4343
type=int,
44-
default=500_000,
44+
default=250_000,
4545
help="Maximum number of points in largest experiment",
4646
)
4747
parser.add_argument(

benchmarks/quadtree_bench/runner.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ class BenchmarkConfig:
2323
"""Configuration for benchmark runs."""
2424

2525
bounds: Tuple[int, int, int, int] = (0, 0, 1000, 1000)
26-
max_points: int = 20 # node capacity where supported
27-
max_depth: int = 10 # depth cap for fairness where supported
28-
n_queries: int = 500 # queries per experiment
26+
max_points: int = 64 # node capacity where supported
27+
max_depth: int = 1_000 # depth cap for fairness where supported
28+
n_queries: int = 100 # queries per experiment
2929
repeats: int = 3 # median over repeats
3030
rng_seed: int = 42 # random seed for reproducibility
31-
max_experiment_points: int = 500_000
31+
max_experiment_points: int = 100_000
3232

3333
def __post_init__(self):
3434
"""Generate experiment point sizes."""
@@ -69,8 +69,8 @@ def generate_queries(
6969
for _ in range(m):
7070
x = rng.randint(x_min, x_max)
7171
y = rng.randint(y_min, y_max)
72-
w = rng.randint(0, x_max - x)
73-
h = rng.randint(0, y_max - y)
72+
w = rng.randint(0, x_max - x) // rng.randint(1, 8)
73+
h = rng.randint(0, y_max - y) // rng.randint(1, 8)
7474
queries.append((x, y, x + w, y + h))
7575
return queries
7676

@@ -306,3 +306,12 @@ def rel_speed(name: str) -> str:
306306
print(f"| {name:12} | {fmt(b)} | {fmt(q)} | {fmt(t)} | {rel_speed(name)} |")
307307

308308
print("")
309+
310+
# Config table
311+
print("#### Benchmark Configuration")
312+
print("| Parameter | Value |")
313+
print("|---|---:|")
314+
print(f"| Bounds | {config.bounds} |")
315+
print(f"| Max points per node | {config.max_points} |")
316+
print(f"| Max depth | {config.max_depth} |")
317+
print(f"| Queries per experiment | {config.n_queries} |")

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ module-name = "fastquadtree._native"
4141
compatibility = "manylinux2014"
4242

4343
[tool.pytest.ini_options]
44-
addopts = "--cov=fastquadtree --cov-branch --cov-report=xml --cov-fail-under=95"
44+
addopts = "--cov=fastquadtree --cov-branch --cov-report=xml --cov-fail-under=100"
4545
testpaths = ["tests"] # still run tests
4646

4747
[tool.coverage.run]
@@ -110,6 +110,7 @@ ignore = [
110110
"PLR0915",
111111
"PLR0913",
112112
"PLR0912",
113+
"PLC0415",
113114
]
114115

115116
# Make pytest files less strict where asserts and fixtures are common.

pysrc/fastquadtree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def nearest_neighbor(self, xy: Point, *, as_item: bool = False):
259259

260260
if self._items is None:
261261
raise ValueError("Cannot return result as item with track_objects=False")
262-
id_, x, y = t
262+
id_, _x, _y = t
263263
item = self._items.by_id(id_)
264264
if item is None:
265265
raise RuntimeError(

src/lib.rs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub use crate::geom::{Point, Rect, dist_sq_point_to_rect, dist_sq_points};
77
pub use crate::quadtree::{Item, QuadTree};
88

99
use pyo3::prelude::*;
10-
use pyo3::types::{PyList, PyTuple};
10+
use pyo3::types::{PyList};
1111

1212
fn item_to_tuple(it: Item) -> (u64, f32, f32) {
1313
(it.id, it.point.x, it.point.y)
@@ -57,22 +57,12 @@ impl PyQuadTree {
5757
// Public behavior is unchanged: returns list[(id, x, y)].
5858
pub fn query<'py>(&self, py: Python<'py>, rect: (f32, f32, f32, f32)) -> Bound<'py, PyList> {
5959
let (min_x, min_y, max_x, max_y) = rect;
60-
let items = self.inner.query(Rect { min_x, min_y, max_x, max_y }); // Vec<Item>
61-
62-
// Preallocate to reduce re-allocations
63-
let mut objs: Vec<PyObject> = Vec::with_capacity(items.len());
64-
for it in items {
65-
let tup = PyTuple::new_bound(py, &[
66-
it.id.into_py(py),
67-
it.point.x.into_py(py),
68-
it.point.y.into_py(py),
69-
]);
70-
objs.push(tup.into_py(py));
71-
}
72-
73-
PyList::new_bound(py, &objs)
60+
let tuples = self.inner.query(Rect { min_x, min_y, max_x, max_y });
61+
// PyO3 will turn Vec<(u64,f32,f32)> into a Python list of tuples
62+
PyList::new_bound(py, &tuples)
7463
}
7564

65+
7666
pub fn nearest_neighbor(&self, xy: (f32, f32)) -> Option<(u64, f32, f32)> {
7767
let (x, y) = xy;
7868
self.inner.nearest_neighbor(Point { x, y }).map(item_to_tuple)

src/quadtree.rs

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,26 +121,83 @@ impl QuadTree {
121121
self.children = Some(Box::new(kids));
122122
}
123123

124-
pub fn query(&self, range: Rect) -> Vec<Item> {
125-
let mut out = Vec::new();
126-
let mut stack: SmallVec<[&QuadTree; 16]> = SmallVec::new();
127-
stack.push(self);
128-
129-
while let Some(node) = stack.pop() {
130-
for it in &node.items {
131-
if range.contains(&it.point) {
132-
out.push(*it);
124+
#[inline(always)]
125+
fn rect_contains_rect(a: &Rect, b: &Rect) -> bool {
126+
a.min_x <= b.min_x && a.min_y <= b.min_y &&
127+
a.max_x >= b.max_x && a.max_y >= b.max_y
128+
}
129+
130+
pub fn query(&self, range: Rect) -> Vec<(u64, f32, f32)> {
131+
#[derive(Copy, Clone)]
132+
enum Mode { Filter, ReportAll }
133+
134+
// Hoist bounds for tight leaf checks
135+
let rx0 = range.min_x;
136+
let ry0 = range.min_y;
137+
let rx1 = range.max_x;
138+
let ry1 = range.max_y;
139+
140+
let mut out: Vec<(u64, f32, f32)> = Vec::with_capacity(128);
141+
let mut stack: SmallVec<[(&QuadTree, Mode); 64]> = SmallVec::new();
142+
stack.push((self, Mode::Filter));
143+
144+
while let Some((node, mode)) = stack.pop() {
145+
match mode {
146+
Mode::ReportAll => {
147+
if let Some(children) = node.children.as_ref() {
148+
// Entire subtree is inside the query.
149+
// No filtering, just recurse in ReportAll.
150+
stack.push((&children[0], Mode::ReportAll));
151+
stack.push((&children[1], Mode::ReportAll));
152+
stack.push((&children[2], Mode::ReportAll));
153+
stack.push((&children[3], Mode::ReportAll));
154+
} else {
155+
// Leaf: append all items, no per-point test
156+
let items = &node.items;
157+
out.reserve(items.len());
158+
out.extend(items.iter().map(|it| (it.id, it.point.x, it.point.y)));
159+
}
133160
}
134-
}
135-
if let Some(children) = node.children.as_ref() {
136-
// Push children that intersect the query range
137-
for child in children.iter() {
138-
if range.intersects(&child.boundary) {
139-
stack.push(child);
161+
162+
Mode::Filter => {
163+
// Node cull
164+
if !range.intersects(&node.boundary) {
165+
continue;
166+
}
167+
168+
// Full cover: switch to ReportAll
169+
if Self::rect_contains_rect(&range, &node.boundary) {
170+
stack.push((node, Mode::ReportAll));
171+
continue;
172+
}
173+
174+
// Partial overlap
175+
if let Some(children) = node.children.as_ref() {
176+
// Only push intersecting children
177+
let c0 = &children[0];
178+
if range.intersects(&c0.boundary) { stack.push((c0, Mode::Filter)); }
179+
let c1 = &children[1];
180+
if range.intersects(&c1.boundary) { stack.push((c1, Mode::Filter)); }
181+
let c2 = &children[2];
182+
if range.intersects(&c2.boundary) { stack.push((c2, Mode::Filter)); }
183+
let c3 = &children[3];
184+
if range.intersects(&c3.boundary) { stack.push((c3, Mode::Filter)); }
185+
} else {
186+
// Leaf scan with tight predicate
187+
let items = &node.items;
188+
// Reserve a little to reduce reallocs if many will pass
189+
out.reserve(items.len().min(64));
190+
for it in items {
191+
let p = &it.point;
192+
if p.x >= rx0 && p.x < rx1 && p.y >= ry0 && p.y < ry1 {
193+
out.push((it.id, p.x, p.y));
194+
}
195+
}
140196
}
141197
}
142198
}
143199
}
200+
144201
out
145202
}
146203

0 commit comments

Comments
 (0)