Skip to content

Commit

Permalink
Merge pull request #18 from mj-will/reduce-overhead
Browse files Browse the repository at this point in the history
Reduce overhead
  • Loading branch information
mj-will authored Jan 26, 2021
2 parents 4da03f7 + 6399944 commit f27e269
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 110 deletions.
11 changes: 4 additions & 7 deletions nessai/flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, model, output='./', resume=True,
signal.signal(signal.SIGINT, self.safe_exit)
signal.signal(signal.SIGALRM, self.safe_exit)
except AttributeError:
logger.debug('Can not set signal attributes on this system')
logger.critical('Can not set signal attributes on this system')

def run(self, plot=True, save=True):
"""
Expand Down Expand Up @@ -191,8 +191,8 @@ def save_results(self, filename):
d['information'] = self.ns.information
d['sampling_time'] = self.ns.sampling_time.total_seconds()
d['training_time'] = self.ns.training_time.total_seconds()
d['population_time'] = self.ns.proposal.population_time.total_seconds()
if (t := self.ns.proposal.logl_eval_time.total_seconds()):
d['population_time'] = self.ns.proposal_population_time.total_seconds()
if (t := self.ns.likelihood_evaluation_time.total_seconds()):
d['likelihood_evaluation_time'] = t

with open(filename, 'w') as wf:
Expand All @@ -203,11 +203,8 @@ def safe_exit(self, signum=None, frame=None):
Safely exit. This includes closing the multiprocessing pool.
"""
logger.warning(f'Trying to safely exit with code {signum}')

self.ns.proposal.close_pool(code=signum)
self.ns.checkpoint()

if self.ns.proposal.pool is not None:
self.ns.proposal.close_pool()

