diff --git a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py index 2b3021d..a34192f 100644 --- a/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py +++ b/src/cryo_sbi/wpa_simulator/cryo_em_simulator.py @@ -109,8 +109,46 @@ def max_index(self) -> int: int: Maximum index of the model file. """ return len(self._models) - 1 + + def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None): + parameters = self._priors.sample((num_sim,)) + indices = parameters[0] if indices is None else indices + if indices is not None: + assert isinstance( + indices, torch.Tensor + ), "Indices are not a torch.tensor, converting to torch.tensor." + assert ( + indices.dtype == torch.float32 + ), "Indices are not a torch.float32, converting to torch.float32." + assert ( + indices.ndim == 2 + ), "Indices are not a 2D tensor, converting to 2D tensor. With shape (batch_size, 1)." + indices = torch.tensor(indices, dtype=torch.float32) + + images = [] + if batch_size is None: + batch_size = num_sim + for i in range(0, num_sim, batch_size): + batch_indices = indices[i:i+batch_size] + batch_parameters = [param[i:i+batch_size] for param in parameters[1:]] + batch_images = cryo_em_simulator( + self._models, + batch_indices, + *batch_parameters, + self._num_pixels, + self._pixel_size, + ) + images.append(batch_images.cpu()) + + images = torch.cat(images, dim=0) + + if return_parameters: + return images.cpu(), parameters + else: + return images.cpu() - def simulate(self, num_sim, indices=None, return_parameters=False): + """def simulate(self, num_sim, indices=None, return_parameters=False, batch_size=None): + parameters = self._priors.sample((num_sim,)) indices = parameters[0] if indices is None else indices if indices is not None: @@ -143,3 +181,4 @@ def simulate(self, num_sim, indices=None, return_parameters=False): return images.cpu(), parameters else: return images.cpu() +""" \ No newline at end of file