From 3ac64b13db256b598bf8951eb9a66137c04b86b6 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Tue, 16 May 2023 20:14:53 +0100 Subject: [PATCH] Add multi-frame micrograph support to predict and erase CLIs (#29) * add multi-frame micrograph support to fidder * add multi-frame support to fidder erase * force float in erase cli * remove stupid bug * pack properly * pack properly again * add fix from andriko --- src/fidder/erase/cli.py | 30 ++++++++++++++++++++---------- src/fidder/predict/cli.py | 30 ++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/fidder/erase/cli.py b/src/fidder/erase/cli.py index 7038abf..f28f998 100644 --- a/src/fidder/erase/cli.py +++ b/src/fidder/erase/cli.py @@ -1,11 +1,12 @@ from pathlib import Path +import einops import mrcfile import numpy as np import torch from typer import Option -from .erase import erase_masked_region +from .erase import erase_masked_region as _erase_masked_region from ..utils import get_pixel_spacing_from_header from .._cli import cli, OPTION_PROMPT_KWARGS as PKWARGS @@ -29,18 +30,27 @@ def erase_masked_region( ), ): """Erase a masked region in a cryo-EM image.""" - image = torch.as_tensor(mrcfile.read(input_image)).squeeze() - mask = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze() - inpainted_image = erase_masked_region( - image=image, - mask=mask, - background_intensity_model_resolution=(8, 8), - background_intensity_model_samples=25000, - ) + images = torch.as_tensor(mrcfile.read(input_image)).squeeze().float() + masks = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze() + if images.shape != masks.shape: + raise ValueError('Shape mismatch between data in image and mask files.') + images, ps = einops.pack([images], pattern='* h w') + masks, ps = einops.pack([masks], pattern='* h w') + + erased_images = torch.empty_like(images, dtype=torch.float32) + for idx, (image, mask) in enumerate(zip(images, masks)): + _erased_image = _erase_masked_region( + image=image, + mask=mask, + background_intensity_model_resolution=(8, 8), + background_intensity_model_samples=25000, + ) + erased_images[idx] = _erased_image + [erased_images] = einops.unpack(erased_images, pattern='* h w', packed_shapes=ps) pixel_spacing = get_pixel_spacing_from_header(input_image) mrcfile.write( name=output_image, - data=np.array(inpainted_image, dtype=np.float32), + data=np.array(erased_images, dtype=np.float32), voxel_size=pixel_spacing, overwrite=True, ) diff --git a/src/fidder/predict/cli.py b/src/fidder/predict/cli.py index 669b446..6a26a93 100644 --- a/src/fidder/predict/cli.py +++ b/src/fidder/predict/cli.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Optional +import einops import mrcfile import numpy as np import torch @@ -36,20 +37,29 @@ def predict_fiducial_mask( ), ): """Predict a fiducial mask using a pretrained model.""" - image = torch.tensor(mrcfile.read(input_image)) + images = torch.tensor(mrcfile.read(input_image)).float() + images, ps = einops.pack([images], pattern='* h w') if pixel_spacing is None: pixel_spacing = get_pixel_spacing_from_header(input_image) - mask, probabilities = _predict_fiducial_mask( - image=image, - pixel_spacing=pixel_spacing, - probability_threshold=probability_threshold, - model_checkpoint_file=model_checkpoint_file, - ) - mask = mask.cpu().numpy().astype(np.int8) - probabilities = probabilities.cpu().numpy() + + masks = torch.empty_like(images, dtype=torch.int8) + probabilities = torch.empty_like(images) + for idx, image in enumerate(images): + _mask, _probabilities = _predict_fiducial_mask( + image=image, + pixel_spacing=pixel_spacing, + probability_threshold=probability_threshold, + model_checkpoint_file=model_checkpoint_file, + ) + masks[idx] = _mask + probabilities[idx] = _probabilities + masks = masks.cpu().numpy().astype(np.int8) + probabilities = probabilities.float().cpu().numpy() + [masks] = einops.unpack(masks, pattern='* h w', packed_shapes=ps) + [probabilities] = einops.unpack(probabilities, pattern='* h w', packed_shapes=ps) output_pixel_spacing = (1, pixel_spacing, pixel_spacing) mrcfile.write( - name=output_mask, data=mask, voxel_size=output_pixel_spacing, overwrite=True + name=output_mask, data=masks, voxel_size=output_pixel_spacing, overwrite=True ) if output_probabilities is not None: probabilities = probabilities.astype(np.float32)