Skip to content

Commit

Permalink
Add constant and observed data to nutpie idata
Browse files Browse the repository at this point in the history
  • Loading branch information
Y0dler committed Oct 6, 2023
1 parent 602234b commit 8d60b7d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
40 changes: 38 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@
import numpy as np
import pytensor.gradient as tg

from arviz import InferenceData
from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from fastprogress.fastprogress import progress_bar
from pytensor.graph.basic import Variable
from typing_extensions import Protocol, TypeAlias

import pymc as pm

from pymc.backends import RunType, TraceOrBackend, init_traces
from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
Expand Down Expand Up @@ -293,8 +299,24 @@ def _sample_external_nuts(
"`idata_kwargs` are currently ignored by the nutpie sampler",
UserWarning,
)

# gather observed and constant data as nutpie.sample() has no access to the PyMC model
coords, dims = coords_and_dims_for_inferencedata(model)
constant_data = dict_to_dataset(
find_constants(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)
observed_data = dict_to_dataset(
find_observations(model),
library=pm,
coords=coords,
dims=dims,
default_dims=[],
)
compiled_model = nutpie.compile_pymc_model(model)
t_start = time.time()
idata = nutpie.sample(
compiled_model,
draws=draws,
Expand All @@ -305,6 +327,20 @@ def _sample_external_nuts(
progress_bar=progressbar,
**nuts_sampler_kwargs,
)
t_sample = time.time() - t_start
attrs = make_attrs(
{
"sampling_time": t_sample,
},
library=nutpie,
)
for k, v in attrs.items():
idata.posterior.attrs[k] = v
idata.add_groups(
{"constant_data": constant_data, "observed_data": observed_data},
coords=coords,
dims=dims,
)
return idata

elif sampler == "numpyro":
Expand Down
12 changes: 9 additions & 3 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy.testing as npt
import pytest

from pymc import Model, Normal, sample
from pymc import ConstantData, Model, Normal, sample


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
Expand All @@ -25,7 +25,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
pytest.importorskip(nuts_sampler)

with Model():
Normal("x")
x = Normal("x", 100, 5)
y = ConstantData("y", [1, 2, 3, 4])
ConstantData("z", [100, 190, 310, 405])

Normal("L", mu=x, sigma=0.1, observed=y)

kwargs = dict(
nuts_sampler=nuts_sampler,
Expand Down Expand Up @@ -55,7 +59,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
)
)
assert warns == expected

assert "y" in idata1.constant_data
assert "z" in idata1.constant_data
assert "L" in idata1.observed_data
assert idata1.posterior.chain.size == 2
assert idata1.posterior.draw.size == 500
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
Expand Down

0 comments on commit 8d60b7d

Please sign in to comment.