Skip to content

Commit 196954a

Browse files
authored
Add 'input_cond' and 'input_uncond' to the args dictionary passed into sampler_cfg_function (#10044)
1 parent 1e098d6 commit 196954a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

comfy/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
360360
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
361361
if "sampler_cfg_function" in model_options:
362362
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
363-
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
363+
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, "input_cond": cond, "input_uncond": uncond}
364364
cfg_result = x - model_options["sampler_cfg_function"](args)
365365
else:
366366
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
@@ -390,7 +390,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
390390
for fn in model_options.get("sampler_pre_cfg_function", []):
391391
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
392392
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
393-
out = fn(args)
393+
out = fn(args)
394394

395395
return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)
396396

0 commit comments

Comments
 (0)