diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 986a34f4ba2..2c7164fc7cf 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -72,6 +72,7 @@ from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray +from pymc.backends.zarr import ZarrTrace from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep @@ -118,15 +119,27 @@ def _init_trace( def init_traces( *, - backend: TraceOrBackend | None, + backend: TraceOrBackend | ZarrTrace | None, chains: int, expected_length: int, step: BlockedStep | CompoundStep, initial_point: Mapping[str, np.ndarray], model: Model, trace_vars: list[TensorVariable] | None = None, + tune: int = 0, ) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initialize a trace recorder for each chain.""" + if isinstance(backend, ZarrTrace): + backend.init_trace( + chains=chains, + draws=expected_length - tune, + tune=tune, + step=step, + model=model, + vars=trace_vars, + test_point=initial_point, + ) + return None, backend.straces if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( backend=backend, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4ee79607b79..bf9a758db44 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -50,6 +50,7 @@ find_observations, ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains +from pymc.backends.zarr import ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -808,6 +809,7 @@ def joined_blas_limiter(): trace_vars=trace_vars, initial_point=ip, model=model, + tune=tune, ) sample_args = { @@ -890,7 +892,7 @@ def joined_blas_limiter(): # into a function to make it easier to test and refactor. return _sample_return( run=run, - traces=traces, + traces=trace if isinstance(trace, ZarrTrace) else traces, tune=tune, t_sampling=t_sampling, discard_tuned_samples=discard_tuned_samples, @@ -905,7 +907,7 @@ def joined_blas_limiter(): def _sample_return( *, run: RunType | None, - traces: Sequence[IBaseTrace], + traces: Sequence[IBaseTrace] | ZarrTrace, tune: int, t_sampling: float, discard_tuned_samples: bool, @@ -919,13 +921,65 @@ def _sample_return( Final step of `pm.sampler`. """ + if isinstance(traces, ZarrTrace): + # Split warmup from posterior samples + traces.split_warmup_groups() + + # Set sampling time + traces._sampling_state.sampling_time[:] = t_sampling + + # Compute number of actual draws per chain + total_draws_per_chain = traces._sampling_state.draw_idx[:] + n_chains = len(traces.straces) + desired_tune = traces.tuning_steps + desired_draw = len(traces.posterior.draw) + tuning_steps_per_chain = np.clip(total_draws_per_chain, 0, desired_tune) + draws_per_chain = total_draws_per_chain - tuning_steps_per_chain + + total_n_tune = tuning_steps_per_chain.sum() + total_draws = draws_per_chain.sum() + + _log.info( + f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {desired_tune:_d} desired tune and {desired_draw:_d} desired draw iterations ' + f"(Actually sampled {total_n_tune:_d} tune and {total_draws:_d} draws total) " + f"took {t_sampling:.0f} seconds." + ) + + if compute_convergence_checks or return_inferencedata: + idata = traces.to_inferencedata(save_warmup=not discard_tuned_samples) + log_likelihood = idata_kwargs.pop("log_likelihood", False) + if log_likelihood: + from pymc.stats.log_density import compute_log_likelihood + + idata = compute_log_likelihood( + idata, + var_names=None if log_likelihood is True else log_likelihood, + extend_inferencedata=True, + model=model, + sample_dims=["chain", "draw"], + progressbar=False, + ) + + if compute_convergence_checks: + warns = run_convergence_checks(idata, model) + for warn in warns: + traces._sampling_state.global_warnings.append(warn) + log_warnings(warns) + + if return_inferencedata: + # By default we drop the "warning" stat which contains `SamplerWarning` + # objects that can not be stored with `.to_netcdf()`. + if not keep_warning_stat: + return drop_warning_stat(idata) + return idata + return traces + # Pick and slice chains to keep the maximum number of samples if discard_tuned_samples: traces, length = _choose_chains(traces, tune) else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened # ideally via the "tune" statistic, but not all samplers record it! if "tune" in mtrace.stat_names: