diff --git a/benchmarks/filter_2d.py b/benchmarks/filter_2d.py index 067fc5d8..4b398802 100644 --- a/benchmarks/filter_2d.py +++ b/benchmarks/filter_2d.py @@ -50,6 +50,8 @@ def setup_filter( soma_diameter=settings.soma_diameter, log_sigma_size=settings.log_sigma_size, n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh=settings.n_sds_above_mean_local_thresh, + local_thresh_tile_size=settings.local_thresh_tile_size, torch_device=torch_device, dtype=settings.filtering_dtype.__name__, use_scipy=use_scipy, diff --git a/benchmarks/filter_debug.py b/benchmarks/filter_debug.py new file mode 100644 index 00000000..37e38b45 --- /dev/null +++ b/benchmarks/filter_debug.py @@ -0,0 +1,424 @@ +""" +Given an input folder containing a list of tiff, it loads it using dask +and does full 2d and 3d filtering, cell detection and cluster splitting. +For each filtering step it outputs the filtered image so you can see what +the step produces and tune parameters. + +It outputs the following folders in the provided output directory (comment +out `save_tiffs` or filtering steps if you don't want wish to run / save them): + +- input: The original input tiffs +- clipped: The first step of 2d filtering where the data is clipped to a max + value (typically should be unchanged). +- enhanced: After it was run through the 2d filters. +- inside_brain: Tiled planes, with each tile indicating whether the tile is + inside or outside the brain. +- filtered_2d: After the enhanced images was thresholded to a binary foreground + / background. +- filtered_3d: After the 3d ball filtering. +- struct_id: The output from the cell detector, where each voxel has an ID with + its cell number or zero if it's background. +- struct_type: Voxel values are 0 for background, 1 if it's a potential cell, + 2 if it's a structure to be split, and 3 if it's too big to be split. +- struct_type_split: Same as struct_type except we put a sphere with value 1 + centered on the structures that were split into cells. + +There's also a `structures.csv` output that lists all the detected structures, +their volumes, and type. + +To analyze, in Fiji open each directory as an image sequence (use virtual +option), check that each image sequence is 32-bit (the max of all of them) +or change the type to 32-bit. Then merge them as color channels with the +composite option and using the Image -> Color -> Channels tool switch them +from composite to grayscale. + +This will load all the images into memory! So select only the directories +you wish to inspect. They should now be overlaid and you can inspect how +the algorithms processed cells. +""" + +import csv +import dataclasses +import math +from pathlib import Path + +import numpy as np +import tifffile +import torch +import tqdm +from brainglobe_utils.IO.image.load import read_with_dask + +from cellfinder.core.detect.filters.plane import TileProcessor +from cellfinder.core.detect.filters.setup_filters import DetectionSettings +from cellfinder.core.detect.filters.volume.ball_filter import BallFilter +from cellfinder.core.detect.filters.volume.structure_detection import ( + CellDetector, + get_structure_centre, +) +from cellfinder.core.detect.filters.volume.structure_splitting import ( + split_cells, +) + + +def setup_filter( + signal_path: Path, # expect to load z, y, x + batch_size: int = 1, + torch_device="cpu", + dtype=np.uint16, + use_scipy=True, + voxel_sizes: tuple[float, float, float] = (5, 2, 2), + soma_diameter: float = 16, + max_cluster_size: float = 100_000, + ball_xy_size: float = 6, + ball_z_size: float = 15, + ball_overlap_fraction: float = 0.6, + soma_spread_factor: float = 1.4, + n_free_cpus: int = 2, + log_sigma_size: float = 0.2, + n_sds_above_mean_thresh: float = 10, + n_sds_above_mean_local_thresh: float = 10, + local_thresh_tile_size: float | None = None, + split_ball_xy_size: int = 3, + split_ball_z_size: int = 3, + split_ball_overlap_fraction: float = 0.8, + n_splitting_iter: int = 10, + start_plane: int = 0, + end_plane: int = 0, +): + signal_array = read_with_dask(str(signal_path)) + if end_plane <= 0: + end_plane = len(signal_array) + signal_array = signal_array[start_plane:end_plane, :, :] + + signal_array = np.asarray(signal_array).astype(dtype) + shape = signal_array.shape + + settings = DetectionSettings( + plane_original_np_dtype=dtype, + plane_shape=shape[1:], + voxel_sizes=voxel_sizes, + soma_spread_factor=soma_spread_factor, + soma_diameter_um=soma_diameter, + max_cluster_size_um3=max_cluster_size, + ball_xy_size_um=ball_xy_size, + ball_z_size_um=ball_z_size, + start_plane=0, + end_plane=len(signal_array), + n_free_cpus=n_free_cpus, + ball_overlap_fraction=ball_overlap_fraction, + log_sigma_size=log_sigma_size, + n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh=n_sds_above_mean_local_thresh, + local_thresh_tile_size=local_thresh_tile_size, + outlier_keep=False, + artifact_keep=False, + save_planes=False, + batch_size=batch_size, + torch_device=torch_device, + n_splitting_iter=n_splitting_iter, + ) + + kwargs = dataclasses.asdict(settings) + kwargs["ball_z_size_um"] = split_ball_z_size + kwargs["ball_xy_size_um"] = split_ball_xy_size + kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction + kwargs["torch_device"] = "cpu" + kwargs["plane_original_np_dtype"] = np.float32 + splitting_settings = DetectionSettings(**kwargs) + + signal_array = settings.filter_data_converter_func(signal_array) + signal_array = torch.from_numpy(signal_array).to(torch_device) + + tile_processor = TileProcessor( + plane_shape=shape[1:], + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + soma_diameter=settings.soma_diameter, + log_sigma_size=settings.log_sigma_size, + n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh=settings.n_sds_above_mean_local_thresh, + local_thresh_tile_size=settings.local_thresh_tile_size, + torch_device=torch_device, + dtype=settings.filtering_dtype.__name__, + use_scipy=use_scipy, + ) + + ball_filter = BallFilter( + plane_height=settings.plane_height, + plane_width=settings.plane_width, + ball_xy_size=settings.ball_xy_size, + ball_z_size=settings.ball_z_size, + overlap_fraction=settings.ball_overlap_fraction, + threshold_value=settings.threshold_value, + soma_centre_value=settings.soma_centre_value, + tile_height=settings.tile_height, + tile_width=settings.tile_width, + dtype=settings.filtering_dtype.__name__, + batch_size=batch_size, + torch_device=torch_device, + use_mask=True, + ) + + cell_detector = CellDetector( + settings.plane_height, + settings.plane_width, + start_z=ball_filter.first_valid_plane, + soma_centre_value=settings.detection_soma_centre_value, + ) + + return ( + settings, + splitting_settings, + tile_processor, + ball_filter, + cell_detector, + signal_array, + batch_size, + ) + + +def save_tiffs( + root: Path, + prefix: str, + start_index: int, + buffer: torch.Tensor | np.ndarray, + total: int, +): + root = root / prefix + root.mkdir(parents=True, exist_ok=True) + + if isinstance(buffer, np.ndarray): + arr = buffer + else: + arr = buffer.cpu().numpy() + digits = int(math.ceil(math.log10(total))) + for i, plane in enumerate(arr, start_index): + tifffile.imwrite( + root / f"{prefix}_{i:0{digits}}.tif", + plane, + compression="LZW", + ) + + +def dump_structures( + output_root: Path, + settings: DetectionSettings, + splitting_settings: DetectionSettings, + cell_detector: CellDetector, + signal_array, +): + max_vol = settings.max_cell_volume + max_cluster = settings.max_cluster_size + shape = signal_array.shape + struct_type = np.zeros(shape, dtype=np.uint8) + struct_type_split = np.zeros(shape, dtype=np.uint8) + + dia = splitting_settings.soma_diameter + r1 = int(dia / 2) + r2 = int(dia - r1) + sphere = np.indices((dia,) * 3) + position = np.array( + np.array( + [ + dia / 2, + ] + * 3 + ) + ).reshape((-1, 1, 1, 1)) + arr = np.linalg.norm(sphere - position, axis=0) + sphere_mask = arr <= dia / 2 + + with open(output_root / "structures.csv", "w", newline="") as fh: + writer = csv.writer(fh, delimiter=",") + writer.writerow(["id", "x", "y", "z", "volume", "volume_type"]) + + for cell_id, cell_points in cell_detector.get_structures().items(): + vol = len(cell_points) + x, y, z = get_structure_centre(cell_points) + + if vol < max_vol: + tp = "maybe_cell" + color = 1 + elif vol < max_cluster: + tp = "needs_split" + color = 2 + else: + tp = "too_big" + color = 3 + + writer.writerow(list(map(str, [cell_id, x, y, z, vol, tp]))) + for p in cell_points: + struct_type[p[2], p[1], p[0]] = color + if tp != "needs_split": + struct_type_split[p[2], p[1], p[0]] = color + + if tp == "needs_split": + centers = split_cells(cell_points, settings=splitting_settings) + for x, y, z in centers: + x, y, z = map(int, [x, y, z]) + if any(v < r1 for v in [x, y, z]): + continue + if any(v + r2 > d for d, v in zip(shape, [z, y, x])): + continue + struct_type_split[ + z - r1 : z + r2, y - r1 : y + r2, x - r1 : x + r2 + ][sphere_mask] = 1 + + save_tiffs(output_root, "struct_type", 0, struct_type, len(struct_type)) + save_tiffs( + output_root, + "struct_type_split", + 0, + struct_type_split, + len(struct_type_split), + ) + + +def pad_3d_filtered_images( + output_root: Path, + ball_filter: BallFilter, + signal_array: torch.Tensor | np.ndarray, + sample_plane: torch.Tensor | np.ndarray, + n_saved_planes: int, +): + """ + 3d filters skip the first and last planes. This pads them by creating those + planes as blank planes so that all outputs have same number of planes. + """ + n = len(signal_array) + + if ball_filter.first_valid_plane: + # 3d filters skip first few planes + buff = np.zeros( + (ball_filter.first_valid_plane, *sample_plane.shape), + dtype=sample_plane.dtype, + ) + save_tiffs(output_root, "filtered_3d", 0, buff, n) + save_tiffs( + output_root, + "struct_id", + 0, + buff.astype(np.uint32), + n, + ) + + n_saved_planes += ball_filter.first_valid_plane + + if n_saved_planes < n: + buff = np.zeros( + (n - n_saved_planes, *sample_plane.shape), dtype=sample_plane.dtype + ) + save_tiffs(output_root, "filtered_3d", n_saved_planes, buff, n) + save_tiffs( + output_root, + "struct_id", + n_saved_planes, + buff.astype(np.uint32), + n, + ) + + +def run_filter( + output_root: Path, + settings: DetectionSettings, + splitting_settings: DetectionSettings, + tile_processor: TileProcessor, + ball_filter: BallFilter, + cell_detector: CellDetector, + signal_array, + batch_size, +): + detection_converter = settings.detection_data_converter_func + previous_plane = None + n = len(signal_array) + n_3d_planes = 0 + middle_planes = None + + for i in tqdm.tqdm(range(0, len(signal_array), batch_size)): + batch = signal_array[i : i + batch_size] + save_tiffs(output_root, "input", i, batch, n) + + batch_clipped = torch.clone(batch) + torch.clip_(batch_clipped, 0, tile_processor.clipping_value) + save_tiffs(output_root, "clipped", i, batch_clipped, n) + + enhanced_planes = tile_processor.peak_enhancer.enhance_peaks( + batch_clipped + ) + save_tiffs(output_root, "enhanced", i, enhanced_planes, n) + + filtered_2d, inside_brain_tiles = tile_processor.get_tile_mask(batch) + save_tiffs(output_root, "inside_brain", i, inside_brain_tiles, n) + save_tiffs(output_root, "filtered_2d", i, filtered_2d, n) + + ball_filter.append(filtered_2d, inside_brain_tiles) + if ball_filter.ready: + ball_filter.walk() + middle_planes = ball_filter.get_processed_planes() + buff = middle_planes.copy() + buff[buff != settings.soma_centre_value] = 0 + save_tiffs( + output_root, + "filtered_3d", + n_3d_planes + ball_filter.first_valid_plane, + buff, + n, + ) + + detection_middle_planes = detection_converter(middle_planes) + + for k, (plane, detection_plane) in enumerate( + zip(middle_planes, detection_middle_planes) + ): + previous_plane = cell_detector.process( + detection_plane, previous_plane + ) + save_tiffs( + output_root, + "struct_id", + i + k + ball_filter.first_valid_plane, + previous_plane[None, :, :].astype(np.uint32), + n, + ) + + n_3d_planes += len(middle_planes) + + pad_3d_filtered_images( + output_root, + ball_filter, + signal_array, + middle_planes[0, :, :], + n_3d_planes, + ) + + dump_structures( + output_root, settings, splitting_settings, cell_detector, signal_array + ) + + +if __name__ == "__main__": + with torch.inference_mode(True): + filter_args = setup_filter( + Path(r"D:\tiffs\MF1_158F_W\debug\input"), + soma_diameter=8, + ball_xy_size=8, + ball_z_size=8, + end_plane=0, + ball_overlap_fraction=0.8, + log_sigma_size=0.35, + n_sds_above_mean_thresh=1, + n_sds_above_mean_local_thresh=1, + local_thresh_tile_size=0, + soma_spread_factor=4, + max_cluster_size=1000, + voxel_sizes=(4, 2.03, 2.03), + torch_device="cuda", + batch_size=4, + split_ball_xy_size=10, + split_ball_z_size=12, + split_ball_overlap_fraction=0.8, + n_splitting_iter=2, + ) + run_filter( + Path(r"D:\tiffs\MF1_158F_W\debug\output_sig_thresh"), *filter_args + ) diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 8ae1f5ff..4c44f08d 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -19,8 +19,8 @@ def main( signal_array: types.array, background_array: types.array, n_free_cpus: int, - voxel_sizes: Tuple[int, int, int], - network_voxel_sizes: Tuple[int, int, int], + voxel_sizes: Tuple[float, float, float], + network_voxel_sizes: Tuple[float, float, float], batch_size: int, cube_height: int, cube_width: int, @@ -35,6 +35,48 @@ def main( """ Parameters ---------- + + points: List of Cell objects + The potential cells to classify. + signal_array : numpy.ndarray or dask array + 3D array representing the signal data in z, y, x order. + background_array : numpy.ndarray or dask array + 3D array representing the signal data in z, y, x order. + n_free_cpus : int + How many CPU cores to leave free. + voxel_sizes : 3-tuple of floats + Size of your voxels in the z, y, and x dimensions. + network_voxel_sizes : 3-tuple of floats + Size of the pre-trained network's voxels in the z, y, and x dimensions. + batch_size : int + How many potential cells to classify at one time. The GPU/CPU + memory must be able to contain at once this many data cubes for + the models. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full. + cube_height: int + The height of the data cube centered on the cell used for + classification. Defaults to `50`. + cube_width: int + The width of the data cube centered on the cell used for + classification. Defaults to `50`. + cube_depth: int + The depth of the data cube centered on the cell used for + classification. Defaults to `20`. + trained_model : Optional[Path] + Trained model file path (home directory (default) -> pretrained + weights). + model_weights : Optional[Path] + Model weights path (home directory (default) -> pretrained + weights). + network_depth: str + The network depth to use during classification. Defaults to `"50"`. + max_workers: int + The number of sub-processes to use for data loading / processing. + Defaults to 8. + pin_memory: bool + Whether torch should pin any memory to be sent to the GPU. This results + in faster GPU uploaded, but, memory cannot be paged while it's in use. + So only use if you have enough RAM. callback : Callable[int], optional A callback function that is called during classification. Called with the batch number once that batch has been classified. diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index e7ba8e95..131fce07 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -43,16 +43,18 @@ def main( n_free_cpus: int = 2, log_sigma_size: float = 0.2, n_sds_above_mean_thresh: float = 10, + n_sds_above_mean_local_thresh: float = 10, + local_thresh_tile_size: float | None = None, outlier_keep: bool = False, artifact_keep: bool = False, save_planes: bool = False, plane_directory: Optional[str] = None, batch_size: Optional[int] = None, torch_device: Optional[str] = None, - split_ball_xy_size: int = 3, - split_ball_z_size: int = 3, + split_ball_xy_size: int = 6, + split_ball_z_size: int = 15, split_ball_overlap_fraction: float = 0.8, - split_soma_diameter: int = 7, + n_splitting_iter: int = 10, *, callback: Optional[Callable[[int], None]] = None, ) -> List[Cell]: @@ -61,69 +63,71 @@ def main( Parameters ---------- - signal_array : numpy.ndarray - 3D array representing the signal data. - + signal_array : numpy.ndarray or dask array + 3D array representing the signal data in z, y, x order. start_plane : int - Index of the starting plane for detection. - + First plane to process (to process a subset of the data). end_plane : int - Index of the ending plane for detection. - - voxel_sizes : Tuple[float, float, float] - Tuple of voxel sizes in each dimension (z, y, x). - + Last plane to process (to process a subset of the data). + voxel_sizes : 3-tuple of floats + Size of your voxels in the z, y, and x dimensions. soma_diameter : float - Diameter of the soma in physical units. - - max_cluster_size : float - Maximum size of a cluster in physical units. - + The expected in-plane (xy) soma diameter (microns). + max_cluster_size : int + Largest detected cell cluster (in cubic um) where splitting + should be attempted. Clusters above this size will be labeled + as artifacts. ball_xy_size : float - Size of the XY ball used for filtering in physical units. - + 3d filter's in-plane (xy) filter ball size (microns). ball_z_size : float - Size of the Z ball used for filtering in physical units. - + 3d filter's axial (z) filter ball size (microns). ball_overlap_fraction : float - Fraction of overlap allowed between balls. - + 3d filter's fraction of the ball filter needed to be filled by + foreground voxels, centered on a voxel, to retain the voxel. soma_spread_factor : float - Spread factor for soma size. - + Cell spread factor for determining the largest cell volume before + splitting up cell clusters. Structures with spherical volume of + diameter `soma_spread_factor * soma_diameter` or less will not be + split. n_free_cpus : int - Number of free CPU cores available for parallel processing. - + How many CPU cores to leave free. log_sigma_size : float - Size of the sigma for the log filter. - + Gaussian filter width (as a fraction of soma diameter) used during + 2d in-plane filtering. n_sds_above_mean_thresh : float - Number of standard deviations above the mean threshold. - + Intensity threshold (the number of standard deviations above + the mean) of the filtered 2d planes used to mark pixels as + foreground or background. outlier_keep : bool, optional Whether to keep outliers during detection. Defaults to False. - artifact_keep : bool, optional Whether to keep artifacts during detection. Defaults to False. - save_planes : bool, optional Whether to save the planes during detection. Defaults to False. - plane_directory : str, optional Directory path to save the planes. Defaults to None. - - batch_size : int, optional - The number of planes to process in each batch. Defaults to 1. - For CPU, there's no benefit for a larger batch size. Only a memory - usage increase. For CUDA, the larger the batch size the better the - performance. Until it fills up the GPU memory - after which it - becomes slower. - + batch_size: int + The number of planes of the original data volume to process at + once. The GPU/CPU memory must be able to contain this many planes + for all the filters. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full. torch_device : str, optional The device on which to run the computation. If not specified (None), "cuda" will be used if a GPU is available, otherwise "cpu". You can also manually specify "cuda" or "cpu". - + split_ball_xy_size: int + Similar to `ball_xy_size`, except the value to use for the 3d + filter during cluster splitting. + split_ball_z_size: int + Similar to `ball_z_size`, except the value to use for the 3d filter + during cluster splitting. + split_ball_overlap_fraction: float + Similar to `ball_overlap_fraction`, except the value to use for the + 3d filter during cluster splitting. + n_splitting_iter: int + The number of iterations to run the 3d filtering on a cluster. Each + iteration reduces the cluster size by the voxels not retained in + the previous iteration. callback : Callable[int], optional A callback function that is called every time a plane has finished being processed. Called with the plane number that has finished. @@ -131,7 +135,7 @@ def main( Returns ------- List[Cell] - List of detected cells. + List of detected potential cells and artifacts. """ start_time = datetime.now() if torch_device is None: @@ -181,25 +185,23 @@ def main( ball_overlap_fraction=ball_overlap_fraction, log_sigma_size=log_sigma_size, n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh=n_sds_above_mean_local_thresh, + local_thresh_tile_size=local_thresh_tile_size, outlier_keep=outlier_keep, artifact_keep=artifact_keep, save_planes=save_planes, plane_directory=plane_directory, batch_size=batch_size, torch_device=torch_device, + n_splitting_iter=n_splitting_iter, ) # replicate the settings specific to splitting, before we access anything # of the original settings, causing cached properties kwargs = dataclasses.asdict(settings) - kwargs["ball_z_size_um"] = split_ball_z_size * settings.z_pixel_size - kwargs["ball_xy_size_um"] = ( - split_ball_xy_size * settings.in_plane_pixel_size - ) + kwargs["ball_z_size_um"] = split_ball_z_size + kwargs["ball_xy_size_um"] = split_ball_xy_size kwargs["ball_overlap_fraction"] = split_ball_overlap_fraction - kwargs["soma_diameter_um"] = ( - split_soma_diameter * settings.in_plane_pixel_size - ) # always run on cpu because copying to gpu overhead is likely slower than # any benefit for detection on smallish volumes kwargs["torch_device"] = "cpu" @@ -219,7 +221,9 @@ def main( plane_shape=settings.plane_shape, clipping_value=settings.clipping_value, threshold_value=settings.threshold_value, - n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh=settings.n_sds_above_mean_local_thresh, + local_thresh_tile_size=settings.local_thresh_tile_size, log_sigma_size=log_sigma_size, soma_diameter=settings.soma_diameter, torch_device=torch_device, diff --git a/cellfinder/core/detect/filters/plane/plane_filter.py b/cellfinder/core/detect/filters/plane/plane_filter.py index 82434d16..e8a8f1d5 100644 --- a/cellfinder/core/detect/filters/plane/plane_filter.py +++ b/cellfinder/core/detect/filters/plane/plane_filter.py @@ -2,6 +2,7 @@ from typing import Tuple import torch +import torch.nn.functional as F from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer from cellfinder.core.detect.filters.plane.tile_walker import TileWalker @@ -63,6 +64,9 @@ class TileProcessor: # voxels who are this many std above mean or more are set to # threshold_value n_sds_above_mean_thresh: float + n_sds_above_mean_local_thresh: float + local_threshold_tile_size_px: int = 0 + torch_device: str = "" # filter that finds the peaks in the planes peak_enhancer: PeakEnhancer = field(init=False) @@ -76,6 +80,8 @@ def __init__( clipping_value: int, threshold_value: int, n_sds_above_mean_thresh: float, + n_sds_above_mean_local_thresh: float, + local_thresh_tile_size: float | None, log_sigma_size: float, soma_diameter: int, torch_device: str, @@ -85,6 +91,12 @@ def __init__( self.clipping_value = clipping_value self.threshold_value = threshold_value self.n_sds_above_mean_thresh = n_sds_above_mean_thresh + self.n_sds_above_mean_local_thresh = n_sds_above_mean_local_thresh + if local_thresh_tile_size: + self.local_threshold_tile_size_px = int( + round(soma_diameter * local_thresh_tile_size) + ) + self.torch_device = torch_device laplace_gaussian_sigma = log_sigma_size * soma_diameter self.peak_enhancer = PeakEnhancer( @@ -131,7 +143,10 @@ def get_tile_mask( planes, enhanced_planes, self.n_sds_above_mean_thresh, + self.n_sds_above_mean_local_thresh, + self.local_threshold_tile_size_px, self.threshold_value, + self.torch_device, ) return planes, inside_brain_tiles @@ -145,21 +160,99 @@ def _threshold_planes( planes: torch.Tensor, enhanced_planes: torch.Tensor, n_sds_above_mean_thresh: float, + n_sds_above_mean_local_thresh: float, + local_threshold_tile_size_px: int, threshold_value: int, + torch_device: str, ) -> None: """ Sets each plane (in-place) to threshold_value, where the corresponding enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be set to zero elsewhere. """ - planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1) + z, y, x = enhanced_planes.shape + # ---- get per-plane global threshold ---- + planes_1d = enhanced_planes.view(z, -1) # add back last dim - avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2) - sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2) - threshold = avg + n_sds_above_mean_thresh * sd + std, mean = torch.std_mean(planes_1d, dim=1, keepdim=True) + threshold = mean.unsqueeze(2) + n_sds_above_mean_thresh * std.unsqueeze(2) + above_global = enhanced_planes > threshold + + # ---- calculate the local threshold ---- + # we do 50% overlap so there's no jumps at boundaries + stride = local_threshold_tile_size_px // 2 + # make tile even for ease of computation + tile_size = stride * 2 + # Due to 50% overlap, to get tiles we move the tile by half tile (stride). + # Total moves will be y // stride - 2 (we start already with mask on first + # tile). So add back 1 for the first tile. Partial tiles are dropped + n_y_tiles = max(y // stride - 1, 1) + n_x_tiles = max(x // stride - 1, 1) + do_tile_y = n_y_tiles >= 2 + do_tile_x = n_x_tiles >= 2 + # num edge pixels dropped b/c moving by a stride would move tile off edge + y_rem = y % stride + x_rem = x % stride + # we want at least one axis to have at least two tiles + if local_threshold_tile_size_px >= 2 and (do_tile_y or do_tile_x): + enhanced_planes_raw = enhanced_planes + enhanced_planes = enhanced_planes[:, y_rem // 2 :, x_rem // 2 :] + + # add empty channel dim after z "batch" dim -> zcyx + enhanced_planes = enhanced_planes.unsqueeze(1) + # unfold makes it 3 dim, z, M, L. L is number of tiles, M is tile area + unfolded = F.unfold( + enhanced_planes, + (tile_size if do_tile_y else y, tile_size if do_tile_x else x), + stride=stride, + ) + # average the tile areas, for each tile + std, mean = torch.std_mean(unfolded, dim=1, keepdim=True) + threshold = mean + n_sds_above_mean_local_thresh * std + + # reshape it back into Y by X tiles, instead of YX being one dim + threshold = threshold.reshape((z, n_y_tiles, n_x_tiles)) + + # we need total size of n_tiles * stride + stride + rem for the + # original size. So we add 2 strides and then chop off the excess. + # We center it because of 50% overlap, the first tile is actually + # centered in between the first two strides + if do_tile_y: + repeats = ( + torch.ones(n_y_tiles, dtype=torch.int, device=torch_device) + * stride + ) + repeats[0] = 2 * stride + repeats[-1] = 2 * stride + output_size = (n_y_tiles + 2) * stride + + threshold = threshold.repeat_interleave( + repeats, dim=1, output_size=output_size + ) + offset = (stride - y_rem) // 2 + threshold = threshold[:, offset : y + offset, :] + + if do_tile_x: + repeats = ( + torch.ones(n_x_tiles, dtype=torch.int, device=torch_device) + * stride + ) + repeats[0] = 2 * stride + repeats[-1] = 2 * stride + output_size = (n_x_tiles + 2) * stride + + threshold = threshold.repeat_interleave( + repeats, dim=2, output_size=output_size + ) + offset = (stride - x_rem) // 2 + threshold = threshold[:, :, offset : x + offset] + + above_local = enhanced_planes_raw > threshold + above = torch.logical_and(above_global, above_local) + else: + above = above_global - above = enhanced_planes > threshold planes[above] = threshold_value # subsequent steps only care about the values that are set to threshold or # above in planes. We set values in *planes* to threshold based on the diff --git a/cellfinder/core/detect/filters/setup_filters.py b/cellfinder/core/detect/filters/setup_filters.py index 15db72ff..36b1afdb 100644 --- a/cellfinder/core/detect/filters/setup_filters.py +++ b/cellfinder/core/detect/filters/setup_filters.py @@ -85,18 +85,23 @@ class DetectionSettings: """ soma_spread_factor: float = 1.4 - """Spread factor for soma size - how much it may stretch in the images.""" + """ + Cell spread factor for determining the largest cell volume before + splitting up cell clusters. Structures with spherical volume of + diameter `soma_spread_factor * soma_diameter` or less will not be + split. + """ soma_diameter_um: float = 16 """ - Diameter of a typical soma in um. Bright areas larger than this will be - split. + Diameter of a typical soma in-plane (xy) in microns. """ max_cluster_size_um3: float = 100_000 """ - Maximum size of a cluster (bright area) that will be processed, in um. - Larger bright areas are skipped as artifacts. + Largest detected cell cluster (in cubic um) where splitting + should be attempted. Clusters above this size will be labeled + as artifacts. """ ball_xy_size_um: float = 6 @@ -116,19 +121,27 @@ class DetectionSettings: ball_overlap_fraction: float = 0.6 """ - Fraction of overlap between a bright area and the spherical kernel, - for the area to be considered a single ball. + Fraction of the 3d ball filter needed to be filled by foreground voxels, + centered on a voxel, to retain the voxel. """ log_sigma_size: float = 0.2 - """Size of the sigma for the 2d Gaussian filter.""" + """ + Gaussian filter width (as a fraction of soma diameter) used during + 2d in-plane filtering. + """ n_sds_above_mean_thresh: float = 10 """ - Number of standard deviations above the mean intensity to use for a - threshold to define bright areas. Below it, it's not considered bright. + Intensity threshold (the number of standard deviations above + the mean) of the filtered 2d planes used to mark pixels as + foreground or background. """ + n_sds_above_mean_local_thresh: float = 10 + + local_thresh_tile_size: float | None = None + outlier_keep: bool = False """Whether to keep outlier structures during detection.""" @@ -191,6 +204,8 @@ class DetectionSettings: """ During the structure splitting phase we iteratively shrink the bright areas and re-filter with the 3d filter. This is the number of iterations to do. + Each iteration reduces the cluster size by the voxels not retained in the + previous iteration. This is a maximum because we also stop if there are no more structures left during any iteration. diff --git a/cellfinder/core/detect/filters/volume/structure_splitting.py b/cellfinder/core/detect/filters/volume/structure_splitting.py index e1d4a5f0..0d3bd358 100644 --- a/cellfinder/core/detect/filters/volume/structure_splitting.py +++ b/cellfinder/core/detect/filters/volume/structure_splitting.py @@ -1,3 +1,4 @@ +from copy import copy from typing import List, Tuple, Type import numpy as np @@ -224,6 +225,7 @@ def split_cells( where M is the number of individual cells and each centre is represented by its x, y, and z coordinates. """ + settings = copy(settings) # these points are in x, y, z order columnwise, in absolute pixels orig_centre = get_structure_centre(cell_points) diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 5f7e39d1..a597d0de 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -1,34 +1,39 @@ import os from typing import Callable, List, Optional, Tuple -import numpy as np from brainglobe_utils.cells.cells import Cell -from cellfinder.core import logger +from cellfinder.core import logger, types from cellfinder.core.download.download import model_type from cellfinder.core.train.train_yaml import depth_type def main( - signal_array: np.ndarray, - background_array: np.ndarray, - voxel_sizes: Tuple[int, int, int], + signal_array: types.array, + background_array: types.array, + voxel_sizes: Tuple[float, float, float], start_plane: int = 0, end_plane: int = -1, trained_model: Optional[os.PathLike] = None, model_weights: Optional[os.PathLike] = None, model: model_type = "resnet50_tv", - batch_size: int = 64, + classification_batch_size: int = 64, n_free_cpus: int = 2, - network_voxel_sizes: Tuple[int, int, int] = (5, 1, 1), + network_voxel_sizes: Tuple[float, float, float] = (5, 1, 1), soma_diameter: int = 16, ball_xy_size: int = 6, ball_z_size: int = 15, ball_overlap_fraction: float = 0.6, log_sigma_size: float = 0.2, n_sds_above_mean_thresh: float = 10, + n_sds_above_mean_local_thresh: float = 10, + local_thresh_tile_size: float | None = None, soma_spread_factor: float = 1.4, max_cluster_size: int = 100000, + split_ball_xy_size: int = 6, + split_ball_z_size: int = 15, + split_ball_overlap_fraction: float = 0.8, + n_splitting_iter: int = 10, cube_width: int = 50, cube_height: int = 50, cube_depth: int = 20, @@ -36,7 +41,7 @@ def main( skip_detection: bool = False, skip_classification: bool = False, detected_cells: List[Cell] = None, - classification_batch_size: Optional[int] = None, + detection_batch_size: Optional[int] = None, torch_device: Optional[str] = None, *, detect_callback: Optional[Callable[[int], None]] = None, @@ -46,6 +51,100 @@ def main( """ Parameters ---------- + signal_array : numpy.ndarray or dask array + 3D array representing the signal data in z, y, x order. + background_array : numpy.ndarray or dask array + 3D array representing the signal data in z, y, x order. + voxel_sizes : 3-tuple of floats + Size of your voxels in the z, y, and x dimensions. + start_plane : int + First plane to process (to process a subset of the data). + end_plane : int + Last plane to process (to process a subset of the data). + trained_model : Optional[Path] + Trained model file path (home directory (default) -> pretrained + weights). + model_weights : Optional[Path] + Model weights path (home directory (default) -> pretrained + weights). + model: str + Type of model to use. Defaults to `"resnet50_tv"`. + classification_batch_size : int + How many potential cells to classify at one time. The GPU/CPU + memory must be able to contain at once this many data cubes for + the models. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full. + n_free_cpus : int + How many CPU cores to leave free. + network_voxel_sizes : 3-tuple of floats + Size of the pre-trained network's voxels in the z, y, and x dimensions. + soma_diameter : float + The expected in-plane (xy) soma diameter (microns). + ball_xy_size : float + 3d filter's in-plane (xy) filter ball size (microns). + ball_z_size : float + 3d filter's axial (z) filter ball size (microns). + ball_overlap_fraction : float + 3d filter's fraction of the ball filter needed to be filled by + foreground voxels, centered on a voxel, to retain the voxel. + log_sigma_size : float + Gaussian filter width (as a fraction of soma diameter) used during + 2d in-plane filtering. + n_sds_above_mean_thresh : float + Intensity threshold (the number of standard deviations above + the mean) of the filtered 2d planes used to mark pixels as + foreground or background. + soma_spread_factor : float + Cell spread factor for determining the largest cell volume before + splitting up cell clusters. Structures with spherical volume of + diameter `soma_spread_factor * soma_diameter` or less will not be + split. + max_cluster_size : int + Largest detected cell cluster (in cubic um) where splitting + should be attempted. Clusters above this size will be labeled + as artifacts. + split_ball_xy_size: int + Similar to `ball_xy_size`, except the value to use for the 3d + filter during cluster splitting. + split_ball_z_size: int + Similar to `ball_z_size`, except the value to use for the 3d filter + during cluster splitting. + split_ball_overlap_fraction: float + Similar to `ball_overlap_fraction`, except the value to use for the + 3d filter during cluster splitting. + n_splitting_iter: int + The number of iterations to run the 3d filtering on a cluster. Each + iteration reduces the cluster size by the voxels not retained in + the previous iteration. + cube_width: int + The width of the data cube centered on the cell used for + classification. Defaults to `50`. + cube_height: int + The height of the data cube centered on the cell used for + classification. Defaults to `50`. + cube_depth: int + The depth of the data cube centered on the cell used for + classification. Defaults to `20`. + network_depth: str + The network depth to use during classification. Defaults to `"50"`. + skip_detection : bool + If selected, the detection step is skipped and instead we get the + detected cells from the cell layer below (from a previous + detection run or import). + skip_classification : bool + If selected, the classification step is skipped and all cells from + the detection stage are added. + detected_cells: Optional list of Cell objects. + If specified, the cells to use during classification. + detection_batch_size: int + The number of planes of the original data volume to process at + once. The GPU/CPU memory must be able to contain this many planes + for all the filters. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full. + torch_device : str, optional + The device on which to run the computation. If not specified (None), + "cuda" will be used if a GPU is available, otherwise "cpu". + You can also manually specify "cuda" or "cpu". detect_callback : Callable[int], optional Called every time a plane has finished being processed during the detection stage. Called with the plane number that has finished. @@ -76,9 +175,15 @@ def main( n_free_cpus, log_sigma_size, n_sds_above_mean_thresh, - batch_size=classification_batch_size, + n_sds_above_mean_local_thresh, + local_thresh_tile_size, + batch_size=detection_batch_size, torch_device=torch_device, callback=detect_callback, + split_ball_z_size=split_ball_z_size, + split_ball_xy_size=split_ball_xy_size, + split_ball_overlap_fraction=split_ball_overlap_fraction, + n_splitting_iter=n_splitting_iter, ) if detect_finished_callback is not None: @@ -101,7 +206,7 @@ def main( n_free_cpus, voxel_sizes, network_voxel_sizes, - batch_size, + classification_batch_size, cube_height, cube_width, cube_depth, diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index fcb01ce5..eb527b74 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -244,18 +244,21 @@ def widget( detection_options, skip_detection: bool, soma_diameter: float, + log_sigma_size: float, + n_sds_above_mean_thresh: float, + n_sds_above_mean_local_thresh: float, + local_thresh_tile_size: float, ball_xy_size: float, ball_z_size: float, ball_overlap_fraction: float, - log_sigma_size: float, - n_sds_above_mean_thresh: int, + detection_batch_size: int, soma_spread_factor: float, max_cluster_size: int, classification_options, skip_classification: bool, use_pre_trained_weights: bool, trained_model: Optional[Path], - batch_size: int, + classification_batch_size: int, misc_options, start_plane: int, end_plane: int, @@ -281,33 +284,48 @@ def widget( detected cells from the cell layer below (from a previous detection run or import) soma_diameter : float - The expected in-plane soma diameter (microns) + The expected in-plane (xy) soma diameter (microns) + log_sigma_size : float + Gaussian filter width (as a fraction of soma diameter) used during + 2d in-plane filtering + n_sds_above_mean_thresh : float + Intensity threshold (the number of standard deviations above + the mean) of the filtered 2d planes used to mark pixels as + foreground or background ball_xy_size : float - Elliptical morphological in-plane filter size (microns) + 3d filter's in-plane (xy) filter ball size (microns) ball_z_size : float - Elliptical morphological axial filter size (microns) + 3d filter's axial (z) filter ball size (microns) ball_overlap_fraction : float - Fraction of the morphological filter needed to be filled - to retain a voxel - log_sigma_size : float - Laplacian of Gaussian filter width (as a fraction of soma diameter) - n_sds_above_mean_thresh : int - Cell intensity threshold (as a multiple of noise above the mean) + 3d filter's fraction of the ball filter needed to be filled by + foreground voxels, centered on a voxel, to retain the voxel + detection_batch_size: int + The number of planes of the original data volume to process at + once. The GPU/CPU memory must be able to contain this many planes + for all the filters. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full soma_spread_factor : float - Cell spread factor (for splitting up cell clusters) + Cell spread factor for determining the largest cell volume before + splitting up cell clusters. Structures with spherical volume of + diameter `soma_spread_factor * soma_diameter` or less will not be + split max_cluster_size : int - Largest putative cell cluster (in cubic um) where splitting - should be attempted - use_pre_trained_weights : bool - Select to use pre-trained model weights - batch_size : int - How many points to classify at one time + Largest detected cell cluster (in cubic um) where splitting + should be attempted. Clusters above this size will be labeled + as artifacts skip_classification : bool If selected, the classification step is skipped and all cells from the detection stage are added + use_pre_trained_weights : bool + Select to use pre-trained model weights trained_model : Optional[Path] Trained model file path (home directory (default) -> pretrained weights) + classification_batch_size : int + How many potential cells to classify at one time. The GPU/CPU + memory must be able to contain at once this many data cubes for + the models. Tune to maximize memory usage without running + out. Check your GPU/CPU memory to verify it's not full start_plane : int First plane to process (to process a subset of the data) end_plane : int @@ -371,8 +389,11 @@ def widget( ball_overlap_fraction, log_sigma_size, n_sds_above_mean_thresh, + n_sds_above_mean_local_thresh, + local_thresh_tile_size, soma_spread_factor, max_cluster_size, + detection_batch_size, ) if use_pre_trained_weights: @@ -381,7 +402,7 @@ def widget( skip_classification, use_pre_trained_weights, trained_model, - batch_size, + classification_batch_size, ) if analyse_local: diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 5d8cdb7a..b3c84da1 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -68,9 +68,12 @@ class DetectionInputs(InputContainer): ball_z_size: float = 15 ball_overlap_fraction: float = 0.6 log_sigma_size: float = 0.2 - n_sds_above_mean_thresh: int = 10 + n_sds_above_mean_thresh: float = 10 + n_sds_above_mean_local_thresh: float = 10 + local_thresh_tile_size: float = 0 soma_spread_factor: float = 1.4 max_cluster_size: int = 100000 + detection_batch_size: int = 4 def as_core_arguments(self) -> dict: return super().as_core_arguments() @@ -96,15 +99,25 @@ def widget_representation(cls) -> dict: n_sds_above_mean_thresh=cls._custom_widget( "n_sds_above_mean_thresh", custom_label="Threshold" ), + n_sds_above_mean_local_thresh=cls._custom_widget( + "n_sds_above_mean_local_thresh", custom_label="Local threshold" + ), + local_thresh_tile_size=cls._custom_widget( + "local_thresh_tile_size", + custom_label="Local thresholding tile size", + ), soma_spread_factor=cls._custom_widget( - "soma_spread_factor", custom_label="Cell spread" + "soma_spread_factor", custom_label="Split cell spread" ), max_cluster_size=cls._custom_widget( "max_cluster_size", - custom_label="Max cluster", + custom_label="Split max cluster", min=0, max=10000000, ), + detection_batch_size=cls._custom_widget( + "detection_batch_size", custom_label="Batch size" + ), ) @@ -115,7 +128,7 @@ class ClassificationInputs(InputContainer): skip_classification: bool = False use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() - batch_size: int = 64 + classification_batch_size: int = 64 def as_core_arguments(self) -> dict: args = super().as_core_arguments() @@ -133,7 +146,10 @@ def widget_representation(cls) -> dict: skip_classification=dict( value=cls.defaults()["skip_classification"] ), - batch_size=dict(value=cls.defaults()["batch_size"]), + classification_batch_size=dict( + value=cls.defaults()["classification_batch_size"], + label="Batch size", + ), )