diff --git a/nessai/samplers/importancesampler.py b/nessai/samplers/importancesampler.py index ec12cdb4..7e04fa21 100644 --- a/nessai/samplers/importancesampler.py +++ b/nessai/samplers/importancesampler.py @@ -36,12 +36,26 @@ class OrderedSamples: - """Samples ordered by log-likelihood.""" + """Samples ordered by log-likelihood. + + Parameters + ---------- + strict_threshold + If true, when adding new samples, only those above the current + log-likelihood threshold will be added to the live points. + replace_all + If true, all samples will be remove when calling :code:`remove_samples` + save_log_q + If true, :code:`log_q` will be saved when the instance is pickled. This + makes resuming faster but increases the disk usage. If false, the + values will not be saved and must be recomputed. + """ def __init__( self, strict_threshold: bool = False, replace_all: bool = False, + save_log_q: bool = False, ) -> None: self.samples = None self.log_q = None @@ -51,6 +65,7 @@ def __init__( self.replace_all = replace_all self.state = _INSIntegralState() self.log_likelihood_threshold = None + self.save_log_q = save_log_q @property def live_points(self) -> np.ndarray: @@ -254,7 +269,10 @@ def __getstate__(self): d = self.__dict__ exclude = {"log_q"} state = {k: d[k] for k in d.keys() - exclude} - state["log_q"] = None + if self.save_log_q: + state["log_q"] = self.log_q + else: + state["log_q"] = None return state @@ -303,6 +321,9 @@ class ImportanceNestedSampler(BaseNestedSampler): If true, when drawing new samples, only those with likelihoods above the current threshold will be added to the live points. If false, all new samples are added to the live points. + save_log_q : bool + Boolean that determines if the log_q array is saved when checkpointing. + If False, this can help reduce the disk usage. """ stopping_criterion_aliases = dict( @@ -328,6 +349,7 @@ def __init__( checkpoint_on_iteration: bool = False, checkpoint_callback: Optional[Callable] = None, save_existing_checkpoint: bool = False, + save_log_q: bool = False, logging_interval: int = None, log_on_iteration: bool = True, resume_file: Optional[str] = None, @@ -427,6 +449,7 @@ def __init__( self.bootstrap_log_evidence_error = None self.weighted_kl = weighted_kl self.save_existing_checkpoint = save_existing_checkpoint + self.save_log_q = save_log_q self.log_dZ = np.inf self.ratio = np.inf @@ -448,11 +471,13 @@ def __init__( self.training_samples = OrderedSamples( strict_threshold=self.strict_threshold, replace_all=self.replace_all, + save_log_q=self.save_log_q, ) if self.draw_iid_live: self.iid_samples = OrderedSamples( strict_threshold=self.strict_threshold, replace_all=self.replace_all, + save_log_q=self.save_log_q, ) else: self.iid_samples = None @@ -2261,31 +2286,33 @@ def resume_from_pickled_sampler( obj = super(ImportanceNestedSampler, cls).resume_from_pickled_sampler( sampler, model, **kwargs ) + logger.info(f"Resuming sampler at iteration {obj.iteration}") + logger.info(f"Current number of samples: {len(obj.nested_samples)}") + logger.info( + f"Current log-evidence: {obj.log_evidence:3f} " + f"+/- {obj.log_evidence_error:.3f}" + ) if flow_config is None: flow_config = {} obj.proposal.resume(model, flow_config, weights_path=weights_path) - logger.debug("Recomputing log_q") - ( - _, - obj.training_samples.log_q, - ) = obj.proposal.compute_meta_proposal_samples( - obj.training_samples.samples - ) - if obj.iid_samples: + if obj.training_samples.log_q is None: + logger.info("Recomputing log_q") + ( + _, + obj.training_samples.log_q, + ) = obj.proposal.compute_meta_proposal_samples( + obj.training_samples.samples + ) + if obj.iid_samples and obj.iid_samples.log_q is None: + logger.info("Recomputing log_q for i.i.d samples") ( _, obj.iid_samples.log_q, ) = obj.proposal.compute_meta_proposal_samples( obj.iid_samples.samples ) - - logger.info(f"Resuming sampler at iteration {obj.iteration}") - logger.info(f"Current number of samples: {len(obj.nested_samples)}") - logger.info( - f"Current logZ: {obj.log_evidence:3f} " - f"+/- {obj.log_evidence_error:.3f}" - ) + logger.info("Finished resuming sampler") return obj def __getstate__(self): @@ -2308,7 +2335,6 @@ def __getstate__(self): else: state["_previous_likelihood_evaluations"] = 0 state["_previous_likelihood_evaluation_time"] = 0 - state["log_q"] = None return state, self.proposal, self.training_samples, self.iid_samples def __setstate__(self, state): diff --git a/tests/test_samplers/test_importance_nested_sampler/test_config.py b/tests/test_samplers/test_importance_nested_sampler/test_config.py index 61de0b98..a8463c3b 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_config.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_config.py @@ -7,10 +7,13 @@ import pytest -def test_init(ins, model): +@pytest.mark.parametrize("save_log_q", [False, True]) +def test_init(ins, model, save_log_q): ins.add_fields = MagicMock() - INS.__init__(ins, model) + INS.__init__(ins, model, save_log_q=save_log_q, draw_iid_live=True) ins.add_fields.assert_called_once() + assert ins.training_samples.save_log_q is save_log_q + assert ins.iid_samples.save_log_q is save_log_q def test_add_fields(): diff --git a/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py b/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py index 305b7470..e5cffe1f 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_ordered_samples.py @@ -310,3 +310,18 @@ def test_computed_evidence_ratio(ordered_samples, samples, threshold): mock_log_evidence.call_args_list[0][0][0], samples[above_threshold] ) assert out == (log_z - log_z_total) + + +@pytest.mark.parametrize("save_log_q", [False, True]) +def test_getstate(ordered_samples, save_log_q): + samples = np.random.randn(20, 4) + log_q = np.random.randn(2, 20) + ordered_samples.save_log_q = save_log_q + ordered_samples.log_q = log_q + ordered_samples.samples = samples + state = OrderedSamples.__getstate__(ordered_samples) + assert state["samples"] is samples + if save_log_q: + assert state["log_q"] is log_q + else: + assert state["log_q"] is None diff --git a/tests/test_samplers/test_importance_nested_sampler/test_resume.py b/tests/test_samplers/test_importance_nested_sampler/test_resume.py index ed6a7f1c..a7bb4f02 100644 --- a/tests/test_samplers/test_importance_nested_sampler/test_resume.py +++ b/tests/test_samplers/test_importance_nested_sampler/test_resume.py @@ -1,7 +1,8 @@ import datetime -from unittest.mock import MagicMock, patch - import numpy as np +import pickle +import pytest +from unittest.mock import MagicMock, patch from nessai.samplers.importancesampler import ImportanceNestedSampler as INS @@ -10,7 +11,6 @@ def test_getstate_no_model(ins): ins.proposal = MagicMock() ins.model = None state, proposal, training_samples, iid_samples = INS.__getstate__(ins) - assert state["log_q"] is None assert "model" not in state assert state["_previous_likelihood_evaluations"] == 0 assert state["_previous_likelihood_evaluation_time"] == 0 @@ -30,7 +30,6 @@ def test_getstate_model(ins): ins.model.likelihood_evaluation_time = time state, proposal, training_samples, iid_samples = INS.__getstate__(ins) - assert state["log_q"] is None assert "model" not in state assert state["_previous_likelihood_evaluations"] == evals assert state["_previous_likelihood_evaluation_time"] == 30 @@ -39,7 +38,8 @@ def test_getstate_model(ins): assert iid_samples is ins.iid_samples -def test_resume_from_pickled_sampler(model, samples): +@pytest.mark.parametrize("has_log_q", [False, True]) +def test_resume_from_pickled_sampler(model, samples, has_log_q): sampler = MagicMock() @@ -48,12 +48,24 @@ def test_resume_from_pickled_sampler(model, samples): obj.log_evidence = 0.0 obj.log_evidence_error = 1.0 obj.proposal = MagicMock() - obj.training_samples.samples = samples log_meta_proposal = np.log(np.random.rand(len(samples))) log_q = np.log(np.random.rand(len(samples))) + log_meta_proposal_iid = np.log(np.random.rand(len(samples))) + log_q_iid = np.log(np.random.rand(len(samples))) obj.proposal.compute_meta_proposal_samples = MagicMock( - return_value=(log_meta_proposal, log_q) + side_effect=[ + (log_meta_proposal, log_q), + (log_meta_proposal_iid, log_q_iid), + ] ) + obj.training_samples.samples = samples + obj.iid_samples.samples = samples + if has_log_q: + obj.training_samples.log_q = log_q + obj.iid_samples.log_q = log_q_iid + else: + obj.training_samples.log_q = None + obj.iid_samples.log_q = None with patch( "nessai.samplers.importancesampler.BaseNestedSampler.resume_from_pickled_sampler", # noqa @@ -62,5 +74,35 @@ def test_resume_from_pickled_sampler(model, samples): out = INS.resume_from_pickled_sampler(sampler, model) mock_resume.assert_called_once_with(sampler, model) + if has_log_q: + obj.proposal.compute_meta_proposal_samples.assert_not_called() + else: + obj.proposal.compute_meta_proposal_samples.assert_called() assert out.training_samples.log_q is log_q + assert out.iid_samples.log_q is log_q_iid + + +@pytest.mark.parametrize("save_log_q", [True, False]) +@pytest.mark.integration_test +def test_pickling_sampler_integration(integration_model, tmp_path, save_log_q): + outdir = tmp_path / "test_pickle" + ins = INS( + model=integration_model, + output=outdir, + nlive=50, + min_samples=10, + max_iteration=1, + save_log_q=save_log_q, + plot=False, + checkpointing=False, + ) + ins.nested_sampling_loop() + data = pickle.dumps(ins) + loaded_ins = pickle.loads(data) + if save_log_q: + np.testing.assert_array_equal( + loaded_ins._ordered_samples.log_q, ins._ordered_samples.log_q + ) + else: + assert loaded_ins._ordered_samples.log_q is None diff --git a/tests/test_sampling/test_ins_sampling.py b/tests/test_sampling/test_ins_sampling.py index dba8e153..e860941a 100644 --- a/tests/test_sampling/test_ins_sampling.py +++ b/tests/test_sampling/test_ins_sampling.py @@ -9,7 +9,8 @@ @pytest.mark.slow_integration_test @pytest.mark.flaky(reruns=3) -def test_ins_resume(tmp_path, model, flow_config): +@pytest.mark.parametrize("save_log_q", [False, True]) +def test_ins_resume(tmp_path, model, flow_config, save_log_q): """Assert the INS sampler resumes correctly""" output = tmp_path / "test_ins_resume" fp = FlowSampler( @@ -22,6 +23,7 @@ def test_ins_resume(tmp_path, model, flow_config): flow_config=flow_config, importance_nested_sampler=True, max_iteration=2, + save_log_q=save_log_q, ) fp.run()