diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py new file mode 100644 index 00000000000..7c1d1094d1b --- /dev/null +++ b/pymc/backends/zarr.py @@ -0,0 +1,302 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any + +import numcodecs +import numpy as np +import zarr + +from pytensor.tensor.variable import TensorVariable +from zarr.storage import BaseStore +from zarr.sync import Synchronizer + +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.backends.base import BaseTrace +from pymc.model.core import Model, modelcontext +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + get_stats_dtypes_shapes_from_steps, +) +from pymc.util import get_default_varnames + + +class ZarrChain(BaseTrace): + def __init__( + self, + store: BaseStore | MutableMapping, + stats_bijection: StatsBijection, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + test_point: Sequence[dict[str, np.ndarray]] | None = None, + ): + super().__init__(name="zarr", model=model, vars=vars, test_point=test_point) + self.draw_idx = 0 + self._posterior = zarr.open_group( + store, synchronizer=synchronizer, path="posterior", mode="a" + ) + self._sample_stats = zarr.open_group( + store, synchronizer=synchronizer, path="sample_stats", mode="a" + ) + self._sampling_state = zarr.open_group( + store, synchronizer=synchronizer, path="_sampling_state", mode="a" + ) + self.stats_bijection = stats_bijection + + def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] + self.chain = chain + + def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + chain = self.chain + draw_idx = self.draw_idx + for var_name, var_value in zip(self.varnames, self.fn(draw)): + self._posterior[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + for var_name, var_value in self.stats_bijection.map(stats).items(): + self._sample_stats[var_name].set_orthogonal_selection( + (chain, draw_idx), + var_value, + ) + self.draw_idx += 1 + + def record_sampling_state(self, step): + self._sampling_state.sampling_state.set_coordinate_selection( + self.chain, np.array([step.sampling_state], dtype="object") + ) + self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx) + + +FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None + + +def get_fill_value_and_codec( + dtype: Any, +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: + _dtype = np.dtype(dtype) + if np.issubdtype(_dtype, np.floating): + return (np.nan, _dtype, None) + elif np.issubdtype(_dtype, np.integer): + return (-1_000_000, _dtype, None) + elif np.issubdtype(_dtype, "bool"): + return (False, _dtype, None) + elif np.issubdtype(_dtype, "str"): + return ("", _dtype, None) + elif np.issubdtype(_dtype, "datetime64"): + return (np.datetime64(0, "Y"), _dtype, None) + elif np.issubdtype(_dtype, "timedelta64"): + return (np.timedelta64(0, "Y"), _dtype, None) + else: + return (None, _dtype, numcodecs.Pickle()) + + +class ZarrTrace: + def __init__( + self, + store: BaseStore | MutableMapping | None = None, + synchronizer: Synchronizer | None = None, + model: Model | None = None, + vars: Sequence[TensorVariable] | None = None, + include_transformed: bool = False, + ): + model = modelcontext(model) + self.model = model + + self.synchronizer = synchronizer + self.root = zarr.group( + store=store, + overwrite=True, + synchronizer=synchronizer, + ) + self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model) + + if vars is None: + vars = model.unobserved_value_vars + + unnamed_vars = {var for var in vars if var.name is None} + if unnamed_vars: + raise Exception(f"Can't trace unnamed variables: {unnamed_vars}") + self.varnames = get_default_varnames( + [var.name for var in vars], include_transformed=include_transformed + ) + self.vars = [var for var in vars if var.name in self.varnames] + + self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") + + # Get variable shapes. Most backends will need this + # information. + test_point = model.initial_point() + var_values = list(zip(self.varnames, self.fn(test_point))) + self.var_dtype_shapes = {var: (value.dtype, value.shape) for var, value in var_values} + self._is_base_setup = False + + @property + def posterior(self): + return self.root.posterior + + @property + def sample_stats(self): + return self.root.sample_stats + + @property + def constant_data(self): + return self.root.constant_data + + @property + def observed_data(self): + return self.root.observed_data + + @property + def sampling_state(self): + return self.root.sampling_state + + def init_trace(self, chains: int, draws: int, step: BlockedStep | CompoundStep): + self.create_group( + name="constant_data", + data_dict=find_constants(self.model), + ) + + self.create_group( + name="observed_data", + data_dict=find_observations(self.model), + ) + + self.init_group_with_empty( + group=self.root.create_group(name="posterior", overwrite=True), + var_dtype_and_shape=self.var_dtype_shapes, + chains=chains, + draws=draws, + ) + stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( + [step] if isinstance(step, BlockedStep) else step.methods + ) + self.init_group_with_empty( + group=self.root.create_group(name="sample_stats", overwrite=True), + var_dtype_and_shape=stats_dtypes_shapes, + chains=chains, + draws=draws, + ) + + self.init_sampling_state_group(chains=chains) + + self.straces = [ + ZarrChain( + store=self.root.store, + synchronizer=self.synchronizer, + model=self.model, + vars=self.vars, + test_point=None, + stats_bijection=StatsBijection(step.stats_dtypes), + ) + ] + for chain, strace in enumerate(self.straces): + strace.setup(draws=draws, chain=chain, sampler_vars=None) + + def close(self): + for strace in self.straces: + strace._posterior.close() + strace._sample_stats.close() + strace._sampling_state.close() + zarr.consolidate_metadata(self.root.store) + self.root.store.close() + + def init_sampling_state_group(self, chains): + state = self.root.create_group(name="_sampling_state", overwrite=True) + sampling_state = state.empty( + name="sampling_state", + overwrite=True, + shape=(chains,), + chunks=(1,), + dtype="object", + object_codec=numcodecs.Pickle(), + ) + sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + draw_idx = state.array( + name="draw_idx", + overwrite=True, + data=np.zeros(chains, dtype="int"), + chunks=(1,), + dtype="int", + fill_value=-1, + ) + draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + chain = state.array(name="chain", data=range(chains)) + chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]}) + + def init_group_with_empty(self, group, var_dtype_and_shape, chains, draws): + group_coords = {"chain": range(chains), "draw": range(draws)} + for name, (dtype, shape) in var_dtype_and_shape.items(): + fill_value, dtype, object_codec = get_fill_value_and_codec(dtype) + shape = shape or () + array = group.full( + name=name, + dtype=dtype, + fill_value=fill_value, + object_codec=object_codec, + shape=(chains, draws, *shape), + chunks=(1, 1, *shape), + ) + try: + dims = self.vars_to_dims[name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i, shape_i in enumerate(shape): + dim = f"{name}_dim_{i}" + dims.append(dim) + group_coords[dim] = list(range(shape_i)) + dims = ("chain", "draw", *dims) + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array(name=dim, data=coord, fill_value=None) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group + + def create_group(self, name, data_dict): + if data_dict: + group_coords = {} + group = self.root.create_group(name=name, overwrite=True) + for var_name, var_value in data_dict.items(): + fill_value, dtype, object_codec = get_fill_value_and_codec(var_value.dtype) + array = group.array( + name=var_name, + data=var_value, + fill_value=fill_value, + dtype=dtype, + object_codec=object_codec, + ) + try: + dims = self.vars_to_dims[var_name] + for dim in dims: + group_coords[dim] = self.coords[dim] + except KeyError: + dims = [] + for i in range(var_value.ndim): + dim = f"{var_name}_dim_{i}" + dims.append(dim) + group_coords[dim] = list(range(var_value.shape[i])) + array.attrs.update({"_ARRAY_DIMENSIONS": dims}) + for dim, coord in group_coords.items(): + array = group.array(name=dim, data=coord, fill_value=None) + array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) + return group