From 91c6f218a9a99efaf6b8e50f5eff2ee2df410281 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Tue, 17 Jun 2025 15:25:31 -0400 Subject: [PATCH] Add local tiled thresholding to 2d filter. --- benchmarks/filter_2d.py | 2 + cellfinder/core/detect/detect.py | 26 ++- .../core/detect/filters/plane/plane_filter.py | 115 +++++++++- .../core/detect/filters/setup_filters.py | 24 ++- cellfinder/core/main.py | 48 +++-- cellfinder/napari/detect/detect.py | 22 +- cellfinder/napari/detect/detect_containers.py | 11 +- .../test_classical_filters.py | 196 ++++++++++++++++++ 8 files changed, 412 insertions(+), 32 deletions(-) diff --git a/benchmarks/filter_2d.py b/benchmarks/filter_2d.py index 067fc5d8..96d4e002 100644 --- a/benchmarks/filter_2d.py +++ b/benchmarks/filter_2d.py @@ -50,6 +50,8 @@ def setup_filter( soma_diameter=settings.soma_diameter, log_sigma_size=settings.log_sigma_size, n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh=settings.n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size=settings.tiled_thresh_tile_size, torch_device=torch_device, dtype=settings.filtering_dtype.__name__, use_scipy=use_scipy, diff --git a/cellfinder/core/detect/detect.py b/cellfinder/core/detect/detect.py index 076b3ae5..2647a206 100644 --- a/cellfinder/core/detect/detect.py +++ b/cellfinder/core/detect/detect.py @@ -53,6 +53,8 @@ def main( split_ball_z_size: float = 15, split_ball_overlap_fraction: float = 0.8, n_splitting_iter: int = 10, + n_sds_above_mean_tiled_thresh: float = 10, + tiled_thresh_tile_size: float | None = None, *, callback: Optional[Callable[[int], None]] = None, ) -> List[Cell]: @@ -95,8 +97,8 @@ def main( Gaussian filter width (as a fraction of soma diameter) used during 2d in-plane Laplacian of Gaussian filtering. n_sds_above_mean_thresh : float - Intensity threshold (the number of standard deviations above - the mean) of the filtered 2d planes used to mark pixels as + Per-plane intensity threshold (the number of standard deviations + above the mean) of the filtered 2d planes used to mark pixels as foreground or background. outlier_keep : bool, optional Whether to keep outliers during detection. Defaults to False. @@ -129,6 +131,20 @@ def main( The number of iterations to run the 3d filtering on a cluster. Each iteration reduces the cluster size by the voxels not retained in the previous iteration. + n_sds_above_mean_tiled_thresh : float + Per-plane, per-tile intensity threshold (the number of standard + deviations above the mean) for the filtered 2d planes used to mark + pixels as foreground or background. When used, (tile size is not zero) + a pixel is marked as foreground if its intensity is above both the + per-plane and per-tile threshold. I.e. it's above the set number of + standard deviations of the per-plane average and of the per-plane + per-tile average for the tile that contains it. + tiled_thresh_tile_size : float + The tile size used to tile the x, y plane to calculate the local + average intensity for the tiled threshold. The value is multiplied + by soma diameter (i.e. 1 means one soma diameter). If zero or None, the + tiled threshold is disabled and only the per-plane threshold is used. + Tiling is done with 50% overlap when striding. callback : Callable[int], optional A callback function that is called every time a plane has finished being processed. Called with the plane number that has finished. @@ -186,6 +202,8 @@ def main( ball_overlap_fraction=ball_overlap_fraction, log_sigma_size=log_sigma_size, n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size=tiled_thresh_tile_size, outlier_keep=outlier_keep, artifact_keep=artifact_keep, save_planes=save_planes, @@ -220,7 +238,9 @@ def main( plane_shape=settings.plane_shape, clipping_value=settings.clipping_value, threshold_value=settings.threshold_value, - n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh=settings.n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size=settings.tiled_thresh_tile_size, log_sigma_size=log_sigma_size, soma_diameter=settings.soma_diameter, torch_device=torch_device, diff --git a/cellfinder/core/detect/filters/plane/plane_filter.py b/cellfinder/core/detect/filters/plane/plane_filter.py index 20fabc27..29606666 100644 --- a/cellfinder/core/detect/filters/plane/plane_filter.py +++ b/cellfinder/core/detect/filters/plane/plane_filter.py @@ -1,13 +1,12 @@ -from dataclasses import dataclass, field from typing import Tuple import torch +import torch.nn.functional as F from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer from cellfinder.core.detect.filters.plane.tile_walker import TileWalker -@dataclass class TileProcessor: """ Processor that filters each plane to highlight the peaks and also @@ -63,12 +62,22 @@ class TileProcessor: # voxels who are this many std above mean or more are set to # threshold_value n_sds_above_mean_thresh: float + # If used, voxels who are this many or more std above mean of the + # containing tile as well as above n_sds_above_mean_thresh for the plane + # average are set to threshold_value. + n_sds_above_mean_tiled_thresh: float + # the tile size, in pixels, that will be used to tile the x, y plane when + # we calculate the per-tile mean / std for use with + # n_sds_above_mean_tiled_thresh. We use 50% overlap when tiling. + local_threshold_tile_size_px: int = 0 + # the torch device name + torch_device: str = "" # filter that finds the peaks in the planes - peak_enhancer: PeakEnhancer = field(init=False) + peak_enhancer: PeakEnhancer = None # generates tiles of the planes, with each tile marked as being inside # or outside the brain based on brightness - tile_walker: TileWalker = field(init=False) + tile_walker: TileWalker = None def __init__( self, @@ -76,6 +85,8 @@ def __init__( clipping_value: int, threshold_value: int, n_sds_above_mean_thresh: float, + n_sds_above_mean_tiled_thresh: float, + tiled_thresh_tile_size: float | None, log_sigma_size: float, soma_diameter: int, torch_device: str, @@ -85,6 +96,12 @@ def __init__( self.clipping_value = clipping_value self.threshold_value = threshold_value self.n_sds_above_mean_thresh = n_sds_above_mean_thresh + self.n_sds_above_mean_tiled_thresh = n_sds_above_mean_tiled_thresh + if tiled_thresh_tile_size: + self.local_threshold_tile_size_px = int( + round(soma_diameter * tiled_thresh_tile_size) + ) + self.torch_device = torch_device laplace_gaussian_sigma = log_sigma_size * soma_diameter self.peak_enhancer = PeakEnhancer( @@ -131,7 +148,10 @@ def get_tile_mask( planes, enhanced_planes, self.n_sds_above_mean_thresh, + self.n_sds_above_mean_tiled_thresh, + self.local_threshold_tile_size_px, self.threshold_value, + self.torch_device, ) return planes, inside_brain_tiles @@ -145,21 +165,98 @@ def _threshold_planes( planes: torch.Tensor, enhanced_planes: torch.Tensor, n_sds_above_mean_thresh: float, + n_sds_above_mean_tiled_thresh: float, + local_threshold_tile_size_px: int, threshold_value: int, + torch_device: str, ) -> None: """ Sets each plane (in-place) to threshold_value, where the corresponding enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be set to zero elsewhere. """ - planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1) + z, y, x = enhanced_planes.shape + # ---- get per-plane global threshold ---- + planes_1d = enhanced_planes.view(z, -1) # add back last dim - avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2) - sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2) - threshold = avg + n_sds_above_mean_thresh * sd + std, mean = torch.std_mean(planes_1d, dim=1, keepdim=True) + threshold = mean.unsqueeze(2) + n_sds_above_mean_thresh * std.unsqueeze(2) + above_global = enhanced_planes > threshold + + # ---- calculate the local tiled threshold ---- + # we do 50% overlap so there's no jumps at boundaries + stride = local_threshold_tile_size_px // 2 + # make tile even for ease of computation + tile_size = stride * 2 + # Due to 50% overlap, to get tiles we move the tile by half tile (stride). + # Total moves will be y // stride - 2 (we start already with mask on first + # tile). So add back 1 for the first tile. Partial tiles are dropped + n_y_tiles = max(y // stride - 1, 1) if stride else 1 + n_x_tiles = max(x // stride - 1, 1) if stride else 1 + do_tile_y = n_y_tiles >= 2 + do_tile_x = n_x_tiles >= 2 + # we want at least one axis to have at least two tiles + if local_threshold_tile_size_px >= 2 and (do_tile_y or do_tile_x): + # num edge pixels dropped b/c moving by stride would move tile off edge + y_rem = y % stride + x_rem = x % stride + enhanced_planes_raw = enhanced_planes + if do_tile_y: + enhanced_planes = enhanced_planes[:, y_rem // 2 :, :] + if do_tile_x: + enhanced_planes = enhanced_planes[:, :, x_rem // 2 :] + + # add empty channel dim after z "batch" dim -> zcyx + enhanced_planes = enhanced_planes.unsqueeze(1) + # unfold makes it 3 dim, z, M, L. L is number of tiles, M is tile area + unfolded = F.unfold( + enhanced_planes, + (tile_size if do_tile_y else y, tile_size if do_tile_x else x), + stride=stride, + ) + # average the tile areas, for each tile + std, mean = torch.std_mean(unfolded, dim=1, keepdim=True) + threshold = mean + n_sds_above_mean_tiled_thresh * std + + # reshape it back into Y by X tiles, instead of YX being one dim + threshold = threshold.reshape((z, n_y_tiles, n_x_tiles)) + + # we need total size of n_tiles * stride + stride + rem for the + # original size. So we add 2 strides and then chop off the excess above + # rem. We center it because of 50% overlap, the first tile is actually + # centered in between the first two strides + offsets = [(0, y), (0, x)] + for dim, do_tile, n_tiles, n, rem in [ + (1, do_tile_y, n_y_tiles, y, y_rem), + (2, do_tile_x, n_x_tiles, x, x_rem), + ]: + if do_tile: + repeats = ( + torch.ones(n_tiles, dtype=torch.int, device=torch_device) + * stride + ) + # add total of 2 additional strides + repeats[0] = 2 * stride + repeats[-1] = 2 * stride + output_size = (n_tiles + 2) * stride + + threshold = threshold.repeat_interleave( + repeats, dim=dim, output_size=output_size + ) + # drop the excess we gained from padding rem to whole stride + offset = (stride - rem) // 2 + offsets[dim - 1] = offset, n + offset + + # can't use slice(...) objects in jit code so use actual indices + (a, b), (c, d) = offsets + threshold = threshold[:, a:b, c:d] + + above_local = enhanced_planes_raw > threshold + above = torch.logical_and(above_global, above_local) + else: + above = above_global - above = enhanced_planes > threshold planes[above] = threshold_value # subsequent steps only care about the values that are set to threshold or # above in planes. We set values in *planes* to threshold based on the diff --git a/cellfinder/core/detect/filters/setup_filters.py b/cellfinder/core/detect/filters/setup_filters.py index 1cf13ee8..6f9d94bf 100644 --- a/cellfinder/core/detect/filters/setup_filters.py +++ b/cellfinder/core/detect/filters/setup_filters.py @@ -133,11 +133,31 @@ class DetectionSettings: n_sds_above_mean_thresh: float = 10 """ - Intensity threshold (the number of standard deviations above - the mean) of the filtered 2d planes used to mark pixels as + Per-plane intensity threshold (the number of standard deviations + above the mean) of the 2d filtered planes used to mark pixels as foreground or background. """ + n_sds_above_mean_tiled_thresh: float = 10 + """ + Per-plane, per-tile intensity threshold (the number of standard deviations + above the mean) for the filtered 2d planes used to mark pixels as + foreground or background. When used, (tile size is not zero) a pixel is + marked as foreground if its intensity is above both the per-plane and + per-tile threshold. I.e. it's above the set number of standard deviations + of the per-plane average and of the per-plane per-tile average for the tile + that contains it. + """ + + tiled_thresh_tile_size: float | None = None + """ + The tile size used to tile the x, y plane to calculate the local average + intensity for the tiled threshold. The value is multiplied by soma + diameter (i.e. 1 means one soma diameter). If zero or None, the tiled + threshold is disabled and only the per-plane threshold is used. Tiling is + done with 50% overlap when striding. + """ + outlier_keep: bool = False """Whether to keep outlier structures during detection.""" diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 7ff3418b..0b0876cc 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -41,6 +41,8 @@ def main( split_ball_z_size: float = 15, split_ball_overlap_fraction: float = 0.8, n_splitting_iter: int = 10, + n_sds_above_mean_tiled_thresh: float = 10, + tiled_thresh_tile_size: float | None = None, *, detect_callback: Optional[Callable[[int], None]] = None, classify_callback: Optional[Callable[[int], None]] = None, @@ -93,8 +95,8 @@ def main( Gaussian filter width (as a fraction of soma diameter) used during 2d in-plane Laplacian of Gaussian filtering. n_sds_above_mean_thresh : float - Intensity threshold (the number of standard deviations above - the mean) of the filtered 2d planes used to mark pixels as + Per-plane intensity threshold (the number of standard deviations + above the mean) of the filtered 2d planes used to mark pixels as foreground or background. soma_spread_factor : float Cell spread factor for determining the largest cell volume before @@ -148,6 +150,20 @@ def main( The number of iterations to run the 3d filtering on a cluster. Each iteration reduces the cluster size by the voxels not retained in the previous iteration. + n_sds_above_mean_tiled_thresh : float + Per-plane, per-tile intensity threshold (the number of standard + deviations above the mean) for the filtered 2d planes used to mark + pixels as foreground or background. When used, (tile size is not zero) + a pixel is marked as foreground if its intensity is above both the + per-plane and per-tile threshold. I.e. it's above the set number of + standard deviations of the per-plane average and of the per-plane + per-tile average for the tile that contains it. + tiled_thresh_tile_size : float + The tile size used to tile the x, y plane to calculate the local + average intensity for the tiled threshold. The value is multiplied + by soma diameter (i.e. 1 means one soma diameter). If zero or None, the + tiled threshold is disabled and only the per-plane threshold is used. + Tiling is done with 50% overlap when striding. detect_callback : Callable[int], optional Called every time a plane has finished being processed during the detection stage. Called with the plane number that has finished. @@ -165,19 +181,21 @@ def main( logger.info("Detecting cell candidates") points = detect.main( - signal_array, - start_plane, - end_plane, - voxel_sizes, - soma_diameter, - max_cluster_size, - ball_xy_size, - ball_z_size, - ball_overlap_fraction, - soma_spread_factor, - n_free_cpus, - log_sigma_size, - n_sds_above_mean_thresh, + signal_array=signal_array, + start_plane=start_plane, + end_plane=end_plane, + voxel_sizes=voxel_sizes, + soma_diameter=soma_diameter, + max_cluster_size=max_cluster_size, + ball_xy_size=ball_xy_size, + ball_z_size=ball_z_size, + ball_overlap_fraction=ball_overlap_fraction, + soma_spread_factor=soma_spread_factor, + n_free_cpus=n_free_cpus, + log_sigma_size=log_sigma_size, + n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size=tiled_thresh_tile_size, batch_size=detection_batch_size, torch_device=torch_device, callback=detect_callback, diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index 2503d412..a42508a3 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -246,6 +246,8 @@ def widget( soma_diameter: float, log_sigma_size: float, n_sds_above_mean_thresh: float, + n_sds_above_mean_tiled_thresh: float, + tiled_thresh_tile_size: float, ball_xy_size: float, ball_z_size: float, ball_overlap_fraction: float, @@ -287,9 +289,23 @@ def widget( Gaussian filter width (as a fraction of soma diameter) used during 2d in-plane Laplacian of Gaussian filtering n_sds_above_mean_thresh : float - Intensity threshold (the number of standard deviations above - the mean) of the filtered 2d planes used to mark pixels as + Per-plane intensity threshold (the number of standard deviations + above the mean) of the filtered 2d planes used to mark pixels as foreground or background + n_sds_above_mean_tiled_thresh : float + Per-plane, per-tile intensity threshold (the number of standard + deviations above the mean) for the filtered 2d planes used to mark + pixels as foreground or background. When used, (tile size is not + zero) a pixel is marked as foreground if its intensity is above + both the per-plane and per-tile threshold. I.e. it's above the set + number of standard deviations of the per-plane average and of the + per-plane per-tile average for the tile that contains it. + tiled_thresh_tile_size : float + The tile size used to tile the x, y plane to calculate the local + average intensity for the tiled threshold. The value is multiplied + by soma diameter (i.e. 1 means one soma diameter). If zero, the + tiled threshold is disabled and only the per-plane threshold is + used. Tiling is done with 50% overlap when striding. ball_xy_size : float 3d filter's in-plane (xy) filter ball size (microns) ball_z_size : float @@ -389,6 +405,8 @@ def widget( ball_overlap_fraction, log_sigma_size, n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size, soma_spread_factor, max_cluster_size, detection_batch_size, diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 5a130853..3fdf9653 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -69,6 +69,8 @@ class DetectionInputs(InputContainer): ball_overlap_fraction: float = 0.6 log_sigma_size: float = 0.2 n_sds_above_mean_thresh: float = 10 + n_sds_above_mean_tiled_thresh: float = 10 + tiled_thresh_tile_size: float = 0 soma_spread_factor: float = 1.4 max_cluster_size: float = 100000 detection_batch_size: int = 1 @@ -95,7 +97,14 @@ def widget_representation(cls) -> dict: "log_sigma_size", custom_label="Filter width" ), n_sds_above_mean_thresh=cls._custom_widget( - "n_sds_above_mean_thresh", custom_label="Threshold" + "n_sds_above_mean_thresh", custom_label="Plane threshold" + ), + n_sds_above_mean_tiled_thresh=cls._custom_widget( + "n_sds_above_mean_tiled_thresh", custom_label="Tiled threshold" + ), + tiled_thresh_tile_size=cls._custom_widget( + "tiled_thresh_tile_size", + custom_label="Thresholding tile size", ), soma_spread_factor=cls._custom_widget( "soma_spread_factor", custom_label="Split cell spread" diff --git a/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py b/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py index 181c912c..4dba4fb9 100644 --- a/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py +++ b/tests/core/test_unit/test_detect/test_filters/test_plane_filters/test_classical_filters.py @@ -140,6 +140,8 @@ def test_2d_filtering_parity( soma_diameter=soma_diameter, log_sigma_size=0.2, n_sds_above_mean_thresh=10, + n_sds_above_mean_tiled_thresh=10, + tiled_thresh_tile_size=0, torch_device=torch_device, dtype=settings.filtering_dtype.__name__, use_scipy=use_scipy, @@ -187,6 +189,8 @@ def test_2d_filter_padding(plane_size): soma_diameter=16, log_sigma_size=0.2, n_sds_above_mean_thresh=10, + n_sds_above_mean_tiled_thresh=10, + tiled_thresh_tile_size=0, torch_device="cpu", dtype=settings.filtering_dtype.__name__, use_scipy=False, @@ -243,3 +247,195 @@ def test_tile_walker_size(sizes, soma_diameter=5): data = torch.rand((1, *plane_size), dtype=torch.float32) tiles = walker.get_bright_tiles(data) assert tiles.shape == (1, *tile_size) + + +def get_filtered_data( + data: np.ndarray, + soma_diameter=16, + log_sigma_size=0.2, + n_sds_above_mean_thresh=10.0, + n_sds_above_mean_tiled_thresh=10.0, + tiled_thresh_tile_size=0.0, +) -> np.ndarray: + settings = DetectionSettings(plane_original_np_dtype=np.uint16) + data = data.astype(settings.filtering_dtype) + + tile_processor = TileProcessor( + plane_shape=data.shape[1:], + clipping_value=settings.clipping_value, + threshold_value=settings.threshold_value, + soma_diameter=soma_diameter, + log_sigma_size=log_sigma_size, + n_sds_above_mean_thresh=n_sds_above_mean_thresh, + n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh, + tiled_thresh_tile_size=tiled_thresh_tile_size, + torch_device="cpu", + dtype=settings.filtering_dtype.__name__, + use_scipy=True, + ) + + filtered, _ = tile_processor.get_tile_mask(torch.from_numpy(data)) + return (filtered == settings.threshold_value).numpy() + + +def test_2d_filter_plane_threshold_single_spot(): + # make bright area of 5x5 = 25 + data = np.zeros((1, 50, 50)) + data[0, 23:28, 23:28] = 10 + + # use normal threshold + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=1, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=0, + ) + # about 25 pixels should be marked + assert 20 <= np.sum(filtered) <= 30 + + # use very high threshold + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=50, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=0, + ) + # with high threshold, should be no marked pixels + assert not np.sum(filtered) + + +def test_2d_filter_plane_threshold_2_spots(): + # create 2 bright areas of 5x5 = 25px, one bright, one darker + data = np.zeros((1, 50, 50)) + data[0, 13:18, 13:18] = 5 + data[0, 33:38, 33:38] = 20 + + # low threshold should get both areas + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=0.1, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=0, + ) + assert 40 <= np.sum(filtered) <= 60 + + # medium threshold should get very bright area + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=2, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=0, + ) + assert 20 <= np.sum(filtered) <= 35 + + # high threshold should get no area + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=50, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=0, + ) + assert not np.sum(filtered) + + +def test_2d_filter_tiled_threshold_2_spots(): + # create 2 bright areas of 5x5 = 25px, one bright, one darker + data = np.zeros((1, 50, 50)) + data[0, 3:8, 3:8] = 5 + data[0, 43:48, 43:48] = 20 + + # medium plane threshold should get only very bright area + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=2, + n_sds_above_mean_tiled_thresh=2, + tiled_thresh_tile_size=0, + ) + assert 20 <= np.sum(filtered) <= 35 + + # with small tiles (size of soma) the mean would be high for the tiles with + # both bright areas so we should get no pixels + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=2, + n_sds_above_mean_tiled_thresh=2, + tiled_thresh_tile_size=1, + ) + assert not np.sum(filtered) + + # but with a very low tiled threshold we should get same as with plane + # threshold only + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=2, + n_sds_above_mean_tiled_thresh=-2, + tiled_thresh_tile_size=1, + ) + assert 20 <= np.sum(filtered) <= 35 + + # and with a low plane threshold as well we should get everything + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=0, + n_sds_above_mean_tiled_thresh=-2, + tiled_thresh_tile_size=1, + ) + assert 40 <= np.sum(filtered) <= 60 + + +@pytest.mark.parametrize( + "shape", [(1, 50, 23), (1, 23, 50), (1, 25, 25), (1, 57, 57)] +) +def test_2d_filter_tiled_threshold_odd_shapes(shape): + # our tile size is 5 * 5 = 25, check that plane shapes that don't fit two + # tiles or are not multiple of tile size still works + # create bright area of 5x5 = 25px + data = np.zeros(shape) + data[0, 3:8, 3:8] = 5 + + # use tiles size of 25 (5 x soma diameter of 5) + filtered = get_filtered_data( + data, + soma_diameter=5, + log_sigma_size=0.2, + n_sds_above_mean_thresh=1, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=5, + ) + # about 25 pixels should be marked + assert 20 <= np.sum(filtered) <= 30 + + +@pytest.mark.parametrize("size", [0, 1, 2, 3]) +def test_2d_filter_tiled_threshold_odd_tile_size(size): + # check that tiny tile sizes works. + data = np.zeros((1, 10, 10)) + + # use tiles size of 25 (5 x soma diameter of 5) + filtered = get_filtered_data( + data, + soma_diameter=1, + log_sigma_size=0.2, + n_sds_above_mean_thresh=1, + n_sds_above_mean_tiled_thresh=1, + tiled_thresh_tile_size=size, + ) + assert filtered.shape == (1, 10, 10)