diff --git a/pyproject.toml b/pyproject.toml index 8f5371ab..88e2f833 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,6 +184,7 @@ no_implicit_reexport = true ignore_missing_imports = true disallow_untyped_defs = true plugins = "pydantic.mypy" +enable_incomplete_feature = ["Unpack"] [tool.pydantic-mypy] init_forbid_extra = true diff --git a/src/ome_types/_conversion.py b/src/ome_types/_conversion.py index fde471ae..21a11947 100644 --- a/src/ome_types/_conversion.py +++ b/src/ome_types/_conversion.py @@ -585,3 +585,20 @@ def _get_root_ome_type(xml: FileLike | AnyElementTree) -> type[OMEType]: return getattr(model, localname) except AttributeError: raise ValueError(f"Unknown root element {localname!r}") from None + + +def camel_to_snake(name: str) -> str: + """Variant of camel_to_snake that preserves adjacent uppercase letters. + + https://stackoverflow.com/a/1176023 + + Note: this function also exists in ome_autogen._util, but we shouldn't import + anything from that module at runtime, so we duplicate it here. + """ + import re + + name = name.lstrip("@") # remove leading @ from "@any_element" + result = re.sub("([A-Z]+)([A-Z][a-z]+)", r"\1_\2", name) + result = re.sub("([a-z0-9])([A-Z])", r"\1_\2", result) + result = result.lower().replace(" ", "_") + return result diff --git a/src/ome_types/_from_arrays.py b/src/ome_types/_from_arrays.py new file mode 100644 index 00000000..569eb15b --- /dev/null +++ b/src/ome_types/_from_arrays.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import warnings +from itertools import zip_longest +from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast + +import numpy as np + +from ome_types import model as m +from ome_types._mixins._validators import numpy_dtype_to_pixel_type + +if TYPE_CHECKING: + import datetime + + import numpy.typing as npt + from typing_extensions import Literal, TypedDict, TypeVar, Unpack + + Kt = TypeVar("Kt") + Vt = TypeVar("Vt", covariant=True) + + from ome_types.model._color import ColorType + + DimsOrderStr = Literal["XYZCT", "XYZTC", "XYCTZ", "XYCZT", "XYTCZ", "XYTZC"] + + # TODO: these should be autogenerated + + class ImagePixelsKwargs(TypedDict, total=False): + acquisition_date: datetime.datetime | None + description: str | None + name: str | None + dimension_order: m.Pixels_DimensionOrder | DimsOrderStr + physical_size_x: float | None + physical_size_y: float | None + physical_size_z: float | None + physical_size_x_unit: m.UnitsLength | str + physical_size_y_unit: m.UnitsLength | str + physical_size_z_unit: m.UnitsLength | str + time_increment: float | None + time_increment_unit: m.UnitsTime | str + + class ChannelKwargs(TypedDict, total=False): + acquisition_mode: m.Channel_AcquisitionMode | None | str + color: m.Color | ColorType | None + contrast_method: m.Channel_ContrastMethod | None | str + emission_wavelength_unit: m.UnitsLength | str + emission_wavelength: float | None + excitation_wavelength_unit: m.UnitsLength | str + excitation_wavelength: float | None + fluor: str | None + illumination_type: m.Channel_IlluminationType | str | None + name: str | None + nd_filter: float | None + pinhole_size_unit: m.UnitsLength | str + pinhole_size: float | None + pockel_cell_setting: int | None + # samples_per_pixel : None | int # will be derived from the array + + # same as above, but in {name: Sequence[values]} format + class ChannelTable(TypedDict, total=False): + acquisition_mode: Sequence[m.Channel_AcquisitionMode | None | str] + color: Sequence[m.Color | ColorType | None] | m.Color + contrast_method: Sequence[m.Channel_ContrastMethod | None | str] + emission_wavelength_unit: Sequence[m.UnitsLength | str] + emission_wavelength: Sequence[float | None] | float + excitation_wavelength_unit: Sequence[m.UnitsLength | str] + excitation_wavelength: Sequence[float | None] | float + fluor: Sequence[str | None] + illumination_type: Sequence[m.Channel_IlluminationType | str | None] + name: Sequence[str | None] + nd_filter: Sequence[float | None] | float + pinhole_size_unit: Sequence[m.UnitsLength | str] + pinhole_size: Sequence[float | None] | float + pockel_cell_setting: Sequence[int | None] + + class PlaneKwargs(TypedDict, total=False): + delta_t: float | None + delta_t_unit: m.UnitsTime | str + exposure_time: float | None + exposure_time_unit: m.UnitsTime | str + position_x: float | None + position_x_unit: m.UnitsLength | str + position_y: float | None + position_y_unit: m.UnitsLength | str + position_z: float | None + position_z_unit: m.UnitsLength | str + + class PlaneTable(TypedDict, total=False): + delta_t: Sequence[float | None] | float + delta_t_unit: Sequence[m.UnitsTime | str] + exposure_time: Sequence[float | None] | float + exposure_time_unit: Sequence[m.UnitsTime | str] + position_x: Sequence[float | None] | float + position_x_unit: Sequence[m.UnitsLength | str] + position_y: Sequence[float | None] | float + position_y_unit: Sequence[m.UnitsLength | str] + position_z: Sequence[float | None] | float + position_z_unit: Sequence[m.UnitsLength | str] + + +def ome_image( + shape: Sequence[int], + dtype: npt.DTypeLike, + axes: Sequence[str] = "", + *, + channels: Sequence[m.Channel] | Sequence[ChannelKwargs] | ChannelTable = (), + planes: Sequence[PlaneKwargs] | PlaneTable = (), + **img_kwargs: Unpack[ImagePixelsKwargs], +) -> m.Image: + shape = tuple(int(i) for i in shape) + ndim = len(shape) + if ndim > 6: + raise ValueError(f"shape must have at most 6 dimensions, not {ndim}") + + # unify axes argument with dimension_order and validate + axes, dims_order = _determine_axes( + axes, shape, img_kwargs.pop("dimension_order", None) + ) + + # determine pixel axis sizes ------------------------------------ + nc, nz, nt, nsamp = (shape[axes.index(x)] if x in axes else 1 for x in "CZTS") + sizes = {f"size_{ax}": 1 for ax in "xyczt"} + for ax, size in zip(axes, shape): + if ax == "S": + continue + if ax == "C": + size *= nsamp + sizes[f"size_{ax.lower()}"] = size + + czt_order = tuple(dims_order.value[2:].index(ax) for ax in "CZT") + + # pull out the kwargs that belong to Image and Pixels + _img_kwargs, _pix_kwargs = {}, {} + for k, v in img_kwargs.items(): + if k in m.Image.__annotations__: + _img_kwargs[k] = v + elif k in m.Pixels.__annotations__: + _pix_kwargs[k] = v + img = m.Image( + pixels=m.Pixels( + dimension_order=dims_order, + **sizes, + type=numpy_dtype_to_pixel_type(dtype), + # big_endian=False, + # significant_bits=8, + # bin_data=numpy.zeros(shape, dtype=dtype), + **_convert_keys_to_snake_case(_pix_kwargs), + channels=ome_channels(channels, nc, nsamp), + planes=ome_planes((nc, nz, nt), czt_order, planes), + ), + **_convert_keys_to_snake_case(_img_kwargs), + ) + ... + # TODO: validate against shape and dtype here + return img + + +def ome_image_like( + array: npt.NDArray, + axes: Sequence[str] = "", + *, + channels: Sequence[ChannelKwargs] | ChannelTable = (), + planes: Sequence[PlaneKwargs] | PlaneTable = (), + **img_kwargs: Unpack[ImagePixelsKwargs], +) -> m.Image: + return ome_image( + shape=array.shape, + dtype=array.dtype, + axes=axes, + channels=channels, + planes=planes, + **img_kwargs, + ) + + +def _determine_axes( + axes: Sequence[str] | None, + shape: Sequence[int], + dimension_order: m.Pixels_DimensionOrder | str | None, +) -> tuple[str, m.Pixels_DimensionOrder]: + _dims_order = m.Pixels_DimensionOrder(dimension_order or "XYCZT") + ndim = len(shape) + if not axes: + axes = _dims_order.value[::-1][-len(shape) :] + if ndim == 6: + axes += "S" + return axes, _dims_order + + if ndim == 6 and "S" not in axes: + raise ValueError( + "shape has 6 dimensions, so axes must be specified with 'S' in it" + ) + + axes = "".join(x[0] for x in axes).upper() + if len(axes) != len(shape): + raise ValueError(f"Axes {axes!r} do not match shape {shape!r}") + + ome_axes = axes[::-1] + ome_axes = ome_axes.replace("S", "") + + for order in m.Pixels_DimensionOrder: + if order.value.startswith(ome_axes): + if dimension_order and order != _dims_order: + warnings.warn( + f"Provided OME dimension_order {dimension_order!r} does not match " + f"provided (reversed) axes {axes[::-1]!r}. Using {order.value!r}", + stacklevel=2, + ) + return axes, order + + raise ValueError(f"Could not determine dimension order from axes {axes!r}") + + +def ome_channels( + channels: Sequence[m.Channel] | Sequence[ChannelKwargs] | ChannelTable = (), + max_channels: int | None = None, + samples_per_pixel: int = 1, +) -> list[m.Channel]: + if not channels: + return [ + m.Channel(samples_per_pixel=samples_per_pixel) + for _ in range((max_channels or 1) // samples_per_pixel) + ] + + # convert dict of lists to list of dicts + if isinstance(channels, dict): + channels = cast("Sequence[ChannelKwargs]", _dol2lod(channels, max_channels)) + + # limit to max_channels (based on previous shape analysis) + # TODO: should we warn if too many channels are provided? + channels = channels[:max_channels] + + channel_list: list[m.Channel] = [] + for channel in channels[:max_channels]: + if isinstance(channel, m.Channel): + kwargs: dict = channel.dict() + else: + kwargs = _convert_keys_to_snake_case(channel) + kwargs["samples_per_pixel"] = samples_per_pixel + channel_list.append(m.Channel(**kwargs)) + return channel_list + + +def ome_planes( + n_czt: tuple[int, int, int], + czt_order: tuple[int, int, int], + planes: Sequence[PlaneKwargs] | PlaneTable, +) -> list[m.Plane]: + # if not planes: + # return [] + + plane_count = int(np.prod(n_czt)) + + if isinstance(planes, dict): + # convert dict of lists to list of dicts + planes = cast("Sequence[PlaneKwargs]", _dol2lod(planes, plane_count)) + + if not planes: + planes = [{} for _ in range(plane_count)] + elif len(planes) > plane_count: + warnings.warn( + f"Provided {len(planes)} planes, but expected {plane_count}", + stacklevel=2, + ) + planes = planes[:plane_count] + elif len(planes) < plane_count: + raise ValueError(f"Provided {len(planes)} planes, but expected {plane_count}") + + plane_list = [] + for idx, plane in enumerate(planes): + unraveled = np.unravel_index(idx, n_czt, order="F") + c, z, t = (unraveled[i] for i in czt_order) + plane_list.append(m.Plane(**dict(**plane, the_c=c, the_z=z, the_t=t))) + return plane_list + + +def _dol2lod(dol: Mapping[str, Any], max_items: int | None = None) -> list[dict]: + # convert dict of sequences to sequence of dicts + for k, v in dol.items(): + # extend single items to max_items + if not isinstance(v, Sequence) or isinstance(v, str): + dol[k] = [v] * (max_items or 1) # type: ignore + val_zip = zip_longest(*dol.values()) # type: ignore + return [dict(zip(dol, v)) for v in val_zip] + + +def _convert_keys_to_snake_case(d: Mapping[str, Vt]) -> dict[str, Vt]: + from ome_types._conversion import camel_to_snake + + return {camel_to_snake(k): v for k, v in d.items()} diff --git a/src/ome_types/_mixins/_validators.py b/src/ome_types/_mixins/_validators.py index ae2be08d..cb364f1c 100644 --- a/src/ome_types/_mixins/_validators.py +++ b/src/ome_types/_mixins/_validators.py @@ -5,7 +5,11 @@ import warnings from typing import TYPE_CHECKING, Any, Dict, List, Sequence +import numpy as np + if TYPE_CHECKING: + from numpy.typing import DTypeLike + from ome_types.model import ( # type: ignore BinData, Pixels, @@ -32,15 +36,14 @@ def bin_data_root_validator(cls: "BinData", values: dict) -> Dict[str, Any]: # @root_validator(pre=True) def pixels_root_validator(cls: "Pixels", value: dict) -> dict: - if "metadata_only" in value: - if isinstance(value["metadata_only"], bool): - if not value["metadata_only"]: - value.pop("metadata_only") - else: - # type ignore in case the autogeneration hasn't been built - from ome_types.model import MetadataOnly # type: ignore + if "metadata_only" in value and isinstance(value["metadata_only"], bool): + if not value["metadata_only"]: + value.pop("metadata_only") + else: + # type ignore in case the autogeneration hasn't been built + from ome_types.model import MetadataOnly # type: ignore - value["metadata_only"] = MetadataOnly() + value["metadata_only"] = MetadataOnly() return value @@ -76,13 +79,25 @@ def xml_value_validator(cls: "XMLAnnotation", v: Any) -> "XMLAnnotation.Value": return v -def pixel_type_to_numpy_dtype(self: "PixelType") -> str: +# maps OME PixelType names to numpy dtype names +NP_DTYPE_MAP: "dict[str, str]" = { + "float": "float32", + "double": "float64", + "complex": "complex64", + "double-complex": "complex128", + "bit": "bool", # ? +} +REV_NP_DTYPE_MAP: "dict[str, str]" = {v: k for k, v in NP_DTYPE_MAP.items()} + + +def pixel_type_to_numpy_dtype(self: "PixelType") -> "DTypeLike": """Get a numpy dtype string for this pixel type.""" - m = { - "float": "float32", - "double": "float64", - "complex": "complex64", - "double-complex": "complex128", - "bit": "bool", # ? - } - return m.get(self.value, self.value) + return NP_DTYPE_MAP.get(self.value, self.value) + + +def numpy_dtype_to_pixel_type(dtype: "DTypeLike") -> "PixelType": + """Return the PixelType corresponding to the numpy dtype.""" + from ome_types.model import PixelType + + _dtype = np.dtype(dtype).name + return PixelType(value=REV_NP_DTYPE_MAP.get(_dtype, _dtype)) diff --git a/tests/test_from_array.py b/tests/test_from_array.py new file mode 100644 index 00000000..baaf4bab --- /dev/null +++ b/tests/test_from_array.py @@ -0,0 +1,33 @@ +import numpy as np + +from ome_types import model +from ome_types._from_arrays import ome_image_like + + +def test_ome_image() -> None: + data = np.empty((2, 2, 3, 10, 20, 3), dtype=np.uint16) + + NAMES = ["a", "b", "c", "d"] # note, it's actually 1 too many + + img = ome_image_like( + data, + description="test", + channels={"name": NAMES, "color": "red"}, + planes={"position_x": [1] * 12, "exposure_time": 3}, + ) + + assert isinstance(img, model.Image) + assert img.description == "test" + assert len(img.pixels.channels) == 3 + assert img.pixels.dimension_order == model.Pixels_DimensionOrder.XYCZT + assert img.pixels.type == model.PixelType.UINT16 + + for c, name in zip(img.pixels.channels, NAMES): + assert c.color == model.Color("red") + assert c.name == name + + n_planes = np.prod(data.shape[:3]) + assert len(img.pixels.planes) == n_planes + for p in img.pixels.planes: + assert p.position_x == 1.0 + assert p.exposure_time == 3.0