@@ -419,6 +419,7 @@ def main(
419419 suffix = "radimagenet_resnet50"
420420 else :
421421 import torchvision
422+
422423 feature_network = torchvision .models .squeezenet1_1 (pretrained = True )
423424 suffix = "squeezenet1_1"
424425
@@ -529,10 +530,7 @@ def main(
529530 center_slices_ratio = center_slices_ratio_final ,
530531 xy_only = False ,
531532 )
532- logger .info (
533- f"feats shapes: { feats [0 ].shape } , "
534- f"{ feats [1 ].shape } , { feats [2 ].shape } "
535- )
533+ logger .info (f"feats shapes: { feats [0 ].shape } , " f"{ feats [1 ].shape } , { feats [2 ].shape } " )
536534 torch .save (feats , out_fp )
537535
538536 real_features_xy .append (feats [0 ])
@@ -543,8 +541,7 @@ def main(
543541 real_features_yz = torch .vstack (real_features_yz )
544542 real_features_zx = torch .vstack (real_features_zx )
545543 logger .info (
546- f"Real feature shapes: { real_features_xy .shape } , "
547- f"{ real_features_yz .shape } , { real_features_zx .shape } "
544+ f"Real feature shapes: { real_features_xy .shape } , " f"{ real_features_yz .shape } , { real_features_zx .shape } "
548545 )
549546
550547 # -------------------------------------------------------------------------
@@ -573,10 +570,7 @@ def main(
573570 center_slices_ratio = center_slices_ratio_final ,
574571 xy_only = False ,
575572 )
576- logger .info (
577- f"feats shapes: { feats [0 ].shape } , "
578- f"{ feats [1 ].shape } , { feats [2 ].shape } "
579- )
573+ logger .info (f"feats shapes: { feats [0 ].shape } , " f"{ feats [1 ].shape } , { feats [2 ].shape } " )
580574 torch .save (feats , out_fp )
581575
582576 synth_features_xy .append (feats [0 ])
@@ -587,8 +581,7 @@ def main(
587581 synth_features_yz = torch .vstack (synth_features_yz )
588582 synth_features_zx = torch .vstack (synth_features_zx )
589583 logger .info (
590- f"Synthetic feature shapes: { synth_features_xy .shape } , "
591- f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
584+ f"Synthetic feature shapes: { synth_features_xy .shape } , " f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
592585 )
593586
594587 # -------------------------------------------------------------------------
@@ -640,12 +633,8 @@ def main(
640633 synth_yz = torch .vstack (all_tensors_list [4 ])
641634 synth_zx = torch .vstack (all_tensors_list [5 ])
642635
643- logger .info (
644- f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } "
645- )
646- logger .info (
647- f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } "
648- )
636+ logger .info (f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } " )
637+ logger .info (f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } " )
649638
650639 fid = FIDMetric ()
651640 logger .info (f"Computing FID for: { output_root0 } | { output_root1 } " )
0 commit comments