Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Dec 2, 2023
1 parent e79d1f0 commit 9db26a4
Show file tree
Hide file tree
Showing 7 changed files with 445 additions and 77 deletions.
24 changes: 20 additions & 4 deletions diffpose/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
'diffpose/deepfluoro.py'),
'diffpose.deepfluoro.get_3d_fiducials': ( 'api/deepfluoro.html#get_3d_fiducials',
'diffpose/deepfluoro.py'),
'diffpose.deepfluoro.get_random_offset': ( 'api/deepfluoro.html#get_random_offset',
'diffpose/deepfluoro.py'),
'diffpose.deepfluoro.load_deepfluoro_dataset': ( 'api/deepfluoro.html#load_deepfluoro_dataset',
'diffpose/deepfluoro.py'),
'diffpose.deepfluoro.parse_proj_params': ( 'api/deepfluoro.html#parse_proj_params',
Expand All @@ -72,13 +74,29 @@
'diffpose/jacobians.py'),
'diffpose.jacobians.plot_img_jacobian': ( 'api/jacobians.html#plot_img_jacobian',
'diffpose/jacobians.py')},
'diffpose.ljubljana': { 'diffpose.ljubljana.LjubljanaDataset': ('api/ljubljana.html#ljubljanadataset', 'diffpose/ljubljana.py'),
'diffpose.ljubljana': { 'diffpose.ljubljana.Evaluator': ('api/ljubljana.html#evaluator', 'diffpose/ljubljana.py'),
'diffpose.ljubljana.Evaluator.__call__': ( 'api/ljubljana.html#evaluator.__call__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.Evaluator.__init__': ( 'api/ljubljana.html#evaluator.__init__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.Evaluator.project': ( 'api/ljubljana.html#evaluator.project',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.LjubljanaDataset': ('api/ljubljana.html#ljubljanadataset', 'diffpose/ljubljana.py'),
'diffpose.ljubljana.LjubljanaDataset.__getitem__': ( 'api/ljubljana.html#ljubljanadataset.__getitem__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.LjubljanaDataset.__init__': ( 'api/ljubljana.html#ljubljanadataset.__init__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.LjubljanaDataset.__iter__': ( 'api/ljubljana.html#ljubljanadataset.__iter__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.LjubljanaDataset.__len__': ( 'api/ljubljana.html#ljubljanadataset.__len__',
'diffpose/ljubljana.py')},
'diffpose/ljubljana.py'),
'diffpose.ljubljana.Transforms': ('api/ljubljana.html#transforms', 'diffpose/ljubljana.py'),
'diffpose.ljubljana.Transforms.__call__': ( 'api/ljubljana.html#transforms.__call__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.Transforms.__init__': ( 'api/ljubljana.html#transforms.__init__',
'diffpose/ljubljana.py'),
'diffpose.ljubljana.get_random_offset': ( 'api/ljubljana.html#get_random_offset',
'diffpose/ljubljana.py')},
'diffpose.metrics': { 'diffpose.metrics.CustomMetric': ('api/metrics.html#custommetric', 'diffpose/metrics.py'),
'diffpose.metrics.CustomMetric.__init__': ( 'api/metrics.html#custommetric.__init__',
'diffpose/metrics.py'),
Expand Down Expand Up @@ -136,8 +154,6 @@
'diffpose/registration.py'),
'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.norm': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.norm',
'diffpose/registration.py'),
'diffpose.registration.get_random_offset': ( 'api/registration.html#get_random_offset',
'diffpose/registration.py'),
'diffpose.registration.img_to_patches': ( 'api/registration.html#img_to_patches',
'diffpose/registration.py'),
'diffpose.registration.mask_to_img': ( 'api/registration.html#mask_to_img',
Expand Down
25 changes: 23 additions & 2 deletions diffpose/deepfluoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# %% auto 0
__all__ = ['DeepFluoroDataset', 'convert_deepfluoro_to_diffdrr', 'convert_diffdrr_to_deepfluoro', 'Evaluator', 'preprocess',
'Transforms']
'get_random_offset', 'Transforms']

# %% ../notebooks/api/00_deepfluoro.ipynb 3
from pathlib import Path
Expand Down Expand Up @@ -305,7 +305,28 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):
img = (img - img.min()) / (img.max() - img.min())
return img

# %% ../notebooks/api/00_deepfluoro.ipynb 30
# %% ../notebooks/api/00_deepfluoro.ipynb 26
from beartype import beartype
from pytorch3d.transforms import se3_exp_map

from .calibration import RigidTransform


@beartype
def get_random_offset(batch_size: int, device) -> RigidTransform:
t1 = torch.distributions.Normal(10, 70).sample((batch_size,))
t2 = torch.distributions.Normal(250, 90).sample((batch_size,))
t3 = torch.distributions.Normal(5, 50).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))
logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)
T = se3_exp_map(logmap)
R = T[..., :3, :3].transpose(-1, -2)
t = T[..., 3, :3]
return RigidTransform(R, t)

# %% ../notebooks/api/00_deepfluoro.ipynb 32
from torchvision.transforms import Compose, Lambda, Normalize, Resize


Expand Down
114 changes: 112 additions & 2 deletions diffpose/ljubljana.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_ljubljana.ipynb.

# %% auto 0
__all__ = ['LjubljanaDataset']
__all__ = ['LjubljanaDataset', 'get_random_offset', 'Evaluator', 'Transforms']

# %% ../notebooks/api/01_ljubljana.ipynb 3
from pathlib import Path
Expand Down Expand Up @@ -53,6 +53,9 @@ def __init__(
def __len__(self):
return 10

def __iter__(self):
return iter(self[idx] for idx in range(len(self)))

def __getitem__(self, idx):
idx += 1
extrinsic = self.f[f"subject{idx:02d}/proj-{self.view}/extrinsic"][:]
Expand All @@ -69,8 +72,9 @@ def __getitem__(self, idx):
if self.preprocess:
img += 1
img = img.max().log() - img.log()

height, width = img.shape
img = img.unsqueeze(0).unsqueeze(0)

focal_len, x0, y0 = parse_intrinsic_matrix(
intrinsic,
height,
Expand All @@ -88,6 +92,13 @@ def __getitem__(self, idx):
volume = self.f[f"subject{idx:02d}/volume/pixels"][:]
spacing = self.f[f"subject{idx:02d}/volume/spacing"][:]

isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]])
isocenter_xyz = torch.tensor(volume.shape) * spacing / 2
isocenter_xyz = isocenter_xyz.unsqueeze(0)
isocenter_pose = RigidTransform(
isocenter_rot, isocenter_xyz, "euler_angles", "ZYX"
)

