Skip to content

Commit

Permalink
Merge pull request #72 from michaelosthege/bump026
Browse files Browse the repository at this point in the history
Bump for 0.2.6
  • Loading branch information
michaelosthege authored Dec 19, 2022
2 parents 35f17ad + a97d84d commit 15b6082
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mcbackend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
pass


__version__ = "0.2.5"
__version__ = "0.2.6"
12 changes: 11 additions & 1 deletion mcbackend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .meta import ChainMeta, RunMeta, Variable
from .npproto.utils import ndarray_to_numpy
from .utils import as_array_from_ragged

InferenceData = TypeVar("InferenceData")
try:
Expand Down Expand Up @@ -252,7 +253,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
warmup_sample_stats[svar.name].append(stats[tune])
sample_stats[svar.name].append(stats[~tune])

kwargs.setdefault("save_warmup", True)
if not equalize_chain_lengths:
# Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
warmup_sample_stats = {
k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()
}
posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()}
sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}

idata = from_dict(
warmup_posterior=warmup_posterior,
warmup_sample_stats=warmup_sample_stats,
Expand All @@ -263,6 +272,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
attrs=self.meta.attributes,
constant_data=self.constant_data,
observed_data=self.observed_data,
save_warmup=True,
**kwargs,
)
return idata
Expand Down
13 changes: 13 additions & 0 deletions mcbackend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

import mcbackend
from mcbackend import utils as mutils
from mcbackend.meta import ChainMeta, DataVariable, RunMeta, Variable
from mcbackend.npproto import utils

Expand Down Expand Up @@ -407,3 +408,15 @@ def test__big_variables(self):
speed = self.measure_big_variables()
assert speed.draws_per_second > 500 or speed.mib_per_second > 5
pass


def test_as_array_from_ragged():
even = mutils.as_array_from_ragged(
[
numpy.ones(2),
numpy.ones(3),
]
)
assert isinstance(even, numpy.ndarray)
assert even.dtype == numpy.dtype(object)
pass
11 changes: 11 additions & 0 deletions mcbackend/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Contains helper functions that are independent of McBackend components."""
from typing import Sequence

import numpy as np


def as_array_from_ragged(arrs: Sequence[np.ndarray]) -> np.ndarray:
shapes = {np.shape(arr) for arr in arrs}
if len(shapes) > 1:
return np.array(arrs, dtype=object)
return np.array(arrs)

0 comments on commit 15b6082

Please sign in to comment.