Skip to content

Commit

Permalink
update figure notebooks and utils, submitted.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jun 28, 2022
1 parent 5f43e7e commit 872b11c
Show file tree
Hide file tree
Showing 7 changed files with 1,197 additions and 699 deletions.
738 changes: 320 additions & 418 deletions notebooks/Paper-Figure-2-likelihood-comparison.ipynb

Large diffs are not rendered by default.

514 changes: 435 additions & 79 deletions notebooks/Paper-Figure-3&4-posterior-comparison.ipynb

Large diffs are not rendered by default.

67 changes: 40 additions & 27 deletions notebooks/Paper-Figure-5-posterior-DDM-with-collapsing-bounds.ipynb

Large diffs are not rendered by default.

286 changes: 140 additions & 146 deletions notebooks/Paper-Figure-A1-3-parameter-recovery-SBC.ipynb

Large diffs are not rendered by default.

57 changes: 34 additions & 23 deletions notebooks/Paper-Figure-A4-synthetic-data.ipynb

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions notebooks/plotting_settings.mplstyle
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
font.sans-serif : Arial
font.sans-serif : open-sans
font.family : sans-serif
axes.spines.top : False
axes.spines.right : False
axes.labelsize : medium
xtick.labelsize : small
ytick.labelsize : small
axes.labelsize : 17
xtick.labelsize : 15
ytick.labelsize : 15
legend.frameon : False
legend.fontsize : 18
font.size : 20
legend.fontsize : 17
font.size : 17

text.usetex: false
221 changes: 221 additions & 0 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from sbibm.utils.io import get_float_from_csv
from tqdm.auto import tqdm

from torch.distributions.transforms import AffineTransform

from sbi.inference.potentials.base_potential import BasePotential

