Skip to content

Commit f0bb2e5

Browse files
author
Virginia
committed
Updates
Signed-off-by: Virginia <[email protected]>
1 parent 04f4d21 commit f0bb2e5

File tree

1 file changed

+35
-27
lines changed

1 file changed

+35
-27
lines changed

generation/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
"import torch\n",
105105
"import torch.nn.functional as F\n",
106106
"import torchvision\n",
107-
"import sys\n",
108107
"from monai import transforms\n",
109108
"from monai.apps import DecathlonDataset\n",
110109
"from monai.config import print_config\n",
@@ -191,8 +190,9 @@
191190
"2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n",
192191
"3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n",
193192
"4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n",
194-
"5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n",
195-
"6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n",
193+
"5. `CenterSpatialCropd`: we crop the 3D images to a specific size\n",
194+
"6. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n",
195+
"7. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n",
196196
"6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )."
197197
]
198198
},
@@ -388,7 +388,7 @@
388388
},
389389
"outputs": [],
390390
"source": [
391-
"class Diffusion_AE(torch.nn.Module):\n",
391+
"class DiffusionAE(torch.nn.Module):\n",
392392
" def __init__(self, embedding_dimension=64):\n",
393393
" super().__init__()\n",
394394
" self.unet = DiffusionModelUNet(\n",
@@ -413,7 +413,7 @@
413413
"\n",
414414
"\n",
415415
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
416-
"model = Diffusion_AE(embedding_dimension=512).to(device)\n",
416+
"model = DiffusionAE(embedding_dimension=512).to(device)\n",
417417
"scheduler = DDIMScheduler(num_train_timesteps=1000)\n",
418418
"optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)\n",
419419
"inferer = DiffusionInferer(scheduler)"
@@ -492,7 +492,8 @@
492492
" # Create timesteps\n",
493493
" timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n",
494494
" # Get model prediction\n",
495-
" # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1\n",
495+
" # cross attention expects shape [batch size, sequence length, channels], \n",
496+
" #we are use channels = latent dimension and sequence length = 1\n",
496497
" latent = model.semantic_encoder(images)\n",
497498
" noise_pred = inferer(\n",
498499
" inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)\n",
@@ -509,7 +510,7 @@
509510
" if epoch % val_interval == 0:\n",
510511
" model.eval()\n",
511512
" val_iter_loss = 0\n",
512-
" for val_step, val_batch in enumerate(val_loader):\n",
513+
" for _, val_batch in enumerate(val_loader):\n",
513514
" with torch.no_grad():\n",
514515
" images = val_batch[\"image\"].to(device)\n",
515516
" timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n",
@@ -526,10 +527,11 @@
526527
"\n",
527528
" val_iter_loss += val_loss.item()\n",
528529
" iter_loss_list.append(iter_loss / val_interval)\n",
529-
" val_iter_loss_list.append(val_iter_loss / (val_step + 1))\n",
530+
" val_iter_loss_list.append(val_iter_loss / len(val_loader))\n",
530531
" iter_loss = 0\n",
531532
" print(\n",
532-
" f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n",
533+
" f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, \n",
534+
" Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n",
533535
" )\n",
534536
"\n",
535537
"total_time = time.time() - total_start\n",
@@ -566,8 +568,10 @@
566568
"plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n",
567569
"plt.plot(list(range(len(iter_loss_list))), iter_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n",
568570
"plt.plot(list(range(len(iter_loss_list))), val_iter_loss_list, color=\"C4\", linewidth=2.0, label=\"Validation\")\n",
569-
"plt.yticks(fontsize=12), plt.xticks(fontsize=12)\n",
570-
"plt.xlabel(\"Iterations\", fontsize=16), plt.ylabel(\"Loss\", fontsize=16)\n",
571+
"plt.yticks(fontsize=12)\n",
572+
"plt.xticks(fontsize=12)\n",
573+
"plt.xlabel(\"Iterations\", fontsize=16)\n",
574+
"plt.ylabel(\"Loss\", fontsize=16)\n",
571575
"plt.legend(prop={\"size\": 14})\n",
572576
"plt.show()"
573577
]
@@ -713,7 +717,8 @@
713717
}
714718
],
715719
"source": [
716-
"latents_train.shape, classes_train.shape"
720+
"print(latents_train.shape)\n",
721+
"print(classes_train.shape)"
717722
]
718723
},
719724
{
@@ -735,17 +740,8 @@
735740
],
736741
"source": [
737742
"clf = LogisticRegression(solver=\"newton-cg\", random_state=0).fit(latents_train, classes_train)\n",
738-
"clf.score(latents_train, classes_train), clf.score(latents_val, classes_val)"
739-
]
740-
},
741-
{
742-
"cell_type": "code",
743-
"execution_count": 22,
744-
"id": "73df71e0",
745-
"metadata": {},
746-
"outputs": [],
747-
"source": [
748-
"w = torch.Tensor(clf.coef_).float().to(device)"
743+
"print(clf.score(latents_train, classes_train))\n",
744+
"print(clf.score(latents_val, classes_val))"
749745
]
750746
},
751747
{
@@ -777,6 +773,7 @@
777773
"source": [
778774
"s = -1.5\n",
779775
"\n",
776+
"w = torch.Tensor(clf.coef_).float().to(device)\n",
780777
"scheduler.set_timesteps(num_inference_steps=100)\n",
781778
"batch = next(iter(val_loader))\n",
782779
"images = batch[\"image\"].to(device)\n",
@@ -802,6 +799,14 @@
802799
")"
803800
]
804801
},
802+
{
803+
"cell_type": "markdown",
804+
"id": "525702b5",
805+
"metadata": {},
806+
"source": [
807+
"Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster: "
808+
]
809+
},
805810
{
806811
"cell_type": "code",
807812
"execution_count": 28,
@@ -831,15 +836,18 @@
831836
"plt.figure(figsize=(15, 5))\n",
832837
"plt.imshow(grid.detach().cpu().numpy()[0], cmap=\"gray\")\n",
833838
"plt.axis(\"off\")\n",
834-
"plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\");"
839+
"plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\")"
835840
]
836841
},
837842
{
838-
"cell_type": "markdown",
839-
"id": "b5ac0b8c-0f9d-43ba-9959-488ab62e892e",
843+
"cell_type": "code",
844+
"execution_count": null,
845+
"id": "9cf8fbf9",
840846
"metadata": {},
847+
"outputs": [],
841848
"source": [
842-
"Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster."
849+
"if directory is None:\n",
850+
" shutil.rmtree(root_dir)"
843851
]
844852
}
845853
],

0 commit comments

Comments
 (0)