Skip to content

Commit

Permalink
Add ZarrTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 16, 2024
1 parent 5352798 commit 4cf7d0c
Showing 1 changed file with 302 additions and 0 deletions.
302 changes: 302 additions & 0 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any

Check warning on line 15 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L14-L15

Added lines #L14 - L15 were not covered by tests

import numcodecs
import numpy as np
import zarr

Check warning on line 19 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L17-L19

Added lines #L17 - L19 were not covered by tests

from pytensor.tensor.variable import TensorVariable
from zarr.storage import BaseStore
from zarr.sync import Synchronizer

Check warning on line 23 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L21-L23

Added lines #L21 - L23 were not covered by tests

from pymc.backends.arviz import (

Check warning on line 25 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L25

Added line #L25 was not covered by tests
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.backends.base import BaseTrace
from pymc.model.core import Model, modelcontext
from pymc.step_methods.compound import (

Check warning on line 32 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L30-L32

Added lines #L30 - L32 were not covered by tests
BlockedStep,
CompoundStep,
StatsBijection,
get_stats_dtypes_shapes_from_steps,
)
from pymc.util import get_default_varnames

Check warning on line 38 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L38

Added line #L38 was not covered by tests


class ZarrChain(BaseTrace):
def __init__(

Check warning on line 42 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L41-L42

Added lines #L41 - L42 were not covered by tests
self,
store: BaseStore | MutableMapping,
stats_bijection: StatsBijection,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
test_point: Sequence[dict[str, np.ndarray]] | None = None,
):
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
self.draw_idx = 0
self._posterior = zarr.open_group(

Check warning on line 53 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L51-L53

Added lines #L51 - L53 were not covered by tests
store, synchronizer=synchronizer, path="posterior", mode="a"
)
self._sample_stats = zarr.open_group(

Check warning on line 56 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L56

Added line #L56 was not covered by tests
store, synchronizer=synchronizer, path="sample_stats", mode="a"
)
self._sampling_state = zarr.open_group(

Check warning on line 59 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L59

Added line #L59 was not covered by tests
store, synchronizer=synchronizer, path="_sampling_state", mode="a"
)
self.stats_bijection = stats_bijection

Check warning on line 62 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L62

Added line #L62 was not covered by tests

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override]
self.chain = chain

Check warning on line 65 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L64-L65

Added lines #L64 - L65 were not covered by tests

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
chain = self.chain
draw_idx = self.draw_idx
for var_name, var_value in zip(self.varnames, self.fn(draw)):
self._posterior[var_name].set_orthogonal_selection(

Check warning on line 71 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L67-L71

Added lines #L67 - L71 were not covered by tests
(chain, draw_idx),
var_value,
)
for var_name, var_value in self.stats_bijection.map(stats).items():
self._sample_stats[var_name].set_orthogonal_selection(

Check warning on line 76 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L75-L76

Added lines #L75 - L76 were not covered by tests
(chain, draw_idx),
var_value,
)
self.draw_idx += 1

Check warning on line 80 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L80

Added line #L80 was not covered by tests

def record_sampling_state(self, step):
self._sampling_state.sampling_state.set_coordinate_selection(

Check warning on line 83 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L82-L83

Added lines #L82 - L83 were not covered by tests
self.chain, np.array([step.sampling_state], dtype="object")
)
self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx)

Check warning on line 86 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L86

Added line #L86 was not covered by tests


FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None

Check warning on line 89 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L89

Added line #L89 was not covered by tests


def get_fill_value_and_codec(

Check warning on line 92 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L92

Added line #L92 was not covered by tests
dtype: Any,
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)

Check warning on line 107 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L95-L107

Added lines #L95 - L107 were not covered by tests
else:
return (None, _dtype, numcodecs.Pickle())

Check warning on line 109 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L109

Added line #L109 was not covered by tests


class ZarrTrace:
def __init__(

Check warning on line 113 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L112-L113

Added lines #L112 - L113 were not covered by tests
self,
store: BaseStore | MutableMapping | None = None,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
include_transformed: bool = False,
):
model = modelcontext(model)
self.model = model

Check warning on line 122 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L121-L122

Added lines #L121 - L122 were not covered by tests

self.synchronizer = synchronizer
self.root = zarr.group(

Check warning on line 125 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L124-L125

Added lines #L124 - L125 were not covered by tests
store=store,
overwrite=True,
synchronizer=synchronizer,
)
self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model)

Check warning on line 130 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L130

Added line #L130 was not covered by tests

if vars is None:
vars = model.unobserved_value_vars

Check warning on line 133 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L132-L133

Added lines #L132 - L133 were not covered by tests

unnamed_vars = {var for var in vars if var.name is None}
if unnamed_vars:
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")
self.varnames = get_default_varnames(

Check warning on line 138 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L135-L138

Added lines #L135 - L138 were not covered by tests
[var.name for var in vars], include_transformed=include_transformed
)
self.vars = [var for var in vars if var.name in self.varnames]

Check warning on line 141 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L141

Added line #L141 was not covered by tests

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")

Check warning on line 143 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L143

Added line #L143 was not covered by tests

# Get variable shapes. Most backends will need this
# information.
test_point = model.initial_point()
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_dtype_shapes = {var: (value.dtype, value.shape) for var, value in var_values}
self._is_base_setup = False

