diff --git a/pyproject.toml b/pyproject.toml index 76d9726..7b8ec94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,5 @@ dependencies = [ [project.scripts] train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving" +train_fmpe_model = "cryo_sbi.inference.command_line_tools:cl_fmpe_train_no_saving" model_to_tensor = "cryo_sbi.utils.command_line_tools:cl_models_to_tensor" diff --git a/src/cryo_sbi/inference/command_line_tools.py b/src/cryo_sbi/inference/command_line_tools.py index 8a1a491..96a72ab 100644 --- a/src/cryo_sbi/inference/command_line_tools.py +++ b/src/cryo_sbi/inference/command_line_tools.py @@ -2,6 +2,9 @@ from cryo_sbi.inference.train_npe_model import ( npe_train_no_saving, ) +from cryo_sbi.inference.train_fmpe_model import ( + fmpe_train_no_saving, +) def cl_npe_train_no_saving(): @@ -68,3 +71,69 @@ def cl_npe_train_no_saving(): validation_set=args.val_set, validation_frequency=args.val_freq, ) + + +def cl_fmpe_train_no_saving(): + cl_parser = argparse.ArgumentParser() + + cl_parser.add_argument( + "--image_config_file", action="store", type=str, required=True + ) + cl_parser.add_argument( + "--train_config_file", action="store", type=str, required=True + ) + cl_parser.add_argument("--epochs", action="store", type=int, required=True) + cl_parser.add_argument("--estimator_file", action="store", type=str, required=True) + cl_parser.add_argument("--loss_file", action="store", type=str, required=True) + cl_parser.add_argument( + "--train_from_checkpoint", + action="store", + type=bool, + nargs="?", + required=False, + const=True, + default=False, + ) + cl_parser.add_argument( + "--state_dict_file", action="store", type=str, required=False, default=False + ) + cl_parser.add_argument( + "--n_workers", action="store", type=int, required=False, default=1 + ) + cl_parser.add_argument( + "--train_device", action="store", type=str, required=False, default="cpu" + ) + cl_parser.add_argument( + "--saving_freq", action="store", type=int, required=False, default=20 + ) + cl_parser.add_argument( + "--val_set", action="store", type=str, required=False, default=None + ) + cl_parser.add_argument( + "--val_freq", action="store", type=int, required=False, default=10 + ) + cl_parser.add_argument( + "--simulation_batch_size", + action="store", + type=int, + required=False, + default=1024, + ) + + args = cl_parser.parse_args() + + fmpe_train_no_saving( + image_config=args.image_config_file, + train_config=args.train_config_file, + epochs=args.epochs, + estimator_file=args.estimator_file, + loss_file=args.loss_file, + train_from_checkpoint=args.train_from_checkpoint, + model_state_dict=args.state_dict_file, + n_workers=args.n_workers, + device=args.train_device, + saving_frequency=args.saving_freq, + simulation_batch_size=args.simulation_batch_size, + validation_set=args.val_set, + validation_frequency=args.val_freq, + ) diff --git a/src/cryo_sbi/inference/models/activations.py b/src/cryo_sbi/inference/models/activations.py new file mode 100644 index 0000000..ceaff1e --- /dev/null +++ b/src/cryo_sbi/inference/models/activations.py @@ -0,0 +1,5 @@ +import torch.nn as nn + +ACTIVATIONS = {} + + diff --git a/src/cryo_sbi/inference/models/build_models.py b/src/cryo_sbi/inference/models/build_models.py index f3e1bd5..5b2cf67 100644 --- a/src/cryo_sbi/inference/models/build_models.py +++ b/src/cryo_sbi/inference/models/build_models.py @@ -39,6 +39,12 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: f"Model : {config['EMBEDDING']} has not been implemented yet! \ The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}" ) + + if "BINS" in config: + bins = config["BINS"] + print(f"Using {bins} bins for NPE") + else: + bins = 8 estimator = estimator_models.NPEWithEmbedding( embedding_net=embedding, @@ -49,11 +55,36 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: flow=model, theta_shift=config["THETA_SHIFT"], theta_scale=config["THETA_SCALE"], + bins=bins, **{"activation": partial(nn.LeakyReLU, 0.1)}, ) return estimator +def build_fmpe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: + + try: + embedding = partial( + EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"], **embedding_kwargs + ) + except KeyError: + raise NotImplementedError( + f"Model : {config['EMBEDDING']} has not been implemented yet! \ +The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}" + ) + + estimator = estimator_models.FMPEWithEmbedding( + embedding_net=embedding, + output_embedding_dim=config["OUT_DIM"], + num_hidden_flow=config["NUM_HIDDEN_FLOW"], + hidden_flow_dim=config["HIDDEN_DIM_FLOW"], + theta_shift=config["THETA_SHIFT"], + theta_scale=config["THETA_SCALE"], + ) + + return estimator + + def build_nre_classifier_model(config: dict, **embedding_kwargs) -> nn.Module: raise NotImplementedError("NRE classifier model has not been implemented yet!") diff --git a/src/cryo_sbi/inference/models/estimator_models.py b/src/cryo_sbi/inference/models/estimator_models.py index a1ff0fe..2f98a4d 100644 --- a/src/cryo_sbi/inference/models/estimator_models.py +++ b/src/cryo_sbi/inference/models/estimator_models.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import zuko -from lampe.inference import NPE, NRE - +from lampe.inference import NPE, NRE, FMPE +from lampe.nn import ResMLP class Standardize(nn.Module): """ @@ -71,6 +71,7 @@ def __init__( flow: nn.Module = zuko.flows.MAF, theta_shift: float = 0.0, theta_scale: float = 1.0, + bins: int = 8, **kwargs, ) -> None: """ @@ -99,9 +100,10 @@ def __init__( transforms=num_transforms, build=flow, hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64], + bins=bins, **kwargs, ) - + self.type = "NPE" self.embedding = embedding_net() self.standardize = Standardize(theta_shift, theta_scale) @@ -145,3 +147,39 @@ def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor: samples_standardized = self.flow(x).sample(shape) return self.standardize.transform(samples_standardized) + +class FMPEWithEmbedding(nn.Module): + def __init__( + self, + embedding_net: nn.Module, + output_embedding_dim: int, + num_hidden_flow: int = 2, + hidden_flow_dim: int = 128, + theta_shift: float = 0.0, + theta_scale: float = 1.0, + **kwargs, + ) -> None: + + super().__init__() + + self.fmpe = FMPE( + theta_dim=1, + x_dim=output_embedding_dim, + freqs=5, + build=ResMLP, + hidden_features=[*[hidden_flow_dim] * num_hidden_flow], + activation=nn.ELU + ) + self.type = "FMPE" + self.embedding = embedding_net() + self.standardize = Standardize(theta_shift, theta_scale) + + def forward(self, theta: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return self.fmpe(self.standardize(theta), self.embedding(x), t) + + def flow(self, x: torch.Tensor): + return self.fmpe.flow(self.embedding(x)) + + def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor: + samples_standardized = self.flow(x).sample(shape) + return self.standardize.transform(samples_standardized) \ No newline at end of file diff --git a/src/cryo_sbi/inference/train_fmpe_model.py b/src/cryo_sbi/inference/train_fmpe_model.py new file mode 100644 index 0000000..e6130b0 --- /dev/null +++ b/src/cryo_sbi/inference/train_fmpe_model.py @@ -0,0 +1,224 @@ +from typing import Union +import json +import torch +import numpy as np +import torch.optim as optim +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from lampe.inference import FMPELoss +from lampe.utils import GDStep +from itertools import islice +import matplotlib.pyplot as plt + +from cryo_sbi.inference.priors import get_image_priors, PriorLoader +from cryo_sbi.inference.models.build_models import build_fmpe_flow_model +from cryo_sbi.inference.validate_train_config import check_train_params +from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator +from cryo_sbi.wpa_simulator.validate_image_config import check_image_params +from cryo_sbi.inference.validate_train_config import check_train_params +from cryo_sbi.utils.estimator_utils import sample_posterior, evaluate_log_prob + + +def load_model( + train_config: str, model_state_dict: str, device: str, train_from_checkpoint: bool +) -> torch.nn.Module: + """ + Load model from checkpoint or from scratch. + + Args: + train_config (str): path to train config file + model_state_dict (str): path to model state dict + device (str): device to load model to + train_from_checkpoint (bool): whether to load model from checkpoint or from scratch + """ + + estimator = build_fmpe_flow_model(train_config) + if train_from_checkpoint: + if not isinstance(model_state_dict, str): + raise Warning("No model state dict specified! --model_state_dict is empty") + print(f"Loading model parameters from {model_state_dict}") + estimator.load_state_dict(torch.load(model_state_dict)) + estimator.to(device=device) + return estimator + + +def fmpe_train_no_saving( + image_config: str, + train_config: str, + epochs: int, + estimator_file: str, + loss_file: str, + train_from_checkpoint: bool = False, + model_state_dict: Union[str, None] = None, + n_workers: int = 1, + device: str = "cpu", + saving_frequency: int = 20, + simulation_batch_size: int = 1024, + validation_set: Union[str, None] = None, + validation_frequency: int = 10 +) -> None: + """ + Train NPE model by simulating training data on the fly. + Saves model and loss to disk. + + Args: + image_config (str): path to image config file + train_config (str): path to train config file + epochs (int): number of epochs + estimator_file (str): path to estimator file + loss_file (str): path to loss file + train_from_checkpoint (bool, optional): train from checkpoint. Defaults to False. + model_state_dict (str, optional): path to pretrained model state dict. Defaults to None. + n_workers (int, optional): number of workers. Defaults to 1. + device (str, optional): training device. Defaults to "cpu". + saving_frequency (int, optional): frequency of saving model. Defaults to 20. + whiten_filter (Union[None, str], optional): path to whiten filter. Defaults to None. + + Raises: + Warning: No model state dict specified! --model_state_dict is empty + + Returns: + None + """ + print("Training fmpe model") + train_config = json.load(open(train_config)) + image_config = json.load(open(image_config)) + check_image_params(image_config) + + assert simulation_batch_size >= train_config["BATCH_SIZE"] + assert simulation_batch_size % train_config["BATCH_SIZE"] == 0 + steps_per_epoch = simulation_batch_size // train_config["BATCH_SIZE"] + epoch_repeats = 100 # number of times to simulate a batch of images per epoch + + if image_config["MODEL_FILE"].endswith("npy"): + models = ( + torch.from_numpy( + np.load(image_config["MODEL_FILE"]), + ) + .to(device) + .to(torch.float32) + ) + else: + models = torch.load(image_config["MODEL_FILE"]) + if isinstance(models, list): + models = torch.cat(models, dim=0).to(device).to(torch.float32) + print("model shape", models.shape) + + image_prior = get_image_priors(len(models) - 1, image_config, device="cpu") + index_to_cv = image_prior.priors[0].index_to_cv.to(device) + max_index = index_to_cv.max().cpu() + prior_loader = PriorLoader( + image_prior, batch_size=simulation_batch_size, num_workers=n_workers + ) + + num_pixels = torch.tensor( + image_config["N_PIXELS"], dtype=torch.float32, device=device + ) + pixel_size = torch.tensor( + image_config["PIXEL_SIZE"], dtype=torch.float32, device=device + ) + + estimator = load_model( + train_config, model_state_dict, device, train_from_checkpoint + ) + + loss = FMPELoss(estimator) + optimizer = optim.AdamW(estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=train_config["WEIGHT_DECAY"]) + step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) + mean_loss = [] + + if validation_set is not None: + validation_set = torch.load(validation_set) + assert isinstance(validation_set, dict), "Validation set must be a dictionary" + assert "IMAGES" in validation_set, "Validation set must contain images" + assert "INDICES" in validation_set, "Validation set must contain ground truth indices" + + #print("Initializing tensorboard writer") + #writer = SummaryWriter() + + """if validation_set is not None: + num_validation_images = validation_set["IMAGES"].shape[0] + for i in range(num_validation_images): + fig, axes = plt.subplots(1, 1, figsize=(5, 5)) + axes.imshow(validation_set["IMAGES"][i].cpu().numpy(), cmap="gray", vmax=1.5, vmin=-1.5) + axes.axis("off") + writer.add_figure(f"Validation/images", fig, global_step=i) + plt.close(fig) + writer.flush()""" + + print("Training neural netowrk:") + estimator.train() + with tqdm(range(epochs), unit="epoch") as tq: + for epoch in tq: + losses = [] + for parameters in islice(prior_loader, epoch_repeats): + ( + indices, + quaternions, + res, + shift, + defocus, + b_factor, + amp, + snr, + ) = parameters + images = cryo_em_simulator( + models, + indices.to(device, non_blocking=True), + quaternions.to(device, non_blocking=True), + res.to(device, non_blocking=True), + shift.to(device, non_blocking=True), + defocus.to(device, non_blocking=True), + b_factor.to(device, non_blocking=True), + amp.to(device, non_blocking=True), + snr.to(device, non_blocking=True), + num_pixels, + pixel_size, + ) + for _indices, _images in zip( + indices.split(train_config["BATCH_SIZE"]), + images.split(train_config["BATCH_SIZE"]), + ): + losses.append( + step( + loss( + index_to_cv[_indices].to(device, non_blocking=True), + _images.to(device, non_blocking=True), + ) + ) + ) + losses = torch.stack(losses) + + tq.set_postfix(loss=losses.mean().item()) + mean_loss.append(losses.mean().item()) + current_step = (epoch + 1) * steps_per_epoch * epoch_repeats + + #writer.add_scalar("Loss/mean", losses.mean().item(), current_step) + #writer.add_scalar("Loss/std", losses.std().item(), current_step) + #writer.add_scalar("Loss/last", losses[-1].item(), current_step) + + if epoch % saving_frequency == 0: + torch.save(estimator.state_dict(), estimator_file + f"_epoch={epoch}") + + if False and validation_set is not None and epoch % validation_frequency == 0: + estimator.eval() + with torch.no_grad(): + val_posterior_samples = sample_posterior( + estimator, validation_set["IMAGES"], num_samples=5000, device=device, batch_size=train_config["BATCH_SIZE"] + ) + for i in range(num_validation_images): + writer.add_histogram( + f"Validation/posterior_{i}_index={validation_set['INDICES'][i].item()}", + val_posterior_samples[:, i], + global_step=current_step + ) + estimator.train() + + """writer.add_hparams( + train_config, + {"hparam/best_loss": min(mean_loss), "hparam/last_loss": mean_loss[-1]} + ) + writer.flush() + writer.close()""" + torch.save(estimator.state_dict(), estimator_file) + torch.save(torch.tensor(mean_loss), loss_file)