|
610 | 610 | "source": [ |
611 | 611 | "# Automatic mixed precision (AMP) for faster training\n", |
612 | 612 | "amp_enabled = True\n", |
613 | | - "scaler = torch.cuda.amp.GradScaler()\n", |
| 613 | + "scaler = torch.GradScaler(\"cuda\")\n", |
614 | 614 | "\n", |
615 | 615 | "# Tensorboard\n", |
616 | 616 | "if do_save:\n", |
|
646 | 646 | "\n", |
647 | 647 | " # Forward pass and loss\n", |
648 | 648 | " optimizer.zero_grad()\n", |
649 | | - " with torch.cuda.amp.autocast(enabled=amp_enabled):\n", |
| 649 | + " with torch.autocast(\"cuda\", enabled=amp_enabled):\n", |
650 | 650 | " ddf_image, pred_image, pred_label_one_hot = forward(\n", |
651 | 651 | " fixed_image, moving_image, moving_label, model, warp_layer, num_classes=4\n", |
652 | 652 | " )\n", |
|
694 | 694 | " # moving_label_35 = batch_data[\"moving_label_35\"].to(device)\n", |
695 | 695 | " n_steps += 1\n", |
696 | 696 | " # Infer\n", |
697 | | - " with torch.cuda.amp.autocast(enabled=amp_enabled):\n", |
| 697 | + " with torch.autocast(\"cuda\", enabled=amp_enabled):\n", |
698 | 698 | " ddf_image, pred_image, pred_label_one_hot = forward(\n", |
699 | 699 | " fixed_image, moving_image, moving_label_4, model, warp_layer, num_classes=4\n", |
700 | 700 | " )\n", |
|
840 | 840 | " model = VoxelMorph()\n", |
841 | 841 | " # load model weights\n", |
842 | 842 | " filename_best_model = glob.glob(os.path.join(dir_load, \"voxelmorph_loss_best_dice_*\"))[0]\n", |
843 | | - " model.load_state_dict(torch.load(filename_best_model))\n", |
| 843 | + " model.load_state_dict(torch.load(filename_best_model, weights_only=True))\n", |
844 | 844 | " # to GPU\n", |
845 | 845 | " model.to(device)\n", |
846 | 846 | "\n", |
|
860 | 860 | "# Forward pass\n", |
861 | 861 | "model.eval()\n", |
862 | 862 | "with torch.no_grad():\n", |
863 | | - " with torch.cuda.amp.autocast(enabled=amp_enabled):\n", |
| 863 | + " with torch.autocast(\"cuda\", enabled=amp_enabled):\n", |
864 | 864 | " ddf_image, pred_image, pred_label_one_hot = forward(\n", |
865 | 865 | " fixed_image, moving_image, moving_label_35, model, warp_layer, num_classes=35\n", |
866 | 866 | " )" |
|
0 commit comments