Skip to content

Commit df0af78

Browse files
committed
add device type in autocast and GradScaler
Signed-off-by: YunLiu <[email protected]>
1 parent 8d5eb16 commit df0af78

File tree

25 files changed

+535
-533
lines changed

25 files changed

+535
-533
lines changed

automl/DiNTS/search_dints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def main():
431431
if amp:
432432
from torch import autocast, GradScaler
433433

434-
scaler = GradScaler()
434+
scaler = GradScaler("cuda")
435435
if dist.get_rank() == 0:
436436
print("[info] amp enabled")
437437

@@ -487,7 +487,7 @@ def main():
487487
optimizer.zero_grad()
488488

489489
if amp:
490-
with autocast():
490+
with autocast("cuda"):
491491
outputs = model(inputs)
492492
if output_classes == 2:
493493
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
@@ -559,7 +559,7 @@ def main():
559559
combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup)
560560

561561
if amp:
562-
with autocast():
562+
with autocast("cuda"):
563563
outputs_search = model(inputs_search)
564564
if output_classes == 2:
565565
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)

automl/DiNTS/train_dints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def main():
408408
if amp:
409409
from torch import autocast, GradScaler
410410

411-
scaler = GradScaler()
411+
scaler = GradScaler("cuda")
412412
if dist.get_rank() == 0:
413413
print("[info] amp enabled")
414414

@@ -450,7 +450,7 @@ def main():
450450
param.grad = None
451451

452452
if amp:
453-
with autocast():
453+
with autocast("cuda"):
454454
outputs = model(inputs)
455455
if output_classes == 2:
456456
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)

competitions/MICCAI/surgtoolloc/classification_files/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def main(cfg):
8585
metric = ConfusionMatrixMetric(metric_name="F1", reduction="mean_batch")
8686

8787
# set other tools
88-
scaler = GradScaler()
88+
scaler = GradScaler("cuda")
8989
writer = SummaryWriter(str(cfg.output_dir + f"/fold{cfg.fold}/"))
9090

9191
# train and val loop
@@ -171,11 +171,11 @@ def run_train(
171171
torch.set_grad_enabled(True)
172172
if torch.rand(1) > 0.5:
173173
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
174-
with autocast():
174+
with autocast("cuda"):
175175
outputs = model(inputs)
176176
loss = lam * loss_function(outputs, labels_a) + (1 - lam) * loss_function(outputs, labels_b)
177177
else:
178-
with autocast():
178+
with autocast("cuda"):
179179
outputs = model(inputs)
180180
loss = loss_function(outputs, labels)
181181
losses.append(loss.item())

competitions/kaggle/RANZCR/4th_place_solution/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(cfg):
8383

8484
# set other tools
8585
if cfg.mixed_precision:
86-
scaler = GradScaler()
86+
scaler = GradScaler("cuda")
8787
else:
8888
scaler = None
8989

@@ -168,7 +168,7 @@ def run_train(
168168
torch.set_grad_enabled(True)
169169

170170
if cfg.mixed_precision:
171-
with autocast():
171+
with autocast("cuda"):
172172
output_dict = model(batch)
173173
else:
174174
output_dict = model(batch)
@@ -210,7 +210,7 @@ def run_eval(model, val_dataloader, cfg, writer, epoch):
210210
for batch in val_dataloader:
211211
batch = cfg.to_device_transform(batch)
212212
if cfg.mixed_precision:
213-
with autocast():
213+
with autocast("cuda"):
214214
output = model(batch)
215215
else:
216216
output = model(batch)
@@ -271,7 +271,7 @@ def run_infer(weights_folder_path, cfg):
271271
batch = to_device_transform(batch)
272272
for i, net in enumerate(nets):
273273
if cfg.mixed_precision:
274-
with autocast():
274+
with autocast("cuda"):
275275
logits = net(batch)["logits"].cpu().numpy()
276276
else:
277277
logits = net(batch)["logits"].cpu().numpy()

generation/2d_ddpm/2d_ddpm_inpainting.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@
476476
" epoch_loss_list = []\n",
477477
" val_epoch_loss_list = []\n",
478478
"\n",
479-
" scaler = GradScaler()\n",
479+
" scaler = GradScaler(\"cuda\")\n",
480480
" total_start = time.time()\n",
481481
" for epoch in range(max_epochs):\n",
482482
" model.train()\n",

generation/2d_ddpm/2d_ddpm_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@
494494
" epoch_loss_list = []\n",
495495
" val_epoch_loss_list = []\n",
496496
"\n",
497-
" scaler = GradScaler()\n",
497+
" scaler = GradScaler(\"cuda\")\n",
498498
" total_start = time.time()\n",
499499
" for epoch in range(max_epochs):\n",
500500
" model.train()\n",

generation/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@
466466
"epoch_loss_list = []\n",
467467
"val_epoch_loss_list = []\n",
468468
"\n",
469-
"scaler = GradScaler()\n",
469+
"scaler = GradScaler(\"cuda\")\n",
470470
"total_start = time.time()\n",
471471
"for epoch in range(max_epochs):\n",
472472
" model.train()\n",

generation/2d_ldm/2d_ldm_tutorial.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,8 @@
401401
"optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=5e-4)\n",
402402
"\n",
403403
"# For mixed precision training\n",
404-
"scaler_g = GradScaler()\n",
405-
"scaler_d = GradScaler()"
404+
"scaler_g = GradScaler(\"cuda\")\n",
405+
"scaler_d = GradScaler(\"cuda\")"
406406
]
407407
},
408408
{
@@ -751,7 +751,7 @@
751751
"val_interval = 40\n",
752752
"epoch_losses = []\n",
753753
"val_losses = []\n",
754-
"scaler = GradScaler()\n",
754+
"scaler = GradScaler(\"cuda\")\n",
755755
"\n",
756756
"for epoch in range(max_epochs):\n",
757757
" unet.train()\n",

generation/2d_ldm/train_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def main():
182182
max_epochs = args.diffusion_train["max_epochs"]
183183
val_interval = args.diffusion_train["val_interval"]
184184
autoencoder.eval()
185-
scaler = GradScaler()
185+
scaler = GradScaler("cuda")
186186
total_step = 0
187187
best_val_recon_epoch_loss = 100.0
188188

generation/2d_super_resolution/2d_sd_super_resolution.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@
407407
"metadata": {},
408408
"outputs": [],
409409
"source": [
410-
"scaler_g = GradScaler()\n",
411-
"scaler_d = GradScaler()"
410+
"scaler_g = GradScaler(\"cuda\")\n",
411+
"scaler_d = GradScaler(\"cuda\")"
412412
]
413413
},
414414
{
@@ -973,7 +973,7 @@
973973
"# Optimizers\n",
974974
"optimizer = torch.optim.Adam(unet.parameters(), lr=5e-5)\n",
975975
"\n",
976-
"scaler_diffusion = GradScaler()\n",
976+
"scaler_diffusion = GradScaler(\"cuda\")\n",
977977
"\n",
978978
"max_epochs = 200\n",
979979
"val_interval = 20\n",

0 commit comments

Comments
 (0)