1- #!/usr/bin/env python
21# Copyright (c) MONAI Consortium
32# Licensed under the Apache License, Version 2.0 (the "License");
43# you may not use this file except in compliance with the License.
1312# and limitations under the License.
1413
1514"""
16- Compute 2.5D FID using distributed GPU processing, **without** external fid_utils dependencies .
15+ Compute 2.5D FID using distributed GPU processing.
1716
1817SHELL Usage Example:
1918-------------------
2221 export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6
2322 NUM_GPUS=7
2423
25- torchrun --nproc_per_node=${NUM_GPUS} compute_fid2p5d_ct .py \
24+ torchrun --nproc_per_node=${NUM_GPUS} compute_fid_2-5d_ct .py \
2625 --model_name "radimagenet_resnet50" \
2726 --data0_dataroot "path/to/datasetA" \
2827 --data0_filelist "path/to/filelistA.txt" \
8281from monai .metrics .fid import FIDMetric
8382from monai .transforms import Compose
8483
84+ import logging
85+
8586# ------------------------------------------------------------------------------
86- # Below are the core utilities originally in fid_utils.py, now inlined here
87- # to remove external dependency.
87+ # Create logger
8888# ------------------------------------------------------------------------------
89+ logger = logging .getLogger ("fid_2-5d_ct" )
90+ if not logger .handlers :
91+ # Configure logger only if it has no handlers (avoid reconfiguring in multi-rank scenarios)
92+ logging .basicConfig (stream = sys .stdout , level = logging .INFO )
93+ logger .setLevel (logging .INFO )
8994
9095
9196def drop_empty_slice (slices , empty_threshold : float ):
@@ -111,7 +116,7 @@ def drop_empty_slice(slices, empty_threshold: float):
111116 else :
112117 outputs .append (True )
113118
114- print (f"Empty slice drop rate { round ((n_drop / len (slices ))* 100 ,1 )} %" )
119+ logger . info (f"Empty slice drop rate { round ((n_drop / len (slices ))* 100 ,1 )} %" )
115120 return outputs
116121
117122
@@ -183,7 +188,7 @@ def radimagenet_intensity_normalisation(volume: torch.Tensor, norm2d: bool = Fal
183188 volume (torch.Tensor): Input (B, C, H, W) or (B, C, H, W, D).
184189 norm2d (bool): If True, normalizes each (H,W) slice to [0,1], then subtracts the ImageNet mean.
185190 """
186- print (f"norm2d: { norm2d } " )
191+ logger . info (f"norm2d: { norm2d } " )
187192 dim = len (volume .shape )
188193 # If norm2d is True, only meaningful for 4D data (B, C, H, W):
189194 if dim == 4 and norm2d :
@@ -236,20 +241,18 @@ def get_features_2p5d(
236241 Returns:
237242 tuple of torch.Tensor or None: (XY_features, YZ_features, ZX_features).
238243 """
239- print (f"center_slices: { center_slices } , ratio: { center_slices_ratio } " )
244+ logger . info (f"center_slices: { center_slices } , ratio: { center_slices_ratio } " )
240245
241246 # If there's only 1 channel, replicate to 3 channels
242247 if image .shape [1 ] == 1 :
243248 image = image .repeat (1 , 3 , 1 , 1 , 1 )
244249
245- # Convert from 'RGB'→(R,G,B) to (B, G, R) ordering
250+ # Convert from 'RGB'→(R,G,B) to (B,G,R)
246251 image = image [:, [2 , 1 , 0 ], ...]
247252
248253 B , C , H , W , D = image .size ()
249254 with torch .no_grad ():
250- # ---------------------------------------------------------------------
251- # 1) XY-plane slicing along D
252- # ---------------------------------------------------------------------
255+ # ---------------------- XY-plane slicing along D ----------------------
253256 if center_slices :
254257 start_d = int ((1.0 - center_slices_ratio ) / 2.0 * D )
255258 end_d = int ((1.0 + center_slices_ratio ) / 2.0 * D )
@@ -268,13 +271,10 @@ def get_features_2p5d(
268271
269272 feature_image_xy = feature_network .forward (images_2d )
270273 feature_image_xy = spatial_average (feature_image_xy , keepdim = False )
271-
272274 if xy_only :
273275 return feature_image_xy , None , None
274276
275- # ---------------------------------------------------------------------
276- # 2) YZ-plane slicing along H
277- # ---------------------------------------------------------------------
277+ # ---------------------- YZ-plane slicing along H ----------------------
278278 if center_slices :
279279 start_h = int ((1.0 - center_slices_ratio ) / 2.0 * H )
280280 end_h = int ((1.0 + center_slices_ratio ) / 2.0 * H )
@@ -294,9 +294,7 @@ def get_features_2p5d(
294294 feature_image_yz = feature_network .forward (images_2d )
295295 feature_image_yz = spatial_average (feature_image_yz , keepdim = False )
296296
297- # ---------------------------------------------------------------------
298- # 3) ZX-plane slicing along W
299- # ---------------------------------------------------------------------
297+ # ---------------------- ZX-plane slicing along W ----------------------
300298 if center_slices :
301299 start_w = int ((1.0 - center_slices_ratio ) / 2.0 * W )
302300 end_w = int ((1.0 + center_slices_ratio ) / 2.0 * W )
@@ -319,11 +317,6 @@ def get_features_2p5d(
319317 return feature_image_xy , feature_image_yz , feature_image_zx
320318
321319
322- # ------------------------------------------------------------------------------
323- # End inline fid_utils code
324- # ------------------------------------------------------------------------------
325-
326-
327320def pad_to_max_size (tensor : torch .Tensor , max_size : int , padding_value : float = 0.0 ) -> torch .Tensor :
328321 """
329322 Zero-pad a 2D feature map or other tensor along the first dimension to match a specified size.
@@ -336,7 +329,6 @@ def pad_to_max_size(tensor: torch.Tensor, max_size: int, padding_value: float =
336329 Returns:
337330 torch.Tensor: Padded tensor matching `max_size` along dim=0.
338331 """
339- # For a shape (B, C, ...), we only pad the B dimension up to `max_size`.
340332 pad_size = [0 , 0 ] * (len (tensor .shape ) - 1 ) + [0 , max_size - tensor .shape [0 ]]
341333 return F .pad (tensor , pad_size , "constant" , padding_value )
342334
@@ -395,11 +387,9 @@ def main(
395387 world_size = int (dist .get_world_size ())
396388 device = torch .device ("cuda" , local_rank )
397389 torch .cuda .set_device (device )
398- print (f"[INFO] Running process on { device } of total { world_size } ranks." )
390+ logger . info (f"[INFO] Running process on { device } of total { world_size } ranks." )
399391
400- # -------------------------------------------------------------------------
401392 # Convert potential string bools to actual bools (Fire sometimes passes strings)
402- # -------------------------------------------------------------------------
403393 if not isinstance (enable_center_slices , bool ):
404394 enable_center_slices = enable_center_slices .lower () == "true"
405395 if not isinstance (enable_padding , bool ):
@@ -413,46 +403,44 @@ def main(
413403
414404 # Print out some flags on rank 0
415405 if local_rank == 0 :
416- print (f"[INFO] enable_center_slices: { enable_center_slices } " )
417- print (f"[INFO] enable_padding: { enable_padding } " )
418- print (f"[INFO] enable_center_cropping: { enable_center_cropping } " )
419- print (f"[INFO] enable_resampling: { enable_resampling } " )
420- print (f"[INFO] ignore_existing: { ignore_existing } " )
406+ logger . info (f"enable_center_slices: { enable_center_slices } " )
407+ logger . info (f"enable_padding: { enable_padding } " )
408+ logger . info (f"enable_center_cropping: { enable_center_cropping } " )
409+ logger . info (f"enable_resampling: { enable_resampling } " )
410+ logger . info (f"ignore_existing: { ignore_existing } " )
421411
422412 # -------------------------------------------------------------------------
423413 # Load feature extraction model
424414 # -------------------------------------------------------------------------
425415 if model_name == "radimagenet_resnet50" :
426- # Using a model from Warvito/radimagenet-models on Torch Hub
427416 feature_network = torch .hub .load (
428417 "Warvito/radimagenet-models" , model = "radimagenet_resnet50" , verbose = True , trust_repo = True
429418 )
430419 suffix = "radimagenet_resnet50"
431420 else :
432421 import torchvision
433-
434422 feature_network = torchvision .models .squeezenet1_1 (pretrained = True )
435423 suffix = "squeezenet1_1"
436424
437425 feature_network .to (device )
438426 feature_network .eval ()
439427
440428 # -------------------------------------------------------------------------
441- # Parse shape/spacings from string
429+ # Parse shape/spacings
442430 # -------------------------------------------------------------------------
443431 t_shape = [int (x ) for x in target_shape .split ("x" )]
444432 target_shape_tuple = tuple (t_shape )
445433 if enable_resampling :
446434 rs_spacing = [float (x ) for x in enable_resampling_spacing .split ("x" )]
447435 rs_spacing_tuple = tuple (rs_spacing )
448436 if local_rank == 0 :
449- print (f"[INFO] resampling spacing: { rs_spacing_tuple } " )
437+ logger . info (f"resampling spacing: { rs_spacing_tuple } " )
450438 else :
451439 rs_spacing_tuple = (1.0 , 1.0 , 1.0 )
452440
453441 center_slices_ratio_final = enable_center_slices_ratio if enable_center_slices else 1.0
454442 if local_rank == 0 :
455- print (f"[INFO] center_slices_ratio: { center_slices_ratio_final } " )
443+ logger . info (f"center_slices_ratio: { center_slices_ratio_final } " )
456444
457445 # -------------------------------------------------------------------------
458446 # Prepare dataset 0
@@ -490,25 +478,20 @@ def main(
490478 monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
491479 monai .transforms .Orientationd (keys = ["image" ], axcodes = "RAS" ),
492480 ]
493-
494481 if enable_resampling :
495482 transform_list .append (monai .transforms .Spacingd (keys = ["image" ], pixdim = rs_spacing_tuple , mode = ["bilinear" ]))
496483 if enable_padding :
497484 transform_list .append (
498- monai .transforms .SpatialPadd (
499- keys = ["image" ], spatial_size = target_shape_tuple , mode = ["constant" ], value = - 1000
500- )
485+ monai .transforms .SpatialPadd (keys = ["image" ], spatial_size = target_shape_tuple , mode = "constant" , value = - 1000 )
501486 )
502487 if enable_center_cropping :
503488 transform_list .append (monai .transforms .CenterSpatialCropd (keys = ["image" ], roi_size = target_shape_tuple ))
504489
505- # Intensity scaling to clamp between [-1000, 1000]
506490 transform_list .append (
507491 monai .transforms .ScaleIntensityRanged (
508492 keys = ["image" ], a_min = - 1000 , a_max = 1000 , b_min = - 1000 , b_max = 1000 , clip = True
509493 )
510494 )
511-
512495 transforms = Compose (transform_list )
513496
514497 # -------------------------------------------------------------------------
@@ -527,7 +510,7 @@ def main(
527510 for idx , batch_data in enumerate (real_loader , start = 1 ):
528511 img = batch_data ["image" ].to (device )
529512 fn = img .meta ["filename_or_obj" ][0 ]
530- print (f"[Rank { local_rank } ] Real data { idx } /{ len (filenames0 )} : { fn } " )
513+ logger . info (f"[Rank { local_rank } ] Real data { idx } /{ len (filenames0 )} : { fn } " )
531514
532515 out_fp = fn .replace (data0_dataroot , output_root0 ).replace (".nii.gz" , ".pt" )
533516 out_fp = Path (out_fp )
@@ -537,17 +520,19 @@ def main(
537520 feats = torch .load (out_fp )
538521 else :
539522 img_t = img .as_tensor ()
540- print (f"[INFO] image shape: { tuple (img_t .shape )} " )
523+ logger . info (f"image shape: { tuple (img_t .shape )} " )
541524
542- # Inline get_features_2p5d
543525 feats = get_features_2p5d (
544526 img_t ,
545527 feature_network ,
546528 center_slices = enable_center_slices ,
547529 center_slices_ratio = center_slices_ratio_final ,
548530 xy_only = False ,
549531 )
550- print (f"[INFO] feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
532+ logger .info (
533+ f"feats shapes: { feats [0 ].shape } , "
534+ f"{ feats [1 ].shape } , { feats [2 ].shape } "
535+ )
551536 torch .save (feats , out_fp )
552537
553538 real_features_xy .append (feats [0 ])
@@ -557,7 +542,10 @@ def main(
557542 real_features_xy = torch .vstack (real_features_xy )
558543 real_features_yz = torch .vstack (real_features_yz )
559544 real_features_zx = torch .vstack (real_features_zx )
560- print (f"[INFO] Real feature shapes: { real_features_xy .shape } , { real_features_yz .shape } , { real_features_zx .shape } " )
545+ logger .info (
546+ f"Real feature shapes: { real_features_xy .shape } , "
547+ f"{ real_features_yz .shape } , { real_features_zx .shape } "
548+ )
561549
562550 # -------------------------------------------------------------------------
563551 # Extract features for dataset 1
@@ -566,7 +554,7 @@ def main(
566554 for idx , batch_data in enumerate (synt_loader , start = 1 ):
567555 img = batch_data ["image" ].to (device )
568556 fn = img .meta ["filename_or_obj" ][0 ]
569- print (f"[Rank { local_rank } ] Synthetic data { idx } /{ len (filenames1 )} : { fn } " )
557+ logger . info (f"[Rank { local_rank } ] Synthetic data { idx } /{ len (filenames1 )} : { fn } " )
570558
571559 out_fp = fn .replace (data1_dataroot , output_root1 ).replace (".nii.gz" , ".pt" )
572560 out_fp = Path (out_fp )
@@ -576,7 +564,7 @@ def main(
576564 feats = torch .load (out_fp )
577565 else :
578566 img_t = img .as_tensor ()
579- print (f"[INFO] image shape: { tuple (img_t .shape )} " )
567+ logger . info (f"image shape: { tuple (img_t .shape )} " )
580568
581569 feats = get_features_2p5d (
582570 img_t ,
@@ -585,7 +573,10 @@ def main(
585573 center_slices_ratio = center_slices_ratio_final ,
586574 xy_only = False ,
587575 )
588- print (f"[INFO] feats shapes: { feats [0 ].shape } , { feats [1 ].shape } , { feats [2 ].shape } " )
576+ logger .info (
577+ f"feats shapes: { feats [0 ].shape } , "
578+ f"{ feats [1 ].shape } , { feats [2 ].shape } "
579+ )
589580 torch .save (feats , out_fp )
590581
591582 synth_features_xy .append (feats [0 ])
@@ -595,8 +586,8 @@ def main(
595586 synth_features_xy = torch .vstack (synth_features_xy )
596587 synth_features_yz = torch .vstack (synth_features_yz )
597588 synth_features_zx = torch .vstack (synth_features_zx )
598- print (
599- f"[INFO] Synthetic feature shapes: { synth_features_xy .shape } , "
589+ logger . info (
590+ f"Synthetic feature shapes: { synth_features_xy .shape } , "
600591 f"{ synth_features_yz .shape } , { synth_features_zx .shape } "
601592 )
602593
@@ -649,25 +640,27 @@ def main(
649640 synth_yz = torch .vstack (all_tensors_list [4 ])
650641 synth_zx = torch .vstack (all_tensors_list [5 ])
651642
652- print (f"[INFO] Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } " )
653- print (f"[INFO] Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } " )
643+ logger .info (
644+ f"Final Real shapes: { real_xy .shape } , { real_yz .shape } , { real_zx .shape } "
645+ )
646+ logger .info (
647+ f"Final Synth shapes: { synth_xy .shape } , { synth_yz .shape } , { synth_zx .shape } "
648+ )
654649
655650 fid = FIDMetric ()
656- print (f"\n [INFO] Computing FID for: { output_root0 } | { output_root1 } " )
651+ logger . info (f"Computing FID for: { output_root0 } | { output_root1 } " )
657652 fid_res_xy = fid (synth_xy , real_xy )
658653 fid_res_yz = fid (synth_yz , real_yz )
659654 fid_res_zx = fid (synth_zx , real_zx )
660655
661- print (f" FID XY: { fid_res_xy } " )
662- print (f" FID YZ: { fid_res_yz } " )
663- print (f" FID ZX: { fid_res_zx } " )
656+ logger . info (f"FID XY: { fid_res_xy } " )
657+ logger . info (f"FID YZ: { fid_res_yz } " )
658+ logger . info (f"FID ZX: { fid_res_zx } " )
664659 fid_avg = (fid_res_xy + fid_res_yz + fid_res_zx ) / 3.0
665- print (f" FID Avg: { fid_avg } " )
660+ logger . info (f"FID Avg: { fid_avg } " )
666661
667662 dist .destroy_process_group ()
668663
669664
670665if __name__ == "__main__" :
671- # Using python-fire for command-line interface.
672- # e.g., python compute_fid2d_mgpu.py --model_name=radimagenet_resnet50 --num_images=100 ...
673666 fire .Fire (main )
0 commit comments