Skip to content

Commit

Permalink
feat(palette): add more plots
Browse files Browse the repository at this point in the history
  • Loading branch information
loiccoyle committed Jul 5, 2024
1 parent 9fde015 commit ce7b124
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 80 deletions.
201 changes: 166 additions & 35 deletions examples/matching.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions phomo/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(self, array: np.ndarray) -> None:
Returns:
Master image instance.
"""
self.array = array
LOGGER.info("master shape: %s", self.array.shape)
super().__init__(array)

@property
def img(self):
Expand Down
197 changes: 160 additions & 37 deletions phomo/palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,13 @@ class Palette:

def __init__(self, array: ArrayLike):
self.array = np.array(array)
self.plot = PalettePlotter(self)

@property
def pixels(self) -> np.ndarray:
"""Returns flattened pixels from the array."""
return self.array.reshape(-1, self.array.shape[-1])

def palette(self, bins: int = 256) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
"""Compute the 3D colour distribution."""
hist, edges = np.histogramdd(
self.pixels,
bins=bins,
range=[(0, 255), (0, 255), (0, 255)],
)
return edges, hist

def plot(self) -> Tuple[Figure, np.ndarray]:
"""Plot 2D projections of the 3D colour distribution."""
_, hist = self.palette()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs = axs.ravel()

titles = ["Red-Green", "Green-Blue", "Blue-Red"]
for i, (ax, title) in enumerate(zip(axs, titles)):
i = (i + 2) % 3
proj = np.sum(hist, axis=i)
if i != 1:
proj = proj.T
ax.imshow(
proj,
origin="lower",
extent=[0, 255, 0, 255],
aspect="auto",
vmax=np.mean(proj) + 3 * np.std(proj),
)
ax.set_title(title)
ax.set_xlabel(title.split("-")[0])
ax.set_ylabel(title.split("-")[1])

fig.tight_layout()
return fig, axs

def equalize(self):
"""Equalize the colour distribution using `cv2.equalizeHist`.
Expand All @@ -73,9 +39,10 @@ def equalize(self):

def match(self, other: "Palette"):
"""Match the colour distribution of the `Master` to the distribution of the
`Pool` using the colour transfer algorithm explained in this paper:
`Pool` using the colour transfer algorithm.
https://api.semanticscholar.org/CorpusID:14088925
See:
https://api.semanticscholar.org/CorpusID:14088925
Args:
The other `Palette` to match this `Palette`'s colour distribution to.
Expand Down Expand Up @@ -111,3 +78,159 @@ def match(self, other: "Palette"):
result_lab = np.clip(result_lab, 0, 255).astype(np.uint8)
result_rgb = cv2.cvtColor(result_lab, cv2.COLOR_LAB2RGB)
return self.__class__(result_rgb.reshape(self_shape))


class PalettePlotter:
def __init__(self, palette: Palette):
self._palette = palette

def _colour_hist(self, **kwargs) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the 1D colour distributions.
Args:
**kwargs: passed to `numpy.histogram`.
Returns:
Histogram edges and counts.
"""
bins = kwargs.pop("bins", range(256))
values = []
bin_edges = []
for i in range(self._palette.pixels.shape[1]):
freqs, edges = np.histogram(self._palette.pixels[:, i], bins=bins, **kwargs)
bin_edges.append(edges)
values.append(freqs)
values = np.vstack(values).T
bin_edges = np.vstack(bin_edges).T
return bin_edges, values

def _colour_hist_3d(
self, bins: int = 256
) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
"""Compute the 3D colour distribution."""
hist, edges = np.histogramdd(
self._palette.pixels,
bins=bins,
range=[(0, 255), (0, 255), (0, 255)],
)
return edges, hist

def _colour_palette(self, depth: int = 3):
pixels = self._palette.array.reshape(-1, 3)

def split(pixels: np.ndarray, depth: int) -> list[np.ndarray]:
if len(pixels) == 0 or depth == 0:
return [pixels]

