diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e9bfffdf8a..fa0f4fe01db 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -27,7 +27,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] - with torch.cuda.amp.autocast(enabled=scaler is not None): + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values())