@@ -96,7 +96,9 @@ def forward_orig(
96
96
y : Tensor ,
97
97
guidance : Tensor = None ,
98
98
control = None ,
99
+ transformer_options = {},
99
100
) -> Tensor :
101
+ patches_replace = transformer_options .get ("patches_replace" , {})
100
102
if img .ndim != 3 or txt .ndim != 3 :
101
103
raise ValueError ("Input img and txt tensors must have 3 dimensions." )
102
104
@@ -114,8 +116,19 @@ def forward_orig(
114
116
ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
115
117
pe = self .pe_embedder (ids )
116
118
119
+ blocks_replace = patches_replace .get ("dit" , {})
117
120
for i , block in enumerate (self .double_blocks ):
118
- img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
121
+ if ("double_block" , i ) in blocks_replace :
122
+ def block_wrap (args ):
123
+ out = {}
124
+ out ["img" ], out ["txt" ] = block (img = args ["img" ], txt = args ["txt" ], vec = args ["vec" ], pe = args ["pe" ])
125
+ return out
126
+
127
+ out = blocks_replace [("double_block" , i )]({"img" : img , "txt" : txt , "vec" : vec , "pe" : pe }, {"original_block" : block_wrap })
128
+ txt = out ["txt" ]
129
+ img = out ["img" ]
130
+ else :
131
+ img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
119
132
120
133
if control is not None : # Controlnet
121
134
control_i = control .get ("input" )
@@ -127,7 +140,16 @@ def forward_orig(
127
140
img = torch .cat ((txt , img ), 1 )
128
141
129
142
for i , block in enumerate (self .single_blocks ):
130
- img = block (img , vec = vec , pe = pe )
143
+ if ("single_block" , i ) in blocks_replace :
144
+ def block_wrap (args ):
145
+ out = {}
146
+ out ["img" ] = block (args ["img" ], vec = args ["vec" ], pe = args ["pe" ])
147
+ return out
148
+
149
+ out = blocks_replace [("single_block" , i )]({"img" : img , "vec" : vec , "pe" : pe }, {"original_block" : block_wrap })
150
+ img = out ["img" ]
151
+ else :
152
+ img = block (img , vec = vec , pe = pe )
131
153
132
154
if control is not None : # Controlnet
133
155
control_o = control .get ("output" )
@@ -141,7 +163,7 @@ def forward_orig(
141
163
img = self .final_layer (img , vec ) # (N, T, patch_size ** 2 * out_channels)
142
164
return img
143
165
144
- def forward (self , x , timestep , context , y , guidance , control = None , ** kwargs ):
166
+ def forward (self , x , timestep , context , y , guidance , control = None , transformer_options = {}, ** kwargs ):
145
167
bs , c , h , w = x .shape
146
168
patch_size = 2
147
169
x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
@@ -156,5 +178,5 @@ def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
156
178
img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
157
179
158
180
txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
159
- out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control )
181
+ out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , transformer_options )
160
182
return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h ,:w ]
0 commit comments