From 86fc13253f0ea0d749065973bf07b60d5bd9fb9f Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Thu, 22 Aug 2024 18:35:05 +0800 Subject: [PATCH] support zipformer2 offline triton recipe --- triton/Dockerfile/Dockerfile.server | 45 +-- triton/README.md | 34 +- .../feature_extractor/config.pbtxt.template | 2 +- .../scorer/config.pbtxt.template | 16 + ...ch_pruned_transducer_stateless7_offline.sh | 2 +- triton/scripts/build_trt.sh | 2 +- ...build_wenetspeech_zipformer_offline_trt.sh | 131 +++++++ .../model_repo_offline/decoder/1/.gitkeep | 0 .../decoder/config.pbtxt.template | 44 --- .../model_repo_offline/encoder/1/.gitkeep | 0 .../encoder/config.pbtxt.template | 55 --- .../feature_extractor/1/model.py | 155 -------- .../feature_extractor/config.pbtxt.template | 72 ---- .../model_repo_offline/joiner/1/.gitkeep | 0 .../joiner/config.pbtxt.template | 49 --- .../model_repo_offline/scorer/1/model.py | 348 ------------------ .../scorer/config.pbtxt.template | 84 ----- .../model_repo_offline/transducer/1/.gitkeep | 0 .../transducer/config.pbtxt.template | 99 ----- 19 files changed, 175 insertions(+), 963 deletions(-) create mode 100644 triton/scripts/build_wenetspeech_zipformer_offline_trt.sh delete mode 100755 triton/zipformer/model_repo_offline/decoder/1/.gitkeep delete mode 100755 triton/zipformer/model_repo_offline/decoder/config.pbtxt.template delete mode 100755 triton/zipformer/model_repo_offline/encoder/1/.gitkeep delete mode 100755 triton/zipformer/model_repo_offline/encoder/config.pbtxt.template delete mode 100755 triton/zipformer/model_repo_offline/feature_extractor/1/model.py delete mode 100755 triton/zipformer/model_repo_offline/feature_extractor/config.pbtxt.template delete mode 100755 triton/zipformer/model_repo_offline/joiner/1/.gitkeep delete mode 100755 triton/zipformer/model_repo_offline/joiner/config.pbtxt.template delete mode 100755 triton/zipformer/model_repo_offline/scorer/1/model.py delete mode 100755 triton/zipformer/model_repo_offline/scorer/config.pbtxt.template delete mode 100755 triton/zipformer/model_repo_offline/transducer/1/.gitkeep delete mode 100644 triton/zipformer/model_repo_offline/transducer/config.pbtxt.template diff --git a/triton/Dockerfile/Dockerfile.server b/triton/Dockerfile/Dockerfile.server index e8a691089..ae8654a44 100755 --- a/triton/Dockerfile/Dockerfile.server +++ b/triton/Dockerfile/Dockerfile.server @@ -1,41 +1,24 @@ -FROM nvcr.io/nvidia/tritonserver:22.12-py3 +FROM nvcr.io/nvidia/tritonserver:24.07-py3 # https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html # Please choose previous tritonserver:xx.xx if you encounter cuda driver mismatch issue - LABEL maintainer="NVIDIA" LABEL repository="tritonserver" -RUN apt-get update && apt-get -y install \ - python3-dev \ - cmake \ - libsndfile1 -RUN pip3 install \ - torch==1.13.1+cu117 \ - torchaudio==0.13.1+cu117 \ - --index-url https://download.pytorch.org/whl/cu117 -RUN pip3 install \ - kaldialign \ - tensorboard \ - sentencepiece \ - lhotse \ - kaldifeat -RUN pip3 install \ - k2==1.24.4.dev20240223+cuda11.7.torch1.13.1 -f https://k2-fsa.github.io/k2/cuda.html -# Dependency for client -RUN pip3 install soundfile grpcio-tools tritonclient pyyaml +RUN apt-get update && apt-get install -y cmake +RUN python3 -m pip install k2==1.24.4.dev20240725+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html && \ + python3 -m pip install -r https://raw.githubusercontent.com/k2-fsa/icefall/master/requirements.txt && \ + pip install -U "huggingface_hub[cli]" lhotse colored onnx_graphsurgeon polygraphy +# https://github.com/k2-fsa/k2/blob/master/k2/python/k2/__init__.py#L13 delete the cuda version check +RUN sed -i '/if (/,/^ )/d' /usr/local/lib/python3.10/dist-packages/k2/__init__.py WORKDIR /workspace -# #install k2 from source -# #"sed -i ..." line tries to turn off the cuda check -# RUN git clone https://github.com/k2-fsa/k2.git && \ -# cd k2 && \ -# sed -i 's/FATAL_ERROR/STATUS/g' cmake/torch.cmake && \ -# sed -i 's/in running_cuda_version//g' get_version.py && \ -# python3 setup.py install && \ -# cd - +RUN git clone https://github.com/csukuangfj/kaldifeat && \ + cd kaldifeat && \ + sed -i 's/in running_cuda_version//g' get_version.py && \ + python3 setup.py install && \ + cd - + RUN git clone https://github.com/k2-fsa/icefall.git ENV PYTHONPATH "${PYTHONPATH}:/workspace/icefall" -# https://github.com/k2-fsa/icefall/issues/674 -ENV PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION "python" -COPY ./scripts scripts +COPY ./scripts scripts \ No newline at end of file diff --git a/triton/README.md b/triton/README.md index ed2025734..c98c47052 100755 --- a/triton/README.md +++ b/triton/README.md @@ -34,19 +34,14 @@ Build the server docker image: cd $SHERPA_SRC/triton docker build . -f Dockerfile/Dockerfile.server -t sherpa_triton_server:latest --network host ``` -Alternatively, you could directly pull the pre-built image based on tritonserver 22.12. +Alternatively, you could directly pull the pre-built image based on tritonserver image. ``` -docker pull soar97/triton-k2:22.12.1 -``` - -If you are planning to use TRT to accelerate the inference speed, you can use the following prebuit image: -``` -docker pull wd929/sherpa_wend_23.04:v1.1 +docker pull soar97/triton-k2:24.07 ``` Start the docker container: ```bash -docker run --gpus all -v $SHERPA_SRC:/workspace/sherpa --name sherpa_server --net host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -it soar97/triton-k2:22.12.1 +docker run --gpus all -v $SHERPA_SRC:/workspace/sherpa --name sherpa_server --net host --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 -it soar97/triton-k2:24.07 ``` Now, you should enter into the container successfully. @@ -69,8 +64,7 @@ apt-get install git-lfs pip3 install -r ./requirements.txt export CUDA_VISIBLE_DEVICES="your_gpu_id" -bash scripts/build_wenetspeech_pruned_transducer_stateless5_streaming.sh -bash scripts/build_librispeech_pruned_transducer_stateless3_streaming.sh +bash scripts/build_wenetspeech_zipformer_offline_trt.sh ``` ## Using TensorRT acceleration @@ -83,26 +77,20 @@ You can directly use the following script to export TRT engine and start Triton bash scripts/build_librispeech_pruned_transducer_stateless3_offline_trt.sh ``` -### Export to TensorRT Step by Step - -If you want to build TensorRT for your own model, you can try the following steps: +### Export to TensorRT -#### Preparation for TRT - -First of all, you have to install the TensorRT. Here we suggest you to use docker container to run TRT. Just run the following command: - -```bash -docker run --gpus '"device=0"' -it --rm --net host -v $PWD/:/k2 nvcr.io/nvidia/tensorrt:23.04-py3 -``` -You can also see [here](https://github.com/NVIDIA/TensorRT#build) to build TRT on your machine. +If you want to build TensorRT for your own service, you can try the following steps: #### Model export -You have to prepare the ONNX model by referring [here](https://github.com/k2-fsa/sherpa/blob/master/triton/scripts/build_librispeech_pruned_transducer_stateless3_offline.sh#L41C1-L41C1) to export your models into ONNX format. Assume you have put your ONNX model in the `$model_dir` directory. +You have to prepare the ONNX model by referring [here](https://icefall.readthedocs.io/en/latest/model-export/export-onnx.html#export-the-model-to-onnx) to export your models into ONNX format. Assume you have put your ONNX model in the `$model_dir` directory. Then, just run the command: ```bash -bash scripts/build_trt.sh 128 $model_dir/encoder.onnx model_repo_offline/encoder/1/encoder.trt +# First, use polygraphy to simplify the onnx model. +polygraphy surgeon sanitize $model_dir/encoder.onnx --fold-constant -o encoder.trt +# Using /usr/src/tensorrt/bin/trtexec tool in the tritonserver docker image. +bash scripts/build_trt.sh 16 $model_dir/encoder.onnx model_repo_offline/encoder/1/encoder.trt ``` The generated TRT model will be saved into `model_repo_offline/encoder/1/encoder.trt`. diff --git a/triton/model_repo_offline/feature_extractor/config.pbtxt.template b/triton/model_repo_offline/feature_extractor/config.pbtxt.template index 17ee879e1..4a1f557d6 100755 --- a/triton/model_repo_offline/feature_extractor/config.pbtxt.template +++ b/triton/model_repo_offline/feature_extractor/config.pbtxt.template @@ -44,7 +44,7 @@ input [ }, { name: "wav_lens" - data_type: TYPE_INT64 + data_type: TYPE_INT32 dims: [1] } ] diff --git a/triton/model_repo_offline/scorer/config.pbtxt.template b/triton/model_repo_offline/scorer/config.pbtxt.template index 7c37df0b5..9e4d2da97 100755 --- a/triton/model_repo_offline/scorer/config.pbtxt.template +++ b/triton/model_repo_offline/scorer/config.pbtxt.template @@ -32,6 +32,22 @@ parameters [ { key: "decoding_method", value: { string_value: "greedy_search"} + }, + { + key: "beam", + value: { string_value: "4"} + }, + { + key: "max_contexts", + value: { string_value: "4"} + }, + { + key: "max_states", + value: { string_value: "32"} + }, + { + key: "temperature", + value: { string_value: "1.0"} } ] diff --git a/triton/scripts/build_librispeech_pruned_transducer_stateless7_offline.sh b/triton/scripts/build_librispeech_pruned_transducer_stateless7_offline.sh index 9955d17b8..7c674ec0d 100644 --- a/triton/scripts/build_librispeech_pruned_transducer_stateless7_offline.sh +++ b/triton/scripts/build_librispeech_pruned_transducer_stateless7_offline.sh @@ -4,7 +4,7 @@ stop_stage=2 # change to your own model directory pretrained_model_dir=/mnt/samsung-t7/wend/github/icefall/egs/librispeech/ASR/pruned_transducer_stateless7/exp/ -model_repo_path=./zipformer/model_repo_offline +model_repo_path=./model_repo_offline # modify model specific parameters according to $pretrained_model_dir/exp/onnx_export.log VOCAB_SIZE=500 diff --git a/triton/scripts/build_trt.sh b/triton/scripts/build_trt.sh index 8bd2e3252..bc2bdba6c 100644 --- a/triton/scripts/build_trt.sh +++ b/triton/scripts/build_trt.sh @@ -14,7 +14,7 @@ # paramters for TRT engines MIN_BATCH=1 -OPT_BATCH=32 +OPT_BATCH=4 MAX_BATCH=$1 onnx_model=$2 trt_model=$3 diff --git a/triton/scripts/build_wenetspeech_zipformer_offline_trt.sh b/triton/scripts/build_wenetspeech_zipformer_offline_trt.sh new file mode 100644 index 000000000..638a2d435 --- /dev/null +++ b/triton/scripts/build_wenetspeech_zipformer_offline_trt.sh @@ -0,0 +1,131 @@ +#!/bin/bash +stage=-1 +stop_stage=3 + +export CUDA_VISIBLE_DEVICES=1 + +pretrained_model_dir=/workspace/icefall-asr-zipformer-wenetspeech-20230615 +model_repo_path=./model_repo_offline + +# modify model specific parameters according to $pretrained_model_dir/exp/ log files +VOCAB_SIZE=5537 + +DECODER_CONTEXT_SIZE=2 +DECODER_DIM=512 +ENCODER_DIM=512 # max(_to_int_tuple(params.encoder_dim) + + +if [ -d "$pretrained_model_dir/data/lang_char" ] +then + echo "pretrained model using char" + TOKENIZER_FILE=$pretrained_model_dir/data/lang_char +else + echo "pretrained model using bpe" + TOKENIZER_FILE=$pretrained_model_dir/data/lang_bpe_500/bpe.model +fi + +MAX_BATCH=16 +# model instance num +FEATURE_EXTRACTOR_INSTANCE_NUM=2 +ENCODER_INSTANCE_NUM=1 +JOINER_INSTANCE_NUM=1 +DECODER_INSTANCE_NUM=1 +SCORER_INSTANCE_NUM=2 + + +icefall_dir=/workspace/icefall +export PYTHONPATH=$PYTHONPATH:$icefall_dir +recipe_dir=$icefall_dir/egs/wenetspeech/ASR/zipformer + +if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then + if [ -d "$pretrained_model_dir" ] + then + echo "skip download pretrained model" + else + echo "downloading pretrained model" + cd /workspace + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615 + pushd icefall-asr-zipformer-wenetspeech-20230615 + git lfs pull --include "exp/pretrained.pt" + ln -s ./exp/pretrained.pt ./exp/epoch-9999.pt + popd + cd - + fi +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "export onnx" + cd ${recipe_dir} + # WAR: please comment https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/zipformer.py#L1422-L1427 + # if you would like to use the exported onnx to build trt engine later. + python3 ./export-onnx.py \ + --tokens $TOKENIZER_FILE/tokens.txt \ + --use-averaged-model 0 \ + --epoch 9999 \ + --avg 1 \ + --exp-dir $pretrained_model_dir/exp/ \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --causal False || exit 1 + + cd - +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "auto gen config.pbtxt" + dirs="encoder decoder feature_extractor joiner scorer transducer" + + if [ ! -d $model_repo_path ]; then + echo "Please cd to $model_repo_path" + exit 1 + fi + + cp -r $TOKENIZER_FILE $model_repo_path/scorer/ + TOKENIZER_FILE=$model_repo_path/scorer/$(basename $TOKENIZER_FILE) + for dir in $dirs + do + cp $model_repo_path/$dir/config.pbtxt.template $model_repo_path/$dir/config.pbtxt + + sed -i "s|VOCAB_SIZE|${VOCAB_SIZE}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|DECODER_CONTEXT_SIZE|${DECODER_CONTEXT_SIZE}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|DECODER_DIM|${DECODER_DIM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|ENCODER_LAYERS|${ENCODER_LAYERS}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|ENCODER_DIM|${ENCODER_DIM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|ENCODER_LEFT_CONTEXT|${ENCODER_LEFT_CONTEXT}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|ENCODER_RIGHT_CONTEXT|${ENCODER_RIGHT_CONTEXT}|g" $model_repo_path/$dir/config.pbtxt + + sed -i "s|TOKENIZER_FILE|${TOKENIZER_FILE}|g" $model_repo_path/$dir/config.pbtxt + + sed -i "s|MAX_BATCH|${MAX_BATCH}|g" $model_repo_path/$dir/config.pbtxt + + sed -i "s|FEATURE_EXTRACTOR_INSTANCE_NUM|${FEATURE_EXTRACTOR_INSTANCE_NUM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|ENCODER_INSTANCE_NUM|${ENCODER_INSTANCE_NUM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|JOINER_INSTANCE_NUM|${JOINER_INSTANCE_NUM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|DECODER_INSTANCE_NUM|${DECODER_INSTANCE_NUM}|g" $model_repo_path/$dir/config.pbtxt + sed -i "s|SCORER_INSTANCE_NUM|${SCORER_INSTANCE_NUM}|g" $model_repo_path/$dir/config.pbtxt + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + cp $pretrained_model_dir/exp/encoder-epoch-9999-avg-1.onnx $model_repo_path/encoder/1/encoder.onnx + cp $pretrained_model_dir/exp/decoder-epoch-9999-avg-1.onnx $model_repo_path/decoder/1/decoder.onnx + cp $pretrained_model_dir/exp/joiner-epoch-9999-avg-1.onnx $model_repo_path/joiner/1/joiner.onnx +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Buiding TRT engine..., skip the stage if you would like to use onnxruntime" + polygraphy surgeon sanitize $pretrained_model_dir/exp/encoder-epoch-9999-avg-1.onnx --fold-constant -o $pretrained_model_dir/exp/encoder.onnx + bash scripts/build_trt.sh $MAX_BATCH $pretrained_model_dir/exp/encoder.onnx $model_repo_path/encoder/1/encoder.trt || exit 1 + + sed -i "s|onnxruntime|tensorrt|g" $model_repo_path/encoder/config.pbtxt + sed -i "s|encoder.onnx|encoder.trt|g" $model_repo_path/encoder/config.pbtxt +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + tritonserver --model-repository=$model_repo_path --pinned-memory-pool-byte-size=512000000 --cuda-memory-pool-byte-size=0:1024000000 --http-port 10086 +fi diff --git a/triton/zipformer/model_repo_offline/decoder/1/.gitkeep b/triton/zipformer/model_repo_offline/decoder/1/.gitkeep deleted file mode 100755 index e69de29bb..000000000 diff --git a/triton/zipformer/model_repo_offline/decoder/config.pbtxt.template b/triton/zipformer/model_repo_offline/decoder/config.pbtxt.template deleted file mode 100755 index 7ffaf6d3b..000000000 --- a/triton/zipformer/model_repo_offline/decoder/config.pbtxt.template +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "decoder" -backend: "onnxruntime" -default_model_filename: "decoder.onnx" - -max_batch_size: MAX_BATCH -input [ - { - name: "y" - data_type: TYPE_INT64 - dims: [DECODER_CONTEXT_SIZE] - } -] - -output [ - { - name: "decoder_out" - data_type: TYPE_FP32 - dims: [ DECODER_DIM ] - } -] - -dynamic_batching { - } - -instance_group [ - { - count: DECODER_INSTANCE_NUM - kind: KIND_GPU - } -] diff --git a/triton/zipformer/model_repo_offline/encoder/1/.gitkeep b/triton/zipformer/model_repo_offline/encoder/1/.gitkeep deleted file mode 100755 index e69de29bb..000000000 diff --git a/triton/zipformer/model_repo_offline/encoder/config.pbtxt.template b/triton/zipformer/model_repo_offline/encoder/config.pbtxt.template deleted file mode 100755 index 5908f7118..000000000 --- a/triton/zipformer/model_repo_offline/encoder/config.pbtxt.template +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "encoder" -backend: "onnxruntime" -default_model_filename: "encoder.onnx" - -max_batch_size: MAX_BATCH -input [ - { - name: "x" - data_type: TYPE_FP32 - dims: [-1, 80] - }, - { - name: "x_lens" - data_type: TYPE_INT64 - dims: [1] - reshape: { shape: [ ] } - } -] -output [ - { - name: "encoder_out" - data_type: TYPE_FP32 - dims: [-1, ENCODER_DIM ] - }, - { - name: "encoder_out_lens" - data_type: TYPE_INT64 - dims: [1] - reshape: { shape: [ ] } - } -] - -dynamic_batching { - } - -instance_group [ - { - count: ENCODER_INSTANCE_NUM - kind: KIND_GPU - } -] diff --git a/triton/zipformer/model_repo_offline/feature_extractor/1/model.py b/triton/zipformer/model_repo_offline/feature_extractor/1/model.py deleted file mode 100755 index 7288b3823..000000000 --- a/triton/zipformer/model_repo_offline/feature_extractor/1/model.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import triton_python_backend_utils as pb_utils -from torch.utils.dlpack import to_dlpack -import torch -import numpy as np -import kaldifeat -import _kaldifeat -from typing import List -import json - -class Fbank(torch.nn.Module): - def __init__(self, opts): - super(Fbank, self).__init__() - self.fbank = kaldifeat.Fbank(opts) - - def forward(self, waves: List[torch.Tensor]): - return self.fbank(waves) - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - self.model_config = model_config = json.loads(args['model_config']) - self.max_batch_size = max(model_config["max_batch_size"], 1) - - if "GPU" in model_config["instance_group"][0]["kind"]: - self.device = "cuda" - else: - self.device = "cpu" - - # Get OUTPUT0 configuration - output0_config = pb_utils.get_output_config_by_name( - model_config, "speech") - # Convert Triton types to numpy types - output0_dtype = pb_utils.triton_string_to_numpy( - output0_config['data_type']) - if output0_dtype == np.float32: - self.output0_dtype = torch.float32 - else: - self.output0_dtype = torch.float16 - - # Get OUTPUT1 configuration - output1_config = pb_utils.get_output_config_by_name( - model_config, "speech_lengths") - # Convert Triton types to numpy types - self.output1_dtype = pb_utils.triton_string_to_numpy( - output1_config['data_type']) - - params = self.model_config['parameters'] - opts = kaldifeat.FbankOptions() - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - for li in params.items(): - key, value = li - value = value["string_value"] - if key == "num_mel_bins": - opts.mel_opts.num_bins = int(value) - elif key == "frame_shift_in_ms": - opts.frame_opts.frame_shift_ms = float(value) - elif key == "frame_length_in_ms": - opts.frame_opts.frame_length_ms = float(value) - elif key == "sample_rate": - opts.frame_opts.samp_freq = int(value) - opts.device = torch.device(self.device) - self.opts = opts - self.feature_extractor = Fbank(self.opts) - self.feature_size = opts.mel_opts.num_bins - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. - - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - batch_count = [] - total_waves = [] - batch_len = [] - responses = [] - for request in requests: - input0 = pb_utils.get_input_tensor_by_name(request, "wav") - input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens") - - cur_b_wav = input0.as_numpy() - cur_b_wav_lens = input1.as_numpy() # b x 1 - cur_batch = cur_b_wav.shape[0] - cur_len = cur_b_wav.shape[1] - batch_count.append(cur_batch) - batch_len.append(cur_len) - for wav, wav_len in zip(cur_b_wav, cur_b_wav_lens): - wav_len = wav_len[0] - wav = torch.tensor(wav[0:wav_len], dtype=torch.float32, - device=self.device) - total_waves.append(wav) - - features = self.feature_extractor(total_waves) - for b, l in zip(batch_count, batch_len): - expect_feat_len = _kaldifeat.num_frames(l, self.opts.frame_opts) - speech = torch.zeros((b, expect_feat_len, self.feature_size), - dtype=self.output0_dtype, device=self.device) - speech_lengths = torch.zeros((b, 1), dtype=torch.int64, device=self.device) - for i in range(b): - f = features.pop(0) - f_l = f.shape[0] - speech[i, 0: f_l, :] = f.to(self.output0_dtype) - speech_lengths[i][0] = f_l - speech = speech.cpu() - speech_lengths = speech_lengths.cpu() - out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech)) - out1 = pb_utils.Tensor.from_dlpack("speech_lengths", - to_dlpack(speech_lengths)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0, out1]) - responses.append(inference_response) - return responses diff --git a/triton/zipformer/model_repo_offline/feature_extractor/config.pbtxt.template b/triton/zipformer/model_repo_offline/feature_extractor/config.pbtxt.template deleted file mode 100755 index 0ae5b166c..000000000 --- a/triton/zipformer/model_repo_offline/feature_extractor/config.pbtxt.template +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "feature_extractor" -backend: "python" -max_batch_size: MAX_BATCH - -parameters [ - { - key: "num_mel_bins", - value: { string_value: "80"} - }, - { - key: "frame_shift_in_ms" - value: { string_value: "10"} - }, - { - key: "frame_length_in_ms" - value: { string_value: "25"} - }, - { - key: "sample_rate" - value: { string_value: "16000"} - } - -] - -input [ - { - name: "wav" - data_type: TYPE_FP32 - dims: [-1] - }, - { - name: "wav_lens" - data_type: TYPE_INT32 - dims: [1] - } -] - -output [ - { - name: "speech" - data_type: TYPE_FP32 - dims: [-1, 80] - }, - { - name: "speech_lengths" - data_type: TYPE_INT64 - dims: [1] - } -] - -dynamic_batching { - } -instance_group [ - { - count: FEATURE_EXTRACTOR_INSTANCE_NUM - kind: KIND_GPU - } -] diff --git a/triton/zipformer/model_repo_offline/joiner/1/.gitkeep b/triton/zipformer/model_repo_offline/joiner/1/.gitkeep deleted file mode 100755 index e69de29bb..000000000 diff --git a/triton/zipformer/model_repo_offline/joiner/config.pbtxt.template b/triton/zipformer/model_repo_offline/joiner/config.pbtxt.template deleted file mode 100755 index 64b9ff0bf..000000000 --- a/triton/zipformer/model_repo_offline/joiner/config.pbtxt.template +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "joiner" -backend: "onnxruntime" -default_model_filename: "joiner.onnx" - -max_batch_size: 512 -input [ - { - name: "encoder_out" - data_type: TYPE_FP32 - dims: [ ENCODER_DIM ] - }, - { - name: "decoder_out" - data_type: TYPE_FP32 - dims: [ DECODER_DIM ] - } -] - -output [ - { - name: "logit" - data_type: TYPE_FP32 - dims: [ VOCAB_SIZE ] - } -] - -dynamic_batching { - } - -instance_group [ - { - count: JOINER_INSTANCE_NUM - kind: KIND_GPU - } -] diff --git a/triton/zipformer/model_repo_offline/scorer/1/model.py b/triton/zipformer/model_repo_offline/scorer/1/model.py deleted file mode 100755 index 7a2cff501..000000000 --- a/triton/zipformer/model_repo_offline/scorer/1/model.py +++ /dev/null @@ -1,348 +0,0 @@ -# -*- coding: utf-8 -*- -import triton_python_backend_utils as pb_utils -import numpy as np -import json -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack -import sentencepiece as spm -from icefall.lexicon import Lexicon -from typing import List, Union -import k2 - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - self.model_config = model_config = json.loads(args['model_config']) - self.max_batch_size = max(model_config["max_batch_size"], 1) - - # Get OUTPUT0 configuration - output0_config = pb_utils.get_output_config_by_name( - model_config, "OUTPUT0") - # Convert Triton types to numpy types - self.out0_dtype = pb_utils.triton_string_to_numpy( - output0_config['data_type']) - - model_instance_kind = args['model_instance_kind'] - model_instance_device_id = args['model_instance_device_id'] - if model_instance_kind == 'GPU': - self.device = f'cuda:{model_instance_device_id}' - else: - self.device= 'cpu' - - # Get INPUT configuration - encoder_config = pb_utils.get_input_config_by_name( - model_config, "encoder_out") - self.data_type = pb_utils.triton_string_to_numpy( - encoder_config['data_type']) - if self.data_type == np.float32: - self.torch_dtype = torch.float32 - else: - assert self.data_type == np.float16 - self.torch_dtype = torch.float16 - - self.encoder_dim = encoder_config['dims'][-1] - - - self.init_parameters(self.model_config['parameters']) - - def init_parameters(self, parameters): - for key,value in parameters.items(): - parameters[key] = value["string_value"] - self.context_size = int(parameters['context_size']) - self.decoding_method = parameters['decoding_method'] - if 'bpe' in parameters['tokenizer_file']: - sp = spm.SentencePieceProcessor() - sp.load(parameters['tokenizer_file']) - self.blank_id = sp.piece_to_id("") - self.unk_id = sp.piece_to_id("") - self.vocab_size = sp.get_piece_size() - self.tokenizer = sp - else: - assert 'char' in parameters['tokenizer_file'] - lexicon = Lexicon(parameters['tokenizer_file']) - self.unk_id = lexicon.token_table[""] - self.blank_id = lexicon.token_table[""] - self.vocab_size = max(lexicon.tokens) + 1 - self.tokenizer = lexicon - if self.decoding_method == 'fast_beam_search': - # parameters for fast beam search - self.beam = int(self.model_config['parameters']['beam']) - self.max_contexts = int(self.model_config['parameters']['max_contexts']) - self.max_states = int(self.model_config['parameters']['max_states']) - self.temperature = float(self.model_config['parameters']['temperature']) - # Support fast beam search one best currently - self.decoding_graph = k2.trivial_graph( - self.vocab_size - 1, device=self.device - ) - - def forward_joiner(self, cur_encoder_out, decoder_out): - in_joiner_tensor_0 = pb_utils.Tensor.from_dlpack("encoder_out", to_dlpack(cur_encoder_out)) - in_joiner_tensor_1 = pb_utils.Tensor.from_dlpack("decoder_out", to_dlpack(decoder_out.squeeze(1))) - - inference_request = pb_utils.InferenceRequest( - model_name='joiner', - requested_output_names=['logit'], - inputs=[in_joiner_tensor_0, in_joiner_tensor_1]) - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - else: - # Extract the output tensors from the inference response. - logits = pb_utils.get_output_tensor_by_name(inference_response, - 'logit') - logits = torch.utils.dlpack.from_dlpack(logits.to_dlpack()).cpu() - assert len(logits.shape) == 2, logits.shape - return logits - - def forward_decoder(self, hyps): - decoder_input = np.asarray(hyps,dtype=np.int64) - - in_decoder_input_tensor = pb_utils.Tensor("y", decoder_input) - - inference_request = pb_utils.InferenceRequest( - model_name='decoder', - requested_output_names=['decoder_out'], - inputs=[in_decoder_input_tensor]) - - inference_response = inference_request.exec() - if inference_response.has_error(): - raise pb_utils.TritonModelException(inference_response.error().message()) - else: - # Extract the output tensors from the inference response. - decoder_out = pb_utils.get_output_tensor_by_name(inference_response, - 'decoder_out') - decoder_out = from_dlpack(decoder_out.to_dlpack()) - return decoder_out - - - def greedy_search(self, encoder_out, encoder_out_lens): - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False - ) - - pack_batch_size_list = packed_encoder_out.batch_sizes.tolist() - - hyps = [[self.blank_id] * self.context_size for _ in range(encoder_out.shape[0])] - contexts = [h[-self.context_size:] for h in hyps] - decoder_out = self.forward_decoder(contexts) - - offset = 0 - for batch_size in pack_batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = self.forward_joiner(current_encoder_out, decoder_out) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - - emitted = False - for i, v in enumerate(y): - if v not in (self.blank_id, self.unk_id): - hyps[i].append(v) - emitted = True - if emitted: - hyp = hyps[:batch_size] - contexts = [h[-self.context_size:] for h in hyp] - decoder_out = self.forward_decoder(contexts) - - - sorted_ans = [h[self.context_size:] for h in hyps] - - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(encoder_out.shape[0]): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - # From k2 utils.py - def get_texts(self, - best_paths: k2.Fsa, return_ragged: bool = False - ) -> Union[List[List[int]], k2.RaggedTensor]: - """Extract the texts (as word IDs) from the best-path FSAs. - Args: - best_paths: - A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. - containing multiple FSAs, which is expected to be the result - of k2.shortest_path (otherwise the returned values won't - be meaningful). - return_ragged: - True to return a ragged tensor with two axes [utt][word_id]. - False to return a list-of-list word IDs. - Returns: - Returns a list of lists of int, containing the label sequences we - decoded. - """ - if isinstance(best_paths.aux_labels, k2.RaggedTensor): - # remove 0's and -1's. - aux_labels = best_paths.aux_labels.remove_values_leq(0) - # TODO: change arcs.shape() to arcs.shape - aux_shape = best_paths.arcs.shape().compose(aux_labels.shape) - - # remove the states and arcs axes. - aux_shape = aux_shape.remove_axis(1) - aux_shape = aux_shape.remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values) - else: - # remove axis corresponding to states. - aux_shape = best_paths.arcs.shape().remove_axis(1) - aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels) - # remove 0's and -1's. - aux_labels = aux_labels.remove_values_leq(0) - - assert aux_labels.num_axes == 2 - if return_ragged: - return aux_labels - else: - return aux_labels.tolist() - - def fast_beam_search(self, encoder_out, encoder_out_lens): - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=self.vocab_size, - decoder_history_len=self.context_size, - beam=self.beam, - max_contexts=self.max_contexts, - max_states=self.max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(self.decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - shape, contexts = decoding_streams.get_contexts() - contexts = contexts.to(torch.int64) - - decoder_out = self.forward_decoder(contexts) - - cur_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - - logits = self.forward_joiner(cur_encoder_out.squeeze(1), - decoder_out) - - logits = logits.squeeze(1).squeeze(1).float() - log_probs = (logits / self.temperature).log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - best_path = k2.shortest_path(lattice, use_double_scores=True) - hyps_list = self.get_texts(best_path) - - return hyps_list - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. - - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - # Every Python backend must iterate through list of requests and create - # an instance of pb_utils.InferenceResponse class for each of them. You - # should avoid storing any of the input Tensors in the class attributes - # as they will be overridden in subsequent inference requests. You can - # make a copy of the underlying NumPy array and store it if it is - # required. - - batch_encoder_out_list, batch_encoder_lens_list = [], [] - batchsize_lists = [] - total_seqs = 0 - encoder_max_len = 0 - - for request in requests: - # Perform inference on the request and append it to responses list... - in_0 = pb_utils.get_input_tensor_by_name(request, "encoder_out") - in_1 = pb_utils.get_input_tensor_by_name(request, "encoder_out_lens") - assert not in_0.is_cpu() - batch_encoder_out_list.append(from_dlpack(in_0.to_dlpack())) - encoder_max_len = max(encoder_max_len, batch_encoder_out_list[-1].shape[1]) - cur_b_lens = from_dlpack(in_1.to_dlpack()) - batch_encoder_lens_list.append(cur_b_lens) - cur_batchsize = cur_b_lens.shape[0] - batchsize_lists.append(cur_batchsize) - total_seqs += cur_batchsize - - encoder_out = torch.zeros((total_seqs, encoder_max_len, self.encoder_dim), - dtype=self.torch_dtype, device=self.device) - encoder_out_lens = torch.zeros(total_seqs, dtype=torch.int64) - st = 0 - - for b in batchsize_lists: - t = batch_encoder_out_list.pop(0) - encoder_out[st:st + b, 0:t.shape[1]] = t - encoder_out_lens[st:st + b] = batch_encoder_lens_list.pop(0) - st += b - - if self.decoding_method == 'greedy_search': - ans = self.greedy_search(encoder_out, encoder_out_lens) - elif self.decoding_method == 'fast_beam_search': - ans = self.fast_beam_search(encoder_out, encoder_out_lens) - else: - raise NotImplementedError - - results = [] - if hasattr(self.tokenizer, 'token_table'): - for i in range(len(ans)): - results.append([self.tokenizer.token_table[idx] for idx in ans[i]]) - else: - for hyp in self.tokenizer.decode(ans): - results.append(hyp.split()) - st = 0 - responses = [] - for b in batchsize_lists: - sents = np.array(results[st:st + b]) - out0 = pb_utils.Tensor("OUTPUT0", sents.astype(self.out0_dtype)) - inference_response = pb_utils.InferenceResponse(output_tensors=[out0]) - responses.append(inference_response) - st += b - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - Implementing `finalize` function is optional. This function allows - the model to perform any necessary clean ups before exit. - """ - print('Cleaning up...') diff --git a/triton/zipformer/model_repo_offline/scorer/config.pbtxt.template b/triton/zipformer/model_repo_offline/scorer/config.pbtxt.template deleted file mode 100755 index 3dc30db35..000000000 --- a/triton/zipformer/model_repo_offline/scorer/config.pbtxt.template +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "scorer" -backend: "python" -max_batch_size: MAX_BATCH - -parameters [ - { - key: "context_size", - value: { string_value: "DECODER_CONTEXT_SIZE"} - }, - { - key: "tokenizer_file", - value: { string_value: "TOKENIZER_FILE"} - }, - { - key: "FORCE_CPU_ONLY_INPUT_TENSORS", - value: {string_value:"no"} - }, - { - key: "decoding_method", - value: { string_value: "greedy_search"} - }, - { - key: "beam", - value: { string_value: "4"} - }, - { - key: "max_contexts", - value: { string_value: "4"} - }, - { - key: "max_states", - value: { string_value: "32"} - }, - { - key: "temperature", - value: { string_value: "1.0"} - } -] - - -input [ - { - name: "encoder_out" - data_type: TYPE_FP32 - dims: [-1, ENCODER_DIM] - }, - { - name: "encoder_out_lens" - data_type: TYPE_INT64 - dims: [1] - reshape: { shape: [ ] } - } -] - -output [ - { - name: "OUTPUT0" - data_type: TYPE_STRING - dims: [1] - } -] - -dynamic_batching { - } -instance_group [ - { - count: SCORER_INSTANCE_NUM - kind: KIND_CPU - } - ] diff --git a/triton/zipformer/model_repo_offline/transducer/1/.gitkeep b/triton/zipformer/model_repo_offline/transducer/1/.gitkeep deleted file mode 100755 index e69de29bb..000000000 diff --git a/triton/zipformer/model_repo_offline/transducer/config.pbtxt.template b/triton/zipformer/model_repo_offline/transducer/config.pbtxt.template deleted file mode 100644 index 4f722fc88..000000000 --- a/triton/zipformer/model_repo_offline/transducer/config.pbtxt.template +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: "transducer" -platform: "ensemble" -max_batch_size: MAX_BATCH - -input [ - { - name: "WAV" - data_type: TYPE_FP32 - dims: [-1] - }, - { - name: "WAV_LENS" - data_type: TYPE_INT32 - dims: [1] - } -] - -output [ - { - name: "TRANSCRIPTS" - data_type: TYPE_STRING - dims: [1] - } -] - -ensemble_scheduling { - step [ - { - model_name: "feature_extractor" - model_version: -1 - input_map { - key: "wav" - value: "WAV" - } - input_map { - key: "wav_lens" - value: "WAV_LENS" - } - output_map { - key: "speech" - value: "SPEECH" - } - output_map { - key: "speech_lengths" - value: "SPEECH_LENGTHS" - } - }, - { - model_name: "encoder" - model_version: -1 - input_map { - key: "x" - value: "SPEECH" - } - input_map { - key: "x_lens" - value: "SPEECH_LENGTHS" - } - output_map { - key: "encoder_out" - value: "encoder_out" - } - output_map { - key: "encoder_out_lens" - value: "encoder_out_lens" - } - }, - { - model_name: "scorer" - model_version: -1 - input_map { - key: "encoder_out" - value: "encoder_out" - } - input_map { - key: "encoder_out_lens" - value: "encoder_out_lens" - } - output_map { - key: "OUTPUT0" - value: "TRANSCRIPTS" - } - } - ] -}