Skip to content

Commit

Permalink
Add whisper tritonserver batch inference (#650) / update pre-commit i…
Browse files Browse the repository at this point in the history
…sort==5.13.2 (#683)
  • Loading branch information
upskyy authored Dec 13, 2024
1 parent 454ef3e commit 4c11c66
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 138 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
args: [--max-line-length=80]

- repo: https://github.com/pycqa/isort
rev: 5.9.2
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black, --line-length=80]
Expand Down
183 changes: 96 additions & 87 deletions triton/whisper/model_repo_whisper_trtllm/infer_bls/1/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
import triton_python_backend_utils as pb_utils
import numpy as np
# -*- coding: utf-8 -*-

import json
import torch
from torch.utils.dlpack import to_dlpack
import re
from .tokenizer import get_tokenizer
from collections import OrderedDict
from pathlib import Path

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import to_dlpack

from .tokenizer import get_tokenizer


def read_config(component, engine_dir):
config_path = engine_dir / component / 'config.json'
with open(config_path, 'r') as f:
Expand All @@ -18,109 +22,114 @@ def read_config(component, engine_dir):
model_config.update(config['build_config'])
return model_config

class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""

class TritonPythonModel:
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to initialize any state associated with this model.
Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
self.model_config = model_config = json.loads(args['model_config'])

# Get OUTPUT0 configuration
output0_config = pb_utils.get_output_config_by_name(
model_config, "TRANSCRIPTS")
# Convert Triton types to numpy types
self.out0_dtype = pb_utils.triton_string_to_numpy(
output0_config['data_type'])
encoder_config = read_config('encoder', Path(self.model_config['parameters']['engine_dir']["string_value"]))
self.tokenizer = get_tokenizer(num_languages=encoder_config['num_languages'])
self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]

engine_dir = Path(
self.model_config['parameters']['engine_dir']["string_value"])
encoder_config = read_config('encoder', engine_dir)
self.tokenizer = get_tokenizer(
num_languages=encoder_config['num_languages']
)
self.blank = self.tokenizer.encode(
" ",
allowed_special=self.tokenizer.special_tokens_set
)[0]
self.device = torch.device("cuda")

def process_batch(self, wav, wav_len, prompt_id):
wav = torch.from_numpy(wav[0]).to(self.device)
wav_tensor = pb_utils.Tensor.from_dlpack("WAV", to_dlpack(wav.unsqueeze(0)))
wav_len_tensor = pb_utils.Tensor("WAV_LENS", np.array([[wav_len]], np.int32))
prompt_id = torch.tensor(prompt_id).unsqueeze(0)
def process_batch(self, wav_batch, wav_lens, prompt_id):
# Convert numpy arrays to torch tensors
wav_batch = torch.from_numpy(wav_batch).to(self.device)
wav_tensor = pb_utils.Tensor.from_dlpack(
"WAV",
to_dlpack(wav_batch)
)
wav_len_tensor = pb_utils.Tensor(
"WAV_LENS",
wav_lens.astype(np.int32)
)

# Replicate prompt_id for batch size
batch_size = wav_batch.shape[0]
prompt_ids = np.tile(prompt_id, (batch_size, 1))
prompt_ids_tensor = pb_utils.Tensor(
"DECODER_INPUT_IDS",
prompt_ids.astype(np.int32)
)

prompt_id = pb_utils.Tensor("DECODER_INPUT_IDS", prompt_id.numpy().astype(np.int32))
infer_request = pb_utils.InferenceRequest(
model_name="whisper",
requested_output_names=["OUTPUT_IDS"],
inputs=[wav_tensor, wav_len_tensor, prompt_id]
inputs=[wav_tensor, wav_len_tensor, prompt_ids_tensor]
)

inference_response = infer_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
output_ids = pb_utils.get_output_tensor_by_name(inference_response, "OUTPUT_IDS")
return output_ids.as_numpy()

raise pb_utils.TritonModelException(
inference_response.error().message())

output_ids = pb_utils.get_output_tensor_by_name(
inference_response, "OUTPUT_IDS")
return output_ids.as_numpy()

