diff --git a/src/cryo_sbi/inference/models/build_models.py b/src/cryo_sbi/inference/models/build_models.py index 1b0ccc3..9a392fd 100644 --- a/src/cryo_sbi/inference/models/build_models.py +++ b/src/cryo_sbi/inference/models/build_models.py @@ -49,7 +49,7 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module: flow=model, theta_shift=config["THETA_SHIFT"], theta_scale=config["THETA_SCALE"], - **{"activation": nn.GELU},#partial(nn.LeakyReLU, 0.1), nn.GELU + **{"activation": partial(nn.LeakyReLU, 0.1)}, # TDOD add changeable activation #partial(nn.LeakyReLU, 0.1), nn.GELU ) print("Training with GELU")