Skip to content

Commit

Permalink
Changed from torch.cuda.amp.autocast to torch.amp.autocast
Browse files Browse the repository at this point in the history
torch.cuda.amp.autocast to be deprecated
  • Loading branch information
jamesmuking5 committed Jun 30, 2024
1 parent bf01bab commit b367f2e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=scaler is not None):
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())

Expand Down

0 comments on commit b367f2e

Please sign in to comment.