3939 SPADEAutoencoderKL ,
4040 SPADEDiffusionModelUNet ,
4141)
42- from monai .networks .schedulers import Scheduler
42+ from monai .networks .schedulers import RFlowScheduler , Scheduler
4343from monai .transforms import CenterSpatialCrop , SpatialPad
4444from monai .utils import BlendMode , Ordering , PatchKeys , PytorchPadMode , ensure_tuple , optional_import
4545from monai .visualize import CAM , GradCAM , GradCAMpp
@@ -859,12 +859,18 @@ def sample(
859859 if not scheduler :
860860 scheduler = self .scheduler
861861 image = input_noise
862+
863+ all_next_timesteps = torch .cat ((scheduler .timesteps [1 :], torch .tensor ([0 ], dtype = scheduler .timesteps .dtype )))
862864 if verbose and has_tqdm :
863- progress_bar = tqdm (scheduler .timesteps )
865+ progress_bar = tqdm (
866+ zip (scheduler .timesteps , all_next_timesteps ),
867+ total = min (len (scheduler .timesteps ), len (all_next_timesteps )),
868+ )
864869 else :
865- progress_bar = iter (scheduler .timesteps )
870+ progress_bar = iter (zip ( scheduler .timesteps , all_next_timesteps ) )
866871 intermediates = []
867- for t in progress_bar :
872+
873+ for t , next_t in progress_bar :
868874 # 1. predict noise model_output
869875 diffusion_model = (
870876 partial (diffusion_model , seg = seg )
@@ -882,9 +888,13 @@ def sample(
882888 )
883889
884890 # 2. compute previous image: x_t -> x_t-1
885- image , _ = scheduler .step (model_output , t , image ) # type: ignore[operator]
891+ if not isinstance (scheduler , RFlowScheduler ):
892+ image , _ = scheduler .step (model_output , t , image ) # type: ignore
893+ else :
894+ image , _ = scheduler .step (model_output , t , image , next_t ) # type: ignore
886895 if save_intermediates and t % intermediate_steps == 0 :
887896 intermediates .append (image )
897+
888898 if save_intermediates :
889899 return image , intermediates
890900 else :
@@ -1392,12 +1402,18 @@ def sample( # type: ignore[override]
13921402 if not scheduler :
13931403 scheduler = self .scheduler
13941404 image = input_noise
1405+
1406+ all_next_timesteps = torch .cat ((scheduler .timesteps [1 :], torch .tensor ([0 ], dtype = scheduler .timesteps .dtype )))
13951407 if verbose and has_tqdm :
1396- progress_bar = tqdm (scheduler .timesteps )
1408+ progress_bar = tqdm (
1409+ zip (scheduler .timesteps , all_next_timesteps ),
1410+ total = min (len (scheduler .timesteps ), len (all_next_timesteps )),
1411+ )
13971412 else :
1398- progress_bar = iter (scheduler .timesteps )
1413+ progress_bar = iter (zip ( scheduler .timesteps , all_next_timesteps ) )
13991414 intermediates = []
1400- for t in progress_bar :
1415+
1416+ for t , next_t in progress_bar :
14011417 diffuse = diffusion_model
14021418 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
14031419 diffuse = partial (diffusion_model , seg = seg )
@@ -1436,7 +1452,11 @@ def sample( # type: ignore[override]
14361452 )
14371453
14381454 # 3. compute previous image: x_t -> x_t-1
1439- image , _ = scheduler .step (model_output , t , image ) # type: ignore[operator]
1455+ if not isinstance (scheduler , RFlowScheduler ):
1456+ image , _ = scheduler .step (model_output , t , image ) # type: ignore
1457+ else :
1458+ image , _ = scheduler .step (model_output , t , image , next_t ) # type: ignore
1459+
14401460 if save_intermediates and t % intermediate_steps == 0 :
14411461 intermediates .append (image )
14421462 if save_intermediates :
0 commit comments