Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
"interpn>=0.8.2",
"matplotlib>=3.5.0",
"numpy>=1.22, <3",
"rlic>=0.2.1",
Expand Down
37 changes: 36 additions & 1 deletion src/lick/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"get_grid_or_mesh",
"get_indexing",
"get_interpolator",
"get_kernel",
"get_layering",
"get_niter_lic",
Expand All @@ -23,7 +24,13 @@
MixMulDict,
NorthWestLightSource,
)
from lick._interpolation import Grid, Mesh
from lick._interpolation import (
Grid,
Interpolator,
Mesh,
RegularGridInterpolator,
UnstruscturedGridInterpolator,
)
from lick._typing import D, F, FArray, FArray1D, FArray2D

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -254,6 +261,34 @@ def get_mesh(
assert_never(unreachable)


def get_interpolator(
input_grid_or_mesh: Grid | Mesh,
*,
target_mesh: Mesh,
indexing: Literal["xy", "ij"] | UnsetType,
) -> Interpolator:
match input_grid_or_mesh:
case Mesh() as m:
return UnstruscturedGridInterpolator(
input_mesh=m,
target_mesh=target_mesh,
)
case Grid() as g if indexing == "xy" and g.is_mono_increasing():
return RegularGridInterpolator(
input_grid=g,
target_mesh=target_mesh,
)
case Grid() as g:
indexing = get_indexing(indexing)
return get_interpolator(
Mesh.from_grid(g, indexing=indexing),
target_mesh=target_mesh,
indexing=UNSET,
)
case _ as unreachable:
assert_never(unreachable)


def get_layering(
layering: AlphaDict | MixMulDict | UnsetType,
*,
Expand Down
64 changes: 58 additions & 6 deletions src/lick/_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
"Method",
]

from dataclasses import dataclass
from dataclasses import KW_ONLY, dataclass
from math import isfinite
from typing import Generic, Literal, TypeAlias, final
from typing import Generic, Literal, Protocol, TypeAlias, final

import numpy as np

Expand Down Expand Up @@ -97,6 +97,9 @@ def from_intervals(
def dtype(self) -> np.dtype[F]:
return self.x.dtype

def is_mono_increasing(self) -> bool:
return bool(np.all(np.diff(self.x) > 0.0) and np.all(np.diff(self.y) > 0.0))


@final
@dataclass(kw_only=True, slots=True, frozen=True)
Expand Down Expand Up @@ -136,16 +139,17 @@ def astype(self, dtype: F, /) -> "Mesh[F]":


@final
@dataclass(kw_only=True, slots=True, frozen=True)
class Interpolator(Generic[F]):
@dataclass(slots=True, frozen=True)
class UnstruscturedGridInterpolator(Generic[F]):
input_mesh: Mesh[F]
_: KW_ONLY
target_mesh: Mesh[F]

def __post_init__(self):
if self.target_mesh.dtype == self.input_mesh.dtype:
return
raise TypeError(
"input and target meshes must use the same data type. "
"input and target must use the same data type. "
f"Got input_mesh.dtype={self.input_mesh.dtype!s}, target_mesh.dtype={self.target_mesh.dtype!s}"
)

Expand All @@ -158,7 +162,7 @@ def __call__(
) -> FArray2D[F]:
if vals.dtype != self.input_mesh.dtype or vals.shape != self.input_mesh.shape:
raise TypeError(
f"Expected values to match the input mesh's data type ({self.input_mesh.dtype}) "
f"Expected values to match the input data type ({self.input_mesh.dtype}) "
f"and shape {self.input_mesh.shape}. "
f"Received values with dtype={vals.dtype!s}, shape={vals.shape}"
)
Expand All @@ -176,3 +180,51 @@ def __call__(
),
method=method,
).astype(vals.dtype)


class Interpolator(Protocol, Generic[F]):
def __call__(self, vals: FArray2D[F], /, *, method: Method) -> FArray2D[F]: ...


@final
@dataclass(slots=True, frozen=True)
class RegularGridInterpolator(Generic[F]):
input_grid: Grid[F]
_: KW_ONLY
target_mesh: Mesh[F]

def __post_init__(self):
if self.target_mesh.dtype == self.input_grid.dtype:
return
raise TypeError(
"input and target must use the same data type. "
f"Got input_grid.dtype={self.input_grid.dtype!s}, target_mesh.dtype={self.target_mesh.dtype!s}"
)

def __call__(
self,
vals: FArray2D[F],
/,
*,
method: Method,
) -> FArray2D[F]:
if (
vals.shape
!= (input_shape := (self.input_grid.y.size, self.input_grid.x.size))
or vals.dtype != self.input_grid.dtype
):
raise TypeError(
f"Expected values to match the input data type ({self.input_grid.dtype}) "
f"and shape {input_shape}. "
f"Received values with dtype={vals.dtype!s}, shape={vals.shape}"
)
from interpn import interpn

# TODO: disable extrapolation for backward compat
# upstream patch required
return interpn(
grids=(self.input_grid.y, self.input_grid.x),
obs=(self.target_mesh.y, self.target_mesh.x),
vals=vals,
method=method,
)
10 changes: 5 additions & 5 deletions src/lick/_publib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MixMulDict,
Normalizer,
)
from lick._interpolation import Grid, Interpolator, Interval, Mesh, Method
from lick._interpolation import Grid, Interval, Mesh, Method
from lick._typing import D, F, FArray, FArray1D, FArray2D

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -58,7 +58,7 @@ def interpol(
) -> InterpolationResults[F]:
if len(all_dtypes := {_.dtype for _ in (x, y, v1, v2, field)}) > 1:
raise TypeError(f"Received inputs with mixed datatypes ({all_dtypes})")
input_mesh = _api.get_mesh(x, y, indexing=indexing)
input_grid_or_mesh = _api.get_grid_or_mesh(x, y) # type: ignore[arg-type]

