diff --git a/pyproject.toml b/pyproject.toml index bd3a40f..36bfc22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ + "interpn>=0.8.2", "matplotlib>=3.5.0", "numpy>=1.22, <3", "rlic>=0.2.1", diff --git a/src/lick/_api.py b/src/lick/_api.py index ba9901c..0f76a70 100644 --- a/src/lick/_api.py +++ b/src/lick/_api.py @@ -1,6 +1,7 @@ __all__ = [ "get_grid_or_mesh", "get_indexing", + "get_interpolator", "get_kernel", "get_layering", "get_niter_lic", @@ -23,7 +24,13 @@ MixMulDict, NorthWestLightSource, ) -from lick._interpolation import Grid, Mesh +from lick._interpolation import ( + Grid, + Interpolator, + Mesh, + RegularGridInterpolator, + UnstruscturedGridInterpolator, +) from lick._typing import D, F, FArray, FArray1D, FArray2D if sys.version_info >= (3, 11): @@ -254,6 +261,34 @@ def get_mesh( assert_never(unreachable) +def get_interpolator( + input_grid_or_mesh: Grid | Mesh, + *, + target_mesh: Mesh, + indexing: Literal["xy", "ij"] | UnsetType, +) -> Interpolator: + match input_grid_or_mesh: + case Mesh() as m: + return UnstruscturedGridInterpolator( + input_mesh=m, + target_mesh=target_mesh, + ) + case Grid() as g if indexing == "xy" and g.is_mono_increasing(): + return RegularGridInterpolator( + input_grid=g, + target_mesh=target_mesh, + ) + case Grid() as g: + indexing = get_indexing(indexing) + return get_interpolator( + Mesh.from_grid(g, indexing=indexing), + target_mesh=target_mesh, + indexing=UNSET, + ) + case _ as unreachable: + assert_never(unreachable) + + def get_layering( layering: AlphaDict | MixMulDict | UnsetType, *, diff --git a/src/lick/_interpolation.py b/src/lick/_interpolation.py index 364849e..da2b640 100644 --- a/src/lick/_interpolation.py +++ b/src/lick/_interpolation.py @@ -6,9 +6,9 @@ "Method", ] -from dataclasses import dataclass +from dataclasses import KW_ONLY, dataclass from math import isfinite -from typing import Generic, Literal, TypeAlias, final +from typing import Generic, Literal, Protocol, TypeAlias, final import numpy as np @@ -97,6 +97,9 @@ def from_intervals( def dtype(self) -> np.dtype[F]: return self.x.dtype + def is_mono_increasing(self) -> bool: + return bool(np.all(np.diff(self.x) > 0.0) and np.all(np.diff(self.y) > 0.0)) + @final @dataclass(kw_only=True, slots=True, frozen=True) @@ -136,16 +139,17 @@ def astype(self, dtype: F, /) -> "Mesh[F]": @final -@dataclass(kw_only=True, slots=True, frozen=True) -class Interpolator(Generic[F]): +@dataclass(slots=True, frozen=True) +class UnstruscturedGridInterpolator(Generic[F]): input_mesh: Mesh[F] + _: KW_ONLY target_mesh: Mesh[F] def __post_init__(self): if self.target_mesh.dtype == self.input_mesh.dtype: return raise TypeError( - "input and target meshes must use the same data type. " + "input and target must use the same data type. " f"Got input_mesh.dtype={self.input_mesh.dtype!s}, target_mesh.dtype={self.target_mesh.dtype!s}" ) @@ -158,7 +162,7 @@ def __call__( ) -> FArray2D[F]: if vals.dtype != self.input_mesh.dtype or vals.shape != self.input_mesh.shape: raise TypeError( - f"Expected values to match the input mesh's data type ({self.input_mesh.dtype}) " + f"Expected values to match the input data type ({self.input_mesh.dtype}) " f"and shape {self.input_mesh.shape}. " f"Received values with dtype={vals.dtype!s}, shape={vals.shape}" ) @@ -176,3 +180,51 @@ def __call__( ), method=method, ).astype(vals.dtype) + + +class Interpolator(Protocol, Generic[F]): + def __call__(self, vals: FArray2D[F], /, *, method: Method) -> FArray2D[F]: ... + + +@final +@dataclass(slots=True, frozen=True) +class RegularGridInterpolator(Generic[F]): + input_grid: Grid[F] + _: KW_ONLY + target_mesh: Mesh[F] + + def __post_init__(self): + if self.target_mesh.dtype == self.input_grid.dtype: + return + raise TypeError( + "input and target must use the same data type. " + f"Got input_grid.dtype={self.input_grid.dtype!s}, target_mesh.dtype={self.target_mesh.dtype!s}" + ) + + def __call__( + self, + vals: FArray2D[F], + /, + *, + method: Method, + ) -> FArray2D[F]: + if ( + vals.shape + != (input_shape := (self.input_grid.y.size, self.input_grid.x.size)) + or vals.dtype != self.input_grid.dtype + ): + raise TypeError( + f"Expected values to match the input data type ({self.input_grid.dtype}) " + f"and shape {input_shape}. " + f"Received values with dtype={vals.dtype!s}, shape={vals.shape}" + ) + from interpn import interpn + + # TODO: disable extrapolation for backward compat + # upstream patch required + return interpn( + grids=(self.input_grid.y, self.input_grid.x), + obs=(self.target_mesh.y, self.target_mesh.x), + vals=vals, + method=method, + ) diff --git a/src/lick/_publib.py b/src/lick/_publib.py index e01f2bf..02f9f1a 100644 --- a/src/lick/_publib.py +++ b/src/lick/_publib.py @@ -19,7 +19,7 @@ MixMulDict, Normalizer, ) -from lick._interpolation import Grid, Interpolator, Interval, Mesh, Method +from lick._interpolation import Grid, Interval, Mesh, Method from lick._typing import D, F, FArray, FArray1D, FArray2D if sys.version_info >= (3, 11): @@ -58,7 +58,7 @@ def interpol( ) -> InterpolationResults[F]: if len(all_dtypes := {_.dtype for _ in (x, y, v1, v2, field)}) > 1: raise TypeError(f"Received inputs with mixed datatypes ({all_dtypes})") - input_mesh = _api.get_mesh(x, y, indexing=indexing) + input_grid_or_mesh = _api.get_grid_or_mesh(x, y) # type: ignore[arg-type] target_grid = Grid.from_intervals( x=Interval( @@ -73,9 +73,9 @@ def interpol( dtype=cast(F, x.dtype), ) - interpolate = Interpolator( - input_mesh=input_mesh, - target_mesh=Mesh.from_grid(target_grid, indexing="xy"), + target_mesh = Mesh.from_grid(target_grid, indexing="xy") + interpolate = _api.get_interpolator( + input_grid_or_mesh, target_mesh=target_mesh, indexing=indexing ) return InterpolationResults( diff --git a/tests/test_api.py b/tests/test_api.py index f88c58f..2a4c753 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -212,6 +212,22 @@ def test_get_grid_or_mesh_mesh(): assert type(res) is Mesh +def test_get_mesh_from_grid(): + _api.get_mesh( + np.linspace(-1.0, 1.0, 8), + np.linspace(-1.0, 1.0, 8), + indexing="ij", + ) + + +def test_get_mesh_no_indexing(): + _api.get_mesh( + np.atleast_2d(np.linspace(-1.0, 1.0, 8)), + np.atleast_2d(np.linspace(-1.0, 1.0, 8)), + indexing=_api.UNSET, + ) + + def test_get_mesh_ignored_indexing(): with pytest.warns( UserWarning, diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 2d9de43..ee7786e 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1,12 +1,18 @@ import re -from itertools import permutations +from itertools import permutations, product import numpy as np import numpy.testing as npt import pytest from lick import interpol -from lick._interpolation import Grid, Interpolator, Interval, Mesh +from lick._interpolation import ( + Grid, + Interval, + Mesh, + RegularGridInterpolator, + UnstruscturedGridInterpolator, +) f64 = np.float64 @@ -223,52 +229,63 @@ def test_mesh_from_grid(dtype, indexing): @pytest.mark.parametrize("dtype", ["float32", "float64"]) @pytest.mark.parametrize("indexing", ["xy", "ij"]) -def test_interpolator_dunder_call(subtests, dtype, indexing): +def test_unstructured_grid_interpolator_dunder_call(subtests, dtype, indexing): x = np.geomspace(1, 2, 5, dtype=dtype) y = np.linspace(3, 4, 7, dtype=dtype) grid = Grid(x=x, y=y) mesh = Mesh.from_grid(grid, indexing=indexing) - interpolator = Interpolator(input_mesh=mesh, target_mesh=mesh) - for method in ["nearest", "linear", "cubic"]: - with subtests.test(method=method): + for interpolator, method in product( + [ + RegularGridInterpolator(input_grid=grid, target_mesh=mesh), + UnstruscturedGridInterpolator(input_mesh=mesh, target_mesh=mesh), + ], + ["nearest", "linear", "cubic"], + ): + if type(interpolator) is RegularGridInterpolator and indexing == "ij": + continue + with subtests.test(type=type(interpolator).__name__, method=method): res = interpolator(mesh.x, method=method) - npt.assert_array_almost_equal_nulp(res, mesh.x) + npt.assert_array_almost_equal(res, mesh.x) @pytest.mark.parametrize("dt1, dt2", permutations(["float32", "float64"])) -def test_interpolator_dunder_call_mixed_dtype(subtests, dt1, dt2): +def test_unstructured_grid_interpolator_dunder_call_mixed_dtype(subtests, dt1, dt2): x = np.geomspace(1, 2, 5, dtype=dt1) y = np.linspace(3, 4, 7, dtype=dt1) grid = Grid(x=x, y=y) mesh = Mesh.from_grid(grid, indexing="ij") - shape = mesh.shape - - interpolator = Interpolator(input_mesh=mesh, target_mesh=mesh) - with ( - subtests.test(vals_dtype=dt2), - pytest.raises( - TypeError, - match=re.escape( - f"Expected values to match the input mesh's data type ({mesh.dtype}) " - f"and shape {mesh.shape}. " - f"Received values with dtype={dt2!s}, shape={shape}" + + grid_interpolator = RegularGridInterpolator(input_grid=grid, target_mesh=mesh) + mesh_interpolator = UnstruscturedGridInterpolator(input_mesh=mesh, target_mesh=mesh) + for interpolator, obj, obj_shape, obj_name in [ + (grid_interpolator, grid, (grid.y.size, grid.x.size), "grid"), + (mesh_interpolator, mesh, mesh.shape, "mesh"), + ]: + with ( + subtests.test(obj_name=obj_name, vals_dtype=dt2), + pytest.raises( + TypeError, + match=re.escape( + f"Expected values to match the input data type ({obj.dtype}) " + f"and shape {obj_shape}. " + f"Received values with dtype={dt2!s}, shape={obj.x.shape}" + ), ), - ), - ): - interpolator(mesh.x.astype(dt2), method="nearest") - - with ( - subtests.test(vals_dtype=dt1), - pytest.raises( - TypeError, - match=( - r"input and target meshes must use the same data type\. " - rf"Got input_mesh.dtype={dt1!s}, target_mesh\.dtype={dt2!s}" + ): + interpolator(obj.x.astype(dt2), method="nearest") + + with ( + subtests.test(input="mesh", vals_dtype=dt1), + pytest.raises( + TypeError, + match=( + r"input and target must use the same data type\. " + rf"Got input_{obj_name}.dtype={dt1!s}, target_mesh\.dtype={dt2!s}" + ), ), - ), - ): - Interpolator(input_mesh=mesh, target_mesh=mesh.astype(dt2)) + ): + type(interpolator)(obj.x, target_mesh=mesh.astype(dt2)) @pytest.mark.parametrize("dtype", ["float32", "float64"]) diff --git a/uv.lock b/uv.lock index 9068cfa..0ad7216 100644 --- a/uv.lock +++ b/uv.lock @@ -369,6 +369,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "interpn" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/b9/ce865437f4a1b52349c3c3f58f6423cab3f36462c42b1e6ac1f1a1720777/interpn-0.8.2.tar.gz", hash = "sha256:297c07c20423144bbd498881e4bad08b49375b23e1712c43ea96c2fe5353de1c", size = 1415340, upload-time = "2025-11-13T02:04:16.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/03/7038c8a3769c433b049d42102c29443b324a0787f4d1a706c76e582bb1f2/interpn-0.8.2-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d06b83b67befbc2d992094e68b979819e2c622b7dbafae066a358f43ed789612", size = 561171, upload-time = "2025-11-13T02:04:07.52Z" }, + { url = "https://files.pythonhosted.org/packages/45/47/6babee99519fb691700418157c3052cc13d3eec266a7c26c8d4a2e756cd7/interpn-0.8.2-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e67ad82aa4b8d794e0676a36ed6c5442eaf7a3d0a253d2ad1f3dba2f0c1ecd4", size = 495627, upload-time = "2025-11-13T02:04:05.777Z" }, + { url = "https://files.pythonhosted.org/packages/bc/4d/ee025c6e71cbf0c7f07dc33ed5283c04722f869d897f31c44751cb8137e2/interpn-0.8.2-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28fe42cabe73e06bc707f09f1d73b02c1e06450d7dd181244cba8dcf10824f64", size = 509061, upload-time = "2025-11-13T02:03:54.546Z" }, + { url = "https://files.pythonhosted.org/packages/87/19/83d1b4b469322d62ea13a5e1ece74cc79faafdac6f770e90546606ada4ae/interpn-0.8.2-cp310-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1347bf47908d2ddf93349d458a1975583fd320daae45e62eda602a8d5b15c5af", size = 394100, upload-time = "2025-11-13T02:03:57.91Z" }, + { url = "https://files.pythonhosted.org/packages/0c/32/2e648fa80b8082e5a608be8cdd1155b388d120dbd8b4ff33ee68e5b542bb/interpn-0.8.2-cp310-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b95430b2137f643faba7f42e57a3fc3b94c80f00df611728618da122799601b", size = 398544, upload-time = "2025-11-13T02:03:59.553Z" }, + { url = "https://files.pythonhosted.org/packages/e8/2e/f78de7e14f6ad37cb3463ebd55b064758dc6a0616e6fd7e46f2417e52d2b/interpn-0.8.2-cp310-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ed558a2a4b5d21eb7a4f01fdee42c6244207e6744080dbf44733f34fff232ee9", size = 487355, upload-time = "2025-11-13T02:04:01.015Z" }, + { url = "https://files.pythonhosted.org/packages/bc/39/005eb2c09e19d90887217610c7f53e20f171db5e4f875b78d1356de524d4/interpn-0.8.2-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0f19fc58b2d40500101145cd420671d16479524759e27387ecc75b9a0b70bcd", size = 563887, upload-time = "2025-11-13T02:04:02.447Z" }, + { url = "https://files.pythonhosted.org/packages/f2/0f/79755bb13c4813caacb6622737e3933b46722eaec47a487c0387c939ef5b/interpn-0.8.2-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b22d3206435f0c008e7dea8da5800e9f4b666bf2eacf21b094e3833c7059e333", size = 538132, upload-time = "2025-11-13T02:04:09.142Z" }, + { url = "https://files.pythonhosted.org/packages/92/c5/044c4cb7354c350826415e82173ac38ba64241272e1cfca3c22019239a3b/interpn-0.8.2-cp310-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:e967d886b3ab7220f096fd0ca7e3e0feebf67e8efd9c621dfbe666c44c3bcfcd", size = 663225, upload-time = "2025-11-13T02:04:10.748Z" }, + { url = "https://files.pythonhosted.org/packages/36/67/2fe23128a89d45d438d3eee1f89cd6f091e24e53767372f0602cd83e002b/interpn-0.8.2-cp310-abi3-musllinux_1_2_i686.whl", hash = "sha256:1be431ada1856f741baeaa6bb3848348980db97cbbf091eec1554441f4313b18", size = 588138, upload-time = "2025-11-13T02:04:12.414Z" }, + { url = "https://files.pythonhosted.org/packages/fc/bc/f29e1613ba9d66e2129b04880c80b22bf90437f15e2ee5ba83a93f67ba7f/interpn-0.8.2-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d92aa971b0733c5dda0956788c31ca6294959b8c43d16c7f76643ecefb3d75b", size = 554986, upload-time = "2025-11-13T02:04:14.621Z" }, + { url = "https://files.pythonhosted.org/packages/db/13/18a9e0a100e2931e57453381b9f77c18632ed4bcc44e0cb3f945eee407c6/interpn-0.8.2-cp310-abi3-win32.whl", hash = "sha256:0981f951aa8bfd86a4180af648f2ba50de4bde1954e681ceb163eef95369d485", size = 304919, upload-time = "2025-11-13T02:04:19.279Z" }, + { url = "https://files.pythonhosted.org/packages/3f/e0/270c6eadb91b47128da19df2d5e480bb9df2af76b884100f0e3405098a91/interpn-0.8.2-cp310-abi3-win_amd64.whl", hash = "sha256:c234363466403bd58712c84ba618e33fdd0413a2cd233ab049057247b9bf7045", size = 506661, upload-time = "2025-11-13T02:04:17.49Z" }, + { url = "https://files.pythonhosted.org/packages/f6/03/c799ade7354bd731c27511b80038fddd036a329e1cd87bdd7d187e902abd/interpn-0.8.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:96ef5e8520711dbb3967b3df5b7d963176cfd8d5b906d146ebb60e95e7c33a57", size = 487660, upload-time = "2025-11-13T02:03:56.605Z" }, + { url = "https://files.pythonhosted.org/packages/f2/e0/e132ef2630bd691b1cce2fea982776ee5e53d7ab6d5ad405484ab171d567/interpn-0.8.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d07de3386cb94ec46c18278fb61e3dd7f192b0175835a681051545fc51f5ea2", size = 538163, upload-time = "2025-11-13T02:04:04.013Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -494,6 +521,7 @@ name = "lick" version = "0.9.0" source = { editable = "." } dependencies = [ + { name = "interpn" }, { name = "matplotlib" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, @@ -518,6 +546,7 @@ typecheck = [ [package.metadata] requires-dist = [ + { name = "interpn", specifier = ">=0.8.2" }, { name = "matplotlib", specifier = ">=3.5.0" }, { name = "numpy", specifier = ">=1.22,<3" }, { name = "rlic", specifier = ">=0.2.1" },