Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency on pytorch3d #5

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
checkpoints/
evaluations/
data/
logs/
runs/
Expand Down
19 changes: 9 additions & 10 deletions diffpose/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from typing import Optional

from beartype import beartype
from diffdrr.utils import Transform3d
from diffdrr.utils import convert as convert_so3
from diffdrr.utils import se3_exp_map, se3_log_map
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import Transform3d
from pytorchse3.se3 import se3_exp_map, se3_log_map

# %% ../notebooks/api/02_calibration.ipynb 7
@beartype
class RigidTransform(Transform3d):
"""Wrapper of pytorch3d.transforms.Transform3d with extra functionalities."""

@jaxtyped
@jaxtyped(typechecker=beartype)
def __init__(
self,
R: Float[torch.Tensor, "..."],
Expand Down Expand Up @@ -74,7 +74,7 @@ def clone(self):
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

def get_se3_log(self):
return se3_log_map(self.get_matrix().transpose(-1, -2))
return se3_log_map(self.get_matrix())

# %% ../notebooks/api/02_calibration.ipynb 8
def convert(
Expand All @@ -88,8 +88,8 @@ def convert(

# Convert any input parameterization to a RigidTransform
if input_parameterization == "se3_log_map":
transform = torch.concat([*transform], axis=-1)
matrix = se3_exp_map(transform)
transform = torch.concat([transform[1], transform[0]], axis=-1)
matrix = se3_exp_map(transform).transpose(-1, -2)
transform = RigidTransform(
R=matrix[..., :3, :3],
t=matrix[..., :3, 3],
Expand All @@ -111,8 +111,8 @@ def convert(
return transform
elif output_parameterization == "se3_log_map":
se3_log = transform.get_se3_log()
log_R_vee = se3_log[..., :3]
log_t_vee = se3_log[..., 3:]
log_t_vee = se3_log[..., :3]
log_R_vee = se3_log[..., 3:]
return log_R_vee, log_t_vee
else:
return (
Expand All @@ -121,8 +121,7 @@ def convert(
)

# %% ../notebooks/api/02_calibration.ipynb 10
@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def perspective_projection(
extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)
intrinsic: Float[torch.Tensor, "3 3"], # Intrinsic camera matrix (camera to image)
Expand Down
2 changes: 0 additions & 2 deletions diffpose/deepfluoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):
return img

# %% ../notebooks/api/00_deepfluoro.ipynb 26
from beartype import beartype

from .calibration import RigidTransform, convert


Expand Down
2 changes: 0 additions & 2 deletions diffpose/ljubljana.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def __getitem__(self, idx):
)

# %% ../notebooks/api/01_ljubljana.ipynb 7
from beartype import beartype

from .calibration import RigidTransform, convert


Expand Down
24 changes: 9 additions & 15 deletions diffpose/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,25 @@ def __init__(self, patch_size=None):
# %% ../notebooks/api/04_metrics.ipynb 9
import torch
from beartype import beartype
from diffdrr.utils import convert
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import (
so3_rotation_angle,
from diffdrr.utils import (
convert,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
standardize_quaternion,
)
from jaxtyping import Float, jaxtyped

from .calibration import RigidTransform

# %% ../notebooks/api/04_metrics.ipynb 10
from pytorchse3.so3 import so3_log_map


class GeodesicSO3(torch.nn.Module):
"""Calculate the angular distance between two rotations in SO(3)."""

def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -102,8 +99,7 @@ class GeodesicTranslation(torch.nn.Module):
def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -120,8 +116,7 @@ class GeodesicSE3(torch.nn.Module):
def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -146,8 +141,7 @@ def __init__(
self.rotation = GeodesicSO3()
self.translation = GeodesicTranslation()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)
translation_geodesic = self.translation(pose_1, pose_2)
Expand Down
5 changes: 1 addition & 4 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
name: preop
name: diffpose
channels:
- conda-forge
- pytorch
- pytorch3d
- nvidia
dependencies:
- pip
- pytorch
- torchvision
- pytorch3d
- pip:
- diffdrr>=0.3.8
- h5py
- scikit-image
- seaborn
- pytorch-transformers
- pytorchse3
- timm
- torchmetrics
- tqdm
Expand Down
82 changes: 82 additions & 0 deletions experiments/deepfluoro/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pathlib import Path

import pandas as pd
import submitit
import torch
from tqdm import tqdm

from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms
from diffpose.registration import PoseRegressor


def load_specimen(id_number, device):
specimen = DeepFluoroDataset(id_number)
isocenter_pose = specimen.isocenter_pose.to(device)
return specimen, isocenter_pose


def load_model(model_name, device):
ckpt = torch.load(model_name)
model = PoseRegressor(
ckpt["model_name"],
ckpt["parameterization"],
ckpt["convention"],
norm_layer=ckpt["norm_layer"],
).to(device)
model.load_state_dict(ckpt["model_state_dict"])
transforms = Transforms(ckpt["height"])
return model, transforms


def evaluate(specimen, isocenter_pose, model, transforms, device):
error = []
model.eval()
for idx in tqdm(range(len(specimen)), ncols=100):
target_registration_error = Evaluator(specimen, idx)
img, _ = specimen[idx]
img = img.to(device)
img = transforms(img)
with torch.no_grad():
offset = model(img)
pred_pose = isocenter_pose.compose(offset)
mtre = target_registration_error(pred_pose.cpu()).item()
error.append(mtre)
return error


def main(id_number):
device = torch.device("cuda")
specimen, isocenter_pose = load_specimen(id_number, device)
models = sorted(Path("checkpoints/").glob(f"specimen_{id_number:02d}_epoch*.ckpt"))

errors = []
for model_name in models:
model, transforms = load_model(model_name, device)
error = evaluate(specimen, isocenter_pose, model, transforms, device)
errors.append([model_name.stem] + error)

df = pd.DataFrame(errors)
df.to_csv(f"evaluations/subject{id_number}.csv", index=False)


if __name__ == "__main__":
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Path("evaluations").mkdir(exist_ok=True)
id_numbers = [1, 2, 3, 4, 5, 6]

executor = submitit.AutoExecutor(folder="logs")
executor.update_parameters(
name="eval",
gpus_per_node=1,
mem_gb=10.0,
slurm_array_parallelism=len(id_numbers),
slurm_exclude="curcum",
slurm_partition="2080ti",
timeout_min=10_000,
)
jobs = executor.map_array(main, id_numbers)
4 changes: 2 additions & 2 deletions experiments/deepfluoro/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train(
best_loss = torch.inf

model.train()
for epoch in range(n_epochs):
for epoch in range(n_epochs + 1):
losses = []
for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)):
contrast = contrast_distribution.sample().item()
Expand Down Expand Up @@ -144,7 +144,7 @@ def train(
f"checkpoints/specimen_{id_number:02d}_best.ckpt",
)

if epoch % 25 == 0 and epoch != 0:
if epoch % 50 == 0:
torch.save(
{
"model_state_dict": model.state_dict(),
Expand Down
2 changes: 0 additions & 2 deletions notebooks/api/00_deepfluoro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,6 @@
"outputs": [],
"source": [
"#| export\n",
"from beartype import beartype\n",
"\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/api/01_ljubljana.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@
"outputs": [],
"source": [
"#| export\n",
"from beartype import beartype\n",
"\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
Expand Down
21 changes: 10 additions & 11 deletions notebooks/api/02_calibration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@
"from typing import Optional\n",
"\n",
"from beartype import beartype\n",
"from diffdrr.utils import Transform3d\n",
"from diffdrr.utils import convert as convert_so3\n",
"from jaxtyping import Float, jaxtyped\n",
"from pytorch3d.transforms import Transform3d\n",
"from pytorchse3.se3 import se3_exp_map, se3_log_map"
"from diffdrr.utils import se3_exp_map, se3_log_map\n",
"from jaxtyping import Float, jaxtyped"
]
},
{
Expand All @@ -115,7 +115,7 @@
"class RigidTransform(Transform3d):\n",
" \"\"\"Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.\"\"\"\n",
"\n",
" @jaxtyped\n",
" @jaxtyped(typechecker=beartype)\n",
" def __init__(\n",
" self,\n",
" R: Float[torch.Tensor, \"...\"],\n",
Expand Down Expand Up @@ -169,7 +169,7 @@
" return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n",
"\n",
" def get_se3_log(self):\n",
" return se3_log_map(self.get_matrix().transpose(-1, -2))"
" return se3_log_map(self.get_matrix())"
]
},
{
Expand All @@ -191,8 +191,8 @@
"\n",
" # Convert any input parameterization to a RigidTransform\n",
" if input_parameterization == \"se3_log_map\":\n",
" transform = torch.concat([*transform], axis=-1)\n",
" matrix = se3_exp_map(transform)\n",
" transform = torch.concat([transform[1], transform[0]], axis=-1)\n",
" matrix = se3_exp_map(transform).transpose(-1, -2)\n",
" transform = RigidTransform(\n",
" R=matrix[..., :3, :3],\n",
" t=matrix[..., :3, 3],\n",
Expand All @@ -214,8 +214,8 @@
" return transform\n",
" elif output_parameterization == \"se3_log_map\":\n",
" se3_log = transform.get_se3_log()\n",
" log_R_vee = se3_log[..., :3]\n",
" log_t_vee = se3_log[..., 3:]\n",
" log_t_vee = se3_log[..., :3]\n",
" log_R_vee = se3_log[..., 3:]\n",
" return log_R_vee, log_t_vee\n",
" else:\n",
" return (\n",
Expand Down Expand Up @@ -243,8 +243,7 @@
"outputs": [],
"source": [
"#| export\n",
"@beartype\n",
"@jaxtyped\n",
"@jaxtyped(typechecker=beartype)\n",
"def perspective_projection(\n",
" extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)\n",
" intrinsic: Float[torch.Tensor, \"3 3\"], # Intrinsic camera matrix (camera to image)\n",
Expand Down
Loading
Loading