@@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
13341334 raise NotImplementedError (f"{ mode } condition is not supported" )
13351335
13361336 noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
1337- down_block_res_samples , mid_block_res_sample = controlnet (
1338- x = noisy_image , timesteps = timesteps , controlnet_cond = cn_cond
1339- )
1337+
13401338 if mode == "concat" and condition is not None :
13411339 noisy_image = torch .cat ([noisy_image , condition ], dim = 1 )
13421340 condition = None
13431341
1342+ down_block_res_samples , mid_block_res_sample = controlnet (
1343+ x = noisy_image , timesteps = timesteps , controlnet_cond = cn_cond , context = condition
1344+ )
1345+
13441346 diffuse = diffusion_model
13451347 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
13461348 diffuse = partial (diffusion_model , seg = seg )
@@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
13961398 progress_bar = iter (scheduler .timesteps )
13971399 intermediates = []
13981400 for t in progress_bar :
1399- # 1. ControlNet forward
1400- down_block_res_samples , mid_block_res_sample = controlnet (
1401- x = image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), controlnet_cond = cn_cond
1402- )
1403- # 2. predict noise model_output
14041401 diffuse = diffusion_model
14051402 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
14061403 diffuse = partial (diffusion_model , seg = seg )
14071404
14081405 if mode == "concat" and conditioning is not None :
1406+ # 1. Conditioning
14091407 model_input = torch .cat ([image , conditioning ], dim = 1 )
1408+ # 2. ControlNet forward
1409+ down_block_res_samples , mid_block_res_sample = controlnet (
1410+ x = model_input ,
1411+ timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
1412+ controlnet_cond = cn_cond ,
1413+ context = None ,
1414+ )
1415+ # 3. predict noise model_output
14101416 model_output = diffuse (
14111417 model_input ,
14121418 timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
@@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
14151421 mid_block_additional_residual = mid_block_res_sample ,
14161422 )
14171423 else :
1424+ down_block_res_samples , mid_block_res_sample = controlnet (
1425+ x = image ,
1426+ timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
1427+ controlnet_cond = cn_cond ,
1428+ context = conditioning ,
1429+ )
14181430 model_output = diffuse (
14191431 image ,
14201432 timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
@@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
14851497 for t in progress_bar :
14861498 timesteps = torch .full (inputs .shape [:1 ], t , device = inputs .device ).long ()
14871499 noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
1488- down_block_res_samples , mid_block_res_sample = controlnet (
1489- x = noisy_image , timesteps = torch .Tensor ((t ,)).to (inputs .device ), controlnet_cond = cn_cond
1490- )
14911500
14921501 diffuse = diffusion_model
14931502 if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
14941503 diffuse = partial (diffusion_model , seg = seg )
14951504
14961505 if mode == "concat" and conditioning is not None :
14971506 noisy_image = torch .cat ([noisy_image , conditioning ], dim = 1 )
1507+ down_block_res_samples , mid_block_res_sample = controlnet (
1508+ x = noisy_image , timesteps = torch .Tensor ((t ,)).to (inputs .device ), controlnet_cond = cn_cond , context = None
1509+ )
14981510 model_output = diffuse (
14991511 noisy_image ,
15001512 timesteps = timesteps ,
@@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
15031515 mid_block_additional_residual = mid_block_res_sample ,
15041516 )
15051517 else :
1518+ down_block_res_samples , mid_block_res_sample = controlnet (
1519+ x = noisy_image ,
1520+ timesteps = torch .Tensor ((t ,)).to (inputs .device ),
1521+ controlnet_cond = cn_cond ,
1522+ context = conditioning ,
1523+ )
15061524 model_output = diffuse (
15071525 x = noisy_image ,
15081526 timesteps = timesteps ,
0 commit comments