diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index b81e7b46..04137ba1 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -15,6 +15,7 @@ import hashlib import json +import warnings from abc import abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -290,12 +291,18 @@ def sample_model(self, **kwargs): raise RuntimeError( "The model hasn't been built yet, call .build_model() first or call .fit() instead." ) - with self.model: sampler_args = {**self.sampler_config, **kwargs} - idata = pm.sample(**sampler_args) - idata.extend(pm.sample_prior_predictive()) - idata.extend(pm.sample_posterior_predictive(idata)) + if "step" in sampler_args: + step_function_name = sampler_args["step"] + step_function = getattr(pm, step_function_name) + sampler_args["step"] = step_function() + idata = pm.sample(**sampler_args) + idata.extend(pm.sample_prior_predictive()) + idata.extend(pm.sample_posterior_predictive(idata)) + sampler_args["step"] = step_function_name + else: + idata = pm.sample(**sampler_args) idata = self.set_idata_attrs(idata) return idata @@ -496,7 +503,13 @@ def fit( X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y], axis=1) assert all(combined_data.columns), "All columns must have non-empty names" - self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore return self.idata # type: ignore diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 3327d1e0..60cbcca3 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -237,3 +237,12 @@ def test_id(): ).hexdigest()[:16] assert model_builder.id == expected_id + + +def test_step_selection_in_sample_config(toy_X, toy_y): + sampler_config = { + "step": "Slice", + } + model = test_ModelBuilder(sampler_config=sampler_config) + model.fit(toy_X, toy_y) + assert model.idata is not None