Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates including tests #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@ dependencies = [
"numpy",
"einops",
"torch_image_lerp",
"torch_grid_utils",
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
# "extras" (e.g. for `pip install .[test]`)
[project.optional-dependencies]
# add dependencies used for testing here
test = ["pytest", "pytest-cov"]
test = [
"pytest",
"pytest-cov",
"torch-fourier-shell-correlation",
]
# add anything else you like to have in your dev environment here
dev = [
"ipython",
Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/backproject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_insertion import insert_central_slices_rfft_3d


Expand Down
2 changes: 1 addition & 1 deletion src/torch_fourier_slice/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fftfreq_grid import fftfreq_grid
from .central_slice_fftfreq_grid import central_slice_fftfreq_grid
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import einops
import torch

from .fftfreq_grid import _construct_fftfreq_grid_2d
from torch_grid_utils import fftfreq_grid
from ..dft_utils import rfft_shape, fftshift_2d


Expand All @@ -13,7 +13,7 @@ def central_slice_fftfreq_grid(
) -> torch.Tensor:
# generate 2d grid of DFT sample frequencies, shape (h, w, 2)
h, w = volume_shape[-2:]
grid = _construct_fftfreq_grid_2d(
grid = fftfreq_grid(
image_shape=(h, w),
rfft=rfft,
device=device
Expand Down
158 changes: 0 additions & 158 deletions src/torch_fourier_slice/grids/fftfreq_grid.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/torch_fourier_slice/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from torch_grid_utils import fftfreq_grid

from .grids import fftfreq_grid
from .slice_extraction import extract_central_slices_rfft_3d


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import sample_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def extract_central_slices_rfft_3d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_image_lerp import insert_into_image_3d

from ..dft_utils import fftfreq_to_dft_coordinates, rfft_shape
from ..grids.central_slice_grid import central_slice_fftfreq_grid
from ..grids.central_slice_fftfreq_grid import central_slice_fftfreq_grid


def insert_central_slices_rfft_3d(
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from pytest import fixture


@fixture
def cube() -> torch.Tensor:
volume = torch.zeros((32, 32, 32))
volume[8:24, 8:24, 8:24] = 1
volume[16, 16, 16] = 32
return volume
61 changes: 58 additions & 3 deletions tests/test_torch_fourier_slice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,59 @@
# temporary stub
import pytest
import torch

def test_something():
pass
from torch_fourier_slice import project_3d_to_2d, backproject_2d_to_3d
from torch_fourier_shell_correlation import fsc
from scipy.stats import special_ortho_group


def test_project_3d_to_2d_rotation_center():
# rotation center should be at position of DC in DFT
volume = torch.zeros((32, 32, 32))
volume[16, 16, 16] = 1

# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=100)).float()
projections = project_3d_to_2d(
volume=volume,
rotation_matrices=rotation_matrices,
)

# check max is always at (16, 16), implying point (16, 16) never moves
for image in projections:
max = torch.argmax(image)
i, j = divmod(max.item(), 32)
assert (i, j) == (16, 16)


def test_3d_2d_projection_backprojection_cycle(cube):
# make projections
rotation_matrices = torch.tensor(special_ortho_group.rvs(dim=3, size=1500)).float()
projections = project_3d_to_2d(
volume=cube,
rotation_matrices=rotation_matrices,
)

# reconstruct
volume = backproject_2d_to_3d(
images=projections,
rotation_matrices=rotation_matrices,
)

# calculate FSC between the projections and the reconstructions
_fsc = fsc(cube, volume.float())

assert torch.all(_fsc[3:] > 0.99) # one low res shell at 0.98...


@pytest.mark.parametrize(
"images, rotation_matrices",
[
(
torch.rand((10, 28, 28)).float(),
torch.tensor(special_ortho_group.rvs(dim=3, size=10)).float()
),
]
)
def test_dtypes_slice_insertion(images, rotation_matrices):
result = backproject_2d_to_3d(images, rotation_matrices)
assert result.dtype == images.dtype
Loading