Check warning on line 150 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L147-L150

Added lines #L147 - L150 were not covered by tests

@property
def posterior(self):
return self.root.posterior

Check warning on line 154 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L152-L154

Added lines #L152 - L154 were not covered by tests

@property
def sample_stats(self):
return self.root.sample_stats

Check warning on line 158 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L156-L158

Added lines #L156 - L158 were not covered by tests

@property
def constant_data(self):
return self.root.constant_data

Check warning on line 162 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L160-L162

Added lines #L160 - L162 were not covered by tests

@property
def observed_data(self):
return self.root.observed_data

Check warning on line 166 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L164-L166

Added lines #L164 - L166 were not covered by tests

@property
def sampling_state(self):
return self.root.sampling_state

Check warning on line 170 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L168-L170

Added lines #L168 - L170 were not covered by tests

def init_trace(self, chains: int, draws: int, step: BlockedStep | CompoundStep):
self.create_group(

Check warning on line 173 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L172-L173

Added lines #L172 - L173 were not covered by tests
name="constant_data",
data_dict=find_constants(self.model),
)

self.create_group(

Check warning on line 178 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L178

Added line #L178 was not covered by tests
name="observed_data",
data_dict=find_observations(self.model),
)

self.init_group_with_empty(

Check warning on line 183 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L183

Added line #L183 was not covered by tests
group=self.root.create_group(name="posterior", overwrite=True),
var_dtype_and_shape=self.var_dtype_shapes,
chains=chains,
draws=draws,
)
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(

Check warning on line 189 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L189

Added line #L189 was not covered by tests
[step] if isinstance(step, BlockedStep) else step.methods
)
self.init_group_with_empty(

Check warning on line 192 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L192

Added line #L192 was not covered by tests
group=self.root.create_group(name="sample_stats", overwrite=True),
var_dtype_and_shape=stats_dtypes_shapes,
chains=chains,
draws=draws,
)

self.init_sampling_state_group(chains=chains)

Check warning on line 199 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L199

Added line #L199 was not covered by tests

self.straces = [

Check warning on line 201 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L201

Added line #L201 was not covered by tests
ZarrChain(
store=self.root.store,
synchronizer=self.synchronizer,
model=self.model,
vars=self.vars,
test_point=None,
stats_bijection=StatsBijection(step.stats_dtypes),
)
]
for chain, strace in enumerate(self.straces):
strace.setup(draws=draws, chain=chain, sampler_vars=None)

Check warning on line 212 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L211-L212

Added lines #L211 - L212 were not covered by tests

def close(self):
for strace in self.straces:
strace._posterior.close()
strace._sample_stats.close()
strace._sampling_state.close()
zarr.consolidate_metadata(self.root.store)
self.root.store.close()

Check warning on line 220 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L214-L220

Added lines #L214 - L220 were not covered by tests

def init_sampling_state_group(self, chains):
state = self.root.create_group(name="_sampling_state", overwrite=True)
sampling_state = state.empty(

Check warning on line 224 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L222-L224

Added lines #L222 - L224 were not covered by tests
name="sampling_state",
overwrite=True,
shape=(chains,),
chunks=(1,),
dtype="object",
object_codec=numcodecs.Pickle(),
)
sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
draw_idx = state.array(

Check warning on line 233 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L232-L233

Added lines #L232 - L233 were not covered by tests
name="draw_idx",
overwrite=True,
data=np.zeros(chains, dtype="int"),
chunks=(1,),
dtype="int",
fill_value=-1,
)
draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
chain = state.array(name="chain", data=range(chains))
chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})

Check warning on line 243 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L241-L243

Added lines #L241 - L243 were not covered by tests

def init_group_with_empty(self, group, var_dtype_and_shape, chains, draws):
group_coords = {"chain": range(chains), "draw": range(draws)}
for name, (dtype, shape) in var_dtype_and_shape.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(dtype)
shape = shape or ()
array = group.full(

Check warning on line 250 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L245-L250

Added lines #L245 - L250 were not covered by tests
name=name,
dtype=dtype,
fill_value=fill_value,
object_codec=object_codec,
shape=(chains, draws, *shape),
chunks=(1, 1, *shape),
)
try:
dims = self.vars_to_dims[name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i, shape_i in enumerate(shape):
dim = f"{name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(shape_i))
dims = ("chain", "draw", *dims)
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

Check warning on line 273 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L258-L273

Added lines #L258 - L273 were not covered by tests

def create_group(self, name, data_dict):
if data_dict:
group_coords = {}
group = self.root.create_group(name=name, overwrite=True)
for var_name, var_value in data_dict.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(var_value.dtype)
array = group.array(

Check warning on line 281 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L275-L281

Added lines #L275 - L281 were not covered by tests
name=var_name,
data=var_value,
fill_value=fill_value,
dtype=dtype,
object_codec=object_codec,
)
try:
dims = self.vars_to_dims[var_name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i in range(var_value.ndim):
dim = f"{var_name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(var_value.shape[i]))
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

Check warning on line 302 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L288-L302

Added lines #L288 - L302 were not covered by tests

0 comments on commit 4cf7d0c

Please sign in to comment.