Skip to content

Commit 04a9326

Browse files
committed
Added TRT config for inference
Signed-off-by: Boris Fomitchev <[email protected]>
1 parent e0bba33 commit 04a9326

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

generation/maisi/configs/config_infer.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,10 @@
1818
2.0
1919
],
2020
"autoencoder_sliding_window_infer_size": [48,48,48],
21-
"autoencoder_sliding_window_infer_overlap": 0.25
21+
"autoencoder_sliding_window_infer_overlap": 0.25,
22+
"controlnet": "$@controlnet_def",
23+
"diffusion_unet": "$@diffusion_unet_def",
24+
"autoencoder": "$@autoencoder_def",
25+
"mask_generation_autoencoder": "$@mask_generation_autoencoder_def",
26+
"mask_generation_diffusion": "$@mask_generation_diffusion_def"
2227
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"+imports": [
3+
"$from monai.networks import trt_compile"
4+
],
5+
"c_trt_args": {
6+
"export_args": {
7+
"dynamo": "$False",
8+
"report": "$True"
9+
},
10+
"output_lists": [
11+
[
12+
-1
13+
],
14+
[
15+
]
16+
]
17+
},
18+
"device": "cuda",
19+
"controlnet": "$trt_compile(@controlnet_def.to(@device), @trained_controlnet_path, @c_trt_args)",
20+
"diffusion_unet": "$trt_compile(@diffusion_unet_def.to(@device), @trained_diffusion_path)",
21+
"mask_generation_diffusion": "$trt_compile(@mask_generation_diffusion_def.to(@device), @trained_mask_generation_diffusion_path)"
22+
}

generation/maisi/scripts/inference.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def main():
4848
default="./configs/config_infer.json",
4949
help="config json file that stores inference hyper-parameters",
5050
)
51+
parser.add_argument(
52+
"-x",
53+
"--extra-config-file",
54+
default=None,
55+
help="config json file that stores inference extra parameters",
56+
)
5157
parser.add_argument(
5258
"-s",
5359
"--random-seed",
@@ -122,7 +128,7 @@ def main():
122128
print(f"{k}: {val}")
123129
print("Global config variables have been loaded.")
124130

125-
# ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
131+
# ## Read in configuration setting, including network definition, body region and anatomy to generate, etc.
126132
#
127133
# The information for the inference input, like body region and anatomy to generate, is stored in "./configs/config_infer.json".
128134
# Please refer to README.md for the details.
@@ -135,11 +141,21 @@ def main():
135141
# override num_split if asked
136142
if "autoencoder_tp_num_splits" in config_infer_dict:
137143
args.autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
138-
args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
144+
args.mask_generation_autoencoder_def["num_splits"] = config_infer_dict["autoencoder_tp_num_splits"]
139145
for k, v in config_infer_dict.items():
140146
setattr(args, k, v)
141147
print(f"{k}: {v}")
142148

149+
#
150+
# ## Read in optional extra configuration setting - typically acceleration options (TRT)
151+
#
152+
#
153+
if args.extra_config_file is not None:
154+
extra_config_dict = json.load(open(args.extra_config_file, "r"))
155+
for k, v in extra_config_dict.items():
156+
setattr(args, k, v)
157+
print(f"{k}: {v}")
158+
143159
check_input(
144160
args.body_region,
145161
args.anatomy_list,
@@ -158,25 +174,25 @@ def main():
158174

159175
device = torch.device("cuda")
160176

161-
autoencoder = define_instance(args, "autoencoder_def").to(device)
177+
autoencoder = define_instance(args, "autoencoder").to(device)
162178
checkpoint_autoencoder = torch.load(args.trained_autoencoder_path)
163179
autoencoder.load_state_dict(checkpoint_autoencoder)
164180

165-
diffusion_unet = define_instance(args, "diffusion_unet_def").to(device)
181+
diffusion_unet = define_instance(args, "diffusion_unet").to(device)
166182
checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)
167183
diffusion_unet.load_state_dict(checkpoint_diffusion_unet["unet_state_dict"], strict=True)
168184
scale_factor = checkpoint_diffusion_unet["scale_factor"].to(device)
169185

170-
controlnet = define_instance(args, "controlnet_def").to(device)
186+
controlnet = define_instance(args, "controlnet").to(device)
171187
checkpoint_controlnet = torch.load(args.trained_controlnet_path)
172188
monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())
173189
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
174190

175-
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder_def").to(device)
191+
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device)
176192
checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)
177193
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
178194

179-
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion_def").to(device)
195+
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device)
180196
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)
181197
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
182198
mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]

0 commit comments

Comments
 (0)