def compile_df(basepath: str,) -> pd.DataFrame:
"""Compile dataframe for further analyses
Expand Down Expand Up @@ -116,6 +120,223 @@ def compile_df(basepath: str,) -> pd.DataFrame:

return df

import lanfactory
import numpy as np
import sbibm
import torch

from torch.distributions.transforms import AffineTransform

from sbi.inference.potentials.base_potential import BasePotential

# affine transform
# add scale transform for "a": to go to LAN prior space, i.e., multiply with scale 0.5.
a_transform = AffineTransform(torch.zeros(1, 4), torch.tensor([[1.0, 0.5, 1.0, 1.0]]))

# LAN potential function
class LANPotential(BasePotential):
allow_iid_x = True # type: ignore
"""Inits LAN potential."""

def __init__(self,
lan,
prior: torch.distributions.Distribution,
x_o: torch.Tensor,
device: str="cpu",
apply_a_transform: bool=False,
ll_lower_bound: float=np.log(1e-7),
apply_ll_lower_bound: bool = False
):
super().__init__(prior, x_o, device)

self.lan = lan
self.device = device
self.ll_lower_bound = ll_lower_bound
self.apply_ll_lower_bound = apply_ll_lower_bound
self.apply_a_transform = apply_a_transform
assert x_o.ndim == 2
assert x_o.shape[1] == 1
rts = abs(x_o)
num_trials = rts.numel()
assert rts.shape == torch.Size([num_trials, 1])
# Code down -1 up +1.
cs = torch.ones_like(rts)
cs[x_o < 0] *= -1

self.num_trials = num_trials
self.rts = rts
self.cs = cs

def __call__(self, theta, track_gradients=False):

num_parameters = theta.shape[0]
# Convert DDM boundary seperation to symmetric boundary size.
# theta_lan = a_transform(theta)
theta_lan = a_transform(theta) if self.apply_a_transform else theta

# Evaluate LAN on batch (as suggested in LANfactory README.)
batch = torch.hstack((
theta_lan.repeat(self.num_trials, 1), # repeat params for each trial
self.rts.repeat_interleave(num_parameters, dim=0), # repeat data for each param
self.cs.repeat_interleave(num_parameters, dim=0))
)
log_likelihood_trials = self.lan(batch).reshape(self.num_trials, num_parameters)

# Sum over trials.
# Lower bound on each trial log likelihood.
# Sum across trials.
if self.apply_ll_lower_bound:
log_likelihood_trial_sum = torch.where(
torch.logical_and(
self.rts.repeat(1, num_parameters) > theta[:, -1],
log_likelihood_trials > self.ll_lower_bound,
),
log_likelihood_trials,
self.ll_lower_bound * torch.ones_like(log_likelihood_trials),
).sum(0).squeeze()
else:
log_likelihood_trial_sum = log_likelihood_trials.sum(0).squeeze()

# Maybe apply correction for transform on "a" parameter.
if self.apply_a_transform:
log_abs_det = a_transform.log_abs_det_jacobian(theta_lan, theta)
if log_abs_det.ndim > 1:
log_abs_det = log_abs_det.sum(-1)
log_likelihood_trial_sum -= log_abs_det

return log_likelihood_trial_sum + self.prior.log_prob(theta)

class OldLANPotential(BasePotential):
allow_iid_x = True # type: ignore
"""Inits LAN potential."""

def __init__(self,
lan,
prior: torch.distributions.Distribution,
x_o: torch.Tensor,
device: str="cpu",
apply_a_transform: bool=False,
ll_lower_bound: float=np.log(1e-7),
apply_ll_lower_bound: bool = False):
super().__init__(prior, x_o, device)

self.lan = lan
self.device = device
self.ll_lower_bound = ll_lower_bound
self.apply_ll_lower_bound = apply_ll_lower_bound
self.apply_a_transform = apply_a_transform
assert x_o.ndim == 2
assert x_o.shape[1] == 1
rts = abs(x_o)
num_trials = rts.numel()
assert rts.shape == torch.Size([num_trials, 1])
# Code down -1 up +1.
cs = torch.ones_like(rts)
cs[x_o < 0] *= -1

self.num_trials = num_trials
self.rts = rts
self.cs = cs

def __call__(self, theta, track_gradients=False):

num_parameters = theta.shape[0]
# Convert DDM boundary seperation to symmetric boundary size.
# theta_lan = a_transform(theta)
theta_lan = a_transform(theta) if self.apply_a_transform else theta

# Evaluate LAN on batch (as suggested in LANfactory README.)
batch = torch.hstack((
theta_lan.repeat(self.num_trials, 1), # repeat params for each trial
self.rts.repeat_interleave(num_parameters, dim=0), # repeat data for each param
self.cs.repeat_interleave(num_parameters, dim=0))
)
log_likelihood_trials = torch.tensor(
self.lan.predict_on_batch(batch.numpy()),
dtype=torch.float32,
).reshape(self.num_trials, num_parameters)


# Sum over trials.
# Lower bound on each trial log likelihood.
# Sum across trials.
if self.apply_ll_lower_bound:
log_likelihood_trial_sum = torch.where(
torch.logical_and(
self.rts.repeat(1, num_parameters) > theta[:, -1],
log_likelihood_trials > self.ll_lower_bound,
),
log_likelihood_trials,
self.ll_lower_bound * torch.ones_like(log_likelihood_trials),
).sum(0).squeeze()
else:
log_likelihood_trial_sum = log_likelihood_trials.sum(0).squeeze()

# Maybe apply correction for transform on "a" parameter.
if self.apply_a_transform:
log_abs_det = a_transform.log_abs_det_jacobian(theta_lan, theta)
if log_abs_det.ndim > 1:
log_abs_det = log_abs_det.sum(-1)
log_likelihood_trial_sum -= log_abs_det

return log_likelihood_trial_sum + self.prior.log_prob(theta)


def lan_likelihood_on_batch(
data: torch.Tensor,
theta: torch.Tensor,
net,
transform,
device,
):
"""Return LAN log-likelihood given a batch of data and parameters.
Return shape: , (batch_size_data, batch_size_parameters)
"""
# Convert to positive rts.
rts = abs(data)
num_trials = rts.numel()
num_parameters = theta.shape[0]
assert rts.shape == torch.Size([num_trials, 1])
theta = torch.tensor(theta, dtype=torch.float32)
# Convert DDM boundary seperation to symmetric boundary size.
theta_lan = transform(theta)

# Code down -1 up +1.
cs = torch.ones_like(rts)
cs[data < 0] *= -1

# Evaluate LAN on batch (as suggested in LANfactory README.)
batch = torch.hstack((
theta_lan.repeat(num_trials, 1), # repeat params for each trial
rts.repeat_interleave(num_parameters, dim=0), # repeat data for each param
cs.repeat_interleave(num_parameters, dim=0))
)
log_likelihood_trials = net(batch.to(device)).reshape(num_trials, num_parameters)

return log_likelihood_trials.to("cpu")


def apply_lower_bound_given_mask(ll, mask, ll_lower_bound: float=np.log(1e-7)):
"""Replaces values at mask with lower bound."""

assert mask.shape == ll.shape, "Mask must have the same shape as the input."

ll[mask] = ll_lower_bound

return ll

def decode_1d_to_2d_x(x1d):
"""Decodes rts with choices encoded as sign into (rts, 0-1-choices) """
x = torch.zeros((x1d.shape[0], 2))
# abs rts in first column
x[:, 0] = abs(x1d[:, 0])
# 0 - 1 code for choices in second column.
x[x1d[:, 0] > 0, 1] = 1

return x

# Define losses.
def huber_loss(y, yhat):
diff = abs(y-yhat)
Expand Down

0 comments on commit 872b11c

Please sign in to comment.