From 46a22afe1a5508348c8c82f7a18aba3851e9b4ac Mon Sep 17 00:00:00 2001 From: Lorenzo Gaifas Date: Fri, 6 Sep 2024 17:23:52 +0200 Subject: [PATCH 1/2] use 3D cubic BSpline and 3D std for 3D erasure of features Co-authored-by: Alessio d'Acapito --- src/fidder/erase/cli.py | 41 ++++++++++++++++ src/fidder/erase/erase.py | 61 ++++++++++++++++++++++- src/fidder/erase/sparse_local_mean.py | 69 +++++++++++++++++++++++++++ src/fidder/utils.py | 47 ++++++++++++++++++ 4 files changed, 216 insertions(+), 2 deletions(-) diff --git a/src/fidder/erase/cli.py b/src/fidder/erase/cli.py index f28f998..34e8782 100644 --- a/src/fidder/erase/cli.py +++ b/src/fidder/erase/cli.py @@ -7,6 +7,7 @@ from typer import Option from .erase import erase_masked_region as _erase_masked_region +from .erase import erase_masked_region_3d as _erase_masked_region_3d from ..utils import get_pixel_spacing_from_header from .._cli import cli, OPTION_PROMPT_KWARGS as PKWARGS @@ -54,3 +55,43 @@ def erase_masked_region( voxel_size=pixel_spacing, overwrite=True, ) + + +@cli.command(name="erase_3d", no_args_is_help=True) +def erase_masked_region_3d( + input_image: Path = Option( + default=..., + help="Image file in MRC format.", + **PKWARGS + ), + input_mask: Path = Option( + default=..., + help="Mask file in MRC format.", + **PKWARGS + ), + output_image: Path = Option( + default=..., + help="Output file in MRC format.", + **PKWARGS + ), +): + """Erase a masked region in a cryo-EM image.""" + volume = torch.as_tensor(mrcfile.read(input_image)).squeeze().float() + mask = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze() + if volume.shape != mask.shape: + raise ValueError('Shape mismatch between data in volume and mask files.') + + erased_volume = _erase_masked_region_3d( + volume=volume, + mask=mask, + background_intensity_model_resolution=(8, 8, 8), + background_intensity_model_samples=25000, + ) + + pixel_spacing = get_pixel_spacing_from_header(input_image) + mrcfile.write( + name=output_image, + data=np.array(erased_volume, dtype=np.float32), + voxel_size=pixel_spacing, + overwrite=True, + ) diff --git a/src/fidder/erase/erase.py b/src/fidder/erase/erase.py index 7192fa1..e26a965 100644 --- a/src/fidder/erase/erase.py +++ b/src/fidder/erase/erase.py @@ -4,8 +4,8 @@ import torch from einops import einops -from ..utils import estimate_background_std -from .sparse_local_mean import estimate_local_mean +from ..utils import estimate_background_std, estimate_background_std_3d +from .sparse_local_mean import estimate_local_mean, estimate_local_mean_3d def erase_masked_region( @@ -108,3 +108,60 @@ def _erase_single_image( size=n_pixels_to_inpaint) inpainted_image[idx_foreground] += torch.as_tensor(noise) return inpainted_image + + +def erase_masked_region_3d( + volume: torch.Tensor, + mask: torch.Tensor, + background_intensity_model_resolution: Tuple[int, int, int] = (5, 5, 5), + background_intensity_model_samples: int = 20000, +) -> torch.Tensor: + """Inpaint image(s) with gaussian noise. + + + Parameters + ---------- + image: torch.Tensor + `(b, h, w)` or `(h, w)` array containing image data for erase. + mask: torch.Tensor + `(b, h, w)` or `(h, w)` binary mask separating foreground from background pixels. + Foreground pixels (1) will be inpainted. + background_intensity_model_resolution: Tuple[int, int] + Number of points in each image dimension for the background mean model. + Minimum of two points in each dimension. + background_intensity_model_samples: int + Number of sample points used to determine the model of the background mean. + + Returns + ------- + inpainted_image: torch.Tensor + `(b, h, w)` or `(h, w)` array containing image data inpainted in the foreground pixels of the mask + with gaussian noise matching the local mean and global standard deviation of the image + for background pixels. + """ + volume = torch.as_tensor(volume) + mask = torch.as_tensor(mask, dtype=torch.bool) + if volume.shape != mask.shape: + raise ValueError("image shape must match mask shape.") + + inpainted = torch.clone(volume) + local_mean = estimate_local_mean_3d( + volume=volume, + mask=torch.logical_not(mask), + resolution=background_intensity_model_resolution, + n_samples_for_fit=background_intensity_model_samples, + ) + + # fill foreground pixels with local mean + idx_foreground = torch.argwhere(mask.bool() == True) + idx_foreground = (idx_foreground[:, 0], idx_foreground[:, 1], idx_foreground[:, 2]) + + inpainted[idx_foreground] = local_mean[idx_foreground] + + # add noise with mean=0 std=background std estimate + background_std = estimate_background_std_3d(volume, mask) + n_pixels_to_inpaint = idx_foreground[0].shape[0] + noise = np.random.normal(loc=0, scale=background_std, size=(n_pixels_to_inpaint, 3)) + inpainted[idx_foreground] += torch.as_tensor(np.mean(noise, axis=1)) + + return torch.as_tensor(inpainted, dtype=torch.float32) diff --git a/src/fidder/erase/sparse_local_mean.py b/src/fidder/erase/sparse_local_mean.py index e27dfcd..77a23b8 100644 --- a/src/fidder/erase/sparse_local_mean.py +++ b/src/fidder/erase/sparse_local_mean.py @@ -3,6 +3,7 @@ import numpy as np import torch from scipy.interpolate import LSQBivariateSpline +from torch_cubic_spline_grids.b_spline_grids import CubicBSplineGrid3d def estimate_local_mean( @@ -59,3 +60,71 @@ def estimate_local_mean( x = np.arange(image.shape[-1]) local_mean = background_model(y, x, grid=True) return torch.tensor(local_mean, dtype=input_dtype) + + +def estimate_local_mean_3d( + volume: torch.Tensor, + mask: Optional[torch.Tensor] = None, + resolution: Tuple[int, int, int] = (5, 5, 5), + n_samples_for_fit: int = 20000, +): + """Estimate local mean of an image with a bivariate cubic spline. + + A mask can be provided to + + Parameters + ---------- + image: torch.Tensor + `(h, w)` array containing image data. + mask: Optional[torch.Tensor] + `(h, w)` array containing a binary mask specifying foreground + and background pixels for the estimation. + resolution: Tuple[int, int] + Resolution of the local mean estimate in each dimension. + n_samples_for_fit: int + Number of samples taken from foreground pixels for background mean estimation. + The number of background pixels will be used if this number is greater than the + number of background pixels. + + Returns + ------- + local_mean: torch.Tensor + `(h, w)` array containing a local estimate of the local mean. + """ + input_dtype = volume.dtype + volume = volume.numpy() + mask = np.ones_like(volume) if mask is None else mask.numpy() + + # get a random set of foreground pixels for the background fit + foreground_sample_idx = np.argwhere(mask == 1) + + n_samples_for_fit = min(n_samples_for_fit, len(foreground_sample_idx)) + selection = np.random.choice( + foreground_sample_idx.shape[0], size=n_samples_for_fit, replace=False + ) + foreground_sample_idx = foreground_sample_idx[selection] + z, y, x = foreground_sample_idx[:, 0], foreground_sample_idx[:, 1], foreground_sample_idx[:, 2] + + w = torch.as_tensor(volume[(z, y, x)]) + + grid = CubicBSplineGrid3d(resolution=resolution) + optimiser = torch.optim.Adam(grid.parameters(), lr=0.01) + + for i in range(500): + # what does the model predict for our observations? + prediction = grid(foreground_sample_idx).squeeze() + + # zero gradients and calculate loss between observations and model prediction + optimiser.zero_grad() + loss = torch.sum((prediction - w)**2)**0.5 + + # backpropagate loss and update values at points on grid + loss.backward() + optimiser.step() + + tz = torch.tensor(np.linspace(0, 1, volume.shape[0])) + ty = torch.tensor(np.linspace(0, 1, volume.shape[1])) + tx = torch.tensor(np.linspace(0, 1, volume.shape[2])) + zz, yy, xx = torch.meshgrid(tz, ty, tx, indexing='xy') + w = grid(torch.stack((zz, yy, xx), dim=-1)).detach().numpy().reshape(volume.shape) + return torch.tensor(w, dtype=input_dtype) diff --git a/src/fidder/utils.py b/src/fidder/utils.py index 3dca809..600c2d1 100644 --- a/src/fidder/utils.py +++ b/src/fidder/utils.py @@ -101,6 +101,28 @@ def central_crop_2d(image: torch.Tensor, percentage: float = 25) -> torch.Tensor return image[..., hf:hc, wf:wc] +def central_crop_3d(image: torch.Tensor, percentage: float = 25) -> torch.Tensor: + """Get a central crop of (a batch of) 2D image(s). + + Parameters + ---------- + image: torch.Tensor + `(b, h, w)` or `(h, w)` array of 2D images. + percentage: float + percentage of image height and width for cropped region. + Returns + ------- + cropped_image: torch.Tensor + `(b, h, w)` or `(h, w)` array of cropped 2D images. + """ + h, w, d = image.shape[-3], image.shape[-2], image.shape[-1] + mh, mw, md = h // 2, w // 2, d // 2 + dh, dw, dd = int(h * (percentage / 100 / 2)), int(w * (percentage / 100 / 2)), int(d * (percentage / 100 / 2)) + hf, wf, df = mh - dh, mw - dw, md - dd + hc, wc, dc = mh + dh, mw + dw, md + dd + return image[..., hf:hc, wf:wc, df:dc] + + def estimate_background_std(image: torch.Tensor, mask: torch.Tensor): """Estimate the standard deviation of the background from a central crop. @@ -120,6 +142,31 @@ def estimate_background_std(image: torch.Tensor, mask: torch.Tensor): return torch.std(image[mask == 0]) +def estimate_background_std_3d(image: torch.Tensor, mask: torch.Tensor): + """Estimate the standard deviation of the background from a central crop. + + Parameters + ---------- + image: torch.Tensor + `(h, w)` array containing data for which background standard deviation will be estimated. + mask: torch.Tensor of 0 or 1 + Binary mask separating foreground and background. + Returns + ------- + standard_deviation: float + estimated standard deviation for the background. + """ + image = central_crop_3d(image, percentage=25).float() + mask = central_crop_3d(mask, percentage=25) + image_masked = image.clone() + image_masked[mask == 1] = np.nan + return ( + np.nanmean(np.nanstd(image_masked, axis=0)), + np.nanmean(np.nanstd(image_masked, axis=1)), + np.nanmean(np.nanstd(image_masked, axis=2)), + ) + + def get_pixel_spacing_from_header(image: Path) -> float: with mrcfile.open(image, header_only=True, permissive=True) as mrc: return float(mrc.voxel_size.x) From 40d4584d3dde15e04885c29b10ecae74b537e947 Mon Sep 17 00:00:00 2001 From: Lorenzo Gaifas Date: Wed, 11 Sep 2024 12:01:25 +0200 Subject: [PATCH 2/2] rescale sampling coordinates --- src/fidder/erase/sparse_local_mean.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fidder/erase/sparse_local_mean.py b/src/fidder/erase/sparse_local_mean.py index 77a23b8..c1f98e6 100644 --- a/src/fidder/erase/sparse_local_mean.py +++ b/src/fidder/erase/sparse_local_mean.py @@ -110,9 +110,10 @@ def estimate_local_mean_3d( grid = CubicBSplineGrid3d(resolution=resolution) optimiser = torch.optim.Adam(grid.parameters(), lr=0.01) + foreground_sample_idx_rescaled = foreground_sample_idx / volume.shape for i in range(500): # what does the model predict for our observations? - prediction = grid(foreground_sample_idx).squeeze() + prediction = grid(foreground_sample_idx_rescaled).squeeze() # zero gradients and calculate loss between observations and model prediction optimiser.zero_grad()