Skip to content

Commit

Permalink
batched simulator in CryoEmSimulator class
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Oct 6, 2023
1 parent 7a92214 commit 47631ef
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,46 @@ def max_index(self) -> int:
int: Maximum index of the model file.
"""
return len(self._models) - 1

def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None):
parameters = self._priors.sample((num_sim,))
indices = parameters[0] if indices is None else indices
if indices is not None:
assert isinstance(
indices, torch.Tensor
), "Indices are not a torch.tensor, converting to torch.tensor."
assert (
indices.dtype == torch.float32
), "Indices are not a torch.float32, converting to torch.float32."
assert (
indices.ndim == 2
), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)."
indices = torch.tensor(indices, dtype=torch.float32)

images = []
if batch_size is None:
batch_size = num_sim
for i in range(0, num_sim, batch_size):
batch_indices = indices[i:i+batch_size]
batch_parameters = [param[i:i+batch_size] for param in parameters[1:]]
batch_images = cryo_em_simulator(
self._models,
batch_indices,
*batch_parameters,
self._num_pixels,
self._pixel_size,
)
images.append(batch_images.cpu())

images = torch.cat(images, dim=0)

if return_parameters:
return images.cpu(), parameters
else:
return images.cpu()

def simulate(self, num_sim, indices=None, return_parameters=False):
"""def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None):
parameters = self._priors.sample((num_sim,))
indices = parameters[0] if indices is None else indices
if indices is not None:
Expand Down Expand Up @@ -143,3 +181,4 @@ def simulate(self, num_sim, indices=None, return_parameters=False):
return images.cpu(), parameters
else:
return images.cpu()
"""

0 comments on commit 47631ef

Please sign in to comment.