@@ -48,6 +48,12 @@ def main():
4848 default = "./configs/config_infer.json" ,
4949 help = "config json file that stores inference hyper-parameters" ,
5050 )
51+ parser .add_argument (
52+ "-x" ,
53+ "--extra-config-file" ,
54+ default = None ,
55+ help = "config json file that stores inference extra parameters" ,
56+ )
5157 parser .add_argument (
5258 "-s" ,
5359 "--random-seed" ,
@@ -122,7 +128,7 @@ def main():
122128 print (f"{ k } : { val } " )
123129 print ("Global config variables have been loaded." )
124130
125- # ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
131+ # ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
126132 #
127133 # The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json".
128134 # Please refer to README.md for the details.
@@ -135,11 +141,21 @@ def main():
135141 # override num_split if asked
136142 if "autoencoder_tp_num_splits" in config_infer_dict :
137143 args .autoencoder_def ["num_splits" ] = config_infer_dict ["autoencoder_tp_num_splits" ]
138- args .mask_generation_autoencoder_def ["num_splits" ] = config_infer_dict ["autoencoder_tp_num_splits" ]
144+ args .mask_generation_autoencoder_def ["num_splits" ] = config_infer_dict ["autoencoder_tp_num_splits" ]
139145 for k , v in config_infer_dict .items ():
140146 setattr (args , k , v )
141147 print (f"{ k } : { v } " )
142148
149+ #
150+ # ## Read in optional extra configuration setting - typically acceleration options (TRT)
151+ #
152+ #
153+ if args .extra_config_file is not None :
154+ extra_config_dict = json .load (open (args .extra_config_file , "r" ))
155+ for k , v in extra_config_dict .items ():
156+ setattr (args , k , v )
157+ print (f"{ k } : { v } " )
158+
143159 check_input (
144160 args .body_region ,
145161 args .anatomy_list ,
@@ -158,25 +174,25 @@ def main():
158174
159175 device = torch .device ("cuda" )
160176
161- autoencoder = define_instance (args , "autoencoder_def " ).to (device )
177+ autoencoder = define_instance (args , "autoencoder " ).to (device )
162178 checkpoint_autoencoder = torch .load (args .trained_autoencoder_path )
163179 autoencoder .load_state_dict (checkpoint_autoencoder )
164180
165- diffusion_unet = define_instance (args , "diffusion_unet_def " ).to (device )
181+ diffusion_unet = define_instance (args , "diffusion_unet " ).to (device )
166182 checkpoint_diffusion_unet = torch .load (args .trained_diffusion_path )
167183 diffusion_unet .load_state_dict (checkpoint_diffusion_unet ["unet_state_dict" ], strict = True )
168184 scale_factor = checkpoint_diffusion_unet ["scale_factor" ].to (device )
169185
170- controlnet = define_instance (args , "controlnet_def " ).to (device )
186+ controlnet = define_instance (args , "controlnet " ).to (device )
171187 checkpoint_controlnet = torch .load (args .trained_controlnet_path )
172188 monai .networks .utils .copy_model_state (controlnet , diffusion_unet .state_dict ())
173189 controlnet .load_state_dict (checkpoint_controlnet ["controlnet_state_dict" ], strict = True )
174190
175- mask_generation_autoencoder = define_instance (args , "mask_generation_autoencoder_def " ).to (device )
191+ mask_generation_autoencoder = define_instance (args , "mask_generation_autoencoder " ).to (device )
176192 checkpoint_mask_generation_autoencoder = torch .load (args .trained_mask_generation_autoencoder_path )
177193 mask_generation_autoencoder .load_state_dict (checkpoint_mask_generation_autoencoder )
178194
179- mask_generation_diffusion_unet = define_instance (args , "mask_generation_diffusion_def " ).to (device )
195+ mask_generation_diffusion_unet = define_instance (args , "mask_generation_diffusion " ).to (device )
180196 checkpoint_mask_generation_diffusion_unet = torch .load (args .trained_mask_generation_diffusion_path )
181197 mask_generation_diffusion_unet .load_state_dict (checkpoint_mask_generation_diffusion_unet ["unet_state_dict" ])
182198 mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet ["scale_factor" ]
0 commit comments