logger.warning(f'Exiting with code: {self.exit_code}')
sys.exit(self.exit_code)
104 changes: 68 additions & 36 deletions nessai/nestedsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .livepoint import get_dtype, DEFAULT_FLOAT_DTYPE
from .plot import plot_indices, plot_trace
from .posterior import logsubexp, log_integrate_log_trap
from .posterior import log_integrate_log_trap
from .proposal import FlowProposal
from .utils import (
safe_file_dump,
Expand All @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class _NSintegralState(object):
class _NSintegralState:
"""
Stores the state of the nested sampling integrator
"""
Expand All @@ -37,7 +37,6 @@ def reset(self):
"""
Reset the sampler to its initial state at logZ = -infinity
"""
self.iteration = 0
self.logZ = -np.inf
self.oldZ = -np.inf
self.logw = 0
Expand All @@ -59,7 +58,7 @@ def increment(self, logL, nlive=None):
nlive = self.nlive
oldZ = self.logZ
logt = - 1.0 / nlive
Wt = self.logw + logL + logsubexp(0, logt)
Wt = self.logw + logL + np.log1p(-np.exp(logt))
self.logZ = np.logaddexp(self.logZ, Wt)
# Update information estimate
if np.isfinite(oldZ) and np.isfinite(self.logZ) and np.isfinite(logL):
Expand All @@ -73,7 +72,6 @@ def increment(self, logL, nlive=None):

# Update history
self.logw += logt
self.iteration += 1
self.logLs.append(logL)
self.log_vols.append(self.logw)
self.gradients.append((self.logLs[-1] - self.logLs[-2])
Expand All @@ -95,8 +93,8 @@ def plot(self, filename):
"""
fig = plt.figure()
plt.plot(self.log_vols, self.logLs)
plt.title((f'{self.iteration} iterations. logZ={self.logZ:.2f}'
f'H={self.info[-1] * np.log2(np.e):.2f} bits'))
plt.title(f'logZ={self.logZ:.2f}'
f'H={self.info[-1] * np.log2(np.e):.2f} bits')
plt.grid(which='both')
plt.xlabel('log prior_volume')
plt.ylabel('log likelihood')
Expand Down Expand Up @@ -198,8 +196,10 @@ def __init__(self, model, nlive=1000, output=None,
stopping=0.1,
max_iteration=None,
checkpointing=True,
checkpoint_on_training=False,
resume_file=None,
seed=None,
n_pool=None,
plot=True,
proposal_plots=True,
prior_sampling=False,
Expand All @@ -223,17 +223,21 @@ def __init__(self, model, nlive=1000, output=None,

logger.info('Initialising nested sampler')

self.info_enabled = logger.isEnabledFor(logging.INFO)

model.verify_model()

self.model = model
self.nlive = nlive
self.n_pool = n_pool
self.live_points = None
self.prior_sampling = prior_sampling
self.setup_random_seed(seed)
self.accepted = 0
self.rejected = 1

self.checkpointing = checkpointing
self.checkpoint_on_training = checkpoint_on_training
self.iteration = 0
self.acceptance_history = deque(maxlen=(nlive // 10))
self.mean_acceptance_history = []
Expand Down Expand Up @@ -332,6 +336,22 @@ def information(self):
def likelihood_calls(self):
return self.model.likelihood_evaluations

@property
def likelihood_evaluation_time(self):
t = self._uninformed_proposal.logl_eval_time
t += self._flow_proposal.logl_eval_time
return t

@property
def proposal_population_time(self):
t = self._uninformed_proposal.population_time
t += self._flow_proposal.population_time
return t

@property
def acceptance(self):
return self.iteration / self.likelihood_calls

@property
def current_sampling_time(self):
return self.sampling_time \
Expand Down Expand Up @@ -400,8 +420,9 @@ def configure_uninformed_proposal(self,
kwargs['poolsize'] = self.nlive

logger.debug(f'Using uninformed proposal: {uninformed_proposal}')
logger.debug(f'Parsing kwargs to uniformed proposal: {kwargs}')
self._uninformed_proposal = uninformed_proposal(self.model, **kwargs)
logger.debug(f'Parsing kwargs to uninformed proposal: {kwargs}')
self._uninformed_proposal = uninformed_proposal(
self.model, n_pool=self.n_pool, **kwargs)

def configure_flow_proposal(self, flow_class, flow_config, proposal_plots,
**kwargs):
Expand Down Expand Up @@ -445,7 +466,7 @@ def configure_flow_proposal(self, flow_class, flow_config, proposal_plots,
logger.info(f'Parsing kwargs to FlowProposal: {kwargs}')
self._flow_proposal = flow_class(
self.model, flow_config=flow_config, output=proposal_output,
plot=proposal_plots, **kwargs)
plot=proposal_plots, n_pool=self.n_pool, **kwargs)

def setup_output(self, output, resume_file=None):
"""
Expand Down Expand Up @@ -535,19 +556,16 @@ def yield_sample(self, oldparam):
while True:
counter += 1
newparam = self.proposal.draw(oldparam.copy())
newparam['logP'] = self.model.log_prior(newparam)

# Prior is computed in the proposal
if newparam['logP'] != -np.inf:
if not newparam['logL']:
newparam['logL'] = \
self.model.evaluate_log_likelihood(newparam)
self.model.evaluate_log_likelihood(newparam)
if newparam['logL'] > self.logLmin:
self.logLmax = max(self.logLmax, newparam['logL'])
oldparam = newparam.copy()
break
if (1 / counter) < self.acceptance_threshold:
self.max_count += 1
break
# Only here if proposed and then empty
# This returns the old point and allows for a training check
if not self.proposal.populated:
Expand Down Expand Up @@ -608,20 +626,18 @@ def consume_sample(self):
if not self.block_iteration:
self.block_iteration += 1

self.acceptance = self.accepted / (self.accepted + self.rejected)
self.mean_block_acceptance = self.block_acceptance \
/ self.block_iteration
logger.info(f"{self.iteration:5d}: n: {count:3d} "
f"NS_acc: {self.acceptance:.3f} "
f"m_acc: {self.mean_acceptance:.3f} "
f"b_acc: {self.mean_block_acceptance:.3f} "
f"sub_acc: {1 / count:.3f} "
f"H: {self.state.info[-1]:.2f} "
f"logL: {self.logLmin:.5f} --> {proposed['logL']:.5f} "
f"dZ: {self.condition:.3f} "
f"logZ: {self.state.logZ:.3f} "
f"+/- {np.sqrt(self.state.info[-1] / self.nlive):.3f} "
f"logLmax: {self.logLmax:.2f}")

if self.info_enabled:
logger.info(f"{self.iteration:5d}: n: {count:3d} "
f"b_acc: {self.mean_block_acceptance:.3f} "
f"H: {self.state.info[-1]:.2f} "
f"logL: {self.logLmin:.5f} --> {proposed['logL']:.5f} "
f"dZ: {self.condition:.3f} "
f"logZ: {self.state.logZ:.3f} "
f"+/- {np.sqrt(self.state.info[-1] / self.nlive):.3f} "
f"logLmax: {self.logLmax:.2f}")

def populate_live_points(self):
"""
Expand Down Expand Up @@ -680,6 +696,8 @@ def initialise(self, live_points=True):
else:
self.proposal = self._flow_proposal

self.proposal.configure_pool()

if live_points and self.live_points is None:
self.populate_live_points()
flags[2] = True
Expand Down Expand Up @@ -714,6 +732,8 @@ def check_proposal_switch(self):
if self.proposal.pool is not None:
self.proposal.close_pool()
self.proposal = self._flow_proposal
if self.proposal.n_pool is not None:
self.proposal.configure_pool()
self.proposal.ns_acceptance = self.mean_block_acceptance
self.uninformed_sampling = False
return True
Expand Down Expand Up @@ -774,7 +794,7 @@ def check_flow_model_reset(self):
self.proposal.reset_model_weights(weights=True)

if (self.reset_permutations and
not (self.proposal.training_count % self.reset_weights)):
not (self.proposal.training_count % self.reset_permutations)):
self.proposal.reset_model_weights(weights=False, permutations=True)

def train_proposal(self, force=False):
Expand Down Expand Up @@ -806,7 +826,7 @@ def train_proposal(self, force=False):
self.block_iteration = 0
self.block_acceptance = 0.
self.completed_training = True
if self.checkpointing:
if self.checkpoint_on_training:
self.checkpoint(periodic=True)

def check_state(self, force=False):
Expand Down Expand Up @@ -926,6 +946,8 @@ def plot_trace(self):
if self.nested_samples:
plot_trace(self.state.log_vols[1:], self.nested_samples,
filename=f'{self.output}/trace.png')
else:
logger.warning('Could not produce trace plot. No nested samples!')

def update_state(self, force=False):
"""
Expand All @@ -948,6 +970,14 @@ def update_state(self, force=False):
self.mean_acceptance_history.append(self.mean_acceptance)

if not (self.iteration % self.nlive) or force:
logger.warning(
f"it: {self.iteration:5d}: "
f"n eval: {self.likelihood_calls} "
f"H: {self.state.info[-1]:.2f} "
f"dZ: {self.condition:.3f} logZ: {self.state.logZ:.3f} "
f"+/- {np.sqrt(self.state.info[-1] / self.nlive):.3f} "
f"logLmax: {self.logLmax:.2f}")
self.checkpoint(periodic=True)
if not force:
self.check_insertion_indices()
if self.plot:
Expand Down Expand Up @@ -992,7 +1022,7 @@ def check_resume(self):
was resumed.
"""
if self.resumed:
# If pool is populated make reset the flag since it is set to
# If pool is populated reset the flag since it is set to
# false during initialisation
if hasattr(self._flow_proposal, 'resume_populated'):
if (self._flow_proposal.resume_populated and
Expand Down Expand Up @@ -1028,11 +1058,14 @@ def nested_sampling_loop(self):
if self.prior_sampling:
for i in range(self.nlive):
self.nested_samples = self.params.copy()
return 0
return

self.check_resume()

self.update_state()
if self.iteration:
self.update_state()

logger.critical('Starting nested sampling loop')

while self.condition > self.tolerance:

Expand Down Expand Up @@ -1060,18 +1093,17 @@ def nested_sampling_loop(self):
self.check_insertion_indices(rolling=False)

# This includes updating the total sampling time
if self.checkpointing:
self.checkpoint(periodic=True)
self.checkpoint(periodic=True)

logger.info(f'Total sampling time: {self.sampling_time}')
logger.info(f'Total training time: {self.training_time}')
logger.info(f'Total population time: {self.proposal.population_time}')
logger.info(f'Total population time: {self.proposal_population_time}')
logger.info(
f'Total likelihood evaluations: {self.likelihood_calls:3d}')
if self.proposal.logl_eval_time.total_seconds():
logger.info(
'Time spent evaluating likelihood: '
f'{self.proposal.logl_eval_time}')
f'{self.likelihood_evaluation_time}')

return self.state.logZ, np.array(self.nested_samples)

Expand Down
2 changes: 1 addition & 1 deletion nessai/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def logsubexp(x, y):
Return
z: :float: x + np.log1p(-np.exp(y-x))
"""
if np.all(x < y):
if np.any(x < y):
raise RuntimeError('cannot take log of negative number '
f'{str(x)!s} - {str(y)!s}')

Expand Down
Loading

0 comments on commit f27e269

Please sign in to comment.