Skip to content

Commit

Permalink
Merge pull request #368 from mj-will/improve-ins-checkpointing
Browse files Browse the repository at this point in the history
Add option to save log_q
  • Loading branch information
mj-will authored Feb 24, 2024
2 parents 43c6413 + b9047af commit 24e8e5b
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 28 deletions.
62 changes: 44 additions & 18 deletions nessai/samplers/importancesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,26 @@


class OrderedSamples:
"""Samples ordered by log-likelihood."""
"""Samples ordered by log-likelihood.
Parameters
----------
strict_threshold
If true, when adding new samples, only those above the current
log-likelihood threshold will be added to the live points.
replace_all
If true, all samples will be remove when calling :code:`remove_samples`
save_log_q
If true, :code:`log_q` will be saved when the instance is pickled. This
makes resuming faster but increases the disk usage. If false, the
values will not be saved and must be recomputed.
"""

def __init__(
self,
strict_threshold: bool = False,
replace_all: bool = False,
save_log_q: bool = False,
) -> None:
self.samples = None
self.log_q = None
Expand All @@ -51,6 +65,7 @@ def __init__(
self.replace_all = replace_all
self.state = _INSIntegralState()
self.log_likelihood_threshold = None
self.save_log_q = save_log_q

@property
def live_points(self) -> np.ndarray:
Expand Down Expand Up @@ -254,7 +269,10 @@ def __getstate__(self):
d = self.__dict__
exclude = {"log_q"}
state = {k: d[k] for k in d.keys() - exclude}
state["log_q"] = None
if self.save_log_q:
state["log_q"] = self.log_q
else:
state["log_q"] = None
return state


Expand Down Expand Up @@ -303,6 +321,9 @@ class ImportanceNestedSampler(BaseNestedSampler):
If true, when drawing new samples, only those with likelihoods above
the current threshold will be added to the live points. If false, all
new samples are added to the live points.
save_log_q : bool
Boolean that determines if the log_q array is saved when checkpointing.
If False, this can help reduce the disk usage.
"""

stopping_criterion_aliases = dict(
Expand All @@ -328,6 +349,7 @@ def __init__(
checkpoint_on_iteration: bool = False,
checkpoint_callback: Optional[Callable] = None,
save_existing_checkpoint: bool = False,
save_log_q: bool = False,
logging_interval: int = None,
log_on_iteration: bool = True,
resume_file: Optional[str] = None,
Expand Down Expand Up @@ -427,6 +449,7 @@ def __init__(
self.bootstrap_log_evidence_error = None
self.weighted_kl = weighted_kl
self.save_existing_checkpoint = save_existing_checkpoint
self.save_log_q = save_log_q

self.log_dZ = np.inf
self.ratio = np.inf
Expand All @@ -448,11 +471,13 @@ def __init__(
self.training_samples = OrderedSamples(
strict_threshold=self.strict_threshold,
replace_all=self.replace_all,
save_log_q=self.save_log_q,
)
if self.draw_iid_live:
self.iid_samples = OrderedSamples(
strict_threshold=self.strict_threshold,
replace_all=self.replace_all,
save_log_q=self.save_log_q,
)
else:
self.iid_samples = None
Expand Down Expand Up @@ -2261,31 +2286,33 @@ def resume_from_pickled_sampler(
obj = super(ImportanceNestedSampler, cls).resume_from_pickled_sampler(
sampler, model, **kwargs
)
logger.info(f"Resuming sampler at iteration {obj.iteration}")
logger.info(f"Current number of samples: {len(obj.nested_samples)}")
logger.info(
f"Current log-evidence: {obj.log_evidence:3f} "
f"+/- {obj.log_evidence_error:.3f}"
)
if flow_config is None:
flow_config = {}
obj.proposal.resume(model, flow_config, weights_path=weights_path)

logger.debug("Recomputing log_q")
(
_,
obj.training_samples.log_q,
) = obj.proposal.compute_meta_proposal_samples(
obj.training_samples.samples
)
if obj.iid_samples:
if obj.training_samples.log_q is None:
logger.info("Recomputing log_q")
(
_,
obj.training_samples.log_q,
) = obj.proposal.compute_meta_proposal_samples(
obj.training_samples.samples
)
if obj.iid_samples and obj.iid_samples.log_q is None:
logger.info("Recomputing log_q for i.i.d samples")
(
_,
obj.iid_samples.log_q,
) = obj.proposal.compute_meta_proposal_samples(
obj.iid_samples.samples
)

logger.info(f"Resuming sampler at iteration {obj.iteration}")
logger.info(f"Current number of samples: {len(obj.nested_samples)}")
logger.info(
f"Current logZ: {obj.log_evidence:3f} "
f"+/- {obj.log_evidence_error:.3f}"
)
logger.info("Finished resuming sampler")
return obj

def __getstate__(self):
Expand All @@ -2308,7 +2335,6 @@ def __getstate__(self):
else:
state["_previous_likelihood_evaluations"] = 0
state["_previous_likelihood_evaluation_time"] = 0
state["log_q"] = None
return state, self.proposal, self.training_samples, self.iid_samples

def __setstate__(self, state):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import pytest


def test_init(ins, model):
@pytest.mark.parametrize("save_log_q", [False, True])
def test_init(ins, model, save_log_q):
ins.add_fields = MagicMock()
INS.__init__(ins, model)
INS.__init__(ins, model, save_log_q=save_log_q, draw_iid_live=True)
ins.add_fields.assert_called_once()
assert ins.training_samples.save_log_q is save_log_q
assert ins.iid_samples.save_log_q is save_log_q


def test_add_fields():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,18 @@ def test_computed_evidence_ratio(ordered_samples, samples, threshold):
mock_log_evidence.call_args_list[0][0][0], samples[above_threshold]
)
assert out == (log_z - log_z_total)


@pytest.mark.parametrize("save_log_q", [False, True])
def test_getstate(ordered_samples, save_log_q):
samples = np.random.randn(20, 4)
log_q = np.random.randn(2, 20)
ordered_samples.save_log_q = save_log_q
ordered_samples.log_q = log_q
ordered_samples.samples = samples
state = OrderedSamples.__getstate__(ordered_samples)
assert state["samples"] is samples
if save_log_q:
assert state["log_q"] is log_q
else:
assert state["log_q"] is None
56 changes: 49 additions & 7 deletions tests/test_samplers/test_importance_nested_sampler/test_resume.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
from unittest.mock import MagicMock, patch

import numpy as np
import pickle
import pytest
from unittest.mock import MagicMock, patch

from nessai.samplers.importancesampler import ImportanceNestedSampler as INS

Expand All @@ -10,7 +11,6 @@ def test_getstate_no_model(ins):
ins.proposal = MagicMock()
ins.model = None
state, proposal, training_samples, iid_samples = INS.__getstate__(ins)
assert state["log_q"] is None
assert "model" not in state
assert state["_previous_likelihood_evaluations"] == 0
assert state["_previous_likelihood_evaluation_time"] == 0
Expand All @@ -30,7 +30,6 @@ def test_getstate_model(ins):
ins.model.likelihood_evaluation_time = time

state, proposal, training_samples, iid_samples = INS.__getstate__(ins)
assert state["log_q"] is None
assert "model" not in state
assert state["_previous_likelihood_evaluations"] == evals
assert state["_previous_likelihood_evaluation_time"] == 30
Expand All @@ -39,7 +38,8 @@ def test_getstate_model(ins):
assert iid_samples is ins.iid_samples


def test_resume_from_pickled_sampler(model, samples):
@pytest.mark.parametrize("has_log_q", [False, True])
def test_resume_from_pickled_sampler(model, samples, has_log_q):

sampler = MagicMock()

Expand All @@ -48,12 +48,24 @@ def test_resume_from_pickled_sampler(model, samples):
obj.log_evidence = 0.0
obj.log_evidence_error = 1.0
obj.proposal = MagicMock()
obj.training_samples.samples = samples
log_meta_proposal = np.log(np.random.rand(len(samples)))
log_q = np.log(np.random.rand(len(samples)))
log_meta_proposal_iid = np.log(np.random.rand(len(samples)))
log_q_iid = np.log(np.random.rand(len(samples)))
obj.proposal.compute_meta_proposal_samples = MagicMock(
return_value=(log_meta_proposal, log_q)
side_effect=[
(log_meta_proposal, log_q),
(log_meta_proposal_iid, log_q_iid),
]
)
obj.training_samples.samples = samples
obj.iid_samples.samples = samples
if has_log_q:
obj.training_samples.log_q = log_q
obj.iid_samples.log_q = log_q_iid
else:
obj.training_samples.log_q = None
obj.iid_samples.log_q = None

with patch(
"nessai.samplers.importancesampler.BaseNestedSampler.resume_from_pickled_sampler", # noqa
Expand All @@ -62,5 +74,35 @@ def test_resume_from_pickled_sampler(model, samples):
out = INS.resume_from_pickled_sampler(sampler, model)

mock_resume.assert_called_once_with(sampler, model)
if has_log_q:
obj.proposal.compute_meta_proposal_samples.assert_not_called()
else:
obj.proposal.compute_meta_proposal_samples.assert_called()

assert out.training_samples.log_q is log_q
assert out.iid_samples.log_q is log_q_iid


@pytest.mark.parametrize("save_log_q", [True, False])
@pytest.mark.integration_test
def test_pickling_sampler_integration(integration_model, tmp_path, save_log_q):
outdir = tmp_path / "test_pickle"
ins = INS(
model=integration_model,
output=outdir,
nlive=50,
min_samples=10,
max_iteration=1,
save_log_q=save_log_q,
plot=False,
checkpointing=False,
)
ins.nested_sampling_loop()
data = pickle.dumps(ins)
loaded_ins = pickle.loads(data)
if save_log_q:
np.testing.assert_array_equal(
loaded_ins._ordered_samples.log_q, ins._ordered_samples.log_q
)
else:
assert loaded_ins._ordered_samples.log_q is None
4 changes: 3 additions & 1 deletion tests/test_sampling/test_ins_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

@pytest.mark.slow_integration_test
@pytest.mark.flaky(reruns=3)
def test_ins_resume(tmp_path, model, flow_config):
@pytest.mark.parametrize("save_log_q", [False, True])
def test_ins_resume(tmp_path, model, flow_config, save_log_q):
"""Assert the INS sampler resumes correctly"""
output = tmp_path / "test_ins_resume"
fp = FlowSampler(
Expand All @@ -22,6 +23,7 @@ def test_ins_resume(tmp_path, model, flow_config):
flow_config=flow_config,
importance_nested_sampler=True,
max_iteration=2,
save_log_q=save_log_q,
)
fp.run()

Expand Down

0 comments on commit 24e8e5b

Please sign in to comment.