From 30c1dd79c13bd4cb9fa01a190d5a04ab7d6facaf Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Wed, 10 Jan 2024 13:26:58 -0500 Subject: [PATCH 1/5] Switch to pytorch3d backend --- diffpose/calibration.py | 18 ++++++-------- diffpose/metrics.py | 18 +++++--------- environment.yml | 3 +-- notebooks/api/02_calibration.ipynb | 18 ++++++-------- notebooks/api/04_metrics.ipynb | 40 ++++++++++-------------------- settings.ini | 2 +- 6 files changed, 37 insertions(+), 62 deletions(-) diff --git a/diffpose/calibration.py b/diffpose/calibration.py index 66a1c9e..b2b73f8 100644 --- a/diffpose/calibration.py +++ b/diffpose/calibration.py @@ -12,15 +12,14 @@ from beartype import beartype from diffdrr.utils import convert as convert_so3 from jaxtyping import Float, jaxtyped -from pytorch3d.transforms import Transform3d -from pytorchse3.se3 import se3_exp_map, se3_log_map +from pytorch3d.transforms import Transform3d, se3_exp_map, se3_log_map # %% ../notebooks/api/02_calibration.ipynb 7 @beartype class RigidTransform(Transform3d): """Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.""" - @jaxtyped + @jaxtyped(typechecker=beartype) def __init__( self, R: Float[torch.Tensor, "..."], @@ -74,7 +73,7 @@ def clone(self): return RigidTransform(R, t, device=self.device, dtype=self.dtype) def get_se3_log(self): - return se3_log_map(self.get_matrix().transpose(-1, -2)) + return se3_log_map(self.get_matrix()) # %% ../notebooks/api/02_calibration.ipynb 8 def convert( @@ -88,8 +87,8 @@ def convert( # Convert any input parameterization to a RigidTransform if input_parameterization == "se3_log_map": - transform = torch.concat([*transform], axis=-1) - matrix = se3_exp_map(transform) + transform = torch.concat([transform[1], transform[0]], axis=-1) + matrix = se3_exp_map(transform).transpose(-1, -2) transform = RigidTransform( R=matrix[..., :3, :3], t=matrix[..., :3, 3], @@ -111,8 +110,8 @@ def convert( return transform elif output_parameterization == "se3_log_map": se3_log = transform.get_se3_log() - log_R_vee = se3_log[..., :3] - log_t_vee = se3_log[..., 3:] + log_t_vee = se3_log[..., :3] + log_R_vee = se3_log[..., 3:] return log_R_vee, log_t_vee else: return ( @@ -121,8 +120,7 @@ def convert( ) # %% ../notebooks/api/02_calibration.ipynb 10 -@beartype -@jaxtyped +@jaxtyped(typechecker=beartype) def perspective_projection( extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera) intrinsic: Float[torch.Tensor, "3 3"], # Intrinsic camera matrix (camera to image) diff --git a/diffpose/metrics.py b/diffpose/metrics.py index 32c1dd9..8d43a70 100644 --- a/diffpose/metrics.py +++ b/diffpose/metrics.py @@ -66,25 +66,22 @@ def __init__(self, patch_size=None): from diffdrr.utils import convert from jaxtyping import Float, jaxtyped from pytorch3d.transforms import ( - so3_rotation_angle, + so3_log_map, so3_relative_angle, + so3_rotation_angle, standardize_quaternion, ) from .calibration import RigidTransform # %% ../notebooks/api/04_metrics.ipynb 10 -from pytorchse3.so3 import so3_log_map - - class GeodesicSO3(torch.nn.Module): """Calculate the angular distance between two rotations in SO(3).""" def __init__(self): super().__init__() - @beartype - @jaxtyped + @jaxtyped(typechecker=beartype) def forward( self, pose_1: RigidTransform, @@ -102,8 +99,7 @@ class GeodesicTranslation(torch.nn.Module): def __init__(self): super().__init__() - @beartype - @jaxtyped + @jaxtyped(typechecker=beartype) def forward( self, pose_1: RigidTransform, @@ -120,8 +116,7 @@ class GeodesicSE3(torch.nn.Module): def __init__(self): super().__init__() - @beartype - @jaxtyped + @jaxtyped(typechecker=beartype) def forward( self, pose_1: RigidTransform, @@ -146,8 +141,7 @@ def __init__( self.rotation = GeodesicSO3() self.translation = GeodesicTranslation() - @beartype - @jaxtyped + @jaxtyped(typechecker=beartype) def forward(self, pose_1: RigidTransform, pose_2: RigidTransform): angular_geodesic = self.sdr * self.rotation(pose_1, pose_2) translation_geodesic = self.translation(pose_1, pose_2) diff --git a/environment.yml b/environment.yml index 189d3ac..1f58326 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: preop +name: diffpose channels: - conda-forge - pytorch @@ -15,7 +15,6 @@ dependencies: - scikit-image - seaborn - pytorch-transformers - - pytorchse3 - timm - torchmetrics - tqdm diff --git a/notebooks/api/02_calibration.ipynb b/notebooks/api/02_calibration.ipynb index 6ad9380..dc4eab0 100644 --- a/notebooks/api/02_calibration.ipynb +++ b/notebooks/api/02_calibration.ipynb @@ -99,8 +99,7 @@ "from beartype import beartype\n", "from diffdrr.utils import convert as convert_so3\n", "from jaxtyping import Float, jaxtyped\n", - "from pytorch3d.transforms import Transform3d\n", - "from pytorchse3.se3 import se3_exp_map, se3_log_map" + "from pytorch3d.transforms import Transform3d, se3_exp_map, se3_log_map" ] }, { @@ -115,7 +114,7 @@ "class RigidTransform(Transform3d):\n", " \"\"\"Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.\"\"\"\n", "\n", - " @jaxtyped\n", + " @jaxtyped(typechecker=beartype)\n", " def __init__(\n", " self,\n", " R: Float[torch.Tensor, \"...\"],\n", @@ -169,7 +168,7 @@ " return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n", "\n", " def get_se3_log(self):\n", - " return se3_log_map(self.get_matrix().transpose(-1, -2))" + " return se3_log_map(self.get_matrix())" ] }, { @@ -191,8 +190,8 @@ "\n", " # Convert any input parameterization to a RigidTransform\n", " if input_parameterization == \"se3_log_map\":\n", - " transform = torch.concat([*transform], axis=-1)\n", - " matrix = se3_exp_map(transform)\n", + " transform = torch.concat([transform[1], transform[0]], axis=-1)\n", + " matrix = se3_exp_map(transform).transpose(-1, -2)\n", " transform = RigidTransform(\n", " R=matrix[..., :3, :3],\n", " t=matrix[..., :3, 3],\n", @@ -214,8 +213,8 @@ " return transform\n", " elif output_parameterization == \"se3_log_map\":\n", " se3_log = transform.get_se3_log()\n", - " log_R_vee = se3_log[..., :3]\n", - " log_t_vee = se3_log[..., 3:]\n", + " log_t_vee = se3_log[..., :3]\n", + " log_R_vee = se3_log[..., 3:]\n", " return log_R_vee, log_t_vee\n", " else:\n", " return (\n", @@ -243,8 +242,7 @@ "outputs": [], "source": [ "#| export\n", - "@beartype\n", - "@jaxtyped\n", + "@jaxtyped(typechecker=beartype)\n", "def perspective_projection(\n", " extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)\n", " intrinsic: Float[torch.Tensor, \"3 3\"], # Intrinsic camera matrix (camera to image)\n", diff --git a/notebooks/api/04_metrics.ipynb b/notebooks/api/04_metrics.ipynb index 752f035..985a016 100644 --- a/notebooks/api/04_metrics.ipynb +++ b/notebooks/api/04_metrics.ipynb @@ -171,7 +171,12 @@ "from beartype import beartype\n", "from diffdrr.utils import convert\n", "from jaxtyping import Float, jaxtyped\n", - "from pytorch3d.transforms import so3_rotation_angle, so3_relative_angle, standardize_quaternion\n", + "from pytorch3d.transforms import (\n", + " so3_log_map,\n", + " so3_relative_angle,\n", + " so3_rotation_angle,\n", + " standardize_quaternion,\n", + ")\n", "\n", "from diffpose.calibration import RigidTransform" ] @@ -181,32 +186,16 @@ "execution_count": null, "id": "1ff308dc-4807-46dd-bd10-ef9dca35c4c9", "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'torch' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| export\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpytorchse3\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mso3\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m so3_log_map\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mGeodesicSO3\u001b[39;00m(\u001b[43mtorch\u001b[49m\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mModule):\n\u001b[1;32m 6\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\u001b[39;00m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, eps\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m):\n", - "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined" - ] - } - ], + "outputs": [], "source": [ "#| export\n", - "from pytorchse3.so3 import so3_log_map\n", - "\n", - "\n", "class GeodesicSO3(torch.nn.Module):\n", " \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", - " @beartype\n", - " @jaxtyped\n", + " @jaxtyped(typechecker=beartype)\n", " def forward(\n", " self,\n", " pose_1: RigidTransform,\n", @@ -214,7 +203,7 @@ " ) -> Float[torch.Tensor, \"b\"]:\n", " r1 = pose_1.get_rotation()\n", " r2 = pose_2.get_rotation()\n", - " rdiff = r1 @ r2.transpose(-1,-2)\n", + " rdiff = r1 @ r2.transpose(-1, -2)\n", " return so3_log_map(rdiff).norm(dim=-1)\n", "\n", "\n", @@ -224,8 +213,7 @@ " def __init__(self):\n", " super().__init__()\n", "\n", - " @beartype\n", - " @jaxtyped\n", + " @jaxtyped(typechecker=beartype)\n", " def forward(\n", " self,\n", " pose_1: RigidTransform,\n", @@ -250,8 +238,7 @@ " def __init__(self):\n", " super().__init__()\n", "\n", - " @beartype\n", - " @jaxtyped\n", + " @jaxtyped(typechecker=beartype)\n", " def forward(\n", " self,\n", " pose_1: RigidTransform,\n", @@ -284,8 +271,7 @@ " self.rotation = GeodesicSO3()\n", " self.translation = GeodesicTranslation()\n", "\n", - " @beartype\n", - " @jaxtyped\n", + " @jaxtyped(typechecker=beartype)\n", " def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n", " angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)\n", " translation_geodesic = self.translation(pose_1, pose_2)\n", @@ -391,7 +377,7 @@ { "data": { "text/plain": [ - "(tensor([50.9999]), tensor([1.7321]), tensor([51.0293]))" + "(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))" ] }, "execution_count": null, diff --git a/settings.ini b/settings.ini index b94d27b..d1772a3 100644 --- a/settings.ini +++ b/settings.ini @@ -39,6 +39,6 @@ status = 3 user = eigenvivek ### Optional ### -requirements = diffdrr h5py scikit-image seaborn torch torchvision pytorch3d timm pytorch-transformers pytorchse3 torchmetrics tqdm beartype jaxtyping +requirements = diffdrr h5py scikit-image seaborn torch torchvision pytorch3d timm pytorch-transformers torchmetrics tqdm beartype jaxtyping dev_requirements = jupyterlab_code_formatter black flake8 isort nbdev ipykernel jupyter-server-proxy optional_requirements = submitit From 7283f6ad26a3b44bb72a74ebce030480551ced1e Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Wed, 10 Jan 2024 13:27:19 -0500 Subject: [PATCH 2/5] Use new jaxtype decorator convention --- diffpose/deepfluoro.py | 2 -- diffpose/ljubljana.py | 2 -- notebooks/api/00_deepfluoro.ipynb | 2 -- notebooks/api/01_ljubljana.ipynb | 2 -- 4 files changed, 8 deletions(-) diff --git a/diffpose/deepfluoro.py b/diffpose/deepfluoro.py index 9e3c6fa..7fbb58b 100644 --- a/diffpose/deepfluoro.py +++ b/diffpose/deepfluoro.py @@ -306,8 +306,6 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)): return img # %% ../notebooks/api/00_deepfluoro.ipynb 26 -from beartype import beartype - from .calibration import RigidTransform, convert diff --git a/diffpose/ljubljana.py b/diffpose/ljubljana.py index d973136..b46b759 100644 --- a/diffpose/ljubljana.py +++ b/diffpose/ljubljana.py @@ -115,8 +115,6 @@ def __getitem__(self, idx): ) # %% ../notebooks/api/01_ljubljana.ipynb 7 -from beartype import beartype - from .calibration import RigidTransform, convert diff --git a/notebooks/api/00_deepfluoro.ipynb b/notebooks/api/00_deepfluoro.ipynb index d577c76..d857a50 100644 --- a/notebooks/api/00_deepfluoro.ipynb +++ b/notebooks/api/00_deepfluoro.ipynb @@ -712,8 +712,6 @@ "outputs": [], "source": [ "#| export\n", - "from beartype import beartype\n", - "\n", "from diffpose.calibration import RigidTransform, convert\n", "\n", "\n", diff --git a/notebooks/api/01_ljubljana.ipynb b/notebooks/api/01_ljubljana.ipynb index 48eee1a..8cdad9e 100644 --- a/notebooks/api/01_ljubljana.ipynb +++ b/notebooks/api/01_ljubljana.ipynb @@ -186,8 +186,6 @@ "outputs": [], "source": [ "#| export\n", - "from beartype import beartype\n", - "\n", "from diffpose.calibration import RigidTransform, convert\n", "\n", "\n", From 716901bdeff51b89f6f5ba17653a9477e600191d Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Fri, 12 Jan 2024 08:48:12 -0500 Subject: [PATCH 3/5] Add checkpoint evaluation script --- .gitignore | 1 + experiments/deepfluoro/evaluate.py | 82 ++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 experiments/deepfluoro/evaluate.py diff --git a/.gitignore b/.gitignore index d18163e..b8d30e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ checkpoints/ +evaluations/ data/ logs/ runs/ diff --git a/experiments/deepfluoro/evaluate.py b/experiments/deepfluoro/evaluate.py new file mode 100644 index 0000000..517dbca --- /dev/null +++ b/experiments/deepfluoro/evaluate.py @@ -0,0 +1,82 @@ +from pathlib import Path + +import pandas as pd +import submitit +import torch +from tqdm import tqdm + +from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms +from diffpose.registration import PoseRegressor + + +def load_specimen(id_number, device): + specimen = DeepFluoroDataset(id_number) + isocenter_pose = specimen.isocenter_pose.to(device) + return specimen, isocenter_pose + + +def load_model(model_name, device): + ckpt = torch.load(model_name) + model = PoseRegressor( + ckpt["model_name"], + ckpt["parameterization"], + ckpt["convention"], + norm_layer=ckpt["norm_layer"], + ).to(device) + model.load_state_dict(ckpt["model_state_dict"]) + transforms = Transforms(ckpt["height"]) + return model, transforms + + +def evaluate(specimen, isocenter_pose, model, transforms, device): + error = [] + model.eval() + for idx in tqdm(range(len(specimen)), ncols=100): + target_registration_error = Evaluator(specimen, idx) + img, _ = specimen[idx] + img = img.to(device) + img = transforms(img) + with torch.no_grad(): + offset = model(img) + pred_pose = isocenter_pose.compose(offset) + mtre = target_registration_error(pred_pose.cpu()).item() + error.append(mtre) + return error + + +def main(id_number): + device = torch.device("cuda") + specimen, isocenter_pose = load_specimen(id_number, device) + models = sorted(Path("checkpoints/").glob(f"specimen_{id_number:02d}_epoch*.ckpt")) + + errors = [] + for model_name in models: + model, transforms = load_model(model_name, device) + error = evaluate(specimen, isocenter_pose, model, transforms, device) + errors.append([model_name.stem] + error) + + df = pd.DataFrame(errors) + df.to_csv(f"evaluations/subject{id_number}.csv", index=False) + + +if __name__ == "__main__": + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + Path("evaluations").mkdir(exist_ok=True) + id_numbers = [1, 2, 3, 4, 5, 6] + + executor = submitit.AutoExecutor(folder="logs") + executor.update_parameters( + name="eval", + gpus_per_node=1, + mem_gb=10.0, + slurm_array_parallelism=len(id_numbers), + slurm_exclude="curcum", + slurm_partition="2080ti", + timeout_min=10_000, + ) + jobs = executor.map_array(main, id_numbers) From 5d247c02ff4a76693d22e978dd7c35bc0aea460c Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Fri, 12 Jan 2024 08:53:49 -0500 Subject: [PATCH 4/5] Decrease number of checkpoints saved --- experiments/deepfluoro/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiments/deepfluoro/train.py b/experiments/deepfluoro/train.py index 0ee17e7..5ff74e5 100644 --- a/experiments/deepfluoro/train.py +++ b/experiments/deepfluoro/train.py @@ -57,7 +57,7 @@ def train( best_loss = torch.inf model.train() - for epoch in range(n_epochs): + for epoch in range(n_epochs + 1): losses = [] for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)): contrast = contrast_distribution.sample().item() @@ -144,7 +144,7 @@ def train( f"checkpoints/specimen_{id_number:02d}_best.ckpt", ) - if epoch % 25 == 0 and epoch != 0: + if epoch % 50 == 0: torch.save( { "model_state_dict": model.state_dict(), From 66e5bfd9cfd0b968d3c9700532c317dff32322da Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Wed, 24 Jan 2024 12:09:47 -0500 Subject: [PATCH 5/5] Switch to diffdrr's wrapper of pytorch3d --- diffpose/calibration.py | 3 ++- diffpose/metrics.py | 6 +++--- environment.yml | 2 -- notebooks/api/02_calibration.ipynb | 5 +++-- notebooks/api/04_metrics.ipynb | 6 +++--- settings.ini | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/diffpose/calibration.py b/diffpose/calibration.py index b2b73f8..6d396ed 100644 --- a/diffpose/calibration.py +++ b/diffpose/calibration.py @@ -10,9 +10,10 @@ from typing import Optional from beartype import beartype +from diffdrr.utils import Transform3d from diffdrr.utils import convert as convert_so3 +from diffdrr.utils import se3_exp_map, se3_log_map from jaxtyping import Float, jaxtyped -from pytorch3d.transforms import Transform3d, se3_exp_map, se3_log_map # %% ../notebooks/api/02_calibration.ipynb 7 @beartype diff --git a/diffpose/metrics.py b/diffpose/metrics.py index 8d43a70..126df22 100644 --- a/diffpose/metrics.py +++ b/diffpose/metrics.py @@ -63,14 +63,14 @@ def __init__(self, patch_size=None): # %% ../notebooks/api/04_metrics.ipynb 9 import torch from beartype import beartype -from diffdrr.utils import convert -from jaxtyping import Float, jaxtyped -from pytorch3d.transforms import ( +from diffdrr.utils import ( + convert, so3_log_map, so3_relative_angle, so3_rotation_angle, standardize_quaternion, ) +from jaxtyping import Float, jaxtyped from .calibration import RigidTransform diff --git a/environment.yml b/environment.yml index 1f58326..09c21a3 100644 --- a/environment.yml +++ b/environment.yml @@ -2,13 +2,11 @@ name: diffpose channels: - conda-forge - pytorch - - pytorch3d - nvidia dependencies: - pip - pytorch - torchvision - - pytorch3d - pip: - diffdrr>=0.3.8 - h5py diff --git a/notebooks/api/02_calibration.ipynb b/notebooks/api/02_calibration.ipynb index dc4eab0..142be05 100644 --- a/notebooks/api/02_calibration.ipynb +++ b/notebooks/api/02_calibration.ipynb @@ -97,9 +97,10 @@ "from typing import Optional\n", "\n", "from beartype import beartype\n", + "from diffdrr.utils import Transform3d\n", "from diffdrr.utils import convert as convert_so3\n", - "from jaxtyping import Float, jaxtyped\n", - "from pytorch3d.transforms import Transform3d, se3_exp_map, se3_log_map" + "from diffdrr.utils import se3_exp_map, se3_log_map\n", + "from jaxtyping import Float, jaxtyped" ] }, { diff --git a/notebooks/api/04_metrics.ipynb b/notebooks/api/04_metrics.ipynb index 985a016..84695ad 100644 --- a/notebooks/api/04_metrics.ipynb +++ b/notebooks/api/04_metrics.ipynb @@ -169,14 +169,14 @@ "#| export\n", "import torch\n", "from beartype import beartype\n", - "from diffdrr.utils import convert\n", - "from jaxtyping import Float, jaxtyped\n", - "from pytorch3d.transforms import (\n", + "from diffdrr.utils import (\n", + " convert,\n", " so3_log_map,\n", " so3_relative_angle,\n", " so3_rotation_angle,\n", " standardize_quaternion,\n", ")\n", + "from jaxtyping import Float, jaxtyped\n", "\n", "from diffpose.calibration import RigidTransform" ] diff --git a/settings.ini b/settings.ini index d1772a3..3bd0efb 100644 --- a/settings.ini +++ b/settings.ini @@ -39,6 +39,6 @@ status = 3 user = eigenvivek ### Optional ### -requirements = diffdrr h5py scikit-image seaborn torch torchvision pytorch3d timm pytorch-transformers torchmetrics tqdm beartype jaxtyping +requirements = diffdrr h5py scikit-image seaborn torch torchvision timm pytorch-transformers torchmetrics tqdm beartype jaxtyping dev_requirements = jupyterlab_code_formatter black flake8 isort nbdev ipykernel jupyter-server-proxy optional_requirements = submitit