Skip to content

Commit

Permalink
Add example of sparse rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Nov 30, 2023
1 parent 83e1f95 commit d39cb5d
Show file tree
Hide file tree
Showing 6 changed files with 383 additions and 113 deletions.
7 changes: 3 additions & 4 deletions diffpose/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'diffpose/calibration.py'),
'diffpose.calibration.RigidTransform.inverse': ( 'api/calibration.html#rigidtransform.inverse',
'diffpose/calibration.py'),
'diffpose.calibration.convert': ('api/calibration.html#convert', 'diffpose/calibration.py'),
'diffpose.calibration.perspective_projection': ( 'api/calibration.html#perspective_projection',
'diffpose/calibration.py')},
'diffpose.deepfluoro': { 'diffpose.deepfluoro.DeepFluoroDataset': ( 'api/deepfluoro.html#deepfluorodataset',
Expand Down Expand Up @@ -123,10 +124,8 @@
'diffpose/registration.py'),
'diffpose.registration.SparseRegistration.forward': ( 'api/registration.html#sparseregistration.forward',
'diffpose/registration.py'),
'diffpose.registration.SparseRegistration.get_rotation': ( 'api/registration.html#sparseregistration.get_rotation',
'diffpose/registration.py'),
'diffpose.registration.SparseRegistration.get_translation': ( 'api/registration.html#sparseregistration.get_translation',
'diffpose/registration.py'),
'diffpose.registration.SparseRegistration.get_current_pose': ( 'api/registration.html#sparseregistration.get_current_pose',
'diffpose/registration.py'),
'diffpose.registration.VectorizedNormalizedCrossCorrelation2d': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d',
'diffpose/registration.py'),
'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.__init__': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.__init__',
Expand Down
55 changes: 50 additions & 5 deletions diffpose/calibration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/02_calibration.ipynb.

# %% auto 0
__all__ = ['RigidTransform', 'perspective_projection']
__all__ = ['RigidTransform', 'convert', 'perspective_projection']

# %% ../notebooks/api/02_calibration.ipynb 4
import torch
Expand All @@ -10,7 +10,7 @@
from typing import Optional

from beartype import beartype
from diffdrr.utils import convert
from diffdrr.utils import convert as convert_so3
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import Transform3d