return (
volume,
spacing,
Expand All @@ -100,4 +111,103 @@ def __getitem__(self, idx):
y0,
img,
pose,
isocenter_pose,
)

# %% ../notebooks/api/01_ljubljana.ipynb 10
from beartype import beartype
from pytorch3d.transforms import se3_exp_map

from .calibration import RigidTransform


@beartype
def get_random_offset(view, batch_size: int, device) -> RigidTransform:
if view == "ap":
t1 = torch.distributions.Normal(-10, 20).sample((batch_size,))
t2 = torch.distributions.Normal(175, 30).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))
elif view == "lat":
t1 = torch.distributions.Normal(75, 15).sample((batch_size,))
t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))
else:
raise ValueError(f"view must be 'ap' or 'lat', not {view}")

logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)
T = se3_exp_map(logmap)
R = T[..., :3, :3].transpose(-1, -2)
t = T[..., 3, :3]
return RigidTransform(R, t)

# %% ../notebooks/api/01_ljubljana.ipynb 12
from torch.nn.functional import pad

from .calibration import perspective_projection


class Evaluator:
def __init__(self, specimen, idx):
# Save matrices to device
self.translate = specimen.translate
self.flip_xz = specimen.flip_xz
self.intrinsic = specimen.intrinsic
self.intrinsic_inv = specimen.intrinsic.inverse()

# Get gt fiducial locations
self.specimen = specimen
self.fiducials = specimen.fiducials
gt_pose = specimen[idx][1]
self.true_projected_fiducials = self.project(gt_pose)

def project(self, pose):
extrinsic = convert_diffdrr_to_deepfluoro(self.specimen, pose)
x = perspective_projection(extrinsic, self.intrinsic, self.fiducials)
x = -self.specimen.focal_len * torch.einsum(
"ij, bnj -> bni",
self.intrinsic_inv,
pad(x, (0, 1), value=1), # Convert to homogenous coordinates
)
extrinsic = (
self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)
)
return extrinsic.transform_points(x)

def __call__(self, pose):
pred_projected_fiducials = self.project(pose)
registration_error = (
(self.true_projected_fiducials - pred_projected_fiducials)
.norm(dim=-1)
.mean()
)
registration_error *= 0.154 # Pixel spacing is 0.154 mm / pixel isotropic
return registration_error

# %% ../notebooks/api/01_ljubljana.ipynb 15
from torchvision.transforms import Compose, Lambda, Normalize, Resize


class Transforms:
def __init__(
self,
height: int,
width: int,
eps: float = 1e-6,
):
"""Transform X-rays and DRRs before inputting to CNN."""
self.transforms = Compose(
[
Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)),
Resize((height, width), antialias=True),
Normalize(mean=0.0774, std=0.0569),
]
)

def __call__(self, x):
return self.transforms(x)
47 changes: 17 additions & 30 deletions diffpose/registration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/03_registration.ipynb.

# %% auto 0
__all__ = ['PoseRegressor', 'get_random_offset', 'SparseRegistration', 'VectorizedNormalizedCrossCorrelation2d']
__all__ = ['PoseRegressor', 'SparseRegistration', 'VectorizedNormalizedCrossCorrelation2d']

# %% ../notebooks/api/03_registration.ipynb 3
import timm
import torch

# %% ../notebooks/api/03_registration.ipynb 5
from .calibration import RigidTransform, convert


