Skip to content

Commit 6b4de81

Browse files
Deal with PyMC3 traces that do not contain n_draws information in the sampler report (#1209)
* add regression test for #1208 * check if n_draws is available and informative before using it + also warn the user about manually slicing + closes #1208 + bumps to 0.8.2 because of the hotfix * split new test and adapt conditions for different pymc3 versions * make pylint happy * make black happy * address review feedback
1 parent cf7c19e commit 6b4de81

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

arviz/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
22
"""ArviZ is a library for exploratory analysis of Bayesian models."""
3-
__version__ = "0.8.1"
3+
__version__ = "0.8.2"
44

55
import os
66
import logging

arviz/data/io_pymc3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
0
9999
].model
100100
self.nchains = trace.nchains if hasattr(trace, "nchains") else 1
101-
if hasattr(trace.report, "n_tune"):
101+
if hasattr(trace.report, "n_draws") and trace.report.n_draws is not None:
102102
self.ndraws = trace.report.n_draws
103103
self.attrs = {
104104
"sampling_time": trace.report.t_sampling,
@@ -109,7 +109,8 @@ def __init__(
109109
if self.save_warmup:
110110
warnings.warn(
111111
"Warmup samples will be stored in posterior group and will not be"
112-
" excluded from stats and diagnostics. Please consider using PyMC3>=3.9",
112+
" excluded from stats and diagnostics."
113+
" Please consider using PyMC3>=3.9 and do not slice the trace manually.",
113114
UserWarning,
114115
)
115116
else:

arviz/tests/external_tests/test_data_pymc.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=no-member, invalid-name, redefined-outer-name
22
from sys import version_info
33
from typing import Dict, Tuple
4-
import packaging
54

65
import numpy as np
76
import pytest
@@ -399,8 +398,10 @@ def test_no_model_deprecation(self):
399398
fails = check_multiple_attrs(test_dict, inference_data)
400399
assert not fails
401400

401+
402+
class TestPyMC3WarmupHandling:
402403
@pytest.mark.skipif(
403-
packaging.version.parse(pm.__version__) < packaging.version.parse("3.9"),
404+
not hasattr(pm.backends.base.SamplerReport, "n_draws"),
404405
reason="requires pymc3 3.9 or higher",
405406
)
406407
@pytest.mark.parametrize("save_warmup", [False, True])
@@ -434,3 +435,58 @@ def test_save_warmup(self, save_warmup):
434435
if save_warmup:
435436
assert idata.warmup_posterior.dims["chain"] == 2
436437
assert idata.warmup_posterior.dims["draw"] == 100
438+
439+
@pytest.mark.skipif(
440+
hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.8 or lower",
441+
)
442+
def test_save_warmup_issue_1208_before_3_9(self):
443+
with pm.Model():
444+
pm.Uniform("u1")
445+
pm.Normal("n1")
446+
trace = pm.sample(
447+
tune=100,
448+
draws=200,
449+
chains=2,
450+
cores=1,
451+
step=pm.Metropolis(),
452+
discard_tuned_samples=False,
453+
)
454+
assert isinstance(trace, pm.backends.base.MultiTrace)
455+
assert len(trace) == 300
456+
457+
# <=3.8 did not track n_draws in the sampler report,
458+
# making from_pymc3 fall back to len(trace) and triggering a warning
459+
with pytest.warns(UserWarning, match="Warmup samples"):
460+
idata = from_pymc3(trace, save_warmup=True)
461+
assert idata.posterior.dims["draw"] == 300
462+
assert idata.posterior.dims["chain"] == 2
463+
464+
@pytest.mark.skipif(
465+
not hasattr(pm.backends.base.SamplerReport, "n_draws"),
466+
reason="requires pymc3 3.9 or higher",
467+
)
468+
def test_save_warmup_issue_1208_after_3_9(self):
469+
with pm.Model():
470+
pm.Uniform("u1")
471+
pm.Normal("n1")
472+
trace = pm.sample(
473+
tune=100,
474+
draws=200,
475+
chains=2,
476+
cores=1,
477+
step=pm.Metropolis(),
478+
discard_tuned_samples=False,
479+
)
480+
assert isinstance(trace, pm.backends.base.MultiTrace)
481+
assert len(trace) == 300
482+
483+
# from original trace, warmup draws should be separated out
484+
idata = from_pymc3(trace, save_warmup=True)
485+
assert idata.posterior.dims["chain"] == 2
486+
assert idata.posterior.dims["draw"] == 200
487+
488+
# manually sliced trace triggers the same warning as <=3.8
489+
with pytest.warns(UserWarning, match="Warmup samples"):
490+
idata = from_pymc3(trace[-30:], save_warmup=True)
491+
assert idata.posterior.dims["chain"] == 2
492+
assert idata.posterior.dims["draw"] == 30

0 commit comments

Comments
 (0)