diff --git a/src/membrain_seg/segmentation/segment.py b/src/membrain_seg/segmentation/segment.py index 93802ed..f5ccda0 100644 --- a/src/membrain_seg/segmentation/segment.py +++ b/src/membrain_seg/segmentation/segment.py @@ -78,8 +78,9 @@ def segment( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize the model and load trained weights from checkpoint - pl_model = SemanticSegmentationUnet.load_from_checkpoint(model_checkpoint, map_location=device, strict=False) - + pl_model = SemanticSegmentationUnet.load_from_checkpoint( + model_checkpoint, map_location=device, strict=False + ) pl_model.to(device) # Preprocess the new data