Skip to content

Commit

Permalink
added hybrid simulator in CryoEMSimulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed May 13, 2024
1 parent b45d2a2 commit 9e8e9d7
Showing 1 changed file with 65 additions and 24 deletions.
89 changes: 65 additions & 24 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ def cryo_em_simulator(
snr,
num_pixels,
pixel_size,
noise=True,
ctf=True,
noise=True,
normalize=True,
):
"""
Simulates a bacth of cryo-electron microscopy (cryo-EM) images of a set of given coars-grained models.
Expand Down Expand Up @@ -60,7 +61,8 @@ def cryo_em_simulator(
image = apply_ctf(image, defocus, b_factor, amp, pixel_size)
if noise:
image = add_noise(image, snr)
image = gaussian_normalize_image(image)
if normalize:
image = gaussian_normalize_image(image)
return image


Expand Down Expand Up @@ -148,7 +150,7 @@ def max_index(self) -> int:
int: Maximum index of the model file.
"""
return len(self._models) - 1

def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None, noise=True, ctf=True):
"""
Simulate cryo-EM images using the specified models and prior distributions.
Expand Down Expand Up @@ -201,34 +203,73 @@ def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=No
return images.cpu(), parameters
else:
return images.cpu()


def simulate_with_micrograph_noise(self, num_sim, micrographs, indices=None, return_parameters=False, parameters=None, batch_size=None, ctf=True, snr=0.0001):
"""
Simulate cryo-EM images using the specified models and prior distributions.
Args:
num_sim (int): The number of images to simulate.
indices (torch.Tensor, optional): The indices of the images to simulate. If None, all images are simulated.
return_parameters (bool, optional): Whether to return the sampled parameters used for simulation.
batch_size (int, optional): The batch size to use for simulation. If None, all images are simulated in a single batch.
def simulate_with_micrograph_noise(self, num_sim, micrographs, indices=None, return_parameters=False, batch_size=None, ctf=True, snr=0.0001):
self._init_micrograph_loader(micrographs, self._config["N_PIXELS"], num_noise_samples=num_sim)
images_and_maybe_params = self.simulate(
num_sim=num_sim,
indices=indices,
return_parameters=return_parameters,
batch_size=batch_size,
Returns:
torch.Tensor or tuple: The simulated images as a tensor of shape (num_sim, num_pixels, num_pixels),
and optionally the sampled parameters as a tuple of tensors.
"""
if parameters is None:
parameters = self._priors.sample((num_sim,))

indices = parameters[0] if indices is None else indices
if indices is not None:
assert isinstance(
indices, torch.Tensor
), "Indices are not a torch.tensor, converting to torch.tensor."
assert (
indices.dtype == torch.float32
), "Indices are not a torch.float32, converting to torch.float32."
assert (
indices.ndim == 2
), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
parameters[0] = indices

images = []
if batch_size is None:
batch_size = num_sim

self._init_micrograph_loader(micrographs, self._config["N_PIXELS"], num_noise_samples=batch_size)

for i in range(0, num_sim, batch_size):
batch_indices = indices[i : i + batch_size]
batch_parameters = [param[i : i + batch_size] for param in parameters[1:]]
batch_images = cryo_em_simulator(
self._models,
batch_indices,
*batch_parameters,
self._num_pixels,
self._pixel_size,
noise=False,
ctf=ctf
ctf=ctf,
normalize=False
)
if return_parameters:
images, parameters = images_and_maybe_params
else:
images = images_and_maybe_params
print("finished simulating images, drawing noise samples...")
noise_power = get_snr(batch_images, batch_parameters[-1])

noise_samples = []
for noise_sample in self._micrograph_loader:
noise_samples.append(noise_sample)
noise_samples = torch.cat(noise_samples, dim=0)

print("finished drawing noise samples, adding noise to images...")
noise_samples = noise_samples / snr
noise_samples = torch.cat(noise_samples, dim=0).to(self._device)
noise_samples = noise_samples * noise_power
batch_images = batch_images + noise_samples
batch_images = gaussian_normalize_image(batch_images)

images = images + noise_samples
images = gaussian_normalize_image(images)
images.append(batch_images.cpu())

images = torch.cat(images, dim=0)

return images

if return_parameters:
return images.cpu(), parameters
else:
return images.cpu()

0 comments on commit 9e8e9d7

Please sign in to comment.