Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple GPUs support and XPU device support #71

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,20 @@ python app/hydit_app.py --infer-mode fa
# The enhancement will be unavailable until you restart the app without the `--no-enhance` flag.
python app/hydit_app.py --no-enhance

# You can specify the device for HunyuanDiT inference and DialogGen inference respectively,
# with `--device` and `--enhance-device` flag.
python app/hydit_app.py --device cuda:0 --enhance-device cuda:1
python app/hydit_app.py --device xpu:0 --enhance-device xpu:1 # Intel GPU

# Start with English UI
python app/hydit_app.py --lang en

# Start a multi-turn T2I generation UI.
# If your GPU memory is less than 32GB, use '--load-4bit' to enable 4-bit quantization, which requires at least 22GB of memory.
python app/multiTurnT2I_app.py
# Using multiple GPU devices.
python app/multiTurnT2I_app.py --device cuda:0 --enhance-device cuda:1 --load-4bit
python app/multiTurnT2I_app.py --device xpu:0 --enhance-device xpu:1 --load-4bit # Intel GPU
```
Then the demo can be accessed through http://0.0.0.0:443

Expand Down
14 changes: 8 additions & 6 deletions dialoggen/dialoggen_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def load_images(image_files):
return out


def init_dialoggen_model(model_path, model_base=None, load_4bit=False):
def init_dialoggen_model(model_path, model_base=None, device="cuda", load_4bit=False):
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit)
model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit, device=device)
return {"tokenizer": tokenizer,
"model": model,
"image_processor": image_processor}
Expand Down Expand Up @@ -117,7 +117,7 @@ def eval_model(models,
input_ids = (
tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
.to(models["model"].device)
)
with torch.inference_mode():
output_ids = models["model"].generate(
Expand Down Expand Up @@ -149,8 +149,8 @@ def remove_prefix(text):


class DialogGen(object):
def __init__(self, model_path, load_4bit=False):
self.models = init_dialoggen_model(model_path, load_4bit=load_4bit)
def __init__(self, model_path, device="cuda", load_4bit=False):
self.models = init_dialoggen_model(model_path, device=device, load_4bit=load_4bit)
self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}"

def __call__(self, prompt, return_history=False, history=None, skip_special=False):
Expand All @@ -176,11 +176,13 @@ def __call__(self, prompt, return_history=False, history=None, skip_special=Fals
parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen')
parser.add_argument('--prompt', type=str, default='画一只小猫')
parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg'
parser.add_argument("--enhance-device", type=str, default="cuda", help="Device for DialogGen model inference.")
parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true")
args = parser.parse_args()

query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}"

models = init_dialoggen_model(args.model_path)
models = init_dialoggen_model(args.model_path, device=args.enhance_device, load_4bit=args.load_4bit)

res = eval_model(models,
query=query,
Expand Down
29 changes: 22 additions & 7 deletions dialoggen/llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,23 @@
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import transformers
from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs):
AutoModelForCausalLM = transformers.AutoModelForCausalLM
if "xpu" in device and (load_8bit or load_4bit):
try:
import ipex_llm.transformers
AutoModelForCausalLM = ipex_llm.transformers.AutoModelForCausalLM
except ImportError:
raise ImportError("""Please install the ipex_llm package to load 8bit/4bit models on XPU.
pip install --pre ipex-llm[xpu]""")

kwargs = {"device_map": device_map, **kwargs}

if device != "cuda":
Expand All @@ -32,12 +42,16 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
if "cuda" in device:
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
if "xpu" in device:
kwargs['torch_dtype'] = torch.float16
kwargs['modules_to_not_convert'] = ['lm_head', 'mm_projector']
else:
kwargs['torch_dtype'] = torch.float16

Expand Down Expand Up @@ -163,4 +177,5 @@ def load_from_hf(repo_id, filename, subfolder=None):
else:
context_len = 2048

model = model.to(device=device)
return tokenizer, model, image_processor, context_len
2 changes: 2 additions & 0 deletions hydit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def get_args(default_args=None):
parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")

# Prompt enhancement
parser.add_argument("--enhance-device", type=str, default="cuda", help="Device for DialogGen model inference.")
parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
parser.add_argument("--no-enhance", dest="enhance", action="store_false")
parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true")
parser.set_defaults(enhance=True)

# Diffusion
parser.add_argument("--device", type=str, default="cuda", help="Device for diffusion model inference")
parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
parser.set_defaults(learn_sigma=True)
Expand Down
8 changes: 6 additions & 2 deletions hydit/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
except ImportError:
ipex = None

# For reproducibility
# torch.backends.cudnn.benchmark = False
Expand Down Expand Up @@ -159,7 +163,7 @@ def __init__(self, args, models_root_path):
logger.info(f"Got text-to-image model root path: {t2i_root_path}")

# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = args.device
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel('ERROR')
Expand All @@ -179,7 +183,7 @@ def __init__(self, args, models_root_path):
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / 'mt5'
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256).to(self.device)
self.embedder_t5 = embedder_t5
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")

Expand Down
2 changes: 1 addition & 1 deletion sample_t2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def inferencer():
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
enhancer = DialogGen(str(models_root_path / "dialoggen"), device=args.enhance_device, load_4bit=args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
Expand Down