Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/filter_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def setup_filter(
soma_diameter=settings.soma_diameter,
log_sigma_size=settings.log_sigma_size,
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh,
n_sds_above_mean_tiled_thresh=settings.n_sds_above_mean_tiled_thresh,
tiled_thresh_tile_size=settings.tiled_thresh_tile_size,
torch_device=torch_device,
dtype=settings.filtering_dtype.__name__,
use_scipy=use_scipy,
Expand Down
26 changes: 23 additions & 3 deletions cellfinder/core/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def main(
split_ball_z_size: float = 15,
split_ball_overlap_fraction: float = 0.8,
n_splitting_iter: int = 10,
n_sds_above_mean_tiled_thresh: float = 10,
tiled_thresh_tile_size: float | None = None,
*,
callback: Optional[Callable[[int], None]] = None,
) -> List[Cell]:
Expand Down Expand Up @@ -95,8 +97,8 @@ def main(
Gaussian filter width (as a fraction of soma diameter) used during
2d in-plane Laplacian of Gaussian filtering.
n_sds_above_mean_thresh : float
Intensity threshold (the number of standard deviations above
the mean) of the filtered 2d planes used to mark pixels as
Per-plane intensity threshold (the number of standard deviations
above the mean) of the filtered 2d planes used to mark pixels as
foreground or background.
outlier_keep : bool, optional
Whether to keep outliers during detection. Defaults to False.
Expand Down Expand Up @@ -129,6 +131,20 @@ def main(
The number of iterations to run the 3d filtering on a cluster. Each
iteration reduces the cluster size by the voxels not retained in
the previous iteration.
n_sds_above_mean_tiled_thresh : float
Per-plane, per-tile intensity threshold (the number of standard
deviations above the mean) for the filtered 2d planes used to mark
pixels as foreground or background. When used, (tile size is not zero)
a pixel is marked as foreground if its intensity is above both the
per-plane and per-tile threshold. I.e. it's above the set number of
standard deviations of the per-plane average and of the per-plane
per-tile average for the tile that contains it.
tiled_thresh_tile_size : float
The tile size used to tile the x, y plane to calculate the local
average intensity for the tiled threshold. The value is multiplied
by soma diameter (i.e. 1 means one soma diameter). If zero or None, the
tiled threshold is disabled and only the per-plane threshold is used.
Tiling is done with 50% overlap when striding.
callback : Callable[int], optional
A callback function that is called every time a plane has finished
being processed. Called with the plane number that has finished.
Expand Down Expand Up @@ -186,6 +202,8 @@ def main(
ball_overlap_fraction=ball_overlap_fraction,
log_sigma_size=log_sigma_size,
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh,
tiled_thresh_tile_size=tiled_thresh_tile_size,
outlier_keep=outlier_keep,
artifact_keep=artifact_keep,
save_planes=save_planes,
Expand Down Expand Up @@ -220,7 +238,9 @@ def main(
plane_shape=settings.plane_shape,
clipping_value=settings.clipping_value,
threshold_value=settings.threshold_value,
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
n_sds_above_mean_thresh=settings.n_sds_above_mean_thresh,
n_sds_above_mean_tiled_thresh=settings.n_sds_above_mean_tiled_thresh,
tiled_thresh_tile_size=settings.tiled_thresh_tile_size,
log_sigma_size=log_sigma_size,
soma_diameter=settings.soma_diameter,
torch_device=torch_device,
Expand Down
115 changes: 106 additions & 9 deletions cellfinder/core/detect/filters/plane/plane_filter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from dataclasses import dataclass, field
from typing import Tuple

import torch
import torch.nn.functional as F

from cellfinder.core.detect.filters.plane.classical_filter import PeakEnhancer
from cellfinder.core.detect.filters.plane.tile_walker import TileWalker


@dataclass
class TileProcessor:
"""
Processor that filters each plane to highlight the peaks and also
Expand Down Expand Up @@ -63,19 +62,31 @@
# voxels who are this many std above mean or more are set to
# threshold_value
n_sds_above_mean_thresh: float
# If used, voxels who are this many or more std above mean of the
# containing tile as well as above n_sds_above_mean_thresh for the plane
# average are set to threshold_value.
n_sds_above_mean_tiled_thresh: float
# the tile size, in pixels, that will be used to tile the x, y plane when
# we calculate the per-tile mean / std for use with
# n_sds_above_mean_tiled_thresh. We use 50% overlap when tiling.
local_threshold_tile_size_px: int = 0
# the torch device name
torch_device: str = ""

# filter that finds the peaks in the planes
peak_enhancer: PeakEnhancer = field(init=False)
peak_enhancer: PeakEnhancer = None
# generates tiles of the planes, with each tile marked as being inside
# or outside the brain based on brightness
tile_walker: TileWalker = field(init=False)
tile_walker: TileWalker = None

def __init__(
self,
plane_shape: Tuple[int, int],
clipping_value: int,
threshold_value: int,
n_sds_above_mean_thresh: float,
n_sds_above_mean_tiled_thresh: float,
tiled_thresh_tile_size: float | None,
log_sigma_size: float,
soma_diameter: int,
torch_device: str,
Expand All @@ -85,6 +96,12 @@
self.clipping_value = clipping_value
self.threshold_value = threshold_value
self.n_sds_above_mean_thresh = n_sds_above_mean_thresh
self.n_sds_above_mean_tiled_thresh = n_sds_above_mean_tiled_thresh
if tiled_thresh_tile_size:
self.local_threshold_tile_size_px = int(
round(soma_diameter * tiled_thresh_tile_size)
)
self.torch_device = torch_device

laplace_gaussian_sigma = log_sigma_size * soma_diameter
self.peak_enhancer = PeakEnhancer(
Expand Down Expand Up @@ -131,7 +148,10 @@
planes,
enhanced_planes,
self.n_sds_above_mean_thresh,
self.n_sds_above_mean_tiled_thresh,
self.local_threshold_tile_size_px,
self.threshold_value,
self.torch_device,
)

return planes, inside_brain_tiles
Expand All @@ -145,21 +165,98 @@
planes: torch.Tensor,
enhanced_planes: torch.Tensor,
n_sds_above_mean_thresh: float,
n_sds_above_mean_tiled_thresh: float,
local_threshold_tile_size_px: int,
threshold_value: int,
torch_device: str,
) -> None:
"""
Sets each plane (in-place) to threshold_value, where the corresponding
enhanced_plane > mean + n_sds_above_mean_thresh*std. Each plane will be
set to zero elsewhere.
"""
planes_1d = enhanced_planes.view(enhanced_planes.shape[0], -1)
z, y, x = enhanced_planes.shape

Check warning on line 178 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L178

Added line #L178 was not covered by tests

# ---- get per-plane global threshold ----
planes_1d = enhanced_planes.view(z, -1)

Check warning on line 181 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L181

Added line #L181 was not covered by tests
# add back last dim
avg = torch.mean(planes_1d, dim=1, keepdim=True).unsqueeze(2)
sd = torch.std(planes_1d, dim=1, keepdim=True).unsqueeze(2)
threshold = avg + n_sds_above_mean_thresh * sd
std, mean = torch.std_mean(planes_1d, dim=1, keepdim=True)
threshold = mean.unsqueeze(2) + n_sds_above_mean_thresh * std.unsqueeze(2)
above_global = enhanced_planes > threshold

Check warning on line 185 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L183-L185

Added lines #L183 - L185 were not covered by tests

# ---- calculate the local tiled threshold ----
# we do 50% overlap so there's no jumps at boundaries
stride = local_threshold_tile_size_px // 2

Check warning on line 189 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L189

Added line #L189 was not covered by tests
# make tile even for ease of computation
tile_size = stride * 2

Check warning on line 191 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L191

Added line #L191 was not covered by tests
# Due to 50% overlap, to get tiles we move the tile by half tile (stride).
# Total moves will be y // stride - 2 (we start already with mask on first
# tile). So add back 1 for the first tile. Partial tiles are dropped
n_y_tiles = max(y // stride - 1, 1) if stride else 1
n_x_tiles = max(x // stride - 1, 1) if stride else 1
do_tile_y = n_y_tiles >= 2
do_tile_x = n_x_tiles >= 2

Check warning on line 198 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L195-L198

Added lines #L195 - L198 were not covered by tests
# we want at least one axis to have at least two tiles
if local_threshold_tile_size_px >= 2 and (do_tile_y or do_tile_x):

Check warning on line 200 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L200

Added line #L200 was not covered by tests
# num edge pixels dropped b/c moving by stride would move tile off edge
y_rem = y % stride
x_rem = x % stride
enhanced_planes_raw = enhanced_planes
if do_tile_y:
enhanced_planes = enhanced_planes[:, y_rem // 2 :, :]
if do_tile_x:
enhanced_planes = enhanced_planes[:, :, x_rem // 2 :]

Check warning on line 208 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L202-L208

Added lines #L202 - L208 were not covered by tests

# add empty channel dim after z "batch" dim -> zcyx
enhanced_planes = enhanced_planes.unsqueeze(1)

Check warning on line 211 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L211

Added line #L211 was not covered by tests
# unfold makes it 3 dim, z, M, L. L is number of tiles, M is tile area
unfolded = F.unfold(

Check warning on line 213 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L213

Added line #L213 was not covered by tests
enhanced_planes,
(tile_size if do_tile_y else y, tile_size if do_tile_x else x),
stride=stride,
)
# average the tile areas, for each tile
std, mean = torch.std_mean(unfolded, dim=1, keepdim=True)
threshold = mean + n_sds_above_mean_tiled_thresh * std

Check warning on line 220 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L219-L220

Added lines #L219 - L220 were not covered by tests

# reshape it back into Y by X tiles, instead of YX being one dim
threshold = threshold.reshape((z, n_y_tiles, n_x_tiles))

Check warning on line 223 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L223

Added line #L223 was not covered by tests

# we need total size of n_tiles * stride + stride + rem for the
# original size. So we add 2 strides and then chop off the excess above
# rem. We center it because of 50% overlap, the first tile is actually
# centered in between the first two strides
offsets = [(0, y), (0, x)]
for dim, do_tile, n_tiles, n, rem in [

Check warning on line 230 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L229-L230

Added lines #L229 - L230 were not covered by tests
(1, do_tile_y, n_y_tiles, y, y_rem),
(2, do_tile_x, n_x_tiles, x, x_rem),
]:
if do_tile:
repeats = (

Check warning on line 235 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L234-L235

Added lines #L234 - L235 were not covered by tests
torch.ones(n_tiles, dtype=torch.int, device=torch_device)
* stride
)
# add total of 2 additional strides
repeats[0] = 2 * stride
repeats[-1] = 2 * stride
output_size = (n_tiles + 2) * stride

Check warning on line 242 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L240-L242

Added lines #L240 - L242 were not covered by tests

threshold = threshold.repeat_interleave(

Check warning on line 244 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L244

Added line #L244 was not covered by tests
repeats, dim=dim, output_size=output_size
)
# drop the excess we gained from padding rem to whole stride
offset = (stride - rem) // 2
offsets[dim - 1] = offset, n + offset

Check warning on line 249 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L248-L249

Added lines #L248 - L249 were not covered by tests

# can't use slice(...) objects in jit code so use actual indices
(a, b), (c, d) = offsets
threshold = threshold[:, a:b, c:d]

Check warning on line 253 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L252-L253

Added lines #L252 - L253 were not covered by tests

above_local = enhanced_planes_raw > threshold
above = torch.logical_and(above_global, above_local)

Check warning on line 256 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L255-L256

Added lines #L255 - L256 were not covered by tests
else:
above = above_global

Check warning on line 258 in cellfinder/core/detect/filters/plane/plane_filter.py

View check run for this annotation

Codecov / codecov/patch

cellfinder/core/detect/filters/plane/plane_filter.py#L258

Added line #L258 was not covered by tests

above = enhanced_planes > threshold
planes[above] = threshold_value
# subsequent steps only care about the values that are set to threshold or
# above in planes. We set values in *planes* to threshold based on the
Expand Down
24 changes: 22 additions & 2 deletions cellfinder/core/detect/filters/setup_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,31 @@ class DetectionSettings:

n_sds_above_mean_thresh: float = 10
"""
Intensity threshold (the number of standard deviations above
the mean) of the filtered 2d planes used to mark pixels as
Per-plane intensity threshold (the number of standard deviations
above the mean) of the 2d filtered planes used to mark pixels as
foreground or background.
"""

n_sds_above_mean_tiled_thresh: float = 10
"""
Per-plane, per-tile intensity threshold (the number of standard deviations
above the mean) for the filtered 2d planes used to mark pixels as
foreground or background. When used, (tile size is not zero) a pixel is
marked as foreground if its intensity is above both the per-plane and
per-tile threshold. I.e. it's above the set number of standard deviations
of the per-plane average and of the per-plane per-tile average for the tile
that contains it.
"""

tiled_thresh_tile_size: float | None = None
"""
The tile size used to tile the x, y plane to calculate the local average
intensity for the tiled threshold. The value is multiplied by soma
diameter (i.e. 1 means one soma diameter). If zero or None, the tiled
threshold is disabled and only the per-plane threshold is used. Tiling is
done with 50% overlap when striding.
"""

outlier_keep: bool = False
"""Whether to keep outlier structures during detection."""

Expand Down
48 changes: 33 additions & 15 deletions cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def main(
split_ball_z_size: float = 15,
split_ball_overlap_fraction: float = 0.8,
n_splitting_iter: int = 10,
n_sds_above_mean_tiled_thresh: float = 10,
tiled_thresh_tile_size: float | None = None,
*,
detect_callback: Optional[Callable[[int], None]] = None,
classify_callback: Optional[Callable[[int], None]] = None,
Expand Down Expand Up @@ -93,8 +95,8 @@ def main(
Gaussian filter width (as a fraction of soma diameter) used during
2d in-plane Laplacian of Gaussian filtering.
n_sds_above_mean_thresh : float
Intensity threshold (the number of standard deviations above
the mean) of the filtered 2d planes used to mark pixels as
Per-plane intensity threshold (the number of standard deviations
above the mean) of the filtered 2d planes used to mark pixels as
foreground or background.
soma_spread_factor : float
Cell spread factor for determining the largest cell volume before
Expand Down Expand Up @@ -148,6 +150,20 @@ def main(
The number of iterations to run the 3d filtering on a cluster. Each
iteration reduces the cluster size by the voxels not retained in
the previous iteration.
n_sds_above_mean_tiled_thresh : float
Per-plane, per-tile intensity threshold (the number of standard
deviations above the mean) for the filtered 2d planes used to mark
pixels as foreground or background. When used, (tile size is not zero)
a pixel is marked as foreground if its intensity is above both the
per-plane and per-tile threshold. I.e. it's above the set number of
standard deviations of the per-plane average and of the per-plane
per-tile average for the tile that contains it.
tiled_thresh_tile_size : float
The tile size used to tile the x, y plane to calculate the local
average intensity for the tiled threshold. The value is multiplied
by soma diameter (i.e. 1 means one soma diameter). If zero or None, the
tiled threshold is disabled and only the per-plane threshold is used.
Tiling is done with 50% overlap when striding.
detect_callback : Callable[int], optional
Called every time a plane has finished being processed during the
detection stage. Called with the plane number that has finished.
Expand All @@ -165,19 +181,21 @@ def main(
logger.info("Detecting cell candidates")

points = detect.main(
signal_array,
start_plane,
end_plane,
voxel_sizes,
soma_diameter,
max_cluster_size,
ball_xy_size,
ball_z_size,
ball_overlap_fraction,
soma_spread_factor,
n_free_cpus,
log_sigma_size,
n_sds_above_mean_thresh,
signal_array=signal_array,
start_plane=start_plane,
end_plane=end_plane,
voxel_sizes=voxel_sizes,
soma_diameter=soma_diameter,
max_cluster_size=max_cluster_size,
ball_xy_size=ball_xy_size,
ball_z_size=ball_z_size,
ball_overlap_fraction=ball_overlap_fraction,
soma_spread_factor=soma_spread_factor,
n_free_cpus=n_free_cpus,
log_sigma_size=log_sigma_size,
n_sds_above_mean_thresh=n_sds_above_mean_thresh,
n_sds_above_mean_tiled_thresh=n_sds_above_mean_tiled_thresh,
tiled_thresh_tile_size=tiled_thresh_tile_size,
batch_size=detection_batch_size,
torch_device=torch_device,
callback=detect_callback,
Expand Down
Loading
Loading