@@ -51,12 +51,12 @@ def load_filenames(data_list_path: str) -> list:
5151
5252
5353def prepare_data (
54- train_files : list ,
55- device : torch .device ,
56- cache_rate : float ,
57- num_workers : int = 2 ,
58- batch_size : int = 1 ,
59- include_body_region : bool = False
54+ train_files : list ,
55+ device : torch .device ,
56+ cache_rate : float ,
57+ num_workers : int = 2 ,
58+ batch_size : int = 1 ,
59+ include_body_region : bool = False ,
6060) -> DataLoader :
6161 """
6262 Prepare training data.
@@ -78,11 +78,11 @@ def _load_data_from_file(file_path, key):
7878 return torch .FloatTensor (json .load (f )[key ])
7979
8080 train_transforms_list = [
81- monai .transforms .LoadImaged (keys = ["image" ]),
82- monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
83- monai .transforms .Lambdad (keys = "spacing" , func = lambda x : _load_data_from_file (x , "spacing" )),
84- monai .transforms .Lambdad (keys = "spacing" , func = lambda x : x * 1e2 ),
85- ]
81+ monai .transforms .LoadImaged (keys = ["image" ]),
82+ monai .transforms .EnsureChannelFirstd (keys = ["image" ]),
83+ monai .transforms .Lambdad (keys = "spacing" , func = lambda x : _load_data_from_file (x , "spacing" )),
84+ monai .transforms .Lambdad (keys = "spacing" , func = lambda x : x * 1e2 ),
85+ ]
8686 if include_body_region :
8787 train_transforms_list += [
8888 monai .transforms .Lambdad (
@@ -202,7 +202,7 @@ def train_one_epoch(
202202 logger : logging .Logger ,
203203 local_rank : int ,
204204 amp : bool = True ,
205- include_body_region : bool = False
205+ include_body_region : bool = False ,
206206) -> torch .Tensor :
207207 """
208208 Train the model for one epoch.
@@ -284,9 +284,10 @@ def train_one_epoch(
284284 # predict velocity
285285 loss = loss_pt (model_output .float (), (images - noise ).float ())
286286 else :
287- raise ValueError ("noise scheduler prediction type has to be chosen from " ,
288- f"[{ DDPMPredictionType .EPSILON } ,{ DDPMPredictionType .SAMPLE } ,{ DDPMPredictionType .V_PREDICTION } ]"
289- )
287+ raise ValueError (
288+ "noise scheduler prediction type has to be chosen from " ,
289+ f"[{ DDPMPredictionType .EPSILON } ,{ DDPMPredictionType .SAMPLE } ,{ DDPMPredictionType .V_PREDICTION } ]" ,
290+ )
290291
291292 if amp :
292293 scaler .scale (loss ).backward ()
@@ -349,7 +350,12 @@ def save_checkpoint(
349350
350351
351352def diff_model_train (
352- env_config_path : str , model_config_path : str , model_def_path : str , num_gpus : int , amp : bool = True , include_body_region : bool = False
353+ env_config_path : str ,
354+ model_config_path : str ,
355+ model_def_path : str ,
356+ num_gpus : int ,
357+ amp : bool = True ,
358+ include_body_region : bool = False ,
353359) -> None :
354360 """
355361 Main function to train a diffusion model.
@@ -400,9 +406,11 @@ def diff_model_train(
400406 )[local_rank ]
401407
402408 train_loader = prepare_data (
403- train_files , device , args .diffusion_unet_train ["cache_rate" ],
409+ train_files ,
410+ device ,
411+ args .diffusion_unet_train ["cache_rate" ],
404412 batch_size = args .diffusion_unet_train ["batch_size" ],
405- include_body_region = include_body_region
413+ include_body_region = include_body_region ,
406414 )
407415
408416 unet = load_unet (args , device , logger )
@@ -438,7 +446,7 @@ def diff_model_train(
438446 logger ,
439447 local_rank ,
440448 amp = amp ,
441- include_body_region = include_body_region
449+ include_body_region = include_body_region ,
442450 )
443451
444452 loss_torch = loss_torch .tolist ()
@@ -479,7 +487,14 @@ def diff_model_train(
479487 )
480488 parser .add_argument ("--num_gpus" , type = int , default = 1 , help = "Number of GPUs to use for training" )
481489 parser .add_argument ("--no_amp" , dest = "amp" , action = "store_false" , help = "Disable automatic mixed precision training" )
482- parser .add_argument ("--include_body_region" , dest = "include_body_region" , action = "store_true" , help = "Whether to include body region in data" )
490+ parser .add_argument (
491+ "--include_body_region" ,
492+ dest = "include_body_region" ,
493+ action = "store_true" ,
494+ help = "Whether to include body region in data" ,
495+ )
483496
484497 args = parser .parse_args ()
485- diff_model_train (args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp , args .include_body_region )
498+ diff_model_train (
499+ args .env_config , args .model_config , args .model_def , args .num_gpus , args .amp , args .include_body_region
500+ )
0 commit comments