Skip to content

Commit

Permalink
support comfyui
Browse files Browse the repository at this point in the history
  • Loading branch information
zml-ai committed Jun 6, 2024
1 parent 39d26b4 commit 454c7ba
Show file tree
Hide file tree
Showing 31 changed files with 4,276 additions and 1 deletion.
3 changes: 3 additions & 0 deletions Notice
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ Copyright (c) 2021 Vision and Computational Cognition Group
8. sd-vae-ft-ema
Copyright (c) sd-vae-ft-ema original author and authors

9. ComfyUI-Diffusers
Copyright (c) 2023 Limitex


Terms of the MIT License:
--------------------------------------------------------------------
Expand Down
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This repo contains PyTorch model definitions, pre-trained weights and inference/
> [**DialogGen: Multi-modal Interactive Dialogue System for Multi-turn Text-to-Image Generation**](https://arxiv.org/abs/2403.08857) <br>
## 🔥🔥🔥 News!!
* Jun 06, 2024: :tada: Hunyuan-DiT is now available in ComfyUI. Please check [ComfyUI](#using-comfyui) for more details.
* Jun 06, 2024: 🚀 We introduce Distillation version for Hunyuan-DiT acceleration, which achieves **50%** acceleration on NVIDIA GPUs. Please check [Tencent-Hunyuan/Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details.
* Jun 05, 2024: 🤗 Hunyuan-DiT is now available in 🤗 Diffusers! Please check the [example](#using--diffusers) below.
* Jun 04, 2024: :globe_with_meridians: Support Tencent Cloud links to download the pretrained models! Please check the [links](#-download-pretrained-models) below.
Expand Down Expand Up @@ -73,7 +74,7 @@ or multi-turn language interactions to create the picture.
- [X] Web Demo (Gradio)
- [x] Multi-turn T2I Demo (Gradio)
- [X] Cli Demo
- [ ] ComfyUI
- [X] ComfyUI
- [X] Diffusers
- [ ] WebUI

Expand All @@ -94,6 +95,7 @@ or multi-turn language interactions to create the picture.
- [Using Diffusers](#using--diffusers)
- [Using Command Line](#using-command-line)
- [More Configurations](#more-configurations)
- [Using ComfyUI](#using-comfyui)
- [🚀 Acceleration (for Linux)](#-acceleration-for-linux)
- [🔗 BibTeX](#-bibtex)

Expand Down Expand Up @@ -389,6 +391,46 @@ We list some more useful configurations for easy usage:
| `--load-key` | ema | Load the student model or EMA model (ema or module) |
| `--load-4bit` | Fasle | Load DialogGen model with 4bit quantization |

### Using ComfyUI

We provide several commands to quick start:

```shell
# Download comfyui code
git clone https://github.com/comfyanonymous/ComfyUI.git

# Install torch, torchvision, torchaudio
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117

# Install Comfyui essential python package
cd ComfyUI
pip install -r requirements.txt

# ComfyUI has been successfully installed!

# Download model weight as before or link the existing model folder to ComfyUI.
python -m pip install "huggingface_hub[cli]"
mkdir models/hunyuan
huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./models/hunyuan/ckpts

# Move to the ComfyUI custom_nodes folder and copy comfyui-hydit folder from HunyuanDiT Repo.
cd custom_nodes
cp -r ${HunyuanDiT}/comfyui-hydit ./
cd comfyui-hydit

# Install some essential python Package.
pip install -r requirements.txt

# Our tool has been successfully installed!

# Go to ComfyUI main folder
cd ../..
# Run the ComfyUI Lauch command
python main.py --listen --port 80

# Running ComfyUI successfully!
```
More details can be found in [ComfyUI README](comfyui-hydit/README.md)

## 🚀 Acceleration (for Linux)

Expand Down
21 changes: 21 additions & 0 deletions comfyui-hydit/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 Limitex

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
95 changes: 95 additions & 0 deletions comfyui-hydit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# comfyui-hydit

This repository contains a customized node and workflow designed specifically for HunYuan DIT. The official tests conducted on DDPM, DDIM, and DPMMS have consistently yielded results that align with those obtained through the Diffusers library. However, it's important to note that we cannot assure the consistency of results from other ComfyUI native samplers with the Diffusers inference. We cordially invite users to explore our workflow and are open to receiving any inquiries or suggestions you may have.

## Overview


### Workflow text2image

![Workflow](img/txt2img_v2.png)

[workflow_diffusers](workflow/hunyuan_diffusers_api.json) file for HunyuanDiT txt2image with diffusers backend.
[workflow_ksampler](workflow/hunyuan_ksampler_api.json) file for HunyuanDiT txt2image with ksampler backend.


## Usage

We provide several commands to quick start:

```shell
# Download comfyui code
git clone https://github.com/comfyanonymous/ComfyUI.git

# Install torch, torchvision, torchaudio
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117

# Install Comfyui essential python package
cd ComfyUI
pip install -r requirements.txt

# ComfyUI has been successfully installed!

# Download model weight as before or link the existing model folder to ComfyUI.
python -m pip install "huggingface_hub[cli]"
mkdir models/hunyuan
huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./models/hunyuan/ckpts

# Move to the ComfyUI custom_nodes folder and copy comfyui-hydit folder from HunyuanDiT Repo.
cd custom_nodes
cp -r ${HunyuanDiT}/comfyui-hydit ./
cd comfyui-hydit

# Install some essential python Package.
pip install -r requirements.txt

# Our tool has been successfully installed!

# Go to ComfyUI main folder
cd ../..
# Run the ComfyUI Lauch command
python main.py --listen --port 80

# Running ComfyUI successfully!
```



## Custom Node
Below I'm trying to document all the nodes, thanks for some good work[[1]](#1)[[2]](#2).
#### HunYuan Pipeline Loader
- Loads the full stack of models needed for HunYuanDiT.
- **pipeline_folder_name** is the official weight folder path for hunyuan dit including clip_text_encoder, model, mt5, sdxl-vae-fp16-fix and tokenizer.
- **model_name** is the weight list of comfyui checkpoint folder.
- **vae_name** is the weight list of comfyui vae folder.
- **backend** "diffusers" means using diffusers as the backend, while "ksampler" means using comfyui ksampler for the backend.
- **PIPELINE** is the instance of StableDiffusionPipeline.
- **MODEL** is the instance of comfyui MODEL.
- **CLIP** is the instance of comfyui CLIP.
- **VAE** is the instance of comfyui VAE.

#### HunYuan Scheduler Loader
- Loads the scheduler algorithm for HunYuanDiT.
- **Input** is the algorithm name including ddpm, ddim and dpmms.
- **Output** is the instance of diffusers.schedulers.

#### HunYuan Model Makeup
- Assemble the models and scheduler module.
- **Input** is the instance of StableDiffusionPipeline and diffusers.schedulers.
- **Output** is the updated instance of StableDiffusionPipeline.

#### HunYuan Clip Text Encode
- Assemble the models and scheduler module.
- **Input** is the string of positive and negative prompts.
- **Output** is the converted string for model.

#### HunYuan Sampler
- Similar with KSampler in ComfyUI.
- **Input** is the instance of StableDiffusionPipeline and some hyper-parameters for sampling.
- **Output** is the generated image.

## Reference
<a id="1">[1]</a>
https://github.com/Limitex/ComfyUI-Diffusers
<a id="2">[2]</a>
https://github.com/Tencent/HunyuanDiT/pull/59
4 changes: 4 additions & 0 deletions comfyui-hydit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .nodes import *
#aa = DiffusersSampler()
#print(aa)
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
113 changes: 113 additions & 0 deletions comfyui-hydit/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import comfy.supported_models_base
import comfy.latent_formats
import comfy.model_patcher
import comfy.model_base
import comfy.utils
from .hydit.modules.text_encoder import MT5Embedder
from transformers import BertModel, BertTokenizer
import torch
import os

class CLIP:
def __init__(self, root):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
text_encoder_path = os.path.join(root,"clip_text_encoder")
clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
tokenizer_path = os.path.join(root,"tokenizer")
self.tokenizer = HyBertTokenizer(tokenizer_path)
t5_text_encoder_path = os.path.join(root,'mt5')
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
self.tokenizer_t5 = HyT5Tokenizer(embedder_t5.tokenizer, max_length=embedder_t5.max_length)
self.embedder_t5 = embedder_t5.model

self.cond_stage_model = clip_text_encoder

def tokenize(self, text):
tokens = self.tokenizer.tokenize(text)
t5_tokens = self.tokenizer_t5.tokenize(text)
tokens.update(t5_tokens)
return tokens

def tokenize_t5(self, text):
return self.tokenizer_t5.tokenize(text)

def encode_from_tokens(self, tokens, return_pooled=False):
attention_mask = tokens['attention_mask'].to(self.device)
with torch.no_grad():
prompt_embeds = self.cond_stage_model(
tokens['text_input_ids'].to(self.device),
attention_mask=attention_mask
)
prompt_embeds = prompt_embeds[0]
t5_attention_mask = tokens['t5_attention_mask'].to(self.device)
with torch.no_grad():
t5_prompt_cond = self.embedder_t5(
tokens['t5_text_input_ids'].to(self.device),
attention_mask=t5_attention_mask
)
t5_embeds = t5_prompt_cond[0]

addit_embeds = {
"t5_embeds": t5_embeds,
"attention_mask": attention_mask.float(),
"t5_attention_mask": t5_attention_mask.float()
}
prompt_embeds.addit_embeds = addit_embeds

if return_pooled:
return prompt_embeds, None
else:
return prompt_embeds

class HyBertTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, truncation=True, return_attention_mask=True, device='cpu'):
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
self.max_length = self.tokenizer.model_max_length or max_length
self.truncation = truncation
self.return_attention_mask = return_attention_mask
self.device = device

def tokenize(self, text:str):
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=self.truncation,
return_attention_mask=self.return_attention_mask,
add_special_tokens = True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
tokens = {
'text_input_ids': text_input_ids,
'attention_mask': attention_mask
}
return tokens

class HyT5Tokenizer:
def __init__(self, tokenizer, max_length=77, truncation=True, return_attention_mask=True, device='cpu'):
self.tokenizer = tokenizer
self.max_length = max_length
self.truncation = truncation
self.return_attention_mask = return_attention_mask
self.device = device

def tokenize(self, text:str):
text_inputs = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=self.truncation,
return_attention_mask=self.return_attention_mask,
add_special_tokens = True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
tokens = {
't5_text_input_ids': text_input_ids,
't5_attention_mask': attention_mask
}
return tokens

6 changes: 6 additions & 0 deletions comfyui-hydit/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os
from .hydit.constants import SAMPLER_FACTORY

base_path = os.path.dirname(os.path.realpath(__file__))
HUNYUAN_PATH = os.path.join(base_path, "..", "..", "models", "hunyuan")
SCHEDULERS_hunyuan = list(SAMPLER_FACTORY.keys())
Loading

0 comments on commit 454c7ba

Please sign in to comment.