class PoseRegressor(torch.nn.Module):
"""
A PoseRegressor is comprised of a pretrained backbone model that extracts features
Expand Down Expand Up @@ -45,46 +48,31 @@ def forward(self, x):
x = self.backbone(x)
rot = self.rot_regression(x)
xyz = self.xyz_regression(x)
return RigidTransform(rot, xyz, self.parameterization, self.convention)
return convert(
[rot, xyz],
input_parameterization=self.parameterization,
output_parameterization="se3_exp_map",
input_convention=self.convention,
)

# %% ../notebooks/api/03_registration.ipynb 6
N_ANGULAR_COMPONENTS = {
"axis_angle": 3,
"euler_angles": 3,
"se3": 3,
"se3_log_map": 3,
"quaternion": 4,
"rotation_6d": 6,
"rotation_10d": 10,
"quaternion_adjugate": 10,
}

# %% ../notebooks/api/03_registration.ipynb 8
from beartype import beartype
from pytorch3d.transforms import se3_exp_map

from .calibration import RigidTransform


@beartype
def get_random_offset(batch_size: int, device) -> RigidTransform:
t1 = torch.distributions.Normal(10, 70).sample((batch_size,))
t2 = torch.distributions.Normal(250, 90).sample((batch_size,))
t3 = torch.distributions.Normal(5, 50).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))
logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)
T = se3_exp_map(logmap)
R = T[..., :3, :3].transpose(-1, -2)
t = T[..., 3, :3]
return RigidTransform(R, t)

# %% ../notebooks/api/03_registration.ipynb 12
# %% ../notebooks/api/03_registration.ipynb 11
from diffdrr.detector import make_xrays
from diffdrr.drr import DRR
from diffdrr.siddon import siddon_raycast
from pytorch3d.transforms import se3_exp_map

from .calibration import RigidTransform, convert
from .calibration import RigidTransform


class SparseRegistration(torch.nn.Module):
Expand Down Expand Up @@ -210,12 +198,11 @@ def get_current_pose(self):
return convert(
[self.rotation, self.translation],
input_parameterization=self.parameterization,
output_parameterization="euler_angles",
output_parameterization="se3_exp_map",
input_convention=self.convention,
output_convention="ZYX",
)

# %% ../notebooks/api/03_registration.ipynb 14
# %% ../notebooks/api/03_registration.ipynb 13
def preprocess(x, eps=1e-4):
x = (x - x.min()) / (x.max() - x.min() + eps)
return (x - 0.3080) / 0.1494
Expand Down Expand Up @@ -245,7 +232,7 @@ def vector_to_img(pred_img, mask):
patches.append(patch)
return filled

# %% ../notebooks/api/03_registration.ipynb 15
# %% ../notebooks/api/03_registration.ipynb 14
class VectorizedNormalizedCrossCorrelation2d(torch.nn.Module):
def __init__(self, eps=1e-4):
super().__init__()
Expand Down
42 changes: 40 additions & 2 deletions notebooks/api/00_deepfluoro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -694,14 +694,53 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "dece6bbf-d973-41a4-8b38-078ab64e6f79",
"metadata": {},
"source": [
"## Distribution over camera poses\n",
"\n",
"We sample the three rotational and three translational parameters of $\\mathfrak{se}(3)$ from independent normal distributions defined with sufficient variance to capture wide perturbations from the isocenter."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32d6434d-41c3-4bf6-b12f-ea3d2721c753",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from beartype import beartype\n",
"from pytorch3d.transforms import se3_exp_map\n",
"\n",
"from diffpose.calibration import RigidTransform\n",
"\n",
"\n",
"@beartype\n",
"def get_random_offset(batch_size: int, device) -> RigidTransform:\n",
" t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))\n",
" logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)\n",
" T = se3_exp_map(logmap)\n",
" R = T[..., :3, :3].transpose(-1, -2)\n",
" t = T[..., 3, :3]\n",
" return RigidTransform(R, t)"
]
},
{
"cell_type": "markdown",
"id": "76efedd3-103c-43dd-944b-fe0c06e2c87b",
"metadata": {},
"source": [
"## Fiducial markers\n",
"\n",
"The `DeepFluoroDataset` class also contains a method for evaluating the registration error for a predicted pose. Fiducial markers were implanted in the original cadavers. Projecting them with predicted pose parameters can be used to measure their distance from the true fiducials."
"The `DeepFluoroDataset` class also contains a method for evaluating the registration error for a predicted pose. Fiducial markers were digitally placed on the preoperative CT. Projecting them with predicted pose parameters can be used to measure their distance from the true fiducials."
]
},
{
Expand All @@ -720,7 +759,6 @@
],
"source": [
"#| eval: false\n",
"# from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles\n",
"from diffdrr.utils import convert\n",
"\n",
"# Perturb the ground truth rotations by 0.05 degrees and 2 mm\n",
Expand Down
Loading

0 comments on commit 9db26a4

Please sign in to comment.