Skip to content

Commit

Permalink
new version gpu simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jul 14, 2023
1 parent 3867c34 commit f34e26b
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 187 deletions.
348 changes: 251 additions & 97 deletions notebooks/Untitled.ipynb

Large diffs are not rendered by default.

90 changes: 64 additions & 26 deletions notebooks/analysis_nma.ipynb

Large diffs are not rendered by default.

12 changes: 3 additions & 9 deletions notebooks/image_params_mixed_training.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
"MODEL_FILE": "../data/protein_models/6wxb_mixed_models.npy",
"ROTATIONS": true,
"SHIFT": true,
"CTF": true,
"NOISE": true,
"SHIFT": 40,
"DEFOCUS": [0.5, 5.0],
"SNR": [0.01, 0.1],
"RADIUS_MASK": 100,
"SNR": [0.01,1.0],
"AMP": 0.1,
"B_FACTOR": [1.0, 100.0],
"ELECWAVE": 0.019866,
"NOISE_INTENSITY": 0.5
"B_FACTOR": [1.0, 100.0]
}
2 changes: 1 addition & 1 deletion src/cryo_sbi/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@

from cryo_sbi.wpa_simulator.cryo_em_simulator import CryoEmSimulator
4 changes: 4 additions & 0 deletions src/cryo_sbi/inference/command_line_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def cl_npe_train_no_saving():
cl_parser.add_argument(
"--saving_freq", action="store", type=int, required=False, default=20
)
cl_parser.add_argument(
"--simulation_batch_size", action="store", type=int, required=False, default=1024
)

args = cl_parser.parse_args()

Expand All @@ -51,4 +54,5 @@ def cl_npe_train_no_saving():
n_workers=args.n_workers,
device=args.train_device,
saving_frequency=args.saving_freq,
simulation_batch_size=args.simulation_batch_size,
)
1 change: 0 additions & 1 deletion src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,5 @@ def forward(self, x):
x = self.resnet(x)
return x


if __name__ == "__main__":
pass
21 changes: 19 additions & 2 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@ def get_image_priors(
lower=torch.tensor([0], dtype=torch.float32, device=device),
upper=torch.tensor([max_index], dtype=torch.float32, device=device),
)
quaternion_prior = QuaternionPrior(device)
if image_config.get("ROTATIONS") and isinstance(image_config["ROTATIONS"], list) and len(image_config["ROTATIONS"]) == 4:
test_quat = image_config["ROTATIONS"]
quaternion_prior = QuaternionTestPrior(test_quat, device)

return ImagePrior(index_prior, sigma, shift, defocus, b_factor, snr, amp, device=device)
return ImagePrior(index_prior, quaternion_prior, sigma, shift, defocus, b_factor, snr, amp, device=device)


class QuaternionPrior:
Expand All @@ -112,12 +116,25 @@ def sample(self, shape) -> torch.Tensor:
[gen_quat().to(self.device) for _ in range(shape[0])], dim=0
)
return quats


class QuaternionTestPrior:
def __init__(self, quat, device) -> None:
self.device = device
self.quat = torch.tensor(quat, device=device)

def sample(self, shape) -> torch.Tensor:
quats = torch.stack(
[self.quat for _ in range(shape[0])], dim=0
)
return quats


