Skip to content

Commit cf80d28

Browse files
Support loading controlnets with different input.
1 parent 6fb44c4 commit cf80d28

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

comfy/controlnet.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,9 @@ def load_controlnet_flux_instantx(sd):
449449
if union_cnet in new_sd:
450450
num_union_modes = new_sd[union_cnet].shape[0]
451451

452-
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
452+
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
453+
454+
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
453455
control_model = controlnet_load_state_dict(control_model, new_sd)
454456

455457
latent_format = comfy.latent_formats.Flux()

comfy/ldm/flux/controlnet.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def forward(self, x):
5252

5353

5454
class ControlNetFlux(Flux):
55-
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
55+
def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
5656
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
5757

5858
self.main_model_double = 19
@@ -80,7 +80,12 @@ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image
8080

8181
self.gradient_checkpointing = False
8282
self.latent_input = latent_input
83-
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
83+
if control_latent_channels is None:
84+
control_latent_channels = self.in_channels
85+
else:
86+
control_latent_channels *= 2 * 2 #patch size
87+
88+
self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
8489
if not self.latent_input:
8590
if self.mistoline:
8691
self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)

0 commit comments

Comments
 (0)