Skip to content

Commit d9f9096

Browse files
Support block replace patches in auraflow.
1 parent 41886af commit d9f9096

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

comfy/ldm/aura/mmdit.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ def apply_pos_embeds(self, x, h, w):
437437
pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
438438
return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
439439

440-
def forward(self, x, timestep, context, **kwargs):
440+
def forward(self, x, timestep, context, transformer_options={}, **kwargs):
441+
patches_replace = transformer_options.get("patches_replace", {})
441442
# patchify x, add PE
442443
b, c, h, w = x.shape
443444

@@ -458,15 +459,36 @@ def forward(self, x, timestep, context, **kwargs):
458459

459460
global_cond = self.t_embedder(t, x.dtype) # B, D
460461

462+
blocks_replace = patches_replace.get("dit", {})
461463
if len(self.double_layers) > 0:
462-
for layer in self.double_layers:
463-
c, x = layer(c, x, global_cond, **kwargs)
464+
for i, layer in enumerate(self.double_layers):
465+
if ("double_block", i) in blocks_replace:
466+
def block_wrap(args):
467+
out = {}
468+
out["txt"], out["img"] = layer(args["txt"],
469+
args["img"],
470+
args["vec"])
471+
return out
472+
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
473+
c = out["txt"]
474+
x = out["img"]
475+
else:
476+
c, x = layer(c, x, global_cond, **kwargs)
464477

465478
if len(self.single_layers) > 0:
466479
c_len = c.size(1)
467480
cx = torch.cat([c, x], dim=1)
468-
for layer in self.single_layers:
469-
cx = layer(cx, global_cond, **kwargs)
481+
for i, layer in enumerate(self.single_layers):
482+
if ("single_block", i) in blocks_replace:
483+
def block_wrap(args):
484+
out = {}
485+
out["img"] = layer(args["img"], args["vec"])
486+
return out
487+
488+
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
489+
cx = out["img"]
490+
else:
491+
cx = layer(cx, global_cond, **kwargs)
470492

471493
x = cx[:, c_len:]
472494

0 commit comments

Comments
 (0)