Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: exact solver #7

Merged
merged 8 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9, "3.10"]
python-version: [3.9, "3.10"]
poetry-version: [1.4.0]
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
Expand All @@ -33,7 +33,7 @@ jobs:
id: cache
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }}
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Ensure cache is healthy
if: steps.cache.outputs.cache-hit == 'true'
run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv
Expand Down
169 changes: 74 additions & 95 deletions examples/faces.ipynb

Large diffs are not rendered by default.

349 changes: 151 additions & 198 deletions examples/metrics.ipynb

Large diffs are not rendered by default.

329 changes: 138 additions & 191 deletions examples/performance.ipynb

Large diffs are not rendered by default.

201 changes: 49 additions & 152 deletions examples/performance_trick.ipynb

Large diffs are not rendered by default.

8 changes: 1 addition & 7 deletions phomo/metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import sys

# prior to python 3.8, Protocol is in typing_extensions
if sys.version_info[0] == 3 and sys.version_info[1] < 8:
from typing_extensions import Protocol
else:
from typing import Protocol
from typing import Protocol

import numpy as np

Expand Down
64 changes: 61 additions & 3 deletions phomo/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from scipy.optimize import linear_sum_assignment

from .grid import Grid
from .master import Master
Expand Down Expand Up @@ -101,7 +102,7 @@ def compute_d_matrix(
if isinstance(metric, str):
if metric not in METRICS.keys():
raise KeyError(
f"'%s' not in available metrics: %s",
"'%s' not in available metrics: %s",
metric,
repr(list(METRICS.keys())),
)
Expand Down Expand Up @@ -141,14 +142,14 @@ def compute_d_matrix(
self._log.debug("d_matrix shape: %s", d_matrix.shape)
return d_matrix

def build(
def build_greedy(
self,
workers: int = 1,
metric: Union[str, MetricCallable] = "norm",
d_matrix: Optional[np.ndarray] = None,
**kwargs,
) -> Image.Image:
"""Construct the mosaic image.
"""Construct the mosaic image using a greedy tile assignement algorithm.

Args:
workers: The number of workers to use when computing the
Expand Down Expand Up @@ -204,6 +205,63 @@ def build(
pbar.close()
return Image.fromarray(np.uint8(mosaic))

def build(
self,
workers: int = 1,
metric: Union[str, MetricCallable] = "norm",
d_matrix: Optional[np.ndarray] = None,
**kwargs,
) -> Image.Image:
"""Construct the mosaic image by solving the linear sum assignment problem.
See: https://en.wikipedia.org/wiki/Assignment_problem

Args:
workers: The number of workers to use when computing the
distance matrix.
metric: The distance metric used for the distance matrix. Either
provide a string, for implemented metrics see ``phomo.metrics.METRICS``.
Or a callable, which should take two ``np.ndarray``s and return a float.
d_matrix: Use a pre-computed distance matrix.
**kwargs: Passed to the `metric` function.

Returns:
The PIL.Image instance of the mosaic.
"""
mosaic = np.zeros((self.size[1], self.size[0], 3))

# Compute the distance matrix.
if d_matrix is None:
d_matrix = self.compute_d_matrix(workers=workers, metric=metric, **kwargs)

# expand the dmatrix to allow for repeated tiles
if self.n_appearances > 0:
d_matrix = np.tile(d_matrix, self.n_appearances)
print("dmatrix", d_matrix.shape)

self._log.info("Computing optimal tile assignment.")
row_ind, col_ind = linear_sum_assignment(d_matrix)
pbar = tqdm(total=d_matrix.shape[0], desc="Building mosaic")
for row, col in zip(row_ind, col_ind):
slices = self.grid.slices[row]
tile_array = self.pool.arrays[col % len(self.pool.arrays)]
# if the grid has been subdivided then the tile should be shrunk to
# the size of the subdivision
array_size = (
slices[1].stop - slices[1].start,
slices[0].stop - slices[0].start,
)
if tile_array.shape[:-1] != array_size[::-1]:
tile_array = resize_array(tile_array, array_size)

# shift slices back so that the centering of the mosaic within the
# master image is removed
slices = self.grid.remove_origin(slices)
mosaic[slices[0], slices[1]] = tile_array
pbar.update(1)
pbar.close()

return Image.fromarray(np.uint8(mosaic))

def __repr__(self) -> str:
# indent these guys
master = repr(self.master).replace("\n", "\n ")
Expand Down
8 changes: 4 additions & 4 deletions phomo/palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _cdfs(frequencies: np.ndarray) -> np.ndarray:
cdfs /= cdfs[-1]
return cdfs

def plot(self, log: bool = False) -> Tuple[plt.Figure, plt.Axes]:
def plot(self, log: bool = False) -> Tuple[plt.Figure, np.ndarray]:
"""Plot the colour distribution.

Args:
Expand All @@ -56,8 +56,8 @@ def plot(self, log: bool = False) -> Tuple[plt.Figure, plt.Axes]:
"""

bin_edges, values = self.palette()
fig, axes = plt.subplots(3, figsize=(12, 6))
for i, ax in enumerate(axes):
fig, axs = plt.subplots(3, figsize=(12, 6))
for i, ax in enumerate(axs):
ax.bar(
bin_edges[:-1, i],
values[:, i],
Expand All @@ -68,4 +68,4 @@ def plot(self, log: bool = False) -> Tuple[plt.Figure, plt.Axes]:
ax.set_yscale("log")
ax.set_title(f"Channel {i+1}")
fig.tight_layout()
return fig, axes
return fig, axs
3 changes: 2 additions & 1 deletion phomo/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def from_dir(

Args:
tile_dir: path to directory containing the images.
crop_ratio: width to height ratio to crop the master image to. 1 results in a square image.
crop_ratio: width to height ratio to crop the tile images to. 1 results in a
square image.
tile_size: resize the image to the provided size, width followed by height.
convert: convert the image to the provided mode. See PIL Modes.
"""
Expand Down
Loading