def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference is requested
for this model.
Parameters
----------
requests : list
A list of pb_utils.InferenceRequest
Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
# Every Python backend must iterate through list of requests and create
# an instance of pb_utils.InferenceResponse class for each of them. You
# should avoid storing any of the input Tensors in the class attributes
# as they will be overridden in subsequent inference requests. You can
# make a copy of the underlying NumPy array and store it if it is
# required.
responses = []

for request in requests:
# Perform inference on the request and append it to responses list...
in_0 = pb_utils.get_input_tensor_by_name(request, "TEXT_PREFIX")
prompt_ids = in_0.as_numpy().tolist()
prompt_ids = prompt_ids[0][0].decode('utf-8')
if prompt_ids == "":
prompt_ids = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
prompt_id = self.tokenizer.encode(prompt_ids, allowed_special=self.tokenizer.special_tokens_set)

wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
assert wav.shape[0] == 1, "Only support batch size 1 for now"
wav_len = pb_utils.get_input_tensor_by_name(request, "WAV_LENS").as_numpy()
wav_len = wav_len.item()

output_ids = self.process_batch(wav, wav_len, prompt_id)
s = self.tokenizer.decode(output_ids)
s = re.sub(r'<\|.*?\|>', '', s)
sentence = np.array([s])
out0 = pb_utils.Tensor("TRANSCRIPTS", sentence.astype(self.out0_dtype))
inference_response = pb_utils.InferenceResponse(output_tensors=[out0])
# Get batch inputs
text_prefix = pb_utils.get_input_tensor_by_name(
request, "TEXT_PREFIX").as_numpy()
wav_batch = pb_utils.get_input_tensor_by_name(
request, "WAV").as_numpy()
wav_lens = pb_utils.get_input_tensor_by_name(
request, "WAV_LENS").as_numpy()

# Use the same text_prefix for all items in the request
prefix = text_prefix[0][0].decode('utf-8')
if prefix == "":
prefix = (
"<|startoftranscript|><|ko|><|transcribe|><|notimestamps|>"
)
prompt_id = self.tokenizer.encode(
prefix,
allowed_special=self.tokenizer.special_tokens_set
)

# Process the entire batch
output_ids = self.process_batch(wav_batch, wav_lens, prompt_id)

# Decode outputs for each item in batch
transcripts = []

# Handle case where output_ids is 3-dimensional
# ([batch_size, beam_size, seq_len])
if len(output_ids.shape) == 3:
output_ids = output_ids[:, 0, :] # Remove beam_size dimension

for output_id in output_ids:
token_list = output_id.tolist()
s = self.tokenizer.decode(token_list)
s = re.sub(r'<\|.*?\|>', '', s)
transcripts.append(s)

# Create response tensor
out0 = pb_utils.Tensor(
"TRANSCRIPTS",
np.array(transcripts).astype(self.out0_dtype)
)
inference_response = pb_utils.InferenceResponse(
output_tensors=[out0]
)
responses.append(inference_response)

return responses

def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
print('Cleaning up...')
124 changes: 74 additions & 50 deletions triton/whisper/model_repo_whisper_trtllm/whisper/1/model.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand All @@ -24,78 +23,103 @@
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from pathlib import Path

from .fbank import FeatureExtractor
import torch
from torch.utils.dlpack import from_dlpack

import triton_python_backend_utils as pb_utils
from tensorrt_llm.runtime import ModelRunnerCpp
from tensorrt_llm.bindings import GptJsonConfig
from tensorrt_llm.runtime import ModelRunnerCpp
from torch.utils.dlpack import from_dlpack

from .fbank import FeatureExtractor


