|
104 | 104 | "import torch\n", |
105 | 105 | "import torch.nn.functional as F\n", |
106 | 106 | "import torchvision\n", |
107 | | - "import sys\n", |
108 | 107 | "from monai import transforms\n", |
109 | 108 | "from monai.apps import DecathlonDataset\n", |
110 | 109 | "from monai.config import print_config\n", |
|
191 | 190 | "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", |
192 | 191 | "3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n", |
193 | 192 | "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n", |
194 | | - "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", |
195 | | - "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", |
| 193 | + "5. `CenterSpatialCropd`: we crop the 3D images to a specific size\n", |
| 194 | + "6. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", |
| 195 | + "7. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", |
196 | 196 | "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." |
197 | 197 | ] |
198 | 198 | }, |
|
388 | 388 | }, |
389 | 389 | "outputs": [], |
390 | 390 | "source": [ |
391 | | - "class Diffusion_AE(torch.nn.Module):\n", |
| 391 | + "class DiffusionAE(torch.nn.Module):\n", |
392 | 392 | " def __init__(self, embedding_dimension=64):\n", |
393 | 393 | " super().__init__()\n", |
394 | 394 | " self.unet = DiffusionModelUNet(\n", |
|
413 | 413 | "\n", |
414 | 414 | "\n", |
415 | 415 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", |
416 | | - "model = Diffusion_AE(embedding_dimension=512).to(device)\n", |
| 416 | + "model = DiffusionAE(embedding_dimension=512).to(device)\n", |
417 | 417 | "scheduler = DDIMScheduler(num_train_timesteps=1000)\n", |
418 | 418 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5)\n", |
419 | 419 | "inferer = DiffusionInferer(scheduler)" |
|
492 | 492 | " # Create timesteps\n", |
493 | 493 | " timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n", |
494 | 494 | " # Get model prediction\n", |
495 | | - " # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1\n", |
| 495 | + " # cross attention expects shape [batch size, sequence length, channels], \n", |
| 496 | + " #we are use channels = latent dimension and sequence length = 1\n", |
496 | 497 | " latent = model.semantic_encoder(images)\n", |
497 | 498 | " noise_pred = inferer(\n", |
498 | 499 | " inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)\n", |
|
509 | 510 | " if epoch % val_interval == 0:\n", |
510 | 511 | " model.eval()\n", |
511 | 512 | " val_iter_loss = 0\n", |
512 | | - " for val_step, val_batch in enumerate(val_loader):\n", |
| 513 | + " for _, val_batch in enumerate(val_loader):\n", |
513 | 514 | " with torch.no_grad():\n", |
514 | 515 | " images = val_batch[\"image\"].to(device)\n", |
515 | 516 | " timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()\n", |
|
526 | 527 | "\n", |
527 | 528 | " val_iter_loss += val_loss.item()\n", |
528 | 529 | " iter_loss_list.append(iter_loss / val_interval)\n", |
529 | | - " val_iter_loss_list.append(val_iter_loss / (val_step + 1))\n", |
| 530 | + " val_iter_loss_list.append(val_iter_loss / len(val_loader))\n", |
530 | 531 | " iter_loss = 0\n", |
531 | 532 | " print(\n", |
532 | | - " f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n", |
| 533 | + " f\"Iteration {epoch} - Interval Loss {iter_loss_list[-1]:.4f}, \n", |
| 534 | + " Interval Loss Val {val_iter_loss_list[-1]:.4f}\"\n", |
533 | 535 | " )\n", |
534 | 536 | "\n", |
535 | 537 | "total_time = time.time() - total_start\n", |
|
566 | 568 | "plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n", |
567 | 569 | "plt.plot(list(range(len(iter_loss_list))), iter_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", |
568 | 570 | "plt.plot(list(range(len(iter_loss_list))), val_iter_loss_list, color=\"C4\", linewidth=2.0, label=\"Validation\")\n", |
569 | | - "plt.yticks(fontsize=12), plt.xticks(fontsize=12)\n", |
570 | | - "plt.xlabel(\"Iterations\", fontsize=16), plt.ylabel(\"Loss\", fontsize=16)\n", |
| 571 | + "plt.yticks(fontsize=12)\n", |
| 572 | + "plt.xticks(fontsize=12)\n", |
| 573 | + "plt.xlabel(\"Iterations\", fontsize=16)\n", |
| 574 | + "plt.ylabel(\"Loss\", fontsize=16)\n", |
571 | 575 | "plt.legend(prop={\"size\": 14})\n", |
572 | 576 | "plt.show()" |
573 | 577 | ] |
|
713 | 717 | } |
714 | 718 | ], |
715 | 719 | "source": [ |
716 | | - "latents_train.shape, classes_train.shape" |
| 720 | + "print(latents_train.shape)\n", |
| 721 | + "print(classes_train.shape)" |
717 | 722 | ] |
718 | 723 | }, |
719 | 724 | { |
|
735 | 740 | ], |
736 | 741 | "source": [ |
737 | 742 | "clf = LogisticRegression(solver=\"newton-cg\", random_state=0).fit(latents_train, classes_train)\n", |
738 | | - "clf.score(latents_train, classes_train), clf.score(latents_val, classes_val)" |
739 | | - ] |
740 | | - }, |
741 | | - { |
742 | | - "cell_type": "code", |
743 | | - "execution_count": 22, |
744 | | - "id": "73df71e0", |
745 | | - "metadata": {}, |
746 | | - "outputs": [], |
747 | | - "source": [ |
748 | | - "w = torch.Tensor(clf.coef_).float().to(device)" |
| 743 | + "print(clf.score(latents_train, classes_train))\n", |
| 744 | + "print(clf.score(latents_val, classes_val))" |
749 | 745 | ] |
750 | 746 | }, |
751 | 747 | { |
|
777 | 773 | "source": [ |
778 | 774 | "s = -1.5\n", |
779 | 775 | "\n", |
| 776 | + "w = torch.Tensor(clf.coef_).float().to(device)\n", |
780 | 777 | "scheduler.set_timesteps(num_inference_steps=100)\n", |
781 | 778 | "batch = next(iter(val_loader))\n", |
782 | 779 | "images = batch[\"image\"].to(device)\n", |
|
802 | 799 | ")" |
803 | 800 | ] |
804 | 801 | }, |
| 802 | + { |
| 803 | + "cell_type": "markdown", |
| 804 | + "id": "525702b5", |
| 805 | + "metadata": {}, |
| 806 | + "source": [ |
| 807 | + "Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster: " |
| 808 | + ] |
| 809 | + }, |
805 | 810 | { |
806 | 811 | "cell_type": "code", |
807 | 812 | "execution_count": 28, |
|
831 | 836 | "plt.figure(figsize=(15, 5))\n", |
832 | 837 | "plt.imshow(grid.detach().cpu().numpy()[0], cmap=\"gray\")\n", |
833 | 838 | "plt.axis(\"off\")\n", |
834 | | - "plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\");" |
| 839 | + "plt.title(f\"Original (top), Reconstruction (middle), Manipulated (bottom) s = {s}\")" |
835 | 840 | ] |
836 | 841 | }, |
837 | 842 | { |
838 | | - "cell_type": "markdown", |
839 | | - "id": "b5ac0b8c-0f9d-43ba-9959-488ab62e892e", |
| 843 | + "cell_type": "code", |
| 844 | + "execution_count": null, |
| 845 | + "id": "9cf8fbf9", |
840 | 846 | "metadata": {}, |
| 847 | + "outputs": [], |
841 | 848 | "source": [ |
842 | | - "Although not perfectly, the manipulated slices do not present a tumour (unlike the middle - \"reconstructed\" - ones), because we tweaked the latents to move away from the abnormality cluster." |
| 849 | + "if directory is None:\n", |
| 850 | + " shutil.rmtree(root_dir)" |
843 | 851 | ] |
844 | 852 | } |
845 | 853 | ], |
|
0 commit comments