Skip to content

Commit efb4cab

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3842e59 commit efb4cab

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

competitions/MICCAI/surgtoolloc/classification_files/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def main(cfg):
6262
model.to(cfg.device)
6363

6464
if cfg.weights is not None:
65-
model.load_state_dict(torch.load(os.path.join(f"{cfg.output_dir}/fold{cfg.fold}", cfg.weights), weights_only=True)["model"])
65+
model.load_state_dict(
66+
torch.load(os.path.join(f"{cfg.output_dir}/fold{cfg.fold}", cfg.weights), weights_only=True)["model"]
67+
)
6668
print(f"weights from: {cfg.weights} are loaded.")
6769

6870
# set optimizer, lr scheduler

generation/maisi/scripts/inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,15 @@ def main():
189189
controlnet.load_state_dict(checkpoint_controlnet["controlnet_state_dict"], strict=True)
190190

191191
mask_generation_autoencoder = define_instance(args, "mask_generation_autoencoder").to(device)
192-
checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path, weights_only=True)
192+
checkpoint_mask_generation_autoencoder = torch.load(
193+
args.trained_mask_generation_autoencoder_path, weights_only=True
194+
)
193195
mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)
194196

195197
mask_generation_diffusion_unet = define_instance(args, "mask_generation_diffusion").to(device)
196-
checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path, weights_only=False)
198+
checkpoint_mask_generation_diffusion_unet = torch.load(
199+
args.trained_mask_generation_diffusion_path, weights_only=False
200+
)
197201
mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet["unet_state_dict"])
198202
mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet["scale_factor"]
199203

0 commit comments

Comments
 (0)