@@ -52,7 +52,7 @@ def forward(self, x):
52
52
53
53
54
54
class ControlNetFlux (Flux ):
55
- def __init__ (self , latent_input = False , num_union_modes = 0 , mistoline = False , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
55
+ def __init__ (self , latent_input = False , num_union_modes = 0 , mistoline = False , control_latent_channels = None , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
56
56
super ().__init__ (final_layer = False , dtype = dtype , device = device , operations = operations , ** kwargs )
57
57
58
58
self .main_model_double = 19
@@ -80,7 +80,12 @@ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image
80
80
81
81
self .gradient_checkpointing = False
82
82
self .latent_input = latent_input
83
- self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
83
+ if control_latent_channels is None :
84
+ control_latent_channels = self .in_channels
85
+ else :
86
+ control_latent_channels *= 2 * 2 #patch size
87
+
88
+ self .pos_embed_input = operations .Linear (control_latent_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
84
89
if not self .latent_input :
85
90
if self .mistoline :
86
91
self .input_cond_block = MistolineCondDownsamplBlock (dtype = dtype , device = device , operations = operations )
0 commit comments