diff --git a/demo/realtime_model_inference_from_file.py b/demo/realtime_model_inference_from_file.py index a321f6c..8d16db6 100644 --- a/demo/realtime_model_inference_from_file.py +++ b/demo/realtime_model_inference_from_file.py @@ -121,6 +121,13 @@ def parse_args(): default=1.5, help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)", ) + parser.add_argument( + "--quantization", + type=str, + default="fp16", + choices=["fp16", "8bit", "4bit"], + help="Quantization level: fp16 (default, ~20GB), 8bit (~12GB), or 4bit (~7GB)" + ) return parser.parse_args() @@ -138,6 +145,14 @@ def main(): args.device = "cpu" print(f"Using device: {args.device}") + + # VRAM Detection and Quantization Info (NEW) + if args.device == "cuda": + available_vram = get_available_vram_gb() + print_vram_info(available_vram, args.model_path, args.quantization) + elif args.quantization != "fp16": + print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.") + args.quantization = "fp16" # Initialize voice mapper voice_mapper = VoiceMapper() @@ -172,6 +187,15 @@ def main(): load_dtype = torch.float32 attn_impl_primary = "sdpa" print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") + + # Get quantization configuration (NEW) + quant_config = get_quantization_config(args.quantization) + + if quant_config: + print(f"Using {args.quantization} quantization...") + else: + print("Using full precision (fp16)...") + # Load model with device-specific logic try: if args.device == "mps": @@ -183,12 +207,25 @@ def main(): ) model.to("mps") elif args.device == "cuda": + # MODIFIED SECTION - Add quantization support + model_kwargs = { + "torch_dtype": load_dtype, + "device_map": "cuda", + "attn_implementation": attn_impl_primary, + } + + # Add quantization config if specified + if quant_config: + model_kwargs.update(quant_config) + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( args.model_path, - torch_dtype=load_dtype, - device_map="cuda", - attn_implementation=attn_impl_primary, + **model_kwargs ) + + # Apply selective quantization if needed (NEW) + if args.quantization in ["8bit", "4bit"]: + model = apply_selective_quantization(model, args.quantization) else: # cpu model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( args.model_path, diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/quantization.py b/utils/quantization.py new file mode 100644 index 0000000..2348cf5 --- /dev/null +++ b/utils/quantization.py @@ -0,0 +1,113 @@ +"""Quantization utilities for VibeVoice models.""" + +import logging +from typing import Optional +import torch + +logger = logging.getLogger(__name__) + + +def get_quantization_config(quantization: str = "fp16") -> Optional[dict]: + """ + Get quantization configuration for model loading. + + Args: + quantization: Quantization level ("fp16", "8bit", or "4bit") + + Returns: + dict: Quantization config for from_pretrained, or None for fp16 + """ + if quantization == "fp16" or quantization == "full": + return None + + if quantization == "8bit": + try: + import bitsandbytes as bnb + logger.info("Using 8-bit quantization (selective LLM only)") + return { + "load_in_8bit": True, + "llm_int8_threshold": 6.0, + } + except ImportError: + logger.error( + "8-bit quantization requires bitsandbytes. " + "Install with: pip install bitsandbytes" + ) + raise + + elif quantization == "4bit": + try: + import bitsandbytes as bnb + from transformers import BitsAndBytesConfig + + logger.info("Using 4-bit NF4 quantization (selective LLM only)") + return { + "quantization_config": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + } + except ImportError: + logger.error( + "4-bit quantization requires bitsandbytes. " + "Install with: pip install bitsandbytes" + ) + raise + + else: + raise ValueError( + f"Invalid quantization: {quantization}. " + f"Must be one of: fp16, 8bit, 4bit" + ) + + +def apply_selective_quantization(model, quantization: str): + """ + Apply selective quantization only to safe components. + + This function identifies which modules should be quantized and which + should remain at full precision for audio quality preservation. + + Args: + model: The VibeVoice model + quantization: Quantization level ("8bit" or "4bit") + """ + if quantization == "fp16": + return model + + logger.info("Applying selective quantization...") + + # Components to KEEP at full precision (audio-critical) + keep_fp_components = [ + "diffusion_head", + "acoustic_connector", + "semantic_connector", + "acoustic_tokenizer", + "semantic_tokenizer", + "vae", + ] + + # Only quantize the LLM (Qwen2.5) component + quantize_components = ["llm", "language_model"] + + for name, module in model.named_modules(): + # Check if this module should stay at full precision + should_keep_fp = any(comp in name for comp in keep_fp_components) + should_quantize = any(comp in name for comp in quantize_components) + + if should_keep_fp: + # Ensure audio components stay at full precision + if hasattr(module, 'weight') and module.weight.dtype != torch.float32: + module.weight.data = module.weight.data.to(torch.bfloat16) + logger.debug(f"Keeping {name} at full precision (audio-critical)") + + elif should_quantize: + logger.debug(f"Quantized {name} to {quantization}") + + logger.info(f"✓ Selective {quantization} quantization applied") + logger.info(" • LLM: Quantized") + logger.info(" • Audio components: Full precision") + + return model \ No newline at end of file diff --git a/utils/vram_utils.py b/utils/vram_utils.py new file mode 100644 index 0000000..7fedd05 --- /dev/null +++ b/utils/vram_utils.py @@ -0,0 +1,87 @@ +"""VRAM detection and quantization recommendation utilities.""" + +import torch +import logging + +logger = logging.getLogger(__name__) + + +def get_available_vram_gb() -> float: + """ + Get available VRAM in GB. + + Returns: + float: Available VRAM in GB, or 0 if no CUDA device available + """ + if not torch.cuda.is_available(): + return 0.0 + + try: + # Get first CUDA device + device = torch.device("cuda:0") + # Get total and allocated memory + total = torch.cuda.get_device_properties(device).total_memory + allocated = torch.cuda.memory_allocated(device) + available = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB + return available + except Exception as e: + logger.warning(f"Could not detect VRAM: {e}") + return 0.0 + + +def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str: + """ + Suggest quantization level based on available VRAM. + + Args: + available_vram_gb: Available VRAM in GB + model_name: Name of the model being loaded + + Returns: + str: Suggested quantization level ("fp16", "8bit", or "4bit") + """ + # VibeVoice-7B memory requirements (approximate) + # Full precision (fp16/bf16): ~20GB + # 8-bit quantization: ~12GB + # 4-bit quantization: ~7GB + + if "1.5B" in model_name: + # 1.5B model is smaller, adjust thresholds + if available_vram_gb >= 8: + return "fp16" + elif available_vram_gb >= 6: + return "8bit" + else: + return "4bit" + else: + # Assume 7B model + if available_vram_gb >= 22: + return "fp16" + elif available_vram_gb >= 14: + return "8bit" + else: + return "4bit" + + +def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"): + """ + Print VRAM information and quantization recommendation. + + Args: + available_vram_gb: Available VRAM in GB + model_name: Name of the model being loaded + quantization: Current quantization setting + """ + logger.info(f"Available VRAM: {available_vram_gb:.1f}GB") + + suggested = suggest_quantization(available_vram_gb, model_name) + + if suggested != quantization and quantization == "fp16": + logger.warning( + f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). " + f"Recommended: --quantization {suggested}" + ) + logger.warning( + f" Example: python demo/inference_from_file.py " + f"--model_path {model_name} --quantization {suggested} ..." + )