Skip to content

Commit 2bff1d9

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3ae25ba commit 2bff1d9

File tree

2 files changed

+17
-39
lines changed

2 files changed

+17
-39
lines changed

generation/maisi/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,13 @@ We provide the `compute_fid_2-5d_ct.py` script that calculates the Frechet Incep
253253

254254
#### Key Features
255255

256-
- **Distributed Processing**
256+
- **Distributed Processing**
257257
Scales to multiple GPUs and larger datasets by splitting the workload across devices.
258258

259-
- **2.5D Feature Extraction**
259+
- **2.5D Feature Extraction**
260260
Uses a slice-based technique, applying a 2D model across all slices in each dimension.
261261

262-
- **Flexible Preprocessing**
262+
- **Flexible Preprocessing**
263263
Supports optional center-cropping, padding, and resampling to target shapes or voxel spacings.
264264

265265
#### Usage Example

generation/maisi/scripts/compute_fid_2-5d_ct.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def main(
390390
391391
This function loads two datasets (real vs. synthetic) in 3D medical format (NIfTI)
392392
and extracts feature maps via a 2.5D approach, then computes the Frechet Inception
393-
Distance (FID) across three orthogonal planes. Data parallelism is implemented
393+
Distance (FID) across three orthogonal planes. Data parallelism is implemented
394394
using torch.distributed with an NCCL backend.
395395
396396
Args:
@@ -406,7 +406,7 @@ def main(
406406
...
407407
These entries will be appended to `real_dataset_root`.
408408
real_features_dir (str):
409-
Name of the directory under `output_root` in which to store
409+
Name of the directory under `output_root` in which to store
410410
extracted features for the real dataset.
411411
412412
synth_dataset_root (str):
@@ -420,7 +420,7 @@ def main(
420420
...
421421
These entries will be appended to `synth_dataset_root`.
422422
synth_features_dir (str):
423-
Name of the directory under `output_root` in which to store
423+
Name of the directory under `output_root` in which to store
424424
extracted features for the synthetic dataset.
425425
426426
enable_center_slices_ratio (float or None):
@@ -500,14 +500,12 @@ def main(
500500
# -------------------------------------------------------------------------
501501
if model_name == "radimagenet_resnet50":
502502
feature_network = torch.hub.load(
503-
"Warvito/radimagenet-models",
504-
model="radimagenet_resnet50",
505-
verbose=True,
506-
trust_repo=True
503+
"Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True, trust_repo=True
507504
)
508505
suffix = "radimagenet_resnet50"
509506
else:
510507
import torchvision
508+
511509
feature_network = torchvision.models.squeezenet1_1(pretrained=True)
512510
suffix = "squeezenet1_1"
513511

@@ -545,10 +543,7 @@ def main(
545543

546544
real_filenames = [{"image": os.path.join(real_dataset_root, f)} for f in real_lines]
547545
real_filenames = monai.data.partition_dataset(
548-
data=real_filenames,
549-
shuffle=False,
550-
num_partitions=world_size,
551-
even_divisible=False
546+
data=real_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
552547
)[local_rank]
553548

554549
# -------------------------------------------------------------------------
@@ -562,10 +557,7 @@ def main(
562557

563558
synth_filenames = [{"image": os.path.join(synth_dataset_root, f)} for f in synth_lines]
564559
synth_filenames = monai.data.partition_dataset(
565-
data=synth_filenames,
566-
shuffle=False,
567-
num_partitions=world_size,
568-
even_divisible=False
560+
data=synth_filenames, shuffle=False, num_partitions=world_size, even_divisible=False
569561
)[local_rank]
570562

571563
# -------------------------------------------------------------------------
@@ -578,23 +570,15 @@ def main(
578570
]
579571

580572
if enable_resampling:
581-
transform_list.append(
582-
monai.transforms.Spacingd(
583-
keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]
584-
)
585-
)
573+
transform_list.append(monai.transforms.Spacingd(keys=["image"], pixdim=rs_spacing_tuple, mode=["bilinear"]))
586574

587575
if enable_padding:
588576
transform_list.append(
589-
monai.transforms.SpatialPadd(
590-
keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000
591-
)
577+
monai.transforms.SpatialPadd(keys=["image"], spatial_size=target_shape_tuple, mode="constant", value=-1000)
592578
)
593579

594580
if enable_center_cropping:
595-
transform_list.append(
596-
monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple)
597-
)
581+
transform_list.append(monai.transforms.CenterSpatialCropd(keys=["image"], roi_size=target_shape_tuple))
598582

599583
transform_list.append(
600584
monai.transforms.ScaleIntensityRanged(
@@ -638,9 +622,7 @@ def main(
638622
center_slices_ratio=center_slices_ratio_final,
639623
xy_only=False,
640624
)
641-
logger.info(
642-
f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}"
643-
)
625+
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
644626
torch.save(feats, out_fp)
645627

646628
real_features_xy.append(feats[0])
@@ -651,8 +633,7 @@ def main(
651633
real_features_yz = torch.vstack(real_features_yz)
652634
real_features_zx = torch.vstack(real_features_zx)
653635
logger.info(
654-
f"Real feature shapes: {real_features_xy.shape}, "
655-
f"{real_features_yz.shape}, {real_features_zx.shape}"
636+
f"Real feature shapes: {real_features_xy.shape}, " f"{real_features_yz.shape}, {real_features_zx.shape}"
656637
)
657638

658639
# -------------------------------------------------------------------------
@@ -681,9 +662,7 @@ def main(
681662
center_slices_ratio=center_slices_ratio_final,
682663
xy_only=False,
683664
)
684-
logger.info(
685-
f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}"
686-
)
665+
logger.info(f"feats shapes: {feats[0].shape}, {feats[1].shape}, {feats[2].shape}")
687666
torch.save(feats, out_fp)
688667

689668
synth_features_xy.append(feats[0])
@@ -694,8 +673,7 @@ def main(
694673
synth_features_yz = torch.vstack(synth_features_yz)
695674
synth_features_zx = torch.vstack(synth_features_zx)
696675
logger.info(
697-
f"Synth feature shapes: {synth_features_xy.shape}, "
698-
f"{synth_features_yz.shape}, {synth_features_zx.shape}"
676+
f"Synth feature shapes: {synth_features_xy.shape}, " f"{synth_features_yz.shape}, {synth_features_zx.shape}"
699677
)
700678

701679
# -------------------------------------------------------------------------

0 commit comments

Comments
 (0)