class TritonPythonModel:
def initialize(self, args):
parameters = json.loads(args['model_config'])['parameters']
for key,value in parameters.items():
for key, value in parameters.items():
parameters[key] = value["string_value"]
engine_dir = parameters["engine_dir"]
json_config = GptJsonConfig.parse_file(Path(engine_dir) / 'decoder' / 'config.json')
config_path = Path(engine_dir) / 'decoder' / 'config.json'
json_config = GptJsonConfig.parse_file(config_path)
assert json_config.model_config.supports_inflight_batching
runner_kwargs = dict(engine_dir=engine_dir,
is_enc_dec=True,
max_batch_size=64,
max_input_len=3000,
max_output_len=96,
max_beam_width=1,
debug_mode=False,
kv_cache_free_gpu_memory_fraction=0.5)
runner_kwargs = dict(
engine_dir=engine_dir,
is_enc_dec=True,
max_batch_size=64,
max_input_len=3000,
max_output_len=96,
max_beam_width=1,
debug_mode=False,
kv_cache_free_gpu_memory_fraction=0.5,
)
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
self.feature_extractor = FeatureExtractor(n_mels = int(parameters["n_mels"]))
self.zero_pad = True if parameters["zero_pad"] == "true" else False
self.feature_extractor = FeatureExtractor(
n_mels=int(parameters["n_mels"])
)
self.zero_pad = parameters["zero_pad"] == "true"
self.eot_id = 50257

def execute(self, requests):
"""
This function receives a list of requests (`pb_utils.InferenceRequest`),
performs inference on every request and appends it to responses.
"""
responses, batch_mel_list, decoder_input_ids = [], [], []
responses = []

for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "WAV")
wav_len = pb_utils.get_input_tensor_by_name(request, "WAV_LENS").as_numpy().item()
prompt_ids = pb_utils.get_input_tensor_by_name(request, "DECODER_INPUT_IDS").as_numpy()
wav_lens = pb_utils.get_input_tensor_by_name(
request, "WAV_LENS").as_numpy()
prompt_ids = pb_utils.get_input_tensor_by_name(
request, "DECODER_INPUT_IDS").as_numpy()

# Move WAV data to GPU
wav = from_dlpack(wav_tensor.to_dlpack())
wav = wav[:, :wav_len]
batch_size = wav.shape[0]

padding = 0 if self.zero_pad else 3000
mel = self.feature_extractor.compute_feature(wav[0].to('cuda'), padding_target_len=padding).transpose(1, 2)
batch_mel_list.append(mel.squeeze(0))
decoder_input_ids.append(torch.tensor(prompt_ids, dtype=torch.int32, device='cuda').squeeze(0))

decoder_input_ids = torch.nn.utils.rnn.pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.eot_id)
mel_input_lengths = torch.tensor([mel.shape[0] for mel in batch_mel_list], dtype=torch.int32, device='cuda')

outputs = self.model_runner_cpp.generate(
batch_input_ids=decoder_input_ids,
encoder_input_features=batch_mel_list,
encoder_output_lengths=mel_input_lengths // 2,
max_new_tokens=96,
end_id=self.eot_id,
pad_id=self.eot_id,
num_beams=1,
output_sequence_lengths=True,
return_dict=True)
torch.cuda.synchronize()

output_ids = outputs['output_ids'].cpu().numpy()

for i, output_id in enumerate(output_ids):
batch_mel_list = []

# Batch processing for each sample in the batch
for i in range(batch_size):
wav_i = wav[i:i+1, :int(wav_lens[i].item())]
mel = self.feature_extractor.compute_feature(
wav_i[0].to('cuda'),
padding_target_len=padding
).transpose(1, 2)
batch_mel_list.append(mel.squeeze(0))

# Move prompt IDs to GPU
decoder_input_ids = torch.tensor(
prompt_ids, dtype=torch.int32, device='cuda')

# Calculate mel lengths
mel_input_lengths = torch.tensor(
[mel.shape[0] for mel in batch_mel_list],
dtype=torch.int32,
device='cuda'
)

# Run batch inference
outputs = self.model_runner_cpp.generate(
batch_input_ids=decoder_input_ids,
encoder_input_features=batch_mel_list,
encoder_output_lengths=mel_input_lengths // 2,
max_new_tokens=96,
end_id=self.eot_id,
pad_id=self.eot_id,
num_beams=1,
output_sequence_lengths=True,
return_dict=True
)
torch.cuda.synchronize()

# Process outputs
output_ids = outputs['output_ids'].cpu().numpy()

# Create response for the request
response = pb_utils.InferenceResponse(output_tensors=[
pb_utils.Tensor("OUTPUT_IDS", output_id[0])
pb_utils.Tensor("OUTPUT_IDS", output_ids)
])
responses.append(response)
assert len(responses) == len(requests)
return responses

return responses

0 comments on commit 4c11c66

Please sign in to comment.