Skip to content

Commit

Permalink
update scripts, new script for sampling comparison.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Apr 28, 2022
1 parent 89d197e commit f251cb0
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
5 changes: 3 additions & 2 deletions notebooks/mnle-lan-comparison/lan_run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ class LANPotential(BasePotential):
allow_iid_x = True # type: ignore
"""Inits LAN potential."""

def __init__(self, lan, prior, x_o, device="cpu", ll_lower_bound=np.log(1e-7)):
def __init__(self, lan, prior, x_o, device="cpu", transform_a=True, ll_lower_bound=np.log(1e-7)):
super().__init__(prior, x_o, device)

self.lan = lan
self.device = device
self.ll_lower_bound = ll_lower_bound
self.transform_a = transform_a
assert x_o.ndim == 2
assert x_o.shape[1] == 1
rts = abs(x_o)
Expand All @@ -78,7 +79,7 @@ def __call__(self, theta, track_gradients=False):
num_parameters = theta.shape[0]
# Convert DDM boundary seperation to symmetric boundary size.
# theta_lan = a_transform(theta)
theta_lan = theta
theta_lan = a_transform(theta) if self.transform_a else theta

# Evaluate LAN on batch (as suggested in LANfactory README.)
batch = torch.hstack((
Expand Down
7 changes: 4 additions & 3 deletions notebooks/mnle-lan-comparison/lan_run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

# NOTE: set budget and n_samples, e.g., for 10k budget set 10_4 and n_samples 100
# for 100k budget set 10_4 and n_samples 1000
budget = "10_7"
budget = "10_5_ours"
n_samples = 1000
num_repeats = 4
num_repeats = 9
num_epochs = 30

# NOTE: The resulting trainiend LAN will be saved with a unique ID under torch_models/ddm_{budget}.

Expand Down Expand Up @@ -53,7 +54,7 @@

network_config = lanfactory.config.network_configs.network_config_mlp
train_config = lanfactory.config.network_configs.train_config_mlp
train_config["n_epochs"] = 10
train_config["n_epochs"] = num_epochs


# LOAD NETWORK
Expand Down
94 changes: 94 additions & 0 deletions notebooks/mnle-lan-comparison/mnle_samplemethod_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pickle
from pathlib import Path
import time

import sbibm
import torch

from sbi.inference import MNLE

# Get benchmark task to load observations
seed = torch.randint(100000, (1,)).item()

task = sbibm.get_task("ddm")
prior = task.get_prior_dist()
simulator = task.get_simulator(seed=seed) # Passing the seed to Julia.

# Observation indices >200 hold 100-trial observations
num_obs = 100
xos = torch.stack([task.get_observation(200 + ii) for ii in range(1, 1+num_obs)]).squeeze()

# encode xos as (time, choice)
xos_2d = torch.zeros((xos.shape[0], xos.shape[1], 2))
for idx, xo in enumerate(xos):
xos_2d[idx, :, 0] = abs(xo)
xos_2d[idx, xo > 0, 1] = 1

BASE_DIR = Path.cwd().parent.parent
save_folder = BASE_DIR / "data/results"

# load a LAN
budget = "100000"
model_path = Path.cwd() / f"models/"
network_file_path = list(model_path.glob(f"mnle_n{budget}_new*"))[5] # take first model from random inits.

with open(network_file_path, "rb") as fh:
mnle, *_ = pickle.load(fh).values()

mcmc_parameters = dict(
warmup_steps = 100,
thin = 10,
num_chains = 10,
num_workers = 1,
init_strategy = "sir",
)

# Build MCMC posterior in SBI.
mnle_posterior = MNLE().build_posterior(mnle, prior,
mcmc_method="slice_np_vectorized",
mcmc_parameters=mcmc_parameters,
)

samples = []
num_samples = 10000
obs_idx = 10
xo = xos_2d[obs_idx]

reference_samples = task.get_reference_posterior_samples(201+obs_idx)[:num_samples]
true_theta = task.get_true_parameters(201+obs_idx)

## Slice sampling
tic = time.time()
slice_samples = mnle_posterior.sample((num_samples,), x=xo)
slice_time = time.time() - tic

## VI
tic = time.time()
viposterior = MNLE().build_posterior(mnle, prior,
sample_with="vi"
)
viposterior.set_default_x(xo)
viposterior.train()
vi_samples = viposterior.sample((num_samples, ))
vi_time = time.time() - tic

## NUTS
tic = time.time()
mcmc_parameters["num_chains"] = 2
mnle_posterior = MNLE().build_posterior(mnle, prior,
mcmc_method="nuts",
mcmc_parameters=mcmc_parameters,
)
nuts_time = time.time() - tic

nuts_samples = mnle_posterior.sample((num_samples,), x=xo)

with open(save_folder / f"mnle_samplemethod_comparison_obs{201+obs_idx}.p", "wb") as fh:
pickle.dump(dict(
reference_samples=reference_samples,
true_theta=true_theta,
slice_samples=slice_samples,
vi_samples=vi_samples,
nuts_samples=nuts_samples,
timings=dict(vi=vi_time, slice=slice_time, nuts=nuts_time),
), fh)

0 comments on commit f251cb0

Please sign in to comment.