Skip to content

Commit

Permalink
feat(transforms): add new image transformation classes
Browse files Browse the repository at this point in the history
  • Loading branch information
yzx9 committed Mar 18, 2024
1 parent 80649c0 commit 7b5d805
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions swcgeom/transforms/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import numpy as np
import numpy.typing as npt

from swcgeom.transforms.base import Transform
from swcgeom.transforms.base import Identity, Transform

__all__ = [
"ImagesCenterCrop",
"ImagesScale",
"ImagesClip",
"ImagesNormalizer",
"ImagesMeanVarianceAdjustment",
"ImagesScaleToUnitRange",
"ImagesHistogramEqualization",
"Center", # legacy
]

Expand Down Expand Up @@ -66,6 +68,9 @@ def __init__(self, scaler: float) -> None:
def __call__(self, x: NDArrayf32) -> NDArrayf32:
return self.scaler * x

def extra_repr(self) -> str:
return f"scaler={self.scaler}"


class ImagesClip(Transform[NDArrayf32, NDArrayf32]):
def __init__(self, vmin: float = 0, vmax: float = 1, /) -> None:
Expand All @@ -75,6 +80,9 @@ def __init__(self, vmin: float = 0, vmax: float = 1, /) -> None:
def __call__(self, x: NDArrayf32) -> NDArrayf32:
return np.clip(x, self.vmin, self.vmax)

def extra_repr(self) -> str:
return f"vmin={self.vmin}, vmax={self.vmax}"


class ImagesNormalizer(Transform[NDArrayf32, NDArrayf32]):
"""Normalize image stack."""
Expand All @@ -101,5 +109,61 @@ def __init__(self, mean: float, variance: float) -> None:
def __call__(self, x: NDArrayf32) -> NDArrayf32:
return (x - self.mean) / self.variance

def extra_repr(self):
def extra_repr(self) -> str:
return f"mean={self.mean}, variance={self.variance}"


class ImagesScaleToUnitRange(Transform[NDArrayf32, NDArrayf32]):
"""Scale image stack to unit range."""

def __init__(self, vmin: float, vmax: float, *, clip: bool = True) -> None:
"""Scale image stack to unit range.
Parameters
----------
vmin : float
Minimum value.
vmax : float
Maximum value.
clip : bool, default True
Clip values to [0, 1] to avoid numerical issues.
"""

super().__init__()
self.vmin = vmin
self.vmax = vmax
self.diff = vmax - vmin
self.clip = clip
self.post = ImagesClip(0, 1) if self.clip else Identity()

def __call__(self, x: NDArrayf32) -> NDArrayf32:
return self.post((x - self.vmin) / self.diff)

def extra_repr(self) -> str:
return f"vmin={self.vmin}, vmax={self.vmax}, clip={self.clip}"


class ImagesHistogramEqualization(Transform[NDArrayf32, NDArrayf32]):
"""Image histogram equalization.
References
----------
http://www.janeriksolem.net/histogram-equalization-with-python-and.html
"""

def __init__(self, bins: int = 256) -> None:
super().__init__()
self.bins = bins

def __call__(self, x: NDArrayf32) -> NDArrayf32:
# get image histogram
hist, bin_edges = np.histogram(x.flatten(), self.bins, density=True)
cdf = hist.cumsum() # cumulative distribution function
cdf = cdf / cdf[-1] # normalize

# use linear interpolation of cdf to find new pixel values
equalized = np.interp(x.flatten(), bin_edges[:-1], cdf)
return equalized.reshape(x.shape).astype(np.float32)

def extra_repr(self) -> str:
return f"bins={self.bins}"

0 comments on commit 7b5d805

Please sign in to comment.