From 4329ce29bbc9d03cd2a04db0a27818c940bc58fc Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sun, 28 Jan 2024 14:30:47 -0500 Subject: [PATCH] Clean up redundant function --- diffpose/deepfluoro.py | 2 +- diffpose/ljubljana.py | 2 +- notebooks/api/00_deepfluoro.ipynb | 35 ------------------------------- notebooks/api/01_ljubljana.ipynb | 2 +- 4 files changed, 3 insertions(+), 38 deletions(-) diff --git a/diffpose/deepfluoro.py b/diffpose/deepfluoro.py index 7fbb58b..341bbda 100644 --- a/diffpose/deepfluoro.py +++ b/diffpose/deepfluoro.py @@ -325,7 +325,7 @@ def get_random_offset(batch_size: int, device) -> RigidTransform: "se3_exp_map", ) -# %% ../notebooks/api/00_deepfluoro.ipynb 33 +# %% ../notebooks/api/00_deepfluoro.ipynb 32 from torchvision.transforms import Compose, Lambda, Normalize, Resize diff --git a/diffpose/ljubljana.py b/diffpose/ljubljana.py index b46b759..59871da 100644 --- a/diffpose/ljubljana.py +++ b/diffpose/ljubljana.py @@ -131,7 +131,7 @@ def get_random_offset(view, batch_size: int, device) -> RigidTransform: t1 = torch.distributions.Normal(75, 30).sample((batch_size,)) t2 = torch.distributions.Normal(-80, 30).sample((batch_size,)) t3 = torch.distributions.Normal(-5, 30).sample((batch_size,)) - r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,)) + r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,)) else: diff --git a/notebooks/api/00_deepfluoro.ipynb b/notebooks/api/00_deepfluoro.ipynb index d857a50..69631e2 100644 --- a/notebooks/api/00_deepfluoro.ipynb +++ b/notebooks/api/00_deepfluoro.ipynb @@ -732,41 +732,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "e14ac65e-3c57-4c67-be3a-1d7184f669b3", - "metadata": {}, - "outputs": [], - "source": [ - "@beartype\n", - "def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n", - " if view == \"ap\":\n", - " t1 = torch.distributions.Normal(-6, 20).sample((batch_size,))\n", - " t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", - " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", - " r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n", - " elif view == \"lat\":\n", - " t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n", - " t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n", - " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", - " r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n", - " else:\n", - " raise ValueError(f\"view must be 'ap' or 'lat', not '{view}'\")\n", - "\n", - " logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)\n", - " T = convert(\n", - " [logmap[..., :3], logmap[..., 3:]],\n", - " \"se3_log_map\",\n", - " \"se3_exp_map\",\n", - " )\n", - " return T" - ] - }, { "cell_type": "markdown", "id": "76efedd3-103c-43dd-944b-fe0c06e2c87b", diff --git a/notebooks/api/01_ljubljana.ipynb b/notebooks/api/01_ljubljana.ipynb index 8cdad9e..588214e 100644 --- a/notebooks/api/01_ljubljana.ipynb +++ b/notebooks/api/01_ljubljana.ipynb @@ -202,7 +202,7 @@ " t1 = torch.distributions.Normal(75, 30).sample((batch_size,))\n", " t2 = torch.distributions.Normal(-80, 30).sample((batch_size,))\n", " t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", " r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n", " else:\n",