diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 2a4e5828..3110eaed 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -185,6 +185,7 @@ jobs: python -m pip install --upgrade pip wheel # Install cellfinder from the latest SHA on this branch python -m pip install "cellfinder @ git+$GITHUB_SERVER_URL/$GITHUB_REPOSITORY@$GITHUB_SHA" + python -m pip install monai --no-deps # Install checked out copy of brainglobe-workflows python -m pip install .[dev] diff --git a/cellfinder/core/classify/augment.py b/cellfinder/core/classify/augment.py index 44a064c9..6643d1b7 100644 --- a/cellfinder/core/classify/augment.py +++ b/cellfinder/core/classify/augment.py @@ -1,199 +1,249 @@ -from typing import List, Tuple - -import numpy as np -from scipy.ndimage import rotate, zoom - -from cellfinder.core.tools.tools import ( - all_elements_equal, - random_bool, - random_probability, - random_sign, -) - -all_axes = np.array((0, 1, 2)) - - -def augment( - augmentation_parameters: "AugmentationParameters", - image: np.ndarray, - scale_back: bool = True, -) -> np.ndarray: - pixel_sizes = image.shape - min_pixel_size = min(pixel_sizes) - relative_pixel_sizes = [] - for pixel_size in pixel_sizes: - relative_pixel_sizes.append(pixel_size / min_pixel_size) - - image, normalised_pixel_sizes = rescale_to_isotropic( - image, - relative_pixel_sizes, - augmentation_parameters.interpolation_order, - ) - # TODO: is this a sensible order? - if augmentation_parameters.flip_axis is not None: - image = flip_image(image, augmentation_parameters.axes_to_flip) - - if augmentation_parameters.translate is not None: - image = translate_image( - image, - augmentation_parameters.translate_axes, - augmentation_parameters.random_translate_multipliers, - ) +from typing import Literal, Sequence - # if augmentation_parameters.scale is not None: - # image = scale_image(image, augmentation_parameters.scale) +import torch +import torch.nn.functional as F +from monai.transforms import RandAffine - if augmentation_parameters.rotate_max_axes is not None: - image = rotate_image(image, augmentation_parameters.rotation_angles) +from cellfinder.core.tools.tools import get_axis_reordering, random_bool - if scale_back: - image = rescale_to_original_size( - image, - relative_pixel_sizes, - normalised_pixel_sizes, - augmentation_parameters.interpolation_order, - ) - return image - - -def rescale_to_isotropic( - image: np.ndarray, - relative_pixel_sizes: List[float], - interpolation_order: int, -) -> Tuple[np.ndarray, List[float]]: - if not all_elements_equal(relative_pixel_sizes): - min_pixel_size = min(relative_pixel_sizes) - normalised_pixel_sizes = [] - for pixel_size in relative_pixel_sizes: - normalised_pixel_sizes.append( - round(pixel_size / min_pixel_size, 2) - ) +DIM = Literal["x", "y", "z", "c"] +AXIS = Literal["x", "y", "z"] +RandRange = Sequence[float] | Sequence[tuple[float, float]] | None - image = zoom(image, normalised_pixel_sizes, order=interpolation_order) - else: - normalised_pixel_sizes = relative_pixel_sizes - return image, normalised_pixel_sizes - - -def rescale_to_original_size( - image: np.ndarray, - relative_pixel_sizes: List[float], - normalised_pixel_sizes: List[float], - interpolation_order: int, -) -> np.ndarray: - if not all_elements_equal(relative_pixel_sizes): - inverse_pixel_sizes = [] - for pixel_size in normalised_pixel_sizes: - inverse_pixel_sizes.append(round(1 / pixel_size, 2)) - - image = zoom(image, inverse_pixel_sizes, order=interpolation_order) - return image - - -def flip_image(image: np.ndarray, axes_to_flip: List[int]) -> np.ndarray: - for axis in axes_to_flip: - image = np.flip(image, axis) - return image - - -def translate_image( - image: np.ndarray, - translate_axes: List[int], - random_translate_multipliers: List[float], -) -> np.ndarray: - pixel_shifts = [] - for idx, axis in enumerate(translate_axes): - pixel_shifts.append( - int(round(random_translate_multipliers[idx] * image.shape[axis])) - ) - image = np.roll(image, pixel_shifts, axis=translate_axes) - return image +class DataAugmentation: + """ + Randomly augments the input data when called. + Typical example:: -def rotate_image( - image: np.ndarray, rotation_angles: List[float] -) -> np.ndarray: - for axis, angle in enumerate(rotation_angles): - if angle != 0: - rotate_axes = all_axes[all_axes != axis] - image = rotate( - image, angle, axes=rotate_axes, reshape=False, mode="constant" - ) - return image + augmenter = DataAugmentation(...) + augmented_data = augmenter(data) + Parameters + ---------- + volume_size : dict + Dict whose keys are x, y, and z and whose values are the size of the + input data at the given dimension. + """ -# def scale_image(image, scale, ndigits=2): -# scale_factor = round( -# uniform(scale[0], scale[1]), ndigits=ndigits -# ) -# return image + DIM_ORDER = "c", "y", "x", "z" + """ + The dimension order we internally expect the data to be. The user passes + in data in `data_dim_order` order that we convert to this order before + handling. + """ -# def shear_image(image): -# return image + AXIS_ORDER = "y", "x", "z" + """ + Similar to `DIM_ORDER`, except it's the order of the 3 axes of the cuboid. + """ - -class AugmentationParameters: # precomputed, so both channels are treated identically def __init__( self, - flip_axis: Tuple[int, int, int], - translate: Tuple[float, float, float], - rotate_max_axes: Tuple[float, float, float], - interpolation_order: int, + volume_size: dict[str, int], + data_dim_order: tuple[DIM, DIM, DIM, DIM], augment_likelihood: float, + flippable_axis: Sequence[int] = (), + translate_range: RandRange = None, + scale_range: RandRange = None, + rotate_range: RandRange = None, ): - # this is a clumsy way of passing parameters to the augment function - self.flip_axis = flip_axis - self.translate = translate - self.rotate_max_axes = rotate_max_axes - self.interpolation_order = interpolation_order + volume_values = list(volume_size.values()) + self.needs_affine = translate_range or scale_range or rotate_range + self.needs_isotropy = ( + max(volume_values) != min(volume_values) and self.needs_affine + ) + self.isotropic_volume_size = (max(volume_values),) * 3 + + self._volume_size = [volume_size[ax] for ax in self.AXIS_ORDER] + + self._input_axis_order = [d for d in data_dim_order if d != "c"] + self._data_reordering = [] + self._data_reordering_back = [] + self._compute_data_reordering(data_dim_order) + + # do prob = 1 because we decide when to apply it + self.affine = RandAffine( + prob=1, + rotate_range=self._fix_rotate_range(rotate_range), + shear_range=None, + translate_range=self._fix_translate_range(translate_range), + scale_range=self._fix_scale_range(scale_range), + cache_grid=True, + spatial_size=self.isotropic_volume_size, + lazy=False, + ) + self._flippable_axis = self._fix_flippable_axis(flippable_axis) self.augment_likelihood = augment_likelihood - self.axes_to_flip: List[int] = [] - self.translate_axes: List[int] = [] - self.random_translate_multipliers: List[float] = [] - self.rotation_angles: List[float] = [] - - if flip_axis: - self.get_flip_parameters(flip_axis) - if translate: - self.get_translation_parameters(translate) - if rotate_max_axes: - self.get_rotation_parameters(rotate_max_axes) - - def get_flip_parameters(self, flip_axis: Tuple[int, int, int]) -> None: - self.axes_to_flip = [] - for axis in all_axes: - if axis in flip_axis: - if random_bool(likelihood=self.augment_likelihood): - self.axes_to_flip.append(axis) - - def get_translation_parameters( - self, translate: Tuple[float, float, float] - ) -> None: - self.translate_axes = [] - self.random_translate_multipliers = [] - for axis, translate_mag in enumerate(translate): - if translate_mag > 0: - if random_bool(likelihood=self.augment_likelihood): - self.translate_axes.append(axis) - self.random_translate_multipliers.append( - random_sign() * random_probability() * translate_mag - ) - - def get_rotation_parameters( - self, rotate_max_axes: Tuple[float, float, float] + self.axes_to_flip: list[int] = [] + self.do_affine = False + + def _compute_data_reordering( + self, data_dim_order: tuple[DIM, DIM, DIM, DIM] ) -> None: - self.rotation_angles = [] - for max_rotation in rotate_max_axes: - if random_bool(likelihood=self.augment_likelihood): - angle = int( - round( - -max_rotation + 2 * random_probability() * max_rotation - ) - ) + self._data_reordering = [] + self._data_reordering_back = [] + + if data_dim_order != self.DIM_ORDER: + self._data_reordering = get_axis_reordering( + data_dim_order, + self.DIM_ORDER, + ) + self._data_reordering_back = get_axis_reordering( + self.DIM_ORDER, + data_dim_order, + ) + + def _fix_flippable_axis( + self, flippable_axis: Sequence[int] + ) -> Sequence[int]: + if not flippable_axis: + return flippable_axis + if self._input_axis_order == self.AXIS_ORDER: + return flippable_axis + + fixed_axis = [] + for ax_i in flippable_axis: + ax = self._input_axis_order[ax_i] + fixed_axis.append(self.AXIS_ORDER.index(ax)) + return fixed_axis + + def _fix_rotate_range(self, rotate_range: RandRange) -> RandRange: + if rotate_range is None: + return None + if len(rotate_range) != 3: + raise ValueError("Must specify rotate value for each dimension") + if self._input_axis_order == self.AXIS_ORDER: + return rotate_range + + fixed_range = [None, None, None] + for i in range(3): + new_i = self.AXIS_ORDER.index(self._input_axis_order[i]) + fixed_range[new_i] = rotate_range[i] + return fixed_range + + def _fix_translate_range(self, translate_range: RandRange) -> RandRange: + if translate_range is None: + return None + if len(translate_range) != 3: + raise ValueError("Must specify translate value for each dimension") + + translate_range = list(translate_range) + for i, (val, size) in enumerate( + zip(translate_range, self.isotropic_volume_size) + ): + # we expect the values as fraction of the size of the volume in + # the given dim. monai expects it as pixel offsets so we need + # to multiply by dim size. Also, it does negative translation + if isinstance(val, Sequence): + translate_range[i] = -val[0] * size, -val[1] * size + else: + translate_range[i] = -val * size + + if self._input_axis_order == self.AXIS_ORDER: + return translate_range + + fixed_range = [None, None, None] + for i in range(3): + new_i = self.AXIS_ORDER.index(self._input_axis_order[i]) + fixed_range[new_i] = translate_range[i] + return fixed_range + + def _fix_scale_range(self, scale_range: RandRange) -> RandRange: + if scale_range is None: + return None + if len(scale_range) != 3: + raise ValueError("Must specify scale value for each dimension") + + scale_range = list(scale_range) + for i, val in enumerate(scale_range): + # we get scale values where 1 means original size. monai + # expects values around 0, where 0 means original size + if isinstance(val, Sequence): + scale_range[i] = 1 / val[0] - 1, 1 / val[1] - 1 else: - angle = 0 - self.rotation_angles.append(angle) + scale_range[i] = 1 / val - 1 + + if self._input_axis_order == self.AXIS_ORDER: + return scale_range + + fixed_range = [None, None, None] + for i in range(3): + new_i = self.AXIS_ORDER.index(self._input_axis_order[i]) + fixed_range[new_i] = scale_range[i] + return fixed_range + + def update_parameters(self) -> bool: + self.do_affine = False + if self.needs_affine: + self.do_affine = random_bool( + likelihood=1 - self.augment_likelihood + ) + self.update_flip_parameters() + + return bool(self.do_affine or self.axes_to_flip) + + def update_flip_parameters(self) -> None: + flippable_axis = self._flippable_axis + if not flippable_axis: + return + + axes_to_flip = self.axes_to_flip = [] + for axis in flippable_axis: + if random_bool(likelihood=1 - self.augment_likelihood): + # add 1 because of initial channel dim + axes_to_flip.append(axis + 1) + + def rescale_to_isotropic(self, data: torch.Tensor) -> torch.Tensor: + if not self.needs_isotropy: + return data + + # needs batch dim + data = data.unsqueeze(0) + data = F.interpolate( + data, size=self.isotropic_volume_size, mode="trilinear" + ) + data = data.squeeze(0) + return data + + def rescale_to_original(self, data: torch.Tensor) -> torch.Tensor: + if not self.needs_isotropy: + return data + + # needs batch dim + data = data.unsqueeze(0) + data = F.interpolate(data, size=self._volume_size, mode="trilinear") + data = data.squeeze(0) + return data + + def apply_affine(self, data: torch.Tensor) -> torch.Tensor: + if not self.do_affine: + return data + + return self.affine(data, padding_mode="border") + + def flip_axis(self, data: torch.Tensor) -> torch.Tensor: + if not self.axes_to_flip: + return data + + return torch.flip(data, self.axes_to_flip) + + def __call__(self, data: torch.Tensor) -> torch.Tensor: + if self._data_reordering: + data = torch.permute(data, self._data_reordering) + + data = self.rescale_to_isotropic(data) + + data = self.apply_affine(data) + data = self.flip_axis(data) + + data = self.rescale_to_original(data) + + if self._data_reordering_back: + data = torch.permute(data, self._data_reordering_back) + + return data diff --git a/cellfinder/core/classify/classify.py b/cellfinder/core/classify/classify.py index 37fc06cc..2845856c 100644 --- a/cellfinder/core/classify/classify.py +++ b/cellfinder/core/classify/classify.py @@ -1,16 +1,22 @@ import os +from copy import deepcopy from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple import keras -import numpy as np +import torch from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.system import get_num_processes +from torch.utils.data import DataLoader from cellfinder.core import logger, types -from cellfinder.core.classify.cube_generator import CubeGeneratorFromFile +from cellfinder.core.classify.cube_generator import ( + CuboidBatchSampler, + CuboidStackDataset, +) from cellfinder.core.classify.tools import get_model +from cellfinder.core.tools.image_processing import dataset_mean_std from cellfinder.core.train.train_yaml import depth_type, models @@ -28,9 +34,12 @@ def main( trained_model: Optional[os.PathLike], model_weights: Optional[os.PathLike], network_depth: depth_type, - max_workers: int = 3, + max_workers: int = 6, + pin_memory: bool = True, *, callback: Optional[Callable[[int], None]] = None, + normalize_channels: bool = False, + normalization_down_sampling: int = 32, ) -> List[Cell]: """ Parameters @@ -72,11 +81,19 @@ def main( network_depth: str The network depth to use during classification. Defaults to `"50"`. max_workers: int - The number of sub-processes to use for data loading / processing. + The max number of sub-processes to use for data loading / processing. Defaults to 8. callback : Callable[int], optional A callback function that is called during classification. Called with the batch number once that batch has been classified. + normalize_channels : bool + If True, the signal and background data will be each normalized + to a mean of zero and standard deviation of 1. Defaults to False. + normalization_down_sampling : int + If `normalize_channels` is True, the data arrays will be down-sampled + in the first axis by this value before calculating their statistics. + E.g. a value of 2 means every second plane will be used. Defaults to + 32. """ if signal_array.ndim != 3: raise IOError("Signal data must be 3D") @@ -90,22 +107,52 @@ def main( # Too many workers doesn't increase speed, and uses huge amounts of RAM workers = get_num_processes(min_free_cpu_cores=n_free_cpus) + workers = min(workers, max_workers) start_time = datetime.now() + voxel_sizes = list(map(float, voxel_sizes)) + + signal_normalization = background_normalization = None + if normalize_channels: + logger.debug("Calculating channels norms") + signal_normalization = dataset_mean_std( + signal_array, normalization_down_sampling + ) + background_normalization = dataset_mean_std( + background_array, normalization_down_sampling + ) + logger.debug( + f"Signal channel norm is: {signal_normalization}. " + f"Background channel norm is: {background_normalization}" + ) + logger.debug("Initialising cube generator") - inference_generator = CubeGeneratorFromFile( - points, - signal_array, - background_array, - voxel_sizes, - network_voxel_sizes, + dataset = CuboidStackDataset( + signal_array=signal_array, + background_array=background_array, + signal_normalization=signal_normalization, + background_normalization=background_normalization, + points=points, + data_voxel_sizes=voxel_sizes, + network_voxel_sizes=network_voxel_sizes, + network_cuboid_voxels=(cube_depth, cube_height, cube_width), + axis_order=("z", "y", "x"), + max_axis_0_cuboids_buffered=1, + ) + # we use our own sampler so we can control the ordering + sampler = CuboidBatchSampler( + dataset=dataset, batch_size=batch_size, - cube_width=cube_width, - cube_height=cube_height, - cube_depth=cube_depth, - use_multiprocessing=False, - workers=workers, + sort_by_axis="z", + auto_shuffle=False, + ) + data_loader = DataLoader( + dataset=dataset, + sampler=sampler, + batch_size=None, + num_workers=workers, + pin_memory=pin_memory, ) if trained_model and Path(trained_model).suffix == ".h5": @@ -125,25 +172,39 @@ def main( logger.info("Running inference") # in Keras 3.0 multiprocessing params are specified in the generator - predictions = model.predict( - inference_generator, - verbose=True, - callbacks=callbacks, - ) - predictions = predictions.round() - predictions = predictions.astype("uint16") + if workers: + dataset.start_dataset_thread(workers) + try: + predictions = model.predict( + data_loader, + verbose=True, + callbacks=callbacks, + ) + finally: + dataset.stop_dataset_thread() - predictions = np.argmax(predictions, axis=1) + predictions = torch.argmax(torch.from_numpy(predictions), dim=1) points_list = [] # only go through the "extractable" points - for idx, cell in enumerate(inference_generator.ordered_points): - cell.type = predictions[idx] + 1 - points_list.append(cell) + k = 0 + # the sampler doesn't auto shuffle, so the classification order (i.e. order + # in `predictions`) is the same order as the sampler returns the batches. + # Use that to get the corresponding row in points_arr, which gives us the + # `index` of the row in the original point in the input points list + for arr in sampler: + for i in arr: + p_idx = int(dataset.points_arr[i, 4].item()) + # don't use the original cell, use a copy + cell = deepcopy(points[p_idx]) + cell.type = int((predictions[k] + 1).item()) + points_list.append(cell) + k += 1 time_elapsed = datetime.now() - start_time logger.info( - "Classfication complete - all points done in : {}".format(time_elapsed) + f"Classification complete - {len(points_list)} points " + f"done in : {time_elapsed}" ) return points_list diff --git a/cellfinder/core/classify/cube_generator.py b/cellfinder/core/classify/cube_generator.py index 601a5a0b..37065820 100644 --- a/cellfinder/core/classify/cube_generator.py +++ b/cellfinder/core/classify/cube_generator.py @@ -1,495 +1,1413 @@ -from pathlib import Path -from random import shuffle -from typing import Dict, List, Optional, Tuple, Union +import math +from collections import OrderedDict, defaultdict +from collections.abc import Sequence +from numbers import Integral +from typing import Any, Hashable, Literal -import keras import numpy as np -from brainglobe_utils.cells.cells import Cell, group_cells_by_z -from brainglobe_utils.general.numerical import is_even -from keras.utils import Sequence -from scipy.ndimage import zoom -from skimage.io import imread +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F +from brainglobe_utils.cells.cells import Cell +from tifffile import imread +from torch.multiprocessing import Queue +from torch.utils.data import Dataset, Sampler, get_worker_info from cellfinder.core import types -from cellfinder.core.classify.augment import AugmentationParameters, augment +from cellfinder.core.classify.augment import DataAugmentation +from cellfinder.core.tools.threading import ( + EOFSignal, + ExecutionFailure, + ThreadWithExceptionMPSafe, +) +from cellfinder.core.tools.tools import ( + get_axis_reordering, + get_data_converter, + get_max_possible_int_value, +) -# TODO: rename, as now using dask arrays - -# actually should combine to one generator +AXIS = Literal["x", "y", "z"] +DIM = Literal[AXIS, "c"] +RandRange = Sequence[float] | Sequence[tuple[float, float]] | None + + +_point_ax_map = {"x": 0, "y": 1, "z": 2} class StackSizeError(Exception): pass -class CubeGeneratorFromFile(Sequence): +def _read_data_send_cuboids( + thread: ThreadWithExceptionMPSafe, + dataset: "ImageDataBase", + queues: Sequence[Queue], +) -> None: + """ + Function run by sub-thread that reads data from a dataset, extracts the + cuboids, and sends them back to the main thread for further processing. + + Each thread listens on its own main queue, via its + `thread.get_msg_from_mainthread`, for requests. We also pass in a list of + response queues to which the response of a given request is sent. Each + request indicates which queue in the list to send back the read data. + + Requests also pass in a torch buffer into which we read the cube data, + saving us having to send back newly allocated buffers. + + If there's an exception serving a particular request, the indicated queue + is sent back the exception, instead of the normal data response. + + :param thread: The `ThreadWithExceptionMPSafe`, automatically passed to the + func. + :param dataset: The `ImageDataBase` used to read the data and cubes. + :param queues: A list of queues for sending data. Each request to the + sub-thread indicates which queue in the list to send the cubes. This + allows one thread to serve multiple consumers. + """ + while True: + msg = thread.get_msg_from_mainthread() + if msg == EOFSignal: + return + + key, point, buffer, queue_id = msg + queue = queues[queue_id] + try: + if key == "point": + buffer[:] = dataset.get_point_cuboid_data(point) + elif key == "points": + dataset.get_point_batch_cuboid_data(buffer, point) + else: + raise ValueError(f"Bad message {key}") + except BaseException as e: + queue.put(("exception", e), block=True, timeout=None) + else: + queue.put((key, point), block=True, timeout=None) + + +def get_data_cuboid_voxels( + network_cuboid_voxels: int, + network_voxel_size_um: float, + data_voxel_size_um: float, +) -> int: + """In a given dimension, the network is trained with + `network_cuboid_voxels` voxels and at `network_voxel_size_um` um per voxel. + This returns the corresponding number of voxels of the input data with its + `data_voxel_size_um` um per voxel. + + :param network_cuboid_voxels: The trained network's cube number of voxels + in that dimension. + :param network_voxel_size_um: The network's um per voxel in the dim. + :param data_voxel_size_um: The input data's um per voxel in the dim. + """ + return int( + round( + network_cuboid_voxels * network_voxel_size_um / data_voxel_size_um + ) + ) + + +def get_data_cuboid_range( + pos: float, num_voxels: int, axis: AXIS +) -> tuple[int, int]: + """For a given dim, takes the location of a point in the input data and + returns the start and end index in the data that centers a cube on the + point. + + :param pos: The position, in voxels, of the point in the input data. For + the given dim. + :param num_voxels: The number of voxels of the cube in this dim. + :param axis: The name of the axis we're calculating. Can be the str + `"x"`, `"y"`, or `"z"`. The axes are centered differently for backward + compatibility. + :return: tuple of ints, `(start, end)`. Cube can be extracted then with + `data[start:end]`. + """ + match axis: + case "x" | "y": + start = int(round(pos - num_voxels / 2)) + case "z": + start = int(pos - num_voxels // 2) + case _: + raise ValueError(f"Unknown axis {axis}") + + return start, start + num_voxels + + +class ImageDataBase: + """ + A base class for extracting cuboids out of image data. + + At the base level we initialize with an array of 3d positions (center + location of potential cells) and we can return a cuboid centered on a + position, given its index in the list. + + The returned cuboid is of size `cuboid_with_channels_size` with the + corresponding axis order given by `data_with_channels_axis_order`. These + correspond to the provided `cuboid_size` and `data_axis_order`, only with + the addition of the channels dimension. + """ + + points_arr: torch.Tensor = None + """An Nx3 array. Each row is a 3d point with the position of a potential + cell. Units are voxels of the input data. The columns are in x, y, z order. """ - Reads cubes (defined as e.g. xml, csv) from raw data to pass to - keras.Model.fit_generator() or keras.Model.predict_generator() - If augment=True, each augmentation selected has an "augment_likelihood" - chance of being applied to each cube + num_channels: int = 1 + """The number of channels contained in the input data.""" + + cuboid_size: tuple[int, int, int] = (1, 1, 1) + """Size of the cuboid in 3d. Cuboids of this size, centered at a given + point will be returned. + + Units are voxels in the input data. The axis order corresponds + to `data_axis_order`. """ - # TODO: shuffle within (and maybe between) batches - # TODO: limit workers based on RAM + data_axis_order: tuple[AXIS, AXIS, AXIS] = ("z", "y", "x") + """The axis order of the input date. It's a tuple of `"x"`, `"y"`, and + `"z"` matching the dim order of the input data. + """ def __init__( self, - points: List[Cell], - signal_array: types.array, - background_array: types.array, - voxel_sizes: Tuple[int, int, int], - network_voxel_sizes: Tuple[int, int, int], - batch_size: int = 64, - cube_width: int = 50, - cube_height: int = 50, - cube_depth: int = 20, - channels: int = 2, # No other option currently - classes: int = 2, - extract: bool = False, - train: bool = False, - augment: bool = False, - augment_likelihood: float = 0.1, - flip_axis: Tuple[int, int, int] = (0, 1, 2), - rotate_max_axes: Tuple[float, float, float] = (1, 1, 1), # degrees - # scale=[0.5, 2], # min, max - translate: Tuple[float, float, float] = (0.05, 0.05, 0.05), - shuffle: bool = False, - interpolation_order: int = 2, - *args, - **kwargs, + points_arr: torch.Tensor, + data_axis_order: tuple[AXIS, AXIS, AXIS] = ("z", "y", "x"), + cuboid_size: tuple[int, int, int] = (1, 1, 1), ): - # pass any additional arguments not specified in signature to the - # constructor of the superclass (e.g.: `use_multiprocessing` or - # `workers`) - super().__init__(*args, **kwargs) - - self.points = points - self.signal_array = signal_array - self.background_array = background_array - self.batch_size = batch_size - self.axis_2_pixel_um = float(voxel_sizes[2]) - self.axis_1_pixel_um = float(voxel_sizes[1]) - self.axis_0_pixel_um = float(voxel_sizes[0]) - self.network_axis_2_pixel_um = float(network_voxel_sizes[2]) - self.network_axis_1_pixel_um = float(network_voxel_sizes[1]) - self.network_axis_0_pixel_um = float(network_voxel_sizes[0]) - self.cube_width = cube_width - self.cube_height = cube_height - self.cube_depth = cube_depth - self.channels = channels - self.classes = classes - - # saving training data to file - self.extract = extract - - self.train = train - self.augment = augment - self.augment_likelihood = augment_likelihood - self.flip_axis = flip_axis - self.rotate_max_axes = rotate_max_axes - # self.scale = scale - self.translate = translate - self.shuffle = shuffle - self.interpolation_order = interpolation_order - - self.scale_cubes = False - - self.rescaling_factor_axis_2: float = 1 - self.rescaling_factor_axis_1: float = 1 - self.rescaled_cube_width: float = self.cube_width - self.rescaled_cube_height: float = self.cube_height - - self.__check_image_sizes() - self.__get_image_size() - self.__check_z_scaling() - self.__check_in_plane_scaling() - self.__remove_outlier_points() - self.__get_batches() - if shuffle: - self.on_epoch_end() - - def __check_image_sizes(self) -> None: - if len(self.signal_array) != len(self.background_array): + self.points_arr = points_arr + self.cuboid_size = cuboid_size + + if len(data_axis_order) != 3 or set(data_axis_order) != { + "z", + "y", + "x", + }: raise ValueError( - f"Number of signal images ({len(self.signal_array)}) does not " - f"match the number of background images " - f"({len(self.background_array)}" + f"Expected the axis order to list x, y, z, but got " + f"{data_axis_order}" ) + self.data_axis_order = data_axis_order - def __get_image_size(self) -> None: - self.image_z_size = len(self.signal_array) - self.image_height, self.image_width = self.signal_array[0].shape + @property + def data_with_channels_axis_order(self) -> tuple[DIM, DIM, DIM, DIM]: + """Same as `data_axis_order`, but it's a 4-tuple because we also + include `"c"`. The output cube ordering is 4 dimensional, x, y, z, + and c. This specifies the order. - def __check_in_plane_scaling(self) -> None: - if self.axis_2_pixel_um != self.network_axis_2_pixel_um: - self.rescaling_factor_axis_2 = ( - self.network_axis_2_pixel_um / self.axis_2_pixel_um - ) - self.rescaled_cube_width = ( - self.cube_width * self.rescaling_factor_axis_2 - ) - self.scale_cubes = True - if self.axis_1_pixel_um != self.network_axis_1_pixel_um: - self.rescaling_factor_axis_1 = ( - self.network_axis_1_pixel_um / self.axis_1_pixel_um - ) - self.rescaled_cube_height = ( - self.cube_height * self.rescaling_factor_axis_1 - ) - self.scale_cubes = True + By default, it's just `data_axis_order` plus `c` at + the end. But it could be different for different loaders. + """ + return *self.data_axis_order, "c" - def __check_z_scaling(self) -> None: - if self.axis_0_pixel_um != self.network_axis_0_pixel_um: - plane_scaling_factor = ( - self.network_axis_0_pixel_um / self.axis_0_pixel_um - ) - self.num_planes_needed_for_cube = round( - self.cube_depth * plane_scaling_factor - ) - else: - self.num_planes_needed_for_cube = self.cube_depth - - if self.num_planes_needed_for_cube > self.image_z_size: - raise StackSizeError( - f"The number of planes provided ({self.image_z_size}) " - "is not sufficient for any cubes to be extracted " - f"(need at least {self.num_planes_needed_for_cube}). " - "Please check the input data" - ) + @property + def cuboid_with_channels_size(self) -> tuple[int, int, int, int]: + """Similar to `cuboid_size`, but it also includes channels size.""" + return *self.cuboid_size, self.num_channels - def __remove_outlier_points(self) -> None: + def get_point_cuboid_data(self, point_key: int) -> torch.Tensor: """ - Remove points that won't get extracted (i.e too close to the edge) + Takes a key used to identify a specific point and returns the cuboid + centered around the point. + + :param point_key: A unique key used to identify the point. E.g. the + index in `points_arr`. + :return: The torch Tensor of size `cuboid_with_channels_size` centered + on the point. """ - self.points = [ - point for point in self.points if self.extractable(point) - ] + raise NotImplementedError + + def get_point_batch_cuboid_data( + self, batch: torch.Tensor, points_key: Sequence[int] + ) -> None: + """ + Similar to `get_point_cuboid_data` except it passes a sequence of keys + that identifies a list of points. We then fill `batch` with the cuboids + centered around the points. + + :param batch: A 5d torch tensor of size N x cuboid_with_channels_size + into which the cuboids corresponding to the points will be filled + into. N is the number of points (the length of `points_key`) and is + the first dimension. + :param points_key: A list of unique keys used to identify the points. + E.g. their indices in `points_arr`. + """ + raise NotImplementedError + + +class CachedStackImageDataBase(ImageDataBase): + """ + Takes a 3d image stack (potentially a folder of tiffs or a 3d numpy array) + and extracts requested cuboids from the stack as torch Tensors. It also + buffers some planes (at the first axis of the data stack) in memory so they + don't have to be repeatedly read. + + This is especially efficient when we request cuboids sequentially ordered + by the first axis. E.g. if the first axis is z, then requesting cuboids + ordered by increasing z will be very fast. With a buffer size at least as + large as the number of processing workers, even if we read z planes + slightly out of order e.g. if multiple workers request them in parallel + which are serialized slightly out of order, it should still be very fast. + + + :param max_axis_0_cuboids_buffered: Each cuboid requires `n` planes + along axis 0. With the assumption that data is read from disk in + planes of this first axis, buffering these planes is advantages. + This determines how many multiples (or fractions) of `n` such planes + to buffer, in addition to `n` that is always buffered. So e.g. `1` + means `2n` and `0.5` means `1.5n`. `1` is a good default. + """ - def extractable(self, point: Cell) -> bool: - x0, x1, y0, y1, z0, z1 = self.__get_boundaries() - return ( - x0 <= point.x <= x1 and y0 <= point.y <= y1 and z0 <= point.z <= z1 + stack_shape: tuple[int, int, int] + """ + The 3d input array size. Its axis order is `data_axis_order`. + """ + + max_axis_0_planes_buffered: int = 0 + """ + The number of planes to buffer in memory for the first axis, in addition to + the number of planes in a cube corresponding to this axis. This + corresponds to the first axis in `data_axis_order`. The assumption is that + data is read from disk in planes of that axis. + + A good default is the max of the number of planes in a cube and the number + of workers used by the torch data loaders. + """ + + _planes_buffer: OrderedDict[int, torch.Tensor] + """ + Cache that maps plane numbers to the plane tensor. Plane axis and max dict + size is as in `max_axis_0_planes_buffered`. + """ + + def __init__( + self, + max_axis_0_cuboids_buffered: float = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.max_axis_0_planes_buffered = int( + round((max_axis_0_cuboids_buffered + 1) * self.cuboid_size[0]) ) - def __get_boundaries(self) -> Tuple[int, int, int, int, int, int]: - x0 = int(round((self.cube_width / 2) * self.rescaling_factor_axis_2)) - x1 = int(round(self.image_width - x0)) + self._planes_buffer = OrderedDict() + + def get_point_cuboid_data(self, point_key: int) -> torch.Tensor: + """ + See base-class, except `point_key` *is* the plane index in the array of + planes. + """ + data = torch.empty(self.cuboid_with_channels_size, dtype=torch.float32) - y0 = int(round((self.cube_height / 2) * self.rescaling_factor_axis_1)) - y1 = int(round(self.image_height - y0)) + point = self.points_arr[point_key, :].tolist() + for j, plane in enumerate(self._get_cuboid_planes(point)): + data[j, ...] = plane - z0 = int(round(self.num_planes_needed_for_cube / 2)) - z1 = self.image_z_size - z0 - return x0, x1, y0, y1, z0, z1 + return data - def __get_batches(self) -> None: - self.points_groups = group_cells_by_z(self.points) - # TODO: add optional shuffling of each group here - self.batches = [] - for centre_plane in self.points_groups.keys(): - points_per_plane = self.points_groups[centre_plane] - for i in range(0, len(points_per_plane), self.batch_size): - self.batches.append(points_per_plane[i : i + self.batch_size]) + def get_point_batch_cuboid_data( + self, batch: torch.Tensor, points_key: Sequence[int] + ) -> None: + """ + See base-class, except `points_key` *is* the plane index in the array + of planes. + """ + if (len(points_key), *self.cuboid_with_channels_size) != batch.shape: + raise ValueError( + "Expected the buffer to match the points X cubes shape" + ) - self.ordered_points = [] - for batch in self.batches: - for cell in batch: - self.ordered_points.append(cell) + points_arr = self.points_arr + for i, point_key in enumerate(points_key): + point = points_arr[point_key, :].tolist() + for j, plane in enumerate(self._get_cuboid_planes(point)): + batch[i, j, ...] = plane - def __len__(self) -> int: + def _get_cuboid_planes(self, point: Sequence[float]) -> list[torch.Tensor]: """ - Number of batches - :return: Number of batches per epoch + Takes a 3d point and returns a list of sequential planes, where when + concatenated in the first axis yields a cube of the correct size. + This is more efficient, for the calling function to copy plane by plane + into its buffer than use concatenating and calling function copying. """ - return len(self.batches) + max_planes = self.max_axis_0_planes_buffered + planes_buffer = self._planes_buffer + + cuboid_indices = [] + for ax, size in zip(self.data_axis_order, self.cuboid_size): + idx = _point_ax_map[ax] + cuboid_indices.append(get_data_cuboid_range(point[idx], size, ax)) + ax1 = slice(*cuboid_indices[1]) + ax2 = slice(*cuboid_indices[2]) + + planes = [] + # buffered axis + for i in range(*cuboid_indices[0]): + if i not in planes_buffer: + plane_shape = *self.stack_shape[1:], self.num_channels + plane = torch.empty(plane_shape, dtype=torch.float32) + for channel in range(self.num_channels): + plane[:, :, channel] = self.read_plane(i, channel) + + if len(planes_buffer) == max_planes: + # fifo when last=False + planes_buffer.popitem(last=False) + assert len(planes_buffer) < max_planes - def __getitem__(self, index: int) -> Union[ - np.ndarray, - Tuple[np.ndarray, List[Dict[str, float]]], - Tuple[np.ndarray, Dict], - ]: + planes_buffer[i] = plane + + planes.append(planes_buffer[i][ax1, ax2, :]) + + return planes + + def read_plane(self, plane: int, channel: int) -> torch.Tensor: """ - Generates a single batch of cubes - :param index: - :return: + Takes a plane number along the first axis and a channel number and + it returns the torch tensor containing that plane for that channel. + + :param plane: The plane index in the first dim of `stack_shape` to + read. + :param channel: The channel number ` Tuple[np.ndarray, np.ndarray]: - centre_plane = self.batches[index][0].z + def __init__( + self, + input_arrays: Sequence[types.array], + **kwargs, + ): + super().__init__(**kwargs) + + self.input_arrays = input_arrays + self.stack_shape = input_arrays[0].shape + self.num_channels = len(self.input_arrays) + self._converters = [ + get_data_converter(arr.dtype, np.float32) + for arr in self.input_arrays + ] - min_plane, max_plane = get_cube_depth_min_max( - centre_plane, self.num_planes_needed_for_cube + def read_plane(self, plane: int, channel: int) -> torch.Tensor: + converter = self._converters[channel] + return torch.from_numpy( + converter(self.input_arrays[channel][plane, ...]) ) - signal_stack = np.array(self.signal_array[min_plane:max_plane]) - background_stack = np.array(self.background_array[min_plane:max_plane]) - return signal_stack, background_stack +class CachedCuboidImageDataBase(ImageDataBase): + """ + Takes a collection of cuboids (e.g. a folder of tiff files, each a cuboid + or a list of 3d numpy arrays) and returns a requested cuboid as a torch + Tensor. + + It also buffers `max_cuboids_buffered` recent cuboids so they + don't have to be repeatedly read. + """ + + max_cuboids_buffered: int = 0 + """ + The number of most recently read cuboids to keep in memory. + + This will be the number passed to `__init__` plus 1 so at least the last + one is always cached. + """ - def __generate_cubes( + _cuboids_buffer: OrderedDict[Hashable, torch.Tensor] + """ + Maps the cuboids hashable point keys to the cuboids. The cuboid buffered + includes the channels in the `data_with_channels_axis_order` order. + """ + + def __init__( self, - cell_batch: List[Cell], - signal_stack: np.ndarray, - background_stack: np.ndarray, - ) -> np.ndarray: - number_images = len(cell_batch) - images = np.empty( - ( - (number_images,) - + (self.cube_height, self.cube_width, self.cube_depth) - + (self.channels,) - ), - dtype=np.float32, - ) + max_cuboids_buffered: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.max_cuboids_buffered = max_cuboids_buffered + 1 + self._cuboids_buffer = OrderedDict() + + def get_point_cuboid_data(self, point_key: int) -> torch.Tensor: + max_cuboids = self.max_cuboids_buffered + cuboids_buffer = self._cuboids_buffer - for idx, cell in enumerate(cell_batch): - images = self.__populate_array_with_cubes( - images, idx, cell, signal_stack, background_stack + if point_key not in cuboids_buffer: + cuboid = torch.empty( + self.cuboid_with_channels_size, dtype=torch.float32 ) + for channel in range(self.num_channels): + cuboid[:, :, :, channel] = self.read_cuboid(point_key, channel) - return images + if len(cuboids_buffer) == max_cuboids: + # fifo when last=False + cuboids_buffer.popitem(last=False) + assert len(cuboids_buffer) < max_cuboids + + cuboids_buffer[point_key] = cuboid + + return cuboids_buffer[point_key] + + def get_point_batch_cuboid_data( + self, batch: torch.Tensor, points_key: Sequence[int] + ) -> None: + if (len(points_key), *self.cuboid_with_channels_size) != batch.shape: + raise ValueError( + "Expected the buffer to match the points X cubes shape" + ) + + for i, point_key in enumerate(points_key): + batch[i, ...] = self.get_point_cuboid_data(point_key) + + def read_cuboid(self, point_key: int, channel: int) -> torch.Tensor: + """ + Takes a key used to identify a point and a channel number and + returns the torch tensor containing the cuboid for that channel. - def __populate_array_with_cubes( + :param point_key: A key used to identify the point. + :param channel: The channel number ` np.ndarray: - if self.augment: - self.augmentation_parameters = AugmentationParameters( - self.flip_axis, - self.translate, - self.rotate_max_axes, - self.interpolation_order, - self.augment_likelihood, + filenames_arr: np.ndarray, + **kwargs, + ): + super().__init__(**kwargs) + if not len(filenames_arr): + raise ValueError( + "No data was provided, must have at least one readable point" ) - images[idx, :, :, :, 0] = self.__get_oriented_image(cell, signal_stack) - images[idx, :, :, :, 1] = self.__get_oriented_image( - cell, background_stack - ) - return images - - def __get_oriented_image( - self, cell: Cell, image_stack: np.ndarray - ) -> np.ndarray: - x0 = int(round(cell.x - (self.rescaled_cube_width / 2))) - x1 = int(x0 + self.rescaled_cube_width) - y0 = int(round(cell.y - (self.rescaled_cube_height / 2))) - y1 = int(y0 + self.rescaled_cube_height) - image = image_stack[:, y0:y1, x0:x1] - image = np.moveaxis(image, 0, 2) - - if self.augment: - # scale to isotropic, but don't scale back - image = augment( - self.augmentation_parameters, image, scale_back=False + if len(filenames_arr) != len(self.points_arr): + raise ValueError( + "Points and filenames must have same number of elements" ) - pixel_scalings = [ - self.cube_height / image.shape[0], - self.cube_width / image.shape[1], - self.cube_depth / image.shape[2], # type: ignore[misc] - # Not sure why mypy thinks .shape[2] is out of bounds above? + self.filenames_arr = filenames_arr + self.num_channels = len(filenames_arr[0]) + self._converters = [ + get_data_converter( + imread(channel_filenames.item()).dtype, np.float32 + ) + for channel_filenames in filenames_arr[0] ] - # TODO: ensure this is always the correct size - image = zoom(image, pixel_scalings, order=self.interpolation_order) - return image - - @staticmethod - def __get_batch_dict(cell_batch: List[Cell]) -> List[Dict[str, float]]: - return [cell.to_dict() for cell in cell_batch] - - def on_epoch_end(self) -> None: + def read_cuboid(self, point_key: int, channel: int) -> torch.Tensor: """ - Shuffle data for each epoch - :return: Shuffled indexes + See super-class. `point_key` is the cuboid's index in `filenames_arr` + that we want to read. """ - shuffle(self.batches) + converter = self._converters[channel] + data = imread(self.filenames_arr[point_key][channel].item()) + return torch.from_numpy(converter(data)) + + +class CuboidDatasetBase(Dataset): + """ + Implements a pytorch `Dataset` that takes a list of 3d point coordinates + (centers of potential cells) and a `ImageDataBase` instance that contains + voxel data. The dataset yields batches of torch Tensors with cuboids + of the voxel data centered at these points. + + Data is accessed similar to normal torch Dataset. With e.g. `len(dataset)` + or `dataset[i]`. If `i` is a single int it returns the corresponding 4d + item (includes channel), otherwise it's a sequence of ints and it returns + a 5d batch of those items. The batch dimension is always the first + dimension followed by `output_axis_order` dimensions. + + The data output is either just the data or it also includes the labels, + depending on `target_output`. + + :param points: A list of `Cell`instances containing the cell centers. + Units are in voxels of the input data - not in microns. The cells are + saved in `points`, with some caveats. See their docs. + :param data_voxel_sizes: A 3-tuple indicating the input data's 3d voxel + size in `um`. The tuple's order corresponds to `axis_order`. This is + used along with `network_voxel_sizes` to extract and scale a cube + around each point of similar size as that the network was trained on. + :param network_voxel_sizes: A 3-tuple indicating the trained network's 3d + voxel size in `um`. The tuple's order corresponds to `axis_order`. + :param network_cuboid_voxels: A 3-tuple indicating the cuboid size used to + train the network in voxels The tuple's order corresponds to + `axis_order`. + :param axis_order: A 3-tuple indicating the input data's three + dimensions' axis order. It's any permutations of `("x", "y", "z")`. + :param output_axis_order: A 4-tuple indicating the desired output data's + three dimensions plus channel axis order. It's any permutations of + `("x", "y", "z", "c")`. For now, `"c"` is assumed last. + :param src_image_data: The `ImageDataBase` that will be used to read the + voxel cuboids for a given point. + :param classes: The number of classes used by the network when classifying + cuboids. + :param target_output: A literal indicating the type of label the dataset + should return during training / testing. It is one of `"index"` + (it returns the `index` of the point in the original `points` input + list), `"label"` (it returns a one-hot vector indicating the instance + class label - values are the valid `Cell.type` minus 1, i.e. a cell + would be `Cell.CELL - 1`), or `None` (no label is returned). + + I.e. if it's None, we get `batch_data = dataset[i]`. Otherwise, it's + `batch_data, batch_label = dataset[i]`. + + Usage of `"index"` is to get the original point index. Because + internally, the input points may be filtered or re-ordered. + :param augment: Whether to augment the dataset with the subsequent + parameters. + :param augment_likelihood: Value `[0, 1]` with the probability of a data + item being augmented. I.e. `0.9` means 90% of the data will have been + augmented. + :param flippable_axis: A sequence of the dimensions in the output + data to reverse, if any, with probability `augment_likelihood`. + :param rotate_range: A sequence of floats or sequence of 2-tuples with the + radian angle or range of angles to rotate the output data. Each item + for the corresponding output data dim, with probability + `augment_likelihood`. Or `None` if there's no rotation. + :param translate_range: A sequence of floats or sequence of 2-tuples with + the pixel distance or range of distance to translate the output data. + Each item for the corresponding output data dim, with probability + `augment_likelihood`. Or `None` if there's no translation. + :param scale_range: A sequence of floats or sequence of 2-tuples with + the amount or range of amount to scale the output data. `1` means + no scaling. Each item for the corresponding output data dim, with + probability `augment_likelihood`. Or `None` if there's no scaling. + """ + + points_arr: torch.Tensor = None + """A generated Nx5 tensor. Each row is `(x, y, z, type, index)` columns + with the 3d position of the potential cell. + + Units are voxels in the input data, not microns. + """ + + src_image_data: ImageDataBase | None = None + """ + The `ImageDataBase` that will be used to read the voxel cuboids for a given + point. + """ + + data_cuboid_voxels: tuple[int, int, int] = 1, 1, 1 + """A 3-tuple of the number of voxels in a cuboid of the output data. + The order corresponds to the `output_axis_order`, excluding the channel + dim. See also `cuboid_with_channels_size`. + """ + + num_channels: int = 1 + """ + The number of channels in the image data / cuboids. + """ -class CubeGeneratorFromDisk(Sequence): + augmentation: DataAugmentation | None = None + """ + If provided, used to augment the data during training. """ - Reads in cubes from a list of paths for keras.Model.fit_generator() or - keras.Model.predict_generator() - If augment=True, each augmentation selected has a 50/50 chance of being - applied to each cube + _output_data_dim_reordering: list[int] | None = None + """ + Cached indices that is used with `get_axis_reordering` to convert the + cuboids from the input `axis_order` / + `src_image_data.data_with_channels_axis_order` to the `output_axis_order`. """ def __init__( self, - signal_list: List[Union[str, Path]], - background_list: List[Union[str, Path]], - labels: Optional[List[int]] = None, # only if training or validating - batch_size: int = 64, - shape: Tuple[int, int, int] = (50, 50, 20), - channels: int = 2, + points: Sequence[Cell], + data_voxel_sizes: tuple[float, float, float], + network_voxel_sizes: tuple[float, float, float] = (5, 1, 1), + network_cuboid_voxels: tuple[int, int, int] = (20, 50, 50), + axis_order: tuple[AXIS, AXIS, AXIS] = ("z", "y", "x"), + output_axis_order: tuple[DIM, DIM, DIM, DIM] = ("y", "x", "z", "c"), + src_image_data: ImageDataBase | None = None, classes: int = 2, - shuffle: bool = False, + target_output: Literal["index", "label"] | None = None, augment: bool = False, - augment_likelihood: float = 0.1, - flip_axis: Tuple[int, int, int] = (0, 1, 2), - rotate_max_axes: Tuple[int, int, int] = (45, 45, 45), # degrees - # scale=[0.5, 2], # min, max - translate: Tuple[float, float, float] = (0.2, 0.2, 0.2), - train: bool = False, # also return labels - interpolation_order: int = 2, - *args, + augment_likelihood: float = 0.9, + flippable_axis: Sequence[int] = (0, 1, 2), + rotate_range: RandRange = (math.pi / 4,) * 3, + translate_range: RandRange = (0.05,) * 3, + scale_range: RandRange = ((0.6, 1.4),) * 3, **kwargs, ): - # pass any additional arguments not specified in signature to the - # constructor of the superclass (e.g.: `use_multiprocessing` or - # `workers`) - super().__init__(*args, **kwargs) + super().__init__(**kwargs) + if len(axis_order) != 3 or set(axis_order) != {"z", "y", "x"}: + raise ValueError( + f"Expected the axis order to list x, y, z, but got " + f"{axis_order}" + ) + if len(output_axis_order) != 4 or set(output_axis_order) != { + "z", + "y", + "x", + "c", + }: + raise ValueError( + f"Expected the axis order to list x, y, z, c, but got " + f"{output_axis_order}" + ) + if output_axis_order[-1] != "c": + raise ValueError("output_axis_order must have c last, for now") + + # very important: we can't save the original points in the instance, + # because then it gets duplicated in sub-process. For many million + # cells, this is easily tens of GB. So only save a tensor representing + # the cells that can be shared in memory between processes. + self.points_arr = torch.empty((len(points), 5), dtype=torch.float64) + data = self.points_arr + for i, cell in enumerate(points): + data[i, 0] = cell.x + data[i, 1] = cell.y + data[i, 2] = cell.z + data[i, 3] = cell.type + data[i, 4] = i + # move it to shared memory so it doesn't get duplicated in workers + data.share_memory_() + + self.src_image_data = src_image_data + self.data_voxel_sizes = tuple(data_voxel_sizes) + self.network_voxel_sizes = tuple(network_voxel_sizes) + self.network_cuboid_voxels = tuple(network_cuboid_voxels) + self.axis_order = axis_order + self.output_axis_order = output_axis_order + + data_cuboid_voxels = [] + for data_um, network_um, cuboid_voxels in zip( + data_voxel_sizes, network_voxel_sizes, network_cuboid_voxels + ): + data_cuboid_voxels.append( + get_data_cuboid_voxels(cuboid_voxels, network_um, data_um) + ) + self.data_cuboid_voxels = tuple(data_cuboid_voxels) + if len(self.data_cuboid_voxels) != 3: + raise ValueError("sizes must be length 3 for the 3 axes") - self.im_shape = shape - self.batch_size = batch_size - self.labels = labels - self.signal_list = signal_list - self.background_list = background_list - self.channels = channels self.classes = classes - self.augment = augment - self.augment_likelihood = augment_likelihood - self.flip_axis = flip_axis - self.rotate_max_axes = rotate_max_axes - # self.scale = scale - self.translate = translate - self.train = train - self.interpolation_order = interpolation_order - self.indexes = np.arange(len(self.signal_list)) - if shuffle: - self.on_epoch_end() + self.target_output = target_output + + if augment: + vol_size = { + ax: n for ax, n in zip(axis_order, network_cuboid_voxels) + } + self.augmentation = DataAugmentation( + vol_size, + output_axis_order, + augment_likelihood, + flippable_axis, + translate_range, + scale_range, + rotate_range, + ) - # TODO: implement scale and shear + if src_image_data is not None: + self._set_output_data_dim_reordering(src_image_data) - def on_epoch_end(self) -> None: + @property + def cuboid_with_channels_size(self) -> tuple[int, int, int, int]: + """A 4-tuple of the number of voxels in the cuboid in the output data + and the number of channels. + + The order corresponds to the `output_axis_order`. For now, `"c"` is + assumed last. """ - Shuffle data for each epoch - :return: Shuffled indexes + return *self.data_cuboid_voxels, self.num_channels + + def __len__(self): + return len(self.points_arr) + + def __getitem__(self, idx: int | Sequence[int]): + if isinstance(idx, Integral): + return self._get_single_item(idx) + return self._get_multiple_items(idx) + + def _set_output_data_dim_reordering( + self, src_image_data: ImageDataBase + ) -> None: """ - self.indexes = np.arange(len(self.signal_list)) - np.random.shuffle(self.indexes) + Sets `_output_data_dim_reordering`. + """ + if src_image_data.data_axis_order != self.output_axis_order: + self._output_data_dim_reordering = get_axis_reordering( + ("b", *src_image_data.data_with_channels_axis_order), + ("b", *self.output_axis_order), + ) - def __len__(self) -> int: + def _get_single_item( + self, idx: int + ) -> torch.Tensor | tuple[torch.Tensor, Any]: + """ + Handles `dataset[i]`, when `i` is an int. + """ + point = self.points_arr[idx] + data = self.get_point_data(idx) + + # batch dim + data = data[None, ...] + if self._output_data_dim_reordering is not None: + data = torch.permute(data, self._output_data_dim_reordering) + + data = self.convert_to_output(data) + data = data[0, ...] + + augmentation = self.augmentation + if augmentation is not None and augmentation.update_parameters(): + data[:] = augmentation(data) + + match self.target_output: + case None: + return data + + case "index": + label = point[4] + case "label": + cls = torch.tensor(point[3] - 1, dtype=torch.long) + label = F.one_hot(cls, num_classes=self.classes) + case _: + raise ValueError(f"Unknown target value {self.target_output}") + + return data, label + + def _get_multiple_items( + self, indices: Sequence[int] + ) -> torch.Tensor | tuple[torch.Tensor, Any]: + """ + Handles `dataset[i]`, when `i` is a list of ints. """ - Number of batches - :return: Number of batches per epoch + # numpy arrays can't be indexed with tuples + points = self.points_arr[indices, :] + + data = self.get_points_data(indices) + if self._output_data_dim_reordering is not None: + data = torch.permute(data, self._output_data_dim_reordering) + data = self.convert_to_output(data) + + augmentation = self.augmentation + if augmentation is not None: + # batch is always first index + for b in range(len(indices)): + if augmentation.update_parameters(): + data[b, ...] = augmentation(data[b, ...]) + + match self.target_output: + case None: + return data + + case "index": + labels = points[:, 4] + case "label": + cls = points[:, 3] - 1 + labels = F.one_hot( + cls.to(torch.long), num_classes=self.classes + ) + case _: + raise ValueError(f"Unknown target value {self.target_output}") + + return data, labels + + def convert_to_output(self, data: torch.Tensor) -> torch.Tensor: """ - return int(np.ceil(len(self.signal_list) / self.batch_size)) + Takes the input cuboids, ordered already according to the + `output_axis_order` with the additional batch dim at the start, and + scales the cuboids to the network cuboid size to be returned by the + dataset. + """ + if self.data_voxel_sizes == self.network_voxel_sizes: + return data + + # batch dimension + if len(data.shape) != 5: + raise ValueError("Needs 5 dimensions: batch, channel and space") + + # our data comes in in output_axis_order order. To scale we need to + # convert first to torch_order, which is torch's expected order. We + # then re-order back to the output_axis_order before returning. + torch_order = "b", "c", "z", "y", "x" + data_order = "b", *self.output_axis_order + voxel_order = self.axis_order + + torch_voxel_map = get_axis_reordering(voxel_order, torch_order[2:]) + scaled_cuboid_size = [ + self.network_cuboid_voxels[i] for i in torch_voxel_map + ] + + data = torch.permute( + data, get_axis_reordering(data_order, torch_order) + ) + data = F.interpolate(data, size=scaled_cuboid_size, mode="trilinear") + data = torch.permute( + data, get_axis_reordering(torch_order, data_order) + ) + + return data - def __getitem__(self, index: int) -> Union[ - np.ndarray, - Tuple[np.ndarray, List[Dict[str, float]]], - Tuple[np.ndarray, Dict], - ]: + def get_point_data(self, point_key: int) -> torch.Tensor: """ - Generates a single batch of cubes - :param index: - :return: + Takes a key used to identify a specific point (typically the index in + points_arr) and returns the cuboid centered around the point. + + This handles getting the cuboids for `dataset[i]`, when `i` is an int. + + :param point_key: A unique key used to identify the point. E.g. the + index in `points_arr`. + :return: The 5d cuboid in the output_axis_order order. With the + batch dim as the first axis. """ - # Generate indexes of the batch - start_index = index * self.batch_size - end_index = start_index + self.batch_size - indexes = self.indexes[start_index:end_index] + return self.get_points_data([point_key])[0, ...] - # Get data corresponding to batch - list_signal_tmp = [self.signal_list[k] for k in indexes] - list_background_tmp = [self.background_list[k] for k in indexes] + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + """ + Takes a list of keys used to identify point (typically the indices in + the points list) and returns the cuboids centered around the points. - images = self.__generate_cubes(list_signal_tmp, list_background_tmp) + This handles getting the cuboids for `dataset[i]`, when `i` is a list + of ints. - if self.train and self.labels is not None: - batch_labels = [self.labels[k] for k in indexes] - batch_labels = keras.utils.to_categorical( - batch_labels, num_classes=self.classes - ) - return images, batch_labels.astype(np.float32) - else: - return images + :param points_key: A list of the unique key used to identify the + points. E.g. the indices in `points_arr`. + :return: The 5d cuboids in the output_axis_order order. With the + batch dim as the first axis. + """ + data = torch.empty( + (len(points_key), *self.cuboid_with_channels_size), + dtype=torch.float32, + ) + self.src_image_data.get_point_batch_cuboid_data(data, points_key) + + return data + + +class CuboidThreadedDatasetBase(CuboidDatasetBase): + """ + Adds the ability to share a single `src_image_data` `ImageDataBase` + instance among multiple sub-process workers so that only the main + process reads the data, while the other processes only handle processing + the batches. + + The overall pattern is as follows. When we use a PyTorch data loader, it + creates worker processes internally and duplicates the dataset for each + worker (via pickling). The PyTorch data loader also automatically + assigns the batches each worker loads and yields their result. + + `CuboidDatasetBase` loads the data directly via its `src_image_data` + (`ImageDataBase`) instance. When the `Dataset` is duplicated for each + worker, each worker gets its own copy of `src_image_data` that it reads + from. This causes the sub-processes to read similar data from disk causing + disk contention, as well as cache duplication. + + This class adds the ability for the dataset to read the data from its + `src_image_data` via a sub-thread in the main process, if + `start_dataset_thread` is called. Each worker (including the main process) + in its own process uses a queue to request data from this singular thread, + which reads it from disk (or cache) and sends it back to the worker to + process (rescale / orient etc.). This ensures that only one thread reads + the data from disk. And only the main process gets access to + `src_image_data` so there's no copies. + + Each worker allocates a batch buffer, which is sent and shared in + memory with the main process sub-thread when it requests it to read a + batch. The thread can then directly write into it the loaded data without + having to copy the data back and forth between processes. + + `start_dataset_thread` and `stop_dataset_thread` must be called to start + and close this main sub-thread. Otherwise, each sub-process worker will + have its own copy of the dataset, obsoleting this class' functionality. + + Overall, each worker uses its own queue to communicate with the main + thread. The main reader thread listens to requests via its own single + queue. Each worder's request tells it the points for which to load cuboids + and the queue to use to send the data back along. If it encounters an error + reading the data, it'll instead send the exception back with that queue + so it doesn't hang. The reader thread never exits on its own because it + catches all exceptions. It only exits when `stop_dataset_thread` is called. + + Similarly, the worker processes don't exit until PyTorch closes them at its + own leisure. But if we call `stop_dataset_thread` while the workers are + still running, it'll cause them to raise exceptions to PyTorch (and exit). + + This class doesn't interfere with our ability to manually request a cube + from the dataset in the main process. Depending on whether + `start_dataset_thread` was called it may load the data via the reader + thread. But from the requesters POV, `dataset[indices]` will always work. + """ + + _dataset_thread: ThreadWithExceptionMPSafe | None = None + """The singular sub-thread that all workers use to request data from. + + The thread is created when `start_dataset_thread` is called. + """ + + _worker_queues: list[Queue] = None + """ + The queues used to get back data from the singular main thread. - def __generate_cubes( + Each worker is identified via its + `0 <= get_worker_info().id < num_workers`. queue `len(queues) - 1` + is used by the main process if that tries reading data. + """ + + def __init__( self, - list_signal_tmp: List[Union[str, Path]], - list_background_tmp: List[Union[str, Path]], - ) -> np.ndarray: - number_images = len(list_signal_tmp) - images = np.empty( - ((number_images,) + self.im_shape + (self.channels,)), - dtype=np.float32, + **kwargs, + ): + super().__init__(**kwargs) + self._worker_queues = [] + + def __getstate__(self): + state = self.__dict__.copy() + if self._dataset_thread is not None: + # we have to prevent copies of src_image_data so that data is not + # duplicated among sub-processes if we use them. + del state["src_image_data"] + return state + + def get_point_data(self, point_key: int) -> torch.Tensor: + thread = self._dataset_thread + if thread is None: + # if start_dataset_thread was not called, use the default + # functionality of each process reading on its own + return super().get_point_data(point_key) + + return self._send_rcv_thread_msg( + self.cuboid_with_channels_size, + "point", + point_key, + ) + + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + thread = self._dataset_thread + if thread is None: + # if start_dataset_thread was not called, use the default + # functionality of each process reading on its own + return super().get_points_data(points_key) + + return self._send_rcv_thread_msg( + (len(points_key), *self.cuboid_with_channels_size), + "points", + points_key, ) - for idx, signal_im in enumerate(list_signal_tmp): - background_im = list_background_tmp[idx] - images = self.__populate_array_with_cubes( - images, idx, signal_im, background_im + def _send_rcv_thread_msg( + self, buffer_shape: tuple[int, ...], subject: str, key: Any + ) -> torch.Tensor: + # create buffer into which the read thread will write the data into + data = torch.empty(buffer_shape, dtype=torch.float32) + + queues = self._worker_queues + # main-process, uses the last queue, if/(and when) there are no workers + queue_id = len(queues) - 1 + if get_worker_info() is not None: + queue_id = get_worker_info().id + queue = queues[queue_id] + + # request the data from the main process data thread + self._dataset_thread.send_msg_to_thread((subject, key, data, queue_id)) + # do a single request and wait for the result + msg, value = queue.get(block=True) + # if there's no error, we just sent back the point + if msg == "exception": + # the reader thread reported an error + raise ExecutionFailure( + "Reporting failure from data thread" + ) from value + if msg == "eof": + # the main process wants to exit + raise ValueError( + "Worker processes was asked to exit while waiting for data" ) + assert msg == subject + + return data + + def start_dataset_thread(self, num_workers: int) -> None: + """ + Must be called if we want the functionality of this class when using + multiple sub-process workers. This creates the singular sub-thread etc. + + :param num_workers: The number of sub-process workers torch will use. + If it's zero, meaning only the main process will process the + batches, we still use a sub-thread to read the data. If you don't + wish this, don't call `start_dataset_thread` and it'll be loaded + directly in the main process without the sub-thread. + """ + # include queue for host thread + ctx = mp.get_context("spawn") + # we use maxsize=0 to prevent potential locking issues. But, we never + # actually request more than one data batch at a time over a queue + queues = [ctx.Queue(maxsize=0) for _ in range(num_workers + 1)] + self._worker_queues = queues + + self._dataset_thread = ThreadWithExceptionMPSafe( + target=_read_data_send_cuboids, + args=(self.src_image_data, queues), + pass_self=True, + ) + self._dataset_thread.start() + + def stop_dataset_thread(self) -> None: + """ + Must be called to shut down the data leading sub-thread when done. + + This should be called when the sub-process workers have stopped + reading data. If the worker processes are still waiting for returned + data in their queues, they will raise an error. + """ + thread = self._dataset_thread + if thread is None: + return + + thread.notify_to_end_thread() + thread.clear_remaining() + thread.join() + + self._dataset_thread = None + # if we exit while workers are still waiting, let them know to not hang + for queue in self._worker_queues: + queue.put(("eof", None), block=False) + self._worker_queues = [] + - return images +class CuboidStackDataset(CuboidThreadedDatasetBase): + """ + Implements `CuboidThreadedDatasetBase` using a `CachedArrayStackImageData`, + which reads the cuboids from array type data (e.g. Dask arrays). + + + :param signal_array: The signal data array. + :param background_array: The background data array. If None, only the + signal channel is used. + :param max_axis_0_cuboids_buffered: Each cuboid requires `n` planes + along axis 0. With the assumption that data is read from disk in + planes of this first axis, buffering these planes is advantages. + This determines how many multiples (or fractions) of `n` such planes + to buffer, in addition to `n` that is always buffered. So e.g. `1` + means `2n` and `0.5` means `1.5n`. `1` is a good default. + :param signal_normalization: None or a 2-tuple of `(mean, std)`. + If not None, the signal channel in the cubes will be normalized to the + provided mean and standard deviation. + :param background_normalization: None or a 2-tuple of `(mean, std)`. + If not None, the background channel in the cubes will be normalized to + the provided mean and standard deviation. + """ - def __populate_array_with_cubes( + def __init__( self, - images: np.ndarray, - idx: int, - signal_im: Union[str, Path], - background_im: Union[str, Path], - ) -> np.ndarray: - if self.augment: - self.augmentation_parameters = AugmentationParameters( - self.flip_axis, - self.translate, - self.rotate_max_axes, - self.interpolation_order, - self.augment_likelihood, + signal_array: types.array, + background_array: types.array | None, + max_axis_0_cuboids_buffered: float = 0, + signal_normalization: None | tuple[float, float] = None, + background_normalization: None | tuple[float, float] = None, + **kwargs, + ): + super().__init__(**kwargs) + + if get_max_possible_int_value( + signal_array.dtype + ) > get_max_possible_int_value(np.float32): + raise ValueError( + f"Input signal array has data type {signal_array.dtype}, " + f"which cannot fit in a float32" + ) + if background_array is not None and get_max_possible_int_value( + background_array.dtype + ) > get_max_possible_int_value(np.float32): + raise ValueError( + f"Input background array has data type " + f"{background_array.dtype}, which cannot fit in a float32" ) - images[idx, :, :, :, 0] = self.__get_oriented_image(signal_im) - images[idx, :, :, :, 1] = self.__get_oriented_image(background_im) - return images + if background_array is None: + data_arrays = [ + signal_array, + ] + self.num_channels = 1 + else: + data_arrays = [signal_array, background_array] + self.num_channels = 2 + + if signal_array.shape != background_array.shape: + raise ValueError( + f"Shape of signal images ({signal_array.shape}) does not " + f"match the shape of the background images " + f"({background_array.shape}" + ) + + if len(signal_array.shape) != 3: + raise ValueError("Expected a 3d in data array") + + self.stack_shape = signal_array.shape + + mask = torch.tensor( + [ + self.point_has_full_cuboid(p[:3].tolist()) + for p in self.points_arr + ], + dtype=torch.bool, + ) + self.points_arr = self.points_arr[mask, :] + # move it to shared memory so it doesn't get duplicated in workers + self.points_arr.share_memory_() + + self.src_image_data = CachedArrayStackImageData( + points_arr=self.points_arr[:, :3], + input_arrays=data_arrays, + data_axis_order=self.axis_order, + max_axis_0_cuboids_buffered=max_axis_0_cuboids_buffered, + cuboid_size=self.data_cuboid_voxels, + ) + self._set_output_data_dim_reordering(self.src_image_data) + + self.signal_normalization = signal_normalization + self.background_normalization = background_normalization + + def point_has_full_cuboid(self, point: Sequence[float]) -> bool: + """ + Takes a 3d point and returns whether a cuboid centered on this point + is fully contained within the signal / background array. I.e. it's not + too close to the edges etc. + """ + for ax, axis_size, cuboid_size in zip( + self.axis_order, self.stack_shape, self.data_cuboid_voxels + ): + idx = _point_ax_map[ax] + start, end = get_data_cuboid_range(point[idx], cuboid_size, ax) + if start < 0: + return False + if end > axis_size: + # if it's axis_size it's fine because end is not inclusive + return False + + return True + + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + data = super().get_points_data(points_key) + + if self.signal_normalization is not None: + mean, std = self.signal_normalization + data[..., 0] -= mean + data[..., 0] /= std + if self.background_normalization is not None: + mean, std = self.background_normalization + data[..., 1] -= mean + data[..., 1] /= std + + return data + + +class CuboidTiffDataset(CuboidThreadedDatasetBase): + """ + Implements `CuboidThreadedDatasetBase` using a `CachedTiffCuboidImageData`, + which reads the cuboids from individual tiff files. + + :param points_filenames: A sequence of sequences of the filenames of the + cubes. E.g. `[("cube1.1.tiff", "cube1.2.tiff"), ("cube2.1.tiff", + "cube2.2.tiff")]`. + + The outer list is the number of points/samples. The inner lists is the + number of channels (e.g. signal/background) for the given point. + :param points_normalization: None or a sequence of sequences of 2-tuples + of `(mean, std)`. + + If not None, each 2-tuple corresponds to a single filename in + `points_filenames` and that cube will be normalized by the given + mean and standard deviation before returning it. + :param max_cuboids_buffered: Integer + The number of the most recently accessed cuboids to cache so it isn't + read from disk again. + """ + + def __init__( + self, + points_filenames: Sequence[Sequence[str]], + points_normalization: ( + Sequence[Sequence[tuple[float, float]]] | None + ) = None, + max_cuboids_buffered: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + if not len(points_filenames): + raise ValueError("No data provided") + if len(points_filenames) != len(self.points_arr): + raise ValueError( + "Points and filenames must have same number of elements" + ) + if points_normalization is not None and len(points_filenames) != len( + points_normalization + ): + raise ValueError("Must have normalizations for all elements") + + self.num_channels = len(points_filenames[0]) + filenames_arr = np.array(points_filenames).astype(np.str_) + self.filenames_arr = filenames_arr + self.points_norm_arr = None + if points_normalization is not None: + self.points_norm_arr = torch.tensor(points_normalization) + + self.src_image_data = CachedTiffCuboidImageData( + points_arr=self.points_arr[:, :3], + filenames_arr=filenames_arr, + data_axis_order=self.axis_order, + max_cuboids_buffered=max_cuboids_buffered, + cuboid_size=self.data_cuboid_voxels, + ) + self._set_output_data_dim_reordering(self.src_image_data) + + def get_points_data(self, points_key: Sequence[int]) -> torch.Tensor: + data = super().get_points_data(points_key) - def __get_oriented_image(self, image_path: Union[str, Path]) -> np.ndarray: - # if paths are pathlib objs, skimage only reads one plane - image = np.moveaxis(imread(image_path), 0, 2) - if self.augment: - image = augment(self.augmentation_parameters, image) - return image + if self.points_norm_arr is not None: + norms = self.points_norm_arr[tuple(points_key), ...] + mean = norms[:, :, 0].unsqueeze(1).unsqueeze(1).unsqueeze(1) + std = norms[:, :, 1].unsqueeze(1).unsqueeze(1).unsqueeze(1) + data -= mean + data /= std + return data -def get_cube_depth_min_max( - centre_plane: int, num_planes_needed_for_cube: int -) -> Tuple[int, int]: - half_cube_depth = num_planes_needed_for_cube // 2 - min_plane = centre_plane - half_cube_depth - if is_even(num_planes_needed_for_cube): - # WARNING: not centered because even - max_plane = centre_plane + half_cube_depth - else: - # centered - max_plane = centre_plane + half_cube_depth + 1 +class CuboidBatchSampler(Sampler): + """ + Custom Sampler for our `CuboidDatasetBase` that helps with sampling planes + of data. + + It is used as:: + + dataset = CuboidStackDataset(...) + sampler = CuboidBatchSampler(dataset=dataset, ...) + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + sampler=sampler, + batch_size=None, + ... + ) + + Or e.g. just:: + + dataset = CuboidStackDataset(...) + sampler = CuboidBatchSampler(dataset=dataset, ...) + for batch in sampler: + data, labels = dataset[batch] + ... + + To get the batch values. + + When used with a `DataLoader`, `CuboidBatchSampler` is doing the + batching, including any shuffling, instead of the `DataLoader` itself. + Our `sampler` must be passed to the `sampler` argument, `batch_size` must + be set to `None`, and `shuffle` shouldn't be used. + + `DataLoader` will return the data according to the order specified by the + sampler. So e.g. `items = list(DataLoader(..., sampler=...))` will + have yielded items in the same index order as the corresponding indices in + `batches = list(sampler)` (assuming the sampler doesn't automatically + reshuffle). So to get the original cells passed to a dataset, associated + with a data item, just index dataset.points[i] for a given i in a batch in + `batches`. + + :param dataset: The `CuboidDatasetBase` that will be sampled. + :param batch_size: The size of the batches that `CuboidBatchSampler` will + yield to the data loader. + :param auto_shuffle: If True, every time we create an iterator of the + sampler, the data is shuffled. E.g. every epoch when we get the data, + the data is reshuffled. If `sort_by_axis` is True, + :param sort_by_axis: + """ + + def __init__( + self, + dataset: CuboidDatasetBase, + batch_size: int, + auto_shuffle: bool = False, + sort_by_axis: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.batch_size = batch_size + self.auto_shuffle = auto_shuffle + + if sort_by_axis is None: + # if not sorted, we just have an array of all the data. When + # shuffling later, this array is shuffled + plane_indices = [ + np.arange(len(dataset.points_arr), dtype=np.int64) + ] + else: + # if it is sorted, then sort and segregate into arrays split by + # the unique values of the given axis. When shuffling later, each + # segregated array is shuffled separately so the sorting is + # respected + points_raw = defaultdict(list) + for i, point in enumerate(dataset.points_arr): + p_idx = _point_ax_map[sort_by_axis] + plane = point[p_idx].item() + points_raw[plane].append(i) + points_sorted = sorted(points_raw.items(), key=lambda x: x[0]) + + # convert the segregated lists into arrays + plane_indices = [ + np.array(indices, dtype=np.int64) + for _, indices in points_sorted + ] + + self.plane_indices = plane_indices + + self.n_batches = len(self.get_batches(False)) + + def get_batches(self, shuffle): + indices = self.plane_indices + batch_size = self.batch_size + + if shuffle: + rng = np.random.default_rng() + # if sorted, that each item is an array of the indices of points + # with the same plane value. and the overall indices is sorted by + # those values. So we sort each item separately so it remains + # sorted. + # permuted creates copy that is shuffled + indices = [rng.permuted(items) for items in indices] + + batches = [] + # whether sorted or not, make it one giant array and create batches. + # If sorted, a batch may have multiple plane indices, e.g. if there's + # only one point for a given plane, but they are still sorted across + # batches. So only the last batch may not be a full batch. + arr = np.concatenate(indices) + for i in range(math.ceil(len(arr) / batch_size)): + batches.append(arr[i * batch_size : (i + 1) * batch_size]) + + return batches + + def __len__(self) -> int: + return self.n_batches - return min_plane, max_plane + def __iter__(self): + yield from self.get_batches(self.auto_shuffle) diff --git a/cellfinder/core/classify/tools.py b/cellfinder/core/classify/tools.py index 7ad13546..df22f924 100644 --- a/cellfinder/core/classify/tools.py +++ b/cellfinder/core/classify/tools.py @@ -1,9 +1,7 @@ import os -from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from typing import Optional import keras -import numpy as np from keras import Model from cellfinder.core import logger @@ -35,7 +33,7 @@ def get_model( """ if existing_model is not None or network_depth is None: logger.debug(f"Loading model: {existing_model}") - return keras.models.load_model(existing_model) + model = keras.models.load_model(existing_model) else: logger.debug(f"Creating a new instance of model: {network_depth}") model = build_model( @@ -59,29 +57,7 @@ def get_model( "Provided weights don't match the model architecture.\n" ) from e - return model - - -def make_lists( - tiff_files: Sequence, - train: bool = True, -) -> Union[Tuple[List, List], Tuple[List, List, np.ndarray]]: - signal_list = [] - background_list = [] - if train: - labels = [] + if inference: + model.trainable = False - for group in tiff_files: - for image in group: - signal_list.append(image.img_files[0]) - background_list.append(image.img_files[1]) - if train: - if image.label == "no_cell": - labels.append(0) - elif image.label == "cell": - labels.append(1) - - if train: - return signal_list, background_list, np.array(labels) - else: - return signal_list, background_list + return model diff --git a/cellfinder/core/main.py b/cellfinder/core/main.py index 7ff3418b..1d6d0ade 100644 --- a/cellfinder/core/main.py +++ b/cellfinder/core/main.py @@ -45,6 +45,9 @@ def main( detect_callback: Optional[Callable[[int], None]] = None, classify_callback: Optional[Callable[[int], None]] = None, detect_finished_callback: Optional[Callable[[list], None]] = None, + normalize_channels: bool = False, + normalization_down_sampling: int = 32, + classification_max_workers: int = 6, ) -> List[Cell]: """ Parameters @@ -156,6 +159,18 @@ def main( Called with the batch number that has just finished. detect_finished_callback : Callable[list], optional Called after detection is finished with the list of detected points. + normalize_channels : bool + If True, the signal and background data will be each normalized + to a mean of zero and standard deviation of 1 before classification. + Defaults to False. + normalization_down_sampling : int + If `normalize_channels` is True, the data arrays will be down-sampled + in the first axis by this value before calculating their statistics + before classification. E.g. a value of 2 means every second plane will + be used. Defaults to 32. + classification_max_workers : int + The max number of sub-processes to use for data loading / processing + during classification. Defaults to 6. """ from cellfinder.core.classify import classify from cellfinder.core.detect import detect @@ -215,6 +230,9 @@ def main( model_weights, network_depth, callback=classify_callback, + normalize_channels=normalize_channels, + normalization_down_sampling=normalization_down_sampling, + max_workers=classification_max_workers, ) else: logger.info("No candidates, skipping classification") diff --git a/cellfinder/core/tools/image_processing.py b/cellfinder/core/tools/image_processing.py index 742220d6..bf3811ac 100644 --- a/cellfinder/core/tools/image_processing.py +++ b/cellfinder/core/tools/image_processing.py @@ -1,6 +1,10 @@ import numpy as np +import tqdm from brainglobe_utils.general.numerical import is_even +from cellfinder.core import types +from cellfinder.core.tools.tools import get_data_converter + def crop_center_2d(img, crop_x=None, crop_y=None): """ @@ -85,3 +89,62 @@ def pad_center_2d(img, x_size=None, y_size=None, pad_mode="edge"): y_front = y_back = 0 return np.pad(img, ((y_front, y_back), (x_front, x_back)), pad_mode) + + +def dataset_mean_std( + dataset: types.array, + sampling_factor: int, + show_progress: bool = True, + progress_desc="Estimating channel mean/std", +) -> tuple[float, float]: + """ + Calculates the mean and sample standard deviation of a 3d dataset using + Welford's online algorithm, sampling it along its first dimension. + + :param dataset: A 3d dataset, such as a numpy or dask array. + :param sampling_factor: The sampling factor to sample along the first + dimension. E.g. if the dataset is 10 x 100 x 100 and `sampling_factor` + is 3, then we'll use planes 0, 3, 6, 9 for the calculation (40_000 + data points). + :param show_progress: Whether to show a progress bar during the + calculation. + :param progress_desc: If showing a progress bar, the description to use in + it. + :return: A 2-tuple of `(mean, std)` estimate of the dataset. + """ + # based on https://en.wikipedia.org/wiki/ + # Algorithms_for_calculating_variance#Welford's_online_algorithm + # and https://stackoverflow.com/q/56402955 + plane_n = dataset.shape[1] * dataset.shape[2] + # get data converter from dataset to float64 + converter = get_data_converter(dataset.dtype, np.float64) + + count = 0 + mean = np.array(0, dtype=np.float64) + sq_dist = np.array(0, dtype=np.float64) + + # make it a list so tqdm will know its full size + samples = list(range(0, len(dataset), sampling_factor)) + if show_progress: + it = tqdm.tqdm(samples, desc=progress_desc, unit="planes") + else: + it = samples + + for i in it: + plane = converter(dataset[i, ...]) + # flatten it + new_value = plane.reshape((plane_n,)) + + count += plane_n + delta = new_value - mean + mean += np.sum(delta) / count + delta2 = new_value - mean + sq_dist += np.sum(np.multiply(delta, delta2)) + + if count <= 1: + raise ValueError("Not enough data to compute the variance") + + var_sample = sq_dist / (count - 1) + std = np.sqrt(var_sample) + + return mean.item(), std.item() diff --git a/cellfinder/core/tools/threading.py b/cellfinder/core/tools/threading.py index e5d1e97a..e3071511 100644 --- a/cellfinder/core/tools/threading.py +++ b/cellfinder/core/tools/threading.py @@ -323,11 +323,14 @@ class ThreadWithException(ExceptionWithQueueMixIn): def __init__(self, target, args=(), **kwargs): super().__init__(target=target, **kwargs) - self.to_thread_queue = Queue(maxsize=0) - self.from_thread_queue = Queue(maxsize=0) + self._setup_queues() self.args = args self.thread = Thread(target=self.user_func_runner) + def _setup_queues(self): + self.to_thread_queue = Queue(maxsize=0) + self.from_thread_queue = Queue(maxsize=0) + def start(self) -> None: """Starts the thread that runs the target function.""" self.thread.start() @@ -342,6 +345,43 @@ def join(self, timeout: Optional[float] = None) -> None: self.thread.join(timeout=timeout) +class ThreadWithExceptionMPSafe(ThreadWithException): + """ + Similar to `ThreadWithException`, except it is safe to use in objects + that are expected to be shared across multiple sub-process. + + `ThreadWithException` cannot be used if it's an attribute of an object that + will be "forked" or "spawned" (i.e. duplicated) in a sub-process because + some of the objects used are multiprocess incompatible as well as + unserializable and you'll get exceptions. This class fixes it. + + The downside is it uses more resilient queues for communication, which + requires any message content to be serializable as well as passable to a + sub-process. E.g. `ThreadWithExceptionMPSafe` itself or any + multiprocessing queues cannot be passed across multiprocess queues. So it + cannot be sent as a message. Instead, when it's an attribute of an object + Python ensures that all spawned sub-processes that can access that object + (or rather a copy of it) can also access queues or this class via the + attribute. But it cannot be sent in a queue. + + So *only* use this class if you must and only send serializable messages. + """ + + def _setup_queues(self): + # use multiprocess queues so subprocesses can communicate with thread + ctx = mp.get_context("spawn") + self.to_thread_queue = ctx.Queue(maxsize=0) + self.from_thread_queue = ctx.Queue(maxsize=0) + + def __getstate__(self): + items = self.__dict__.copy() + # don't pickle thread or args passed to it because thread and + # potentially args can't be sent to other processes + del items["thread"] + del items["args"] + return items + + class ProcessWithException(ExceptionWithQueueMixIn): """ Runs a target function in a sub-process. diff --git a/cellfinder/core/tools/tiff.py b/cellfinder/core/tools/tiff.py index a2813ce8..b0925e17 100644 --- a/cellfinder/core/tools/tiff.py +++ b/cellfinder/core/tools/tiff.py @@ -5,32 +5,39 @@ from brainglobe_utils.cells.cells import Cell, UntypedCell -class TiffList(object): - """This class represents list of tiff files. These tiff files are the - output from the cell extractor plugin - and are used as training and classification data. +class TiffList: + """ + Represents a list of tiff files output from the cell extractor plugin and + used for training and classification. + + Represents tiff files whose names end with `Ch[ch].tif`, where `[ch]` is + the non-zero-padded channel number, starting with channel 1. Given a list + of tiff files for the first channel (`...Ch1.tif`), it will find the other + corresponding tif files for the channels passed in the `channels` + parameter. + + :param ch1_list: List of the paths to the tiff files of the first channel. + :param channels: List of channels numbers for which we have tif files + (zero or one based). + :param label: Label of all these tiffs (e.g. `"cell"`, `"no_cell"`). """ - def __init__(self, ch1_list, channels, label=None): - """A list of tiff files output by the cell extractor plugin to be - used in machine learning. - Expects file names to end with "ch[ch].tif", where [ch] is the - non-zero-padded channel index. - Given a list of tiff files for the first channel, it will find the - corresponding files for the - channels passed in the [channels] parameter. - :param ch1_list: List of the tiff files of the first channel. - :param channels: List of channels to use. - :param label: Label of the directory (e.g. 2 for cell, 1 for no cell). - Can be ignored on classification runs. - """ - + def __init__( + self, + ch1_list: list[str], + channels: list[int], + channels_metadata: list[dict], + label: str | None = None, + ): self.ch1_list = natsort.natsorted(ch1_list) self.label = label self.channels = channels + self.channels_metadata = channels_metadata - def make_tifffile_list(self): + def make_tifffile_list(self) -> list["TiffFile"]: """ + Splits the list of tiffs represented by this instance into a list of + `TiffFile` instances. :return: Returns the relevant tiff files as a list of TiffFile objects. """ @@ -41,17 +48,26 @@ def make_tifffile_list(self): ] tiff_files = [ - TiffFile(tiffFile, self.channels, self.label) for tiffFile in files + TiffFile( + tiffFile, self.channels, self.channels_metadata, self.label + ) + for tiffFile in files ] return tiff_files class TiffDir(TiffList): - """A simplified version of TiffList that uses all tiff files without - any filtering. + """Like `TiffList` except it takes a directory (`tiff_dir`) and gets all + the tiffs in that directory that match the `Ch[ch].tif` pattern. """ - def __init__(self, tiff_dir, channels, label=None): + def __init__( + self, + tiff_dir: str, + channels: list[int], + channels_metadata: list[dict], + label: str | None = None, + ): super(TiffDir, self).__init__( [ join(tiff_dir, f) @@ -59,37 +75,71 @@ def __init__(self, tiff_dir, channels, label=None): if f.lower().endswith("ch" + str(channels[0]) + ".tif") ], channels, + channels_metadata, label, ) -class TiffFile(object): +class TiffFile: """This class represents a multichannel tiff file, with one individual file per channel. + + :param path: The full path to the first (zero or one based) channel's tiff + file. + :param channels: List of channels numbers for which we have tiff files + next to the first channel. It must match the `Ch[ch].tif` pattern. + :param label: Label of the tiffs (e.g. `"cell"`, `"no_cell"`). """ - def __init__(self, path, channels, label=None): + def __init__( + self, + path: str, + channels: list[int], + channels_metadata: list[dict], + label: str | None = None, + ): self.path = path self.channels = channels + self.channels_metadata = channels_metadata self.label = label def files_exist(self): + """ + Returns whether the tiffs actually exist on disk for all the channels + represented by this instance. + """ return all([isfile(tif) for tif in self.img_files]) - def as_cell(self, force_typed=True): + def as_cell(self, force_typed=True) -> Cell | UntypedCell: + """ + Returns a `Cell` instance that represents the (potential) cell for + whom the tiff files was saved. + + :param force_typed: If True returns a `Cell`. If False, it returns a + `UntypedCell` if `self.label` is None and `Cell` otherwise. + :return: + """ if force_typed: - return ( - Cell(self.path, -1) - if self.label is None - else Cell(self.path, self.label) - ) - else: - return ( - UntypedCell(self.path) - if self.label is None - else Cell(self.path, self.label) - ) + match self.label: + case None: + cell_type = Cell.ARTIFACT + case "cell": + cell_type = Cell.CELL + case "no_cell": + cell_type = Cell.NO_CELL + case _: + raise ValueError(f"Unknown cell type {self.label}") + + return Cell(self.path, cell_type) + + if self.label is None: + return UntypedCell(self.path) + return Cell(self.path, self.label) @property - def img_files(self): + def img_files(self) -> list[str]: + """ + Returns a list of the full filenames to the tiffs for all channels + represented by this instance. + """ return [self.path[:-5] + str(ch) + ".tif" for ch in self.channels] diff --git a/cellfinder/core/tools/tools.py b/cellfinder/core/tools/tools.py index 231515de..53832dcb 100644 --- a/cellfinder/core/tools/tools.py +++ b/cellfinder/core/tools/tools.py @@ -1,6 +1,6 @@ -from functools import wraps +from functools import partial, wraps from random import getrandbits, uniform -from typing import Callable, Optional, Type +from typing import Any, Callable, Optional, Sequence, Type import numpy as np import torch @@ -56,6 +56,34 @@ def get_min_possible_int_value(dtype: Type[np.number]) -> int: raise ValueError("datatype must be of integer or floating data type") +def _unchanged(data: np.ndarray) -> np.ndarray: + return np.asarray(data) + + +def _float_to_float_scale_down( + data: np.ndarray, + in_abs_max: float, + out_abs_max: float, + dest_dtype: np.dtype, +) -> np.ndarray: + return ((np.asarray(data) / in_abs_max) * out_abs_max).astype(dest_dtype) + + +def _int_to_float_scale_down( + data: np.ndarray, + in_abs_max: float, + out_abs_max: float, + dest_dtype: np.dtype, +) -> np.ndarray: + # data must fit in float64 + data = np.asarray(data).astype(np.float64) + return ((data / in_abs_max) * out_abs_max).astype(dest_dtype) + + +def _to_float_unscaled(data: np.ndarray, dest_dtype: np.dtype) -> np.ndarray: + return np.asarray(data).astype(dest_dtype) + + def get_data_converter( src_dtype: Type[np.number], dest_dtype: Type[np.floating] ) -> Callable[[np.ndarray], np.ndarray]: @@ -82,6 +110,8 @@ def get_data_converter( A function that takes a single input data parameter and returns the converted data. """ + # converter functions must be global functions so they can be serialized + # and passed to other processes if not np.issubdtype(dest_dtype, np.float32) and not np.issubdtype( dest_dtype, np.float64 ): @@ -97,28 +127,12 @@ def get_data_converter( in_abs_max = max(in_max, abs(in_min)) out_abs_max = max(out_max, abs(out_min)) - def unchanged(data: np.ndarray) -> np.ndarray: - return np.asarray(data) - - def float_to_float_scale_down(data: np.ndarray) -> np.ndarray: - return ((np.asarray(data) / in_abs_max) * out_abs_max).astype( - dest_dtype - ) - - def int_to_float_scale_down(data: np.ndarray) -> np.ndarray: - # data must fit in float64 - data = np.asarray(data).astype(np.float64) - return ((data / in_abs_max) * out_abs_max).astype(dest_dtype) - - def to_float_unscaled(data: np.ndarray) -> np.ndarray: - return np.asarray(data).astype(dest_dtype) - if src_dtype == dest_dtype: - return unchanged + return _unchanged # out can hold the largest in values - just convert to float if out_min <= in_min < in_max <= out_max: - return to_float_unscaled + return partial(_to_float_unscaled, dest_dtype=dest_dtype) # need to scale down before converting to float if np.issubdtype(src_dtype, np.integer): @@ -131,11 +145,21 @@ def to_float_unscaled(data: np.ndarray) -> np.ndarray: f"The input datatype {src_dtype} cannot fit in a " f"64-bit float" ) - return int_to_float_scale_down + return partial( + _int_to_float_scale_down, + in_abs_max=in_abs_max, + out_abs_max=out_abs_max, + dest_dtype=dest_dtype, + ) # for float input, however big it is, we can always scale it down in the # input data type before changing type - return float_to_float_scale_down + return partial( + _float_to_float_scale_down, + in_abs_max=in_abs_max, + out_abs_max=out_abs_max, + dest_dtype=dest_dtype, + ) def union(a, b): @@ -293,3 +317,27 @@ def all_elements_equal(x) -> bool: :return: True if all elements are equal, False otherwise. """ return len(set(x)) <= 1 + + +def get_axis_reordering( + in_order: Sequence[Any], out_order: Sequence[Any] +) -> list[int]: + """ + Helps re-order a tensor, given an arbitrary labeled input and output axis + ordering. + + E.g. if the original ordering of 3d data was (a, b, c) and we want + (b, a, c):: + + reordering = get_axis_reordering(("a", "b", "c"), ("b", "a", "c")) + reordered = torch.permute(data, reordering) + + :param in_order: A sequence of named axes. They could be arbitrary values. + :param out_order: A re-ordered sequence of in_order with the desired order. + :return: A list of indices that can be passed to `torch.permute` to reorder + the data tensor. + """ + indices = [] + for value in out_order: + indices.append(in_order.index(value)) + return indices diff --git a/cellfinder/core/train/train_yaml.py b/cellfinder/core/train/train_yaml.py index 4c173a74..e6d5f35a 100644 --- a/cellfinder/core/train/train_yaml.py +++ b/cellfinder/core/train/train_yaml.py @@ -12,9 +12,11 @@ ArgumentTypeError, ) from datetime import datetime +from functools import partial from pathlib import Path -from typing import Dict, Literal +from typing import Dict, Literal, Sequence +from brainglobe_utils.cells.cells import Cell from brainglobe_utils.general.numerical import ( check_positive_float, check_positive_int, @@ -26,16 +28,26 @@ from brainglobe_utils.IO.cells import find_relevant_tiffs from brainglobe_utils.IO.yaml import read_yaml_section from fancylog import fancylog -from keras.callbacks import CSVLogger, ModelCheckpoint, TensorBoard +from keras.callbacks import ( + CSVLogger, + LearningRateScheduler, + ModelCheckpoint, + TensorBoard, +) from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader import cellfinder.core as program_for_log from cellfinder.core import logger -from cellfinder.core.classify.cube_generator import CubeGeneratorFromDisk +from cellfinder.core.classify.cube_generator import ( + CuboidBatchSampler, + CuboidTiffDataset, +) from cellfinder.core.classify.resnet import layer_type -from cellfinder.core.classify.tools import get_model, make_lists +from cellfinder.core.classify.tools import get_model from cellfinder.core.download.download import DEFAULT_DOWNLOAD_DIRECTORY from cellfinder.core.tools.prep import prep_model_weights +from cellfinder.core.tools.tiff import TiffDir, TiffFile, TiffList depth_type = Literal["18", "34", "50", "101", "152"] @@ -48,6 +60,22 @@ } +CUBE_WIDTH = 50 +CUBE_HEIGHT = 50 +CUBE_DEPTH = 20 + + +def lr_scheduler( + epoch: int, + lr: float, + multiplier: float, + epoch_list: Sequence[int], +) -> float: + if epoch in epoch_list: + return lr * multiplier + return lr + + def valid_model_depth(depth): """ Ensures a correct existing_model is chosen @@ -175,6 +203,13 @@ def training_parse(): default=100, help="Number of training epochs", ) + training_parser.add_argument( + "--max-workers", + dest="max_workers", + type=check_positive_int, + default=3, + help="Maximum number of worker processes to use to load data", + ) training_parser.add_argument( "--test-fraction", dest="test_fraction", @@ -195,6 +230,15 @@ def training_parse(): action="store_true", help="Don't apply data augmentation", ) + training_parser.add_argument( + "--augment-likelihood", + dest="augment_likelihood", + type=check_positive_float, + default=0.9, + help="Value `[0, 1]` with the probability of a data item being " + "augmented. I.e. `0.9` means 90% of the data will have been " + "augmented.", + ) training_parser.add_argument( "--save-weights", dest="save_weights", @@ -219,6 +263,33 @@ def training_parse(): action="store_true", help="Save training progress to a .csv file", ) + training_parser.add_argument( + "--normalize-channels", + dest="normalize_channels", + action="store_true", + help="Normalize the training data to the mean/std of the datasets " + "from which the cubes came from", + ) + training_parser.add_argument( + "--lr-schedule", + dest="lr_schedule", + nargs="*", + type=partial(check_positive_int, none_allowed=False), + default=(), + help="If not empty, the list of epochs when to multiply the current " + "learning rate by the lr_multiplier. E.g. if it's [10, 25], we " + "start with a learning rate of 0.001, and lr_multiplier is " + "0.1, then the LR will be 0.001 for epochs 0-9, 0.0001 for 10-24," + " and 00001 for epoch 25 and beyond.", + ) + training_parser.add_argument( + "--lr-multiplier", + dest="lr_multiplier", + type=partial(check_positive_float, none_allowed=False), + default=0.1, + help="The multiplier by which to multiply the previous learning rate " + "at the epochs listed in lr_schedule.", + ) training_parser = misc_parse(training_parser) training_parser = download_parser(training_parser) @@ -234,15 +305,36 @@ def parse_yaml(yaml_files, section="data"): return data -def get_tiff_files(yaml_contents): - from cellfinder.core.tools.tiff import TiffDir, TiffList - +def get_tiff_files(yaml_contents: list[dict]) -> list[list[TiffFile]]: + """ + Takes a yaml file representing multiple folders each containing many + extracted cube tiff files. It returns a corresponding list of lists of + `TiffFile`, where in the sub-list each `TiffFile` represents a tiff in + the given directory. + """ tiff_lists = [] for d in yaml_contents: if d["bg_channel"] < 0: channels = [d["signal_channel"]] + channels_metadata = [ + {}, + ] else: channels = [d["signal_channel"], d["bg_channel"]] + channels_metadata = [{}, {}] + + if "signal_mean" in d: + channels_metadata[0] = { + "mean": float(d["signal_mean"]), + "std": float(d["signal_std"]), + } + # if we have norm for signal we must have for background + if "signal_mean" in d and d["bg_channel"] >= 0: + channels_metadata[1] = { + "mean": float(d["bg_mean"]), + "std": float(d["bg_std"]), + } + if "cell_def" in d and d["cell_def"]: ch1_tiffs = [ os.path.join(d["cube_dir"], f) @@ -253,16 +345,34 @@ def get_tiff_files(yaml_contents): TiffList( find_relevant_tiffs(ch1_tiffs, d["cell_def"]), channels, + channels_metadata, d["type"], ) ) else: - tiff_lists.append(TiffDir(d["cube_dir"], channels, d["type"])) + tiff_lists.append( + TiffDir(d["cube_dir"], channels, channels_metadata, d["type"]) + ) tiff_files = [tiff_dir.make_tifffile_list() for tiff_dir in tiff_lists] return tiff_files +def make_tiff_lists( + tiff_files: list[list[TiffFile]], +) -> tuple[list[tuple[list[str], list[dict]]], list[Cell]]: + + cells = [] + filenames = [] + + for group in tiff_files: + for image in group: + filenames.append((image.img_files, image.channels_metadata)) + cells.append(image.as_cell()) + + return filenames, cells + + def cli(): args = training_parse() ensure_directory_exists(args.output_dir) @@ -288,15 +398,73 @@ def cli(): continue_training=args.continue_training, test_fraction=args.test_fraction, batch_size=args.batch_size, + max_workers=args.max_workers, no_augment=args.no_augment, + augment_likelihood=args.augment_likelihood, tensorboard=args.tensorboard, save_weights=args.save_weights, no_save_checkpoints=args.no_save_checkpoints, save_progress=args.save_progress, epochs=args.epochs, + normalize_channels=args.normalize_channels, + lr_schedule=args.lr_schedule, + lr_multiplier=args.lr_multiplier, ) +def get_dataloader( + cells: list[Cell], + filenames: list[tuple[list[str], list[dict]]], + batch_size: int, + n_processes: int, + pin_memory: bool, + auto_shuffle: bool, + augment: bool, + augment_likelihood: float, + normalize_channels: bool, +) -> tuple[DataLoader, CuboidTiffDataset]: + points_filenames = [f[0] for f in filenames] + + points_norm = None + if normalize_channels: + points_norm = [] + for names, channels_norm in filenames: + # check the first channel for metadata. We expect all or none + # of the channels to have metadata + if not channels_norm[0]: + raise ValueError(f"Data mean and std not found for {names}") + + norms = [(ch["mean"], ch["std"]) for ch in channels_norm] + points_norm.append(norms) + + dataset = CuboidTiffDataset( + points=cells, + points_filenames=points_filenames, + points_normalization=points_norm, + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(CUBE_DEPTH, CUBE_HEIGHT, CUBE_WIDTH), + axis_order=("z", "y", "x"), + target_output="label", + augment=augment, + augment_likelihood=augment_likelihood, + ) + # we use our own sampler so we can control the ordering + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=batch_size, + auto_shuffle=auto_shuffle, + ) + data_loader = DataLoader( + dataset=dataset, + sampler=sampler, + batch_size=None, + num_workers=n_processes, + pin_memory=pin_memory, + ) + return data_loader, dataset + + def run( output_dir, yaml_file, @@ -316,6 +484,12 @@ def run( no_save_checkpoints=False, save_progress=False, epochs=100, + max_workers: int = 3, + pin_memory: bool = True, + normalize_channels: bool = False, + lr_schedule: Sequence[int] = (), + lr_multiplier: float = 0.1, + augment_likelihood: float = 0.9, ): start_time = datetime.now() @@ -343,37 +517,37 @@ def run( continue_training=continue_training, ) - signal_train, background_train, labels_train = make_lists(tiff_files) + filenames_train, cells_train = make_tiff_lists(tiff_files) n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus) + n_processes = min(n_processes, max_workers) if test_fraction > 0: logger.info("Splitting data into training and validation datasets") ( - signal_train, - signal_test, - background_train, - background_test, - labels_train, - labels_test, + filenames_train, + filenames_test, + cells_train, + cells_test, ) = train_test_split( - signal_train, - background_train, - labels_train, + filenames_train, + cells_train, test_size=test_fraction, ) logger.info( - f"Using {len(signal_train)} images for training and " - f"{len(signal_test)} images for validation" + f"Using {len(filenames_train)} images for training and " + f"{len(filenames_test)} images for validation" ) - validation_generator = CubeGeneratorFromDisk( - signal_test, - background_test, - labels=labels_test, - batch_size=batch_size, - train=True, - use_multiprocessing=False, - workers=n_processes, + validation_data_loader, validation_dataset = get_dataloader( + cells_test, + filenames_test, + batch_size, + n_processes, + pin_memory, + auto_shuffle=False, + augment=False, + augment_likelihood=augment_likelihood, + normalize_channels=normalize_channels, ) # for saving checkpoints @@ -381,19 +555,20 @@ def run( else: logger.info("No validation data selected.") - validation_generator = None + validation_data_loader = None + validation_dataset = None base_checkpoint_file_name = "-epoch.{epoch:02d}" - training_generator = CubeGeneratorFromDisk( - signal_train, - background_train, - labels=labels_train, - batch_size=batch_size, - shuffle=True, - train=True, + training_data_loader, training_dataset = get_dataloader( + cells_train, + filenames_train, + batch_size, + n_processes, + pin_memory, + auto_shuffle=True, augment=not no_augment, - use_multiprocessing=False, - workers=n_processes, + augment_likelihood=augment_likelihood, + normalize_channels=normalize_channels, ) callbacks = [] @@ -430,15 +605,33 @@ def run( csv_logger = CSVLogger(csv_filepath) callbacks.append(csv_logger) + if lr_schedule: + # we need to drop the lr by a given schedule. This is called at the + # start of each epoch and is zero based. E.g. if epoch 10 is listed, + # it'll drop at the start of the 11th epoch. + lr_callback = partial( + lr_scheduler, multiplier=lr_multiplier, epoch_list=lr_schedule + ) + callbacks.append(LearningRateScheduler(lr_callback)) + logger.info("Beginning training.") - # Keras 3.0: `use_multiprocessing` input is set in the - # `training_generator` (False by default) - model.fit( - training_generator, - validation_data=validation_generator, - epochs=epochs, - callbacks=callbacks, - ) + if n_processes: + training_dataset.start_dataset_thread(n_processes) + if validation_dataset is not None: + validation_dataset.start_dataset_thread(n_processes) + try: + model.fit( + x=training_data_loader, + validation_data=validation_data_loader, + epochs=epochs, + callbacks=callbacks, + ) + finally: + try: + training_dataset.stop_dataset_thread() + finally: + if validation_dataset is not None: + validation_dataset.stop_dataset_thread() if save_weights: logger.info("Saving model weights") diff --git a/cellfinder/napari/curation.py b/cellfinder/napari/curation.py index 2dfc7d92..76b72424 100644 --- a/cellfinder/napari/curation.py +++ b/cellfinder/napari/curation.py @@ -1,3 +1,4 @@ +from functools import partial from pathlib import Path from typing import List, Optional, Tuple @@ -12,7 +13,12 @@ from napari.qt.threading import thread_worker from napari.utils.notifications import show_info from qt_niu.dialog import display_warning -from qt_niu.interaction import add_button, add_combobox +from qt_niu.interaction import ( + add_button, + add_combobox, + add_float_box, + add_int_box, +) from qtpy import QtCore from qtpy.QtWidgets import ( QComboBox, @@ -23,6 +29,12 @@ QWidget, ) +from cellfinder.core.classify.cube_generator import ( + CuboidBatchSampler, + CuboidStackDataset, +) +from cellfinder.core.tools.image_processing import dataset_mean_std + # Constants used throughout WINDOW_HEIGHT = 750 WINDOW_WIDTH = 1500 @@ -30,6 +42,10 @@ class CurationWidget(QWidget): + """ + Voxel size parameters are in z, y, x order. + """ + def __init__( self, viewer: napari.viewer.Viewer, @@ -43,8 +59,8 @@ def __init__( ): super(CurationWidget, self).__init__() - self.non_cells_to_extract = None - self.cells_to_extract = None + self.non_cells_to_extract: list[Cell] = [] + self.cells_to_extract: list[Cell] = [] self.cube_depth = cube_depth self.cube_width = cube_width @@ -54,6 +70,7 @@ def __init__( self.save_empty_cubes = save_empty_cubes self.max_ram = max_ram self.voxel_sizes = [5, 2, 2] + self.normalization_down_sampling = 32 self.batch_size = 64 self.viewer = viewer @@ -159,6 +176,12 @@ def setup_main_layout(self): self.setLayout(self.layout) + def _set_voxel_size(self, value: float, index: int) -> None: + self.voxel_sizes[index] = value + + def _set_normalization_down_sampling(self, value: int) -> None: + self.normalization_down_sampling = value + def add_loading_panel(self, row: int, column: int = 0): self.load_data_panel = QGroupBox("Load data") self.load_data_layout = QGridLayout() @@ -180,32 +203,81 @@ def add_loading_panel(self, row: int, column: int = 0): 2, callback=self.set_background_image, ) + box_z = add_float_box( + self.load_data_layout, + self.voxel_sizes[0], + 0, + 1000, + "Voxel size (z)", + 0.01, + tooltip="Size of your voxels in the axial dimension (microns)", + row=3, + ) + box_z.valueChanged.connect(partial(self._set_voxel_size, index=0)) + box_y = add_float_box( + self.load_data_layout, + self.voxel_sizes[1], + 0, + 1000, + "Voxel size (y)", + 0.01, + tooltip="Size of your voxels in the y direction " + "(top to bottom) (microns)", + row=4, + ) + box_y.valueChanged.connect(partial(self._set_voxel_size, index=1)) + box_x = add_float_box( + self.load_data_layout, + self.voxel_sizes[2], + 0, + 1000, + "Voxel size (x)", + 0.01, + tooltip="Size of your voxels in the x direction " + "(left to right) (microns)", + row=5, + ) + box_x.valueChanged.connect(partial(self._set_voxel_size, index=2)) + self.voxel_sizes_boxes = box_z, box_y, box_x + box_norm = add_int_box( + self.load_data_layout, + self.normalization_down_sampling, + 1, + 1000, + "Normalization down-sampling", + 6, + tooltip="Down-sampling factor of the z-dimension used to calculate" + " the mean and std of the dataset. Used to normalize the " + "channels during training.", + ) + box_norm.valueChanged.connect(self._set_normalization_down_sampling) + self.norm_sampling_box = box_norm self.training_data_cell_choice, _ = add_combobox( self.load_data_layout, "Training data (cells)", self.point_layer_names, - 3, + 7, callback=self.set_training_data_cell, ) self.training_data_non_cell_choice, _ = add_combobox( self.load_data_layout, "Training_data (non_cells)", self.point_layer_names, - row=4, + row=8, callback=self.set_training_data_non_cell, ) self.mark_as_cell_button = add_button( "Mark as cell(s)", self.load_data_layout, self.mark_as_cell, - row=5, + row=9, tooltip="Mark all selected points as non cell. Shortcut: 'c'", ) self.mark_as_non_cell_button = add_button( "Mark as non cell(s)", self.load_data_layout, self.mark_as_non_cell, - row=5, + row=9, column=1, tooltip="Mark all selected points as non cell. Shortcut: 'x'", ) @@ -213,13 +285,13 @@ def add_loading_panel(self, row: int, column: int = 0): "Add training data layers", self.load_data_layout, self.add_training_data, - row=6, + row=10, ) self.save_training_data_button = add_button( "Save training data", self.load_data_layout, self.save_training_data, - row=6, + row=10, column=1, ) self.load_data_layout.setColumnMinimumWidth(0, COLUMN_WIDTH) @@ -544,7 +616,17 @@ def convert_layers_to_cells(self): self.cells_to_extract = list(set(self.cells_to_extract)) self.non_cells_to_extract = list(set(self.non_cells_to_extract)) + def _calculate_channel_stats(self): + signal_stat = dataset_mean_std( + self.signal_layer.data, self.normalization_down_sampling + ) + bg_stat = dataset_mean_std( + self.background_layer.data, self.normalization_down_sampling + ) + return signal_stat, bg_stat + def __save_yaml_file(self): + signal_stat, bg_stat = self._calculate_channel_stats() yaml_section = [ { "cube_dir": str(self.cell_cube_dir), @@ -552,6 +634,10 @@ def __save_yaml_file(self): "type": "cell", "signal_channel": 0, "bg_channel": 1, + "signal_mean": signal_stat[0], + "signal_std": signal_stat[1], + "bg_mean": bg_stat[0], + "bg_std": bg_stat[1], }, { "cube_dir": str(self.no_cell_cube_dir), @@ -559,6 +645,10 @@ def __save_yaml_file(self): "type": "no_cell", "signal_channel": 0, "bg_channel": 1, + "signal_mean": signal_stat[0], + "signal_std": signal_stat[1], + "bg_mean": bg_stat[0], + "bg_std": bg_stat[1], }, ] @@ -583,9 +673,6 @@ def extract_cubes(self): Attributes used to update a progress bar. The keys can be any of the properties of `magicgui.widgets.ProgressBar`. """ - from cellfinder.core.classify.cube_generator import ( - CubeGeneratorFromFile, - ) to_extract = { "cells": self.cells_to_extract, @@ -603,40 +690,56 @@ def extract_cubes(self): self.update_status_label(f"Saving {cell_type}...") - cube_generator = CubeGeneratorFromFile( - cell_list, - self.signal_layer.data, - self.background_layer.data, - self.voxel_sizes, - self.network_voxel_sizes, - batch_size=self.batch_size, - cube_width=self.cube_width, - cube_height=self.cube_height, - cube_depth=self.cube_depth, - extract=True, + cube_generator = CuboidStackDataset( + signal_array=self.signal_layer.data, + background_array=self.background_layer.data, + points=cell_list, + data_voxel_sizes=self.voxel_sizes, + network_voxel_sizes=self.network_voxel_sizes, + network_cuboid_voxels=( + self.cube_depth, + self.cube_height, + self.cube_width, + ), + axis_order=("z", "y", "x"), + output_axis_order=("z", "y", "x", "c"), + max_axis_0_cuboids_buffered=1, + target_output="index", + ) + # use sampler and data loader so we can use the z sorting for + # better caching. Potentially also for multiple workers. + sampler = CuboidBatchSampler( + dataset=cube_generator, + batch_size=1, + sort_by_axis="z", ) # Set up progress bar yield { "value": 0, "min": 0, - "max": len(cube_generator), + "max": len(sampler), } - for i, (image_batch, batch_info) in enumerate(cube_generator): - image_batch = image_batch.astype(np.int16) - - for point, point_info in zip(image_batch, batch_info): - point = np.moveaxis(point, 2, 0) - for channel in range(point.shape[-1]): + i = 0 + for batch in sampler: + # manually sample the dataset using the sampler so it's sampled + # ordered by z, for efficiency + images, points_index = cube_generator[batch] + for image, point_index in zip(images, points_index): + image = image.numpy() + # the output is the index of the input points + cell = cell_list[int(point_index.item())] + for channel in range(image.shape[-1]): save_cube( - point, - point_info, + image, + cell.to_dict(), channel, cell_type_output_directory, ) - # Update progress bar - yield {"value": i + 1} + # Update progress bar + yield {"value": i + 1} + i += 1 self.update_status_label("Finished saving cubes") diff --git a/cellfinder/napari/detect/detect.py b/cellfinder/napari/detect/detect.py index 2503d412..81988fc3 100644 --- a/cellfinder/napari/detect/detect.py +++ b/cellfinder/napari/detect/detect.py @@ -1,5 +1,4 @@ from functools import partial -from math import ceil from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple @@ -11,7 +10,10 @@ from napari.utils.notifications import show_info from qtpy.QtWidgets import QScrollArea -from cellfinder.core.classify.cube_generator import get_cube_depth_min_max +from cellfinder.core.classify.cube_generator import ( + get_data_cuboid_range, + get_data_cuboid_voxels, +) from cellfinder.napari.utils import ( add_classified_layers, add_single_layer, @@ -30,7 +32,7 @@ NETWORK_VOXEL_SIZES = [5, 1, 1] CUBE_WIDTH = 50 -CUBE_HEIGHT = 20 +CUBE_HEIGHT = 50 CUBE_DEPTH = 20 # If using ROI, how many extra planes to analyse @@ -188,12 +190,12 @@ def find_local_planes( current_plane = viewer.dims.current_step[0] # so a reasonable number of cells in the plane are detected - planes_needed = MIN_PLANES_ANALYSE + int( - ceil((CUBE_DEPTH * NETWORK_VOXEL_SIZES[0]) / voxel_size_z) + planes_needed = MIN_PLANES_ANALYSE + get_data_cuboid_voxels( + CUBE_DEPTH, NETWORK_VOXEL_SIZES[0], voxel_size_z ) - start_plane, end_plane = get_cube_depth_min_max( - current_plane, planes_needed + start_plane, end_plane = get_data_cuboid_range( + current_plane, planes_needed, "z" ) start_plane = max(0, start_plane) end_plane = min(len(signal_image.data), end_plane) @@ -257,6 +259,8 @@ def widget( use_pre_trained_weights: bool, trained_model: Optional[Path], classification_batch_size: int, + normalize_channels: bool, + normalization_down_sampling: int, misc_options, start_plane: int, end_plane: int, @@ -326,6 +330,15 @@ def widget( the models. For performance-critical applications, tune to maximize memory usage without running out. Check your GPU/CPU memory to verify it's not full + normalize_channels : bool + For classification only - whether to normalize the cubes to the + mean/std of the image channels before classification. If the model + used for classification was trained on normalized data, this should + be enabled. + normalization_down_sampling : int + If normalizing the cubes is enabled, the input channels will be + down-sampled in z by this value before calculating their mean/std. + E.g. a value of 2 means every second z plane will be used. start_plane : int First plane to process (to process a subset of the data) end_plane : int @@ -401,6 +414,8 @@ def widget( use_pre_trained_weights, trained_model, classification_batch_size, + normalize_channels, + normalization_down_sampling, ) if analyse_local: diff --git a/cellfinder/napari/detect/detect_containers.py b/cellfinder/napari/detect/detect_containers.py index 5a130853..e57bba20 100644 --- a/cellfinder/napari/detect/detect_containers.py +++ b/cellfinder/napari/detect/detect_containers.py @@ -120,6 +120,8 @@ class ClassificationInputs(InputContainer): use_pre_trained_weights: bool = True trained_model: Optional[Path] = Path.home() classification_batch_size: int = 64 + normalize_channels: bool = False + normalization_down_sampling: int = 32 def as_core_arguments(self) -> dict: args = super().as_core_arguments() @@ -141,6 +143,14 @@ def widget_representation(cls) -> dict: value=cls.defaults()["classification_batch_size"], label="Batch size (classification)", ), + normalize_channels=dict( + value=cls.defaults()["normalize_channels"], + label="Normalize data", + ), + normalization_down_sampling=dict( + value=cls.defaults()["normalization_down_sampling"], + label="Normalization down-sampling", + ), ) diff --git a/cellfinder/napari/train/train.py b/cellfinder/napari/train/train.py index 79d92b6b..db74af5d 100644 --- a/cellfinder/napari/train/train.py +++ b/cellfinder/napari/train/train.py @@ -59,11 +59,14 @@ def widget( training_options: dict, continue_training: bool, augment: bool, + normalize_channels: bool, tensorboard: bool, save_checkpoints: bool, save_progress: bool, epochs: int, learning_rate: float, + lr_schedule: list[int], + lr_multiplier: float, batch_size: int, test_fraction: float, misc_options: dict, @@ -93,6 +96,10 @@ def widget( this will continue from the pretrained model augment : bool Augment the training data to improve generalisation + normalize_channels : bool + Whether to normalize the cubes by the mean/std of their origin + dataset. If True, the yaml files must include the mean/std of + the origin dataset. tensorboard : bool Log to output_directory/tensorboard save_checkpoints : bool @@ -104,6 +111,15 @@ def widget( (How many times to use each training data point) learning_rate : float Learning rate for training the model + lr_schedule : list of ints + If not empty, the list of epochs when to multiply the current + learning rate by the lr_multiplier. E.g. if it's [10, 25], we start + with a learning rate of 0.001, and `lr_multiplier` is 0.1, then the + LR will be 0.001 for epochs 0-9, 0.0001 for 10-24, and 00001 + for epoch 25 and beyond. + lr_multiplier : float + The multiplier by which to multiply the previous learning rate + at the epochs listed in `lr_schedule`. batch_size : int Training batch size test_fraction : float @@ -135,6 +151,9 @@ def widget( learning_rate, batch_size, test_fraction, + normalize_channels, + lr_schedule, + lr_multiplier, ) misc_training_inputs = MiscTrainingInputs(number_of_free_cpus) diff --git a/cellfinder/napari/train/train_containers.py b/cellfinder/napari/train/train_containers.py index c77ece05..01aa5897 100644 --- a/cellfinder/napari/train/train_containers.py +++ b/cellfinder/napari/train/train_containers.py @@ -81,6 +81,9 @@ class OptionalTrainingInputs(InputContainer): learning_rate: float = 1e-4 batch_size: int = 16 test_fraction: float = 0.1 + normalize_channels: bool = False + lr_schedule: list[int] | tuple[int, ...] = () + lr_multiplier: float = 0.1 def as_core_arguments(self) -> dict: arguments = super().as_core_arguments() @@ -105,6 +108,13 @@ def widget_representation(cls) -> dict: test_fraction=cls._custom_widget( "test_fraction", step=0.05, min=0.05, max=0.95 ), + normalize_channels=cls._custom_widget("normalize_channels"), + lr_schedule=cls._custom_widget( + "lr_schedule", custom_label="LR schedule" + ), + lr_multiplier=cls._custom_widget( + "lr_multiplier", custom_label="LR multiplier" + ), ) diff --git a/pyproject.toml b/pyproject.toml index 1528d76d..d66c02b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dev = [ "pytest-qt", "pytest-timeout", "pytest", + "PyYAML", "tox", "pooch >= 1", ] @@ -123,7 +124,9 @@ python = 3.13: py313 [testenv] -commands = python -m pytest -v --color=yes --cov=cellfinder --cov-report=xml +commands = + python -m pip install monai --no-deps + python -m pytest -v --color=yes --cov=cellfinder --cov-report=xml extras = dev napari diff --git a/tests/core/test_integration/test_detection.py b/tests/core/test_integration/test_detection.py index 5e145875..544d2b7d 100644 --- a/tests/core/test_integration/test_detection.py +++ b/tests/core/test_integration/test_detection.py @@ -178,13 +178,15 @@ def detect_finished_callback(points): assert npoints == 120, f"Expected 120 points, found {npoints}" -def test_synthetic_data(synthetic_bright_spots, no_free_cpus): +@pytest.mark.parametrize("normalize", [True, False]) +def test_synthetic_data(synthetic_bright_spots, no_free_cpus, normalize): signal_array, background_array = synthetic_bright_spots detected = main( signal_array, background_array, voxel_sizes, n_free_cpus=no_free_cpus, + normalize_channels=normalize, ) assert len(detected) == 8 diff --git a/tests/core/test_integration/test_train.py b/tests/core/test_integration/test_train.py index 64e0d443..e0d2e7b9 100644 --- a/tests/core/test_integration/test_train.py +++ b/tests/core/test_integration/test_train.py @@ -1,7 +1,7 @@ import os -import sys import pytest +from pytest_mock.plugin import MockerFixture from cellfinder.core.train.train_yaml import cli as train_run @@ -11,6 +11,7 @@ cell_cubes = os.path.join(data_dir, "cells") non_cell_cubes = os.path.join(data_dir, "non_cells") training_yaml_file = os.path.join(data_dir, "training.yaml") +training_yaml_file_stats = os.path.join(data_dir, "training_with_stats.yaml") EPOCHS = "2" @@ -20,7 +21,7 @@ @pytest.mark.slow -def test_train(tmpdir): +def test_train(mocker, tmpdir): tmpdir = str(tmpdir) train_args = [ @@ -32,8 +33,105 @@ def test_train(tmpdir): "--epochs", EPOCHS, ] - sys.argv = train_args + mocker.patch("sys.argv", train_args) + train_run() model_file = os.path.join(tmpdir, "model.keras") assert os.path.exists(model_file) + + +@pytest.mark.parametrize("normalize", [True, False]) +@pytest.mark.parametrize("has_norms", [True, False]) +def test_train_normalization_missing_stats( + mocker: MockerFixture, tmpdir, has_norms, normalize +): + tmpdir = str(tmpdir) + + train_args = [ + "cellfinder_train", + "-y", + training_yaml_file_stats if has_norms else training_yaml_file, + "-o", + tmpdir, + "--epochs", + EPOCHS, + ] + if normalize: + train_args.append("--normalize-channels") + + mocker.patch("sys.argv", train_args) + get_model = mocker.patch( + "cellfinder.core.train.train_yaml.get_model", autospec=True + ) + + if normalize and not has_norms: + # if the yaml doesn't have normalization info an error will be raised + with pytest.raises(ValueError): + train_run() + else: + train_run() + # get the data sets passed to fit() to verify if it has norm data + # there's no clear name property of the mock fit call, so use its repr + (fit_mock,) = [ + m + for m in get_model.mock_calls + if repr(m).startswith("call().fit(") + ] + train_dataset = fit_mock.kwargs["x"].dataset + val_dataset = fit_mock.kwargs["validation_data"].dataset + + if normalize: + # if we normalize, the normalization data should be in dataset + assert train_dataset.points_norm_arr is not None + assert val_dataset.points_norm_arr is not None + else: + # otherwise, no normalization data should have been passed, even if + # the yaml has it + assert train_dataset.points_norm_arr is None + assert val_dataset.points_norm_arr is None + + +@pytest.mark.parametrize("lr_schedule", [True, False]) +def test_train_lr_schedule(mocker: MockerFixture, tmpdir, lr_schedule): + tmpdir = str(tmpdir) + + train_args = [ + "cellfinder_train", + "-y", + training_yaml_file, + "-o", + tmpdir, + "--epochs", + EPOCHS, + "--lr-multiplier", + "0.3", + ] + if lr_schedule: + train_args.extend(["--lr-schedule", "10", "20"]) + + mocker.patch("sys.argv", train_args) + get_model = mocker.patch( + "cellfinder.core.train.train_yaml.get_model", autospec=True + ) + + train_run() + # get the data sets passed to fit(). There's no clear name property of + # the mock fit call, so use its repr + (fit_mock,) = [ + m for m in get_model.mock_calls if repr(m).startswith("call().fit(") + ] + callbacks = fit_mock.kwargs["callbacks"] + + # locate the scheduler callback, if any + from keras.callbacks import LearningRateScheduler + + callbacks = [c for c in callbacks if isinstance(c, LearningRateScheduler)] + if lr_schedule: + assert len(callbacks) == 1 + # the callback is a partial function with these args + partial_callback = callbacks[0].schedule + assert partial_callback.keywords["multiplier"] == 0.3 + assert partial_callback.keywords["epoch_list"] == [10, 20] + else: + assert not callbacks diff --git a/tests/core/test_unit/test_classify/__init__.py b/tests/core/test_unit/test_classify/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/test_unit/test_classify/test_cube_gen.py b/tests/core/test_unit/test_classify/test_cube_gen.py new file mode 100644 index 00000000..e939776d --- /dev/null +++ b/tests/core/test_unit/test_classify/test_cube_gen.py @@ -0,0 +1,1334 @@ +from typing import Sequence + +import numpy as np +import pytest +import tifffile +import torch +from brainglobe_utils.cells.cells import Cell +from pytest_mock.plugin import MockerFixture +from torch.utils.data import DataLoader + +from cellfinder.core.classify.cube_generator import ( + CachedArrayStackImageData, + CachedCuboidImageDataBase, + CachedStackImageDataBase, + CachedTiffCuboidImageData, + CuboidBatchSampler, + CuboidDatasetBase, + CuboidStackDataset, + CuboidTiffDataset, + ImageDataBase, + get_data_cuboid_range, +) +from cellfinder.core.tools.threading import ExecutionFailure + +try: + from brainglobe_utils.cells.cells import file_name_from_cell +except ImportError: + + def file_name_from_cell(cell: Cell, channel: int) -> str: + name = f"x{int(cell.x)}_y{int(cell.y)}_z{int(cell.z)}Ch{channel}.tif" + return name + + +PT_TYPE = tuple[int, int, int] + + +unique_int_val = 0 + + +@pytest.fixture(scope="function") +def unique_int() -> int: + """Returns a unique int every time called.""" + global unique_int_val + unique_int_val += 1 + return unique_int_val + + +def sample_volume(x: int, y: int, z: int, c: int, seed: int) -> np.ndarray: + """ + Returns numpy volume with the given size using the given seed to make it + unique. + """ + x, y, z, c = ( + x, + y, + z, + c, + ) + data = np.arange(x * y * z * c) + seed + data = data.reshape((x, y, z, c)).astype(np.uint16) + return data + + +def point_to_slice( + point: PT_TYPE, + cube_size: PT_TYPE, +) -> list[slice]: + """Returns slices to index volume to get cube around point.""" + slices = [] + for p, s, ax in zip(point, cube_size, ["x", "y", "z"]): + slices.append(slice(*get_data_cuboid_range(p, s, ax))) + return slices + + +def to_numpy_cubes( + volume: np.ndarray, points: Sequence[PT_TYPE], cube_size: PT_TYPE +) -> tuple[list[np.ndarray], torch.Tensor]: + """Extracts numpy cubes around the points in the volume.""" + points_arr = torch.empty((len(points), 5)) + cubes = [] + for i, point in enumerate(points): + points_arr[i] = torch.tensor([*point, 0, i]) + cube = volume[tuple(point_to_slice(point, cube_size))] + cubes.append(cube) + return cubes, points_arr + + +def to_tiff_cubes( + volume: np.ndarray, cube_size: PT_TYPE, points: Sequence[PT_TYPE], tmp_path +) -> tuple[list[Sequence[str]], list[np.ndarray], torch.Tensor]: + """Creates tiff files for the cubes around the points in the volume.""" + cubes, points_arr = to_numpy_cubes(volume, points, cube_size) + # create the tiff files, one per channel per point + filenames = [] + for (x, y, z), cube in zip(points, cubes): + # don't have to center cube on point, just get a unique cube + sig = tmp_path / file_name_from_cell( + Cell([x, y, z], Cell.UNKNOWN), channel=0 + ) + tifffile.imwrite(sig, cube[:, :, :, 0]) + + back = tmp_path / file_name_from_cell( + Cell([x, y, z], Cell.UNKNOWN), channel=1 + ) + tifffile.imwrite(back, cube[:, :, :, 1]) + + filenames.append((str(sig), str(back))) + + return filenames, cubes, points_arr + + +def assert_loader_cubes_matches_cubes( + cubes: Sequence[np.ndarray], + data_loader: ImageDataBase, + batches: Sequence[Sequence[int]], +) -> None: + """ + Checks that the provided cubes from the input data are the same as the + cubes we get from the data loader. + """ + cubes = [torch.from_numpy(cube) for cube in cubes] + for batch in batches: + for i in batch: + # compare each cube with cube from data loader + assert torch.equal(cubes[i], data_loader.get_point_cuboid_data(i)) + batch_cubes = [cubes[i][None, ...] for i in batch] + + # compare batch of cubes with cubes from data loader + cube_batch = torch.concatenate(batch_cubes, 0) + loader_batch = torch.zeros_like(cube_batch) + data_loader.get_point_batch_cuboid_data(loader_batch, batch) + assert torch.equal(cube_batch, loader_batch) + + +def assert_loader_cubes_bad_indices( + data_shape: Sequence[int], + data_loader: ImageDataBase, + individuals: Sequence[int], + batches: Sequence[Sequence[int]], +) -> None: + """Checks that trying to access cubes beyond valid cubes fails.""" + for i in individuals: + # check individual cubes indices that are out of bounds + with pytest.raises(IndexError): + data_loader.get_point_cuboid_data(i) + + for batch in batches: + # check batches that have indices that are out of bounds + loader_batch = torch.empty(data_shape, dtype=torch.uint16) + with pytest.raises(IndexError): + data_loader.get_point_batch_cuboid_data(loader_batch, batch) + + +def assert_loader_cubes_bad_size( + data_shape: Sequence[int], + data_loader: ImageDataBase, + batch: Sequence[int], +) -> None: + """Checks that providing a wrong sized buffer for cube batches fails.""" + # for batches, we need to provide buffer. Check buffers with wrong size + for i in range(1, len(data_shape)): + shape = list(data_shape) + # increase that dim size by one + shape[i] += 1 + + loader_batch = torch.empty(shape, dtype=torch.uint16) + with pytest.raises(ValueError): + data_loader.get_point_batch_cuboid_data(loader_batch, batch) + + +def test_array_image_data(unique_int): + """ + Checks that the data returned by the CachedArrayStackImageData for given + points matches the data it should return. + """ + volume = sample_volume(20, 20, 30, 2, unique_int) + points = [(5, 5, 10), (10, 10, 20)] + cube_size = 3, 3, 5 + cubes, points_arr = to_numpy_cubes(volume, points, cube_size) + + stack = CachedArrayStackImageData( + input_arrays=[volume[:, :, :, 0], volume[:, :, :, 1]], + max_axis_0_cuboids_buffered=3, + data_axis_order=("x", "y", "z"), + cuboid_size=cube_size, + points_arr=points_arr[:, :3], + ) + + assert stack.cuboid_with_channels_size == (*cube_size, 2) + assert stack.data_with_channels_axis_order == ("x", "y", "z", "c") + + assert_loader_cubes_matches_cubes( + cubes, stack, [[0, 1], [0], [1], [0], [1, 0], [1, 1]] + ) + assert_loader_cubes_bad_indices((2, *cube_size, 2), stack, [2], [[1, 2]]) + assert_loader_cubes_bad_size((2, *cube_size, 2), stack, [0, 1]) + + +def test_tiff_image_data(unique_int, tmp_path): + """ + Checks that the data returned by the CachedTiffCuboidImageData for given + points matches the data it should return. + """ + volume = sample_volume(20, 20, 30, 2, unique_int) + points = [(5, 5, 10), (10, 10, 20)] + cube_size = 3, 3, 5 + filenames, cubes, points_arr = to_tiff_cubes( + volume, cube_size, points, tmp_path + ) + + tiffs = CachedTiffCuboidImageData( + filenames_arr=np.array(filenames).astype(np.str_), + max_cuboids_buffered=3, + data_axis_order=("x", "y", "z"), + cuboid_size=cube_size, + points_arr=points_arr[:, :3], + ) + + assert tiffs.cuboid_with_channels_size == (*cube_size, 2) + assert tiffs.data_with_channels_axis_order == ("x", "y", "z", "c") + + assert_loader_cubes_matches_cubes( + cubes, tiffs, [[0, 1], [0], [1], [0], [1, 0], [1, 1]] + ) + assert_loader_cubes_bad_indices((2, *cube_size, 2), tiffs, [2], [[1, 2]]) + assert_loader_cubes_bad_size((2, *cube_size, 2), tiffs, [0, 1]) + + +@pytest.mark.parametrize("cached", [0, 1]) +def test_array_image_data_cache(unique_int, cached, mocker: MockerFixture): + """Checks that CachedArrayStackImageData properly caches the first axis.""" + volume = sample_volume(20, 20, 30, 2, unique_int) + points = [(5, 5, 10), (10, 10, 20)] + cube_size = 3, 3, 5 + cubes, points_arr = to_numpy_cubes(volume, points, cube_size) + + stack = CachedArrayStackImageData( + input_arrays=[volume[:, :, :, 0], volume[:, :, :, 1]], + max_axis_0_cuboids_buffered=cached, + data_axis_order=("x", "y", "z"), + cuboid_size=cube_size, + points_arr=points_arr[:, :3], + ) + + spy = mocker.spy(stack, "read_plane") + stack.get_point_cuboid_data(0) + stack.get_point_cuboid_data(1) + stack.get_point_cuboid_data(0) + batch = torch.zeros((2, *cube_size, 2)) + stack.get_point_batch_cuboid_data(batch, [0, 1]) + + # number of planes per cache dim cube + n = cube_size[0] + match cached: + # cache is always in addition to one cube + case 0: + # only one cube is ever cached in memory. Plus 2 channels + assert spy.call_count == (n + n + n + n) * 2 + case 1: + # should only ever read each plane once as there's enough cache + assert spy.call_count == n * 2 * 2 + + +@pytest.mark.parametrize("cached", [0, 1]) +def test_tiff_image_data_cache( + unique_int, tmp_path, cached, mocker: MockerFixture +): + """Checks that CachedTiffCuboidImageData properly caches the first axis.""" + volume = sample_volume(20, 20, 30, 2, unique_int) + points = [(5, 5, 10), (10, 10, 20)] + cube_size = 3, 3, 5 + filenames, cubes, points_arr = to_tiff_cubes( + volume, cube_size, points, tmp_path + ) + + tiffs = CachedTiffCuboidImageData( + filenames_arr=np.array(filenames).astype(np.str_), + max_cuboids_buffered=cached, + data_axis_order=("x", "y", "z"), + cuboid_size=cube_size, + points_arr=points_arr[:, :3], + ) + + spy = mocker.spy(tiffs, "read_cuboid") + tiffs.get_point_cuboid_data(0) + tiffs.get_point_cuboid_data(1) + tiffs.get_point_cuboid_data(0) + batch = torch.zeros((2, *cube_size, 2)) + tiffs.get_point_batch_cuboid_data(batch, [0, 1]) + + # number of cubes per + match cached: + # cache is always in addition to one cube + case 0: + # only one cube is ever cached in memory, per channel + assert spy.call_count == (1 + 1 + 1 + 1) * 2 + case 1: + # should only ever read each cube once as there's enough cache + assert spy.call_count == 2 * 2 + + +def assert_dataset_cubes_matches_cubes( + cubes: Sequence[np.ndarray], + data_loader: CuboidDatasetBase | dict, + batches: Sequence[Sequence[int]], +) -> None: + """ + Checks that the provided cubes from the data loader or dict are the same + as the cubes we get from the data loader. + """ + cubes = [torch.from_numpy(cube) for cube in cubes] + for batch in batches: + for i in batch: + # compare each cube with cube from data loader + assert torch.equal(cubes[i], data_loader[i]) + batch_cubes = [cubes[i][None, ...] for i in batch] + + # compare batch of cubes with cubes from data loader + cube_batch = torch.concatenate(batch_cubes, 0) + assert torch.equal(cube_batch, data_loader[batch]) + + +def assert_dataset_cubes_bad_indices( + data_loader: CuboidDatasetBase, + individuals: Sequence[int], + batches: Sequence[Sequence[int]], +) -> None: + """Checks that trying to access cubes beyond valid cubes fails.""" + for i in individuals: + # check individual cubes indices that are out of bounds + with pytest.raises(IndexError): + data_loader[i] + + for batch in batches: + # check batches that have indices that are out of bounds + with pytest.raises(IndexError): + data_loader[batch] + + +def get_sample_dataset_12( + seed=0, + target_output=None, + augment=False, + augment_likelihood=0.9, + flippable_axis=(0, 1, 2), + output_axis_order=("x", "y", "z", "c"), +): + """ + Returns a numpy volume, a set of 12 points in the volume, extracted cubes + from that original volume centered on the points, and a CuboidStackDataset + representing the volume. + """ + volume = sample_volume(60, 60, 30, 2, seed) + points = [ + (x, y, z) for x in (27, 29) for y in (28, 30) for z in (12, 13, 18) + ] + cube_size = 50, 50, 20 + cubes, _ = to_numpy_cubes(volume, points, cube_size) + + stack = CuboidStackDataset( + points=[Cell(pos, Cell.UNKNOWN) for pos in points], + data_voxel_sizes=(1, 1, 5), + network_voxel_sizes=(1, 1, 5), + network_cuboid_voxels=cube_size, + axis_order=("x", "y", "z"), + output_axis_order=output_axis_order, + augment=augment, + augment_likelihood=augment_likelihood, + flippable_axis=flippable_axis, + rotate_range=None, + translate_range=None, + scale_range=None, + target_output=target_output, + signal_array=volume[..., 0], + background_array=volume[..., 1], + max_axis_0_cuboids_buffered=3, + ) + return stack, points, cubes + + +def test_array_dataset(unique_int): + """ + Checks that the data returned by the CuboidStackDataset for given + points matches the data it should return. + """ + stack, points, cubes = get_sample_dataset_12(unique_int) + cube_size = cubes[0].shape + + assert stack.cuboid_with_channels_size == cube_size + assert stack.src_image_data.cuboid_with_channels_size == cube_size + assert stack.src_image_data.data_with_channels_axis_order == ( + "x", + "y", + "z", + "c", + ) + + # check various batches are correctly returned + assert_dataset_cubes_matches_cubes( + cubes, stack, [[0, 5], [0], [3], [0], [5, 0], [2, 2]] + ) + assert_dataset_cubes_bad_indices(stack, [15], [[1, 14]]) + + +def test_array_dataset_signal_only(unique_int): + """ + Checks that when using only the signal channel, the data returned by the + CuboidStackDataset for given points matches the data it should return. + """ + volume = sample_volume(60, 60, 30, 1, unique_int) + points = [(x, 28, 18) for x in (27, 29)] + cube_size = 50, 50, 20 + cubes, _ = to_numpy_cubes(volume, points, cube_size) + + stack = CuboidStackDataset( + points=[Cell(pos, Cell.UNKNOWN) for pos in points], + data_voxel_sizes=(1, 1, 5), + network_voxel_sizes=(1, 1, 5), + network_cuboid_voxels=cube_size, + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + signal_array=volume[..., 0], + background_array=None, + ) + + cube_size = cubes[0].shape + + assert stack.cuboid_with_channels_size == cube_size + assert stack.src_image_data.cuboid_with_channels_size == cube_size + + # check various batches are correctly returned + assert_dataset_cubes_matches_cubes( + cubes, + stack, + [ + [0, 1], + ], + ) + + +def test_tiff_image_dataset(unique_int, tmp_path): + """ + Checks that the data returned by the CuboidTiffDataset for given points + matches the data it should return. + """ + volume = sample_volume(60, 60, 30, 2, unique_int) + points = [ + (x, y, z) for x in (27, 29) for y in (28, 30) for z in (12, 13, 18) + ] + cube_size = 50, 50, 20 + filenames, cubes, _ = to_tiff_cubes(volume, cube_size, points, tmp_path) + + tiffs = CuboidTiffDataset( + points=[Cell(pos, Cell.UNKNOWN) for pos in points], + data_voxel_sizes=(1, 1, 5), + network_voxel_sizes=(1, 1, 5), + network_cuboid_voxels=cube_size, + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + augment=False, + points_filenames=filenames, + max_cuboids_buffered=3, + ) + + assert tiffs.cuboid_with_channels_size == (*cube_size, 2) + assert tiffs.src_image_data.cuboid_with_channels_size == (*cube_size, 2) + assert tiffs.src_image_data.data_with_channels_axis_order == ( + "x", + "y", + "z", + "c", + ) + + # check various batches are correctly returned + assert_dataset_cubes_matches_cubes( + cubes, tiffs, [[0, 5], [0], [3], [0], [5, 0], [2, 2]] + ) + assert_dataset_cubes_bad_indices(tiffs, [15], [[1, 14]]) + + +@pytest.mark.parametrize( + "batch_size,batches", + [ + (1, [[i] for i in range(12)]), + (4, [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), + (5, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11]]), + ], +) +def test_sampler_batch_size(batch_size: int, batches: list[list[int]]): + """ + Checks that the dataset sampler with different batch sizes returns the + correct batches. + """ + dataset, points, _ = get_sample_dataset_12() + + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=batch_size, + auto_shuffle=False, + sort_by_axis=None, + ) + + assert len(sampler) == len(batches) + samples = list(sampler) + assert len(samples) == len(batches) + + for i, batch in enumerate(batches): + assert np.array_equal(samples[i], batch) + + +@pytest.mark.parametrize( + "axis_name,axis", + [ + ("y", 1), + ("z", 2), + ], +) +def test_sampler_sort(axis_name: str, axis: int): + """ + Checks that the dataset sampler with sorting for different axes returns the + batches correctly sorted. + """ + dataset, points, _ = get_sample_dataset_12() + + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=4, + auto_shuffle=False, + sort_by_axis=axis_name, + ) + + assert len(sampler) == 3 + samples = list(sampler) + assert len(samples) == 3 + + # for the given axis, points should only be increasing + indices = np.concatenate(samples) + last = -1 + for i in indices: + assert points[i][axis] >= last + last = points[i][axis] + + +def test_sampler_shuffle(): + """Checks that the dataset sampler with shuffling works.""" + dataset, points, _ = get_sample_dataset_12() + + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=4, + auto_shuffle=True, + sort_by_axis=None, + ) + + # we do 5 different sampling / shuffling. The prob they are all the same + # by chance is astronomical, unless it's not shuffling properly + same = True + last = None + for i in range(5): + assert len(sampler) == 3 + samples = list(sampler) + assert len(samples) == 3 + indices = np.concatenate(samples) + + if last is not None: + same = same and np.array_equal(last, indices) + if not same: + break + last = indices + + assert not same + + +def test_sampler_shuffle_sort(): + """ + Checks that when shuffling and sorting then only within batch is shuffled. + Across batches the data stays the same. + """ + dataset, points, _ = get_sample_dataset_12() + + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=4, + auto_shuffle=True, + sort_by_axis="z", + ) + + # prob of a particular ordering is 1 / (4! * 4! * 4!) because each batch + # is individually reshuffled. This 1 / 13,824. Doing this 5 times with them + # all being the same is astronomical, since they are all independent events + same = True + last = None + last_raw = None + for i in range(5): + assert len(sampler) == 3 + samples = list(sampler) + assert len(samples) == 3 + indices = np.concatenate(samples) + + if last is not None: + # check batches were only shuffled within batch + for i in range(3): + assert set(last_raw[i]) == set(samples[i]) + + same = same and np.array_equal(last, indices) + if not same: + break + + last = indices + last_raw = samples + + assert not same + + +@pytest.mark.parametrize("data_thread", [True, False]) +@pytest.mark.parametrize("num_workers", [0, 1, 4]) +def test_dataset_dataloader_threads(unique_int, num_workers, data_thread): + """ + Checks that the torch/keras dataloaders can load the data properly under + different threading conditions including whether data is loaded in separate + threads or only the main thread. + + Also check that we can load the data in each sub-process directly, without + a main data reading thread. + """ + dataset, points, cubes = get_sample_dataset_12(unique_int) + dataloader = DataLoader( + dataset, + batch_size=5, + shuffle=False, + num_workers=num_workers, + drop_last=False, + ) + + chunks = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11]] + try: + if data_thread: + dataset.start_dataset_thread(num_workers) + # check that going through the data loader works + batches = list(dataloader) + # check that asking the dataset directly works + batches_manual = [dataset[chunk] for chunk in chunks] + finally: + dataset.stop_dataset_thread() + + cubes = [torch.from_numpy(cube)[None, ...] for cube in cubes] + for k, batch in enumerate(chunks): + # compare batch of cubes with cubes from data loader + cube_batch = torch.concatenate([cubes[i] for i in batch], 0) + assert torch.equal(cube_batch, batches[k]) + assert torch.equal(cube_batch, batches_manual[k]) + + +def test_dataset_dataloader_worker_exit_early(): + """ + Checks that if we request to exit the reader thread while torch is still + reading the data, instead of hanging forever, torch will raise an error + that we closed. + """ + dataset, points, cubes = get_sample_dataset_12(0) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=2, + drop_last=False, + prefetch_factor=1, + ) + + try: + dataset.start_dataset_thread(2) + it = iter(dataloader) + next(it) + finally: + dataset.stop_dataset_thread() + + # this should raise an exception that the workers were closed + with pytest.raises(ValueError): + # there might be some data already fetched, but not more than ~4 + for i in range(10): + next(it) + + +@pytest.mark.parametrize( + "batch_size,batch_idx", + [ + (1, [[i] for i in range(12)]), + (4, [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]), + ], +) +def test_dataset_dataloader_sampler( + unique_int, batch_size, batch_idx, mocker: MockerFixture +): + """ + Checks that the torch/keras dataloaders can load the data properly with + different batch size using the provided sampler. + """ + dataset, points, cubes = get_sample_dataset_12(unique_int) + + sampler = CuboidBatchSampler( + dataset=dataset, + batch_size=batch_size, + sort_by_axis=None, + auto_shuffle=False, + ) + dataloader = DataLoader( + dataset, + sampler=sampler, + batch_size=None, + num_workers=0, + ) + + spy = mocker.spy(dataset, "_get_multiple_items") + + batches = list(dataloader) + + cubes = [torch.from_numpy(cube)[None, ...] for cube in cubes] + for k, batch in enumerate(batch_idx): + # compare batch of cubes with cubes from data loader + cube_batch = torch.concatenate([cubes[i] for i in batch], 0) + assert torch.equal(cube_batch, batches[k]) + + assert len(spy.call_args_list) == len(batch_idx) + args = [call.args[0].tolist() for call in spy.call_args_list] + assert args == batch_idx + + +@pytest.mark.parametrize("target_output", ["index", "label", None]) +def test_dataset_target_output(unique_int, target_output): + """Checks that dataset's output for different requested target labels.""" + stack, points, cubes = get_sample_dataset_12(unique_int, target_output) + + # get the batches or single items + data = {k: stack[k] for k in [(1, 5), (1,), 1, 5]} + target = {} + if target_output is not None: + # we have some labels, break it into target labels and data + target = {k: v[1] for k, v in data.items()} + data = {k: v[0] for k, v in data.items()} + + assert_dataset_cubes_matches_cubes(cubes, data, [(1, 5), (1,)]) + if target_output == "index": + # labels are sample indices + assert torch.equal( + target[(1, 5)], torch.tensor([1, 5], dtype=torch.int) + ) + assert torch.equal(target[(1,)], torch.tensor([1], dtype=torch.int)) + assert torch.equal(target[1], torch.tensor(1, dtype=torch.int)) + assert torch.equal(target[5], torch.tensor(5, dtype=torch.int)) + elif target_output == "label": + # labels are one-hot vectors + assert torch.equal( + target[(1, 5)], torch.tensor([[1, 0], [1, 0]], dtype=torch.int) + ) + assert torch.equal( + target[(1,)], torch.tensor([[1, 0]], dtype=torch.int) + ) + assert torch.equal(target[1], torch.tensor([1, 0], dtype=torch.int)) + assert torch.equal(target[5], torch.tensor([1, 0], dtype=torch.int)) + + +def test_dataset_augment(unique_int): + """Checks that augment works using axis flipping.""" + stack, points, cubes = get_sample_dataset_12( + unique_int, augment=True, augment_likelihood=1, flippable_axis=(1,) + ) + # torch can't handle negative views, so copy + cubes = [np.flip(c, 1).copy() for c in cubes] + assert_dataset_cubes_matches_cubes( + cubes, stack, [[0, 5], [0], [3], [0], [5, 0], [2, 2]] + ) + + +def test_dataset_dim_switch(unique_int): + """Checks that input/output axis ordering works.""" + stack, points, cubes = get_sample_dataset_12( + unique_int, + output_axis_order=("y", "z", "x", "c"), + ) + # all our test data is x, y, z order + cubes = [np.moveaxis(c, [0, 1, 2], [2, 0, 1]).copy() for c in cubes] + assert_dataset_cubes_matches_cubes( + cubes, stack, [[0, 5], [0], [3], [0], [5, 0], [2, 2]] + ) + + +@pytest.mark.parametrize( + "scale,axis", + [ + (0.5, 0), + (2, 0), + (2, 1), + (2, 2), + ], +) +def test_dataset_voxel_scale(scale, axis): + """ + Checks that cube scaling works properly when the input/network voxel size + are different. Try different scale factors, as well as scaling only one + axis at a time for each axis. + """ + # generate points whose cubes don't much overlap, even when cube is 2x size + cube_size = 15, 15, 17 + volume = np.zeros((60, 35, 40, 2), dtype=np.uint16) + points = [(15, 17, 20), (45, 17, 20)] + # fill cube of 5x5x5x around center of point + for x, y, z in points: + volume[x - 2 : x + 3, y - 2 : y + 3, z - 2 : z + 3, :] = 100 + + network_voxel_sizes = (6, 8, 10) + data_voxel_sizes = list(network_voxel_sizes) + data_voxel_sizes[axis] = int(data_voxel_sizes[axis] * scale) + stack = CuboidStackDataset( + points=[Cell(pos, Cell.UNKNOWN) for pos in points], + data_voxel_sizes=data_voxel_sizes, + network_voxel_sizes=network_voxel_sizes, + network_cuboid_voxels=cube_size, + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + augment=False, + signal_array=volume[..., 0], + background_array=volume[..., 1], + max_axis_0_cuboids_buffered=3, + ) + + # get cubes from data sets, batch and individual items + scaled_cubes = [ + stack[0], + stack[1], + *stack[(0, 1)], + ] + # we test a voxel in the cube along the given axis, for each axis + for ax_i in (0, 1, 2): + # both before and after the center in that axis + for offset_sign in (-1, 1): + # and at different offsets from center. Offset of 1 and 2 should be + # set, offset of 4 should be blank. Because we set 5x5x5x + for offset, f in ((1, torch.gt), (2, torch.gt), (4, torch.lt)): + # for only the axis that was scaled, also scale the offset + if ax_i == axis: + offset = int(offset * scale) + + pos = [s // 2 for s in cube_size] + # keep it within cube in case of large scaling + pos[ax_i] = min( + max(pos[ax_i] + offset_sign * offset, 0), + cube_size[ax_i] - 1, + ) + + for cube in scaled_cubes: + # check that for all channels, the voxel is set/unset + assert torch.all(f(cube[tuple(pos)], 25)) + + +def get_single_point_dataset(): + volume = np.empty((30, 60, 60, 2), dtype=np.uint16) + dataset = CuboidStackDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(5, 1, 1), + augment=False, + network_cuboid_voxels=(20, 50, 50), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + return dataset + + +def test_dataset_thread_exception(): + """ + Checks that the external thread that reads the data for all the + requesting threads/processes properly forwards exceptions. + """ + dataset = get_single_point_dataset() + + try: + dataset.start_dataset_thread(1) + with pytest.raises(ExecutionFailure): + dataset.get_point_data(5) + finally: + dataset.stop_dataset_thread() + + +def test_dataset_thread_bad_msg(): + """ + Checks that the external thread can handle bad queue msg. + """ + dataset = get_single_point_dataset() + + try: + dataset.start_dataset_thread(1) + with pytest.raises(ExecutionFailure): + dataset._send_rcv_thread_msg( + dataset.cuboid_with_channels_size, "baaad", 0 + ) + finally: + dataset.stop_dataset_thread() + + +@pytest.mark.parametrize("index", [1, [1]]) +@pytest.mark.parametrize("do_thread", [True, False]) +def test_dataset_single_point_access(unique_int, do_thread, index): + """ + Checks getting one and a batch of cubes with / without external thread. + """ + dataset, _, cubes = get_sample_dataset_12(unique_int) + + try: + if do_thread: + dataset.start_dataset_thread(1) + + res = dataset[index] + if isinstance(index, list): + # remove batch dim + res = res[0] + assert np.array_equal(res.numpy(), cubes[1]) + finally: + dataset.stop_dataset_thread() + + +def test_get_data_cuboid_range(): + """Validate input parameters.""" + assert get_data_cuboid_range(10, 10, "x") == (5, 15) + assert get_data_cuboid_range(10, 10, "y") == (5, 15) + assert get_data_cuboid_range(10, 10, "z") == (5, 15) + + with pytest.raises(ValueError): + get_data_cuboid_range(10, 10, "m") + + +def test_img_data_base_bad_args(): + """Validate input parameters.""" + points = torch.empty((5, 5)) + + with pytest.raises(ValueError): + ImageDataBase(points_arr=points, data_axis_order=("x", "y")) + + +def test_img_data_not_impl(): + """Validate calling base, not-implemented functions.""" + points = torch.empty((5, 5)) + data = ImageDataBase(points_arr=points) + + with pytest.raises(NotImplementedError): + data.get_point_cuboid_data(0) + with pytest.raises(NotImplementedError): + data.get_point_batch_cuboid_data( + torch.empty(1, *data.cuboid_with_channels_size), [0] + ) + + +def test_img_stack_not_impl(): + """Validate calling base, not-implemented functions.""" + points = torch.empty((5, 5)) + data = CachedStackImageDataBase(points_arr=points) + + with pytest.raises(NotImplementedError): + data.read_plane(0, 0) + + +def test_img_cuboid_not_impl(): + """Validate calling base, not-implemented functions.""" + points = torch.empty((5, 5)) + data = CachedCuboidImageDataBase(points_arr=points) + + with pytest.raises(NotImplementedError): + data.read_cuboid(0, 0) + + +def test_img_cuboid_bad_arg(tmp_path): + """Validate input parameters.""" + # 2 points but only 1 filename set + points = torch.empty((2, 5)) + filenames = np.array( + [(str(tmp_path / "a.tif"), str(tmp_path / "b.tif"))] + ).astype(np.str_) + + tifffile.imwrite( + filenames[0, 0].item(), np.empty((10, 10, 10), dtype=np.uint16) + ) + tifffile.imwrite( + filenames[0, 1].item(), np.empty((10, 10, 10), dtype=np.uint16) + ) + + with pytest.raises(ValueError): + # no filenames, 1 point + CachedTiffCuboidImageData( + points_arr=points[:1, :3], filenames_arr=filenames[:0] + ) + + with pytest.raises(ValueError): + # 1 filename set, but 2 points + CachedTiffCuboidImageData( + points_arr=points[:, :3], filenames_arr=filenames + ) + + # one of each - should work + data = CachedTiffCuboidImageData( + points_arr=points[:1, :3], filenames_arr=filenames + ) + assert len(data.points_arr) + + +def test_dataset_base_bad_args(): + """Validate input parameters.""" + points = [Cell((0, 0, 0), Cell.CELL)] + + with pytest.raises(ValueError): + # needs 3 axis xyz + CuboidDatasetBase( + points=points, + data_voxel_sizes=(5, 1, 1), + axis_order=("x", "y"), + ) + + with pytest.raises(ValueError): + # needs 4 dims xyzc + CuboidDatasetBase( + points=points, + data_voxel_sizes=(5, 1, 1), + output_axis_order=("x", "y", "c"), + ) + + with pytest.raises(ValueError): + # c must be last dim + CuboidDatasetBase( + points=points, + data_voxel_sizes=(5, 1, 1), + output_axis_order=("x", "y", "c", "z"), + ) + + with pytest.raises(ValueError): + # sizes must be 3-tuple + CuboidDatasetBase( + points=points, + data_voxel_sizes=(5, 1, 1), + network_voxel_sizes=(20, 50), + ) + + +def test_dataset_manual_image_data(): + """Check that we can manually pass an image data instance to dataset.""" + points = torch.zeros((1, 5)) + data = CachedStackImageDataBase( + points_arr=points, + data_axis_order=("x", "y", "z"), + cuboid_size=(20, 50, 50), + ) + + dataset = CuboidDatasetBase( + points=[Cell((0, 0, 0), Cell.CELL)], + data_voxel_sizes=(5, 1, 1), + axis_order=("z", "y", "x"), + network_cuboid_voxels=(20, 50, 50), + src_image_data=data, + ) + # should have 5 elements (x, y, z, c, b) because we are re-ordering + assert len(dataset._output_data_dim_reordering) == 5 + + +def test_dataset_target_bad_values(): + """Validate that target must be valid both for single point and batch.""" + dataset = get_single_point_dataset() + dataset.target_output = "blah" + + with pytest.raises(ValueError): + dataset[0] + + with pytest.raises(ValueError): + # batch of points + dataset[[0]] + + +def test_rescale_dataset_bad_arg(): + """Validate input parameters.""" + volume = np.empty((30, 60, 60, 2), dtype=np.uint16) + dataset = CuboidStackDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(4, 1, 1), + augment=False, + network_cuboid_voxels=(20, 50, 50), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + + with pytest.raises(ValueError): + # must be 5-dim, batch, xyzc + dataset.convert_to_output(torch.zeros((50, 50, 20, 2))) + + +def test_stack_dataset_bad_arg_diff_shapes(): + """Validates that we check signal and background must have same shape.""" + volume = np.empty((30, 60, 60, 2, 2), dtype=np.uint16) + with pytest.raises(ValueError): + CuboidStackDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(5, 1, 1), + signal_array=volume[..., 0, 0], + background_array=volume[..., 1], + ) + + +@pytest.mark.parametrize("channel", ["signal", "background"]) +def test_stack_dataset_bad_arg_signal_dtype(channel): + """Validates that we check signal and background must have same shape.""" + volume_bad = np.empty((30, 60, 60), dtype=np.float64) + volume_good = np.empty((30, 60, 60), dtype=np.float32) + + with pytest.raises(ValueError): + CuboidStackDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(5, 1, 1), + signal_array=volume_bad if channel == "signal" else volume_good, + background_array=( + volume_bad if channel == "background" else volume_good + ), + ) + + +def test_stack_dataset_bad_arg_bad_shape(): + """Validates that we check signal/background are 4-dim.""" + volume = np.empty((30, 60, 60, 2, 2), dtype=np.uint16) + with pytest.raises(ValueError): + CuboidStackDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(5, 1, 1), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + + +def test_dataset_cuboid_bad_arg(tmp_path): + """Validate input parameters.""" + filenames = [(str(tmp_path / "a.tif"), str(tmp_path / "b.tif"))] + + tifffile.imwrite(filenames[0][0], np.empty((10, 10, 10), dtype=np.uint16)) + tifffile.imwrite(filenames[0][1], np.empty((10, 10, 10), dtype=np.uint16)) + + with pytest.raises(ValueError): + # filenames must have at least one sample + CuboidTiffDataset( + points=[Cell((30, 30, 15), Cell.UNKNOWN)], + data_voxel_sizes=(5, 1, 1), + points_filenames=filenames[:0], + ) + + with pytest.raises(ValueError): + # filenames (1) must have same num as points (2) + CuboidTiffDataset( + points=[ + Cell((30, 30, 15), Cell.UNKNOWN), + Cell((32, 30, 15), Cell.UNKNOWN), + ], + data_voxel_sizes=(5, 1, 1), + points_filenames=filenames, + ) + + +def test_point_has_full_cuboid_unscaled(): + """ + Tests that only cuboids that have full cubes around the point's center + are included. Tested under condition where network and data have same + voxel size. + """ + volume = np.empty((30, 30, 30, 2), dtype=np.uint16) + dataset = CuboidStackDataset( + points=[ + Cell((5, 15, 15), Cell.UNKNOWN), + Cell((3, 15, 15), Cell.UNKNOWN), + Cell((27, 15, 15), Cell.UNKNOWN), + ], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(10, 10, 10), + axis_order=("x", "y", "z"), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + assert len(dataset.points_arr) == 1 + p = dataset.points_arr[0].tolist() + assert tuple(p[:3]) == (5, 15, 15) + assert p[3] == Cell.UNKNOWN + assert p[4] == 0 + + +def test_point_has_full_cuboid_scaled(): + """ + Tests that only cuboids that have full cubes around the point's center + are included. Tested under condition where data is half of network's + voxel size. + """ + volume = np.empty((30, 30, 30, 2), dtype=np.uint16) + dataset = CuboidStackDataset( + points=[ + Cell((10, 15, 15), Cell.UNKNOWN), + Cell((6, 15, 15), Cell.UNKNOWN), + Cell((24, 15, 15), Cell.UNKNOWN), + ], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(2, 1, 1), + network_cuboid_voxels=(10, 10, 10), + axis_order=("x", "y", "z"), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + assert len(dataset.points_arr) == 1 + p = dataset.points_arr[0].tolist() + assert tuple(p[:3]) == (10, 15, 15) + assert p[3] == Cell.UNKNOWN + assert p[4] == 0 + + +def test_points_unchanged(): + volume = np.empty((30, 60, 60, 2), dtype=np.uint16) + cell = Cell((30, 30, 15), Cell.UNKNOWN) + cell2 = Cell((30, 33, 15), Cell.CELL) + dataset = CuboidStackDataset( + points=[cell, cell2], + data_voxel_sizes=(5, 1, 1), + augment=False, + network_cuboid_voxels=(20, 50, 50), + signal_array=volume[..., 0], + background_array=volume[..., 1], + ) + + assert len(dataset.points_arr) == 2 + x, y, z, tp, i = dataset.points_arr[0].tolist() + assert (x, y, z) == (30, 30, 15) + assert tp == Cell.UNKNOWN + assert i == 0 + + x, y, z, tp, i = dataset.points_arr[1].tolist() + assert (x, y, z) == (30, 33, 15) + assert tp == Cell.CELL + assert i == 1 + + +def _get_volume_with_stats(normalize): + sig_norm = back_norm = None + sig_mean, sig_std = 222, 20 + back_mean, back_std = 555, 5 + if normalize: + sig_norm = sig_mean, sig_std + back_norm = back_mean, back_std + + volume = np.empty((20, 20, 30, 2), dtype=np.float32) + volume[..., 0] = np.random.normal(sig_mean, sig_std, (20, 20, 30)) + volume[..., 1] = np.random.normal(back_mean, back_std, (20, 20, 30)) + + return ( + volume, + sig_norm, + back_norm, + (sig_mean, sig_std), + (back_mean, back_std), + ) + + +def _check_cube_normalization( + stack, sig_mean, sig_std, back_mean, back_std, normalize +): + # get a single and a batch of 2 cubes + cube_1 = stack[0] + cube_2 = stack[[0, 0]] + for cube in [cube_1, cube_2]: + for ch, ex_mean, ex_std in [ + (0, sig_mean, sig_std), + (1, back_mean, back_std), + ]: + # get output stats + std, mean = torch.std_mean(cube[..., ch]) + + lower_mean = ex_mean * 0.8 + upper_mean = ex_mean * 1.2 + # if normalized, it should be standard normal + if normalize: + ex_mean, ex_std = 0, 1 + lower_mean = -0.2 + upper_mean = 0.2 + + assert lower_mean <= mean.item() < upper_mean + assert ex_std * 0.8 <= std.item() < ex_std * 1.2 + + +@pytest.mark.parametrize("normalize", [True, False]) +def test_array_image_data_normalization(normalize): + """ + Checks that the data returned by the CuboidStackDataset is normalized + if requested, otherwise it shouldn't be normalized. + """ + volume, sig_norm, back_norm, sig_stat, back_stat = _get_volume_with_stats( + normalize + ) + + stack = CuboidStackDataset( + points=[Cell((10, 10, 10), Cell.UNKNOWN)], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(5, 5, 8), + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + signal_array=volume[..., 0], + background_array=volume[..., 1], + signal_normalization=sig_norm, + background_normalization=back_norm, + ) + _check_cube_normalization(stack, *sig_stat, *back_stat, normalize) + + +@pytest.mark.parametrize("normalize", [True, False]) +def test_tiff_image_data_normalization(normalize, tmp_path): + """ + Checks that the data returned by the CuboidTiffDataset is normalized + if requested, otherwise it shouldn't be normalized. + """ + volume, sig_norm, back_norm, sig_stat, back_stat = _get_volume_with_stats( + normalize + ) + + points = [(10, 10, 10)] + cube_size = 5, 5, 8 + filenames, _, _ = to_tiff_cubes(volume, cube_size, points, tmp_path) + + tiffs = CuboidTiffDataset( + points=[Cell(p, Cell.UNKNOWN) for p in points], + data_voxel_sizes=(1, 1, 1), + network_voxel_sizes=(1, 1, 1), + network_cuboid_voxels=(5, 5, 8), + axis_order=("x", "y", "z"), + output_axis_order=("x", "y", "z", "c"), + points_filenames=filenames, + points_normalization=[[sig_norm, back_norm]] if normalize else None, + ) + + _check_cube_normalization(tiffs, *sig_stat, *back_stat, normalize) diff --git a/tests/core/test_unit/test_classify/test_data_augment.py b/tests/core/test_unit/test_classify/test_data_augment.py new file mode 100644 index 00000000..3c8455d3 --- /dev/null +++ b/tests/core/test_unit/test_classify/test_data_augment.py @@ -0,0 +1,108 @@ +import math + +import pytest +import torch + +from cellfinder.core.classify.augment import DataAugmentation + + +@pytest.fixture +def cube_with_side_dot() -> torch.Tensor: + # DataAugmentation.DIM_ORDER of (c, y, x, z) + data = torch.zeros((2, 11, 11, 7)) + # put dot at pixel in plane z = 5, on the x/y diagonal - both x and y are 8 + data[0, 8, 8, 5] = 1 + # put dot at pixel in plane z = 5, but x is 2 and y is 8 + data[1, 8, 2, 5] = 1 + return data + + +@pytest.fixture +def cube_with_center_dot() -> torch.Tensor: + # DataAugmentation.DIM_ORDER of (c, y, x, z) + data = torch.zeros((2, 11, 11, 7)) + data[:, 5, 5, 3] = 1 + return data + + +def test_augment_translate(cube_with_side_dot): + c, y, x, z = cube_with_side_dot.shape + translate_range = [(1 / 11, 1 / 11), (2 / 11, 2 / 11), (1 / 7, 1 / 7)] + augmenter = DataAugmentation( + volume_size={"x": x, "y": y, "z": z}, + augment_likelihood=1, + translate_range=translate_range, + data_dim_order=("c", "y", "x", "z"), + ) + assert augmenter.update_parameters(), "Parameters should be randomized" + augmented = augmenter(cube_with_side_dot) + + assert augmented.shape == cube_with_side_dot.shape + assert augmented[0, 8, 8, 5] < 0.1 + assert augmented[1, 8, 2, 5] < 0.1 + assert augmented[0, 8 + 1, 8 + 2, 5 + 1] > 0.1 + assert augmented[1, 8 + 1, 2 + 2, 5 + 1] > 0.1 + + +def test_augment_rotate(cube_with_side_dot): + c, y, x, z = cube_with_side_dot.shape + augmenter = DataAugmentation( + volume_size={"x": x, "y": y, "z": z}, + augment_likelihood=1, + rotate_range=[(0, 0), (0, 0), (math.pi / 2, math.pi / 2)], + data_dim_order=("c", "y", "x", "z"), + ) + assert augmenter.update_parameters(), "Parameters should be randomized" + augmented = augmenter(cube_with_side_dot) + + assert augmented.shape == cube_with_side_dot.shape + assert augmented[0, 8, 8, 5] < 0.1 + assert augmented[1, 8, 2, 5] < 0.1 + # we rotated around z axis by 45 degree. So x,y point at 8, just reflects + # around x axis so y is still 8 (3 from end), but x becomes 10 - 8 + # (len = 11) + assert augmented[0, 8, 10 - 8, 5] > 0.1 + # for center rotation of 45 degree from x = 2, y = 8 we end up with + # x = -8 (10 - 8 = 2) and y is 2 + assert augmented[1, 2, 10 - 8, 5] > 0.1 + + +def test_augment_scale(cube_with_center_dot): + c, y, x, z = cube_with_center_dot.shape + augmenter = DataAugmentation( + volume_size={"x": x, "y": y, "z": z}, + augment_likelihood=1, + scale_range=((3, 3),) * 3, + data_dim_order=("c", "y", "x", "z"), + ) + assert augmenter.update_parameters(), "Parameters should be randomized" + augmented = augmenter(cube_with_center_dot) + + assert augmented.shape == cube_with_center_dot.shape + assert augmented[0, 5, 5, 3] > 0.1 + assert augmented[0, 5 + 1, 5, 3] > 0.1 + assert augmented[0, 5, 5 + 1, 3] > 0.1 + assert augmented[0, 5, 5, 3 + 1] > 0.1 + + assert augmented[1, 5, 5, 3] > 0.1 + assert augmented[1, 5 + 1, 5, 3] > 0.1 + assert augmented[1, 5, 5 + 1, 3] > 0.1 + assert augmented[1, 5, 5, 3 + 1] > 0.1 + + +def test_augment_axis_flip(cube_with_side_dot): + c, y, x, z = cube_with_side_dot.shape + augmenter = DataAugmentation( + volume_size={"x": x, "y": y, "z": z}, + augment_likelihood=1, + flippable_axis=(0, 1), + data_dim_order=("c", "y", "x", "z"), + ) + assert augmenter.update_parameters(), "Parameters should be randomized" + augmented = augmenter(cube_with_side_dot) + + assert augmented.shape == cube_with_side_dot.shape + assert augmented[0, 8, 8, 5] < 0.1 + assert augmented[1, 8, 2, 5] < 0.1 + assert augmented[0, 2, 2, 5] > 0.1 + assert augmented[1, 2, 8, 5] > 0.1 diff --git a/tests/core/test_unit/test_tools/test_image_processing.py b/tests/core/test_unit/test_tools/test_image_processing.py index 64ad4891..d56d189e 100644 --- a/tests/core/test_unit/test_tools/test_image_processing.py +++ b/tests/core/test_unit/test_tools/test_image_processing.py @@ -1,6 +1,7 @@ import random import numpy as np +import pytest from cellfinder.core.tools import image_processing as img_tools @@ -35,3 +36,16 @@ def test_pad_centre_2d(): img, x_size=new_x_shape, y_size=new_y_shape ) assert (new_y_shape, new_x_shape) == pad_img.shape + + +@pytest.mark.parametrize("progress", [True, False]) +def test_dataset_mean_std(progress): + # checks that dataset_mean_std correctly computes the std/mean + data = np.random.normal(100, 10, (10, 10, 10)) + + mean, std = img_tools.dataset_mean_std( + data, sampling_factor=2, show_progress=progress + ) + # give it enough room for estimation error + assert 90 < mean < 110 + assert 8 < std < 12 diff --git a/tests/core/test_unit/test_tools/test_threading.py b/tests/core/test_unit/test_tools/test_threading.py index 0a341068..747be98a 100644 --- a/tests/core/test_unit/test_tools/test_threading.py +++ b/tests/core/test_unit/test_tools/test_threading.py @@ -1,3 +1,5 @@ +import multiprocessing as mp + import pytest from cellfinder.core.tools.threading import ( @@ -6,6 +8,7 @@ ExecutionFailure, ProcessWithException, ThreadWithException, + ThreadWithExceptionMPSafe, ) cls_to_test = [ThreadWithException, ProcessWithException] @@ -33,7 +36,7 @@ def do_nothing(*args): pass -def send_back_msg(thread: ExceptionWithQueueMixIn): +def send_back_msg(thread: ExceptionWithQueueMixIn, *args): # do this single op and exit thread.send_msg_to_mainthread(("back", thread.get_msg_from_mainthread())) @@ -139,3 +142,70 @@ def test_skip_until_eof(cls): # eof to main-thread, which is the last thing thread does before exiting assert thread._saw_eof thread.join() + + +def _do_nothing(*args): + pass + + +class BlankClass: + pass + + +@pytest.mark.parametrize( + "cls", [ThreadWithException, ThreadWithExceptionMPSafe] +) +def test_thread_with_multiprocess(cls): + inst = BlankClass() + inst.thread = cls(target=_do_nothing) + + proc = ProcessWithException(target=_do_nothing, args=(inst,)) + + if cls is ThreadWithException: + # ThreadWithException cannot be shared with subprocess + with pytest.raises(TypeError): + proc.start() + else: + proc.start() + proc.join() + + +def test_share_queue_sub_process_bad(capsys): + # we cannot pass queues to sub-process via messages + ctx = mp.get_context("spawn") + queue = ctx.Queue(maxsize=0) + + proc = ProcessWithException(target=send_back_msg, pass_self=True) + proc.start() + proc.send_msg_to_thread(queue) + + proc.notify_to_end_thread() + proc.clear_remaining() + proc.join() + + # we cannot catch the exception because the exception "RuntimeError: Queue + # objects should only be shared between processes through inheritance" is + # raised in some 3rd party thread/process and is not propagated, only + # printed to stderr + err = capsys.readouterr().err + assert "RuntimeError" in err + + +def test_share_queue_sub_process_good(capsys): + # we can pass queues to sub-process via args + ctx = mp.get_context("spawn") + queue = ctx.Queue(maxsize=0) + + proc = ProcessWithException( + target=send_back_msg, args=(queue,), pass_self=True + ) + proc.start() + proc.send_msg_to_thread("hello") + proc.get_msg_from_thread() + + proc.notify_to_end_thread() + proc.clear_remaining() + proc.join() + + err = capsys.readouterr().err + assert "RuntimeError" not in err diff --git a/tests/core/test_unit/test_tools/test_tools_general.py b/tests/core/test_unit/test_tools/test_tools_general.py index 9a3c573e..a34f10dc 100644 --- a/tests/core/test_unit/test_tools/test_tools_general.py +++ b/tests/core/test_unit/test_tools/test_tools_general.py @@ -1,5 +1,7 @@ import numpy as np import pytest +import torch +from pytest_mock.plugin import MockerFixture import cellfinder.core.tools.tools as tools @@ -220,6 +222,7 @@ def test_check_unique_list(): def test_common_member(): assert (True, [10, 30]) == tools.common_member(a, b) + assert (False, []) == tools.common_member(a, []) def test_get_number_of_bins_nd(): @@ -246,3 +249,44 @@ def test_swap_elements_list(): def test_is_any_list_overlap(): assert tools.is_any_list_overlap(a, b) assert not tools.is_any_list_overlap(a, [2, "b", (1, 2, 3)]) + + +def test_random_bool(): + assert tools.random_bool() in (0, 1) + assert tools.random_bool(0.5) in (True, False) + + +@pytest.mark.parametrize("ret_val,sign_val", [(1, 1), (0, -1)]) +def test_random_sign_false(mocker: MockerFixture, ret_val, sign_val): + def ret(): + return ret_val + + mocker.patch("cellfinder.core.tools.tools.random_bool", new=ret) + assert tools.random_sign() == sign_val + + +def test_random_probability(): + assert 0 <= tools.random_probability() <= 1 + + +def test_all_elements_equal(): + assert not tools.all_elements_equal([1, 1, 2]) + assert tools.all_elements_equal([1, 1]) + assert tools.all_elements_equal([]) + + +def test_get_axis_reordering(): + x = torch.arange(10) + y = torch.arange(10) / 10 + 2 + z = torch.arange(10) + 33 + + data = x[:, None, None] * y[None, :, None] * z[None, None, :] + + reordering = tools.get_axis_reordering(("x", "y", "z"), ("y", "x", "z")) + reordered = torch.permute(data, reordering) + + manual_reordered = y[:, None, None] * x[None, :, None] * z[None, None, :] + + assert torch.allclose(reordered, manual_reordered) + # check data is not already symmetric in x, y + assert data[0, 1, 5] != data[1, 0, 5] diff --git a/tests/data/integration/training/training_with_stats.yaml b/tests/data/integration/training/training_with_stats.yaml new file mode 100644 index 00000000..b2ea7b72 --- /dev/null +++ b/tests/data/integration/training/training_with_stats.yaml @@ -0,0 +1,19 @@ +data: +- bg_channel: 1 + cell_def: '' + cube_dir: tests/data/integration/training/cells + signal_channel: 0 + type: cell + signal_mean: 241.31 + signal_std: 154.92 + bg_mean: 650.94 + bg_std: 217.90 +- bg_channel: 1 + cell_def: '' + cube_dir: tests/data/integration/training/cells + signal_channel: 0 + type: no_cell + signal_mean: 231.28 + signal_std: 79.60 + bg_mean: 836.21 + bg_std: 348.35 diff --git a/tests/napari/test_curation.py b/tests/napari/test_curation.py index 083f90fa..8e8350a0 100644 --- a/tests/napari/test_curation.py +++ b/tests/napari/test_curation.py @@ -4,6 +4,7 @@ import napari import numpy as np import pytest +import yaml from napari.layers import Image, Points from cellfinder.napari import sample_data @@ -12,7 +13,7 @@ @pytest.fixture -def curation_widget(make_napari_viewer): +def curation_widget(make_napari_viewer) -> CurationWidget: """ Create a viewer, add the curation widget, and return the widget. The viewer can be accessed using ``widget.viewer``. @@ -39,6 +40,20 @@ def test_add_new_training_layers(curation_widget): assert layers[1].name == "Training data (non cells)" +def test_update_voxel_size(curation_widget: CurationWidget): + assert curation_widget.voxel_sizes == [5, 2, 2] + curation_widget.voxel_sizes_boxes[0].setValue(3) + curation_widget.voxel_sizes_boxes[1].setValue(4) + curation_widget.voxel_sizes_boxes[2].setValue(5) + assert curation_widget.voxel_sizes == [3, 4, 5] + + +def test_update_normalization_down_sampling(curation_widget: CurationWidget): + assert curation_widget.normalization_down_sampling == 32 + curation_widget.norm_sampling_box.setValue(8) + assert curation_widget.normalization_down_sampling == 8 + + @pytest.mark.xfail(reason="See discussion in #443", raises=AssertionError) def test_cell_marking(curation_widget, tmp_path): """ @@ -90,6 +105,19 @@ def test_cell_marking(curation_widget, tmp_path): assert len(list((tmp_path / "non_cells").glob("*.tif"))) == 2 assert len(list((tmp_path / "cells").glob("*.tif"))) == 2 + with open(tmp_path / "training.yaml", "r") as fh: + yaml_data = yaml.safe_load(fh) + + for item in yaml_data["data"]: + assert "cube_dir" in item + assert "signal_channel" in item + assert "bg_channel" in item + assert "type" in item + assert "signal_mean" in item + assert "signal_std" in item + assert "bg_mean" in item + assert "bg_std" in item + @pytest.fixture def valid_curation_widget(make_napari_viewer) -> CurationWidget: