Skip to content

Commit

Permalink
Add multi-frame micrograph support to predict and erase CLIs (#29)
Browse files Browse the repository at this point in the history
* add multi-frame micrograph support to fidder

* add multi-frame support to fidder erase

* force float in erase cli

* remove stupid bug

* pack properly

* pack properly again

* add fix from andriko
  • Loading branch information
alisterburt authored May 16, 2023
1 parent 3b1b4f0 commit 3ac64b1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
30 changes: 20 additions & 10 deletions src/fidder/erase/cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pathlib import Path

import einops
import mrcfile
import numpy as np
import torch
from typer import Option

from .erase import erase_masked_region
from .erase import erase_masked_region as _erase_masked_region
from ..utils import get_pixel_spacing_from_header
from .._cli import cli, OPTION_PROMPT_KWARGS as PKWARGS

Expand All @@ -29,18 +30,27 @@ def erase_masked_region(
),
):
"""Erase a masked region in a cryo-EM image."""
image = torch.as_tensor(mrcfile.read(input_image)).squeeze()
mask = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze()
inpainted_image = erase_masked_region(
image=image,
mask=mask,
background_intensity_model_resolution=(8, 8),
background_intensity_model_samples=25000,
)
images = torch.as_tensor(mrcfile.read(input_image)).squeeze().float()
masks = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze()
if images.shape != masks.shape:
raise ValueError('Shape mismatch between data in image and mask files.')
images, ps = einops.pack([images], pattern='* h w')
masks, ps = einops.pack([masks], pattern='* h w')

erased_images = torch.empty_like(images, dtype=torch.float32)
for idx, (image, mask) in enumerate(zip(images, masks)):
_erased_image = _erase_masked_region(
image=image,
mask=mask,
background_intensity_model_resolution=(8, 8),
background_intensity_model_samples=25000,
)
erased_images[idx] = _erased_image
[erased_images] = einops.unpack(erased_images, pattern='* h w', packed_shapes=ps)
pixel_spacing = get_pixel_spacing_from_header(input_image)
mrcfile.write(
name=output_image,
data=np.array(inpainted_image, dtype=np.float32),
data=np.array(erased_images, dtype=np.float32),
voxel_size=pixel_spacing,
overwrite=True,
)
30 changes: 20 additions & 10 deletions src/fidder/predict/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Optional

import einops
import mrcfile
import numpy as np
import torch
Expand Down Expand Up @@ -36,20 +37,29 @@ def predict_fiducial_mask(
),
):
"""Predict a fiducial mask using a pretrained model."""
image = torch.tensor(mrcfile.read(input_image))
images = torch.tensor(mrcfile.read(input_image)).float()
images, ps = einops.pack([images], pattern='* h w')
if pixel_spacing is None:
pixel_spacing = get_pixel_spacing_from_header(input_image)
mask, probabilities = _predict_fiducial_mask(
image=image,
pixel_spacing=pixel_spacing,
probability_threshold=probability_threshold,
model_checkpoint_file=model_checkpoint_file,
)
mask = mask.cpu().numpy().astype(np.int8)
probabilities = probabilities.cpu().numpy()

masks = torch.empty_like(images, dtype=torch.int8)
probabilities = torch.empty_like(images)
for idx, image in enumerate(images):
_mask, _probabilities = _predict_fiducial_mask(
image=image,
pixel_spacing=pixel_spacing,
probability_threshold=probability_threshold,
model_checkpoint_file=model_checkpoint_file,
)
masks[idx] = _mask
probabilities[idx] = _probabilities
masks = masks.cpu().numpy().astype(np.int8)
probabilities = probabilities.float().cpu().numpy()
[masks] = einops.unpack(masks, pattern='* h w', packed_shapes=ps)
[probabilities] = einops.unpack(probabilities, pattern='* h w', packed_shapes=ps)
output_pixel_spacing = (1, pixel_spacing, pixel_spacing)
mrcfile.write(
name=output_mask, data=mask, voxel_size=output_pixel_spacing, overwrite=True
name=output_mask, data=masks, voxel_size=output_pixel_spacing, overwrite=True
)
if output_probabilities is not None:
probabilities = probabilities.astype(np.float32)
Expand Down

0 comments on commit 3ac64b1

Please sign in to comment.