diff --git a/pyproject.toml b/pyproject.toml index fe287df..d6180eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,13 +40,18 @@ dependencies = [ "numpy", "einops", "torch_image_lerp", + "torch_grid_utils", ] # https://peps.python.org/pep-0621/#dependencies-optional-dependencies # "extras" (e.g. for `pip install .[test]`) [project.optional-dependencies] # add dependencies used for testing here -test = ["pytest", "pytest-cov"] +test = [ + "pytest", + "pytest-cov", + "torch-fourier-shell-correlation", +] # add anything else you like to have in your dev environment here dev = [ "ipython", diff --git a/src/torch_fourier_slice/backproject.py b/src/torch_fourier_slice/backproject.py index efed8c7..fbc8ae9 100644 --- a/src/torch_fourier_slice/backproject.py +++ b/src/torch_fourier_slice/backproject.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F +from torch_grid_utils import fftfreq_grid -from .grids import fftfreq_grid from .slice_insertion import insert_central_slices_rfft_3d diff --git a/src/torch_fourier_slice/grids/__init__.py b/src/torch_fourier_slice/grids/__init__.py index 70a89bb..74d3492 100644 --- a/src/torch_fourier_slice/grids/__init__.py +++ b/src/torch_fourier_slice/grids/__init__.py @@ -1 +1 @@ -from .fftfreq_grid import fftfreq_grid +from .central_slice_fftfreq_grid import central_slice_fftfreq_grid \ No newline at end of file diff --git a/src/torch_fourier_slice/grids/central_slice_grid.py b/src/torch_fourier_slice/grids/central_slice_fftfreq_grid.py similarity index 92% rename from src/torch_fourier_slice/grids/central_slice_grid.py rename to src/torch_fourier_slice/grids/central_slice_fftfreq_grid.py index 3b6f2c1..b09d327 100644 --- a/src/torch_fourier_slice/grids/central_slice_grid.py +++ b/src/torch_fourier_slice/grids/central_slice_fftfreq_grid.py @@ -1,7 +1,7 @@ import einops import torch -from .fftfreq_grid import _construct_fftfreq_grid_2d +from torch_grid_utils import fftfreq_grid from ..dft_utils import rfft_shape, fftshift_2d @@ -13,7 +13,7 @@ def central_slice_fftfreq_grid( ) -> torch.Tensor: # generate 2d grid of DFT sample frequencies, shape (h, w, 2) h, w = volume_shape[-2:] - grid = _construct_fftfreq_grid_2d( + grid = fftfreq_grid( image_shape=(h, w), rfft=rfft, device=device diff --git a/src/torch_fourier_slice/grids/fftfreq_grid.py b/src/torch_fourier_slice/grids/fftfreq_grid.py deleted file mode 100644 index cf55da4..0000000 --- a/src/torch_fourier_slice/grids/fftfreq_grid.py +++ /dev/null @@ -1,158 +0,0 @@ -import functools -from typing import Sequence, Tuple - -import einops -import torch - -from ..dft_utils import rfft_shape, fftshift_2d, fftshift_3d - - -@functools.lru_cache(maxsize=1) -def fftfreq_grid( - image_shape: tuple[int, int] | tuple[int, int, int], - rfft: bool, - fftshift: bool = False, - spacing: float | tuple[float, float] | tuple[float, float, float] = 1, - norm: bool = False, - device: torch.device | None = None, -): - """Construct a 2D or 3D grid of DFT sample frequencies. - - For a 2D image with shape `(h, w)` and `rfft=False` this function will produce - a `(h, w, 2)` array of DFT sample frequencies in the `h` and `w` dimensions. - If `norm` is True the Euclidean norm will be calculated over the last dimension - leaving a `(h, w)` grid. - - Parameters - ---------- - image_shape: tuple[int, int] | tuple[int, int, int] - Shape of the 2D or 3D image before computing the DFT. - rfft: bool - Whether the output should contain frequencies for a real-valued DFT. - fftshift: bool - Whether to fftshift the output grid. - spacing: float | tuple[float, float] | tuple[float, float, float] - Spacing between samples in each dimension. Sampling is considered to be - isotropic if a single value is passed. - norm: bool - Whether to compute the Euclidean norm over the last dimension. - device: torch.device | None - PyTorch device on which the returned grid will be stored. - - Returns - ------- - frequency_grid: torch.Tensor - `(*image_shape, ndim)` array of DFT sample frequencies in each - image dimension if `norm` is `False` else `(*image_shape, )`. - """ - if len(image_shape) == 2: - frequency_grid = _construct_fftfreq_grid_2d( - image_shape=image_shape, - rfft=rfft, - spacing=spacing, - device=device, - ) - if fftshift is True: - frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...') - frequency_grid = fftshift_2d(frequency_grid, rfft=rfft) - frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq') - elif len(image_shape) == 3: - frequency_grid = _construct_fftfreq_grid_3d( - image_shape=image_shape, - rfft=rfft, - spacing=spacing, - device=device, - ) - if fftshift is True: - frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...') - frequency_grid = fftshift_3d(frequency_grid, rfft=rfft) - frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq') - else: - raise NotImplementedError( - "Construction of fftfreq grids is currently only supported for " - "2D and 3D images." - ) - if norm is True: - frequency_grid = einops.reduce( - frequency_grid ** 2, '... d -> ...', reduction='sum' - ) ** 0.5 - return frequency_grid - - -def _construct_fftfreq_grid_2d( - image_shape: Tuple[int, int], - rfft: bool, - spacing: float | tuple[float, float] = 1, - device: torch.device = None -) -> torch.Tensor: - """Construct a grid of DFT sample freqs for a 2D image. - - Parameters - ---------- - image_shape: Sequence[int] - A 2D shape `(h, w)` of the input image for which a grid of DFT sample freqs - should be calculated. - rfft: bool - Whether the frequency grid is for a real fft (rfft). - spacing: float | Tuple[float, float] - Sample spacing in `h` and `w` dimensions of the grid. - device: torch.device - Torch device for the resulting grid. - - Returns - ------- - frequency_grid: torch.Tensor - `(h, w, 2)` array of DFT sample freqs. - Order of freqs in the last dimension corresponds to the order of - the two dimensions of the grid. - """ - dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 2 - last_axis_fftfreq = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq - h, w = image_shape - freq_y = torch.fft.fftfreq(h, d=dh, device=device) - freq_x = last_axis_fftfreq(w, d=dw, device=device) - h, w = rfft_shape(image_shape) if rfft is True else image_shape - freq_yy = einops.repeat(freq_y, 'h -> h w', w=w) - freq_xx = einops.repeat(freq_x, 'w -> h w', h=h) - return einops.rearrange([freq_yy, freq_xx], 'freq h w -> h w freq') - - -def _construct_fftfreq_grid_3d( - image_shape: Sequence[int], - rfft: bool, - spacing: float | Tuple[float, float, float] = 1, - device: torch.device = None -) -> torch.Tensor: - """Construct a grid of DFT sample freqs for a 3D image. - - Parameters - ---------- - image_shape: Sequence[int] - A 3D shape `(d, h, w)` of the input image for which a grid of DFT sample freqs - should be calculated. - rfft: bool - Controls Whether the frequency grid is for a real fft (rfft). - spacing: float | Tuple[float, float, float] - Sample spacing in `d`, `h` and `w` dimensions of the grid. - device: torch.device - Torch device for the resulting grid. - - Returns - ------- - frequency_grid: torch.Tensor - `(h, w, 3)` array of DFT sample freqs. - Order of freqs in the last dimension corresponds to the order of dimensions - of the grid. - """ - dd, dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 3 - last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq - d, h, w = image_shape - freq_z = torch.fft.fftfreq(d, d=dd, device=device) - freq_y = torch.fft.fftfreq(h, d=dh, device=device) - freq_x = last_axis_frequency_func(w, d=dw, device=device) - d, h, w = rfft_shape(image_shape) if rfft is True else image_shape - freq_zz = einops.repeat(freq_z, 'd -> d h w', h=h, w=w) - freq_yy = einops.repeat(freq_y, 'h -> d h w', d=d, w=w) - freq_xx = einops.repeat(freq_x, 'w -> d h w', d=d, h=h) - return einops.rearrange([freq_zz, freq_yy, freq_xx], 'freq ... -> ... freq') - diff --git a/src/torch_fourier_slice/project.py b/src/torch_fourier_slice/project.py index 1072234..221ee3e 100644 --- a/src/torch_fourier_slice/project.py +++ b/src/torch_fourier_slice/project.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F +from torch_grid_utils import fftfreq_grid -from .grids import fftfreq_grid from .slice_extraction import extract_central_slices_rfft_3d diff --git a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py index d216aac..8987443 100644 --- a/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py +++ b/src/torch_fourier_slice/slice_extraction/_extract_central_slices_rfft_3d.py @@ -3,7 +3,7 @@ from torch_image_lerp import sample_image_3d from ..dft_utils import fftfreq_to_dft_coordinates -from ..grids.central_slice_grid import central_slice_fftfreq_grid +from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid def extract_central_slices_rfft_3d( diff --git a/src/torch_fourier_slice/slice_insertion/_insert_central_slices_rfft_3d.py b/src/torch_fourier_slice/slice_insertion/_insert_central_slices_rfft_3d.py index 7b09840..cf02b21 100644 --- a/src/torch_fourier_slice/slice_insertion/_insert_central_slices_rfft_3d.py +++ b/src/torch_fourier_slice/slice_insertion/_insert_central_slices_rfft_3d.py @@ -3,7 +3,7 @@ from torch_image_lerp import insert_into_image_3d from ..dft_utils import fftfreq_to_dft_coordinates, rfft_shape -from ..grids.central_slice_grid import central_slice_fftfreq_grid +from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid def insert_central_slices_rfft_3d( diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..6ecef69 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +import torch +from pytest import fixture + + +@fixture +def cube() -> torch.Tensor: + volume = torch.zeros((32, 32, 32)) + volume[8:24, 8:24, 8:24] = 1 + volume[16, 16, 16] = 32 + return volume diff --git a/tests/test_torch_fourier_slice.py b/tests/test_torch_fourier_slice.py index 4c6e45f..ff02dde 100644 --- a/tests/test_torch_fourier_slice.py +++ b/tests/test_torch_fourier_slice.py @@ -1,4 +1,59 @@ -# temporary stub +import pytest +import torch -def test_something(): - pass +from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d +from torch_fourier_shell_correlation import fsc +from scipy.stats import special_ortho_group + + +def test_project_3d_to_2d_rotation_center(): + # rotation center should be at position of DC in DFT + volume = torch.zeros((32, 32, 32)) + volume[16, 16, 16] = 1 + + # make projections + rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float() + projections = project_3d_to_2d( + volume=volume, + rotation_matrices=rotation_matrices, + ) + + # check max is always at (16, 16), implying point (16, 16) never moves + for image in projections: + max = torch.argmax(image) + i, j = divmod(max.item(), 32) + assert (i, j) == (16, 16) + + +def test_3d_2d_projection_backprojection_cycle(cube): + # make projections + rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float() + projections = project_3d_to_2d( + volume=cube, + rotation_matrices=rotation_matrices, + ) + + # reconstruct + volume = backproject_2d_to_3d( + images=projections, + rotation_matrices=rotation_matrices, + ) + + # calculate FSC between the projections and the reconstructions + _fsc = fsc(cube, volume.float()) + + assert torch.all(_fsc[3:] > 0.99) # one low res shell at 0.98... + + +@pytest.mark.parametrize( + "images, rotation_matrices", + [ + ( + torch.rand((10, 28, 28)).float(), + torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float() + ), + ] +) +def test_dtypes_slice_insertion(images, rotation_matrices): + result = backproject_2d_to_3d(images, rotation_matrices) + assert result.dtype == images.dtype