Skip to content

Commit cf9fb59

Browse files
virginiafdezVirginia FernandezericspodKumoLiu
authored andcommitted
Modify ControlNet inferer so that it takes in context when the diffus… (Project-MONAI#8360)
Fixes Project-MONAI#8344 ### Description The ControlNet inferers (latent and not latent) work in such a way that, when conditioning is used, the ControlNet does not take in the conditioning. It should, in theory, exhibit the same behaviour as the diffusion model. I've changed this behaviour, which has included modifying ControlNetDiffusionInferer and ControlNetLatentDiffusionInferer; the methods call, sample and get_likelihood. I've also modified the tests to take this into account. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes (modified, rather than new) - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]> Signed-off-by: Can-Zhao <[email protected]>
1 parent a9a7082 commit cf9fb59

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

monai/inferers/inferer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,13 +1345,15 @@ def __call__( # type: ignore[override]
13451345
raise NotImplementedError(f"{mode} condition is not supported")
13461346

13471347
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1348-
down_block_res_samples, mid_block_res_sample = controlnet(
1349-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1350-
)
1348+
13511349
if mode == "concat" and condition is not None:
13521350
noisy_image = torch.cat([noisy_image, condition], dim=1)
13531351
condition = None
13541352

1353+
down_block_res_samples, mid_block_res_sample = controlnet(
1354+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1355+
)
1356+
13551357
diffuse = diffusion_model
13561358
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
13571359
diffuse = partial(diffusion_model, seg=seg)
@@ -1407,17 +1409,21 @@ def sample( # type: ignore[override]
14071409
progress_bar = iter(scheduler.timesteps)
14081410
intermediates = []
14091411
for t in progress_bar:
1410-
# 1. ControlNet forward
1411-
down_block_res_samples, mid_block_res_sample = controlnet(
1412-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1413-
)
1414-
# 2. predict noise model_output
14151412
diffuse = diffusion_model
14161413
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14171414
diffuse = partial(diffusion_model, seg=seg)
14181415

14191416
if mode == "concat" and conditioning is not None:
1417+
# 1. Conditioning
14201418
model_input = torch.cat([image, conditioning], dim=1)
1419+
# 2. ControlNet forward
1420+
down_block_res_samples, mid_block_res_sample = controlnet(
1421+
x=model_input,
1422+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1423+
controlnet_cond=cn_cond,
1424+
context=None,
1425+
)
1426+
# 3. predict noise model_output
14211427
model_output = diffuse(
14221428
model_input,
14231429
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1426,6 +1432,12 @@ def sample( # type: ignore[override]
14261432
mid_block_additional_residual=mid_block_res_sample,
14271433
)
14281434
else:
1435+
down_block_res_samples, mid_block_res_sample = controlnet(
1436+
x=image,
1437+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1438+
controlnet_cond=cn_cond,
1439+
context=conditioning,
1440+
)
14291441
model_output = diffuse(
14301442
image,
14311443
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1496,16 +1508,16 @@ def get_likelihood( # type: ignore[override]
14961508
for t in progress_bar:
14971509
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
14981510
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1499-
down_block_res_samples, mid_block_res_sample = controlnet(
1500-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1501-
)
15021511

15031512
diffuse = diffusion_model
15041513
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
15051514
diffuse = partial(diffusion_model, seg=seg)
15061515

15071516
if mode == "concat" and conditioning is not None:
15081517
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1518+
down_block_res_samples, mid_block_res_sample = controlnet(
1519+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1520+
)
15091521
model_output = diffuse(
15101522
noisy_image,
15111523
timesteps=timesteps,
@@ -1514,6 +1526,12 @@ def get_likelihood( # type: ignore[override]
15141526
mid_block_additional_residual=mid_block_res_sample,
15151527
)
15161528
else:
1529+
down_block_res_samples, mid_block_res_sample = controlnet(
1530+
x=noisy_image,
1531+
timesteps=torch.Tensor((t,)).to(inputs.device),
1532+
controlnet_cond=cn_cond,
1533+
context=conditioning,
1534+
)
15171535
model_output = diffuse(
15181536
x=noisy_image,
15191537
timesteps=timesteps,

tests/inferers/test_controlnet_inferers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
550550
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
551551
model_params["with_conditioning"] = True
552552
model_params["cross_attention_dim"] = 3
553+
controlnet_params["with_conditioning"] = True
554+
controlnet_params["cross_attention_dim"] = 3
553555
model = DiffusionModelUNet(**model_params)
554556
controlnet = ControlNet(**controlnet_params)
555557
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
619621
model_params = model_params.copy()
620622
n_concat_channel = 2
621623
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
624+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
622625
model_params["cross_attention_dim"] = None
626+
controlnet_params["cross_attention_dim"] = None
623627
model_params["with_conditioning"] = False
628+
controlnet_params["with_conditioning"] = False
624629
model = DiffusionModelUNet(**model_params)
625630
device = "cuda:0" if torch.cuda.is_available() else "cpu"
626631
model.to(device)
@@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
10231028
if ae_model_type == "SPADEAutoencoderKL":
10241029
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
10251030
stage_2_params = stage_2_params.copy()
1031+
controlnet_params = controlnet_params.copy()
10261032
n_concat_channel = 3
10271033
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1034+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
10281035
if dm_model_type == "SPADEDiffusionModelUNet":
10291036
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
10301037
else:
@@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
11061113
if ae_model_type == "SPADEAutoencoderKL":
11071114
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
11081115
stage_2_params = stage_2_params.copy()
1116+
controlnet_params = controlnet_params.copy()
11091117
n_concat_channel = 3
11101118
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1119+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
11111120
if dm_model_type == "SPADEDiffusionModelUNet":
11121121
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
11131122
else:

0 commit comments

Comments
 (0)