Skip to content

Commit 32048e6

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent babaf90 commit 32048e6

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

generation/maisi/scripts/diff_model_infer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def prepare_tensors(args: argparse.Namespace, device: torch.device) -> tuple:
9595
top_region_index_tensor = torch.from_numpy(top_region_index_tensor[np.newaxis, :]).half().to(device)
9696
bottom_region_index_tensor = torch.from_numpy(bottom_region_index_tensor[np.newaxis, :]).half().to(device)
9797
spacing_tensor = torch.from_numpy(spacing_tensor[np.newaxis, :]).half().to(device)
98-
modality_tensor = args.diffusion_unet_inference["modality"]*torch.ones((len(spacing_tensor)),dtype=torch.long).to(device)
98+
modality_tensor = args.diffusion_unet_inference["modality"] * torch.ones(
99+
(len(spacing_tensor)), dtype=torch.long
100+
).to(device)
99101

100102
return top_region_index_tensor, bottom_region_index_tensor, spacing_tensor, modality_tensor
101103

generation/maisi/scripts/infer_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def main():
161161
top_region_index_tensor = None
162162
bottom_region_index_tensor = None
163163
spacing_tensor = batch["spacing"].to(device)
164-
modality_tensor = args.controlnet_infer["modality"]*torch.ones((len(labels),),dtype=torch.long).to(device)
164+
modality_tensor = args.controlnet_infer["modality"] * torch.ones((len(labels),), dtype=torch.long).to(device)
165165
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
166166
# get target dimension
167167
dim = batch["dim"]

generation/maisi/scripts/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,12 @@ def prepare_maisi_controlnet_json_dataloader(
306306
LoadImaged(keys=["image", "label"], image_only=True, ensure_channel_first=True),
307307
Orientationd(keys=["label"], axcodes="RAS"),
308308
EnsureTyped(keys=["label"], dtype=torch.uint8, track_meta=True),
309-
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x),allow_missing_keys=True),
310-
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x),allow_missing_keys=True),
309+
Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True),
310+
Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x), allow_missing_keys=True),
311311
Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)),
312-
Lambdad(keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2,allow_missing_keys=True),
312+
Lambdad(
313+
keys=["top_region_index", "bottom_region_index", "spacing"], func=lambda x: x * 1e2, allow_missing_keys=True
314+
),
313315
]
314316
train_transforms, val_transforms = Compose(common_transform), Compose(common_transform)
315317

0 commit comments

Comments
 (0)