ranges = np.ptp(pixels, axis=0)
axis = np.argmax(ranges)
median = np.median(pixels[:, axis])

left = pixels[pixels[:, axis] <= median]
right = pixels[pixels[:, axis] > median]

return split(left, depth - 1) + split(right, depth - 1)

quantized = split(pixels, depth)

palette = [np.mean(region, axis=0) for region in quantized if len(region) > 0]
palette = np.array(palette, dtype=np.uint8)

return palette[::-1]

def palette(self, depth: int = 3) -> Tuple[Figure, np.ndarray]:
"""Show the dominant colours of the palette using a median cut algorithm.
See:
https://en.wikipedia.org/wiki/Median_cut
Args:
depth: The number of splits to perform.
Returns:
`Figure` and `np.array` of `Axes`.
"""
palette = self._colour_palette(depth=depth)

square_size = 50
palette_ar = np.zeros(
(square_size, len(palette) * square_size, 3), dtype="uint8"
)

for i, color in enumerate(palette):
palette_ar[:, i * square_size : (i + 1) * square_size, :] = color

fig, ax = plt.subplots(
1,
figsize=(5, 5 * len(palette)),
frameon=False,
)
ax.imshow(palette_ar, aspect="equal")
ax.set_axis_off()
ax.margins(0, 0)
fig.tight_layout(pad=0)
return fig, ax

def distribution(self, log: bool = False) -> Tuple[Figure, np.ndarray]:
"""Plot the colour distribution of each channel.
Args:
log: Plot y axis in log scale.
Returns:
`Figure` and `np.array` of `Axes`.
"""

bin_edges, values = self._colour_hist()
fig, axs = plt.subplots(3, figsize=(12, 6))
channels = ["Red", "Green", "Blue"]
for i, (ax, channel) in enumerate(zip(axs, channels)):
ax.bar(
bin_edges[:-1, i],
values[:, i],
width=np.diff(bin_edges[:, i]),
align="edge",
color=channel,
)
if log:
ax.set_yscale("log")
ax.set_title(channel)
ax.set_xlim(0, 255)
fig.tight_layout()
return fig, axs

def distribution_2d(self) -> Tuple[Figure, np.ndarray]:
"""Plot 2D projections of the 3D colour distribution.
Returns:
`Figure` and `np.array` of `Axes`.
"""
_, hist = self._colour_hist_3d()
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs = axs.ravel()

titles = ["Red-Green", "Green-Blue", "Blue-Red"]
for i, (ax, title) in enumerate(zip(axs, titles)):
i = (i + 2) % 3
proj = np.sum(hist, axis=i)
if i != 1:
proj = proj.T
ax.imshow(
proj,
origin="lower",
extent=[0, 255, 0, 255],
aspect="auto",
vmax=np.mean(proj) + 3 * np.std(proj),
)
ax.set_title(title)
ax.set_xlabel(title.split("-")[0])
ax.set_ylabel(title.split("-")[1])

fig.tight_layout()
return fig, axs

def __call__(self):
"""Plot all the plots."""
self.palette()
self.distribution()
self.distribution_2d()
3 changes: 1 addition & 2 deletions phomo/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def __init__(
Args:
array: `Pool` image data array. Should be (n_tiles, height, width, 3)
"""
self.array = np.array(array)
LOGGER.info("Number of tiles: %s", len(self.array))
super().__init__(array)

@property
def tiles(self) -> "PoolTiles":
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def test_pixels(self):
== self.master_array.shape[0] * self.master_array.shape[1]
)

# Palette methods
# plot methods
def test_palette(self):
self.master.palette()
self.master.plot()

def test_plot(self):
self.master.plot()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def test_pixels(self):
def test_len(self):
assert len(self.tile_paths) == len(self.pool)

# Palette methods
# plot methods
def test_palette(self):
self.pool.palette()
self.pool.plot()

def test_plot(self):
self.pool.plot()
Expand Down

0 comments on commit ce7b124

Please sign in to comment.