@@ -390,7 +390,7 @@ def main(
390390
391391 This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
392392 and extracts feature maps via a 2.5D approach, then computes the Frechet Inception
393- Distance (FID) across three orthogonal planes. Data parallelism is implemented
393+ Distance (FID) across three orthogonal planes. Data parallelism is implemented
394394 using torch.distributed with an NCCL backend.
395395
396396 Args:
@@ -406,7 +406,7 @@ def main(
406406 ...
407407 These entries will be appended to `real_dataset_root`.
408408 real_features_dir (str):
409- Name of the directory under `output_root` in which to store
409+ Name of the directory under `output_root` in which to store
410410 extracted features for the real dataset.
411411
412412 synth_dataset_root (str):
@@ -420,7 +420,7 @@ def main(
420420 ...
421421 These entries will be appended to `synth_dataset_root`.
422422 synth_features_dir (str):
423- Name of the directory under `output_root` in which to store
423+ Name of the directory under `output_root` in which to store
424424 extracted features for the synthetic dataset.
425425
426426 enable_center_slices_ratio (float or None):
@@ -500,14 +500,12 @@ def main(
500500 # -------------------------------------------------------------------------
501501 if model_name == "radimagenet_resnet50" :
502502 feature_network = torch .hub .load (
503- "Warvito/radimagenet-models" ,
504- model = "radimagenet_resnet50" ,
505- verbose = True ,
506- trust_repo = True
503+ "Warvito/radimagenet-models" , model = "radimagenet_resnet50" , verbose = True , trust_repo = True
507504 )
508505 suffix = "radimagenet_resnet50"
509506 else :
510507 import torchvision
508+
511509 feature_network = torchvision .models .squeezenet1_1 (pretrained = True )
512510 suffix = "squeezenet1_1"
513511
@@ -545,10 +543,7 @@ def main(
545543
546544 real_filenames = [{"image" : os .path .join (real_dataset_root , f )} for f in real_lines ]
547545 real_filenames = monai .data .partition_dataset (
548- data = real_filenames ,
549- shuffle = False ,
550- num_partitions = world_size ,
551- even_divisible = False
546+ data = real_filenames , shuffle = False , num_partitions = world_size , even_divisible = False
552547 )[local_rank ]
553548
554549 # -------------------------------------------------------------------------
@@ -562,10 +557,7 @@ def main(
562557
563558 synth_filenames = [{"image" : os .path .join (synth_dataset_root , f )} for f in synth_lines ]
564559 synth_filenames = monai .data .partition_dataset (
565- data = synth_filenames ,
566- shuffle = False ,
567- num_partitions = world_size ,
568- even_divisible = False
560+ data = synth_filenames , shuffle = False , num_partitions = world_size , even_divisible = False
569561 )[local_rank ]
570562
571563 # -------------------------------------------------------------------------
@@ -578,23 +570,15 @@ def main(
578570 ]
579571
580572 if enable_resampling :
581- transform_list .append (
582- monai .transforms .Spacingd (
583- keys = ["image" ], pixdim = rs_spacing_tuple , mode = ["bilinear" ]
584- )
585- )
573+ transform_list .append (monai .transforms .Spacingd (keys = ["image" ], pixdim = rs_spacing_tuple , mode = ["bilinear" ]))
586574
587575 if enable_padding :
588576 transform_list .append (
589- monai .transforms .SpatialPadd (
590- keys = ["image" ], spatial_size = target_shape_tuple , mode = "constant" , value = - 1000
591- )
577+ monai .transforms .SpatialPadd (keys = ["image" ], spatial_size = target_shape_tuple , mode = "constant" , value = - 1000 )
592578 )
593579
594580 if enable_center_cropping :
595- transform_list .append (
596- monai .transforms .CenterSpatialCropd (keys = ["image" ], roi_size = target_shape_tuple )
597- )
581+ transform_list .append (monai .transforms .CenterSpatialCropd (keys = ["image" ], roi_size = target_shape_tuple ))
598582
599583 transform_list .append (
600584 monai .transforms .ScaleIntensityRanged (
@@ -638,9 +622,7 @@ def main(
638622 center_slices_ratio = center_slices_ratio_final ,
639623 xy_only = False ,
640624 )
641- logger .info (
642- f"feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } "
643- )
625+ logger .info (f"feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
644626 torch .save (feats , out_fp )
645627
646628 real_features_xy .append (feats [0 ])
@@ -651,8 +633,7 @@ def main(
651633 real_features_yz = torch .vstack (real_features_yz )
652634 real_features_zx = torch .vstack (real_features_zx )
653635 logger .info (
654- f"Real feature shapes: { real_features_xy .shape } , "
655- f"{ real_features_yz .shape } , { real_features_zx .shape } "
636+ f"Real feature shapes: { real_features_xy .shape } , " f"{ real_features_yz .shape } , { real_features_zx .shape } "
656637 )
657638
658639 # -------------------------------------------------------------------------
@@ -681,9 +662,7 @@ def main(
681662 center_slices_ratio = center_slices_ratio_final ,
682663 xy_only = False ,
683664 )
684- logger .info (
685- f"feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } "
686- )
665+ logger .info (f"feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
687666 torch .save (feats , out_fp )
688667
689668 synth_features_xy .append (feats [0 ])
@@ -694,8 +673,7 @@ def main(
694673 synth_features_yz = torch .vstack (synth_features_yz )
695674 synth_features_zx = torch .vstack (synth_features_zx )
696675 logger .info (
697- f"Synth feature shapes: { synth_features_xy .shape } , "
698- f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
676+ f"Synth feature shapes: { synth_features_xy .shape } , " f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
699677 )
700678
701679 # -------------------------------------------------------------------------
0 commit comments