target_grid = Grid.from_intervals(
x=Interval(
Expand All @@ -73,9 +73,9 @@ def interpol(
dtype=cast(F, x.dtype),
)

interpolate = Interpolator(
input_mesh=input_mesh,
target_mesh=Mesh.from_grid(target_grid, indexing="xy"),
target_mesh = Mesh.from_grid(target_grid, indexing="xy")
interpolate = _api.get_interpolator(
input_grid_or_mesh, target_mesh=target_mesh, indexing=indexing
)

return InterpolationResults(
Expand Down
16 changes: 16 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,22 @@ def test_get_grid_or_mesh_mesh():
assert type(res) is Mesh


def test_get_mesh_from_grid():
_api.get_mesh(
np.linspace(-1.0, 1.0, 8),
np.linspace(-1.0, 1.0, 8),
indexing="ij",
)


def test_get_mesh_no_indexing():
_api.get_mesh(
np.atleast_2d(np.linspace(-1.0, 1.0, 8)),
np.atleast_2d(np.linspace(-1.0, 1.0, 8)),
indexing=_api.UNSET,
)


def test_get_mesh_ignored_indexing():
with pytest.warns(
UserWarning,
Expand Down
83 changes: 50 additions & 33 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import re
from itertools import permutations
from itertools import permutations, product

import numpy as np
import numpy.testing as npt
import pytest

from lick import interpol
from lick._interpolation import Grid, Interpolator, Interval, Mesh
from lick._interpolation import (
Grid,
Interval,
Mesh,
RegularGridInterpolator,
UnstruscturedGridInterpolator,
)

f64 = np.float64

Expand Down Expand Up @@ -223,52 +229,63 @@ def test_mesh_from_grid(dtype, indexing):

@pytest.mark.parametrize("dtype", ["float32", "float64"])
@pytest.mark.parametrize("indexing", ["xy", "ij"])
def test_interpolator_dunder_call(subtests, dtype, indexing):
def test_unstructured_grid_interpolator_dunder_call(subtests, dtype, indexing):
x = np.geomspace(1, 2, 5, dtype=dtype)
y = np.linspace(3, 4, 7, dtype=dtype)
grid = Grid(x=x, y=y)
mesh = Mesh.from_grid(grid, indexing=indexing)

interpolator = Interpolator(input_mesh=mesh, target_mesh=mesh)
for method in ["nearest", "linear", "cubic"]:
with subtests.test(method=method):
for interpolator, method in product(
[
RegularGridInterpolator(input_grid=grid, target_mesh=mesh),
UnstruscturedGridInterpolator(input_mesh=mesh, target_mesh=mesh),
],
["nearest", "linear", "cubic"],
):
if type(interpolator) is RegularGridInterpolator and indexing == "ij":
continue
with subtests.test(type=type(interpolator).__name__, method=method):
res = interpolator(mesh.x, method=method)
npt.assert_array_almost_equal_nulp(res, mesh.x)
npt.assert_array_almost_equal(res, mesh.x)


@pytest.mark.parametrize("dt1, dt2", permutations(["float32", "float64"]))
def test_interpolator_dunder_call_mixed_dtype(subtests, dt1, dt2):
def test_unstructured_grid_interpolator_dunder_call_mixed_dtype(subtests, dt1, dt2):
x = np.geomspace(1, 2, 5, dtype=dt1)
y = np.linspace(3, 4, 7, dtype=dt1)
grid = Grid(x=x, y=y)
mesh = Mesh.from_grid(grid, indexing="ij")
shape = mesh.shape

interpolator = Interpolator(input_mesh=mesh, target_mesh=mesh)
with (
subtests.test(vals_dtype=dt2),
pytest.raises(
TypeError,
match=re.escape(
f"Expected values to match the input mesh's data type ({mesh.dtype}) "
f"and shape {mesh.shape}. "
f"Received values with dtype={dt2!s}, shape={shape}"

grid_interpolator = RegularGridInterpolator(input_grid=grid, target_mesh=mesh)
mesh_interpolator = UnstruscturedGridInterpolator(input_mesh=mesh, target_mesh=mesh)
for interpolator, obj, obj_shape, obj_name in [
(grid_interpolator, grid, (grid.y.size, grid.x.size), "grid"),
(mesh_interpolator, mesh, mesh.shape, "mesh"),
]:
with (
subtests.test(obj_name=obj_name, vals_dtype=dt2),
pytest.raises(
TypeError,
match=re.escape(
f"Expected values to match the input data type ({obj.dtype}) "
f"and shape {obj_shape}. "
f"Received values with dtype={dt2!s}, shape={obj.x.shape}"
),
),
),
):
interpolator(mesh.x.astype(dt2), method="nearest")

with (
subtests.test(vals_dtype=dt1),
pytest.raises(
TypeError,
match=(
r"input and target meshes must use the same data type\. "
rf"Got input_mesh.dtype={dt1!s}, target_mesh\.dtype={dt2!s}"
):
interpolator(obj.x.astype(dt2), method="nearest")

with (
subtests.test(input="mesh", vals_dtype=dt1),
pytest.raises(
TypeError,
match=(
r"input and target must use the same data type\. "
rf"Got input_{obj_name}.dtype={dt1!s}, target_mesh\.dtype={dt2!s}"
),
),
),
):
Interpolator(input_mesh=mesh, target_mesh=mesh.astype(dt2))
):
type(interpolator)(obj.x, target_mesh=mesh.astype(dt2))


@pytest.mark.parametrize("dtype", ["float32", "float64"])
Expand Down
29 changes: 29 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading