Skip to content

Commit

Permalink
update data, training, and inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 28, 2022
1 parent b81821a commit 2e2b05a
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 72 deletions.
8 changes: 5 additions & 3 deletions notebooks/mnle-lan-comparison/lan_generate_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Adapted from: https://github.com/AlexanderFengler/ssm_simulators
# Script to generate DDM data with LANs ssm_simulator package,
# for different simulation budgets.

# Load necessary packages
from copy import deepcopy
Expand All @@ -22,11 +24,11 @@
# Specify number of parameter sets to simulate:
generator_config['n_parameter_sets'] = 100
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 100
generator_config['n_samples'] = 1000
# Number of KDE samples to draw from KDE to generate NN targets
generator_config['n_training_samples_by_parameter_set'] = 1000
# Specify folder in which to save generated data
generator_config['output_folder'] = 'data/lan_mlp_test/'
generator_config['output_folder'] = 'data/lan_mlp_10_5_^2^3/'
generator_config['n_cpus'] = 1
generator_config['n_subruns'] = 1

Expand All @@ -39,4 +41,4 @@
# Pass our simulator to use our prior.
julia_simulator = simulator,
)
training_data = my_dataset_generator.generate_data_training_uniform(save = False)
training_data = my_dataset_generator.generate_data_training_uniform(save = True)
84 changes: 24 additions & 60 deletions notebooks/mnle-lan-comparison/lan_run_inference.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Script to run inference with MCMC given a pretrained LAN
# Uses MCMC methods from the sbi package, via a custom potential
# function wrapper in utils.

import pickle
from pathlib import Path
from joblib import Parallel, delayed

import lanfactory
import numpy as np
import sbibm
import torch

from torch.distributions.transforms import AffineTransform

from sbi.inference import MCMCPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.utils import mcmc_transform

from utils import LANPotential


BASE_DIR = Path.cwd().parent.parent
save_folder = BASE_DIR / "data/results"
Expand All @@ -25,7 +27,7 @@
simulator = task.get_simulator(seed=seed) # Passing the seed to Julia.

# Observation indices >200 hold 100-trial observations
num_trials = 1
num_trials = 100
if num_trials == 1:
start_obs = 0
elif num_trials == 10:
Expand All @@ -46,7 +48,9 @@


# load a LAN
budget = "10_8_ours"
budget = "10_11"
apply_a_transform = True
apply_ll_lower_bound = True
model_path = Path.cwd() / f"data/torch_models/ddm_{budget}/"
network_file_path = list(model_path.glob("*state_dict*"))[0] # take first model from random inits.

Expand All @@ -59,56 +63,13 @@
network_config = network_config,
input_dim = 6) # 4 params plus 2 data dims)

# LAN potential function
class LANPotential(BasePotential):
allow_iid_x = True # type: ignore
"""Inits LAN potential."""

def __init__(self, lan, prior, x_o, device="cpu", transform_a=False, ll_lower_bound=np.log(1e-7)):
super().__init__(prior, x_o, device)

self.lan = lan
self.device = device
self.ll_lower_bound = ll_lower_bound
self.transform_a = transform_a
assert x_o.ndim == 2
assert x_o.shape[1] == 1
rts = abs(x_o)
num_trials = rts.numel()
assert rts.shape == torch.Size([num_trials, 1])
# Code down -1 up +1.
cs = torch.ones_like(rts)
cs[x_o < 0] *= -1

self.num_trials = num_trials
self.rts = rts
self.cs = cs

def __call__(self, theta, track_gradients=False):

num_parameters = theta.shape[0]
# Convert DDM boundary seperation to symmetric boundary size.
# theta_lan = a_transform(theta)
theta_lan = a_transform(theta) if self.transform_a else theta

# Evaluate LAN on batch (as suggested in LANfactory README.)
batch = torch.hstack((
theta_lan.repeat(self.num_trials, 1), # repeat params for each trial
self.rts.repeat_interleave(num_parameters, dim=0), # repeat data for each param
self.cs.repeat_interleave(num_parameters, dim=0))
)
log_likelihood_trials = lan(batch).reshape(self.num_trials, num_parameters)

# Sum over trials.
log_likelihood_trial_sum = log_likelihood_trials.sum(0).squeeze()
# load old LAN
from tensorflow import keras
# network trained on KDE likelihood for 4-param ddm
lan_kde_path = Path.cwd() / "../../data/pretrained-models/model_final_ddm.h5"
lan = keras.models.load_model(lan_kde_path, compile=False)

# Apply correction for transform on "a" parameter.
# log_abs_det = a_transform.log_abs_det_jacobian(theta_lan, theta)
# if log_abs_det.ndim > 1:
# log_abs_det = log_abs_det.sum(-1)
# log_likelihood_trial_sum -= log_abs_det

return log_likelihood_trial_sum + self.prior.log_prob(theta)

mcmc_parameters = dict(
warmup_steps = 100,
Expand All @@ -120,16 +81,19 @@ def __call__(self, theta, track_gradients=False):

# Build MCMC posterior in SBI.
theta_transform = mcmc_transform(prior)
# affine transform
# add scale transform for "a": to go to LAN prior space, i.e., multiply with scale 0.5.
a_transform = AffineTransform(torch.zeros(1, 4), torch.tensor([[1.0, 0.5, 1.0, 1.0]]))

samples = []
num_samples = 1000
num_samples = 10000
num_workers = 20

def run(x_o):
lan_posterior = MCMCPosterior(LANPotential(lan, prior, x_o.reshape(-1, 1)),
lan_potential = LANPotential(lan,
prior,
x_o.reshape(-1, 1),
apply_a_transform=apply_a_transform,
apply_ll_lower_bound=apply_ll_lower_bound,
)
lan_posterior = MCMCPosterior(lan_potential,
proposal=prior,
theta_transform=theta_transform,
method="slice_np_vectorized",
Expand All @@ -142,5 +106,5 @@ def run(x_o):
delayed(run)(x_o) for x_o in xos
)

with open(save_folder / f"lan_{budget}_posterior_samples_{num_obs}x{num_trials}iid_new.p", "wb") as fh:
with open(save_folder / f"lan_{budget}_posterior_samples_{num_obs}x{num_trials}iid_old.p", "wb") as fh:
pickle.dump(results, fh)
1 change: 1 addition & 0 deletions notebooks/mnle-lan-comparison/lan_run_training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Adapted from: https://github.com/AlexanderFengler/LANfactory
# Script to run LAN training given a pre-simulated data set.

# Load necessary packages
import lanfactory
Expand Down
16 changes: 7 additions & 9 deletions notebooks/mnle-lan-comparison/mnle_run_inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Script for running MCMC using a pre-trained MNLE object from sbi.

import pickle
from pathlib import Path

Expand All @@ -14,7 +16,7 @@
simulator = task.get_simulator(seed=seed) # Passing the seed to Julia.

# Observation indices >200 hold 100-trial observations
num_trials = 1
num_trials = 100
if num_trials == 1:
start_obs = 0
elif num_trials == 10:
Expand All @@ -35,9 +37,9 @@


# load a LAN
budget = "1000000"
budget = "100000"
model_path = Path.cwd() / f"models/"
init_idx = 1
init_idx = 2
network_file_path = list(model_path.glob(f"mnle_n{budget}_new*"))[init_idx] # take one model from random inits.

with open(network_file_path, "rb") as fh:
Expand All @@ -64,15 +66,11 @@

mnle_samples = mnle_posterior.sample(
(num_samples,),
x=x_o.reshape(100, 2)
x=x_o.reshape(num_trials, 2)
)

# mnle_posterior.set_default_x(x_o.reshape(100, 2))
# mnle_posterior.train()
# vi_samples = mnle_posterior.sample((num_samples, ))

samples.append(mnle_samples)

with open(f"mnle-{init_idx}_{budget}_posterior_samples_{num_obs}x{num_trials}*iid.p", "wb") as fh:
with open(f"mnle-{init_idx}_{budget}_posterior_samples_{num_obs}x{num_trials}iid.p", "wb") as fh:
pickle.dump(samples, fh)

2 changes: 2 additions & 0 deletions notebooks/mnle-lan-comparison/mnle_training_script.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Script for training MNLE with pre-simulated data.

import pickle
import torch

Expand Down
Loading

0 comments on commit 2e2b05a

Please sign in to comment.