Expand All @@ -30,7 +30,7 @@ def __init__(
if device is None and (R.device == t.device):
device = R.device

R = convert(R, parameterization, "matrix", convention)
R = convert_so3(R, parameterization, "matrix", convention)
if R.dim() == 2 and t.dim() == 1:
R = R.unsqueeze(0)
t = t.unsqueeze(0)
Expand All @@ -46,7 +46,7 @@ def __init__(
def get_rotation(self, parameterization=None, convention=None):
R = self.get_matrix()[..., :3, :3].transpose(-1, -2)
if parameterization is not None:
R = convert(R, "matrix", parameterization, None, convention)
R = convert_so3(R, "matrix", parameterization, None, convention)
return R

def get_translation(self):
Expand All @@ -70,7 +70,52 @@ def clone(self):
t = self.get_matrix()[..., 3, :3].clone()
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

# %% ../notebooks/api/02_calibration.ipynb 8
# %% ../notebooks/api/02_calibration.ipynb 7
from pytorch3d.transforms import se3_exp_map


def convert(
transform,
input_parameterization,
output_parameterization,
input_convention=None,
output_convention=None,
):
"""Convert between representations of SE(3)."""

# Convert any input parameterization to a RigidTransform
if input_parameterization == "se3_log_map":
transform = torch.concat((transform[1], transform[0]), axis=-1)
matrix = se3_exp_map(transform)
transform = RigidTransform(
R=matrix[..., :3, :3].transpose(-1, -2),
t=matrix[..., 3, :3],
device=matrix.device,
dtype=matrix.dtype,
)
elif input_parameterization == "se3_exp_map":
pass
else:
transform = RigidTransform(
R=transform[0],
t=transform[1],
parameterization=input_parameterization,
convention=input_convention,
)

# Convert the RigidTransform to any output
if output_parameterization == "se3_exp_map":
return transform
elif output_parameterization == "se3_log_map":
se3_log = transform.get_se3_log()
return se3_log[..., 3:], se3_log[..., :3]
else:
return (
transform.get_rotation(output_parameterization, output_convention),
transform.get_translation(),
)

# %% ../notebooks/api/02_calibration.ipynb 9
@beartype
@jaxtyped
def perspective_projection(
Expand Down
129 changes: 79 additions & 50 deletions diffpose/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,60 +81,100 @@ def get_random_offset(batch_size: int, device) -> RigidTransform:

# %% ../notebooks/api/03_registration.ipynb 12
from diffdrr.detector import make_xrays
from diffdrr.drr import DRR
from diffdrr.siddon import siddon_raycast
from diffdrr.utils import convert
from pytorch3d.transforms import se3_exp_map, se3_log_map

from .calibration import RigidTransform, convert


class SparseRegistration(torch.nn.Module):
def __init__(self, drr, pose, features=None):
def __init__(
self,
drr: DRR,
pose: RigidTransform,
parameterization: str,
convention: str = None,
features=None, # Used to compute biased estimate of mNCC
n_patches: int = None, # If n_patches is None, render the whole image
patch_size: int = 13,
):
super().__init__()
self.drr = drr
log = se3_log_map(pose.get_matrix())
self.translation = torch.nn.Parameter(log[..., :3])
self.rotation = torch.nn.Parameter(log[..., 3:])

# Crop 10 pixels off the edge (i.e., use patch radius < 10 pixels)
# Parse the input pose
rotation, translation = convert(
pose,
input_parameterization="se3_exp_map",
output_parameterization=parameterization,
output_convention=convention,
)
self.parameterization = parameterization
self.convention = convention
self.rotation = torch.nn.Parameter(rotation)
self.translation = torch.nn.Parameter(translation)

# Crop pixels off the edge such that pixels don't fall outside the image
self.n_patches = n_patches
self.patch_size = patch_size
self.patch_radius = self.patch_size // 2 + 1
self.height = self.drr.detector.height
self.f_height = self.height - 2 * 10
self.f_height = self.height - 2 * self.patch_radius

if features is None: # Sample all pixels equally
# Define the distribution over patch centers
if features is None:
features = torch.ones(
self.height, self.height, device=self.rotation.device
) / (self.height**2)
self.m = torch.distributions.categorical.Categorical(
probs=features.squeeze()[10:-10, 10:-10].flatten()
self.patch_centers = torch.distributions.categorical.Categorical(
probs=features.squeeze()[
self.patch_radius : -self.patch_radius,
self.patch_radius : -self.patch_radius,
].flatten()
)

def forward(self, n_patches, patch_size):
"""If n_patches is None, render the whole image."""

def forward(self, n_patches=None, patch_size=None):
# Parse initial density
if not hasattr(self.drr, "density"):
self.drr.set_bone_attenuation_multiplier(
self.drr.bone_attenuation_multiplier
)

# Make the mask
if n_patches is not None:
if n_patches is not None or patch_size is not None:
self.n_patches = n_patches
self.patch_size = patch_size

# Make the mask for sparse rendering
if self.n_patches is None:
mask = torch.ones(
1,
self.height,
self.height,
dtype=torch.bool,
device=self.rotation.device,
)
else:
mask = torch.zeros(
n_patches,
self.n_patches,
self.height,
self.height,
dtype=torch.bool,
device=self.rotation.device,
)
radius = patch_size // 2
idxs = self.m.sample(sample_shape=torch.Size([n_patches]))
idxs, jdxs = idxs // self.f_height + 10, idxs % self.f_height + 10
radius = self.patch_size // 2
idxs = self.patch_centers.sample(sample_shape=torch.Size([self.n_patches]))
idxs, jdxs = (
idxs // self.f_height + self.patch_radius,
idxs % self.f_height + self.patch_radius,
)

idx = torch.arange(-radius, radius + 1, device=self.rotation.device)
patches = torch.cartesian_prod(idx, idx).expand(n_patches, -1, -1)
patches = torch.cartesian_prod(idx, idx).expand(self.n_patches, -1, -1)
patches = patches + torch.stack([idxs, jdxs], dim=-1).unsqueeze(1)
patches = torch.concat(
[
torch.arange(n_patches, device=self.rotation.device)
torch.arange(self.n_patches, device=self.rotation.device)
.unsqueeze(-1)
.expand(-1, patch_size**2)
.expand(-1, self.patch_size**2)
.unsqueeze(-1),
patches,
],
Expand All @@ -145,22 +185,14 @@ def forward(self, n_patches, patch_size):
patches[..., 1],
patches[..., 2],
] = True
else:
mask = torch.ones(
1,
self.height,
self.height,
dtype=torch.bool,
device=self.rotation.device,
)

# Get the source and target
T = se3_exp_map(
torch.concat([self.translation, self.rotation], dim=1)
).transpose(-1, -2)
R = T[..., :3, :3]
t = T[..., :3, 3]
pose = RigidTransform(R, t, "matrix")
pose = convert(
[self.rotation, self.translation],
input_parameterization=self.parameterization,
output_parameterization="se3_exp_map",
input_convention=self.convention,
)
source, target = make_xrays(
pose,
self.drr.detector.source,
Expand All @@ -170,21 +202,18 @@ def forward(self, n_patches, patch_size):
# Render the sparse image
target = target[mask.any(dim=0).view(1, -1)]
img = siddon_raycast(source, target, self.drr.density, self.drr.spacing)
if self.n_patches is None:
img = self.drr.reshape_transform(img, batch_size=len(self.rotation))
return img, mask

def get_rotation(self):
T = se3_exp_map(
torch.concat([self.translation, self.rotation], dim=1)
).transpose(-1, -2)
R = T[..., :3, :3]
return convert(R, "matrix", "euler_angles", output_convention="ZYX")

def get_translation(self):
T = se3_exp_map(
torch.concat([self.translation, self.rotation], dim=1)
).transpose(-1, -2)
t = T[..., :3, 3]
return t
def get_current_pose(self):
return convert(
[self.rotation, self.translation],
input_parameterization=self.parameterization,
output_parameterization="euler_angles",
input_convention=self.convention,
output_convention="ZYX",
)

# %% ../notebooks/api/03_registration.ipynb 14
def preprocess(x, eps=1e-4):
Expand Down
59 changes: 56 additions & 3 deletions notebooks/api/02_calibration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
"from typing import Optional\n",
"\n",
"from beartype import beartype\n",
"from diffdrr.utils import convert\n",
"from diffdrr.utils import convert as convert_so3\n",
"from jaxtyping import Float, jaxtyped\n",
"from pytorch3d.transforms import Transform3d\n",
"\n",
Expand All @@ -117,7 +117,7 @@
" if device is None and (R.device == t.device):\n",
" device = R.device\n",
"\n",
" R = convert(R, parameterization, \"matrix\", convention)\n",
" R = convert_so3(R, parameterization, \"matrix\", convention)\n",
" if R.dim() == 2 and t.dim() == 1:\n",
" R = R.unsqueeze(0)\n",
" t = t.unsqueeze(0)\n",
Expand All @@ -133,7 +133,7 @@
" def get_rotation(self, parameterization=None, convention=None):\n",
" R = self.get_matrix()[..., :3, :3].transpose(-1, -2)\n",
" if parameterization is not None:\n",
" R = convert(R, \"matrix\", parameterization, None, convention)\n",
" R = convert_so3(R, \"matrix\", parameterization, None, convention)\n",
" return R\n",
"\n",
" def get_translation(self):\n",
Expand All @@ -158,6 +158,59 @@
" return RigidTransform(R, t, device=self.device, dtype=self.dtype)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff7984c3-a8f5-435f-b504-dce55787f517",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from pytorch3d.transforms import se3_exp_map\n",
"\n",
"\n",
"def convert(\n",
" transform,\n",
" input_parameterization,\n",
" output_parameterization,\n",
" input_convention=None,\n",
" output_convention=None,\n",
"):\n",
" \"\"\"Convert between representations of SE(3).\"\"\"\n",
"\n",
" # Convert any input parameterization to a RigidTransform\n",
" if input_parameterization == \"se3_log_map\":\n",
" transform = torch.concat((transform[1], transform[0]), axis=-1)\n",
" matrix = se3_exp_map(transform)\n",
" transform = RigidTransform(\n",
" R=matrix[..., :3, :3].transpose(-1, -2),\n",
" t=matrix[..., 3, :3],\n",
" device=matrix.device,\n",
" dtype=matrix.dtype,\n",
" )\n",
" elif input_parameterization == \"se3_exp_map\":\n",
" pass\n",
" else:\n",
" transform = RigidTransform(\n",
" R=transform[0],\n",
" t=transform[1],\n",
" parameterization=input_parameterization,\n",
" convention=input_convention,\n",
" )\n",
"\n",
" # Convert the RigidTransform to any output\n",
" if output_parameterization == \"se3_exp_map\":\n",
" return transform\n",
" elif output_parameterization == \"se3_log_map\":\n",
" se3_log = transform.get_se3_log()\n",
" return se3_log[..., 3:], se3_log[..., :3]\n",
" else:\n",
" return (\n",
" transform.get_rotation(output_parameterization, output_convention),\n",
" transform.get_translation(),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "7968b28c-ba34-49c7-8acc-bb080c0e4556",
Expand Down
131 changes: 80 additions & 51 deletions notebooks/api/03_registration.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit d39cb5d

Please sign in to comment.