diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.1-flf2v-14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.1_flf2v_14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.1-i2v-14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.1_i2v_14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json b/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-1.3b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_1.3b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.1-t2v-14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.1_t2v_14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.2-i2v-a14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.2_i2v_a14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.2_s2v_14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json b/diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.2-t2v-a14b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.2_t2v_a14b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json b/diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json similarity index 100% rename from diffsynth_engine/conf/models/wan/dit/wan2.2-ti2v-5b.json rename to diffsynth_engine/conf/models/wan/dit/wan2.2_ti2v_5b.json diff --git a/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json b/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json new file mode 100644 index 00000000..2e45127e --- /dev/null +++ b/diffsynth_engine/conf/models/wan/dit/wan_dit_keymap.json @@ -0,0 +1,41 @@ +{ + "diffusers": { + "global_rename_dict": { + "patch_embedding": "patch_embedding", + "condition_embedder.text_embedder.linear_1": "text_embedding.0", + "condition_embedder.text_embedder.linear_2": "text_embedding.2", + "condition_embedder.time_embedder.linear_1": "time_embedding.0", + "condition_embedder.time_embedder.linear_2": "time_embedding.2", + "condition_embedder.time_proj": "time_projection.1", + "condition_embedder.image_embedder.norm1": "img_emb.proj.0", + "condition_embedder.image_embedder.ff.net.0.proj": "img_emb.proj.1", + "condition_embedder.image_embedder.ff.net.2": "img_emb.proj.3", + "condition_embedder.image_embedder.norm2": "img_emb.proj.4", + "condition_embedder.image_embedder.pos_embed": "img_emb.emb_pos", + "proj_out": "head.head", + "scale_shift_table": "head.modulation" + }, + "rename_dict": { + "attn1.to_q": "self_attn.q", + "attn1.to_k": "self_attn.k", + "attn1.to_v": "self_attn.v", + "attn1.to_out.0": "self_attn.o", + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + "to_gate_compress": "self_attn.gate_compress", + "attn2.to_q": "cross_attn.q", + "attn2.to_k": "cross_attn.k", + "attn2.to_v": "cross_attn.v", + "attn2.to_out.0": "cross_attn.o", + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + "attn2.add_k_proj": "cross_attn.k_img", + "attn2.add_v_proj": "cross_attn.v_img", + "attn2.norm_added_k": "cross_attn.norm_k_img", + "norm2": "norm3", + "ffn.net.0.proj": "ffn.0", + "ffn.net.2": "ffn.2", + "scale_shift_table": "modulation" + } + } +} \ No newline at end of file diff --git a/diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json b/diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json similarity index 100% rename from diffsynth_engine/conf/models/wan/vae/wan2.1-vae.json rename to diffsynth_engine/conf/models/wan/vae/wan2.1_vae.json diff --git a/diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json b/diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json similarity index 100% rename from diffsynth_engine/conf/models/wan/vae/wan2.2-vae.json rename to diffsynth_engine/conf/models/wan/vae/wan2.2_vae.json diff --git a/diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json b/diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json similarity index 100% rename from diffsynth_engine/conf/models/wan/vae/wan-vae-keymap.json rename to diffsynth_engine/conf/models/wan/vae/wan_vae_keymap.json diff --git a/diffsynth_engine/configs/pipeline.py b/diffsynth_engine/configs/pipeline.py index d36d3aa9..131883de 100644 --- a/diffsynth_engine/configs/pipeline.py +++ b/diffsynth_engine/configs/pipeline.py @@ -5,6 +5,7 @@ from typing import List, Dict, Tuple, Optional from diffsynth_engine.configs.controlnet import ControlType +from diffsynth_engine.models.basic.video_sparse_attention import get_vsa_kwargs @dataclass @@ -30,16 +31,43 @@ class AttnImpl(Enum): SDPA = "sdpa" # Scaled Dot Product Attention SAGE = "sage" # Sage Attention SPARGE = "sparge" # Sparge Attention + VSA = "vsa" # Video Sparse Attention + + +@dataclass +class SpargeAttentionParams: + smooth_k: bool = True + cdfthreshd: float = 0.6 + simthreshd1: float = 0.98 + pvthreshd: float = 50.0 + + +@dataclass +class VideoSparseAttentionParams: + sparsity: float = 0.9 @dataclass class AttentionConfig: dit_attn_impl: AttnImpl = AttnImpl.AUTO - # Sparge Attention - sparge_smooth_k: bool = True - sparge_cdfthreshd: float = 0.6 - sparge_simthreshd1: float = 0.98 - sparge_pvthreshd: float = 50.0 + attn_params: Optional[SpargeAttentionParams | VideoSparseAttentionParams] = None + + def get_attn_kwargs(self, latents: torch.Tensor, device: str) -> Dict: + attn_kwargs = {"attn_impl": self.dit_attn_impl.value} + if isinstance(self.attn_params, SpargeAttentionParams): + assert self.dit_attn_impl == AttnImpl.SPARGE + attn_kwargs.update( + { + "smooth_k": self.attn_params.smooth_k, + "simthreshd1": self.attn_params.simthreshd1, + "cdfthreshd": self.attn_params.cdfthreshd, + "pvthreshd": self.attn_params.pvthreshd, + } + ) + elif isinstance(self.attn_params, VideoSparseAttentionParams): + assert self.dit_attn_impl == AttnImpl.VSA + attn_kwargs.update(get_vsa_kwargs(latents.shape[2:], (1, 2, 2), self.attn_params.sparsity, device=device)) + return attn_kwargs @dataclass diff --git a/diffsynth_engine/models/basic/attention.py b/diffsynth_engine/models/basic/attention.py index 6b949784..9954672e 100644 --- a/diffsynth_engine/models/basic/attention.py +++ b/diffsynth_engine/models/basic/attention.py @@ -12,6 +12,7 @@ SDPA_AVAILABLE, SAGE_ATTN_AVAILABLE, SPARGE_ATTN_AVAILABLE, + VIDEO_SPARSE_ATTN_AVAILABLE, ) from diffsynth_engine.utils.platform import DTYPE_FP8 @@ -20,12 +21,6 @@ logger = logging.get_logger(__name__) -def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8): - padding_size = (alignment - x.shape[dim] % alignment) % alignment - padded_x = F.pad(x, (0, padding_size), "constant", 0) - return padded_x[..., : x.shape[dim]] - - if FLASH_ATTN_3_AVAILABLE: from flash_attn_interface import flash_attn_func as flash_attn3 if FLASH_ATTN_2_AVAILABLE: @@ -33,6 +28,11 @@ def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8): if XFORMERS_AVAILABLE: from xformers.ops import memory_efficient_attention + def memory_align(x: torch.Tensor, dim=-1, alignment: int = 8): + padding_size = (alignment - x.shape[dim] % alignment) % alignment + padded_x = F.pad(x, (0, padding_size), "constant", 0) + return padded_x[..., : x.shape[dim]] + def xformers_attn(q, k, v, attn_mask=None, scale=None): if attn_mask is not None: if attn_mask.ndim == 2: @@ -94,6 +94,13 @@ def sparge_attn( return out.transpose(1, 2) +if VIDEO_SPARSE_ATTN_AVAILABLE: + from diffsynth_engine.models.basic.video_sparse_attention import ( + video_sparse_attn, + distributed_video_sparse_attn, + ) + + def eager_attn(q, k, v, attn_mask=None, scale=None): q = q.transpose(1, 2) k = k.transpose(1, 2) @@ -109,9 +116,10 @@ def eager_attn(q, k, v, attn_mask=None, scale=None): def attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, attn_impl: Optional[str] = "auto", attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, @@ -133,6 +141,7 @@ def attention( "sdpa", "sage", "sparge", + "vsa", ] flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM if attn_impl is None or attn_impl == "auto": @@ -189,10 +198,24 @@ def attention( v, attn_mask=attn_mask, scale=scale, - smooth_k=kwargs.get("sparge_smooth_k", True), - simthreshd1=kwargs.get("sparge_simthreshd1", 0.6), - cdfthreshd=kwargs.get("sparge_cdfthreshd", 0.98), - pvthreshd=kwargs.get("sparge_pvthreshd", 50), + smooth_k=kwargs.get("smooth_k", True), + simthreshd1=kwargs.get("simthreshd1", 0.6), + cdfthreshd=kwargs.get("cdfthreshd", 0.98), + pvthreshd=kwargs.get("pvthreshd", 50), + ) + if attn_impl == "vsa": + return video_sparse_attn( + q, + k, + v, + g, + sparsity=kwargs.get("sparsity"), + num_tiles=kwargs.get("num_tiles"), + total_seq_length=kwargs.get("total_seq_length"), + tile_partition_indices=kwargs.get("tile_partition_indices"), + reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"), + variable_block_sizes=kwargs.get("variable_block_sizes"), + non_pad_index=kwargs.get("non_pad_index"), ) raise ValueError(f"Invalid attention implementation: {attn_impl}") @@ -242,9 +265,10 @@ def forward( def long_context_attention( - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, attn_impl: Optional[str] = None, attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, @@ -267,6 +291,7 @@ def long_context_attention( "sdpa", "sage", "sparge", + "vsa", ] assert attn_mask is None, "long context attention does not support attention mask" flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM @@ -307,11 +332,25 @@ def long_context_attention( if attn_impl == "sparge": attn_processor = SparseAttentionMeansim() # default args from spas_sage2_attn_meansim_cuda - attn_processor.smooth_k = torch.tensor(kwargs.get("sparge_smooth_k", True)) - attn_processor.simthreshd1 = torch.tensor(kwargs.get("sparge_simthreshd1", 0.6)) - attn_processor.cdfthreshd = torch.tensor(kwargs.get("sparge_cdfthreshd", 0.98)) - attn_processor.pvthreshd = torch.tensor(kwargs.get("sparge_pvthreshd", 50)) + attn_processor.smooth_k = torch.tensor(kwargs.get("smooth_k", True)) + attn_processor.simthreshd1 = torch.tensor(kwargs.get("simthreshd1", 0.6)) + attn_processor.cdfthreshd = torch.tensor(kwargs.get("cdfthreshd", 0.98)) + attn_processor.pvthreshd = torch.tensor(kwargs.get("pvthreshd", 50)) return LongContextAttention(attn_type=AttnType.SPARSE_SAGE, attn_processor=attn_processor)( q, k, v, softmax_scale=scale ) + if attn_impl == "vsa": + return distributed_video_sparse_attn( + q, + k, + v, + g, + sparsity=kwargs.get("sparsity"), + num_tiles=kwargs.get("num_tiles"), + total_seq_length=kwargs.get("total_seq_length"), + tile_partition_indices=kwargs.get("tile_partition_indices"), + reverse_tile_partition_indices=kwargs.get("reverse_tile_partition_indices"), + variable_block_sizes=kwargs.get("variable_block_sizes"), + non_pad_index=kwargs.get("non_pad_index"), + ) raise ValueError(f"Invalid long context attention implementation: {attn_impl}") diff --git a/diffsynth_engine/models/basic/video_sparse_attention.py b/diffsynth_engine/models/basic/video_sparse_attention.py new file mode 100644 index 00000000..2665ce05 --- /dev/null +++ b/diffsynth_engine/models/basic/video_sparse_attention.py @@ -0,0 +1,235 @@ +import torch +import math +import functools + +from vsa import video_sparse_attn as vsa_core +from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size + +VSA_TILE_SIZE = (4, 4, 4) + + +@functools.lru_cache(maxsize=10) +def get_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + ts, hs, ws = tile_size + indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W) + ls = [] + for t in range(math.ceil(T / ts)): + for h in range(math.ceil(H / hs)): + for w in range(math.ceil(W / ws)): + ls.append( + indices[ + t * ts : min(t * ts + ts, T), h * hs : min(h * hs + hs, H), w * ws : min(w * ws + ws, W) + ].flatten() + ) + index = torch.cat(ls, dim=0) + return index + + +@functools.lru_cache(maxsize=10) +def get_reverse_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device)) + + +@functools.lru_cache(maxsize=10) +def construct_variable_block_sizes( + dit_seq_shape: tuple[int, int, int], + num_tiles: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """ + Compute the number of valid (non-padded) tokens inside every + (ts_t x ts_h x ts_w) tile after padding -- flattened in the order + (t-tile, h-tile, w-tile) that `rearrange` uses. + + Returns + ------- + torch.LongTensor # shape: [∏ full_window_size] + """ + # unpack + t, h, w = dit_seq_shape + ts_t, ts_h, ts_w = VSA_TILE_SIZE + n_t, n_h, n_w = num_tiles + + def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: + """Vector with the size of each tile along one dimension.""" + sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device) + # size of last (possibly partial) tile + remainder = dim_len - (n_tiles - 1) * tile + sizes[-1] = remainder if remainder > 0 else tile + return sizes + + t_sizes = _sizes(t, ts_t, n_t) # [n_t] + h_sizes = _sizes(h, ts_h, n_h) # [n_h] + w_sizes = _sizes(w, ts_w, n_w) # [n_w] + + # broadcast‑multiply to get voxels per tile, then flatten + block_sizes = ( + t_sizes[:, None, None] # [n_t, 1, 1] + * h_sizes[None, :, None] # [1, n_h, 1] + * w_sizes[None, None, :] # [1, 1, n_w] + ).reshape(-1) # [n_t * n_h * n_w] + + return block_sizes + + +@functools.lru_cache(maxsize=10) +def get_non_pad_index( + variable_block_sizes: torch.LongTensor, + max_block_size: int, +): + n_win = variable_block_sizes.shape[0] + device = variable_block_sizes.device + starts_pad = torch.arange(n_win, device=device) * max_block_size + index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] + index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None] + return index_pad[index_mask] + + +def get_vsa_kwargs( + latent_shape: tuple[int, int, int], + patch_size: tuple[int, int, int], + sparsity: float, + device: torch.device, +): + dit_seq_shape = ( + latent_shape[0] // patch_size[0], + latent_shape[1] // patch_size[1], + latent_shape[2] // patch_size[2], + ) + + num_tiles = ( + math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), + math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]), + math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]), + ) + total_seq_length = math.prod(dit_seq_shape) + + tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device) + reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device) + variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device) + non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE)) + + return { + "sparsity": sparsity, + "num_tiles": num_tiles, + "total_seq_length": total_seq_length, + "tile_partition_indices": tile_partition_indices, + "reverse_tile_partition_indices": reverse_tile_partition_indices, + "variable_block_sizes": variable_block_sizes, + "non_pad_index": non_pad_index, + } + + +def tile( + x: torch.Tensor, + num_tiles: tuple[int, int, int], + tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, +) -> torch.Tensor: + t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0] + h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1] + w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2] + + x_padded = torch.zeros( + (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]), + device=x.device, + dtype=x.dtype, + ) + x_padded[:, non_pad_index] = x[:, tile_partition_indices] + return x_padded + + +def untile( + x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor, non_pad_index: torch.LongTensor +) -> torch.Tensor: + x = x[:, non_pad_index][:, reverse_tile_partition_indices] + return x + + +def video_sparse_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + sparsity: float, + num_tiles: tuple[int, int, int], + total_seq_length: int, + tile_partition_indices: torch.LongTensor, + reverse_tile_partition_indices: torch.LongTensor, + variable_block_sizes: torch.LongTensor, + non_pad_index: torch.LongTensor, +): + q = tile(q, num_tiles, tile_partition_indices, non_pad_index) + k = tile(k, num_tiles, tile_partition_indices, non_pad_index) + v = tile(v, num_tiles, tile_partition_indices, non_pad_index) + g = tile(g, num_tiles, tile_partition_indices, non_pad_index) + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + g = g.transpose(1, 2).contiguous() + + topk = math.ceil((1 - sparsity) * (total_seq_length / math.prod(VSA_TILE_SIZE))) + out = vsa_core( + q, + k, + v, + variable_block_sizes=variable_block_sizes, + topk=topk, + block_size=VSA_TILE_SIZE, + compress_attn_weight=g, + ).transpose(1, 2) + out = untile(out, reverse_tile_partition_indices, non_pad_index) + return out + + +def distributed_video_sparse_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + sparsity: float, + num_tiles: tuple[int, int, int], + total_seq_length: int, + tile_partition_indices: torch.LongTensor, + reverse_tile_partition_indices: torch.LongTensor, + variable_block_sizes: torch.LongTensor, + non_pad_index: torch.LongTensor, + scatter_idx: int = 2, + gather_idx: int = 1, +): + from yunchang.comm.all_to_all import SeqAllToAll4D + + assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1" + sp_ulysses_group = get_sp_ulysses_group() + + q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx) + k = SeqAllToAll4D.apply(sp_ulysses_group, k, scatter_idx, gather_idx) + v = SeqAllToAll4D.apply(sp_ulysses_group, v, scatter_idx, gather_idx) + g = SeqAllToAll4D.apply(sp_ulysses_group, g, scatter_idx, gather_idx) + + out = video_sparse_attn( + q, + k, + v, + g, + sparsity, + num_tiles, + total_seq_length, + tile_partition_indices, + reverse_tile_partition_indices, + variable_block_sizes, + non_pad_index, + ) + + out = SeqAllToAll4D.apply(sp_ulysses_group, out, gather_idx, scatter_idx) + return out diff --git a/diffsynth_engine/models/flux/flux_controlnet.py b/diffsynth_engine/models/flux/flux_controlnet.py index 889c28ec..27eef140 100644 --- a/diffsynth_engine/models/flux/flux_controlnet.py +++ b/diffsynth_engine/models/flux/flux_controlnet.py @@ -86,7 +86,6 @@ class FluxControlNet(PreTrainedModel): def __init__( self, condition_channels: int = 64, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -103,10 +102,7 @@ def __init__( self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype) self.controlnet_x_embedder = nn.Linear(condition_channels, 3072) self.blocks = nn.ModuleList( - [ - FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype) - for _ in range(6) - ] + [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(6)] ) # controlnet projection self.blocks_proj = nn.ModuleList( @@ -128,6 +124,7 @@ def forward( image_ids: torch.Tensor, text_ids: torch.Tensor, guidance: torch.Tensor, + attn_kwargs: Optional[Dict[str, Any]] = None, ): hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition) condition = ( @@ -141,7 +138,9 @@ def forward( # double block double_block_outputs = [] for i, block in enumerate(self.blocks): - hidden_states, prompt_emb = block(hidden_states, prompt_emb, condition, image_rotary_emb) + hidden_states, prompt_emb = block( + hidden_states, prompt_emb, condition, image_rotary_emb, attn_kwargs=attn_kwargs + ) double_block_outputs.append(self.blocks_proj[i](hidden_states)) # apply control scale @@ -149,24 +148,13 @@ def forward( return double_block_outputs, None @classmethod - def from_state_dict( - cls, - state_dict: Dict[str, torch.Tensor], - device: str, - dtype: torch.dtype, - attn_kwargs: Optional[Dict[str, Any]] = None, - ): + def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype): if "controlnet_x_embedder.weight" in state_dict: condition_channels = state_dict["controlnet_x_embedder.weight"].shape[1] else: condition_channels = 64 - model = cls( - condition_channels=condition_channels, - attn_kwargs=attn_kwargs, - device="meta", - dtype=dtype, - ) + model = cls(condition_channels=condition_channels, device="meta", dtype=dtype) model.requires_grad_(False) model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) diff --git a/diffsynth_engine/models/flux/flux_dit.py b/diffsynth_engine/models/flux/flux_dit.py index 2fa0dd72..0d3470a2 100644 --- a/diffsynth_engine/models/flux/flux_dit.py +++ b/diffsynth_engine/models/flux/flux_dit.py @@ -176,7 +176,6 @@ def __init__( dim_b, num_heads, head_dim, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -194,19 +193,20 @@ def __init__( self.a_to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) self.b_to_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} def attention_callback(self, attn_out_a, attn_out_b, x_a, x_b, q_a, q_b, k_a, k_b, v_a, v_b, rope_emb, image_emb): return attn_out_a, attn_out_b - def forward(self, image, text, rope_emb, image_emb): + def forward(self, image, text, rope_emb, image_emb, attn_kwargs=None): q_a, k_a, v_a = rearrange(self.a_to_qkv(image), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2) q_b, k_b, v_b = rearrange(self.b_to_qkv(text), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2) q = torch.cat([self.norm_q_b(q_b), self.norm_q_a(q_a)], dim=1) k = torch.cat([self.norm_k_b(k_b), self.norm_k_a(k_a)], dim=1) v = torch.cat([v_b, v_a], dim=1) q, k = apply_rope(q, k, rope_emb) - attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + attn_out = attention_ops.attention(q, k, v, **attn_kwargs) attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype) text_out, image_out = attn_out[:, : text.shape[1]], attn_out[:, text.shape[1] :] image_out, text_out = self.attention_callback( @@ -231,14 +231,11 @@ def __init__( self, dim, num_heads, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): super().__init__() - self.attn = FluxDoubleAttention( - dim, dim, num_heads, dim // num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype - ) + self.attn = FluxDoubleAttention(dim, dim, num_heads, dim // num_heads, device=device, dtype=dtype) # Image self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype) self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype) @@ -256,11 +253,11 @@ def __init__( nn.Linear(dim * 4, dim, device=device, dtype=dtype), ) - def forward(self, image, text, t_emb, rope_emb, image_emb=None): + def forward(self, image, text, t_emb, rope_emb, image_emb=None, attn_kwargs=None): # AdaLayerNorm-Zero for Image and Text MSA image_in, gate_a = self.norm_msa_a(image, t_emb) text_in, gate_b = self.norm_msa_b(text, t_emb) - image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb) + image_out, text_out = self.attn(image_in, text_in, rope_emb, image_emb, attn_kwargs) image = image + gate_a * image_out text = text + gate_b * text_out @@ -279,7 +276,6 @@ def __init__( self, dim, num_heads, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -288,15 +284,16 @@ def __init__( self.to_qkv = nn.Linear(dim, dim * 3, device=device, dtype=dtype) self.norm_q_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype) self.norm_k_a = RMSNorm(dim // num_heads, eps=1e-6, device=device, dtype=dtype) - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} def attention_callback(self, attn_out, x, q, k, v, rope_emb, image_emb): return attn_out - def forward(self, x, rope_emb, image_emb): + def forward(self, x, rope_emb, image_emb, attn_kwargs=None): q, k, v = rearrange(self.to_qkv(x), "b s (h d) -> b s h d", h=(3 * self.num_heads)).chunk(3, dim=2) q, k = apply_rope(self.norm_q_a(q), self.norm_k_a(k), rope_emb) - attn_out = attention_ops.attention(q, k, v, **self.attn_kwargs) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + attn_out = attention_ops.attention(q, k, v, **attn_kwargs) attn_out = rearrange(attn_out, "b s h d -> b s (h d)").to(q.dtype) return self.attention_callback(attn_out=attn_out, x=x, q=q, k=k, v=v, rope_emb=rope_emb, image_emb=image_emb) @@ -306,23 +303,22 @@ def __init__( self, dim, num_heads, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): super().__init__() self.dim = dim self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype) - self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + self.attn = FluxSingleAttention(dim, num_heads, device=device, dtype=dtype) self.mlp = nn.Sequential( nn.Linear(dim, dim * 4, device=device, dtype=dtype), nn.GELU(approximate="tanh"), ) self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype) - def forward(self, x, t_emb, rope_emb, image_emb=None): + def forward(self, x, t_emb, rope_emb, image_emb=None, attn_kwargs=None): h, gate = self.norm(x, emb=t_emb) - attn_output = self.attn(h, rope_emb, image_emb) + attn_output = self.attn(h, rope_emb, image_emb, attn_kwargs) mlp_output = self.mlp(h) return x + gate * self.proj_out(torch.cat([attn_output, mlp_output], dim=2)) @@ -334,7 +330,6 @@ class FluxDiT(PreTrainedModel): def __init__( self, in_channel: int = 64, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -352,16 +347,10 @@ def __init__( self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype) self.blocks = nn.ModuleList( - [ - FluxDoubleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype) - for _ in range(19) - ] + [FluxDoubleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(19)] ) self.single_blocks = nn.ModuleList( - [ - FluxSingleTransformerBlock(3072, 24, attn_kwargs=attn_kwargs, device=device, dtype=dtype) - for _ in range(38) - ] + [FluxSingleTransformerBlock(3072, 24, device=device, dtype=dtype) for _ in range(38)] ) self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype) self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype) @@ -403,6 +392,7 @@ def forward( text_ids: torch.Tensor, guidance: torch.Tensor, image_emb: torch.Tensor | None = None, + attn_kwargs: Optional[Dict[str, Any]] = None, controlnet_double_block_output: List[torch.Tensor] | None = None, controlnet_single_block_output: List[torch.Tensor] | None = None, **kwargs, @@ -470,14 +460,16 @@ def forward( rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2) for i, block in enumerate(self.blocks): - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb) + hidden_states, prompt_emb = block( + hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs + ) if len(controlnet_double_block_output) > 0: interval_control = len(self.blocks) / len(controlnet_double_block_output) interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_double_block_output[i // interval_control] hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for i, block in enumerate(self.single_blocks): - hidden_states = block(hidden_states, conditioning, rope_emb, image_emb) + hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs) if len(controlnet_single_block_output) > 0: interval_control = len(self.single_blocks) / len(controlnet_double_block_output) interval_control = int(np.ceil(interval_control)) @@ -498,14 +490,8 @@ def from_state_dict( device: str, dtype: torch.dtype, in_channel: int = 64, - attn_kwargs: Optional[Dict[str, Any]] = None, ): - model = cls( - device="meta", - dtype=dtype, - in_channel=in_channel, - attn_kwargs=attn_kwargs, - ) + model = cls(device="meta", dtype=dtype, in_channel=in_channel) model = model.requires_grad_(False) model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) diff --git a/diffsynth_engine/models/flux/flux_dit_fbcache.py b/diffsynth_engine/models/flux/flux_dit_fbcache.py index 1b7d59d1..15c41c50 100644 --- a/diffsynth_engine/models/flux/flux_dit_fbcache.py +++ b/diffsynth_engine/models/flux/flux_dit_fbcache.py @@ -20,12 +20,11 @@ class FluxDiTFBCache(FluxDiT): def __init__( self, in_channel: int = 64, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, relative_l1_threshold: float = 0.05, ): - super().__init__(in_channel=in_channel, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + super().__init__(in_channel=in_channel, device=device, dtype=dtype) self.relative_l1_threshold = relative_l1_threshold self.step_count = 0 self.num_inference_steps = 0 @@ -56,6 +55,7 @@ def forward( text_ids: torch.Tensor, guidance: torch.Tensor, image_emb: torch.Tensor | None = None, + attn_kwargs: Optional[Dict[str, Any]] = None, controlnet_double_block_output: List[torch.Tensor] | None = None, controlnet_single_block_output: List[torch.Tensor] | None = None, **kwargs, @@ -124,7 +124,9 @@ def forward( # first block original_hidden_states = hidden_states - hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb) + hidden_states, prompt_emb = self.blocks[0]( + hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs + ) first_hidden_states_residual = hidden_states - original_hidden_states (first_hidden_states_residual,) = sequence_parallel_unshard( @@ -149,14 +151,16 @@ def forward( first_hidden_states = hidden_states.clone() for i, block in enumerate(self.blocks[1:]): - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb) + hidden_states, prompt_emb = block( + hidden_states, prompt_emb, conditioning, rope_emb, image_emb, attn_kwargs + ) if len(controlnet_double_block_output) > 0: interval_control = len(self.blocks) / len(controlnet_double_block_output) interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_double_block_output[i // interval_control] hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for i, block in enumerate(self.single_blocks): - hidden_states = block(hidden_states, conditioning, rope_emb, image_emb) + hidden_states = block(hidden_states, conditioning, rope_emb, image_emb, attn_kwargs) if len(controlnet_single_block_output) > 0: interval_control = len(self.single_blocks) / len(controlnet_double_block_output) interval_control = int(np.ceil(interval_control)) @@ -182,14 +186,12 @@ def from_state_dict( device: str, dtype: torch.dtype, in_channel: int = 64, - attn_kwargs: Optional[Dict[str, Any]] = None, relative_l1_threshold: float = 0.05, ): model = cls( device="meta", dtype=dtype, in_channel=in_channel, - attn_kwargs=attn_kwargs, relative_l1_threshold=relative_l1_threshold, ) model = model.requires_grad_(False) diff --git a/diffsynth_engine/models/flux/flux_ipadapter.py b/diffsynth_engine/models/flux/flux_ipadapter.py index cd5063ec..15280dd7 100644 --- a/diffsynth_engine/models/flux/flux_ipadapter.py +++ b/diffsynth_engine/models/flux/flux_ipadapter.py @@ -2,7 +2,7 @@ from einops import rearrange from torch import nn from PIL import Image -from typing import Any, Dict, List, Optional +from typing import Dict, List from functools import partial from diffsynth_engine.models.text_encoder.siglip import SiglipImageEncoder from diffsynth_engine.models.basic.transformer_helper import RMSNorm @@ -18,7 +18,6 @@ def __init__( dim: int = 3072, head_num: int = 24, scale: float = 1.0, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -28,12 +27,13 @@ def __init__( self.to_v_ip = nn.Linear(image_emb_dim, dim, device=device, dtype=dtype, bias=False) self.head_num = head_num self.scale = scale - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} - def forward(self, query: torch.Tensor, image_emb: torch.Tensor): + def forward(self, query: torch.Tensor, image_emb: torch.Tensor, attn_kwargs=None): key = rearrange(self.norm_k(self.to_k_ip(image_emb)), "b s (h d) -> b s h d", h=self.head_num) value = rearrange(self.to_v_ip(image_emb), "b s (h d) -> b s h d", h=self.head_num) - attn_out = attention(query, key, value, **self.attn_kwargs) + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + attn_out = attention(query, key, value, **attn_kwargs) return self.scale * rearrange(attn_out, "b s h d -> b s (h d)") @classmethod diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit.py b/diffsynth_engine/models/qwen_image/qwen_image_dit.py index f8d81c18..77e0f551 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit.py @@ -167,7 +167,6 @@ def __init__( dim_b, num_heads, head_dim, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -189,7 +188,6 @@ def __init__( self.to_out = nn.Linear(dim_a, dim_a, device=device, dtype=dtype) self.to_add_out = nn.Linear(dim_b, dim_b, device=device, dtype=dtype) - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} def forward( self, @@ -197,6 +195,7 @@ def forward( text: torch.FloatTensor, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image) txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text) @@ -227,7 +226,8 @@ def forward( joint_k = joint_k.transpose(1, 2) joint_v = joint_v.transpose(1, 2) - joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs) + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs) joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype) @@ -247,7 +247,6 @@ def __init__( num_attention_heads: int, attention_head_dim: int, eps: float = 1e-6, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -267,7 +266,6 @@ def __init__( dim_b=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, - attn_kwargs=attn_kwargs, device=device, dtype=dtype, ) @@ -293,6 +291,7 @@ def forward( temb: torch.Tensor, rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_mask: Optional[torch.Tensor] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each @@ -308,6 +307,7 @@ def forward( text=txt_modulated, rotary_emb=rotary_emb, attn_mask=attn_mask, + attn_kwargs=attn_kwargs, ) image = image + img_gate * img_attn_out @@ -335,7 +335,6 @@ class QwenImageDiT(PreTrainedModel): def __init__( self, num_layers: int = 60, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -356,7 +355,6 @@ def __init__( dim=3072, num_attention_heads=24, attention_head_dim=128, - attn_kwargs=attn_kwargs, device=device, dtype=dtype, ) @@ -444,6 +442,7 @@ def forward( entity_text: Optional[List[torch.Tensor]] = None, entity_seq_lens: Optional[List[torch.LongTensor]] = None, entity_masks: Optional[List[torch.Tensor]] = None, + attn_kwargs: Optional[Dict[str, Any]] = None, ): h, w = image.shape[-2:] fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) @@ -509,7 +508,12 @@ def forward( rotary_emb = (img_freqs, txt_freqs) for block in self.transformer_blocks: text, image = block( - image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask + image=image, + text=text, + temb=conditioning, + rotary_emb=rotary_emb, + attn_mask=attn_mask, + attn_kwargs=attn_kwargs, ) image = self.norm_out(image, conditioning) image = self.proj_out(image) @@ -527,14 +531,8 @@ def from_state_dict( device: str, dtype: torch.dtype, num_layers: int = 60, - attn_kwargs: Optional[Dict[str, Any]] = None, ): - model = cls( - device="meta", - dtype=dtype, - num_layers=num_layers, - attn_kwargs=attn_kwargs, - ) + model = cls(device="meta", dtype=dtype, num_layers=num_layers) model = model.requires_grad_(False) model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) diff --git a/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py b/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py index b4c28abe..641168b5 100644 --- a/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py +++ b/diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py @@ -11,12 +11,11 @@ class QwenImageDiTFBCache(QwenImageDiT): def __init__( self, num_layers: int = 60, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, relative_l1_threshold: float = 0.05, ): - super().__init__(num_layers=num_layers, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + super().__init__(num_layers=num_layers, device=device, dtype=dtype) self.relative_l1_threshold = relative_l1_threshold self.step_count = 0 self.num_inference_steps = 0 @@ -43,6 +42,7 @@ def forward( text: torch.Tensor = None, timestep: torch.LongTensor = None, txt_seq_lens: torch.LongTensor = None, + attn_kwargs: Optional[Dict[str, Any]] = None, ): h, w = image.shape[-2:] fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) @@ -72,7 +72,11 @@ def forward( # first block original_hidden_states = image text, image = self.transformer_blocks[0]( - image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attn_kwargs=attn_kwargs, ) first_hidden_states_residual = image - original_hidden_states @@ -94,7 +98,13 @@ def forward( first_hidden_states = image.clone() for block in self.transformer_blocks[1:]: - text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb) + text, image = block( + image=image, + text=text, + temb=conditioning, + image_rotary_emb=image_rotary_emb, + attn_kwargs=attn_kwargs, + ) previous_residual = image - first_hidden_states self.previous_residual = previous_residual @@ -114,14 +124,12 @@ def from_state_dict( device: str, dtype: torch.dtype, num_layers: int = 60, - attn_kwargs: Optional[Dict[str, Any]] = None, relative_l1_threshold: float = 0.05, ): model = cls( device="meta", dtype=dtype, num_layers=num_layers, - attn_kwargs=attn_kwargs, relative_l1_threshold=relative_l1_threshold, ) model = model.requires_grad_(False) diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 0c274836..948d6751 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -17,6 +17,7 @@ WAN2_2_DIT_TI2V_5B_CONFIG_FILE, WAN2_2_DIT_I2V_A14B_CONFIG_FILE, WAN2_2_DIT_T2V_A14B_CONFIG_FILE, + WAN_DIT_KEYMAP_FILE, ) from diffsynth_engine.utils.gguf import gguf_inference from diffsynth_engine.utils.fp8_linear import fp8_inference @@ -30,6 +31,9 @@ T5_TOKEN_NUM = 512 FLF_TOKEN_NUM = 257 * 2 +with open(WAN_DIT_KEYMAP_FILE, "r", encoding="utf-8") as f: + config = json.load(f) + def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): return x * (1 + scale) + shift @@ -73,7 +77,7 @@ def __init__( dim: int, num_heads: int, eps: float = 1e-6, - attn_kwargs: Optional[Dict[str, Any]] = None, + use_vsa: bool = False, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -86,19 +90,25 @@ def __init__( self.o = nn.Linear(dim, dim, device=device, dtype=dtype) self.norm_q = RMSNorm(dim, eps=eps, device=device, dtype=dtype) self.norm_k = RMSNorm(dim, eps=eps, device=device, dtype=dtype) - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + self.gate_compress = nn.Linear(dim, dim, device=device, dtype=dtype) if use_vsa else None - def forward(self, x, freqs): + def forward(self, x, freqs, attn_kwargs=None): q, k, v = self.norm_q(self.q(x)), self.norm_k(self.k(x)), self.v(x) + g = self.gate_compress(x) if self.gate_compress is not None else None + num_heads = q.shape[2] // self.head_dim q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) + g = rearrange(g, "b s (n d) -> b s n d", n=num_heads) if g is not None else None + + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} x = attention_ops.attention( q=rope_apply(q, freqs), k=rope_apply(k, freqs), v=v, - **self.attn_kwargs, + g=g, + **attn_kwargs, ) x = x.flatten(2) return self.o(x) @@ -111,7 +121,6 @@ def __init__( num_heads: int, eps: float = 1e-6, has_image_input: bool = False, - attn_kwargs: Optional[Dict[str, Any]] = None, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -130,9 +139,8 @@ def __init__( self.k_img = nn.Linear(dim, dim, device=device, dtype=dtype) self.v_img = nn.Linear(dim, dim, device=device, dtype=dtype) self.norm_k_img = RMSNorm(dim, eps=eps, device=device, dtype=dtype) - self.attn_kwargs = attn_kwargs if attn_kwargs is not None else {} - def forward(self, x: torch.Tensor, y: torch.Tensor): + def forward(self, x: torch.Tensor, y: torch.Tensor, attn_kwargs=None): if self.has_image_input: img = y[:, :-T5_TOKEN_NUM] ctx = y[:, -T5_TOKEN_NUM:] @@ -144,12 +152,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) - x = attention(q, k, v, **self.attn_kwargs).flatten(2) + attn_kwargs = attn_kwargs if attn_kwargs is not None else {} + if attn_kwargs.get("attn_impl", None) == "vsa": + attn_kwargs = attn_kwargs.copy() + attn_kwargs["attn_impl"] = "sdpa" + x = attention(q, k, v, **attn_kwargs).flatten(2) if self.has_image_input: k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img) k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads) v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads) - y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2) + y = attention(q, k_img, v_img, **attn_kwargs).flatten(2) x = x + y return self.o(x) @@ -162,7 +174,7 @@ def __init__( num_heads: int, ffn_dim: int, eps: float = 1e-6, - attn_kwargs: Optional[Dict[str, Any]] = None, + use_vsa: bool = False, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -170,9 +182,9 @@ def __init__( self.dim = dim self.num_heads = num_heads self.ffn_dim = ffn_dim - self.self_attn = SelfAttention(dim, num_heads, eps, attn_kwargs=attn_kwargs, device=device, dtype=dtype) + self.self_attn = SelfAttention(dim, num_heads, eps, use_vsa=use_vsa, device=device, dtype=dtype) self.cross_attn = CrossAttention( - dim, num_heads, eps, has_image_input=has_image_input, attn_kwargs=attn_kwargs, device=device, dtype=dtype + dim, num_heads, eps, has_image_input=has_image_input, device=device, dtype=dtype ) self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype) self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, device=device, dtype=dtype) @@ -184,14 +196,14 @@ def __init__( ) self.modulation = nn.Parameter(torch.randn(1, 6, dim, device=device, dtype=dtype) / dim**0.5) - def forward(self, x, context, t_mod, freqs): + def forward(self, x, context, t_mod, freqs, attn_kwargs=None): # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ t.squeeze(1) for t in (self.modulation + t_mod).chunk(6, dim=1) ] input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = x + gate_msa * self.self_attn(input_x, freqs) - x = x + self.cross_attn(self.norm3(x), context) + x = x + gate_msa * self.self_attn(input_x, freqs, attn_kwargs) + x = x + self.cross_attn(self.norm3(x), context, attn_kwargs) input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) x = x + gate_mlp * self.ffn(input_x) return x @@ -249,7 +261,26 @@ def forward(self, x, t_mod): class WanDiTStateDictConverter(StateDictConverter): + def _from_diffusers(self, state_dict): + global_rename_dict = config["diffusers"]["global_rename_dict"] + rename_dict = config["diffusers"]["rename_dict"] + state_dict_ = {} + for name, param in state_dict.items(): + suffix = "" + suffix = ".weight" if name.endswith(".weight") else suffix + suffix = ".bias" if name.endswith(".bias") else suffix + prefix = name[: -len(suffix)] if suffix else name + if prefix in global_rename_dict: + state_dict_[f"{global_rename_dict[prefix]}{suffix}"] = param + if prefix.startswith("blocks."): + _, idx, middle = prefix.split(".", 2) + if middle in rename_dict: + state_dict_[f"blocks.{idx}.{rename_dict[middle]}{suffix}"] = param + return state_dict_ + def convert(self, state_dict): + if "condition_embedder.time_proj.weight" in state_dict: + return self._from_diffusers(state_dict) return state_dict @@ -273,7 +304,7 @@ def __init__( has_vae_feature: bool = False, fuse_image_latents: bool = False, flf_pos_emb: bool = False, - attn_kwargs: Optional[Dict[str, Any]] = None, + use_vsa: bool = False, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, ): @@ -307,7 +338,16 @@ def __init__( ) self.blocks = nn.ModuleList( [ - DiTBlock(has_clip_feature, dim, num_heads, ffn_dim, eps, attn_kwargs, device=device, dtype=dtype) + DiTBlock( + has_clip_feature, + dim, + num_heads, + ffn_dim, + eps, + use_vsa, + device=device, + dtype=dtype, + ) for _ in range(num_layers) ] ) @@ -344,6 +384,7 @@ def forward( timestep: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img) y: Optional[torch.Tensor] = None, # vae_encoder(img) + attn_kwargs: Optional[Dict[str, Any]] = None, ): fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False) use_cfg = x.shape[0] > 1 @@ -376,7 +417,7 @@ def forward( with sequence_parallel((x, t, t_mod, freqs), seq_dims=(1, 0, 0, 0)): for block in self.blocks: - x = block(x, context, t_mod, freqs) + x = block(x, context, t_mod, freqs, attn_kwargs) x = self.head(x, t) (x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,)) x = self.unpatchify(x, (f, h, w)) @@ -409,12 +450,11 @@ def from_state_dict( config: Dict[str, Any], device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16, - attn_kwargs: Optional[Dict[str, Any]] = None, - assign: bool = True, + use_vsa: bool = False, ): - model = cls(**config, device="meta", dtype=dtype, attn_kwargs=attn_kwargs) + model = cls(**config, device="meta", dtype=dtype, use_vsa=use_vsa) model = model.requires_grad_(False) - model.load_state_dict(state_dict, assign=assign) + model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) return model diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index e7b2b409..f698a707 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -17,7 +17,12 @@ flux_dit_config, flux_text_encoder_config, ) -from diffsynth_engine.configs import FluxPipelineConfig, FluxStateDicts, ControlType, ControlNetParams +from diffsynth_engine.configs import ( + FluxPipelineConfig, + FluxStateDicts, + ControlType, + ControlNetParams, +) from diffsynth_engine.models.basic.lora import LoRAContext from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter from diffsynth_engine.pipelines.utils import accumulate, calculate_shift @@ -143,7 +148,7 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic layer_id, layer_type = name.split("_", 1) layer_type = layer_type.replace("self_attn_", "self_attn.").replace("mlp_", "mlp.") rename = ".".join(["encoders", layer_id, clip_attn_rename_dict[layer_type]]) - + lora_args = {} lora_args["alpha"] = param lora_args["up"] = lora_state_dict[origin_key.replace(".alpha", ".lora_up.weight")] @@ -507,20 +512,12 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi vae_encoder = FluxVAEEncoder.from_state_dict(state_dicts.vae, device=init_device, dtype=config.vae_dtype) with LoRAContext(): - attn_kwargs = { - "attn_impl": config.dit_attn_impl.value, - "sparge_smooth_k": config.sparge_smooth_k, - "sparge_cdfthreshd": config.sparge_cdfthreshd, - "sparge_simthreshd1": config.sparge_simthreshd1, - "sparge_pvthreshd": config.sparge_pvthreshd, - } if config.use_fbcache: dit = FluxDiTFBCache.from_state_dict( state_dicts.model, device=init_device, dtype=config.model_dtype, in_channel=config.control_type.get_in_channel(), - attn_kwargs=attn_kwargs, relative_l1_threshold=config.fbcache_relative_l1_threshold, ) else: @@ -529,7 +526,6 @@ def _from_state_dict(cls, state_dicts: FluxStateDicts, config: FluxPipelineConfi device=init_device, dtype=config.model_dtype, in_channel=config.control_type.get_in_channel(), - attn_kwargs=attn_kwargs, ) if config.use_fp8_linear: enable_fp8_linear(dit) @@ -755,6 +751,7 @@ def predict_noise( latents = latents.to(self.dtype) self.load_models_to_device(["dit"]) + attn_kwargs = self.config.get_attn_kwargs(latents, self.device) noise_pred = self.dit( hidden_states=latents, timestep=timestep, @@ -766,6 +763,7 @@ def predict_noise( image_ids=image_ids, controlnet_double_block_output=double_block_output, controlnet_single_block_output=single_block_output, + attn_kwargs=attn_kwargs, ) noise_pred = noise_pred[:, :image_seq_len] noise_pred = self.dit.unpatchify(noise_pred, height, width) @@ -887,6 +885,8 @@ def predict_multicontrolnet( if self.offload_mode is not None: empty_cache() param.model.to(self.device) + + attn_kwargs = self.config.get_attn_kwargs(latents, self.device) double_block_output, single_block_output = param.model( hidden_states=latents, control_condition=control_condition, @@ -897,6 +897,7 @@ def predict_multicontrolnet( image_ids=image_ids, text_ids=text_ids, guidance=guidance, + attn_kwargs=attn_kwargs, ) if self.offload_mode is not None: param.model.to("cpu") diff --git a/diffsynth_engine/pipelines/qwen_image.py b/diffsynth_engine/pipelines/qwen_image.py index 55d20b24..41af4a1d 100644 --- a/diffsynth_engine/pipelines/qwen_image.py +++ b/diffsynth_engine/pipelines/qwen_image.py @@ -91,7 +91,7 @@ def _from_diffusers(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, if "lora_A.weight" in key: lora_a_suffix = "lora_A.weight" lora_b_suffix = "lora_B.weight" - + if lora_a_suffix is None: continue @@ -247,19 +247,11 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip ) with LoRAContext(): - attn_kwargs = { - "attn_impl": config.dit_attn_impl.value, - "sparge_smooth_k": config.sparge_smooth_k, - "sparge_cdfthreshd": config.sparge_cdfthreshd, - "sparge_simthreshd1": config.sparge_simthreshd1, - "sparge_pvthreshd": config.sparge_pvthreshd, - } if config.use_fbcache: dit = QwenImageDiTFBCache.from_state_dict( state_dicts.model, device=init_device, dtype=config.model_dtype, - attn_kwargs=attn_kwargs, relative_l1_threshold=config.fbcache_relative_l1_threshold, ) else: @@ -267,7 +259,6 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip state_dicts.model, device=init_device, dtype=config.model_dtype, - attn_kwargs=attn_kwargs, ) if config.use_fp8_linear: enable_fp8_linear(dit) @@ -542,6 +533,7 @@ def predict_noise( entity_masks: Optional[List[torch.Tensor]] = None, ): self.load_models_to_device(["dit"]) + attn_kwargs = self.config.get_attn_kwargs(latents, self.device) noise_pred = self.dit( image=latents, edit=image_latents, @@ -552,6 +544,7 @@ def predict_noise( entity_text=entity_prompt_embs, entity_seq_lens=[mask.sum(dim=1) for mask in entity_prompt_emb_masks] if entity_prompt_emb_masks else None, entity_masks=entity_masks, + attn_kwargs=attn_kwargs, ) return noise_pred diff --git a/diffsynth_engine/pipelines/wan_s2v.py b/diffsynth_engine/pipelines/wan_s2v.py index ab7d001b..f8cafb6a 100644 --- a/diffsynth_engine/pipelines/wan_s2v.py +++ b/diffsynth_engine/pipelines/wan_s2v.py @@ -394,6 +394,7 @@ def predict_noise( void_audio_input: torch.Tensor | None = None, ): latents = latents.to(dtype=self.config.model_dtype, device=self.device) + attn_kwargs = self.config.get_attn_kwargs(latents, self.device) noise_pred = model( x=latents, @@ -408,6 +409,7 @@ def predict_noise( drop_motion_frames=drop_motion_frames, audio_mask=audio_mask, void_audio_input=void_audio_input, + attn_kwargs=attn_kwargs, ) return noise_pred @@ -654,19 +656,12 @@ def _from_state_dict( ) with LoRAContext(): - attn_kwargs = { - "attn_impl": config.dit_attn_impl.value, - "sparge_smooth_k": config.sparge_smooth_k, - "sparge_cdfthreshd": config.sparge_cdfthreshd, - "sparge_simthreshd1": config.sparge_simthreshd1, - "sparge_pvthreshd": config.sparge_pvthreshd, - } dit = WanS2VDiT.from_state_dict( state_dicts.model, config=model_config, device=init_device, dtype=config.model_dtype, - attn_kwargs=attn_kwargs, + use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: enable_fp8_linear(dit) diff --git a/diffsynth_engine/pipelines/wan_video.py b/diffsynth_engine/pipelines/wan_video.py index abae7b1a..538562a0 100644 --- a/diffsynth_engine/pipelines/wan_video.py +++ b/diffsynth_engine/pipelines/wan_video.py @@ -301,6 +301,7 @@ def predict_noise_with_cfg( def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, context): latents = latents.to(dtype=self.config.model_dtype, device=self.device) + attn_kwargs = self.config.get_attn_kwargs(latents, self.device) noise_pred = model( x=latents, @@ -308,6 +309,7 @@ def predict_noise(self, model, latents, image_clip_feature, image_y, timestep, c context=context, clip_feature=image_clip_feature, y=image_y, + attn_kwargs=attn_kwargs, ) return noise_pred @@ -556,19 +558,12 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig) dit_state_dict = state_dicts.model with LoRAContext(): - attn_kwargs = { - "attn_impl": config.dit_attn_impl.value, - "sparge_smooth_k": config.sparge_smooth_k, - "sparge_cdfthreshd": config.sparge_cdfthreshd, - "sparge_simthreshd1": config.sparge_simthreshd1, - "sparge_pvthreshd": config.sparge_pvthreshd, - } dit = WanDiT.from_state_dict( dit_state_dict, config=dit_config, device=init_device, dtype=config.model_dtype, - attn_kwargs=attn_kwargs, + use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: enable_fp8_linear(dit) @@ -580,7 +575,7 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig) config=dit_config, device=init_device, dtype=config.model_dtype, - attn_kwargs=attn_kwargs, + use_vsa=(config.dit_attn_impl.value == "vsa"), ) if config.use_fp8_linear: enable_fp8_linear(dit2) @@ -618,19 +613,22 @@ def _from_state_dict(cls, state_dicts: WanStateDicts, config: WanPipelineConfig) @staticmethod def _get_dit_type(model_state_dict: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]) -> str: # determine wan dit type by model params + def has_any_key(*xs): + return any(x in model_state_dict for x in xs) + dit_type = None - if "high_noise_model" in model_state_dict and "low_noise_model" in model_state_dict: + if has_any_key("high_noise_model"): if model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 36: dit_type = "wan2.2-i2v-a14b" elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16: dit_type = "wan2.2-t2v-a14b" elif model_state_dict["patch_embedding.weight"].shape[1] == 48: dit_type = "wan2.2-ti2v-5b" - elif "img_emb.emb_pos" in model_state_dict: + elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"): dit_type = "wan2.1-flf2v-14b" - elif "img_emb.proj.0.weight" in model_state_dict: + elif has_any_key("img_emb.proj.0.weight", "condition_embedder.image_embedder.norm1"): dit_type = "wan2.1-i2v-14b" - elif "blocks.39.self_attn.norm_q.weight" in model_state_dict: + elif has_any_key("blocks.39.self_attn.norm_q.weight", "blocks.39.attn1.norm_q.weight"): dit_type = "wan2.1-t2v-14b" else: dit_type = "wan2.1-t2v-1.3b" diff --git a/diffsynth_engine/utils/constants.py b/diffsynth_engine/utils/constants.py index d001123a..9e38dee4 100644 --- a/diffsynth_engine/utils/constants.py +++ b/diffsynth_engine/utils/constants.py @@ -27,18 +27,19 @@ SDXL_TEXT_ENCODER_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_text_encoder.json") SDXL_UNET_CONFIG_FILE = os.path.join(CONF_PATH, "models", "sdxl", "sdxl_unet.json") -WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-1.3b.json") -WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-t2v-14b.json") -WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-i2v-14b.json") -WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1-flf2v-14b.json") -WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-ti2v-5b.json") -WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-t2v-a14b.json") -WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-i2v-a14b.json") -WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2-s2v-14b.json") - -WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1-vae.json") -WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2-vae.json") -WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan-vae-keymap.json") +WAN2_1_DIT_T2V_1_3B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_1.3b.json") +WAN2_1_DIT_T2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_t2v_14b.json") +WAN2_1_DIT_I2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_i2v_14b.json") +WAN2_1_DIT_FLF2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.1_flf2v_14b.json") +WAN2_2_DIT_TI2V_5B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_ti2v_5b.json") +WAN2_2_DIT_T2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_t2v_a14b.json") +WAN2_2_DIT_I2V_A14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_i2v_a14b.json") +WAN2_2_DIT_S2V_14B_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan2.2_s2v_14b.json") +WAN_DIT_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "dit", "wan_dit_keymap.json") + +WAN2_1_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.1_vae.json") +WAN2_2_VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan2.2_vae.json") +WAN_VAE_KEYMAP_FILE = os.path.join(CONF_PATH, "models", "wan", "vae", "wan_vae_keymap.json") QWEN_IMAGE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_config.json") QWEN_IMAGE_VISION_CONFIG_FILE = os.path.join(CONF_PATH, "models", "qwen_image", "qwen2_5_vl_vision_config.json") diff --git a/diffsynth_engine/utils/flag.py b/diffsynth_engine/utils/flag.py index b14388d9..7ac0b3e3 100644 --- a/diffsynth_engine/utils/flag.py +++ b/diffsynth_engine/utils/flag.py @@ -44,3 +44,9 @@ logger.info("Sparge attention is available") else: logger.info("Sparge attention is not available") + +VIDEO_SPARSE_ATTN_AVAILABLE = importlib.util.find_spec("vsa") is not None +if VIDEO_SPARSE_ATTN_AVAILABLE: + logger.info("Video sparse attention is available") +else: + logger.info("Video sparse attention is not available") diff --git a/diffsynth_engine/utils/parallel.py b/diffsynth_engine/utils/parallel.py index 62c2b99e..eb954f40 100644 --- a/diffsynth_engine/utils/parallel.py +++ b/diffsynth_engine/utils/parallel.py @@ -40,10 +40,14 @@ class ProcessGroupSingleton(Singleton): def __init__(self): self.CFG_GROUP: Optional[dist.ProcessGroup] = None self.SP_GROUP: Optional[dist.ProcessGroup] = None + self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None + self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None self.TP_GROUP: Optional[dist.ProcessGroup] = None self.CFG_RANKS: List[int] = [] self.SP_RANKS: List[int] = [] + self.SP_ULYSSUES_RANKS: List[int] = [] + self.SP_RING_RANKS: List[int] = [] self.TP_RANKS: List[int] = [] @@ -82,6 +86,38 @@ def get_sp_ranks(): return PROCESS_GROUP.SP_RANKS +def get_sp_ulysses_group(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP + + +def get_sp_ulysses_world_size(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1 + + +def get_sp_ulysses_rank(): + return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0 + + +def get_sp_ulysses_ranks(): + return PROCESS_GROUP.SP_ULYSSUES_RANKS + + +def get_sp_ring_group(): + return PROCESS_GROUP.SP_RING_GROUP + + +def get_sp_ring_world_size(): + return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1 + + +def get_sp_ring_rank(): + return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0 + + +def get_sp_ring_ranks(): + return PROCESS_GROUP.SP_RING_RANKS + + def get_tp_group(): return PROCESS_GROUP.TP_GROUP @@ -127,23 +163,32 @@ def make_parallel_groups(blocks: List[List[int]], degree: int): blocks = [list(range(world_size))] cfg_groups, cfg_blocks = make_parallel_groups(blocks, cfg_degree) for cfg_ranks in cfg_groups: - cfg_group = dist.new_group(cfg_ranks) if rank in cfg_ranks: - PROCESS_GROUP.CFG_GROUP = cfg_group + PROCESS_GROUP.CFG_GROUP = dist.new_group(cfg_ranks) PROCESS_GROUP.CFG_RANKS = cfg_ranks sp_groups, sp_blocks = make_parallel_groups(cfg_blocks, sp_degree) for sp_ranks in sp_groups: - group = dist.new_group(sp_ranks) if rank in sp_ranks: - PROCESS_GROUP.SP_GROUP = group + PROCESS_GROUP.SP_GROUP = dist.new_group(sp_ranks) PROCESS_GROUP.SP_RANKS = sp_ranks + sp_ulysses_groups, sp_ulysses_blocks = make_parallel_groups(cfg_blocks, sp_ulysses_degree) + for sp_ulysses_ranks in sp_ulysses_groups: + if rank in sp_ulysses_ranks: + PROCESS_GROUP.SP_ULYSSUES_GROUP = dist.new_group(sp_ulysses_ranks) + PROCESS_GROUP.SP_ULYSSUES_RANKS = sp_ulysses_ranks + + sp_ring_groups, _ = make_parallel_groups(sp_ulysses_blocks, sp_ring_degree) + for sp_ring_ranks in sp_ring_groups: + if rank in sp_ring_ranks: + PROCESS_GROUP.SP_RING_GROUP = dist.new_group(sp_ring_ranks) + PROCESS_GROUP.SP_RING_RANKS = sp_ring_ranks + tp_groups, _ = make_parallel_groups(sp_blocks, tp_degree) for tp_ranks in tp_groups: - group = dist.new_group(tp_ranks) if rank in tp_ranks: - PROCESS_GROUP.TP_GROUP = group + PROCESS_GROUP.TP_GROUP = dist.new_group(tp_ranks) PROCESS_GROUP.TP_RANKS = tp_ranks set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)