diff --git a/README.md b/README.md index 74afe9d..623628d 100644 --- a/README.md +++ b/README.md @@ -283,12 +283,20 @@ python app/hydit_app.py --infer-mode fa # The enhancement will be unavailable until you restart the app without the `--no-enhance` flag. python app/hydit_app.py --no-enhance +# You can specify the device for HunyuanDiT inference and DialogGen inference respectively, +# with `--device` and `--enhance-device` flag. +python app/hydit_app.py --device cuda:0 --enhance-device cuda:1 +python app/hydit_app.py --device xpu:0 --enhance-device xpu:1 # Intel GPU + # Start with English UI python app/hydit_app.py --lang en # Start a multi-turn T2I generation UI. # If your GPU memory is less than 32GB, use '--load-4bit' to enable 4-bit quantization, which requires at least 22GB of memory. python app/multiTurnT2I_app.py +# Using multiple GPU devices. +python app/multiTurnT2I_app.py --device cuda:0 --enhance-device cuda:1 --load-4bit +python app/multiTurnT2I_app.py --device xpu:0 --enhance-device xpu:1 --load-4bit # Intel GPU ``` Then the demo can be accessed through http://0.0.0.0:443 diff --git a/dialoggen/dialoggen_demo.py b/dialoggen/dialoggen_demo.py index 581fc03..383b502 100644 --- a/dialoggen/dialoggen_demo.py +++ b/dialoggen/dialoggen_demo.py @@ -50,10 +50,10 @@ def load_images(image_files): return out -def init_dialoggen_model(model_path, model_base=None, load_4bit=False): +def init_dialoggen_model(model_path, model_base=None, device="cuda", load_4bit=False): model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( - model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit) + model_path, model_base, model_name, llava_type_model=True, load_4bit=load_4bit, device=device) return {"tokenizer": tokenizer, "model": model, "image_processor": image_processor} @@ -117,7 +117,7 @@ def eval_model(models, input_ids = ( tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) - .cuda() + .to(models["model"].device) ) with torch.inference_mode(): output_ids = models["model"].generate( @@ -149,8 +149,8 @@ def remove_prefix(text): class DialogGen(object): - def __init__(self, model_path, load_4bit=False): - self.models = init_dialoggen_model(model_path, load_4bit=load_4bit) + def __init__(self, model_path, device="cuda", load_4bit=False): + self.models = init_dialoggen_model(model_path, device=device, load_4bit=load_4bit) self.query_template = "请先判断用户的意图,若为画图则在输出前加入<画图>:{}" def __call__(self, prompt, return_history=False, history=None, skip_special=False): @@ -176,11 +176,13 @@ def __call__(self, prompt, return_history=False, history=None, skip_special=Fals parser.add_argument('--model_path', type=str, default='./ckpts/dialoggen') parser.add_argument('--prompt', type=str, default='画一只小猫') parser.add_argument('--image_file', type=str, default=None) # 'images/demo1.jpeg' + parser.add_argument("--enhance-device", type=str, default="cuda", help="Device for DialogGen model inference.") + parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true") args = parser.parse_args() query = f"请先判断用户的意图,若为画图则在输出前加入<画图>:{args.prompt}" - models = init_dialoggen_model(args.model_path) + models = init_dialoggen_model(args.model_path, device=args.enhance_device, load_4bit=args.load_4bit) res = eval_model(models, query=query, diff --git a/dialoggen/llava/model/builder.py b/dialoggen/llava/model/builder.py index 263d5d1..913fc16 100644 --- a/dialoggen/llava/model/builder.py +++ b/dialoggen/llava/model/builder.py @@ -17,13 +17,23 @@ import warnings import shutil -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +import transformers +from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig import torch from llava.model import * from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, llava_type_model=True, **kwargs): + AutoModelForCausalLM = transformers.AutoModelForCausalLM + if "xpu" in device and (load_8bit or load_4bit): + try: + import ipex_llm.transformers + AutoModelForCausalLM = ipex_llm.transformers.AutoModelForCausalLM + except ImportError: + raise ImportError("""Please install the ipex_llm package to load 8bit/4bit models on XPU. + pip install --pre ipex-llm[xpu]""") + kwargs = {"device_map": device_map, **kwargs} if device != "cuda": @@ -32,12 +42,16 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l kwargs['load_in_8bit'] = True elif load_4bit: kwargs['load_in_4bit'] = True - kwargs['quantization_config'] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=torch.float16, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type='nf4' - ) + if "cuda" in device: + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + if "xpu" in device: + kwargs['torch_dtype'] = torch.float16 + kwargs['modules_to_not_convert'] = ['lm_head', 'mm_projector'] else: kwargs['torch_dtype'] = torch.float16 @@ -163,4 +177,5 @@ def load_from_hf(repo_id, filename, subfolder=None): else: context_len = 2048 + model = model.to(device=device) return tokenizer, model, image_processor, context_len diff --git a/hydit/config.py b/hydit/config.py index 2a731c0..0750e99 100644 --- a/hydit/config.py +++ b/hydit/config.py @@ -26,12 +26,14 @@ def get_args(default_args=None): parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.") # Prompt enhancement + parser.add_argument("--enhance-device", type=str, default="cuda", help="Device for DialogGen model inference.") parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.") parser.add_argument("--no-enhance", dest="enhance", action="store_false") parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true") parser.set_defaults(enhance=True) # Diffusion + parser.add_argument("--device", type=str, default="cuda", help="Device for diffusion model inference") parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.") parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false") parser.set_defaults(learn_sigma=True) diff --git a/hydit/inference.py b/hydit/inference.py index 7751ffb..a9698fa 100644 --- a/hydit/inference.py +++ b/hydit/inference.py @@ -4,6 +4,10 @@ import numpy as np import torch +try: + import intel_extension_for_pytorch as ipex +except ImportError: + ipex = None # For reproducibility # torch.backends.cudnn.benchmark = False @@ -159,7 +163,7 @@ def __init__(self, args, models_root_path): logger.info(f"Got text-to-image model root path: {t2i_root_path}") # Set device and disable gradient - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = args.device torch.set_grad_enabled(False) # Disable BertModel logging checkpoint info tf_logger.setLevel('ERROR') @@ -179,7 +183,7 @@ def __init__(self, args, models_root_path): # ======================================================================== logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...") t5_text_encoder_path = self.root / 'mt5' - embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256) + embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256).to(self.device) self.embedder_t5 = embedder_t5 logger.info(f"Loading t5_text_encoder and t5_tokenizer finished") diff --git a/sample_t2i.py b/sample_t2i.py index f017839..9f330d9 100644 --- a/sample_t2i.py +++ b/sample_t2i.py @@ -19,7 +19,7 @@ def inferencer(): # Try to enhance prompt if args.enhance: logger.info("Loading DialogGen model (for prompt enhancement)...") - enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit) + enhancer = DialogGen(str(models_root_path / "dialoggen"), device=args.enhance_device, load_4bit=args.load_4bit) logger.info("DialogGen model loaded.") else: enhancer = None