diff --git a/benchmarks/test_bench_base.py b/benchmarks/test_bench_base.py index b1da22bf3..1a58e5eb1 100644 --- a/benchmarks/test_bench_base.py +++ b/benchmarks/test_bench_base.py @@ -43,47 +43,47 @@ def setup_class(self): "points_simple", { "geom_type": "Point", - "target_rows": num_geoms, + "num_rows": num_geoms, }, ), ( "segments_large", { "geom_type": "LineString", - "target_rows": num_geoms, - "vertices_per_linestring_range": [2, 10], + "num_rows": num_geoms, + "num_vertices": [2, 10], }, ), ( "polygons_simple", { "geom_type": "Polygon", - "target_rows": num_geoms, - "vertices_per_linestring_range": [10, 10], + "num_rows": num_geoms, + "num_vertices": [10, 10], }, ), ( "polygons_complex", { "geom_type": "Polygon", - "target_rows": num_geoms, - "vertices_per_linestring_range": [500, 500], + "num_rows": num_geoms, + "num_vertices": [500, 500], }, ), ( "collections_simple", { "geom_type": "GeometryCollection", - "target_rows": num_geoms, - "vertices_per_linestring_range": [10, 10], + "num_rows": num_geoms, + "num_vertices": [10, 10], }, ), ( "collections_complex", { "geom_type": "GeometryCollection", - "target_rows": num_geoms, - "vertices_per_linestring_range": [500, 500], + "num_rows": num_geoms, + "num_vertices": [500, 500], }, ), ]: @@ -97,7 +97,7 @@ def setup_class(self): { "seed": 42, "bounds": [0.0, 0.0, 80.0, 100.0], # Slightly left-leaning - "size_range": [ + "size": [ 1.0, 15.0, ], # Medium-sized geometries for good intersection chance @@ -110,7 +110,7 @@ def setup_class(self): { "seed": 43, "bounds": [20.0, 0.0, 100.0, 100.0], # Slightly right-leaning - "size_range": [1.0, 15.0], # Same size range for fair comparison + "size": [1.0, 15.0], # Same size range for fair comparison } ) diff --git a/benchmarks/test_knn.py b/benchmarks/test_knn.py index 18fcc02fd..125140d04 100644 --- a/benchmarks/test_knn.py +++ b/benchmarks/test_knn.py @@ -30,9 +30,9 @@ def setup_class(self): # Create building-like polygons (index side - fewer, larger geometries) building_options = { "geom_type": "Polygon", - "target_rows": 2_000, - "vertices_per_linestring_range": [4, 8], - "size_range": [0.001, 0.01], + "num_rows": 2_000, + "num_vertices": [4, 8], + "size": [0.001, 0.01], "seed": 42, } @@ -51,7 +51,7 @@ def setup_class(self): # Create trip pickup points (probe side) trip_options = { "geom_type": "Point", - "target_rows": 10_000, + "num_rows": 10_000, "seed": 43, } diff --git a/python/sedonadb/python/sedonadb/context.py b/python/sedonadb/python/sedonadb/context.py index 714d73339..a3a624ac5 100644 --- a/python/sedonadb/python/sedonadb/context.py +++ b/python/sedonadb/python/sedonadb/context.py @@ -14,15 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import os import sys +from functools import cached_property from pathlib import Path from typing import Any, Dict, Iterable, Literal, Optional, Union from sedonadb._lib import InternalContext, configure_proj_shared +from sedonadb._options import Options from sedonadb.dataframe import DataFrame, _create_data_frame +from sedonadb.functions import Functions from sedonadb.utility import sedona # noqa: F401 -from sedonadb._options import Options class SedonaContext: @@ -272,6 +275,11 @@ def register_udf(self, udf: Any): """ self._impl.register_udf(udf) + @cached_property + def funcs(self) -> Functions: + """Access Python wrappers for SedonaDB functions""" + return Functions(self) + def connect() -> SedonaContext: """Create a new [SedonaContext][sedonadb.context.SedonaContext]""" diff --git a/python/sedonadb/python/sedonadb/functions/__init__.py b/python/sedonadb/python/sedonadb/functions/__init__.py new file mode 100644 index 000000000..cad3c0f7d --- /dev/null +++ b/python/sedonadb/python/sedonadb/functions/__init__.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from functools import cached_property +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sedonadb.functions.table import TableFunctions + + +class Functions: + """Functions accessor + + This class provides Pythonic wrappers to call SedonaDB functions + given a specific SedonaDB context. + """ + + def __init__(self, ctx): + self._ctx = ctx + + @cached_property + def table(self) -> "TableFunctions": + """Access SedonaDB Table functions""" + from sedonadb.functions.table import TableFunctions + + return TableFunctions(self._ctx) diff --git a/python/sedonadb/python/sedonadb/functions/table.py b/python/sedonadb/python/sedonadb/functions/table.py new file mode 100644 index 000000000..30a2f285d --- /dev/null +++ b/python/sedonadb/python/sedonadb/functions/table.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Optional, Literal, Union, Tuple, Iterable + +from sedonadb.dataframe import DataFrame +from sedonadb.utility import sedona # noqa: F401 + + +class TableFunctions: + def __init__(self, ctx): + self._ctx = ctx + + def sd_random_geometry( + self, + geom_type: Optional[ + Literal[ + "Geometry", + "Point", + "LineString", + "Polygon", + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", + ] + ] = None, + num_rows: Optional[int] = None, + *, + num_vertices: Union[int, Tuple[int, int], None] = None, + num_parts: Union[int, Tuple[int, int], None] = None, + size: Union[float, Tuple[float, float], None] = None, + bounds: Optional[Iterable[float]] = None, + hole_rate: Optional[float] = None, + empty_rate: Optional[float] = None, + null_rate: Optional[float] = None, + seed: Optional[int] = None, + ) -> DataFrame: + """ + Generate a DataFrame with random geometries for testing purposes. + This function creates a DataFrame containing randomly generated geometries with + configurable parameters for geometry type, size, complexity, and spatial distribution. + Returns a DataFrame with columns 'id', 'dist', and 'geometry' containing randomly + generated geometries and distances. + + Parameters + ---------- + geom_type : str, default "Point" + The type of geometry to generate. One of "Geometry", + "Point", "LineString", "Polygon", "MultiPoint", "MultiLineString", + "MultiPolygon", or "GeometryCollection". + num_rows : int, default 1024 + Number of rows to generate. + num_vertices : int or tuple of (int, int), default 4 + Number of vertices per geometry. If a tuple, specifies (min, max) range. + num_parts : int or tuple of (int, int), default (1, 3) + Number of parts for multi-geometries. If a tuple, specifies (min, max) range. + size : float or tuple of (float, float), default (1.0, 10.0) + Spatial size of geometries. If a tuple, specifies (min, max) range. + bounds : iterable of float, default [0.0, 0.0, 100.0, 100.0] + Spatial bounds as [xmin, ymin, xmax, ymax] to constrain generated geometries. + hole_rate : float, default 0.0 + Rate of polygons with holes, between 0.0 and 1.0. + empty_rate : float, default 0.0 + Rate of empty geometries, between 0.0 and 1.0. + null_rate : float, default 0.0 + Rate of null geometries, between 0.0 and 1.0. + seed : int, optional + Random seed for reproducible geometry generation. If omitted, the result is + non-deterministic. + + Examples + -------- + >>> sd = sedona.db.connect() + >>> sd.funcs.table.sd_random_geometry("Point", 1, seed=938).show() + ┌───────┬───────────────────┬────────────────────────────────────────────┐ + │ id ┆ dist ┆ geometry │ + │ int32 ┆ float64 ┆ geometry │ + ╞═══════╪═══════════════════╪════════════════════════════════════════════╡ + │ 0 ┆ 58.86528701627309 ┆ POINT(94.77686827801787 17.65107885959438) │ + └───────┴───────────────────┴────────────────────────────────────────────┘ + """ + + args = { + "bounds": bounds, + "empty_rate": empty_rate, + "geom_type": geom_type, + "null_rate": null_rate, + "num_parts": num_parts, + "hole_rate": hole_rate, + "seed": seed, + "size": size, + "num_rows": num_rows, + "num_vertices": num_vertices, + } + + args = {k: v for k, v in args.items() if v is not None} + + return self._ctx.sql(f"SELECT * FROM sd_random_geometry('{json.dumps(args)}')") diff --git a/python/sedonadb/python/sedonadb/testing.py b/python/sedonadb/python/sedonadb/testing.py index 83533dd6f..ec8b870a3 100644 --- a/python/sedonadb/python/sedonadb/testing.py +++ b/python/sedonadb/python/sedonadb/testing.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os import math +import os import warnings from pathlib import Path from typing import TYPE_CHECKING, List, Tuple @@ -29,6 +29,17 @@ import sedonadb +def random_geometry(*args, **kwargs) -> "sedonadb.dataframe.DataFrame": + """ + Generate a DataFrame with random geometries for testing purposes by + calling sd_random_geometry() on an isolated SedonaDB session. + """ + import sedonadb + + sd = sedonadb.connect() + return sd.funcs.table.sd_random_geometry(*args, **kwargs) + + def skip_if_not_exists(path: Path): """Skip a test using pytest.skip() if path does not exist diff --git a/python/sedonadb/tests/test_funcs.py b/python/sedonadb/tests/test_funcs.py new file mode 100644 index 000000000..2c6fb9956 --- /dev/null +++ b/python/sedonadb/tests/test_funcs.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def test_random_geometry(con): + df = con.funcs.table.sd_random_geometry("Point", 5, seed=99873) + + # Ensure we produce the correct number of rows + assert df.count() == 5 + + # Ensure the output is reproducible + assert df.to_arrow_table() == df.to_arrow_table() diff --git a/python/sedonadb/tests/test_knnjoin.py b/python/sedonadb/tests/test_knnjoin.py index 485f468fd..331bcf1f2 100644 --- a/python/sedonadb/tests/test_knnjoin.py +++ b/python/sedonadb/tests/test_knnjoin.py @@ -17,7 +17,7 @@ import pytest import json -from sedonadb.testing import PostGIS, SedonaDB +from sedonadb.testing import PostGIS, SedonaDB, random_geometry @pytest.mark.parametrize("k", [1, 3, 5]) @@ -28,28 +28,11 @@ def test_knn_join_basic(k): PostGIS.create_or_skip() as eng_postgis, ): # Create query points (probe side) - point_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 20, - "seed": 42, - } - ) - df_points = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{point_options}') LIMIT 20" - ) + num_points = 20 + df_points = random_geometry("Point", num_points, seed=42) # Create target points (build side) - target_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 50, - "seed": 43, - } - ) - df_targets = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{target_options}') LIMIT 50" - ) + df_targets = random_geometry("Point", 50, seed=43) # Set up tables in both engines eng_sedonadb.create_table_arrow("knn_query_points", df_points) @@ -73,7 +56,7 @@ def test_knn_join_basic(k): # Verify basic correctness assert len(sedonadb_results) > 0 assert ( - len(sedonadb_results) == len(df_points) * k + len(sedonadb_results) == num_points * k ) # Each query point should have k neighbors # Verify results are ordered by distance within each query point @@ -112,29 +95,12 @@ def test_knn_join_with_polygons(): PostGIS.create_or_skip() as eng_postgis, ): # Create query points - point_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 15, - "seed": 100, - } - ) - df_points = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{point_options}') LIMIT 15" - ) + n_points = 15 + df_points = random_geometry("Point", n_points, seed=100) # Create target polygons - polygon_options = json.dumps( - { - "geom_type": "Polygon", - "target_rows": 30, - "vertices_per_linestring_range": [4, 8], - "size_range": [0.001, 0.01], - "seed": 101, - } - ) - df_polygons = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{polygon_options}') LIMIT 30" + df_polygons = random_geometry( + "Polygon", 30, num_vertices=(4, 8), size=(0.001, 0.01), seed=101 ) # Set up tables @@ -159,7 +125,7 @@ def test_knn_join_with_polygons(): # Verify correctness assert len(sedonadb_results) > 0 - assert len(sedonadb_results) == len(df_points) * k + assert len(sedonadb_results) == n_points * k # Verify ordering within each point for point_id in sedonadb_results["point_id"].unique(): @@ -196,27 +162,12 @@ def test_knn_join_edge_cases(): PostGIS.create_or_skip() as eng_postgis, ): # Create small datasets for edge case testing - point_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 5, - "seed": 200, - } - ) - df_points = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{point_options}') LIMIT 5" - ) + n_points = 5 + df_points = random_geometry("Point", n_points, seed=200) - target_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 3, # Fewer targets than k in some tests - "seed": 201, - } - ) - df_targets = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{target_options}') LIMIT 3" - ) + # Fewer targets than k in some tests + n_targets = 3 + df_targets = random_geometry("Point", n_targets, seed=201) eng_sedonadb.create_table_arrow("knn_query_small", df_points) eng_sedonadb.create_table_arrow("knn_target_small", df_targets) @@ -238,8 +189,8 @@ def test_knn_join_edge_cases(): sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas() # Should return all available targets (3) for each query point - expected_results_per_query = min(k, len(df_targets)) # min(5, 3) = 3 - assert len(sedonadb_results) == len(df_points) * expected_results_per_query + expected_results_per_query = min(k, n_targets) # min(5, 3) = 3 + assert len(sedonadb_results) == n_points * expected_results_per_query # PostGIS syntax postgis_sql = f""" @@ -271,7 +222,7 @@ def test_knn_join_with_attributes(): point_options = json.dumps( { "geom_type": "Point", - "target_rows": 10, + "num_rows": 10, "seed": 300, } ) @@ -290,7 +241,7 @@ def test_knn_join_with_attributes(): target_options = json.dumps( { "geom_type": "Point", - "target_rows": 20, + "num_rows": 20, "seed": 301, } ) @@ -371,27 +322,8 @@ def test_knn_join_correctness_known_points(): PostGIS.create_or_skip() as eng_postgis, ): # Create deterministic synthetic data for reproducible results - query_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 3, - "seed": 1000, - } - ) - df_known = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{query_options}') LIMIT 3" - ) - - target_options = json.dumps( - { - "geom_type": "Point", - "target_rows": 8, - "seed": 1001, - } - ) - df_targets = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{target_options}') LIMIT 8" - ) + df_known = random_geometry("Point", 3, seed=1000) + df_targets = random_geometry("Point", 8, seed=1001) eng_sedonadb.create_table_arrow("knn_known", df_known) eng_sedonadb.create_table_arrow("knn_target_known", df_targets) diff --git a/python/sedonadb/tests/test_sjoin.py b/python/sedonadb/tests/test_sjoin.py index beb412ced..b75831801 100644 --- a/python/sedonadb/tests/test_sjoin.py +++ b/python/sedonadb/tests/test_sjoin.py @@ -16,7 +16,7 @@ # under the License. import pytest import json -from sedonadb.testing import PostGIS, SedonaDB +from sedonadb.testing import PostGIS, SedonaDB, random_geometry @pytest.mark.parametrize( @@ -38,30 +38,11 @@ def test_spatial_join(join_type, on): SedonaDB.create_or_skip() as eng_sedonadb, PostGIS.create_or_skip() as eng_postgis, ): - options = json.dumps( - { - "geom_type": "Point", - "polygon_hole_rate": 0.5, - "num_parts_range": [2, 10], - "vertices_per_linestring_range": [2, 10], - "seed": 42, - } - ) - df_point = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" - ) - options = json.dumps( - { - "geom_type": "Polygon", - "polygon_hole_rate": 0.5, - "num_parts_range": [2, 10], - "vertices_per_linestring_range": [2, 10], - "seed": 43, - } - ) - df_polygon = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + df_point = random_geometry("Point", 100, seed=42) + df_polygon = random_geometry( + "Polygon", 100, hole_rate=0.5, num_vertices=(2, 10), seed=43 ) + eng_sedonadb.create_table_arrow("sjoin_point", df_point) eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon) eng_postgis.create_table_arrow("sjoin_point", df_point) @@ -102,11 +83,11 @@ def test_spatial_join_geography(join_type, on): options = json.dumps( { "geom_type": "Point", - "num_parts_range": [2, 10], - "vertices_per_linestring_range": [2, 10], + "num_parts": [2, 10], + "num_vertices": [2, 10], "bounds": west_most_bound, - "size_range": [0.1, 5], - "seed": 958, + "size": [0.1, 5], + "seed": 542, } ) df_point = eng_sedonadb.execute_and_collect( @@ -115,11 +96,11 @@ def test_spatial_join_geography(join_type, on): options = json.dumps( { "geom_type": "Polygon", - "polygon_hole_rate": 0.5, - "num_parts_range": [2, 10], - "vertices_per_linestring_range": [2, 10], + "hole_rate": 0.5, + "num_parts": [2, 10], + "num_vertices": [2, 10], "bounds": east_most_bound, - "size_range": [0.1, 5], + "size": [0.1, 5], "seed": 44, } ) @@ -147,28 +128,11 @@ def test_query_window_in_subquery(): SedonaDB.create_or_skip() as eng_sedonadb, PostGIS.create_or_skip() as eng_postgis, ): - options = json.dumps( - { - "geom_type": "Point", - "seed": 42, - } - ) - df_point = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" - ) - options = json.dumps( - { - "geom_type": "Polygon", - "polygon_hole_rate": 0.5, - "num_parts_range": [2, 10], - "vertices_per_linestring_range": [2, 10], - "size_range": [50, 60], - "seed": 43, - } - ) - df_polygon = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" + df_point = random_geometry("Point", 100, seed=100) + df_polygon = random_geometry( + "Polygon", 100, hole_rate=0.5, num_vertices=(2, 10), size=(50, 60), seed=999 ) + eng_sedonadb.create_table_arrow("sjoin_point", df_point) eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon) eng_postgis.create_table_arrow("sjoin_point", df_point) @@ -195,24 +159,9 @@ def test_non_optimizable_subquery(): SedonaDB.create_or_skip() as eng_sedonadb, PostGIS.create_or_skip() as eng_postgis, ): - options = json.dumps( - { - "geom_type": "Point", - "seed": 42, - } - ) - df_main = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" - ) - options = json.dumps( - { - "geom_type": "Point", - "seed": 43, - } - ) - df_subquery = eng_sedonadb.execute_and_collect( - f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100" - ) + df_main = random_geometry("Point", 100, seed=42) + df_subquery = random_geometry("Point", 100, seed=43) + eng_sedonadb.create_table_arrow("sjoin_main", df_main) eng_sedonadb.create_table_arrow("sjoin_subquery", df_subquery) eng_postgis.create_table_arrow("sjoin_main", df_main) diff --git a/rust/sedona-testing/src/benchmark_util.rs b/rust/sedona-testing/src/benchmark_util.rs index 78504b834..4cd3873be 100644 --- a/rust/sedona-testing/src/benchmark_util.rs +++ b/rust/sedona-testing/src/benchmark_util.rs @@ -19,7 +19,7 @@ use std::{fmt::Debug, sync::Arc, vec}; use arrow_array::{ArrayRef, Float64Array, Int64Array}; use arrow_schema::DataType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use geo_types::Rect; use rand::{distr::Uniform, rngs::StdRng, Rng, SeedableRng}; @@ -402,8 +402,8 @@ impl BenchmarkArgSpec { ), BenchmarkArgSpec::Int64(lo, hi) => { let mut rng = self.rng(i); - let dist = - Uniform::new(lo, hi).map_err(|e| DataFusionError::External(Box::new(e)))?; + let dist = Uniform::new(lo, hi) + .map_err(|e| exec_datafusion_err!("Invalid Int64 range [{lo}, {hi}): {e}"))?; (0..num_batches) .map(|_| -> Result { let int64_array: Int64Array = @@ -414,8 +414,8 @@ impl BenchmarkArgSpec { } BenchmarkArgSpec::Float64(lo, hi) => { let mut rng = self.rng(i); - let dist = - Uniform::new(lo, hi).map_err(|e| DataFusionError::External(Box::new(e)))?; + let dist = Uniform::new(lo, hi) + .map_err(|e| exec_datafusion_err!("Invalid Float64 range [{lo}, {hi}): {e}"))?; (0..num_batches) .map(|_| -> Result { let float64_array: Float64Array = @@ -426,8 +426,8 @@ impl BenchmarkArgSpec { } BenchmarkArgSpec::Int32(lo, hi) => { let mut rng = self.rng(i); - let dist = - Uniform::new(lo, hi).map_err(|e| DataFusionError::External(Box::new(e)))?; + let dist = Uniform::new(lo, hi) + .map_err(|e| exec_datafusion_err!("Invalid Int32 range [{lo}, {hi}): {e}"))?; (0..num_batches) .map(|_| -> Result { let int32_array: arrow_array::Int32Array = diff --git a/rust/sedona-testing/src/datagen.rs b/rust/sedona-testing/src/datagen.rs index 088bf220b..be94f55da 100644 --- a/rust/sedona-testing/src/datagen.rs +++ b/rust/sedona-testing/src/datagen.rs @@ -27,7 +27,7 @@ use arrow_array::{ArrayRef, RecordBatch, RecordBatchReader}; use arrow_array::{BinaryArray, BinaryViewArray}; use arrow_array::{Float64Array, Int32Array}; use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{exec_datafusion_err, plan_err, DataFusionError, Result}; use geo_types::{ Coord, Geometry, GeometryCollection, LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, Rect, @@ -347,6 +347,32 @@ impl RandomPartitionedDataBuilder { Ok((schema, result)) } + /// Validate options + /// + /// This is called internally before generating batches to prevent panics from + /// occurring while creating random output; however, it may also be called + /// at a higher level to generate an error at a more relevant time. + pub fn validate(&self) -> Result<()> { + self.options.validate()?; + + if self.null_rate < 0.0 || self.null_rate > 1.0 { + return plan_err!( + "Expected null_rate between 0.0 and 1.0 but got {}", + self.null_rate + ); + } + + if self.rows_per_batch == 0 { + return plan_err!("Expected rows_per_batch > 0 but got 0"); + } + + if self.num_partitions == 0 { + return plan_err!("Expected num_partitions > 0 but got 0"); + } + + Ok(()) + } + /// Generate a [Rng] based on a seed /// /// Callers can also supply their own [Rng]. @@ -379,6 +405,9 @@ impl RandomPartitionedDataBuilder { partition_idx: usize, batch_idx: usize, ) -> Result { + // Check for valid ranges to avoid panic in generation + self.validate()?; + // Generate IDs - make them unique across partitions and batches let id_start = (partition_idx * self.batches_per_partition + batch_idx) * self.rows_per_batch; @@ -386,8 +415,13 @@ impl RandomPartitionedDataBuilder { .map(|i| (id_start + i) as i32) .collect(); - // Generate random distances between 0.0 and 100.0 - let distance_dist = Uniform::new(0.0, 100.0).expect("valid input to Uniform::new()"); + // Generate random distances relevant to the bounds (0.0 and 100.0 by default) + let max_dist = self + .options + .bounds + .width() + .min(self.options.bounds.height()); + let distance_dist = Uniform::new(0.0, max_dist).expect("valid input to Uniform::new()"); let distances: Vec = (0..self.rows_per_batch) .map(|_| rng.sample(distance_dist)) .collect(); @@ -395,9 +429,7 @@ impl RandomPartitionedDataBuilder { // Generate random geometries based on the geometry type let wkb_geometries = (0..self.rows_per_batch) .map(|_| -> Result>> { - if rng.sample(Uniform::new(0.0, 1.0).expect("valid input to Uniform::new()")) - < self.null_rate - { + if rng.random_bool(self.null_rate) { Ok(None) } else { Ok(Some(generate_random_wkb(rng, &self.options)?)) @@ -490,6 +522,48 @@ impl RandomGeometryOptions { num_parts_range: (1, 3), } } + + fn validate(&self) -> Result<()> { + if self.bounds.width() <= 0.0 || self.bounds.height() <= 0.0 { + return plan_err!("Expected valid bounds but got {:?}", self.bounds); + } + + if self.size_range.0 <= 0.0 || self.size_range.0 > self.size_range.1 { + return plan_err!("Expected valid size_range but got {:?}", self.size_range); + } + + if self.vertices_per_linestring_range.0 == 0 + || self.vertices_per_linestring_range.0 > self.vertices_per_linestring_range.1 + { + return plan_err!( + "Expected valid vertices_per_linestring_range but got {:?}", + self.vertices_per_linestring_range + ); + } + + if !(0.0..=1.0).contains(&self.empty_rate) { + return plan_err!( + "Expected empty_rate between 0.0 and 1.0 but got {}", + self.empty_rate + ); + } + + if !(0.0..=1.0).contains(&self.polygon_hole_rate) { + return plan_err!( + "Expected polygon_hole_rate between 0.0 and 1.0 but got {}", + self.polygon_hole_rate + ); + } + + if self.num_parts_range.0 == 0 || self.num_parts_range.0 > self.num_parts_range.1 { + return plan_err!( + "Expected valid num_parts_range but got {:?}", + self.num_parts_range + ); + } + + Ok(()) + } } impl Default for RandomGeometryOptions { @@ -560,9 +634,9 @@ fn generate_random_point( } else { // Generate random points within the specified bounds let x_dist = Uniform::new(options.bounds.min().x, options.bounds.max().x) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + .map_err(|e| exec_datafusion_err!("Invalid x bounds for random point: {e}"))?; let y_dist = Uniform::new(options.bounds.min().y, options.bounds.max().y) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + .map_err(|e| exec_datafusion_err!("Invalid y bounds for random point: {e}"))?; let x = rng.sample(x_dist); let y = rng.sample(y_dist); Ok(Point::new(x, y)) @@ -581,11 +655,13 @@ fn generate_random_linestring( options.vertices_per_linestring_range.0, options.vertices_per_linestring_range.1, ) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + .map_err(|e| exec_datafusion_err!("Invalid vertex count range for linestring: {e}"))?; // Always sample in such a way that we end up with a valid linestring let num_vertices = rng.sample(vertices_dist).max(2); + // Randomize starting angle (0 to 2 * PI) + let angle = rng.random_range(0.0..(2.0 * PI)); let coords = - generate_circular_vertices(rng, center_x, center_y, half_size, num_vertices, false)?; + generate_circular_vertices(angle, center_x, center_y, half_size, num_vertices, false)?; Ok(LineString::from(coords)) } } @@ -602,22 +678,31 @@ fn generate_random_polygon( options.vertices_per_linestring_range.0, options.vertices_per_linestring_range.1, ) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + .map_err(|e| exec_datafusion_err!("Invalid vertex count range for polygon: {e}"))?; // Always sample in such a way that we end up with a valid Polygon let num_vertices = rng.sample(vertices_dist).max(3); + + // Randomize starting angle (but use the same starting angle for both the shell + // and the hole to ensure a non-intersecting interior) + let angle = rng.random_range(0.0..=(2.0 * PI)); let coords = - generate_circular_vertices(rng, center_x, center_y, half_size, num_vertices, true)?; + generate_circular_vertices(angle, center_x, center_y, half_size, num_vertices, true)?; let shell = LineString::from(coords); let mut holes = Vec::new(); // Potentially add a hole based on probability let add_hole = rng.random_bool(options.polygon_hole_rate); - let hole_scale_factor_dist = Uniform::new(0.1, 0.5).expect("Valid input range"); - let hole_scale_factor = rng.sample(hole_scale_factor_dist); + let hole_scale_factor = rng.random_range(0.1..0.5); if add_hole { let new_size = half_size * hole_scale_factor; - let mut coords = - generate_circular_vertices(rng, center_x, center_y, new_size, num_vertices, true)?; + let mut coords = generate_circular_vertices( + angle, + center_x, + center_y, + new_size, + num_vertices, + true, + )?; coords.reverse(); holes.push(LineString::from(coords)); } @@ -681,7 +766,7 @@ fn generate_random_children ) -> Result> { let num_parts_dist = Uniform::new_inclusive(options.num_parts_range.0, options.num_parts_range.1) - .map_err(|e| DataFusionError::External(Box::new(e)))?; + .map_err(|e| exec_datafusion_err!("Invalid part count range: {e}"))?; let num_parts = rng.sample(num_parts_dist); // Constrain this feature to the size range indicated in the option @@ -733,26 +818,44 @@ fn generate_random_circle( rng: &mut R, options: &RandomGeometryOptions, ) -> Result<(f64, f64, f64)> { - // Generate random diamond polygons (rotated squares) - let size_dist = Uniform::new(options.size_range.0, options.size_range.1) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let half_size = rng.sample(size_dist) / 2.0; - - // Ensure diamond fits within bounds by constraining center position - let center_x_dist = Uniform::new( - options.bounds.min().x + half_size, - options.bounds.max().x - half_size, - ) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let center_y_dist = Uniform::new( - options.bounds.min().y + half_size, - options.bounds.max().y - half_size, - ) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let center_x = rng.sample(center_x_dist); - let center_y = rng.sample(center_y_dist); + // Generate random circular polygons + let size_dist = Uniform::new_inclusive(options.size_range.0, options.size_range.1) + .map_err(|e| exec_datafusion_err!("Invalid size range for random region: {e}"))?; + let size = rng.sample(size_dist); + let half_size = size / 2.0; + let height = options.bounds.height(); + let width = options.bounds.width(); + + // Ensure circle fits within bounds by constraining center position + let center_x = if width >= size { + let center_x_dist = Uniform::new( + options.bounds.min().x + half_size, + options.bounds.max().x - half_size, + ) + .map_err(|e| exec_datafusion_err!("Invalid x bounds for random circle center: {e}"))?; + + rng.sample(center_x_dist) + } else { + options.bounds.min().x + width / 2.0 + }; - Ok((center_x, center_y, half_size)) + let center_y = if height >= size { + let center_y_dist = Uniform::new( + options.bounds.min().y + half_size, + options.bounds.max().y - half_size, + ) + .map_err(|e| exec_datafusion_err!("Invalid y bounds for random circle center: {e}"))?; + + rng.sample(center_y_dist) + } else { + options.bounds.min().y + height / 2.0 + }; + + Ok(( + center_x, + center_y, + half_size.min(height / 2.0).min(width / 2.0), + )) } fn generate_non_overlapping_sub_rectangles(num_parts: usize, bounds: &Rect) -> Vec { @@ -784,8 +887,8 @@ fn generate_non_overlapping_sub_rectangles(num_parts: usize, bounds: &Rect) -> V tiles } -fn generate_circular_vertices( - rng: &mut R, +fn generate_circular_vertices( + mut angle: f64, center_x: f64, center_y: f64, radius: f64, @@ -794,11 +897,6 @@ fn generate_circular_vertices( ) -> Result> { let mut out = Vec::new(); - // Randomize starting angle (0 to 2 * PI) - let start_angle_dist = - Uniform::new(0.0, 2.0 * PI).map_err(|e| DataFusionError::External(Box::new(e)))?; - let mut angle: f64 = rng.sample(start_angle_dist); - let dangle = 2.0 * PI / (num_vertices as f64).max(3.0); for _ in 0..num_vertices { out.push(Coord { @@ -1270,4 +1368,183 @@ mod tests { assert!(bounds.y().is_empty()); } } + + #[test] + fn test_random_partitioned_data_builder_validation() { + // Test invalid null_rate (< 0.0) + let err = RandomPartitionedDataBuilder::new() + .null_rate(-0.1) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected null_rate between 0.0 and 1.0 but got -0.1" + ); + + // Test invalid null_rate (> 1.0) + let err = RandomPartitionedDataBuilder::new() + .null_rate(1.5) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected null_rate between 0.0 and 1.0 but got 1.5" + ); + + // Test invalid rows_per_batch (0) + let err = RandomPartitionedDataBuilder::new() + .rows_per_batch(0) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected rows_per_batch > 0 but got 0" + ); + + // Test invalid num_partitions (0) + let err = RandomPartitionedDataBuilder::new() + .num_partitions(0) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected num_partitions > 0 but got 0" + ); + + // Test invalid empty_rate (< 0.0) + let err = RandomPartitionedDataBuilder::new() + .empty_rate(-0.1) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected empty_rate between 0.0 and 1.0 but got -0.1" + ); + + // Test invalid empty_rate (> 1.0) + let err = RandomPartitionedDataBuilder::new() + .empty_rate(1.5) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected empty_rate between 0.0 and 1.0 but got 1.5" + ); + + // Test invalid polygon_hole_rate (< 0.0) + let err = RandomPartitionedDataBuilder::new() + .polygon_hole_rate(-0.1) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected polygon_hole_rate between 0.0 and 1.0 but got -0.1" + ); + + // Test invalid polygon_hole_rate (> 1.0) + let err = RandomPartitionedDataBuilder::new() + .polygon_hole_rate(1.5) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected polygon_hole_rate between 0.0 and 1.0 but got 1.5" + ); + + // Test invalid size_range (min <= 0) + let err = RandomPartitionedDataBuilder::new() + .size_range((0.0, 10.0)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid size_range but got (0.0, 10.0)" + ); + + // Test invalid size_range (max <= 0) + let err = RandomPartitionedDataBuilder::new() + .size_range((5.0, -1.0)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid size_range but got (5.0, -1.0)" + ); + + // Test invalid size_range (min > max) + let err = RandomPartitionedDataBuilder::new() + .size_range((10.0, 5.0)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid size_range but got (10.0, 5.0)" + ); + + // Test invalid vertices_per_linestring_range (min == 0) + let err = RandomPartitionedDataBuilder::new() + .vertices_per_linestring_range((0, 5)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid vertices_per_linestring_range but got (0, 5)" + ); + + // Test invalid vertices_per_linestring_range (min > max) + let err = RandomPartitionedDataBuilder::new() + .vertices_per_linestring_range((10, 5)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid vertices_per_linestring_range but got (10, 5)" + ); + + // Test invalid num_parts_range (min == 0) + let err = RandomPartitionedDataBuilder::new() + .num_parts_range((0, 5)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid num_parts_range but got (0, 5)" + ); + + // Test invalid num_parts_range (min > max) + let err = RandomPartitionedDataBuilder::new() + .num_parts_range((10, 5)) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid num_parts_range but got (10, 5)" + ); + + // Test invalid bounds (zero width) + let err = RandomPartitionedDataBuilder::new() + .bounds(Rect::new( + Coord { x: 10.0, y: 10.0 }, + Coord { x: 10.0, y: 20.0 }, + )) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid bounds but got RECT(10.0 10.0,10.0 20.0)" + ); + + // Test invalid bounds (zero height) + let err = RandomPartitionedDataBuilder::new() + .bounds(Rect::new( + Coord { x: 10.0, y: 10.0 }, + Coord { x: 20.0, y: 10.0 }, + )) + .validate() + .unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: Expected valid bounds but got RECT(10.0 10.0,20.0 10.0)" + ); + } } diff --git a/rust/sedona/src/random_geometry_provider.rs b/rust/sedona/src/random_geometry_provider.rs index 82d1edba9..dc6285d13 100644 --- a/rust/sedona/src/random_geometry_provider.rs +++ b/rust/sedona/src/random_geometry_provider.rs @@ -39,6 +39,8 @@ use sedona_geometry::types::GeometryTypeId; use sedona_testing::datagen::RandomPartitionedDataBuilder; use serde::{Deserialize, Serialize}; +use crate::record_batch_reader_provider::RowLimitedIterator; + /// A table function that refers to a table of random geometries /// /// This table function accepts one argument, which is a JSON-specified @@ -84,7 +86,7 @@ pub struct RandomGeometryProvider { builder: RandomPartitionedDataBuilder, num_partitions: usize, rows_per_batch: usize, - target_rows: usize, + num_rows: usize, } impl RandomGeometryProvider { @@ -94,7 +96,7 @@ impl RandomGeometryProvider { /// Parameters that affect the number of partitions or rows have different defaults that /// always override that of the builder (unless manually specified in the option). This /// reflects the SQL use case where often the parameter that needs tweaking is the number - /// of total rows, which in this case can be set with `{"target_rows": 2048}`. The number + /// of total rows, which in this case can be set with `{"num_rows": 2048}`. The number /// of total rows will always be a multiple of the batch size times the number of partitions, /// whose defaults to 1024 and 1, respectively. /// @@ -104,7 +106,7 @@ impl RandomGeometryProvider { match serde_json::from_str::(&options_str) { Ok(options) => Some(options), Err(e) => { - return plan_err!("Failed to parse options: {e}\nOption were: {options_str}") + return plan_err!("Failed to parse options: {e}\nOptions were: {options_str}") } } } else { @@ -113,7 +115,7 @@ impl RandomGeometryProvider { let mut num_partitions = 1; let mut rows_per_batch = 1024; - let mut target_rows = 1024; + let mut num_rows = 1024; if let Some(options) = options { if let Some(opt_num_partitions) = options.num_partitions { @@ -124,11 +126,22 @@ impl RandomGeometryProvider { rows_per_batch = opt_rows_per_batch; } - if let Some(opt_target_rows) = options.target_rows { - target_rows = opt_target_rows; + if let Some(opt_num_rows) = options.num_rows { + num_rows = opt_num_rows; } + + // Unlike the Rust version, where we almost always want a set seed by default, + // in SQL, Python, and R we want this to behave like random() be non-deterministic. if let Some(seed) = options.seed { builder = builder.seed(seed); + } else { + builder = builder.seed( + (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_millis() + % u32::MAX as u128) as u64, + ); } if let Some(null_rate) = options.null_rate { builder = builder.null_rate(null_rate); @@ -140,28 +153,31 @@ impl RandomGeometryProvider { let bounds = Rect::new((bounds.0, bounds.1), (bounds.2, bounds.3)); builder = builder.bounds(bounds); } - if let Some(size_range) = options.size_range { + if let Some(size_range) = options.size { builder = builder.size_range(size_range); } - if let Some(vertices_range) = options.vertices_per_linestring_range { + if let Some(vertices_range) = options.num_vertices { builder = builder.vertices_per_linestring_range(vertices_range); } if let Some(empty_rate) = options.empty_rate { builder = builder.empty_rate(empty_rate); } - if let Some(hole_rate) = options.polygon_hole_rate { + if let Some(hole_rate) = options.hole_rate { builder = builder.polygon_hole_rate(hole_rate); } - if let Some(parts_range) = options.num_parts_range { + if let Some(parts_range) = options.num_parts { builder = builder.num_parts_range(parts_range); } } + // Check options early to provide an error at a more relevant place + builder.validate()?; + Ok(RandomGeometryProvider { builder, num_partitions, rows_per_batch, - target_rows, + num_rows, }) } } @@ -187,14 +203,14 @@ impl TableProvider for RandomGeometryProvider { _filters: &[Expr], _limit: Option, ) -> Result> { - let builder = builder_with_partition_sizes( + let (builder, last_partition_rows) = builder_with_partition_sizes( self.builder.clone(), self.rows_per_batch, self.num_partitions, - self.target_rows, + self.num_rows, ); - let exec = Arc::new(RandomGeometryExec::new(builder)); + let exec = Arc::new(RandomGeometryExec::new(builder, last_partition_rows)); // We're required to handle the projection or we'll get an execution error if let Some(projection) = projection { @@ -216,11 +232,12 @@ impl TableProvider for RandomGeometryProvider { #[derive(Debug)] struct RandomGeometryExec { builder: RandomPartitionedDataBuilder, + last_partition_rows: usize, properties: PlanProperties, } impl RandomGeometryExec { - pub fn new(builder: RandomPartitionedDataBuilder) -> Self { + pub fn new(builder: RandomPartitionedDataBuilder, last_partition_rows: usize) -> Self { let properties = PlanProperties::new( EquivalenceProperties::new(builder.schema().clone()), Partitioning::UnknownPartitioning(builder.num_partitions), @@ -230,6 +247,7 @@ impl RandomGeometryExec { Self { builder, + last_partition_rows, properties, } } @@ -237,7 +255,11 @@ impl RandomGeometryExec { impl DisplayAs for RandomGeometryExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "RecordBatchReaderExec") + write!( + f, + "RandomGeometryExec: builder={:?}, last_partition_rows={}", + self.builder, self.last_partition_rows + ) } } @@ -274,56 +296,108 @@ impl ExecutionPlan for RandomGeometryExec { partition: usize, _context: Arc, ) -> Result { + // Return empty stream for out-of-range partitions (can happen with joins) + if partition >= self.builder.num_partitions { + let stream = Box::pin(futures::stream::empty()); + let record_batch_stream = RecordBatchStreamAdapter::new(self.schema(), stream); + return Ok(Box::pin(record_batch_stream)); + } + let rng = RandomPartitionedDataBuilder::default_rng(self.builder.seed + partition as u64); let reader = self.builder.partition_reader(rng, partition); - let iter = reader.map(|item| match item { - Ok(batch) => Ok(batch), - Err(e) => Err(DataFusionError::from(e)), - }); - - let stream = Box::pin(futures::stream::iter(iter)); - let record_batch_stream = RecordBatchStreamAdapter::new(self.schema(), stream); - Ok(Box::pin(record_batch_stream)) + + // If this is the last partition, limit the number of rows from the reader + if partition == (self.builder.num_partitions - 1) { + let iter = Box::new(RowLimitedIterator::new(reader, self.last_partition_rows)); + + let stream = Box::pin(futures::stream::iter(iter)); + let record_batch_stream = RecordBatchStreamAdapter::new(self.schema(), stream); + Ok(Box::pin(record_batch_stream)) + } else { + let iter = reader.map(|item| match item { + Ok(batch) => Ok(batch), + Err(e) => Err(DataFusionError::from(e)), + }); + + let stream = Box::pin(futures::stream::iter(iter)); + let record_batch_stream = RecordBatchStreamAdapter::new(self.schema(), stream); + Ok(Box::pin(record_batch_stream)) + } } } /// These options only exist as a mechanism to deserialize JSON options /// /// See the [RandomPartitionedDataBuilder] for definitive documentation of these -/// values. +/// values. Compared to the lower-level class, these have slightly more compact +/// argument names (whereas the lower level class has more descriptive names +/// for Rust where autocomplete is available to help). #[derive(Serialize, Deserialize, Default)] struct RandomGeometryFunctionOptions { num_partitions: Option, rows_per_batch: Option, - target_rows: Option, + num_rows: Option, seed: Option, null_rate: Option, geom_type: Option, bounds: Option<(f64, f64, f64, f64)>, - size_range: Option<(f64, f64)>, - vertices_per_linestring_range: Option<(usize, usize)>, + #[serde(default, deserialize_with = "deserialize_scalar_or_range")] + size: Option<(f64, f64)>, + #[serde(default, deserialize_with = "deserialize_scalar_or_range")] + num_vertices: Option<(usize, usize)>, empty_rate: Option, - polygon_hole_rate: Option, - num_parts_range: Option<(usize, usize)>, + hole_rate: Option, + #[serde(default, deserialize_with = "deserialize_scalar_or_range")] + num_parts: Option<(usize, usize)>, } fn builder_with_partition_sizes( builder: RandomPartitionedDataBuilder, batch_size: usize, partitions: usize, - target_rows: usize, -) -> RandomPartitionedDataBuilder { + num_rows: usize, +) -> (RandomPartitionedDataBuilder, usize) { let rows_for_one_batch_per_partition = batch_size * partitions; - let batches_per_partition = if target_rows.is_multiple_of(rows_for_one_batch_per_partition) { - target_rows / rows_for_one_batch_per_partition + let batches_per_partition = if num_rows.is_multiple_of(rows_for_one_batch_per_partition) { + num_rows / rows_for_one_batch_per_partition } else { - target_rows / rows_for_one_batch_per_partition + 1 + num_rows / rows_for_one_batch_per_partition + 1 }; - builder + let builder_out = builder .rows_per_batch(batch_size) .num_partitions(partitions) - .batches_per_partition(batches_per_partition) + .batches_per_partition(batches_per_partition); + let normal_partition_rows = batches_per_partition * batch_size; + let remainder = (normal_partition_rows * partitions) - num_rows; + let last_partition_rows = if remainder == 0 { + normal_partition_rows + } else { + normal_partition_rows - remainder + }; + (builder_out, last_partition_rows) +} + +/// Helper to make specifying scalar ranges more concise when only one value is needed +fn deserialize_scalar_or_range<'de, D, T>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: serde::Deserializer<'de>, + T: serde::Deserialize<'de> + Copy, +{ + #[derive(Deserialize)] + #[serde(untagged)] + enum ScalarOrRange { + Scalar(T), + Range((T, T)), + } + + match Option::>::deserialize(deserializer)? { + None => Ok(None), + Some(ScalarOrRange::Scalar(val)) => Ok(Some((val, val))), + Some(ScalarOrRange::Range(range)) => Ok(Some(range)), + } } #[cfg(test)] @@ -341,6 +415,7 @@ mod test { let builder = RandomPartitionedDataBuilder::new() .num_partitions(4) .batches_per_partition(2) + .seed(3840) .rows_per_batch(1024); let (expected_schema, expected_results) = builder.build().unwrap(); assert_eq!(expected_results.len(), 4); @@ -351,7 +426,8 @@ mod test { let provider = RandomGeometryProvider::try_new( builder, Some( - r#"{"target_rows": 8192, "num_partitions": 4, "rows_per_batch": 1024}"#.to_string(), + r#"{"num_rows": 8192, "num_partitions": 4, "seed": 3840, "rows_per_batch": 1024}"# + .to_string(), ), ) .unwrap(); @@ -367,7 +443,7 @@ mod test { ctx.register_udtf("sd_random_geometry", Arc::new(RandomGeometryFunction {})); let df = ctx .sql(r#" - SELECT * FROM sd_random_geometry('{"target_rows": 8192, "num_partitions": 4, "rows_per_batch": 1024}') + SELECT * FROM sd_random_geometry('{"num_rows": 8192, "num_partitions": 4, "seed": 3840, "rows_per_batch": 1024}') "#) .await .unwrap(); @@ -385,24 +461,44 @@ mod test { // an exact number of rows let provider = RandomGeometryProvider::try_new( RandomPartitionedDataBuilder::new(), - Some( - r#"{"target_rows": 8192, "num_partitions": 2, "rows_per_batch": 1024}"#.to_string(), - ), + Some(r#"{"num_rows": 8192, "num_partitions": 2, "rows_per_batch": 1024}"#.to_string()), ) .unwrap(); let df = ctx.read_table(Arc::new(provider)).unwrap(); assert_eq!(df.count().await.unwrap(), 8192); - // If the batch size * num_partitions doesn't fit evenly, we should have more rows - // than target_rows + // If the batch size * num_partitions doesn't fit evenly, we should still get the + // exact number of target rows let provider = RandomGeometryProvider::try_new( RandomPartitionedDataBuilder::new(), - Some( - r#"{"target_rows": 9000, "num_partitions": 2, "rows_per_batch": 1024}"#.to_string(), - ), + Some(r#"{"num_rows": 9000, "num_partitions": 2, "rows_per_batch": 1024}"#.to_string()), + ) + .unwrap(); + let df = ctx.read_table(Arc::new(provider)).unwrap(); + assert_eq!(df.count().await.unwrap(), 9000); + } + + #[tokio::test] + async fn provider_with_scalar_size() { + let ctx = SessionContext::new(); + + // Test that a scalar value for size works (gets converted to (value, value)) + let provider = RandomGeometryProvider::try_new( + RandomPartitionedDataBuilder::new(), + Some(r#"{"num_rows": 1024, "size": 0.5}"#.to_string()), + ) + .unwrap(); + + let df = ctx.read_table(Arc::new(provider)).unwrap(); + assert_eq!(df.count().await.unwrap(), 1024); + + // Test that a range value for size still works + let provider = RandomGeometryProvider::try_new( + RandomPartitionedDataBuilder::new(), + Some(r#"{"num_rows": 1024, "size": [0.1, 0.5]}"#.to_string()), ) .unwrap(); let df = ctx.read_table(Arc::new(provider)).unwrap(); - assert_eq!(df.count().await.unwrap(), 8192 + (2 * 1024)); + assert_eq!(df.count().await.unwrap(), 1024); } } diff --git a/rust/sedona/src/record_batch_reader_provider.rs b/rust/sedona/src/record_batch_reader_provider.rs index e197f89d3..9040a9af0 100644 --- a/rust/sedona/src/record_batch_reader_provider.rs +++ b/rust/sedona/src/record_batch_reader_provider.rs @@ -101,14 +101,14 @@ impl TableProvider for RecordBatchReaderProvider { } /// An iterator that limits the number of rows from a RecordBatchReader -struct RowLimitedIterator { +pub struct RowLimitedIterator { reader: Option>, limit: usize, rows_consumed: usize, } impl RowLimitedIterator { - fn new(reader: Box, limit: usize) -> Self { + pub fn new(reader: Box, limit: usize) -> Self { Self { reader: Some(reader), limit, @@ -245,9 +245,16 @@ impl ExecutionPlan for RecordBatchReaderExec { fn execute( &self, - _partition: usize, + partition: usize, _context: Arc, ) -> Result { + // Return empty stream for out-of-range partitions (can happen with joins) + if partition > 0 { + let stream = Box::pin(futures::stream::empty()); + let record_batch_stream = RecordBatchStreamAdapter::new(self.schema(), stream); + return Ok(Box::pin(record_batch_stream)); + } + let mut reader_guard = self.reader.lock(); let reader = if let Some(reader) = reader_guard.take() {