Skip to content

Commit aab8ad9

Browse files
authored
Merge branch 'dev' into add-monaihosting-backup-url
2 parents 1d92253 + ab07523 commit aab8ad9

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

monai/inferers/inferer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
13341334
raise NotImplementedError(f"{mode} condition is not supported")
13351335

13361336
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1337-
down_block_res_samples, mid_block_res_sample = controlnet(
1338-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1339-
)
1337+
13401338
if mode == "concat" and condition is not None:
13411339
noisy_image = torch.cat([noisy_image, condition], dim=1)
13421340
condition = None
13431341

1342+
down_block_res_samples, mid_block_res_sample = controlnet(
1343+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1344+
)
1345+
13441346
diffuse = diffusion_model
13451347
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
13461348
diffuse = partial(diffusion_model, seg=seg)
@@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
13961398
progress_bar = iter(scheduler.timesteps)
13971399
intermediates = []
13981400
for t in progress_bar:
1399-
# 1. ControlNet forward
1400-
down_block_res_samples, mid_block_res_sample = controlnet(
1401-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1402-
)
1403-
# 2. predict noise model_output
14041401
diffuse = diffusion_model
14051402
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14061403
diffuse = partial(diffusion_model, seg=seg)
14071404

14081405
if mode == "concat" and conditioning is not None:
1406+
# 1. Conditioning
14091407
model_input = torch.cat([image, conditioning], dim=1)
1408+
# 2. ControlNet forward
1409+
down_block_res_samples, mid_block_res_sample = controlnet(
1410+
x=model_input,
1411+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1412+
controlnet_cond=cn_cond,
1413+
context=None,
1414+
)
1415+
# 3. predict noise model_output
14101416
model_output = diffuse(
14111417
model_input,
14121418
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
14151421
mid_block_additional_residual=mid_block_res_sample,
14161422
)
14171423
else:
1424+
down_block_res_samples, mid_block_res_sample = controlnet(
1425+
x=image,
1426+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1427+
controlnet_cond=cn_cond,
1428+
context=conditioning,
1429+
)
14181430
model_output = diffuse(
14191431
image,
14201432
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
14851497
for t in progress_bar:
14861498
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
14871499
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1488-
down_block_res_samples, mid_block_res_sample = controlnet(
1489-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1490-
)
14911500

14921501
diffuse = diffusion_model
14931502
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14941503
diffuse = partial(diffusion_model, seg=seg)
14951504

14961505
if mode == "concat" and conditioning is not None:
14971506
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1507+
down_block_res_samples, mid_block_res_sample = controlnet(
1508+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1509+
)
14981510
model_output = diffuse(
14991511
noisy_image,
15001512
timesteps=timesteps,
@@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
15031515
mid_block_additional_residual=mid_block_res_sample,
15041516
)
15051517
else:
1518+
down_block_res_samples, mid_block_res_sample = controlnet(
1519+
x=noisy_image,
1520+
timesteps=torch.Tensor((t,)).to(inputs.device),
1521+
controlnet_cond=cn_cond,
1522+
context=conditioning,
1523+
)
15061524
model_output = diffuse(
15071525
x=noisy_image,
15081526
timesteps=timesteps,

tests/inferers/test_controlnet_inferers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
550550
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
551551
model_params["with_conditioning"] = True
552552
model_params["cross_attention_dim"] = 3
553+
controlnet_params["with_conditioning"] = True
554+
controlnet_params["cross_attention_dim"] = 3
553555
model = DiffusionModelUNet(**model_params)
554556
controlnet = ControlNet(**controlnet_params)
555557
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
619621
model_params = model_params.copy()
620622
n_concat_channel = 2
621623
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
624+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
622625
model_params["cross_attention_dim"] = None
626+
controlnet_params["cross_attention_dim"] = None
623627
model_params["with_conditioning"] = False
628+
controlnet_params["with_conditioning"] = False
624629
model = DiffusionModelUNet(**model_params)
625630
device = "cuda:0" if torch.cuda.is_available() else "cpu"
626631
model.to(device)
@@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
10231028
if ae_model_type == "SPADEAutoencoderKL":
10241029
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
10251030
stage_2_params = stage_2_params.copy()
1031+
controlnet_params = controlnet_params.copy()
10261032
n_concat_channel = 3
10271033
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1034+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
10281035
if dm_model_type == "SPADEDiffusionModelUNet":
10291036
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
10301037
else:
@@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
11061113
if ae_model_type == "SPADEAutoencoderKL":
11071114
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
11081115
stage_2_params = stage_2_params.copy()
1116+
controlnet_params = controlnet_params.copy()
11091117
n_concat_channel = 3
11101118
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1119+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
11111120
if dm_model_type == "SPADEDiffusionModelUNet":
11121121
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
11131122
else:

0 commit comments

Comments
 (0)