Skip to content

Commit

Permalink
added fmpe to multi bin cvs
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 12, 2024
1 parent c58f616 commit 29e2792
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ dependencies = [

[project.scripts]
train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving"
train_fmpe_model = "cryo_sbi.inference.command_line_tools:cl_fmpe_train_no_saving"
model_to_tensor = "cryo_sbi.utils.command_line_tools:cl_models_to_tensor"
69 changes: 69 additions & 0 deletions src/cryo_sbi/inference/command_line_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from cryo_sbi.inference.train_npe_model import (
npe_train_no_saving,
)
from cryo_sbi.inference.train_fmpe_model import (
fmpe_train_no_saving,
)


def cl_npe_train_no_saving():
Expand Down Expand Up @@ -68,3 +71,69 @@ def cl_npe_train_no_saving():
validation_set=args.val_set,
validation_frequency=args.val_freq,
)


def cl_fmpe_train_no_saving():
cl_parser = argparse.ArgumentParser()

cl_parser.add_argument(
"--image_config_file", action="store", type=str, required=True
)
cl_parser.add_argument(
"--train_config_file", action="store", type=str, required=True
)
cl_parser.add_argument("--epochs", action="store", type=int, required=True)
cl_parser.add_argument("--estimator_file", action="store", type=str, required=True)
cl_parser.add_argument("--loss_file", action="store", type=str, required=True)
cl_parser.add_argument(
"--train_from_checkpoint",
action="store",
type=bool,
nargs="?",
required=False,
const=True,
default=False,
)
cl_parser.add_argument(
"--state_dict_file", action="store", type=str, required=False, default=False
)
cl_parser.add_argument(
"--n_workers", action="store", type=int, required=False, default=1
)
cl_parser.add_argument(
"--train_device", action="store", type=str, required=False, default="cpu"
)
cl_parser.add_argument(
"--saving_freq", action="store", type=int, required=False, default=20
)
cl_parser.add_argument(
"--val_set", action="store", type=str, required=False, default=None
)
cl_parser.add_argument(
"--val_freq", action="store", type=int, required=False, default=10
)
cl_parser.add_argument(
"--simulation_batch_size",
action="store",
type=int,
required=False,
default=1024,
)

args = cl_parser.parse_args()

fmpe_train_no_saving(
image_config=args.image_config_file,
train_config=args.train_config_file,
epochs=args.epochs,
estimator_file=args.estimator_file,
loss_file=args.loss_file,
train_from_checkpoint=args.train_from_checkpoint,
model_state_dict=args.state_dict_file,
n_workers=args.n_workers,
device=args.train_device,
saving_frequency=args.saving_freq,
simulation_batch_size=args.simulation_batch_size,
validation_set=args.val_set,
validation_frequency=args.val_freq,
)
5 changes: 5 additions & 0 deletions src/cryo_sbi/inference/models/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torch.nn as nn

ACTIVATIONS = {}


31 changes: 31 additions & 0 deletions src/cryo_sbi/inference/models/build_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
f"Model : {config['EMBEDDING']} has not been implemented yet! \
The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
)

if "BINS" in config:
bins = config["BINS"]
print(f"Using {bins} bins for NPE")
else:
bins = 8

estimator = estimator_models.NPEWithEmbedding(
embedding_net=embedding,
Expand All @@ -49,11 +55,36 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
flow=model,
theta_shift=config["THETA_SHIFT"],
theta_scale=config["THETA_SCALE"],
bins=bins,
**{"activation": partial(nn.LeakyReLU, 0.1)},
)

return estimator


def build_fmpe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:

try:
embedding = partial(
EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"], **embedding_kwargs
)
except KeyError:
raise NotImplementedError(
f"Model : {config['EMBEDDING']} has not been implemented yet! \
The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
)

estimator = estimator_models.FMPEWithEmbedding(
embedding_net=embedding,
output_embedding_dim=config["OUT_DIM"],
num_hidden_flow=config["NUM_HIDDEN_FLOW"],
hidden_flow_dim=config["HIDDEN_DIM_FLOW"],
theta_shift=config["THETA_SHIFT"],
theta_scale=config["THETA_SCALE"],
)

return estimator


def build_nre_classifier_model(config: dict, **embedding_kwargs) -> nn.Module:
raise NotImplementedError("NRE classifier model has not been implemented yet!")
44 changes: 41 additions & 3 deletions src/cryo_sbi/inference/models/estimator_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import zuko
from lampe.inference import NPE, NRE

from lampe.inference import NPE, NRE, FMPE
from lampe.nn import ResMLP

class Standardize(nn.Module):
"""
Expand Down Expand Up @@ -71,6 +71,7 @@ def __init__(
flow: nn.Module = zuko.flows.MAF,
theta_shift: float = 0.0,
theta_scale: float = 1.0,
bins: int = 8,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -99,9 +100,10 @@ def __init__(
transforms=num_transforms,
build=flow,
hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64],
bins=bins,
**kwargs,
)

self.type = "NPE"
self.embedding = embedding_net()
self.standardize = Standardize(theta_shift, theta_scale)

Expand Down Expand Up @@ -145,3 +147,39 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:

samples_standardized = self.flow(x).sample(shape)
return self.standardize.transform(samples_standardized)

class FMPEWithEmbedding(nn.Module):
def __init__(
self,
embedding_net: nn.Module,
output_embedding_dim: int,
num_hidden_flow: int = 2,
hidden_flow_dim: int = 128,
theta_shift: float = 0.0,
theta_scale: float = 1.0,
**kwargs,
) -> None:

super().__init__()

self.fmpe = FMPE(
theta_dim=1,
x_dim=output_embedding_dim,
freqs=5,
build=ResMLP,
hidden_features=[*[hidden_flow_dim] * num_hidden_flow],
activation=nn.ELU
)
self.type = "FMPE"
self.embedding = embedding_net()
self.standardize = Standardize(theta_shift, theta_scale)

def forward(self, theta: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.fmpe(self.standardize(theta), self.embedding(x), t)

def flow(self, x: torch.Tensor):
return self.fmpe.flow(self.embedding(x))

def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
samples_standardized = self.flow(x).sample(shape)
return self.standardize.transform(samples_standardized)
Loading

0 comments on commit 29e2792

Please sign in to comment.