Skip to content

Commit

Permalink
Merge pull request #102 from kadirnar/add_load_model
Browse files Browse the repository at this point in the history
💯 Add load_model file
  • Loading branch information
kadirnar authored May 10, 2024
2 parents 3d535cf + 72f744c commit fa07a7e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 64 deletions.
78 changes: 78 additions & 0 deletions whisperplus/model/load_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Optional

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


def load_model_whisper(
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq_compile: bool = False,
flash_attention_2: bool = False,
device=None):
"""
Loads a speech-to-text model and processor.
Args:
- model_id (str): The model ID to load (default: "distil-whisper/distil-large-v3").
- quant_config: The quantization configuration (optional).
- hqq_compile (bool): Whether to use HQQ compilation (default: False).
- flash_attention_2 (bool): Whether to use flash attention 2 (default: False).
- device: The device to use (e.g., "cuda" or "cpu").
Returns:
- The loaded model.
"""
if hqq_compile:
import hqq.models.base as hqq_base
import torch._dynamo
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.models.hf.base import AutoHQQHFModel
from hqq.utils.patching import prepare_for_inference

torch._dynamo.config.suppress_errors = True

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")

processor = AutoProcessor.from_pretrained(model_id)
HQQLinear.set_backend(HQQBackend.PYTORCH)

AutoHQQHFModel.quantize_model(
model.model.encoder, quant_config=quant_config, compute_dtype=torch.bfloat16, device=device)

AutoHQQHFModel.quantize_model(
model.model.decoder, quant_config=quant_config, compute_dtype=torch.bfloat16, device=device)

hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]
AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)
prepare_for_inference(model.model.encoder)

AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)
prepare_for_inference(model.model.decoder, backend="torchao_int4")

model.model.encoder.forward = torch.compile(
model.model.encoder.forward, mode="reduce-overhead", fullgraph=True)
model.model.decoder.forward = torch.compile(
model.model.decoder.forward, mode="reduce-overhead", fullgraph=True)

else:
if flash_attention_2:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"

import torch
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map=device,
)

processor = AutoProcessor.from_pretrained(model_id)

return model, processor
84 changes: 20 additions & 64 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

from whisperplus.model.load_model import load_model_whisper

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


Expand All @@ -15,81 +17,35 @@ def __init__(
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
flash_attention_2: Optional[bool] = True,
hqq_compile: Optional[bool] = True,
hqq_compile: Optional[bool] = False,
):
self.model = None
self.device = None
self.device = "cuda"
self.hqq_compile = hqq_compile
self.flash_attention_2 = flash_attention_2

if self.model is None:
self.load_model(model_id)
self.load_plus_model(model_id, quant_config, hqq_compile, flash_attention_2)
else:
logging.info("Model already loaded.")

def load_model(self, model_id: str = "distil-whisper/distil-large-v3", quant_config=None):
if self.hqq_compile:
import hqq.models.base as hqq_base
import torch._dynamo
from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.models.hf.base import AutoHQQHFModel
from hqq.utils.patching import prepare_for_inference

torch._dynamo.config.suppress_errors = True

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")

processor = AutoProcessor.from_pretrained(model_id)

# Please keep nbits=4 and axis=1
HQQLinear.set_backend(HQQBackend.PYTORCH)

AutoHQQHFModel.quantize_model(
model.model.encoder,
quant_config=quant_config,
compute_dtype=torch.bfloat16,
device=self.device)
AutoHQQHFModel.quantize_model(
model.model.decoder,
quant_config=quant_config,
compute_dtype=torch.bfloat16,
device=self.device)

# Replace HQQLinear layers matmuls to support int4 mm
hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]
AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)
prepare_for_inference(model.model.encoder)

AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)
prepare_for_inference(model.model.decoder, backend="torchao_int4")

model.model.encoder.forward = torch.compile(
model.model.encoder.forward, mode="reduce-overhead", fullgraph=True)
model.model.decoder.forward = torch.compile(
model.model.decoder.forward, mode="reduce-overhead", fullgraph=True)

elif self.hqq_compile is False:
if self.flash_attention_2:
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attn_implementation,
torch_dtype=torch.bfloat16,
device_map=self.device)

logging.info("Model loaded successfully.")
def load_plus_model(
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq_compile: bool = False,
flash_attention_2: bool = True,
):

processor = AutoProcessor.from_pretrained(model_id)
model, processor = load_model_whisper(
model_id=model_id,
quant_config=quant_config,
hqq_compile=hqq_compile,
flash_attention_2=flash_attention_2,
device=self.device)

self.processor = processor
self.model = model
self.processor = processor

return model

Expand Down Expand Up @@ -124,7 +80,7 @@ def __call__(
feature_extractor=self.processor.feature_extractor,
model_kwargs={"use_flash_attention_2": self.flash_attention_2},
generate_kwargs={"language": language},
)
device_map=self.device)
logging.info("Transcribing audio...")
result = pipe(audio_path)
return result

0 comments on commit fa07a7e

Please sign in to comment.