Skip to content

Commit

Permalink
Merge pull request #97 from kadirnar/add_hqq_compile
Browse files Browse the repository at this point in the history
Add Hqq + Compile + Flash Attention support
  • Loading branch information
kadirnar authored May 9, 2024
2 parents 222443e + 3ba0d74 commit b7cf37e
Showing 1 changed file with 65 additions and 27 deletions.
92 changes: 65 additions & 27 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ class SpeechToTextPipeline:
"""Class for converting audio to text using a pre-trained speech recognition model."""

def __init__(
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq: Optional[bool] = True,
flash_attention_2: Optional[bool] = True):
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
flash_attention_2: Optional[bool] = True,
hqq_compile: Optional[bool] = True,
):
self.model = None
self.device = None
self.hqq = hqq
self.hqq_compile = hqq_compile
self.flash_attention_2 = flash_attention_2

if self.model is None:
Expand All @@ -27,34 +28,71 @@ def __init__(
logging.info("Model already loaded.")

def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
if self.hqq:
if self.hqq_compile:
import hqq.models.base as hqq_base
import torch._dynamo
from hqq.core.quantize import HQQBackend, HQQLinear
HQQLinear.set_backend(HQQBackend.PYTORCH) # Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) # Compiled Pytorch via dynamo
HQQLinear.set_backend(
HQQBackend.ATEN) # C++ Aten/CUDA backend (set automatically by default if available)
from hqq.models.hf.base import AutoHQQHFModel
from hqq.utils.patching import prepare_for_inference

if self.flash_attention_2:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map='auto',
)
logging.info("Model loaded successfully.")
torch._dynamo.config.suppress_errors = True

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")

processor = AutoProcessor.from_pretrained(model_id)

# Please keep nbits=4 and axis=1
HQQLinear.set_backend(HQQBackend.PYTORCH)

AutoHQQHFModel.quantize_model(
model.model.encoder,
quant_config=quant_config,
compute_dtype=torch.bfloat16,
device=self.device)
AutoHQQHFModel.quantize_model(
model.model.decoder,
quant_config=quant_config,
compute_dtype=torch.bfloat16,
device=self.device)

processor = AutoProcessor.from_pretrained(model_id)
# Replace HQQLinear layers matmuls to support int4 mm
hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]
AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)
prepare_for_inference(model.model.encoder)

AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)
prepare_for_inference(model.model.decoder, backend="torchao_int4")

model.model.encoder.forward = torch.compile(
model.model.encoder.forward, mode="reduce-overhead", fullgraph=True)
model.model.decoder.forward = torch.compile(
model.model.decoder.forward, mode="reduce-overhead", fullgraph=True)

elif self.hqq_compile is False:
if self.flash_attention_2:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map=self.device)

logging.info("Model loaded successfully.")

processor = AutoProcessor.from_pretrained(model_id)

self.processor = processor
self.model = model

return model

def __call__(
self,
audio_path: str = "test.mp3",
Expand Down

0 comments on commit b7cf37e

Please sign in to comment.