From f487bce5912264191a74532a8a23447be9e7b091 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Thu, 20 Jun 2024 16:44:40 -0400 Subject: [PATCH] added clear images into mmd loss --- src/cryo_sbi/inference/train_npe_model.py | 10 ++++------ src/cryo_sbi/wpa_simulator/cryo_em_simulator.py | 16 ++++++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 2e9fef9..70e8810 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -58,7 +58,6 @@ def npe_train_no_saving( saving_frequency: int = 20, simulation_batch_size: int = 1024, gamma: float = 1.0, - experimental_particles: Union[str, None] = None, ) -> None: """ Train NPE model by simulating training data on the fly. @@ -119,7 +118,6 @@ def npe_train_no_saving( ) loss = NPERobustStatsLoss(estimator, gamma) - experimental_particles = torch.load(experimental_particles, map_location=device) optimizer = optim.AdamW( estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001 @@ -143,7 +141,7 @@ def npe_train_no_saving( amp, snr, ) = parameters - images = cryo_em_simulator( + images, clear_images = cryo_em_simulator( models, indices.to(device, non_blocking=True), quaternions.to(device, non_blocking=True), @@ -156,17 +154,17 @@ def npe_train_no_saving( num_pixels, pixel_size, ) - for _indices, _images in zip( + for _indices, _images, _clear_images in zip( indices.split(train_config["BATCH_SIZE"]), images.split(train_config["BATCH_SIZE"]), + clear_images.split(train_config["BATCH_SIZE"]), ): - random_indices = torch.randperm(experimental_particles.size(0))[:train_config["BATCH_SIZE"]] losses.append( step( loss( _indices.to(device, non_blocking=True), _images.to(device, non_blocking=True), - experimental_particles[random_indices].to(device, non_blocking=True), + _clear_images.to(device, non_blocking=True), ) ) ) diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index efcc7c9..93abcd4 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -52,10 +52,11 @@ def cryo_em_simulator( num_pixels, pixel_size, ) - image = apply_ctf(image, defocus, b_factor, amp, pixel_size) - image = add_noise(image, snr) + clear_image = apply_ctf(image, defocus, b_factor, amp, pixel_size) + image = add_noise(clear_image, snr) image = gaussian_normalize_image(image) - return image + clear_image = gaussian_normalize_image(clear_image) + return image, clear_image class CryoEmSimulator: @@ -159,12 +160,13 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No parameters[0] = indices images = [] + clear_images = [] if batch_size is None: batch_size = num_sim for i in range(0, num_sim, batch_size): batch_indices = indices[i : i + batch_size] batch_parameters = [param[i : i + batch_size] for param in parameters[1:]] - batch_images = cryo_em_simulator( + batch_images, batch_clear_images = cryo_em_simulator( self._models, batch_indices, *batch_parameters, @@ -172,10 +174,12 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No self._pixel_size, ) images.append(batch_images.cpu()) + clear_images.append(batch_clear_images.cpu()) images = torch.cat(images, dim=0) + clear_images = torch.cat(clear_images, dim=0) if return_parameters: - return images.cpu(), parameters + return images.cpu(), clear_images.cpu(), parameters else: - return images.cpu() + return images.cpu(), clear_images.cpu()