|
28 | 28 | "\n", |
29 | 29 | "In this notebook, we detail the procedure for training a 3D latent diffusion model to generate high-dimensional 3D medical images. Due to the potential for out-of-memory issues on most GPUs when generating large images (e.g., those with dimensions of 512 x 512 x 512 or greater), we have structured the training process into two primary steps: 1) generating image embeddings and 2) training 3D latent diffusion models. The subsequent sections will demonstrate the entire process using a simulated dataset.\n", |
30 | 30 | "\n", |
31 | | - "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi-rflow'`. Compared with the previous version `'maisi-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section." |
| 31 | + "`[Release Note (March 2025)]:` We are excited to announce the new MAISI Version `'maisi3d-rflow'`. Compared with the previous version `'maisi3d-ddpm'`, it accelerated latent diffusion model inference by 33x. Please see the detailed difference in the following section." |
32 | 32 | ] |
33 | 33 | }, |
34 | 34 | { |
|
38 | 38 | "source": [ |
39 | 39 | "## Set up the MAISI version\n", |
40 | 40 | "\n", |
41 | | - "Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n", |
42 | | - "- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
43 | | - "- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n", |
44 | | - "- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`." |
| 41 | + "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n", |
| 42 | + "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
| 43 | + "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n", |
| 44 | + "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`." |
45 | 45 | ] |
46 | 46 | }, |
47 | 47 | { |
|
51 | 51 | "metadata": {}, |
52 | 52 | "outputs": [], |
53 | 53 | "source": [ |
54 | | - "maisi_version = \"maisi-ddpm\"\n", |
55 | | - "assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]" |
| 54 | + "maisi_version = \"maisi3d-ddpm\"\n", |
| 55 | + "assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]" |
56 | 56 | ] |
57 | 57 | }, |
58 | 58 | { |
|
131 | 131 | "import numpy as np\n", |
132 | 132 | "import nibabel as nib\n", |
133 | 133 | "import subprocess\n", |
| 134 | + "from IPython.display import Image, display\n", |
134 | 135 | "\n", |
135 | 136 | "from monai.apps import download_url\n", |
136 | 137 | "from monai.data import create_test_image_3d\n", |
137 | 138 | "from monai.config import print_config\n", |
138 | 139 | "\n", |
139 | | - "from IPython.display import Image, display\n", |
140 | | - "\n", |
141 | 140 | "from scripts.diff_model_setting import setup_logging\n", |
142 | 141 | "\n", |
143 | 142 | "print_config()\n", |
|
152 | 151 | "source": [ |
153 | 152 | "## Set up the MAISI version\n", |
154 | 153 | "\n", |
155 | | - "Choose between `'maisi-ddpm'` and `'maisi-rflow'`. The differences are:\n", |
156 | | - "- The maisi version `'maisi-ddpm'` uses basic noise scheduler DDPM. `'maisi-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
157 | | - "- The maisi version `'maisi-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi-rflow'`.\n", |
158 | | - "- For the released model weights, `'maisi-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi-ddpm'`." |
| 154 | + "Choose between `'maisi3d-ddpm'` and `'maisi3d-rflow'`. The differences are:\n", |
| 155 | + "- The maisi version `'maisi3d-ddpm'` uses basic noise scheduler DDPM. `'maisi3d-rflow'` uses Rectified Flow scheduler, can be 33 times faster during inference.\n", |
| 156 | + "- The maisi version `'maisi3d-ddpm'` requires training images to be labeled with body region (`\"top_region_index\"` and `\"bottom_region_index\"`), while `'maisi3d-rflow'` does not have such requirement. In other words, it is easier to prepare training data for `'maisi3d-rflow'`.\n", |
| 157 | + "- For the released model weights, `'maisi3d-rflow'` can generate images with better quality for head region and small output volumes, and comparable quality for other cases compared with `'maisi3d-ddpm'`." |
159 | 158 | ] |
160 | 159 | }, |
161 | 160 | { |
|
165 | 164 | "metadata": {}, |
166 | 165 | "outputs": [], |
167 | 166 | "source": [ |
168 | | - "maisi_version = \"maisi-ddpm\"\n", |
169 | | - "assert maisi_version in [\"maisi-ddpm\", \"maisi-rflow\"]" |
| 167 | + "maisi_version = \"maisi3d-ddpm\"\n", |
| 168 | + "assert maisi_version in [\"maisi3d-ddpm\", \"maisi3d-rflow\"]" |
170 | 169 | ] |
171 | 170 | }, |
172 | 171 | { |
|
213 | 212 | "name": "stderr", |
214 | 213 | "output_type": "stream", |
215 | 214 | "text": [ |
216 | | - "[2025-03-11 22:05:02.952][ INFO](notebook) - Generated simulated images.\n" |
| 215 | + "[2025-03-11 22:16:41.000][ INFO](notebook) - Generated simulated images.\n" |
217 | 216 | ] |
218 | 217 | } |
219 | 218 | ], |
|
260 | 259 | "name": "stderr", |
261 | 260 | "output_type": "stream", |
262 | 261 | "text": [ |
263 | | - "[2025-03-11 22:05:02.966][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
264 | | - "[2025-03-11 22:05:02.966][ INFO](notebook) - number of GPUs: 1.\n" |
| 262 | + "[2025-03-11 22:16:41.012][ INFO](notebook) - files and folders under work_dir: ['predictions', 'config_maisi.json', 'models', 'sim_dataroot', 'config_maisi_diff_model.json', 'embeddings', 'environment_maisi_diff_model.json', 'sim_datalist.json'].\n", |
| 263 | + "[2025-03-11 22:16:41.012][ INFO](notebook) - number of GPUs: 1.\n" |
265 | 264 | ] |
266 | 265 | } |
267 | 266 | ], |
268 | 267 | "source": [ |
269 | 268 | "env_config_path = \"./configs/environment_maisi_diff_model.json\"\n", |
270 | 269 | "model_config_path = \"./configs/config_maisi_diff_model.json\"\n", |
271 | | - "if maisi_version == \"maisi-ddpm\":\n", |
272 | | - " model_def_path = \"./configs/config_maisi-ddpm.json\"\n", |
| 270 | + "if maisi_version == \"maisi3d-ddpm\":\n", |
| 271 | + " model_def_path = \"./configs/config_maisi3d-ddpm.json\"\n", |
273 | 272 | " include_body_region = True\n", |
274 | | - "elif maisi_version == \"maisi-rflow\":\n", |
275 | | - " model_def_path = \"./configs/config_maisi-rflow.json\"\n", |
| 273 | + "elif maisi_version == \"maisi3d-rflow\":\n", |
| 274 | + " model_def_path = \"./configs/config_maisi3d-rflow.json\"\n", |
276 | 275 | " include_body_region = False\n", |
277 | 276 | "else:\n", |
278 | | - " raise ValueError(f\"maisi_version has to be chosen from ['maisi-ddpm', 'maisi-rflow'], yet got {maisi_version}.\")\n", |
| 277 | + " raise ValueError(f\"maisi_version has to be chosen from ['maisi3d-ddpm', 'maisi3d-rflow'], yet got {maisi_version}.\")\n", |
279 | 278 | "\n", |
280 | 279 | "# Load environment configuration, model configuration and model definition\n", |
281 | 280 | "with open(env_config_path, \"r\") as f:\n", |
|
407 | 406 | "name": "stderr", |
408 | 407 | "output_type": "stream", |
409 | 408 | "text": [ |
410 | | - "[2025-03-11 22:05:02.977][ INFO](notebook) - Creating training data...\n" |
| 409 | + "[2025-03-11 22:16:41.021][ INFO](notebook) - Creating training data...\n" |
411 | 410 | ] |
412 | 411 | }, |
413 | 412 | { |
414 | 413 | "name": "stdout", |
415 | 414 | "output_type": "stream", |
416 | 415 | "text": [ |
417 | 416 | "\n", |
418 | | - "[2025-03-11 22:05:10.881][ INFO](creating training data) - Using device cuda:0\n", |
419 | | - "[2025-03-11 22:05:11.686][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
| 417 | + "[2025-03-11 22:16:50.396][ INFO](creating training data) - Using device cuda:0\n", |
| 418 | + "[2025-03-11 22:16:51.402][ INFO](creating training data) - filenames_raw: ['tr_image_001.nii.gz', 'tr_image_002.nii.gz']\n", |
420 | 419 | "\n" |
421 | 420 | ] |
422 | 421 | } |
|
460 | 459 | "name": "stderr", |
461 | 460 | "output_type": "stream", |
462 | 461 | "text": [ |
463 | | - "[2025-03-11 22:05:13.881][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
464 | | - "[2025-03-11 22:05:13.884][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
465 | | - "[2025-03-11 22:05:13.885][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
| 462 | + "[2025-03-11 22:16:53.638][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 463 | + "[2025-03-11 22:16:53.640][ INFO](notebook) - data: {'dim': (64, 64, 32), 'spacing': [0.875, 0.875, 0.75], 'top_region_index': [0, 1, 0, 0], 'bottom_region_index': [0, 0, 1, 0]}.\n", |
| 464 | + "[2025-03-11 22:16:53.641][ INFO](notebook) - Completed creating .json files for all embedding files.\n" |
466 | 465 | ] |
467 | 466 | } |
468 | 467 | ], |
|
539 | 538 | "name": "stderr", |
540 | 539 | "output_type": "stream", |
541 | 540 | "text": [ |
542 | | - "[2025-03-11 22:05:13.892][ INFO](notebook) - Training the model...\n" |
| 541 | + "[2025-03-11 22:16:53.646][ INFO](notebook) - Training the model...\n" |
543 | 542 | ] |
544 | 543 | }, |
545 | 544 | { |
546 | 545 | "name": "stdout", |
547 | 546 | "output_type": "stream", |
548 | 547 | "text": [ |
549 | 548 | "\n", |
550 | | - "[2025-03-11 22:05:24.419][ INFO](training) - Using cuda:0 of 1\n", |
551 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
552 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
553 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
554 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] lr -> 0.0001.\n", |
555 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_epochs -> 2.\n", |
556 | | - "[2025-03-11 22:05:24.419][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
557 | | - "[2025-03-11 22:05:24.420][ INFO](training) - num_files_train: 2\n", |
558 | | - "[2025-03-11 22:05:26.152][ INFO](training) - Training from scratch.\n", |
559 | | - "[2025-03-11 22:05:26.539][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
560 | | - "[2025-03-11 22:05:26.539][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
561 | | - "[2025-03-11 22:05:26.542][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
562 | | - "[2025-03-11 22:05:26.542][ INFO](training) - Epoch 1, lr 0.0001.\n", |
563 | | - "[2025-03-11 22:05:28.578][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n", |
564 | | - "[2025-03-11 22:05:28.719][ INFO](training) - [2025-03-11 22:05:28] epoch 1, iter 2/2, loss: 0.7943, lr: 0.000056250000.\n", |
565 | | - "[2025-03-11 22:05:28.762][ INFO](training) - epoch 1 average loss: 0.7958.\n", |
566 | | - "[2025-03-11 22:05:30.615][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
567 | | - "[2025-03-11 22:05:31.002][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 1/2, loss: 0.7898, lr: 0.000025000000.\n", |
568 | | - "[2025-03-11 22:05:31.105][ INFO](training) - [2025-03-11 22:05:31] epoch 2, iter 2/2, loss: 0.7886, lr: 0.000006250000.\n", |
569 | | - "[2025-03-11 22:05:31.168][ INFO](training) - epoch 2 average loss: 0.7892.\n", |
| 549 | + "[2025-03-11 22:17:02.004][ INFO](training) - Using cuda:0 of 1\n", |
| 550 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] ckpt_folder -> ./temp_work_dir/./models.\n", |
| 551 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] data_root -> ./temp_work_dir/./embeddings.\n", |
| 552 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] data_list -> ./temp_work_dir/sim_datalist.json.\n", |
| 553 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] lr -> 0.0001.\n", |
| 554 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] num_epochs -> 2.\n", |
| 555 | + "[2025-03-11 22:17:02.004][ INFO](training) - [config] num_train_timesteps -> 1000.\n", |
| 556 | + "[2025-03-11 22:17:02.005][ INFO](training) - num_files_train: 2\n", |
| 557 | + "[2025-03-11 22:17:03.887][ INFO](training) - Training from scratch.\n", |
| 558 | + "[2025-03-11 22:17:04.338][ INFO](training) - Scaling factor set to 1.159977912902832.\n", |
| 559 | + "[2025-03-11 22:17:04.339][ INFO](training) - scale_factor -> 1.159977912902832.\n", |
| 560 | + "[2025-03-11 22:17:04.341][ INFO](training) - torch.set_float32_matmul_precision -> highest.\n", |
| 561 | + "[2025-03-11 22:17:04.341][ INFO](training) - Epoch 1, lr 0.0001.\n", |
| 562 | + "[2025-03-11 22:17:05.278][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 1/2, loss: 0.7973, lr: 0.000100000000.\n", |
| 563 | + "[2025-03-11 22:17:05.673][ INFO](training) - [2025-03-11 22:17:05] epoch 1, iter 2/2, loss: 0.7969, lr: 0.000056250000.\n", |
| 564 | + "[2025-03-11 22:17:05.718][ INFO](training) - epoch 1 average loss: 0.7971.\n", |
| 565 | + "[2025-03-11 22:17:07.383][ INFO](training) - Epoch 2, lr 2.5e-05.\n", |
| 566 | + "[2025-03-11 22:17:07.777][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 1/2, loss: 0.7932, lr: 0.000025000000.\n", |
| 567 | + "[2025-03-11 22:17:07.881][ INFO](training) - [2025-03-11 22:17:07] epoch 2, iter 2/2, loss: 0.7904, lr: 0.000006250000.\n", |
| 568 | + "[2025-03-11 22:17:07.942][ INFO](training) - epoch 2 average loss: 0.7918.\n", |
570 | 569 | "\n" |
571 | 570 | ] |
572 | 571 | } |
|
612 | 611 | "name": "stderr", |
613 | 612 | "output_type": "stream", |
614 | 613 | "text": [ |
615 | | - "[2025-03-11 22:05:35.033][ INFO](notebook) - Running inference...\n", |
616 | | - "[2025-03-11 22:05:50.259][ INFO](notebook) - Completed all steps.\n" |
| 614 | + "[2025-03-11 22:17:11.993][ INFO](notebook) - Running inference...\n", |
| 615 | + "[2025-03-11 22:17:27.730][ INFO](notebook) - Completed all steps.\n" |
617 | 616 | ] |
618 | 617 | }, |
619 | 618 | { |
620 | 619 | "name": "stdout", |
621 | 620 | "output_type": "stream", |
622 | 621 | "text": [ |
623 | 622 | "\n", |
624 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - Using cuda:0 of 1 with random seed: 7854\n", |
625 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
626 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] random_seed -> 7854.\n", |
627 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
628 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
629 | | - "[2025-03-11 22:05:43.502][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
630 | | - "[2025-03-11 22:05:43.502][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
631 | | - "[2025-03-11 22:05:45.793][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
632 | | - "[2025-03-11 22:05:45.795][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
633 | | - "[2025-03-11 22:05:45.796][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
634 | | - "[2025-03-11 22:05:45.798][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
| 623 | + "[2025-03-11 22:17:20.465][ INFO](inference) - Using cuda:0 of 1 with random seed: 23141\n", |
| 624 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n", |
| 625 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] random_seed -> 23141.\n", |
| 626 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_prefix -> unet_3d.\n", |
| 627 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] output_size -> (256, 256, 128).\n", |
| 628 | + "[2025-03-11 22:17:20.466][ INFO](inference) - [config] out_spacing -> (1.0, 1.0, 0.75).\n", |
| 629 | + "[2025-03-11 22:17:20.466][ INFO](root) - `controllable_anatomy_size` is not provided.\n", |
| 630 | + "[2025-03-11 22:17:23.065][ INFO](inference) - checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n", |
| 631 | + "[2025-03-11 22:17:23.067][ INFO](inference) - scale_factor -> 1.159977912902832.\n", |
| 632 | + "[2025-03-11 22:17:23.068][ INFO](inference) - num_downsample_level -> 4, divisor -> 4.\n", |
| 633 | + "[2025-03-11 22:17:23.070][ INFO](inference) - noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n", |
635 | 634 | "\n", |
636 | 635 | " 0%| | 0/10 [00:00<?, ?it/s]\n", |
637 | | - " 10%|█ | 1/10 [00:00<00:05, 1.78it/s]\n", |
638 | | - " 60%|██████ | 6/10 [00:00<00:00, 11.19it/s]\n", |
639 | | - "100%|██████████| 10/10 [00:00<00:00, 12.88it/s]\n", |
640 | | - "[2025-03-11 22:05:48.356][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed7854_size256x256x128_spacing1.00x1.00x0.75_20250311220547_rank0.nii.gz.\n", |
| 636 | + " 10%|█ | 1/10 [00:00<00:07, 1.24it/s]\n", |
| 637 | + " 60%|██████ | 6/10 [00:00<00:00, 8.37it/s]\n", |
| 638 | + "100%|██████████| 10/10 [00:01<00:00, 9.78it/s]\n", |
| 639 | + "[2025-03-11 22:17:25.828][ INFO](inference) - Saved ./temp_work_dir/./predictions/unet_3d_seed23141_size256x256x128_spacing1.00x1.00x0.75_20250311221725_rank0.nii.gz.\n", |
641 | 640 | "\n" |
642 | 641 | ] |
643 | 642 | } |
|
0 commit comments