Skip to content

Commit b265a91

Browse files
committed
train rflow
Signed-off-by: Can-Zhao <[email protected]>
1 parent 5cff0a4 commit b265a91

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

generation/maisi/scripts/diff_model_infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,11 @@ def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
9191
top_region_index_tensor = np.array(args.diffusion_unet_inference["top_region_index"]).astype(float) * 1e2
9292
bottom_region_index_tensor = np.array(args.diffusion_unet_inference["bottom_region_index"]).astype(float) * 1e2
9393
spacing_tensor = np.array(args.diffusion_unet_inference["spacing"]).astype(float) * 1e2
94-
modality_tensor = np.array([args.diffusion_unet_inference["modality"]]).astype(int)
9594

9695
top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
9796
bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
9897
spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
99-
modality_tensor = torch.from_numpy(modality_tensor[np.newaxis, :]).long().to(device)
98+
modality_tensor = args.diffusion_unet_inference["modality"]*torch.ones((len(spacing_tensor)),dtype=torch.long).to(device)
10099

101100
return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
102101

generation/maisi/scripts/infer_controlnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def main():
161161
top_region_index_tensor = None
162162
bottom_region_index_tensor = None
163163
spacing_tensor = batch["spacing"].to(device)
164+
modality_tensor = args.controlnet_infer["modality"]*torch.ones((len(labels),),dtype=torch.long).to(device)
164165
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
165166
# get target dimension
166167
dim = batch["dim"]
@@ -180,6 +181,7 @@ def main():
180181
top_region_index_tensor=top_region_index_tensor,
181182
bottom_region_index_tensor=bottom_region_index_tensor,
182183
spacing_tensor=spacing_tensor,
184+
modality_tensor=modality_tensor,
183185
latent_shape=latent_shape,
184186
output_size=output_size,
185187
noise_factor=1.0,

generation/maisi/scripts/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,10 @@ def prepare_maisi_controlnet_json_dataloader(
306306
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
307307
Orientationd(keys=["label"], axcodes="RAS"),
308308
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
309-
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)),
310-
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)),
309+
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x),allow_missing_keys=True),
310+
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x),allow_missing_keys=True),
311311
Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
312-
Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2),
312+
Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2,allow_missing_keys=True),
313313
]
314314
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
315315

0 commit comments

Comments
 (0)