@@ -360,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
360
360
def cfg_function (model , cond_pred , uncond_pred , cond_scale , x , timestep , model_options = {}, cond = None , uncond = None ):
361
361
if "sampler_cfg_function" in model_options :
362
362
args = {"cond" : x - cond_pred , "uncond" : x - uncond_pred , "cond_scale" : cond_scale , "timestep" : timestep , "input" : x , "sigma" : timestep ,
363
- "cond_denoised" : cond_pred , "uncond_denoised" : uncond_pred , "model" : model , "model_options" : model_options }
363
+ "cond_denoised" : cond_pred , "uncond_denoised" : uncond_pred , "model" : model , "model_options" : model_options , "input_cond" : cond , "input_uncond" : uncond }
364
364
cfg_result = x - model_options ["sampler_cfg_function" ](args )
365
365
else :
366
366
cfg_result = uncond_pred + (cond_pred - uncond_pred ) * cond_scale
@@ -390,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
390
390
for fn in model_options .get ("sampler_pre_cfg_function" , []):
391
391
args = {"conds" :conds , "conds_out" : out , "cond_scale" : cond_scale , "timestep" : timestep ,
392
392
"input" : x , "sigma" : timestep , "model" : model , "model_options" : model_options }
393
- out = fn (args )
393
+ out = fn (args )
394
394
395
395
return cfg_function (model , out [0 ], out [1 ], cond_scale , x , timestep , model_options = model_options , cond = cond , uncond = uncond_ )
396
396
0 commit comments