Skip to content

Commit

Permalink
updates including tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alisterburt committed Aug 9, 2024
1 parent 6abf032 commit 6281aa7
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 169 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ dependencies = [
"numpy",
"einops",
"torch_image_lerp",
"torch_grid_utils",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
# "extras" (e.g. for `pip install .[test]`)
[project.optional-dependencies]
# add dependencies used for testing here
test = ["pytest", "pytest-cov"]
test = [
"pytest",
"pytest-cov",
"torch-fourier-shell-correlation",
]
# add anything else you like to have in your dev environment here
dev = [
"ipython",
Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/backproject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_insertion import insert_central_slices_rfft_3d


Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fftfreq_grid import fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch

from .fftfreq_grid import _construct_fftfreq_grid_2d
from torch_grid_utils import fftfreq_grid
from ..dft_utils import rfft_shape, fftshift_2d


Expand All @@ -13,7 +13,7 @@ def central_slice_fftfreq_grid(
) -> torch.Tensor:
# generate 2d grid of DFT sample frequencies, shape (h, w, 2)
h, w = volume_shape[-2:]
grid = _construct_fftfreq_grid_2d(
grid = fftfreq_grid(
image_shape=(h, w),
rfft=rfft,
device=device
Expand Down
158 changes: 0 additions & 158 deletions src/torch_fourier_slice/grids/fftfreq_grid.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_extraction import extract_central_slices_rfft_3d


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import sample_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def extract_central_slices_rfft_3d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import insert_into_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates, rfft_shape
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def insert_central_slices_rfft_3d(
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from pytest import fixture


@fixture
def cube() -> torch.Tensor:
volume = torch.zeros((32, 32, 32))
volume[8:24, 8:24, 8:24] = 1
volume[16, 16, 16] = 32
return volume
61 changes: 58 additions & 3 deletions tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,59 @@
# temporary stub
import pytest
import torch

def test_something():
pass
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group


def test_project_3d_to_2d_rotation_center():
# rotation center should be at position of DC in DFT
volume = torch.zeros((32, 32, 32))
volume[16, 16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float()
projections = project_3d_to_2d(
volume=volume,
rotation_matrices=rotation_matrices,
)

# check max is always at (16, 16), implying point (16, 16) never moves
for image in projections:
max = torch.argmax(image)
i, j = divmod(max.item(), 32)
assert (i, j) == (16, 16)


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
projections = project_3d_to_2d(
volume=cube,
rotation_matrices=rotation_matrices,
)

# reconstruct
volume = backproject_2d_to_3d(
images=projections,
rotation_matrices=rotation_matrices,
)

# calculate FSC between the projections and the reconstructions
_fsc = fsc(cube, volume.float())

assert torch.all(_fsc[3:] > 0.99) # one low res shell at 0.98...


@pytest.mark.parametrize(
"images, rotation_matrices",
[
(
torch.rand((10, 28, 28)).float(),
torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float()
),
]
)
def test_dtypes_slice_insertion(images, rotation_matrices):
result = backproject_2d_to_3d(images, rotation_matrices)
assert result.dtype == images.dtype

0 comments on commit 6281aa7

Please sign in to comment.