Skip to content
Open
48 changes: 43 additions & 5 deletions demo/realtime_model_inference_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import time
import torch
import copy

from vibevoice.utils.vram_utils import get_available_vram_gb, print_vram_info
from vibevoice.utils.quantization import get_quantization_config, apply_selective_quantization
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
from transformers.utils import logging
Expand Down Expand Up @@ -129,6 +130,13 @@ def parse_args():
default=1.5,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
)
parser.add_argument(
"--quantization",
type=str,
default="fp16",
choices=["fp16", "8bit", "4bit"],
help="Quantization level: fp16 (default, ~20GB), 8bit (~12GB), or 4bit (~7GB)"
)

return parser.parse_args()

Expand All @@ -146,6 +154,14 @@ def main():
args.device = "cpu"

print(f"Using device: {args.device}")

# VRAM Detection and Quantization Info (NEW)
if args.device == "cuda":
available_vram = get_available_vram_gb()
print_vram_info(available_vram, args.model_path, args.quantization)
elif args.quantization != "fp16":
print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.")
args.quantization = "fp16"

# Initialize voice mapper
voice_mapper = VoiceMapper()
Expand All @@ -164,7 +180,7 @@ def main():
print("Error: No valid scripts found in the txt file")
return

full_script = scripts.replace("", "'").replace('', '"').replace('', '"')
full_script = scripts.replace("'", "'").replace('"', '"').replace('"', '"')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pay attention to your code agent. DO NOT introduce bugs like this.

Copy link
Author

@maitrisavaliya maitrisavaliya Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will pay attention to this, and I have corrected it.
What are your thoughts on the quantization approach? Is it going in right direction or Should I change something?


print(f"Loading processor & model from {args.model_path}")
processor = VibeVoiceStreamingProcessor.from_pretrained(args.model_path)
Expand All @@ -180,6 +196,15 @@ def main():
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")

# Get quantization configuration (NEW)
quant_config = get_quantization_config(args.quantization)

if quant_config:
print(f"Using {args.quantization} quantization...")
else:
print("Using full precision (fp16)...")

# Load model with device-specific logic
try:
if args.device == "mps":
Expand All @@ -191,12 +216,25 @@ def main():
)
model.to("mps")
elif args.device == "cuda":
# MODIFIED SECTION - Add quantization support
model_kwargs = {
"torch_dtype": load_dtype,
"device_map": "cuda",
"attn_implementation": attn_impl_primary,
}

# Add quantization config if specified
if quant_config:
model_kwargs.update(quant_config)

model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
**model_kwargs
)

# Apply selective quantization if needed (NEW)
if args.quantization in ["8bit", "4bit"]:
model = apply_selective_quantization(model, args.quantization)
else: # cpu
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
Expand Down
Empty file added utils/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Quantization utilities for VibeVoice models."""

import logging
from typing import Optional
import torch

logger = logging.getLogger(__name__)


def get_quantization_config(quantization: str = "fp16") -> Optional[dict]:
"""
Get quantization configuration for model loading.

Args:
quantization: Quantization level ("fp16", "8bit", or "4bit")

Returns:
dict: Quantization config for from_pretrained, or None for fp16
"""
if quantization == "fp16" or quantization == "full":
return None

if quantization == "8bit":
try:
import bitsandbytes as bnb
logger.info("Using 8-bit quantization (selective LLM only)")
return {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
}
except ImportError:
logger.error(
"8-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

elif quantization == "4bit":
try:
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig

logger.info("Using 4-bit NF4 quantization (selective LLM only)")
return {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
}
except ImportError:
logger.error(
"4-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

else:
raise ValueError(
f"Invalid quantization: {quantization}. "
f"Must be one of: fp16, 8bit, 4bit"
)


def apply_selective_quantization(model, quantization: str):
"""
Apply selective quantization only to safe components.

This function identifies which modules should be quantized and which
should remain at full precision for audio quality preservation.

Args:
model: The VibeVoice model
quantization: Quantization level ("8bit" or "4bit")
"""
if quantization == "fp16":
return model

logger.info("Applying selective quantization...")

# Components to KEEP at full precision (audio-critical)
keep_fp_components = [
"diffusion_head",
"acoustic_connector",
"semantic_connector",
"acoustic_tokenizer",
"semantic_tokenizer",
"vae",
]

# Only quantize the LLM (Qwen2.5) component
quantize_components = ["llm", "language_model"]

for name, module in model.named_modules():
# Check if this module should stay at full precision
should_keep_fp = any(comp in name for comp in keep_fp_components)
should_quantize = any(comp in name for comp in quantize_components)

if should_keep_fp:
# Ensure audio components stay at full precision
if hasattr(module, 'weight') and module.weight.dtype != torch.float32:
module.weight.data = module.weight.data.to(torch.bfloat16)
logger.debug(f"Keeping {name} at full precision (audio-critical)")

elif should_quantize:
logger.debug(f"Quantized {name} to {quantization}")

logger.info(f"✓ Selective {quantization} quantization applied")
logger.info(" • LLM: Quantized")
logger.info(" • Audio components: Full precision")

return model
87 changes: 87 additions & 0 deletions utils/vram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""VRAM detection and quantization recommendation utilities."""

import torch
import logging

logger = logging.getLogger(__name__)


def get_available_vram_gb() -> float:
"""
Get available VRAM in GB.

Returns:
float: Available VRAM in GB, or 0 if no CUDA device available
"""
if not torch.cuda.is_available():
return 0.0

try:
# Get first CUDA device
device = torch.device("cuda:0")
# Get total and allocated memory
total = torch.cuda.get_device_properties(device).total_memory
allocated = torch.cuda.memory_allocated(device)
available = (total - allocated) / (1024 ** 3) # Convert to GB
return available
except Exception as e:
logger.warning(f"Could not detect VRAM: {e}")
return 0.0


def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str:
"""
Suggest quantization level based on available VRAM.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded

Returns:
str: Suggested quantization level ("fp16", "8bit", or "4bit")
"""
# VibeVoice-7B memory requirements (approximate)
# Full precision (fp16/bf16): ~20GB
# 8-bit quantization: ~12GB
# 4-bit quantization: ~7GB

if "1.5B" in model_name:
# 1.5B model is smaller, adjust thresholds
if available_vram_gb >= 8:
return "fp16"
elif available_vram_gb >= 6:
return "8bit"
else:
return "4bit"
else:
# Assume 7B model
if available_vram_gb >= 22:
return "fp16"
elif available_vram_gb >= 14:
return "8bit"
else:
return "4bit"


def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"):
"""
Print VRAM information and quantization recommendation.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded
quantization: Current quantization setting
"""
logger.info(f"Available VRAM: {available_vram_gb:.1f}GB")

suggested = suggest_quantization(available_vram_gb, model_name)

if suggested != quantization and quantization == "fp16":
logger.warning(
f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). "
f"Recommended: --quantization {suggested}"
)
logger.warning(
f" Example: python demo/inference_from_file.py "
f"--model_path {model_name} --quantization {suggested} ..."
)