File tree Expand file tree Collapse file tree 2 files changed +9
-3
lines changed
competitions/MICCAI/surgtoolloc/classification_files Expand file tree Collapse file tree 2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -62,7 +62,9 @@ def main(cfg):
6262 model .to (cfg .device )
6363
6464 if cfg .weights is not None :
65- model .load_state_dict (torch .load (os .path .join (f"{ cfg .output_dir } /fold{ cfg .fold } " , cfg .weights ), weights_only = True )["model" ])
65+ model .load_state_dict (
66+ torch .load (os .path .join (f"{ cfg .output_dir } /fold{ cfg .fold } " , cfg .weights ), weights_only = True )["model" ]
67+ )
6668 print (f"weights from: { cfg .weights } are loaded." )
6769
6870 # set optimizer, lr scheduler
Original file line number Diff line number Diff line change @@ -189,11 +189,15 @@ def main():
189189 controlnet .load_state_dict (checkpoint_controlnet ["controlnet_state_dict" ], strict = True )
190190
191191 mask_generation_autoencoder = define_instance (args , "mask_generation_autoencoder" ).to (device )
192- checkpoint_mask_generation_autoencoder = torch .load (args .trained_mask_generation_autoencoder_path , weights_only = True )
192+ checkpoint_mask_generation_autoencoder = torch .load (
193+ args .trained_mask_generation_autoencoder_path , weights_only = True
194+ )
193195 mask_generation_autoencoder .load_state_dict (checkpoint_mask_generation_autoencoder )
194196
195197 mask_generation_diffusion_unet = define_instance (args , "mask_generation_diffusion" ).to (device )
196- checkpoint_mask_generation_diffusion_unet = torch .load (args .trained_mask_generation_diffusion_path , weights_only = False )
198+ checkpoint_mask_generation_diffusion_unet = torch .load (
199+ args .trained_mask_generation_diffusion_path , weights_only = False
200+ )
197201 mask_generation_diffusion_unet .load_state_dict (checkpoint_mask_generation_diffusion_unet ["unet_state_dict" ])
198202 mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet ["scale_factor" ]
199203
You can’t perform that action at this time.
0 commit comments