@@ -437,7 +437,8 @@ def apply_pos_embeds(self, x, h, w):
437
437
pos_encoding = pos_encoding [:,from_h :from_h + h ,from_w :from_w + w ]
438
438
return x + pos_encoding .reshape (1 , - 1 , self .positional_encoding .shape [- 1 ])
439
439
440
- def forward (self , x , timestep , context , ** kwargs ):
440
+ def forward (self , x , timestep , context , transformer_options = {}, ** kwargs ):
441
+ patches_replace = transformer_options .get ("patches_replace" , {})
441
442
# patchify x, add PE
442
443
b , c , h , w = x .shape
443
444
@@ -458,15 +459,36 @@ def forward(self, x, timestep, context, **kwargs):
458
459
459
460
global_cond = self .t_embedder (t , x .dtype ) # B, D
460
461
462
+ blocks_replace = patches_replace .get ("dit" , {})
461
463
if len (self .double_layers ) > 0 :
462
- for layer in self .double_layers :
463
- c , x = layer (c , x , global_cond , ** kwargs )
464
+ for i , layer in enumerate (self .double_layers ):
465
+ if ("double_block" , i ) in blocks_replace :
466
+ def block_wrap (args ):
467
+ out = {}
468
+ out ["txt" ], out ["img" ] = layer (args ["txt" ],
469
+ args ["img" ],
470
+ args ["vec" ])
471
+ return out
472
+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : c , "vec" : global_cond }, {"original_block" : block_wrap })
473
+ c = out ["txt" ]
474
+ x = out ["img" ]
475
+ else :
476
+ c , x = layer (c , x , global_cond , ** kwargs )
464
477
465
478
if len (self .single_layers ) > 0 :
466
479
c_len = c .size (1 )
467
480
cx = torch .cat ([c , x ], dim = 1 )
468
- for layer in self .single_layers :
469
- cx = layer (cx , global_cond , ** kwargs )
481
+ for i , layer in enumerate (self .single_layers ):
482
+ if ("single_block" , i ) in blocks_replace :
483
+ def block_wrap (args ):
484
+ out = {}
485
+ out ["img" ] = layer (args ["img" ], args ["vec" ])
486
+ return out
487
+
488
+ out = blocks_replace [("single_block" , i )]({"img" : cx , "vec" : global_cond }, {"original_block" : block_wrap })
489
+ cx = out ["img" ]
490
+ else :
491
+ cx = layer (cx , global_cond , ** kwargs )
470
492
471
493
x = cx [:, c_len :]
472
494
0 commit comments