|
1 | 1 | # pylint: disable=no-member, invalid-name, redefined-outer-name
|
2 | 2 | from sys import version_info
|
3 | 3 | from typing import Dict, Tuple
|
4 |
| -import packaging |
5 | 4 |
|
6 | 5 | import numpy as np
|
7 | 6 | import pytest
|
@@ -399,8 +398,10 @@ def test_no_model_deprecation(self):
|
399 | 398 | fails = check_multiple_attrs(test_dict, inference_data)
|
400 | 399 | assert not fails
|
401 | 400 |
|
| 401 | + |
| 402 | +class TestPyMC3WarmupHandling: |
402 | 403 | @pytest.mark.skipif(
|
403 |
| - packaging.version.parse(pm.__version__) < packaging.version.parse("3.9"), |
| 404 | + not hasattr(pm.backends.base.SamplerReport, "n_draws"), |
404 | 405 | reason="requires pymc3 3.9 or higher",
|
405 | 406 | )
|
406 | 407 | @pytest.mark.parametrize("save_warmup", [False, True])
|
@@ -434,3 +435,58 @@ def test_save_warmup(self, save_warmup):
|
434 | 435 | if save_warmup:
|
435 | 436 | assert idata.warmup_posterior.dims["chain"] == 2
|
436 | 437 | assert idata.warmup_posterior.dims["draw"] == 100
|
| 438 | + |
| 439 | + @pytest.mark.skipif( |
| 440 | + hasattr(pm.backends.base.SamplerReport, "n_draws"), reason="requires pymc3 3.8 or lower", |
| 441 | + ) |
| 442 | + def test_save_warmup_issue_1208_before_3_9(self): |
| 443 | + with pm.Model(): |
| 444 | + pm.Uniform("u1") |
| 445 | + pm.Normal("n1") |
| 446 | + trace = pm.sample( |
| 447 | + tune=100, |
| 448 | + draws=200, |
| 449 | + chains=2, |
| 450 | + cores=1, |
| 451 | + step=pm.Metropolis(), |
| 452 | + discard_tuned_samples=False, |
| 453 | + ) |
| 454 | + assert isinstance(trace, pm.backends.base.MultiTrace) |
| 455 | + assert len(trace) == 300 |
| 456 | + |
| 457 | + # <=3.8 did not track n_draws in the sampler report, |
| 458 | + # making from_pymc3 fall back to len(trace) and triggering a warning |
| 459 | + with pytest.warns(UserWarning, match="Warmup samples"): |
| 460 | + idata = from_pymc3(trace, save_warmup=True) |
| 461 | + assert idata.posterior.dims["draw"] == 300 |
| 462 | + assert idata.posterior.dims["chain"] == 2 |
| 463 | + |
| 464 | + @pytest.mark.skipif( |
| 465 | + not hasattr(pm.backends.base.SamplerReport, "n_draws"), |
| 466 | + reason="requires pymc3 3.9 or higher", |
| 467 | + ) |
| 468 | + def test_save_warmup_issue_1208_after_3_9(self): |
| 469 | + with pm.Model(): |
| 470 | + pm.Uniform("u1") |
| 471 | + pm.Normal("n1") |
| 472 | + trace = pm.sample( |
| 473 | + tune=100, |
| 474 | + draws=200, |
| 475 | + chains=2, |
| 476 | + cores=1, |
| 477 | + step=pm.Metropolis(), |
| 478 | + discard_tuned_samples=False, |
| 479 | + ) |
| 480 | + assert isinstance(trace, pm.backends.base.MultiTrace) |
| 481 | + assert len(trace) == 300 |
| 482 | + |
| 483 | + # from original trace, warmup draws should be separated out |
| 484 | + idata = from_pymc3(trace, save_warmup=True) |
| 485 | + assert idata.posterior.dims["chain"] == 2 |
| 486 | + assert idata.posterior.dims["draw"] == 200 |
| 487 | + |
| 488 | + # manually sliced trace triggers the same warning as <=3.8 |
| 489 | + with pytest.warns(UserWarning, match="Warmup samples"): |
| 490 | + idata = from_pymc3(trace[-30:], save_warmup=True) |
| 491 | + assert idata.posterior.dims["chain"] == 2 |
| 492 | + assert idata.posterior.dims["draw"] == 30 |
0 commit comments