Skip to content

Commit ab07523

Browse files
virginiafdezVirginia FernandezericspodKumoLiu
authored
Modify ControlNet inferer so that it takes in context when the diffus… (#8360)
Fixes #8344 ### Description The ControlNet inferers (latent and not latent) work in such a way that, when conditioning is used, the ControlNet does not take in the conditioning. It should, in theory, exhibit the same behaviour as the diffusion model. I've changed this behaviour, which has included modifying ControlNetDiffusionInferer and ControlNetLatentDiffusionInferer; the methods call, sample and get_likelihood. I've also modified the tests to take this into account. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes (modified, rather than new) - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Virginia Fernandez <[email protected]> Co-authored-by: Virginia Fernandez <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent a790590 commit ab07523

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

monai/inferers/inferer.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
13341334
raise NotImplementedError(f"{mode} condition is not supported")
13351335

13361336
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1337-
down_block_res_samples, mid_block_res_sample = controlnet(
1338-
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1339-
)
1337+
13401338
if mode == "concat" and condition is not None:
13411339
noisy_image = torch.cat([noisy_image, condition], dim=1)
13421340
condition = None
13431341

1342+
down_block_res_samples, mid_block_res_sample = controlnet(
1343+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
1344+
)
1345+
13441346
diffuse = diffusion_model
13451347
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
13461348
diffuse = partial(diffusion_model, seg=seg)
@@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
13961398
progress_bar = iter(scheduler.timesteps)
13971399
intermediates = []
13981400
for t in progress_bar:
1399-
# 1. ControlNet forward
1400-
down_block_res_samples, mid_block_res_sample = controlnet(
1401-
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1402-
)
1403-
# 2. predict noise model_output
14041401
diffuse = diffusion_model
14051402
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14061403
diffuse = partial(diffusion_model, seg=seg)
14071404

14081405
if mode == "concat" and conditioning is not None:
1406+
# 1. Conditioning
14091407
model_input = torch.cat([image, conditioning], dim=1)
1408+
# 2. ControlNet forward
1409+
down_block_res_samples, mid_block_res_sample = controlnet(
1410+
x=model_input,
1411+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1412+
controlnet_cond=cn_cond,
1413+
context=None,
1414+
)
1415+
# 3. predict noise model_output
14101416
model_output = diffuse(
14111417
model_input,
14121418
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
14151421
mid_block_additional_residual=mid_block_res_sample,
14161422
)
14171423
else:
1424+
down_block_res_samples, mid_block_res_sample = controlnet(
1425+
x=image,
1426+
timesteps=torch.Tensor((t,)).to(input_noise.device),
1427+
controlnet_cond=cn_cond,
1428+
context=conditioning,
1429+
)
14181430
model_output = diffuse(
14191431
image,
14201432
timesteps=torch.Tensor((t,)).to(input_noise.device),
@@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
14851497
for t in progress_bar:
14861498
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
14871499
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1488-
down_block_res_samples, mid_block_res_sample = controlnet(
1489-
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1490-
)
14911500

14921501
diffuse = diffusion_model
14931502
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14941503
diffuse = partial(diffusion_model, seg=seg)
14951504

14961505
if mode == "concat" and conditioning is not None:
14971506
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1507+
down_block_res_samples, mid_block_res_sample = controlnet(
1508+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
1509+
)
14981510
model_output = diffuse(
14991511
noisy_image,
15001512
timesteps=timesteps,
@@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
15031515
mid_block_additional_residual=mid_block_res_sample,
15041516
)
15051517
else:
1518+
down_block_res_samples, mid_block_res_sample = controlnet(
1519+
x=noisy_image,
1520+
timesteps=torch.Tensor((t,)).to(inputs.device),
1521+
controlnet_cond=cn_cond,
1522+
context=conditioning,
1523+
)
15061524
model_output = diffuse(
15071525
x=noisy_image,
15081526
timesteps=timesteps,

tests/inferers/test_controlnet_inferers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
550550
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
551551
model_params["with_conditioning"] = True
552552
model_params["cross_attention_dim"] = 3
553+
controlnet_params["with_conditioning"] = True
554+
controlnet_params["cross_attention_dim"] = 3
553555
model = DiffusionModelUNet(**model_params)
554556
controlnet = ControlNet(**controlnet_params)
555557
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
619621
model_params = model_params.copy()
620622
n_concat_channel = 2
621623
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
624+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
622625
model_params["cross_attention_dim"] = None
626+
controlnet_params["cross_attention_dim"] = None
623627
model_params["with_conditioning"] = False
628+
controlnet_params["with_conditioning"] = False
624629
model = DiffusionModelUNet(**model_params)
625630
device = "cuda:0" if torch.cuda.is_available() else "cpu"
626631
model.to(device)
@@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
10231028
if ae_model_type == "SPADEAutoencoderKL":
10241029
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
10251030
stage_2_params = stage_2_params.copy()
1031+
controlnet_params = controlnet_params.copy()
10261032
n_concat_channel = 3
10271033
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1034+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
10281035
if dm_model_type == "SPADEDiffusionModelUNet":
10291036
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
10301037
else:
@@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
11061113
if ae_model_type == "SPADEAutoencoderKL":
11071114
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
11081115
stage_2_params = stage_2_params.copy()
1116+
controlnet_params = controlnet_params.copy()
11091117
n_concat_channel = 3
11101118
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
1119+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
11111120
if dm_model_type == "SPADEDiffusionModelUNet":
11121121
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
11131122
else:

0 commit comments

Comments
 (0)