Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add ComfyUI Native Support #59

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hunyuandit_comfy_nodes.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
113 changes: 113 additions & 0 deletions hunyuandit_comfy_nodes/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
from ..hydit.modules.text_encoder import MT5Embedder
from transformers import BertModel, BertTokenizer
import torch
import os

class CLIP:
def __init__(self, root):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder_path = os.path.join(root,"clip_text_encoder")
clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
tokenizer_path = os.path.join(root,"tokenizer")
self.tokenizer = HyBertTokenizer(tokenizer_path)
t5_text_encoder_path = os.path.join(root,'mt5')
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
self.tokenizer_t5 = HyT5Tokenizer(embedder_t5.tokenizer, max_length=embedder_t5.max_length)
self.embedder_t5 = embedder_t5.model

self.cond_stage_model = clip_text_encoder

def tokenize(self, text):
tokens = self.tokenizer.tokenize(text)
t5_tokens = self.tokenizer_t5.tokenize(text)
tokens.update(t5_tokens)
return tokens

def tokenize_t5(self, text):
return self.tokenizer_t5.tokenize(text)

def encode_from_tokens(self, tokens, return_pooled=False):
attention_mask = tokens['attention_mask'].to(self.device)
with torch.no_grad():
prompt_embeds = self.cond_stage_model(
tokens['text_input_ids'].to(self.device),
attention_mask=attention_mask
)
prompt_embeds = prompt_embeds[0]
t5_attention_mask = tokens['t5_attention_mask'].to(self.device)
with torch.no_grad():
t5_prompt_cond = self.embedder_t5(
tokens['t5_text_input_ids'].to(self.device),
attention_mask=t5_attention_mask
)
t5_embeds = t5_prompt_cond[0]

addit_embeds = {
"t5_embeds": t5_embeds,
"attention_mask": attention_mask.float(),
"t5_attention_mask": t5_attention_mask.float()
}
prompt_embeds.addit_embeds = addit_embeds

if return_pooled:
return prompt_embeds, None
else:
return prompt_embeds

class HyBertTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, truncation=True, return_attention_mask=True, device='cpu'):
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
self.max_length = self.tokenizer.model_max_length or max_length
self.truncation = truncation
self.return_attention_mask = return_attention_mask
self.device = device

def tokenize(self, text:str):
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=self.truncation,
return_attention_mask=self.return_attention_mask,
add_special_tokens = True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
tokens = {
'text_input_ids': text_input_ids,
'attention_mask': attention_mask
}
return tokens

class HyT5Tokenizer:
def __init__(self, tokenizer, max_length=77, truncation=True, return_attention_mask=True, device='cpu'):
self.tokenizer = tokenizer
self.max_length = max_length
self.truncation = truncation
self.return_attention_mask = return_attention_mask
self.device = device

def tokenize(self, text:str):
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=self.truncation,
return_attention_mask=self.return_attention_mask,
add_special_tokens = True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
tokens = {
't5_text_input_ids': text_input_ids,
't5_attention_mask': attention_mask
}
return tokens

84 changes: 84 additions & 0 deletions hunyuandit_comfy_nodes/dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
from comfy import model_management
from .supported_dit_models import HunYuan_DiT, HYDiT_Model, ModifiedHunYuanDiT
from .clip import CLIP
import os
import folder_paths
import torch

sampling_settings = {
"beta_schedule" : "linear",
"linear_start" : 0.00085,
"linear_end" : 0.03,
"timesteps" : 1000,
}

hydit_conf = {
"G/2": { # Seems to be the main one
"unet_config": {
"depth" : 40,
"num_heads" : 16,
"patch_size" : 2,
"hidden_size" : 1408,
"mlp_ratio" : 4.3637,
"input_size": (1024//8, 1024//8),
},
"sampling_settings" : sampling_settings,
},
}

def load_dit(model_path, output_clip=True, output_model=True, output_vae=True):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
clip = None,
vae = None
model_patcher = None

# ignore fp8/etc and use directly for now
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
root = os.path.join(folder_paths.models_dir, "hunyuan/ckpts/t2i")
if manual_cast_dtype:
print(f"DiT: falling back to {manual_cast_dtype}")
unet_dtype = manual_cast_dtype

#model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty

if output_model:
model_conf = HunYuan_DiT(hydit_conf["G/2"])
model = HYDiT_Model(
model_conf,
model_type=comfy.model_base.ModelType.V_PREDICTION,
device=model_management.get_torch_device()
)

model.diffusion_model = ModifiedHunYuanDiT(model_conf.dit_conf, **model_conf.unet_config).half().to(load_device)

model.diffusion_model.load_state_dict(state_dict)
#model.diffusion_model.dtype = unet_dtype
model.diffusion_model.eval()
model.diffusion_model.to(unet_dtype)

model_patcher = comfy.model_patcher.ModelPatcher(
model,
load_device = load_device,
offload_device = offload_device,
current_device = "cpu",
)
#model_patcher['model_options']['dit'] = 'hunyuan'
if output_clip:
clip = CLIP(root)

if output_vae:
vae_path = os.path.join(root, 'sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors')
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)

