|
78 | 78 | }, |
79 | 79 | { |
80 | 80 | "cell_type": "code", |
81 | | - "execution_count": 69, |
| 81 | + "execution_count": null, |
82 | 82 | "metadata": { |
83 | 83 | "colab": { |
84 | 84 | "base_uri": "https://localhost:8080/" |
|
94 | 94 | " Compose,\n", |
95 | 95 | " LoadImageD,\n", |
96 | 96 | " ScaleIntensityd,\n", |
97 | | - " RandGaussianNoiseD, \n", |
98 | | - " RandGaussianSmoothD, \n", |
99 | | - " RandAdjustContrastD, \n", |
| 97 | + " RandGaussianNoiseD,\n", |
| 98 | + " RandGaussianSmoothD,\n", |
100 | 99 | ")\n", |
101 | | - "import numpy as np\n", |
102 | 100 | "from monai.data import DataLoader, Dataset, CacheDataset\n", |
103 | 101 | "from monai.config import print_config\n", |
104 | 102 | "from monai.networks.nets.restormer import Restormer\n", |
105 | 103 | "from monai.apps import MedNISTDataset\n", |
106 | 104 | "\n", |
107 | | - "import numpy as np\n", |
108 | 105 | "import torch\n", |
109 | 106 | "from monai.losses import SSIMLoss\n", |
110 | 107 | "import matplotlib.pyplot as plt\n", |
111 | 108 | "import os\n", |
112 | 109 | "import tempfile\n", |
113 | 110 | "\n", |
114 | | - "from tqdm.notebook import tqdm\n", |
115 | | - "\n", |
116 | 111 | "\n", |
117 | | - "#print_config()\n", |
118 | | - "#set_determinism(42)" |
| 112 | + "print_config()\n", |
| 113 | + "set_determinism(42)" |
119 | 114 | ] |
120 | 115 | }, |
121 | 116 | { |
|
361 | 356 | }, |
362 | 357 | { |
363 | 358 | "cell_type": "code", |
364 | | - "execution_count": 70, |
| 359 | + "execution_count": null, |
365 | 360 | "metadata": { |
366 | 361 | "id": "zHAj8nuHXG-D", |
367 | 362 | "outputId": "462d37f3-b59e-4d88-ca18-60224f69076d" |
|
374 | 369 | " device = torch.device(\"mps\")\n", |
375 | 370 | "else:\n", |
376 | 371 | " device = torch.device(\"cpu\")\n", |
377 | | - " \n", |
| 372 | + "\n", |
378 | 373 | "model = Restormer(\n", |
379 | 374 | " spatial_dims=2,\n", |
380 | 375 | " in_channels=1,\n", |
|
399 | 394 | }, |
400 | 395 | { |
401 | 396 | "cell_type": "code", |
402 | | - "execution_count": 72, |
| 397 | + "execution_count": null, |
403 | 398 | "metadata": { |
404 | 399 | "id": "eyiL4ccmYsjt" |
405 | 400 | }, |
|
432 | 427 | } |
433 | 428 | ], |
434 | 429 | "source": [ |
435 | | - "max_epochs = 20\n", |
| 430 | + "max_epochs = 2\n", |
436 | 431 | "epoch_loss_values = []\n", |
437 | 432 | "\n", |
438 | 433 | "\n", |
|
448 | 443 | " moving = batch_data[\"moving_hand\"].to(device)\n", |
449 | 444 | " fixed = batch_data[\"fixed_hand\"].to(device)\n", |
450 | 445 | " pred_image = model(moving)\n", |
451 | | - " pred_image=torch.sigmoid(pred_image)\n", |
| 446 | + " pred_image = torch.sigmoid(pred_image)\n", |
452 | 447 | "\n", |
453 | 448 | " loss = image_loss(input=pred_image, target=fixed)\n", |
454 | 449 | " loss.backward()\n", |
|
514 | 509 | }, |
515 | 510 | { |
516 | 511 | "cell_type": "code", |
517 | | - "execution_count": 76, |
| 512 | + "execution_count": null, |
518 | 513 | "metadata": { |
519 | 514 | "colab": { |
520 | 515 | "base_uri": "https://localhost:8080/" |
|
534 | 529 | "source": [ |
535 | 530 | "val_ds = CacheDataset(data=training_datadict[2000:2500], transform=train_transforms, cache_rate=1.0, num_workers=0)\n", |
536 | 531 | "val_loader = DataLoader(val_ds, batch_size=16, num_workers=0)\n", |
537 | | - "model.eval() # Set model to evaluation mode\n", |
| 532 | + "model.eval() # Set model to evaluation mode\n", |
538 | 533 | "\n", |
539 | | - "with torch.no_grad(): # Disable gradient calculation for inference\n", |
| 534 | + "with torch.no_grad(): # Disable gradient calculation for inference\n", |
540 | 535 | " for batch_data in val_loader:\n", |
541 | 536 | " moving = batch_data[\"moving_hand\"].to(device)\n", |
542 | 537 | " fixed = batch_data[\"fixed_hand\"].to(device)\n", |
543 | 538 | " # Pass only the moving image, consistent with training\n", |
544 | 539 | " pred_image = model(moving)\n", |
545 | 540 | " pred_image = torch.sigmoid(pred_image)\n", |
546 | | - " break # Process only the first batch for visualization\n", |
| 541 | + " break # Process only the first batch for visualization\n", |
547 | 542 | "\n", |
548 | 543 | "fixed_image = fixed.detach().cpu().numpy()[:, 0]\n", |
549 | 544 | "moving_image = moving.detach().cpu().numpy()[:, 0]\n", |
|
596 | 591 | "plt.axis(\"off\")\n", |
597 | 592 | "plt.show()" |
598 | 593 | ] |
599 | | - }, |
600 | | - { |
601 | | - "cell_type": "code", |
602 | | - "execution_count": null, |
603 | | - "metadata": {}, |
604 | | - "outputs": [], |
605 | | - "source": [] |
606 | 594 | } |
607 | 595 | ], |
608 | 596 | "metadata": { |
|
0 commit comments