From ea4ad3fb573909ffa1eb6c298981450489260874 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Tue, 11 Jul 2023 14:57:48 +0200 Subject: [PATCH] supressing ArviZ warning and allowing step selection --- pymc_experimental/model_builder.py | 24 +++++++++++++++---- pymc_experimental/tests/test_model_builder.py | 9 +++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index c3550712..b8f049fa 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -15,6 +15,7 @@ import hashlib import json +import warnings from abc import abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -290,12 +291,18 @@ def sample_model(self, **kwargs): raise RuntimeError( "The model hasn't been built yet, call .build_model() first or call .fit() instead." ) - with self.model: sampler_args = {**self.sampler_config, **kwargs} - idata = pm.sample(**sampler_args) - idata.extend(pm.sample_prior_predictive()) - idata.extend(pm.sample_posterior_predictive(idata)) + if "step" in sampler_args: + step_function_name = sampler_args["step"] + step_function = getattr(pm, step_function_name) + sampler_args["step"] = step_function() + idata = pm.sample(**sampler_args) + idata.extend(pm.sample_prior_predictive()) + idata.extend(pm.sample_posterior_predictive(idata)) + sampler_args["step"] = step_function_name + else: + idata = pm.sample(**sampler_args) self.set_idata_attrs(idata) return idata @@ -479,7 +486,14 @@ def fit( X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y], axis=1) assert all(combined_data.columns), "All columns must have non-empty names" - self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + return self.idata # type: ignore def predict( diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index dd4a88ab..8a7b9c49 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -209,3 +209,12 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + + +def test_step_selection_in_sample_config(toy_X, toy_y): + sampler_config = { + "step": "Slice", + } + model = test_ModelBuilder(sampler_config=sampler_config) + model.fit(toy_X, toy_y) + assert model.idata is not None