return (model_patcher, clip, vae)
31 changes: 31 additions & 0 deletions hunyuandit_comfy_nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path
import folder_paths
from .dit import load_dit

MAX_RESOLUTION=8192

class DitCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}

RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "ExtraModels/DiT"
TITLE = "DitCheckpointLoader"

def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = load_dit(
model_path = ckpt_path,
)
return out[:3]

NODE_CLASS_MAPPINGS = {
"DitCheckpointLoader":DitCheckpointLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DitCheckpointLoader":"DitCheckpointLoaderSimple",
}

109 changes: 109 additions & 0 deletions hunyuandit_comfy_nodes/supported_dit_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
import torch
import inspect
from collections import namedtuple
from ..hydit.modules.models import HunYuanDiT as HYDiT
from ..hydit.modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop

def batch_embeddings(embeds, batch_size):
bs_embed, seq_len, _ = embeds.shape
embeds = embeds.repeat(1, batch_size, 1)
embeds = embeds.view(bs_embed * batch_size, seq_len, -1)
return embeds

class HunYuan_DiT(comfy.supported_models_base.BASE):
Conf = namedtuple('DiT', ['learn_sigma', 'text_states_dim', 'text_states_dim_t5', 'text_len', 'text_len_t5', 'norm', 'infer_mode', 'use_fp16'])
conf = {
'learn_sigma': True,
'text_states_dim': 1024,
'text_states_dim_t5': 2048,
'text_len': 77,
'text_len_t5': 256,
'norm': 'layer',
'infer_mode': 'torch',
'use_fp16': True
}

unet_config = {}
unet_extra_config = {
"num_heads": 16
}
latent_format = comfy.latent_formats.SDXL

dit_conf = Conf(**conf)

def __init__(self, model_conf):
self.unet_config = model_conf.get("unet_config", {})
self.sampling_settings = model_conf.get("sampling_settings", {})
self.latent_format = self.latent_format()
self.unet_config["disable_unet_model_creation"] = True

def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.V_PREDICTION

class HYDiT_Model(comfy.model_base.BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
addit_embeds = kwargs['cross_attn'].addit_embeds
for name in addit_embeds:
out[name] = comfy.conds.CONDRegular(addit_embeds[name])

return out

class ModifiedHunYuanDiT(HYDiT):
def __init__ (self, *args, **kwargs):
signature = inspect.signature(HYDiT.__init__)
params = set(signature.parameters.keys()) - {'self'}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in params}

super().__init__(*args, **filtered_kwargs)

def forward_core(self, *args, **kwargs):
return super().forward(*args, **kwargs)

def forward(self, x, timesteps, context, t5_embeds=None, attention_mask=None, t5_attention_mask=None, image_meta_size=None, **kwargs):
batch_size, _, width, height = x.shape

style = torch.as_tensor([0, 0] * (batch_size//2), device=x.device)
src_size_cond = (width//2*16, height//2*16)
size_cond = list(src_size_cond) + [width*8, height*8, 0, 0]
image_meta_size = torch.as_tensor([size_cond] * batch_size, device=x.device)
rope = self.calc_rope(*src_size_cond)

noise_pred = self.forward_core(
x = x.to(self.dtype),
t = timesteps.to(self.dtype),
encoder_hidden_states = context.to(self.dtype),
text_embedding_mask = attention_mask.to(self.dtype),
encoder_hidden_states_t5 = t5_embeds.to(self.dtype),
text_embedding_mask_t5 = t5_attention_mask.to(self.dtype),
image_meta_size = image_meta_size.to(self.dtype),
style = style,
cos_cis_img = rope[0],
sin_cis_img = rope[1],
return_dict=False
)
noise_pred = noise_pred.to(torch.float)
eps, _ = noise_pred[:, :self.in_channels], noise_pred[:, self.in_channels:]
return eps

def calc_rope(self, height, width):
"""
Probably not the best in terms of perf to have this here
"""
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
head_size = self.hidden_size // self.num_heads
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope

2 changes: 1 addition & 1 deletion hydit/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,4 @@ def unpatchify(self, x, h, w):
'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16},
'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12},
}
}