class ImagePrior:
def __init__(
self,
index_prior,
quaternion_prior,
sigma_prior,
shift_prior,
defocus_prior,
Expand All @@ -128,7 +145,7 @@ def __init__(
) -> None:
self.priors = [
index_prior,
QuaternionPrior(device),
quaternion_prior,
sigma_prior,
shift_prior,
defocus_prior,
Expand Down
36 changes: 22 additions & 14 deletions src/cryo_sbi/inference/train_npe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from cryo_sbi.inference.models.build_models import build_npe_flow_model
from cryo_sbi.inference.validate_train_config import check_train_params
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params
from cryo_sbi.inference.validate_train_config import check_train_params


def load_model(
Expand Down Expand Up @@ -52,6 +54,7 @@ def npe_train_no_saving(
n_workers: int = 1,
device: str = "cpu",
saving_frequency: int = 20,
simulation_batch_size: int = 1024
) -> None:
"""
Train NPE model by simulating training data on the fly.
Expand All @@ -78,27 +81,31 @@ def npe_train_no_saving(
"""

train_config = json.load(open(train_config))
check_train_params(train_config)
image_config = json.load(open(image_config))


assert simulation_batch_size > train_config["BATCH_SIZE"]
assert simulation_batch_size % train_config["BATCH_SIZE"] == 0

if image_config["MODEL_FILE"].endswith("npy"):
models = (
torch.from_numpy(
np.load(image_config["MODEL_FILE"]),
)
.to(device)
.to(torch.float32)
)
models = (
torch.from_numpy(
np.load(image_config["MODEL_FILE"]),
)
.to(device)
.to(torch.float32)
)
else:
models = torch.load(
image_config["MODEL_FILE"],
dtype=torch.float32,
device=device
image_config["MODEL_FILE"],
dtype=torch.float32,
device=device
)

image_prior = get_image_priors(len(models) - 1, image_config, device="cpu")
prior_loader = PriorLoader(
image_prior,
batch_size=train_config["BATCH_SIZE"],
batch_size=simulation_batch_size,
num_workers=n_workers
)

Expand All @@ -119,7 +126,7 @@ def npe_train_no_saving(

loss = NPELoss(estimator)
optimizer = optim.AdamW(
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.0001
estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001
)
step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"])
mean_loss = []
Expand All @@ -144,7 +151,8 @@ def npe_train_no_saving(
num_pixels,
pixel_size,
)
losses.append(step(loss(indices.to(device), images.to(device))))
for _indices, _images in zip(indices.split(train_config["BATCH_SIZE"]), images.split(train_config["BATCH_SIZE"])):
losses.append(step(loss(_indices.to(device, non_blocking=True), _images.to(device, non_blocking=True))))
losses = torch.stack(losses)

tq.set_postfix(loss=losses.mean().item())
Expand Down
34 changes: 15 additions & 19 deletions src/cryo_sbi/wpa_simulator/cryo_em_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cryo_sbi.wpa_simulator.noise import add_noise
from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
from cryo_sbi.inference.priors import get_image_priors
from cryo_sbi.wpa_simulator.validate_image_config import check_image_params


def cryo_em_simulator(
Expand Down Expand Up @@ -40,12 +41,12 @@ def cryo_em_simulator(

class CryoEmSimulator:
def __init__(self, config_fname: str, device: str = "cpu"):
self._device = device
self._load_params(config_fname)
self._load_models()
self._priors = get_image_priors(config_fname, device=device)
self._priors = get_image_priors(self.max_index, self._config, device=device)
self._num_pixels = torch.tensor(self._config["N_PIXELS"], dtype=torch.float32, device=device)
self._pixel_size = torch.tensor(self._config["PIXEL_SIZE"], dtype=torch.float32, device=device)
self._device = device

def _load_params(self, config_fname: str) -> None:
"""
Expand All @@ -59,6 +60,7 @@ def _load_params(self, config_fname: str) -> None:
"""

config = json.load(open(config_fname))
check_image_params(config)
self._config = config

def _load_models(self) -> None:
Expand All @@ -69,23 +71,20 @@ def _load_models(self) -> None:
None
"""
print("Loading models without template... assuming shape (models, 3, atoms)")
if self._config["MODEL_FILE"].endswith("npy"):
models = (
torch.from_numpy(
np.load(self._config["MODEL_FILE"]),
)
.to(self._device)
.to(torch.float32)
).to(self._device).to(torch.float32)
)
else:
models = torch.load(
self._config["MODEL_FILE"],
dtype=torch.float32,
device=self._device
)
self._config["MODEL_FILE"]
).to(self._device).to(torch.float32)

self._models = models
assert self._models.ndim != 3, "Models are not of shape (models, 3, atoms)."

assert self._models.ndim == 3, "Models are not of shape (models, 3, atoms)."
assert self._models.shape[1] == 3, "Models are not of shape (models, 3, atoms)."

@property
Expand All @@ -98,12 +97,8 @@ def max_index(self) -> int:
"""
return len(self._models) - 1

def _sample_prior(self, shape: tuple) -> torch.Tensor:
prior_samples = self._priors.sample()
return prior_samples

def simulate(self, indices=None, return_parameters=False):
parameters = self._sample_prior()
def simulate(self, num_sim, indices=None, return_parameters=False):
parameters = self._priors.sample((num_sim,))
indices = parameters[0] if indices is None else indices
if indices is not None:
assert isinstance(
Expand All @@ -126,11 +121,12 @@ def simulate(self, indices=None, return_parameters=False):
parameters[4],
parameters[5],
parameters[6],
parameters[7],
self._num_pixels,
self._pixel_size,
)

if return_parameters:
return images, parameters
return images.cpu(), parameters
else:
return images
return images.cpu()
11 changes: 6 additions & 5 deletions src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def gen_quat() -> torch.Tensor:


def gen_rot_matrix(quats: torch.Tensor) -> torch.Tensor:
# TODO add docstring explaining the quaternion convention qr, qx, qy, qz
"""
Generate a rotation matrix from a quaternion.
Expand Down Expand Up @@ -73,14 +74,14 @@ def project_density(
num_batch, _, num_atoms = coords.shape
norm = 1 / (2 * torch.pi * sigma**2 * num_atoms)

grid_min = -pixel_size * (num_pixels - 1) * 0.5
grid_max = pixel_size * (num_pixels - 1) * 0.5 + pixel_size
grid_min = -pixel_size * num_pixels * 0.5
grid_max = pixel_size * num_pixels * 0.5

rot_matrix = gen_rot_matrix(quats)
grid = torch.arange(grid_min, grid_max, pixel_size, device=coords.device).repeat(
grid = torch.arange(grid_min, grid_max, pixel_size, device=coords.device)[0:num_pixels.long()].repeat(
num_batch, 1
)

) # [0: num_pixels.long()] is needed due to single precision error in some cases
coords_rot = torch.bmm(rot_matrix, coords)
coords_rot[:, :2, :] += shift.unsqueeze(-1)

Expand Down
12 changes: 6 additions & 6 deletions src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ def get_snr(images, snr):

mask = circular_mask(
n_pixels=images.shape[-1],
radius=64, # todo: make this a parameter
radius=images.shape[-1]//2, # TODO: make this a parameter
device=images.device,
)
signal_power = images[:, mask].pow(2).mean().sqrt() # torch.std(image[mask])
signal_power = torch.std(images[:, mask], dim=[-1, -2]) #images are not centered at 0, so std is not the same as power
noise_power = signal_power / torch.sqrt(torch.pow(10, snr))

return noise_power


def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor:
"""
Adds noise to image
Adds noise to image.
Args:
image (torch.Tensor): Image of shape (n_pixels, n_pixels).
Expand All @@ -54,12 +54,12 @@ def add_noise(image: torch.Tensor, snr, seed=None) -> torch.Tensor:
"""

if seed is not None:
torch.manual_seed(seed)
torch.manual_seed(seed) #

snr = get_snr(image, snr)
noise_power = get_snr(image, snr)

noise = torch.randn_like(image, device=image.device)
noise = noise * snr.reshape(-1, 1, 1)
noise = noise * noise_power.reshape(-1, 1, 1)

image_noise = image + noise

Expand Down
9 changes: 2 additions & 7 deletions src/cryo_sbi/wpa_simulator/validate_image_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def check_params(config: dict) -> None:
def check_image_params(config: dict) -> None:
"""
Checks if all necessary parameters are provided.
Expand All @@ -14,16 +14,11 @@ def check_params(config: dict) -> None:
"PIXEL_SIZE",
"SIGMA",
"SHIFT",
"CTF",
"NOISE",
"DEFOCUS",
"SNR",
"MODEL_FILE",
"ROTATIONS",
"RADIUS_MASK",
"AMP",
"B_FACTOR",
"ELECWAVE",
"B_FACTOR"
]

for key in needed_keys:
Expand Down

0 comments on commit f34e26b

Please sign in to comment.